diff --git a/web/src/views/classification/ModelSelectionView.tsx b/web/src/views/classification/ModelSelectionView.tsx
index 63133842a..aa2f94c6a 100644
--- a/web/src/views/classification/ModelSelectionView.tsx
+++ b/web/src/views/classification/ModelSelectionView.tsx
@@ -1,3 +1,4 @@
+import { baseUrl } from "@/api/baseUrl";
import ActivityIndicator from "@/components/indicators/activity-indicator";
import { cn } from "@/lib/utils";
import {
@@ -37,27 +38,60 @@ export default function ModelSelectionView({
return (
{classificationConfigs.map((config) => (
-
onClick(config)}
- onContextMenu={() => {
- // e.stopPropagation();
- // e.preventDefault();
- // handleClickEvent(true);
- }}
- >
-
-
- {config.name} ({config.state_config != null ? "State" : "Object"}{" "}
- Classification)
-
-
+
onClick(config)} />
))}
);
}
+
+type ModelCardProps = {
+ config: CustomClassificationModelConfig;
+ onClick: () => void;
+};
+function ModelCard({ config, onClick }: ModelCardProps) {
+ const { data: dataset } = useSWR<{
+ [id: string]: string[];
+ }>(`classification/${config.name}/dataset`, { revalidateOnFocus: false });
+
+ const coverImages = useMemo(() => {
+ if (!dataset) {
+ return {};
+ }
+
+ const imageMap: { [key: string]: string } = {};
+
+ for (const [key, imageList] of Object.entries(dataset)) {
+ if (imageList.length > 0) {
+ imageMap[key] = imageList[0];
+ }
+ }
+
+ return imageMap;
+ }, [dataset]);
+
+ return (
+ onClick()}
+ >
+
+ {Object.entries(coverImages).map(([key, image]) => (
+

+ ))}
+
+
+ {config.name} ({config.state_config != null ? "State" : "Object"}{" "}
+ Classification)
+
+
+ );
+}