mirror of
https://github.com/Frooodle/Stirling-PDF.git
synced 2025-11-16 01:21:16 +01:00
wip - implementing RAG system
This commit is contained in:
parent
1973c55d10
commit
c37707d9ad
@ -612,6 +612,7 @@ public class ApplicationProperties {
|
||||
private Audit audit = new Audit();
|
||||
private long maxPromptCharacters = 4000;
|
||||
private double minConfidenceNano = 0.65;
|
||||
private Usage usage = new Usage();
|
||||
|
||||
@Data
|
||||
public static class Cache {
|
||||
@ -626,6 +627,8 @@ public class ApplicationProperties {
|
||||
private String primary = "gpt-5-nano";
|
||||
private String fallback = "gpt-5-mini";
|
||||
private String embedding = "text-embedding-3-small";
|
||||
private double temperature = 0.2;
|
||||
private double topP = 0.95;
|
||||
private long connectTimeoutMillis = 10000;
|
||||
private long readTimeoutMillis = 60000;
|
||||
}
|
||||
@ -646,6 +649,12 @@ public class ApplicationProperties {
|
||||
public static class Audit {
|
||||
private boolean enabled = true;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class Usage {
|
||||
private long perUserMonthlyTokens = 200000;
|
||||
private double warnAtRatio = 0.7;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ springdoc.swagger-ui.url=/v1/api-docs
|
||||
# Spring AI OpenAI Configuration
|
||||
# Uses GPT-5-nano as primary model and GPT-5-mini as fallback (configured in settings.yml)
|
||||
spring.ai.openai.enabled=true
|
||||
spring.ai.openai.api-key=# todo <API-KEY-HERE>
|
||||
#spring.ai.openai.api-key=# todo <API-KEY-HERE>
|
||||
spring.ai.openai.base-url=https://api.openai.com
|
||||
spring.ai.openai.chat.enabled=true
|
||||
spring.ai.openai.chat.options.model=gpt-5-nano
|
||||
|
||||
@ -97,11 +97,13 @@ premium:
|
||||
cache:
|
||||
ttlMinutes: 720 # Cache entry lifetime (12h)
|
||||
maxEntries: 200 # Maximum number of cached documents per node
|
||||
maxDocumentCharacters: 200000 # Reject uploads exceeding this character count
|
||||
maxDocumentCharacters: 600000 # Reject uploads exceeding this character count
|
||||
models:
|
||||
primary: gpt-5-nano # Default lightweight model
|
||||
fallback: gpt-5-mini # Escalation model for complex prompts
|
||||
embedding: text-embedding-3-small # Embedding model for vector store usage
|
||||
temperature: 0.2 # Sampling temperature for LLM responses
|
||||
topP: 0.95 # Top-p (nucleus) sampling for LLM responses
|
||||
rag:
|
||||
chunkSizeTokens: 512 # Token window used when chunking text
|
||||
chunkOverlapTokens: 128 # Overlap between successive chunks
|
||||
@ -112,6 +114,9 @@ premium:
|
||||
enabled: true # Emit audit records for chatbot activity
|
||||
maxPromptCharacters: 4000 # Server-side guardrail for incoming prompts
|
||||
minConfidenceNano: 0.65 # Minimum nano confidence to avoid escalation
|
||||
usage:
|
||||
perUserMonthlyTokens: 200000 # Monthly RAG + chat token budget per user
|
||||
warnAtRatio: 0.7 # Warn users when usage exceeds 70%
|
||||
enterpriseFeatures:
|
||||
audit:
|
||||
enabled: true # Enable audit logging
|
||||
|
||||
@ -52,9 +52,9 @@ dependencies {
|
||||
api 'org.springframework.boot:spring-boot-starter-cache'
|
||||
api 'com.github.ben-manes.caffeine:caffeine'
|
||||
api 'io.swagger.core.v3:swagger-core-jakarta:2.2.38'
|
||||
api 'org.springframework.ai:spring-ai-starter-model-openai'
|
||||
api 'org.springframework.ai:spring-ai-starter-model-ollama'
|
||||
api 'org.springframework.ai:spring-ai-redis-store'
|
||||
implementation 'org.springframework.ai:spring-ai-starter-model-openai'
|
||||
implementation 'org.springframework.ai:spring-ai-starter-model-ollama'
|
||||
implementation 'org.springframework.ai:spring-ai-starter-vector-store-redis'
|
||||
implementation 'com.bucket4j:bucket4j_jdk17-core:8.15.0'
|
||||
|
||||
// https://mvnrepository.com/artifact/com.bucket4j/bucket4j_jdk17
|
||||
|
||||
@ -0,0 +1,54 @@
|
||||
package stirling.software.proprietary.config;
|
||||
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.ai.vectorstore.redis.RedisVectorStore;
|
||||
import org.springframework.beans.factory.ObjectProvider;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import redis.clients.jedis.JedisPooled;
|
||||
|
||||
@Configuration
|
||||
@org.springframework.boot.autoconfigure.condition.ConditionalOnProperty(
|
||||
value = "premium.proFeatures.chatbot.enabled",
|
||||
havingValue = "true")
|
||||
@Slf4j
|
||||
public class ChatbotVectorStoreConfig {
|
||||
|
||||
private static final String DEFAULT_INDEX = "stirling-chatbot-index";
|
||||
private static final String DEFAULT_PREFIX = "stirling:chatbot:";
|
||||
|
||||
@Bean
|
||||
@Primary
|
||||
public VectorStore chatbotVectorStore(
|
||||
ObjectProvider<JedisPooled> jedisProvider, EmbeddingModel embeddingModel) {
|
||||
JedisPooled jedis = jedisProvider.getIfAvailable();
|
||||
if (jedis != null) {
|
||||
try {
|
||||
jedis.ping();
|
||||
log.info("Initialising Redis vector store for chatbot usage");
|
||||
return RedisVectorStore.builder(jedis, embeddingModel)
|
||||
.indexName(DEFAULT_INDEX)
|
||||
.prefix(DEFAULT_PREFIX)
|
||||
.initializeSchema(true)
|
||||
.build();
|
||||
} catch (RuntimeException ex) {
|
||||
log.warn(
|
||||
"Redis vector store unavailable ({}). Falling back to SimpleVectorStore.",
|
||||
sanitize(ex.getMessage()));
|
||||
}
|
||||
} else {
|
||||
log.info("No Redis connection detected; using SimpleVectorStore for chatbot.");
|
||||
}
|
||||
return SimpleVectorStore.builder(embeddingModel).build();
|
||||
}
|
||||
|
||||
private String sanitize(String message) {
|
||||
return message == null ? "unknown error" : message.replaceAll("\\s+", " ").trim();
|
||||
}
|
||||
}
|
||||
@ -21,6 +21,7 @@ import stirling.software.proprietary.model.chatbot.ChatbotResponse;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotSession;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotSessionCreateRequest;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotSessionResponse;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotUsageSummary;
|
||||
import stirling.software.proprietary.service.chatbot.ChatbotCacheService;
|
||||
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties;
|
||||
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
|
||||
@ -28,12 +29,10 @@ import stirling.software.proprietary.service.chatbot.ChatbotService;
|
||||
import stirling.software.proprietary.service.chatbot.ChatbotSessionRegistry;
|
||||
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/v1/internal/chatbot")
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
// @ConditionalOnProperty(value = "premium.proFeatures.chatbot.enabled", havingValue = "true")
|
||||
// @ConditionalOnBean(ChatbotService.class)
|
||||
@RestController
|
||||
@RequiredArgsConstructor
|
||||
@RequestMapping("/api/v1/internal/chatbot")
|
||||
public class ChatbotController {
|
||||
|
||||
private final ChatbotService chatbotService;
|
||||
@ -54,10 +53,12 @@ public class ChatbotController {
|
||||
.ocrRequested(session.isOcrRequested())
|
||||
.imageContentDetected(session.isImageContentDetected())
|
||||
.textCharacters(session.getTextCharacters())
|
||||
.estimatedTokens(session.getEstimatedTokens())
|
||||
.maxCachedCharacters(cacheService.getMaxDocumentCharacters())
|
||||
.createdAt(session.getCreatedAt())
|
||||
.warnings(sessionWarnings(settings, session))
|
||||
.metadata(session.getMetadata())
|
||||
.usageSummary(session.getUsageSummary())
|
||||
.build();
|
||||
return ResponseEntity.status(HttpStatus.CREATED).body(response);
|
||||
}
|
||||
@ -83,10 +84,12 @@ public class ChatbotController {
|
||||
.ocrRequested(session.isOcrRequested())
|
||||
.imageContentDetected(session.isImageContentDetected())
|
||||
.textCharacters(session.getTextCharacters())
|
||||
.estimatedTokens(session.getEstimatedTokens())
|
||||
.maxCachedCharacters(cacheService.getMaxDocumentCharacters())
|
||||
.createdAt(session.getCreatedAt())
|
||||
.warnings(sessionWarnings(settings, session))
|
||||
.metadata(session.getMetadata())
|
||||
.usageSummary(session.getUsageSummary())
|
||||
.build();
|
||||
return ResponseEntity.ok(response);
|
||||
}
|
||||
@ -106,7 +109,16 @@ public class ChatbotController {
|
||||
|
||||
warnings.add("Only extracted text is sent for analysis.");
|
||||
if (session != null && session.isOcrRequested()) {
|
||||
warnings.add("OCR was requested – extra processing charges may apply.");
|
||||
warnings.add("OCR requested – uses credits .");
|
||||
}
|
||||
|
||||
if (session != null && session.getUsageSummary() != null) {
|
||||
ChatbotUsageSummary usage = session.getUsageSummary();
|
||||
if (usage.isLimitExceeded()) {
|
||||
warnings.add("Monthly chatbot allocation exceeded – requests may be throttled.");
|
||||
} else if (usage.isNearingLimit()) {
|
||||
warnings.add("You are approaching the monthly chatbot allocation.");
|
||||
}
|
||||
}
|
||||
|
||||
return warnings;
|
||||
|
||||
@ -2,7 +2,6 @@ package stirling.software.proprietary.model.chatbot;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
@ -20,19 +19,12 @@ public class ChatbotDocumentCacheEntry {
|
||||
private String sessionId;
|
||||
private String documentId;
|
||||
private Map<String, String> metadata;
|
||||
private String text;
|
||||
private List<ChatbotTextChunk> chunks;
|
||||
private boolean ocrApplied;
|
||||
private boolean imageContentDetected;
|
||||
private long textCharacters;
|
||||
private String vectorStoreId;
|
||||
private Instant storedAt;
|
||||
|
||||
public Map<String, String> getMetadata() {
|
||||
return metadata == null ? Collections.emptyMap() : metadata;
|
||||
}
|
||||
|
||||
public List<ChatbotTextChunk> getChunks() {
|
||||
return chunks == null ? Collections.emptyList() : chunks;
|
||||
}
|
||||
}
|
||||
|
||||
@ -26,6 +26,10 @@ public class ChatbotResponse {
|
||||
private Instant respondedAt;
|
||||
private List<String> warnings;
|
||||
private Map<String, Object> metadata;
|
||||
private long promptTokens;
|
||||
private long completionTokens;
|
||||
private long totalTokens;
|
||||
private ChatbotUsageSummary usageSummary;
|
||||
|
||||
public List<String> getWarnings() {
|
||||
return warnings == null ? Collections.emptyList() : warnings;
|
||||
|
||||
@ -21,9 +21,11 @@ public class ChatbotSession {
|
||||
private boolean alphaWarningRequired;
|
||||
private boolean imageContentDetected;
|
||||
private long textCharacters;
|
||||
private long estimatedTokens;
|
||||
private String cacheKey;
|
||||
private String vectorStoreId;
|
||||
private Instant createdAt;
|
||||
private ChatbotUsageSummary usageSummary;
|
||||
|
||||
public static String randomSessionId() {
|
||||
return UUID.randomUUID().toString();
|
||||
|
||||
@ -23,9 +23,11 @@ public class ChatbotSessionResponse {
|
||||
private boolean imageContentDetected;
|
||||
private long maxCachedCharacters;
|
||||
private long textCharacters;
|
||||
private long estimatedTokens;
|
||||
private Instant createdAt;
|
||||
private List<String> warnings;
|
||||
private Map<String, String> metadata;
|
||||
private ChatbotUsageSummary usageSummary;
|
||||
|
||||
public List<String> getWarnings() {
|
||||
return warnings == null ? Collections.emptyList() : warnings;
|
||||
|
||||
@ -1,20 +0,0 @@
|
||||
package stirling.software.proprietary.model.chatbot;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class ChatbotTextChunk {
|
||||
|
||||
private String id;
|
||||
private String text;
|
||||
private int order;
|
||||
private List<Double> embedding;
|
||||
}
|
||||
@ -0,0 +1,22 @@
|
||||
package stirling.software.proprietary.model.chatbot;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class ChatbotUsageSummary {
|
||||
|
||||
private long allocatedTokens;
|
||||
private long consumedTokens;
|
||||
private long remainingTokens;
|
||||
private double usageRatio;
|
||||
private boolean nearingLimit;
|
||||
private boolean limitExceeded;
|
||||
private long lastIncrementTokens;
|
||||
private String window;
|
||||
}
|
||||
@ -11,7 +11,6 @@ import org.springframework.context.annotation.DependsOn;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.security.authentication.ProviderManager;
|
||||
import org.springframework.security.authentication.dao.DaoAuthenticationProvider;
|
||||
import org.springframework.security.config.Customizer;
|
||||
import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity;
|
||||
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
||||
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
|
||||
@ -225,9 +224,19 @@ public class SecurityConfiguration {
|
||||
csrf.ignoringRequestMatchers(
|
||||
request -> {
|
||||
String uri = request.getRequestURI();
|
||||
String contextPath = request.getContextPath();
|
||||
String trimmedUri =
|
||||
uri.startsWith(contextPath)
|
||||
? uri.substring(
|
||||
contextPath.length())
|
||||
: uri;
|
||||
|
||||
// Ignore CSRF for auth endpoints
|
||||
if (uri.startsWith("/api/v1/auth/")) {
|
||||
// Ignore CSRF for auth endpoints + oauth/saml
|
||||
if (trimmedUri.startsWith("/api/v1/auth/")
|
||||
|| trimmedUri.startsWith("/oauth2")
|
||||
|| trimmedUri.startsWith("/saml2")
|
||||
|| trimmedUri.startsWith(
|
||||
"/login/oauth2/code/")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -363,7 +372,8 @@ public class SecurityConfiguration {
|
||||
loginAttemptService,
|
||||
securityProperties.getOauth2(),
|
||||
userService,
|
||||
jwtService))
|
||||
jwtService,
|
||||
applicationProperties))
|
||||
.failureHandler(new CustomOAuth2AuthenticationFailureHandler())
|
||||
// Add existing Authorities from the database
|
||||
.userInfoEndpoint(
|
||||
|
||||
@ -50,6 +50,7 @@ public class CustomOAuth2AuthenticationSuccessHandler
|
||||
private final ApplicationProperties.Security.OAUTH2 oauth2Properties;
|
||||
private final UserService userService;
|
||||
private final JwtServiceInterface jwtService;
|
||||
private final ApplicationProperties applicationProperties;
|
||||
|
||||
@Override
|
||||
@Audited(type = AuditEventType.USER_LOGIN, level = AuditLevel.BASIC)
|
||||
@ -57,6 +58,12 @@ public class CustomOAuth2AuthenticationSuccessHandler
|
||||
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
|
||||
throws ServletException, IOException {
|
||||
|
||||
System.out.println("[OAuth Success Handler] Starting authentication success handling");
|
||||
System.out.println("[OAuth Success Handler] Request URL: " + request.getRequestURL());
|
||||
System.out.println(
|
||||
"[OAuth Success Handler] Frontend URL from config: "
|
||||
+ applicationProperties.getSystem().getFrontendUrl());
|
||||
|
||||
Object principal = authentication.getPrincipal();
|
||||
String username = "";
|
||||
|
||||
@ -127,10 +134,16 @@ public class CustomOAuth2AuthenticationSuccessHandler
|
||||
jwtService.generateToken(
|
||||
authentication, Map.of("authType", AuthenticationType.OAUTH2));
|
||||
|
||||
System.out.println(
|
||||
"[OAuth Success Handler] JWT generated: "
|
||||
+ (jwt != null ? "YES (length: " + jwt.length() + ")" : "NO"));
|
||||
|
||||
// Build context-aware redirect URL based on the original request
|
||||
String redirectUrl =
|
||||
buildContextAwareRedirectUrl(request, response, contextPath, jwt);
|
||||
|
||||
System.out.println(
|
||||
"[OAuth Success Handler] Final redirect URL: " + redirectUrl);
|
||||
response.sendRedirect(redirectUrl);
|
||||
} else {
|
||||
// v1: redirect directly to home
|
||||
@ -170,14 +183,44 @@ public class CustomOAuth2AuthenticationSuccessHandler
|
||||
String contextPath,
|
||||
String jwt) {
|
||||
String redirectPath = resolveRedirectPath(request, contextPath);
|
||||
String origin =
|
||||
resolveForwardedOrigin(request)
|
||||
.orElseGet(
|
||||
() ->
|
||||
resolveOriginFromReferer(request)
|
||||
.orElseGet(() -> buildOriginFromRequest(request)));
|
||||
|
||||
System.out.println("[OAuth Success Handler] Resolving origin...");
|
||||
Optional<String> configuredOrigin = resolveConfiguredFrontendOrigin();
|
||||
System.out.println(
|
||||
"[OAuth Success Handler] Configured frontend origin: "
|
||||
+ configuredOrigin.orElse("NOT SET"));
|
||||
|
||||
String origin;
|
||||
if (configuredOrigin.isPresent()) {
|
||||
origin = configuredOrigin.get();
|
||||
System.out.println("[OAuth Success Handler] Using configured origin: " + origin);
|
||||
} else {
|
||||
System.out.println(
|
||||
"[OAuth Success Handler] No configured origin, trying other methods...");
|
||||
origin =
|
||||
resolveForwardedOrigin(request)
|
||||
.orElseGet(
|
||||
() -> {
|
||||
System.out.println(
|
||||
"[OAuth Success Handler] No forwarded origin, trying referer...");
|
||||
return resolveOriginFromReferer(request)
|
||||
.orElseGet(
|
||||
() -> {
|
||||
System.out.println(
|
||||
"[OAuth Success Handler] No referer, building from request...");
|
||||
return buildOriginFromRequest(request);
|
||||
});
|
||||
});
|
||||
System.out.println("[OAuth Success Handler] Resolved origin: " + origin);
|
||||
}
|
||||
|
||||
clearRedirectCookie(response);
|
||||
return origin + redirectPath + "#access_token=" + jwt;
|
||||
String finalUrl = origin + redirectPath + "#access_token=" + jwt;
|
||||
System.out.println("[OAuth Success Handler] Building redirect URL:");
|
||||
System.out.println(" - Origin: " + origin);
|
||||
System.out.println(" - Redirect path: " + redirectPath);
|
||||
System.out.println(" - Final URL: " + finalUrl);
|
||||
return finalUrl;
|
||||
}
|
||||
|
||||
private String resolveRedirectPath(HttpServletRequest request, String contextPath) {
|
||||
@ -278,6 +321,26 @@ public class CustomOAuth2AuthenticationSuccessHandler
|
||||
return origin.toString();
|
||||
}
|
||||
|
||||
private Optional<String> resolveConfiguredFrontendOrigin() {
|
||||
System.out.println("[OAuth Success Handler] Checking configured frontend URL...");
|
||||
if (applicationProperties.getSystem() == null) {
|
||||
System.out.println("[OAuth Success Handler] applicationProperties.getSystem() is NULL");
|
||||
return Optional.empty();
|
||||
}
|
||||
String configured = applicationProperties.getSystem().getFrontendUrl();
|
||||
System.out.println("[OAuth Success Handler] Frontend URL from config: " + configured);
|
||||
if (configured == null || configured.isBlank()) {
|
||||
System.out.println("[OAuth Success Handler] Frontend URL is null or blank");
|
||||
return Optional.empty();
|
||||
}
|
||||
String trimmed = configured.trim();
|
||||
if (trimmed.endsWith("/")) {
|
||||
trimmed = trimmed.substring(0, trimmed.length() - 1);
|
||||
}
|
||||
System.out.println("[OAuth Success Handler] Returning configured frontend URL: " + trimmed);
|
||||
return Optional.of(trimmed);
|
||||
}
|
||||
|
||||
private boolean isDefaultPort(String scheme, String port) {
|
||||
if (port == null) {
|
||||
return true;
|
||||
|
||||
@ -2,7 +2,6 @@ package stirling.software.proprietary.service.chatbot;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
@ -21,15 +20,14 @@ import stirling.software.common.model.ApplicationProperties.Premium;
|
||||
import stirling.software.common.model.ApplicationProperties.Premium.ProFeatures;
|
||||
import stirling.software.common.model.ApplicationProperties.Premium.ProFeatures.Chatbot;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotDocumentCacheEntry;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotTextChunk;
|
||||
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
|
||||
|
||||
@Service
|
||||
// @ConditionalOnProperty(value = "premium.proFeatures.chatbot.enabled", havingValue = "true")
|
||||
@Slf4j
|
||||
public class ChatbotCacheService {
|
||||
|
||||
private final Cache<String, ChatbotDocumentCacheEntry> documentCache;
|
||||
private final Cache<String, ChatbotDocumentCacheEntry>
|
||||
documentCache; // todo: can redis be used instead?
|
||||
private final long maxDocumentCharacters;
|
||||
private final Map<String, String> sessionToCacheKey = new ConcurrentHashMap<>();
|
||||
|
||||
@ -61,18 +59,12 @@ public class ChatbotCacheService {
|
||||
public String register(
|
||||
String sessionId,
|
||||
String documentId,
|
||||
String rawText,
|
||||
Map<String, String> metadata,
|
||||
boolean ocrApplied,
|
||||
boolean imageContentDetected,
|
||||
long textCharacters) {
|
||||
Objects.requireNonNull(sessionId, "sessionId must not be null");
|
||||
Objects.requireNonNull(documentId, "documentId must not be null");
|
||||
Objects.requireNonNull(rawText, "rawText must not be null");
|
||||
if (rawText.length() > maxDocumentCharacters) {
|
||||
throw new ChatbotException(
|
||||
"Document text exceeds maximum allowed characters: " + maxDocumentCharacters);
|
||||
}
|
||||
String cacheKey =
|
||||
sessionToCacheKey.computeIfAbsent(sessionId, k -> UUID.randomUUID().toString());
|
||||
ChatbotDocumentCacheEntry entry =
|
||||
@ -81,7 +73,6 @@ public class ChatbotCacheService {
|
||||
.sessionId(sessionId)
|
||||
.documentId(documentId)
|
||||
.metadata(metadata)
|
||||
.text(rawText)
|
||||
.ocrApplied(ocrApplied)
|
||||
.imageContentDetected(imageContentDetected)
|
||||
.textCharacters(textCharacters)
|
||||
@ -91,17 +82,6 @@ public class ChatbotCacheService {
|
||||
return cacheKey;
|
||||
}
|
||||
|
||||
public void attachChunks(String cacheKey, List<ChatbotTextChunk> chunks) {
|
||||
documentCache
|
||||
.asMap()
|
||||
.computeIfPresent(
|
||||
cacheKey,
|
||||
(key, existing) -> {
|
||||
existing.setChunks(chunks);
|
||||
return existing;
|
||||
});
|
||||
}
|
||||
|
||||
public Optional<ChatbotDocumentCacheEntry> resolveByCacheKey(String cacheKey) {
|
||||
return Optional.ofNullable(documentCache.getIfPresent(cacheKey));
|
||||
}
|
||||
|
||||
@ -16,6 +16,7 @@ import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.ollama.OllamaChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
@ -32,12 +33,12 @@ import stirling.software.proprietary.model.chatbot.ChatbotDocumentCacheEntry;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotQueryRequest;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotResponse;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotSession;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotTextChunk;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotUsageSummary;
|
||||
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
|
||||
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class ChatbotConversationService {
|
||||
|
||||
@ -46,6 +47,7 @@ public class ChatbotConversationService {
|
||||
private final ChatbotCacheService cacheService;
|
||||
private final ChatbotFeatureProperties featureProperties;
|
||||
private final ChatbotRetrievalService retrievalService;
|
||||
private final ChatbotUsageService usageService;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final AtomicBoolean modelSwitchVerified = new AtomicBoolean(false);
|
||||
|
||||
@ -74,7 +76,7 @@ public class ChatbotConversationService {
|
||||
|
||||
List<String> warnings = buildWarnings(settings, session);
|
||||
|
||||
List<ChatbotTextChunk> context =
|
||||
List<Document> context =
|
||||
retrievalService.retrieveTopK(
|
||||
request.getSessionId(), request.getPrompt(), settings);
|
||||
|
||||
@ -97,17 +99,23 @@ public class ChatbotConversationService {
|
||||
boolean escalated = false;
|
||||
if (shouldEscalate) {
|
||||
escalated = true;
|
||||
List<ChatbotTextChunk> expandedContext = ensureMinimumContext(context, cacheEntry);
|
||||
finalReply =
|
||||
invokeModel(
|
||||
settings,
|
||||
settings.models().fallback(),
|
||||
request.getPrompt(),
|
||||
session,
|
||||
expandedContext,
|
||||
context,
|
||||
cacheEntry.getMetadata());
|
||||
}
|
||||
|
||||
ChatbotUsageSummary usageSummary =
|
||||
usageService.registerGeneration(
|
||||
session.getUserId(),
|
||||
finalReply.promptTokens(),
|
||||
finalReply.completionTokens());
|
||||
session.setUsageSummary(usageSummary);
|
||||
|
||||
return ChatbotResponse.builder()
|
||||
.sessionId(request.getSessionId())
|
||||
.modelUsed(
|
||||
@ -120,6 +128,10 @@ public class ChatbotConversationService {
|
||||
.respondedAt(Instant.now())
|
||||
.warnings(warnings)
|
||||
.metadata(buildMetadata(settings, session, finalReply, context.size(), escalated))
|
||||
.promptTokens(finalReply.promptTokens())
|
||||
.completionTokens(finalReply.completionTokens())
|
||||
.totalTokens(finalReply.totalTokens())
|
||||
.usageSummary(usageSummary)
|
||||
.build();
|
||||
}
|
||||
|
||||
@ -151,6 +163,9 @@ public class ChatbotConversationService {
|
||||
metadata.put("modelProvider", settings.models().provider().name());
|
||||
metadata.put("imageContentDetected", session.isImageContentDetected());
|
||||
metadata.put("charactersCached", session.getTextCharacters());
|
||||
metadata.put("promptTokens", reply.promptTokens());
|
||||
metadata.put("completionTokens", reply.completionTokens());
|
||||
metadata.put("totalTokens", reply.totalTokens());
|
||||
return metadata;
|
||||
}
|
||||
|
||||
@ -179,29 +194,12 @@ public class ChatbotConversationService {
|
||||
}
|
||||
}
|
||||
|
||||
private List<ChatbotTextChunk> ensureMinimumContext(
|
||||
List<ChatbotTextChunk> context, ChatbotDocumentCacheEntry entry) {
|
||||
if (context.size() >= 3 || entry.getChunks().size() <= context.size()) {
|
||||
return context;
|
||||
}
|
||||
List<ChatbotTextChunk> augmented = new ArrayList<>(context);
|
||||
for (ChatbotTextChunk chunk : entry.getChunks()) {
|
||||
if (augmented.size() >= 3) {
|
||||
break;
|
||||
}
|
||||
if (!augmented.contains(chunk)) {
|
||||
augmented.add(chunk);
|
||||
}
|
||||
}
|
||||
return augmented;
|
||||
}
|
||||
|
||||
private ModelReply invokeModel(
|
||||
ChatbotSettings settings,
|
||||
String model,
|
||||
String prompt,
|
||||
ChatbotSession session,
|
||||
List<ChatbotTextChunk> context,
|
||||
List<Document> context,
|
||||
Map<String, String> metadata) {
|
||||
Prompt requestPrompt = buildPrompt(settings, model, prompt, session, context, metadata);
|
||||
ChatResponse response;
|
||||
@ -217,13 +215,27 @@ public class ChatbotConversationService {
|
||||
+ sanitizeRemoteMessage(ex.getMessage()),
|
||||
ex);
|
||||
}
|
||||
long promptTokens = 0L;
|
||||
long completionTokens = 0L;
|
||||
long totalTokens = 0L;
|
||||
if (response != null && response.getMetadata() != null) {
|
||||
org.springframework.ai.chat.metadata.Usage usage = response.getMetadata().getUsage();
|
||||
if (usage != null) {
|
||||
promptTokens = toLong(usage.getPromptTokens());
|
||||
completionTokens = toLong(usage.getCompletionTokens());
|
||||
totalTokens =
|
||||
usage.getTotalTokens() != null
|
||||
? usage.getTotalTokens()
|
||||
: promptTokens + completionTokens;
|
||||
}
|
||||
}
|
||||
String content =
|
||||
Optional.ofNullable(response)
|
||||
.map(ChatResponse::getResults)
|
||||
.filter(results -> !results.isEmpty())
|
||||
.map(results -> results.get(0).getOutput().getText())
|
||||
.orElse("");
|
||||
return parseModelResponse(content);
|
||||
return parseModelResponse(content, promptTokens, completionTokens, totalTokens);
|
||||
}
|
||||
|
||||
private Prompt buildPrompt(
|
||||
@ -231,13 +243,13 @@ public class ChatbotConversationService {
|
||||
String model,
|
||||
String question,
|
||||
ChatbotSession session,
|
||||
List<ChatbotTextChunk> context,
|
||||
List<Document> context,
|
||||
Map<String, String> metadata) {
|
||||
StringBuilder contextBuilder = new StringBuilder();
|
||||
for (ChatbotTextChunk chunk : context) {
|
||||
for (Document chunk : context) {
|
||||
contextBuilder
|
||||
.append("[Chunk ")
|
||||
.append(chunk.getOrder())
|
||||
.append(chunk.getMetadata().getOrDefault("chunkOrder", "?"))
|
||||
.append("]\n")
|
||||
.append(chunk.getText())
|
||||
.append("\n\n");
|
||||
@ -270,18 +282,24 @@ public class ChatbotConversationService {
|
||||
+ "Question: "
|
||||
+ question;
|
||||
|
||||
OpenAiChatOptions options = buildChatOptions(model);
|
||||
OpenAiChatOptions options = buildChatOptions(settings, model);
|
||||
|
||||
return new Prompt(
|
||||
List.of(new SystemMessage(systemPrompt), new UserMessage(userPrompt)), options);
|
||||
}
|
||||
|
||||
private OpenAiChatOptions buildChatOptions(String model) {
|
||||
// Note: Some models only support default temperature value of 1.0
|
||||
return OpenAiChatOptions.builder().model(model).temperature(1.0).build();
|
||||
private OpenAiChatOptions buildChatOptions(ChatbotSettings settings, String model) {
|
||||
OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder().model(model);
|
||||
String normalizedModel = model == null ? "" : model.toLowerCase();
|
||||
boolean reasoningModel = normalizedModel.startsWith("gpt-5-");
|
||||
if (!reasoningModel) {
|
||||
builder.temperature(settings.models().temperature()).topP(settings.models().topP());
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
private ModelReply parseModelResponse(String raw) {
|
||||
private ModelReply parseModelResponse(
|
||||
String raw, long promptTokens, long completionTokens, long totalTokens) {
|
||||
if (!StringUtils.hasText(raw)) {
|
||||
throw new ChatbotException("Model returned empty response");
|
||||
}
|
||||
@ -301,15 +319,35 @@ public class ChatbotConversationService {
|
||||
Optional.ofNullable(node.get("rationale"))
|
||||
.map(JsonNode::asText)
|
||||
.orElse("Model did not provide rationale");
|
||||
return new ModelReply(answer, confidence, requiresEscalation, rationale);
|
||||
return new ModelReply(
|
||||
answer,
|
||||
confidence,
|
||||
requiresEscalation,
|
||||
rationale,
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
totalTokens);
|
||||
} catch (IOException ex) {
|
||||
log.warn("Failed to parse model JSON response, returning raw text", ex);
|
||||
return new ModelReply(raw, 0.0D, true, "Unable to parse JSON response");
|
||||
return new ModelReply(
|
||||
raw,
|
||||
0.0D,
|
||||
true,
|
||||
"Unable to parse JSON response",
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
totalTokens);
|
||||
}
|
||||
}
|
||||
|
||||
private record ModelReply(
|
||||
String answer, double confidence, boolean requiresEscalation, String rationale) {}
|
||||
String answer,
|
||||
double confidence,
|
||||
boolean requiresEscalation,
|
||||
String rationale,
|
||||
long promptTokens,
|
||||
long completionTokens,
|
||||
long totalTokens) {}
|
||||
|
||||
private String sanitizeRemoteMessage(String message) {
|
||||
if (!StringUtils.hasText(message)) {
|
||||
@ -317,4 +355,8 @@ public class ChatbotConversationService {
|
||||
}
|
||||
return message.replaceAll("(?i)api[-_ ]?key\\s*=[^\\s]+", "api-key=***");
|
||||
}
|
||||
|
||||
private long toLong(Integer value) {
|
||||
return value == null ? 0L : value.longValue();
|
||||
}
|
||||
}
|
||||
|
||||
@ -26,7 +26,9 @@ public class ChatbotFeatureProperties {
|
||||
resolveProvider(chatbot.getModels().getProvider()),
|
||||
chatbot.getModels().getPrimary(),
|
||||
chatbot.getModels().getFallback(),
|
||||
chatbot.getModels().getEmbedding());
|
||||
chatbot.getModels().getEmbedding(),
|
||||
chatbot.getModels().getTemperature(),
|
||||
chatbot.getModels().getTopP());
|
||||
return new ChatbotSettings(
|
||||
chatbot.isEnabled(),
|
||||
chatbot.isAlphaWarning(),
|
||||
@ -42,7 +44,10 @@ public class ChatbotFeatureProperties {
|
||||
chatbot.getCache().getMaxEntries(),
|
||||
chatbot.getCache().getMaxDocumentCharacters()),
|
||||
new ChatbotSettings.OcrSettings(chatbot.getOcr().isEnabledByDefault()),
|
||||
new ChatbotSettings.AuditSettings(chatbot.getAudit().isEnabled()));
|
||||
new ChatbotSettings.AuditSettings(chatbot.getAudit().isEnabled()),
|
||||
new ChatbotSettings.UsageSettings(
|
||||
chatbot.getUsage().getPerUserMonthlyTokens(),
|
||||
chatbot.getUsage().getWarnAtRatio()));
|
||||
}
|
||||
|
||||
public boolean isEnabled() {
|
||||
@ -77,10 +82,16 @@ public class ChatbotFeatureProperties {
|
||||
RagSettings rag,
|
||||
CacheSettings cache,
|
||||
OcrSettings ocr,
|
||||
AuditSettings audit) {
|
||||
AuditSettings audit,
|
||||
UsageSettings usage) {
|
||||
|
||||
public record ModelSettings(
|
||||
ModelProvider provider, String primary, String fallback, String embedding) {}
|
||||
ModelProvider provider,
|
||||
String primary,
|
||||
String fallback,
|
||||
String embedding,
|
||||
double temperature,
|
||||
double topP) {}
|
||||
|
||||
public record RagSettings(int chunkSizeTokens, int chunkOverlapTokens, int topK) {}
|
||||
|
||||
@ -90,6 +101,8 @@ public class ChatbotFeatureProperties {
|
||||
|
||||
public record AuditSettings(boolean enabled) {}
|
||||
|
||||
public record UsageSettings(long perUserMonthlyTokens, double warnAtRatio) {}
|
||||
|
||||
public enum ModelProvider {
|
||||
OPENAI,
|
||||
OLLAMA
|
||||
|
||||
@ -5,10 +5,9 @@ import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
@ -17,7 +16,6 @@ import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotSession;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotSessionCreateRequest;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotTextChunk;
|
||||
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
|
||||
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
|
||||
import stirling.software.proprietary.service.chatbot.exception.NoTextDetectedException;
|
||||
@ -30,7 +28,8 @@ public class ChatbotIngestionService {
|
||||
private final ChatbotCacheService cacheService;
|
||||
private final ChatbotSessionRegistry sessionRegistry;
|
||||
private final ChatbotFeatureProperties featureProperties;
|
||||
private final EmbeddingModel embeddingModel;
|
||||
private final VectorStore vectorStore;
|
||||
private final ChatbotUsageService usageService;
|
||||
|
||||
public ChatbotSession ingest(ChatbotSessionCreateRequest request) {
|
||||
ChatbotSettings settings = featureProperties.current();
|
||||
@ -40,18 +39,23 @@ public class ChatbotIngestionService {
|
||||
if (!request.isWarningsAccepted() && settings.alphaWarning()) {
|
||||
throw new ChatbotException("Alpha warning must be accepted before use");
|
||||
}
|
||||
boolean hasText = StringUtils.hasText(request.getText());
|
||||
if (!hasText) {
|
||||
if (!StringUtils.hasText(request.getText())) {
|
||||
throw new NoTextDetectedException(
|
||||
"No text detected in document payload. Images are currently unsupported – enable OCR to continue.");
|
||||
}
|
||||
|
||||
long characterLimit = cacheService.getMaxDocumentCharacters();
|
||||
long textCharacters = request.getText().length();
|
||||
if (textCharacters > characterLimit) {
|
||||
throw new ChatbotException(
|
||||
"Document text exceeds maximum allowed characters: " + characterLimit);
|
||||
}
|
||||
|
||||
String sessionId =
|
||||
StringUtils.hasText(request.getSessionId())
|
||||
? request.getSessionId()
|
||||
: ChatbotSession.randomSessionId();
|
||||
boolean imagesDetected = request.isImagesDetected();
|
||||
long textCharacters = request.getText().length();
|
||||
boolean ocrApplied = request.isOcrRequested();
|
||||
Map<String, String> metadata = new HashMap<>();
|
||||
if (request.getMetadata() != null) {
|
||||
@ -63,23 +67,28 @@ public class ChatbotIngestionService {
|
||||
"content.extractionSource", ocrApplied ? "ocr-text-layer" : "embedded-text-layer");
|
||||
Map<String, String> immutableMetadata = Map.copyOf(metadata);
|
||||
|
||||
List<Document> documents =
|
||||
buildDocuments(
|
||||
sessionId, request.getDocumentId(), request.getText(), metadata, settings);
|
||||
try {
|
||||
vectorStore.add(documents);
|
||||
} catch (RuntimeException ex) {
|
||||
throw new ChatbotException(
|
||||
"Failed to index document content in vector store: "
|
||||
+ sanitizeRemoteMessage(ex.getMessage()),
|
||||
ex);
|
||||
}
|
||||
|
||||
String cacheKey =
|
||||
cacheService.register(
|
||||
sessionId,
|
||||
request.getDocumentId(),
|
||||
request.getText(),
|
||||
immutableMetadata,
|
||||
ocrApplied,
|
||||
imagesDetected,
|
||||
textCharacters);
|
||||
|
||||
List<String> chunkTexts =
|
||||
chunkText(
|
||||
request.getText(),
|
||||
settings.rag().chunkSizeTokens(),
|
||||
settings.rag().chunkOverlapTokens());
|
||||
List<ChatbotTextChunk> chunks = embedChunks(sessionId, cacheKey, chunkTexts, metadata);
|
||||
cacheService.attachChunks(cacheKey, chunks);
|
||||
long estimatedTokens = Math.max(1L, Math.round(textCharacters / 4.0));
|
||||
|
||||
ChatbotSession session =
|
||||
ChatbotSession.builder()
|
||||
@ -90,93 +99,65 @@ public class ChatbotIngestionService {
|
||||
.ocrRequested(ocrApplied)
|
||||
.imageContentDetected(imagesDetected)
|
||||
.textCharacters(textCharacters)
|
||||
.estimatedTokens(estimatedTokens)
|
||||
.warningsAccepted(request.isWarningsAccepted())
|
||||
.alphaWarningRequired(settings.alphaWarning())
|
||||
.cacheKey(cacheKey)
|
||||
.createdAt(Instant.now())
|
||||
.build();
|
||||
session.setUsageSummary(
|
||||
usageService.registerIngestion(session.getUserId(), estimatedTokens));
|
||||
sessionRegistry.register(session);
|
||||
log.info(
|
||||
"Registered chatbot session {} for document {} with {} chunks",
|
||||
"Registered chatbot session {} for document {} with {} RAG chunks",
|
||||
sessionId,
|
||||
request.getDocumentId(),
|
||||
chunks.size());
|
||||
documents.size());
|
||||
return session;
|
||||
}
|
||||
|
||||
private List<String> chunkText(String text, int chunkSizeTokens, int overlapTokens) {
|
||||
String[] tokens = text.split("\\s+");
|
||||
List<String> chunks = new ArrayList<>();
|
||||
if (tokens.length == 0) {
|
||||
return chunks;
|
||||
private List<Document> buildDocuments(
|
||||
String sessionId,
|
||||
String documentId,
|
||||
String text,
|
||||
Map<String, String> metadata,
|
||||
ChatbotSettings settings) {
|
||||
List<Document> documents = new ArrayList<>();
|
||||
if (!StringUtils.hasText(text)) {
|
||||
return documents;
|
||||
}
|
||||
int effectiveChunk = Math.max(chunkSizeTokens, 1);
|
||||
int effectiveOverlap = Math.max(Math.min(overlapTokens, effectiveChunk - 1), 0);
|
||||
|
||||
int chunkChars = Math.max(512, settings.rag().chunkSizeTokens() * 4);
|
||||
int overlapChars = Math.max(64, settings.rag().chunkOverlapTokens() * 4);
|
||||
|
||||
int index = 0;
|
||||
while (index < tokens.length) {
|
||||
int end = Math.min(tokens.length, index + effectiveChunk);
|
||||
String chunk = String.join(" ", java.util.Arrays.copyOfRange(tokens, index, end));
|
||||
if (StringUtils.hasText(chunk)) {
|
||||
chunks.add(chunk);
|
||||
int order = 0;
|
||||
while (index < text.length()) {
|
||||
int end = Math.min(text.length(), index + chunkChars);
|
||||
String chunk = text.substring(index, end).trim();
|
||||
if (!chunk.isEmpty()) {
|
||||
Document document = new Document(chunk);
|
||||
document.getMetadata().putAll(metadata);
|
||||
document.getMetadata().put("sessionId", sessionId);
|
||||
document.getMetadata().put("documentId", documentId);
|
||||
document.getMetadata().put("chunkOrder", Integer.toString(order));
|
||||
documents.add(document);
|
||||
order++;
|
||||
}
|
||||
if (end == tokens.length) {
|
||||
if (end == text.length()) {
|
||||
break;
|
||||
}
|
||||
index = end - effectiveOverlap;
|
||||
if (index <= 0) {
|
||||
index = end;
|
||||
int nextIndex = end - overlapChars;
|
||||
if (nextIndex <= index) {
|
||||
nextIndex = end;
|
||||
}
|
||||
index = nextIndex;
|
||||
}
|
||||
return chunks;
|
||||
}
|
||||
|
||||
private List<ChatbotTextChunk> embedChunks(
|
||||
String sessionId,
|
||||
String cacheKey,
|
||||
List<String> chunkTexts,
|
||||
Map<String, String> metadata) {
|
||||
if (chunkTexts.isEmpty()) {
|
||||
throw new ChatbotException("Unable to split document text into retrievable chunks");
|
||||
if (documents.isEmpty()) {
|
||||
throw new ChatbotException("Unable to split document text into searchable chunks");
|
||||
}
|
||||
EmbeddingResponse response;
|
||||
try {
|
||||
response = embeddingModel.embedForResponse(chunkTexts);
|
||||
} catch (org.eclipse.jetty.client.HttpResponseException ex) {
|
||||
throw new ChatbotException(
|
||||
"Embedding provider rejected the request: "
|
||||
+ sanitizeRemoteMessage(ex.getMessage()),
|
||||
ex);
|
||||
} catch (RuntimeException ex) {
|
||||
throw new ChatbotException(
|
||||
"Failed to compute embeddings for chatbot ingestion: "
|
||||
+ sanitizeRemoteMessage(ex.getMessage()),
|
||||
ex);
|
||||
}
|
||||
if (response.getResults().size() != chunkTexts.size()) {
|
||||
throw new ChatbotException("Mismatch between chunks and embedding results");
|
||||
}
|
||||
List<ChatbotTextChunk> chunks = new ArrayList<>();
|
||||
for (int i = 0; i < chunkTexts.size(); i++) {
|
||||
String chunkId = sessionId + ":" + i + ":" + UUID.randomUUID();
|
||||
float[] embeddingArray = response.getResults().get(i).getOutput();
|
||||
List<Double> embedding = new ArrayList<>(embeddingArray.length);
|
||||
for (float value : embeddingArray) {
|
||||
embedding.add((double) value);
|
||||
}
|
||||
chunks.add(
|
||||
ChatbotTextChunk.builder()
|
||||
.id(chunkId)
|
||||
.order(i)
|
||||
.text(chunkTexts.get(i))
|
||||
.embedding(embedding)
|
||||
.build());
|
||||
}
|
||||
log.debug(
|
||||
"Computed embeddings for session {} cacheKey {} ({} vectors)",
|
||||
sessionId,
|
||||
cacheKey,
|
||||
chunks.size());
|
||||
return chunks;
|
||||
return documents;
|
||||
}
|
||||
|
||||
private String sanitizeRemoteMessage(String message) {
|
||||
|
||||
@ -1,20 +1,16 @@
|
||||
package stirling.software.proprietary.service.chatbot;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.vectorstore.SearchRequest;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotDocumentCacheEntry;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotTextChunk;
|
||||
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
|
||||
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
|
||||
|
||||
@ -24,69 +20,53 @@ import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
|
||||
public class ChatbotRetrievalService {
|
||||
|
||||
private final ChatbotCacheService cacheService;
|
||||
private final EmbeddingModel embeddingModel;
|
||||
private final VectorStore vectorStore;
|
||||
|
||||
public List<ChatbotTextChunk> retrieveTopK(
|
||||
String sessionId, String query, ChatbotSettings settings) {
|
||||
ChatbotDocumentCacheEntry entry =
|
||||
cacheService
|
||||
.resolveBySessionId(sessionId)
|
||||
.orElseThrow(() -> new ChatbotException("Unknown chatbot session"));
|
||||
List<ChatbotTextChunk> chunks = entry.getChunks();
|
||||
if (CollectionUtils.isEmpty(chunks)) {
|
||||
throw new ChatbotException("Chatbot cache does not contain pre-computed chunks");
|
||||
public List<Document> retrieveTopK(String sessionId, String query, ChatbotSettings settings) {
|
||||
cacheService
|
||||
.resolveBySessionId(sessionId)
|
||||
.orElseThrow(() -> new ChatbotException("Unknown chatbot session"));
|
||||
|
||||
int topK = Math.max(settings.rag().topK(), 1);
|
||||
String sanitizedQuery = StringUtils.hasText(query) ? query : "";
|
||||
String filterExpression = "sessionId == '" + escape(sessionId) + "'";
|
||||
SearchRequest searchRequest =
|
||||
SearchRequest.builder()
|
||||
.query(sanitizedQuery)
|
||||
.topK(topK)
|
||||
.filterExpression(filterExpression)
|
||||
.build();
|
||||
List<Document> results;
|
||||
try {
|
||||
results = vectorStore.similaritySearch(searchRequest);
|
||||
} catch (RuntimeException ex) {
|
||||
throw new ChatbotException(
|
||||
"Failed to perform vector similarity search: "
|
||||
+ sanitizeRemoteMessage(ex.getMessage()),
|
||||
ex);
|
||||
}
|
||||
List<Double> queryEmbedding = computeQueryEmbedding(query);
|
||||
List<ScoredChunk> scoredChunks = new ArrayList<>();
|
||||
for (ChatbotTextChunk chunk : chunks) {
|
||||
if (CollectionUtils.isEmpty(chunk.getEmbedding())) {
|
||||
log.warn("Chunk {} missing embedding, skipping", chunk.getId());
|
||||
continue;
|
||||
}
|
||||
double score = cosineSimilarity(queryEmbedding, chunk.getEmbedding());
|
||||
scoredChunks.add(new ScoredChunk(chunk, score));
|
||||
results =
|
||||
results.stream()
|
||||
.filter(
|
||||
doc ->
|
||||
sessionId.equals(
|
||||
doc.getMetadata().getOrDefault("sessionId", "")))
|
||||
.limit(topK)
|
||||
.toList();
|
||||
if (results.isEmpty()) {
|
||||
throw new ChatbotException("No context available for this chatbot session");
|
||||
}
|
||||
return scoredChunks.stream()
|
||||
.sorted(Comparator.comparingDouble(ScoredChunk::score).reversed())
|
||||
.limit(Math.max(settings.rag().topK(), 1))
|
||||
.map(ScoredChunk::chunk)
|
||||
.toList();
|
||||
return results;
|
||||
}
|
||||
|
||||
private List<Double> computeQueryEmbedding(String query) {
|
||||
EmbeddingResponse response = embeddingModel.embedForResponse(List.of(query));
|
||||
float[] embeddingArray =
|
||||
Optional.ofNullable(response.getResults().stream().findFirst().orElse(null))
|
||||
.map(org.springframework.ai.embedding.Embedding::getOutput)
|
||||
.orElseThrow(
|
||||
() -> new ChatbotException("Failed to compute query embedding"));
|
||||
List<Double> embedding = new ArrayList<>(embeddingArray.length);
|
||||
for (float value : embeddingArray) {
|
||||
embedding.add((double) value);
|
||||
private String sanitizeRemoteMessage(String message) {
|
||||
if (!StringUtils.hasText(message)) {
|
||||
return "unexpected provider error";
|
||||
}
|
||||
return embedding;
|
||||
return message.replaceAll("(?i)api[-_ ]?key\\s*=[^\\s]+", "api-key=***");
|
||||
}
|
||||
|
||||
private double cosineSimilarity(List<Double> v1, List<Double> v2) {
|
||||
int size = Math.min(v1.size(), v2.size());
|
||||
if (size == 0) {
|
||||
return -1.0;
|
||||
}
|
||||
double dot = 0.0;
|
||||
double mag1 = 0.0;
|
||||
double mag2 = 0.0;
|
||||
for (int i = 0; i < size; i++) {
|
||||
double a = v1.get(i);
|
||||
double b = v2.get(i);
|
||||
dot += a * b;
|
||||
mag1 += a * a;
|
||||
mag2 += b * b;
|
||||
}
|
||||
if (mag1 == 0.0 || mag2 == 0.0) {
|
||||
return -1.0;
|
||||
}
|
||||
return dot / (Math.sqrt(mag1) * Math.sqrt(mag2));
|
||||
private String escape(String value) {
|
||||
return value.replace("'", "\\'");
|
||||
}
|
||||
|
||||
private record ScoredChunk(ChatbotTextChunk chunk, double score) {}
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@ -12,6 +13,7 @@ import stirling.software.proprietary.model.chatbot.ChatbotQueryRequest;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotResponse;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotSession;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotSessionCreateRequest;
|
||||
import stirling.software.proprietary.security.service.UserService;
|
||||
import stirling.software.proprietary.service.AuditService;
|
||||
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
|
||||
|
||||
@ -26,8 +28,12 @@ public class ChatbotService {
|
||||
private final ChatbotCacheService cacheService;
|
||||
private final ChatbotFeatureProperties featureProperties;
|
||||
private final AuditService auditService;
|
||||
private final UserService userService;
|
||||
|
||||
public ChatbotSession createSession(ChatbotSessionCreateRequest request) {
|
||||
if (!StringUtils.hasText(request.getUserId())) {
|
||||
request.setUserId(userService.getCurrentUsername());
|
||||
}
|
||||
ChatbotSession session = ingestionService.ingest(request);
|
||||
log.debug("Chatbot session {} initialised", session.getSessionId());
|
||||
audit(
|
||||
|
||||
@ -0,0 +1,106 @@
|
||||
package stirling.software.proprietary.service.chatbot;
|
||||
|
||||
import java.time.YearMonth;
|
||||
import java.time.ZoneOffset;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotUsageSummary;
|
||||
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class ChatbotUsageService {
|
||||
|
||||
private final ChatbotFeatureProperties featureProperties;
|
||||
|
||||
private final Map<String, UsageWindow> usageByUser = new ConcurrentHashMap<>();
|
||||
|
||||
public ChatbotUsageSummary registerIngestion(String userId, long estimatedTokens) {
|
||||
return incrementUsage(userId, Math.max(estimatedTokens, 0L));
|
||||
}
|
||||
|
||||
public ChatbotUsageSummary registerGeneration(
|
||||
String userId, long promptTokens, long completionTokens) {
|
||||
long total = Math.max(promptTokens + completionTokens, 0L);
|
||||
return incrementUsage(userId, total);
|
||||
}
|
||||
|
||||
public ChatbotUsageSummary currentUsage(String userId) {
|
||||
String key = normalizeUserId(userId);
|
||||
UsageWindow window = usageByUser.get(key);
|
||||
if (window == null) {
|
||||
return buildSummary(key, 0L, 0L);
|
||||
}
|
||||
return buildSummary(key, window.tokens.get(), 0L);
|
||||
}
|
||||
|
||||
private ChatbotUsageSummary incrementUsage(String userId, long deltaTokens) {
|
||||
String key = normalizeUserId(userId);
|
||||
YearMonth now = YearMonth.now(ZoneOffset.UTC);
|
||||
UsageWindow window =
|
||||
usageByUser.compute(
|
||||
key,
|
||||
(ignored, existing) -> {
|
||||
if (existing == null || !existing.window.equals(now)) {
|
||||
existing = new UsageWindow(now);
|
||||
}
|
||||
if (deltaTokens > 0) {
|
||||
existing.tokens.addAndGet(deltaTokens);
|
||||
}
|
||||
return existing;
|
||||
});
|
||||
return buildSummary(key, window.tokens.get(), deltaTokens);
|
||||
}
|
||||
|
||||
private ChatbotUsageSummary buildSummary(String userKey, long consumed, long deltaTokens) {
|
||||
ChatbotSettings settings = featureProperties.current();
|
||||
long allocation = Math.max(settings.usage().perUserMonthlyTokens(), 1L);
|
||||
double ratio = allocation == 0 ? 1.0 : (double) consumed / allocation;
|
||||
long remaining = Math.max(allocation - consumed, 0L);
|
||||
boolean limitExceeded = consumed > allocation;
|
||||
boolean nearingLimit = ratio >= settings.usage().warnAtRatio();
|
||||
return ChatbotUsageSummary.builder()
|
||||
.allocatedTokens(allocation)
|
||||
.consumedTokens(consumed)
|
||||
.remainingTokens(remaining)
|
||||
.usageRatio(Math.min(ratio, 1.0))
|
||||
.nearingLimit(nearingLimit)
|
||||
.limitExceeded(limitExceeded)
|
||||
.lastIncrementTokens(deltaTokens)
|
||||
.window(currentWindowDescription(userKey))
|
||||
.build();
|
||||
}
|
||||
|
||||
private String currentWindowDescription(String userKey) {
|
||||
UsageWindow window = usageByUser.get(userKey);
|
||||
if (window == null) {
|
||||
return YearMonth.now(ZoneOffset.UTC).toString();
|
||||
}
|
||||
return window.window.toString();
|
||||
}
|
||||
|
||||
private String normalizeUserId(String userId) {
|
||||
if (!StringUtils.hasText(userId)) {
|
||||
return "anonymous";
|
||||
}
|
||||
return userId.trim().toLowerCase();
|
||||
}
|
||||
|
||||
private static final class UsageWindow {
|
||||
private final YearMonth window;
|
||||
private final AtomicLong tokens = new AtomicLong();
|
||||
|
||||
private UsageWindow(YearMonth window) {
|
||||
this.window = window;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,7 +1,6 @@
|
||||
package stirling.software.proprietary.service.chatbot;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import java.util.Map;
|
||||
@ -11,7 +10,6 @@ import org.junit.jupiter.api.Test;
|
||||
|
||||
import stirling.software.common.model.ApplicationProperties;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotDocumentCacheEntry;
|
||||
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
|
||||
|
||||
class ChatbotCacheServiceTest {
|
||||
|
||||
@ -34,23 +32,6 @@ class ChatbotCacheServiceTest {
|
||||
properties.setPremium(premium);
|
||||
}
|
||||
|
||||
@Test
|
||||
void registerRejectsOversizedText() {
|
||||
ChatbotCacheService cacheService = new ChatbotCacheService(properties);
|
||||
String longText = "a".repeat(51);
|
||||
assertThrows(
|
||||
ChatbotException.class,
|
||||
() ->
|
||||
cacheService.register(
|
||||
"session",
|
||||
"doc",
|
||||
longText,
|
||||
Map.of(),
|
||||
false,
|
||||
false,
|
||||
longText.length()));
|
||||
}
|
||||
|
||||
@Test
|
||||
void registerAndResolveSession() {
|
||||
ChatbotCacheService cacheService = new ChatbotCacheService(properties);
|
||||
@ -58,7 +39,6 @@ class ChatbotCacheServiceTest {
|
||||
cacheService.register(
|
||||
"session1",
|
||||
"doc1",
|
||||
"hello world",
|
||||
Map.of("title", "Sample"),
|
||||
false,
|
||||
false,
|
||||
|
||||
@ -22,6 +22,7 @@ import stirling.software.proprietary.model.chatbot.ChatbotQueryRequest;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotResponse;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotSession;
|
||||
import stirling.software.proprietary.model.chatbot.ChatbotSessionCreateRequest;
|
||||
import stirling.software.proprietary.security.service.UserService;
|
||||
import stirling.software.proprietary.service.AuditService;
|
||||
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
|
||||
|
||||
@ -34,6 +35,7 @@ class ChatbotServiceTest {
|
||||
@Mock private ChatbotCacheService cacheService;
|
||||
@Mock private ChatbotFeatureProperties featureProperties;
|
||||
@Mock private AuditService auditService;
|
||||
@Mock private UserService userService;
|
||||
|
||||
@InjectMocks private ChatbotService chatbotService;
|
||||
|
||||
@ -52,11 +54,14 @@ class ChatbotServiceTest {
|
||||
ChatbotSettings.ModelProvider.OPENAI,
|
||||
"gpt-5-nano",
|
||||
"gpt-5-mini",
|
||||
"embed"),
|
||||
"embed",
|
||||
0.2D,
|
||||
0.95D),
|
||||
new ChatbotSettings.RagSettings(512, 128, 4),
|
||||
new ChatbotSettings.CacheSettings(60, 10, 1000),
|
||||
new ChatbotSettings.OcrSettings(false),
|
||||
new ChatbotSettings.AuditSettings(true));
|
||||
new ChatbotSettings.AuditSettings(true),
|
||||
new ChatbotSettings.UsageSettings(100000L, 0.7D));
|
||||
|
||||
auditDisabledSettings =
|
||||
new ChatbotSettings(
|
||||
@ -68,11 +73,14 @@ class ChatbotServiceTest {
|
||||
ChatbotSettings.ModelProvider.OPENAI,
|
||||
"gpt-5-nano",
|
||||
"gpt-5-mini",
|
||||
"embed"),
|
||||
"embed",
|
||||
0.2D,
|
||||
0.95D),
|
||||
new ChatbotSettings.RagSettings(512, 128, 4),
|
||||
new ChatbotSettings.CacheSettings(60, 10, 1000),
|
||||
new ChatbotSettings.OcrSettings(false),
|
||||
new ChatbotSettings.AuditSettings(false));
|
||||
new ChatbotSettings.AuditSettings(false),
|
||||
new ChatbotSettings.UsageSettings(100000L, 0.7D));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -86,6 +94,7 @@ class ChatbotServiceTest {
|
||||
.build();
|
||||
when(ingestionService.ingest(any())).thenReturn(session);
|
||||
when(featureProperties.current()).thenReturn(auditEnabledSettings);
|
||||
when(userService.getCurrentUsername()).thenReturn("tester");
|
||||
|
||||
chatbotService.createSession(
|
||||
ChatbotSessionCreateRequest.builder().text("abc").warningsAccepted(true).build());
|
||||
@ -97,6 +106,7 @@ class ChatbotServiceTest {
|
||||
payloadCaptor.capture());
|
||||
Map<String, Object> payload = payloadCaptor.getValue();
|
||||
verify(cacheService, times(0)).invalidateSession(any());
|
||||
verify(userService).getCurrentUsername();
|
||||
org.junit.jupiter.api.Assertions.assertEquals("session-1", payload.get("sessionId"));
|
||||
}
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ import { runOcrForChat } from '@app/services/chatbotOcrService';
|
||||
import {
|
||||
ChatbotMessageResponse,
|
||||
ChatbotSessionInfo,
|
||||
ChatbotUsageSummary,
|
||||
createChatbotSession,
|
||||
sendChatbotPrompt,
|
||||
} from '@app/services/chatbotService';
|
||||
@ -81,6 +82,7 @@ const ChatbotDrawer = () => {
|
||||
const [pendingOcrRetry, setPendingOcrRetry] = useState(false);
|
||||
const scrollViewportRef = useRef<HTMLDivElement>(null);
|
||||
const [panelAnchor, setPanelAnchor] = useState<{ right: number; top: number } | null>(null);
|
||||
const usageAlertState = useRef<'none' | 'warned' | 'limit'>('none');
|
||||
|
||||
const selectedFile = useMemo<StirlingFile | undefined>(
|
||||
() => files.find((file) => file.fileId === selectedFileId),
|
||||
@ -120,6 +122,39 @@ const ChatbotDrawer = () => {
|
||||
}
|
||||
}, [messages, isOpen]);
|
||||
|
||||
useEffect(() => {
|
||||
usageAlertState.current = 'none';
|
||||
}, [sessionInfo?.sessionId]);
|
||||
|
||||
const maybeShowUsageWarning = (usage?: ChatbotUsageSummary | null) => {
|
||||
if (!usage) {
|
||||
return;
|
||||
}
|
||||
if (usage.limitExceeded && usageAlertState.current !== 'limit') {
|
||||
usageAlertState.current = 'limit';
|
||||
show({
|
||||
alertType: 'warning',
|
||||
title: t('chatbot.usage.limitReachedTitle', 'Chatbot limit reached'),
|
||||
body: t(
|
||||
'chatbot.usage.limitReachedBody',
|
||||
'You have exceeded the current monthly allocation for the chatbot. Further responses may be throttled.'
|
||||
),
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (usage.nearingLimit && usageAlertState.current === 'none') {
|
||||
usageAlertState.current = 'warned';
|
||||
show({
|
||||
alertType: 'warning',
|
||||
title: t('chatbot.usage.nearingLimitTitle', 'Approaching usage limit'),
|
||||
body: t(
|
||||
'chatbot.usage.nearingLimitBody',
|
||||
'You are nearing your monthly chatbot allocation. Consider limiting very large requests.'
|
||||
),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (sessionInfo && sessionInfo.documentId !== selectedFileId) {
|
||||
setSessionInfo(null);
|
||||
@ -245,6 +280,7 @@ const ChatbotDrawer = () => {
|
||||
);
|
||||
|
||||
setSessionInfo(response);
|
||||
maybeShowUsageWarning(response.usageSummary);
|
||||
setContextStats({
|
||||
pageCount: extractionResult.pageCount,
|
||||
characterCount: extractionResult.characterCount,
|
||||
@ -295,6 +331,7 @@ const ChatbotDrawer = () => {
|
||||
prompt: trimmedPrompt,
|
||||
allowEscalation: true,
|
||||
});
|
||||
maybeShowUsageWarning(reply.usageSummary);
|
||||
setWarnings(reply.warnings ?? []);
|
||||
const assistant = convertAssistantMessage(reply);
|
||||
setMessages((prev) => [...prev, assistant]);
|
||||
|
||||
@ -87,6 +87,34 @@ export const AppConfigProvider: React.FC<AppConfigProviderProps> = ({
|
||||
const initialDelay = retryOptions?.initialDelay ?? 1000;
|
||||
|
||||
const fetchConfig = useCallback(async (force = false) => {
|
||||
// First check if user has a JWT token - if not, they're not authenticated
|
||||
const hasJWT = localStorage.getItem('stirling_jwt');
|
||||
|
||||
// Check if on auth page
|
||||
// Need to check for paths with or without base path
|
||||
const pathname = window.location.pathname;
|
||||
const isAuthPage = pathname.endsWith('/login') ||
|
||||
pathname.endsWith('/signup') ||
|
||||
pathname.endsWith('/auth/callback') ||
|
||||
pathname.includes('/auth/') ||
|
||||
pathname.includes('/invite/');
|
||||
|
||||
// Skip config fetch if:
|
||||
// 1. On auth page, OR
|
||||
// 2. No JWT token (not authenticated) and not forcing
|
||||
if (isAuthPage || (!hasJWT && !force)) {
|
||||
console.debug('[AppConfig] Skipping config fetch:', {
|
||||
reason: isAuthPage ? 'On auth page' : 'No JWT token',
|
||||
pathname,
|
||||
hasJWT: !!hasJWT,
|
||||
force
|
||||
});
|
||||
setLoading(false);
|
||||
setConfig({ enableLogin: true });
|
||||
setHasResolvedConfig(true);
|
||||
return;
|
||||
}
|
||||
|
||||
// Prevent duplicate fetches unless forced
|
||||
if (!force && fetchCount > 0) {
|
||||
console.debug('[AppConfig] Already fetched, skipping');
|
||||
@ -109,6 +137,16 @@ export const AppConfigProvider: React.FC<AppConfigProviderProps> = ({
|
||||
console.log('[AppConfig] Fetching app config...');
|
||||
}
|
||||
|
||||
// GUARD: Only make the API call if user has JWT token
|
||||
const currentJWT = localStorage.getItem('stirling_jwt');
|
||||
if (!currentJWT && !force) {
|
||||
console.debug('[AppConfig] No JWT token, skipping API call entirely');
|
||||
setConfig({ enableLogin: true });
|
||||
setHasResolvedConfig(true);
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// apiClient automatically adds JWT header if available via interceptors
|
||||
const response = await apiClient.get<AppConfig>('/api/v1/config/app-config', !isBlockingMode ? { suppressErrorToast: true } : undefined);
|
||||
const data = response.data;
|
||||
@ -155,12 +193,31 @@ export const AppConfigProvider: React.FC<AppConfigProviderProps> = ({
|
||||
}, [fetchCount, hasResolvedConfig, isBlockingMode, maxRetries, initialDelay]);
|
||||
|
||||
useEffect(() => {
|
||||
// Skip fetching config on login and auth callback pages
|
||||
// Need to check for paths with or without base path
|
||||
const pathname = window.location.pathname;
|
||||
const isAuthPage = pathname.endsWith('/login') ||
|
||||
pathname.endsWith('/signup') ||
|
||||
pathname.endsWith('/auth/callback') ||
|
||||
pathname.includes('/auth/') ||
|
||||
pathname.includes('/invite/');
|
||||
|
||||
if (isAuthPage) {
|
||||
console.debug('[AppConfig] On auth page, skipping config fetch in useEffect');
|
||||
console.debug('[AppConfig] Current pathname:', pathname);
|
||||
setLoading(false);
|
||||
// Set minimal config for auth pages
|
||||
setConfig({ enableLogin: true });
|
||||
setHasResolvedConfig(true);
|
||||
return;
|
||||
}
|
||||
|
||||
// Always try to fetch config to check if login is disabled
|
||||
// The endpoint should be public and return proper JSON
|
||||
if (autoFetch) {
|
||||
fetchConfig();
|
||||
}
|
||||
}, [autoFetch, fetchConfig]);
|
||||
}, [autoFetch]); // Remove fetchConfig from deps to prevent re-runs
|
||||
|
||||
// Listen for JWT availability (triggered on login/signup)
|
||||
useEffect(() => {
|
||||
|
||||
@ -1,5 +1,16 @@
|
||||
import apiClient from '@app/services/apiClient';
|
||||
|
||||
export interface ChatbotUsageSummary {
|
||||
allocatedTokens: number;
|
||||
consumedTokens: number;
|
||||
remainingTokens: number;
|
||||
usageRatio: number;
|
||||
nearingLimit: boolean;
|
||||
limitExceeded: boolean;
|
||||
lastIncrementTokens: number;
|
||||
window?: string;
|
||||
}
|
||||
|
||||
export interface ChatbotSessionPayload {
|
||||
sessionId?: string;
|
||||
documentId: string;
|
||||
@ -17,8 +28,11 @@ export interface ChatbotSessionInfo {
|
||||
ocrRequested: boolean;
|
||||
maxCachedCharacters: number;
|
||||
createdAt: string;
|
||||
textCharacters: number;
|
||||
estimatedTokens: number;
|
||||
warnings?: string[];
|
||||
metadata?: Record<string, string>;
|
||||
usageSummary?: ChatbotUsageSummary;
|
||||
}
|
||||
|
||||
export interface ChatbotQueryPayload {
|
||||
@ -37,6 +51,10 @@ export interface ChatbotMessageResponse {
|
||||
cacheHit?: boolean;
|
||||
warnings?: string[];
|
||||
metadata?: Record<string, unknown>;
|
||||
promptTokens?: number;
|
||||
completionTokens?: number;
|
||||
totalTokens?: number;
|
||||
usageSummary?: ChatbotUsageSummary;
|
||||
}
|
||||
|
||||
export async function createChatbotSession(payload: ChatbotSessionPayload) {
|
||||
@ -48,4 +66,3 @@ export async function sendChatbotPrompt(payload: ChatbotQueryPayload) {
|
||||
const { data } = await apiClient.post<ChatbotMessageResponse>('/api/v1/internal/chatbot/query', payload);
|
||||
return data;
|
||||
}
|
||||
|
||||
|
||||
@ -95,6 +95,30 @@ export function AuthProvider({ children }: { children: ReactNode }) {
|
||||
try {
|
||||
console.debug('[Auth] Initializing auth...');
|
||||
|
||||
// Skip auth check if we're on auth pages
|
||||
// Need to check for paths with or without base path
|
||||
const pathname = window.location.pathname;
|
||||
const isAuthPage = pathname.endsWith('/login') ||
|
||||
pathname.endsWith('/signup') ||
|
||||
pathname.endsWith('/auth/callback') ||
|
||||
pathname.includes('/auth/') ||
|
||||
pathname.includes('/invite/');
|
||||
|
||||
if (isAuthPage) {
|
||||
console.log('[Auth] On auth page, completely skipping session check');
|
||||
console.log('[Auth] Current path:', pathname);
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// GUARD: Check if JWT exists before making session call
|
||||
const hasJWT = localStorage.getItem('stirling_jwt');
|
||||
if (!hasJWT) {
|
||||
console.debug('[Auth] No JWT token found, skipping session check');
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Skip config check entirely - let the app handle login state
|
||||
// The config will be fetched by useAppConfig when needed
|
||||
const { data, error } = await springAuth.getSession();
|
||||
@ -126,6 +150,22 @@ export function AuthProvider({ children }: { children: ReactNode }) {
|
||||
|
||||
initializeAuth();
|
||||
|
||||
// Listen for JWT availability after OAuth callback or login
|
||||
const handleJwtAvailable = async () => {
|
||||
console.debug('[Auth] JWT available event detected, loading session');
|
||||
try {
|
||||
const { data, error } = await springAuth.getSession();
|
||||
if (!error && data.session) {
|
||||
console.debug('[Auth] Session loaded after JWT available:', data.session);
|
||||
setSession(data.session);
|
||||
setLoading(false);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('[Auth] Error loading session after JWT available:', err);
|
||||
}
|
||||
};
|
||||
window.addEventListener('jwt-available', handleJwtAvailable);
|
||||
|
||||
// Subscribe to auth state changes
|
||||
const { data: { subscription } } = springAuth.onAuthStateChange(
|
||||
async (event: AuthChangeEvent, newSession: Session | null) => {
|
||||
@ -163,6 +203,7 @@ export function AuthProvider({ children }: { children: ReactNode }) {
|
||||
return () => {
|
||||
mounted = false;
|
||||
subscription.unsubscribe();
|
||||
window.removeEventListener('jwt-available', handleJwtAvailable);
|
||||
};
|
||||
}, []);
|
||||
|
||||
|
||||
@ -120,25 +120,32 @@ class SpringAuthClient {
|
||||
async getSession(): Promise<{ data: { session: Session | null }; error: AuthError | null }> {
|
||||
try {
|
||||
// Get JWT from localStorage
|
||||
console.log('[SpringAuth] getSession: Checking localStorage for JWT...');
|
||||
const token = localStorage.getItem('stirling_jwt');
|
||||
|
||||
// Log all localStorage keys for debugging
|
||||
console.log('[SpringAuth] All localStorage keys:', Object.keys(localStorage));
|
||||
|
||||
if (!token) {
|
||||
console.debug('[SpringAuth] getSession: No JWT in localStorage');
|
||||
console.warn('[SpringAuth] getSession: No JWT found in localStorage!');
|
||||
console.warn('[SpringAuth] This will cause logout on refresh. Make sure JWT is saved after login.');
|
||||
return { data: { session: null }, error: null };
|
||||
}
|
||||
|
||||
console.log('[SpringAuth] getSession: Found JWT in localStorage, length:', token.length);
|
||||
|
||||
// Verify with backend
|
||||
// Note: We pass the token explicitly here, overriding the interceptor's default
|
||||
console.debug('[SpringAuth] getSession: Verifying JWT with /api/v1/auth/me');
|
||||
console.log('[SpringAuth] getSession: Verifying JWT with /api/v1/auth/me');
|
||||
const response = await apiClient.get('/api/v1/auth/me', {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${token}`,
|
||||
},
|
||||
});
|
||||
|
||||
console.debug('[SpringAuth] /me response status:', response.status);
|
||||
console.log('[SpringAuth] /me response status:', response.status);
|
||||
const data = response.data;
|
||||
console.debug('[SpringAuth] /me response data:', data);
|
||||
console.log('[SpringAuth] /me response data:', data);
|
||||
|
||||
// Create session object
|
||||
const session: Session = {
|
||||
@ -148,7 +155,8 @@ class SpringAuthClient {
|
||||
expires_at: Date.now() + 3600 * 1000,
|
||||
};
|
||||
|
||||
console.debug('[SpringAuth] getSession: Session retrieved successfully');
|
||||
console.log('[SpringAuth] getSession: ✓ Session retrieved successfully');
|
||||
console.log('[SpringAuth] User is logged in as:', data.user?.email || data.user?.username);
|
||||
return { data: { session }, error: null };
|
||||
} catch (error: unknown) {
|
||||
console.error('[SpringAuth] getSession error:', error);
|
||||
@ -187,9 +195,22 @@ class SpringAuthClient {
|
||||
const data = response.data;
|
||||
const token = data.session.access_token;
|
||||
|
||||
// Store JWT in localStorage
|
||||
// Store JWT in localStorage - CRITICAL for persistence
|
||||
localStorage.setItem('stirling_jwt', token);
|
||||
console.log('[SpringAuth] JWT stored in localStorage');
|
||||
console.log('[SpringAuth] JWT stored in localStorage after password login');
|
||||
console.log('[SpringAuth] JWT token length:', token ? token.length : 0);
|
||||
|
||||
// Verify it was actually saved
|
||||
const savedToken = localStorage.getItem('stirling_jwt');
|
||||
if (!savedToken) {
|
||||
console.error('[SpringAuth] CRITICAL: JWT was not saved to localStorage!');
|
||||
// Try again
|
||||
localStorage.setItem('stirling_jwt', token);
|
||||
} else if (savedToken !== token) {
|
||||
console.error('[SpringAuth] CRITICAL: Saved token differs from received token!');
|
||||
} else {
|
||||
console.log('[SpringAuth] ✓ Verified JWT is correctly saved in localStorage');
|
||||
}
|
||||
|
||||
// Dispatch custom event for other components to react to JWT availability
|
||||
window.dispatchEvent(new CustomEvent('jwt-available'));
|
||||
@ -317,19 +338,37 @@ class SpringAuthClient {
|
||||
});
|
||||
|
||||
const data = response.data;
|
||||
const token = data.session.access_token;
|
||||
|
||||
// Handle different response structures - the API might return the token directly or nested
|
||||
const token = data?.session?.access_token || data?.access_token || data?.token;
|
||||
|
||||
if (!token) {
|
||||
console.error('[SpringAuth] refreshSession: No access token in response:', data);
|
||||
throw new Error('No access token received from refresh endpoint');
|
||||
}
|
||||
|
||||
// Update local storage with new token
|
||||
localStorage.setItem('stirling_jwt', token);
|
||||
console.log('[SpringAuth] refreshSession: New JWT stored in localStorage');
|
||||
|
||||
// Verify it was saved
|
||||
const savedToken = localStorage.getItem('stirling_jwt');
|
||||
if (savedToken !== token) {
|
||||
console.error('[SpringAuth] CRITICAL: JWT was not properly saved during refresh!');
|
||||
} else {
|
||||
console.log('[SpringAuth] refreshSession: ✓ JWT refreshed and verified in localStorage');
|
||||
}
|
||||
|
||||
// Dispatch custom event for other components to react to JWT availability
|
||||
window.dispatchEvent(new CustomEvent('jwt-available'));
|
||||
|
||||
// Build session object, handling different response structures
|
||||
const expires_in = data?.session?.expires_in || data?.expires_in || 3600; // Default to 1 hour
|
||||
const session: Session = {
|
||||
user: data.user,
|
||||
user: data?.user || data?.session?.user || null,
|
||||
access_token: token,
|
||||
expires_in: data.session.expires_in,
|
||||
expires_at: Date.now() + data.session.expires_in * 1000,
|
||||
expires_in: expires_in,
|
||||
expires_at: Date.now() + expires_in * 1000,
|
||||
};
|
||||
|
||||
// Notify listeners
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import { useEffect } from 'react';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { useAuth } from '@app/auth/UseSession';
|
||||
|
||||
/**
|
||||
* OAuth Callback Handler
|
||||
@ -11,7 +10,6 @@ import { useAuth } from '@app/auth/UseSession';
|
||||
*/
|
||||
export default function AuthCallback() {
|
||||
const navigate = useNavigate();
|
||||
const { refreshSession } = useAuth();
|
||||
|
||||
useEffect(() => {
|
||||
const handleCallback = async () => {
|
||||
@ -32,17 +30,31 @@ export default function AuthCallback() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Store JWT in localStorage
|
||||
// Store JWT in localStorage - CRITICAL for persistence
|
||||
localStorage.setItem('stirling_jwt', token);
|
||||
console.log('[AuthCallback] JWT stored in localStorage');
|
||||
console.log('[AuthCallback] JWT stored in localStorage after OAuth');
|
||||
console.log('[AuthCallback] JWT token length:', token.length);
|
||||
|
||||
// Verify it was actually saved
|
||||
const savedToken = localStorage.getItem('stirling_jwt');
|
||||
if (!savedToken) {
|
||||
console.error('[AuthCallback] CRITICAL: JWT was not saved to localStorage!');
|
||||
// Try again
|
||||
localStorage.setItem('stirling_jwt', token);
|
||||
} else if (savedToken !== token) {
|
||||
console.error('[AuthCallback] CRITICAL: Saved token differs from received token!');
|
||||
} else {
|
||||
console.log('[AuthCallback] ✓ Verified JWT is correctly saved in localStorage');
|
||||
}
|
||||
|
||||
// Dispatch custom event for other components to react to JWT availability
|
||||
window.dispatchEvent(new CustomEvent('jwt-available'))
|
||||
// This will trigger the auth provider to load the session with the new JWT
|
||||
window.dispatchEvent(new CustomEvent('jwt-available'));
|
||||
|
||||
// Refresh session to load user info into state
|
||||
await refreshSession();
|
||||
// Small delay to ensure event is processed
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
console.log('[AuthCallback] Session refreshed, redirecting to home');
|
||||
console.log('[AuthCallback] JWT saved, redirecting to home');
|
||||
|
||||
// Clear the hash from URL and redirect to home page
|
||||
navigate('/', { replace: true });
|
||||
@ -56,7 +68,7 @@ export default function AuthCallback() {
|
||||
};
|
||||
|
||||
handleCallback();
|
||||
}, [navigate, refreshSession]);
|
||||
}, [navigate]);
|
||||
|
||||
return (
|
||||
<div style={{
|
||||
|
||||
@ -118,8 +118,12 @@ export default function Login() {
|
||||
setError(error.message);
|
||||
} else if (user && session) {
|
||||
console.log('[Login] Email sign in successful');
|
||||
// Auth state will update automatically and Landing will redirect to home
|
||||
// No need to navigate manually here
|
||||
// Dispatch event to trigger auth state update
|
||||
window.dispatchEvent(new CustomEvent('jwt-available'));
|
||||
// Navigate to home page
|
||||
setTimeout(() => {
|
||||
navigate('/', { replace: true });
|
||||
}, 100); // Small delay to ensure auth state updates
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('[Login] Unexpected error:', err);
|
||||
|
||||
@ -36,19 +36,16 @@ export default defineConfig(({ mode }) => {
|
||||
target: 'http://localhost:8080',
|
||||
changeOrigin: true,
|
||||
secure: false,
|
||||
xfwd: true,
|
||||
},
|
||||
'/oauth2': {
|
||||
target: 'http://localhost:8080',
|
||||
changeOrigin: true,
|
||||
secure: false,
|
||||
xfwd: true,
|
||||
},
|
||||
'/login/oauth2': {
|
||||
target: 'http://localhost:8080',
|
||||
changeOrigin: true,
|
||||
secure: false,
|
||||
xfwd: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user