diff --git a/frigate/api/review.py b/frigate/api/review.py index 4788356f3..b04c8353a 100644 --- a/frigate/api/review.py +++ b/frigate/api/review.py @@ -9,10 +9,10 @@ import pandas as pd from fastapi import APIRouter from fastapi.params import Depends from fastapi.responses import JSONResponse -from peewee import Case, DoesNotExist, fn, operator +from peewee import Case, DoesNotExist, IntegrityError, fn, operator from playhouse.shortcuts import model_to_dict -from frigate.api.auth import require_role +from frigate.api.auth import get_current_user, require_role from frigate.api.defs.query.review_query_parameters import ( ReviewActivityMotionQueryParams, ReviewQueryParams, @@ -26,7 +26,7 @@ from frigate.api.defs.response.review_response import ( ReviewSummaryResponse, ) from frigate.api.defs.tags import Tags -from frigate.models import Recordings, ReviewSegment +from frigate.models import Recordings, ReviewSegment, UserReviewStatus from frigate.review.types import SeverityEnum from frigate.util.builtin import get_tz_modifiers @@ -36,7 +36,15 @@ router = APIRouter(tags=[Tags.review]) @router.get("/review", response_model=list[ReviewSegmentResponse]) -def review(params: ReviewQueryParams = Depends()): +async def review( + params: ReviewQueryParams = Depends(), + current_user: dict = Depends(get_current_user), +): + if isinstance(current_user, JSONResponse): + return current_user + + user_id = current_user["username"] + cameras = params.cameras labels = params.labels zones = params.zones @@ -74,9 +82,7 @@ def review(params: ReviewQueryParams = Depends()): (ReviewSegment.data["objects"].cast("text") % f'*"{label}"*') | (ReviewSegment.data["audio"].cast("text") % f'*"{label}"*') ) - - label_clause = reduce(operator.or_, label_clauses) - clauses.append((label_clause)) + clauses.append(reduce(operator.or_, label_clauses)) if zones != "all": # use matching so segments with multiple zones @@ -88,27 +94,52 @@ def review(params: ReviewQueryParams = Depends()): zone_clauses.append( (ReviewSegment.data["zones"].cast("text") % f'*"{zone}"*') ) - - zone_clause = reduce(operator.or_, zone_clauses) - clauses.append((zone_clause)) - - if reviewed == 0: - clauses.append((ReviewSegment.has_been_reviewed == False)) + clauses.append(reduce(operator.or_, zone_clauses)) if severity: clauses.append((ReviewSegment.severity == severity)) - review = ( - ReviewSegment.select() + # Join with UserReviewStatus to get per-user review status + review_query = ( + ReviewSegment.select( + ReviewSegment.id, + ReviewSegment.camera, + ReviewSegment.start_time, + ReviewSegment.end_time, + ReviewSegment.severity, + ReviewSegment.thumb_path, + ReviewSegment.data, + fn.COALESCE(UserReviewStatus.has_been_reviewed, False).alias( + "has_been_reviewed" + ), + ) + .left_outer_join( + UserReviewStatus, + on=( + (ReviewSegment.id == UserReviewStatus.review_segment) + & (UserReviewStatus.user_id == user_id) + ), + ) .where(reduce(operator.and_, clauses)) - .order_by(ReviewSegment.severity.asc()) + ) + + # Filter unreviewed items without subquery + if reviewed == 0: + review_query = review_query.where( + (UserReviewStatus.has_been_reviewed == False) + | (UserReviewStatus.has_been_reviewed.is_null()) + ) + + # Apply ordering and limit + review_query = ( + review_query.order_by(ReviewSegment.severity.asc()) .order_by(ReviewSegment.start_time.desc()) .limit(limit) .dicts() .iterator() ) - return JSONResponse(content=[r for r in review]) + return JSONResponse(content=[r for r in review_query]) @router.get("/review_ids", response_model=list[ReviewSegmentResponse]) @@ -134,7 +165,15 @@ def review_ids(ids: str): @router.get("/review/summary", response_model=ReviewSummaryResponse) -def review_summary(params: ReviewSummaryQueryParams = Depends()): +async def review_summary( + params: ReviewSummaryQueryParams = Depends(), + current_user: dict = Depends(get_current_user), +): + if isinstance(current_user, JSONResponse): + return current_user + + user_id = current_user["username"] + hour_modifier, minute_modifier, seconds_offset = get_tz_modifiers(params.timezone) day_ago = (datetime.datetime.now() - datetime.timedelta(hours=24)).timestamp() month_ago = (datetime.datetime.now() - datetime.timedelta(days=30)).timestamp() @@ -160,10 +199,7 @@ def review_summary(params: ReviewSummaryQueryParams = Depends()): (ReviewSegment.data["objects"].cast("text") % f'*"{label}"*') | (ReviewSegment.data["audio"].cast("text") % f'*"{label}"*') ) - - label_clause = reduce(operator.or_, label_clauses) - clauses.append((label_clause)) - + clauses.append(reduce(operator.or_, label_clauses)) if zones != "all": # use matching so segments with multiple zones # still match on a search where any zone matches @@ -172,21 +208,20 @@ def review_summary(params: ReviewSummaryQueryParams = Depends()): for zone in filtered_zones: zone_clauses.append( - (ReviewSegment.data["zones"].cast("text") % f'*"{zone}"*') + ReviewSegment.data["zones"].cast("text") % f'*"{zone}"*' ) + clauses.append(reduce(operator.or_, zone_clauses)) - zone_clause = reduce(operator.or_, zone_clauses) - clauses.append((zone_clause)) - - last_24 = ( + last_24_query = ( ReviewSegment.select( fn.SUM( Case( None, [ ( - (ReviewSegment.severity == SeverityEnum.alert), - ReviewSegment.has_been_reviewed, + (ReviewSegment.severity == SeverityEnum.alert) + & (UserReviewStatus.has_been_reviewed == True), + 1, ) ], 0, @@ -197,8 +232,9 @@ def review_summary(params: ReviewSummaryQueryParams = Depends()): None, [ ( - (ReviewSegment.severity == SeverityEnum.detection), - ReviewSegment.has_been_reviewed, + (ReviewSegment.severity == SeverityEnum.detection) + & (UserReviewStatus.has_been_reviewed == True), + 1, ) ], 0, @@ -229,6 +265,13 @@ def review_summary(params: ReviewSummaryQueryParams = Depends()): ) ).alias("total_detection"), ) + .left_outer_join( + UserReviewStatus, + on=( + (ReviewSegment.id == UserReviewStatus.review_segment) + & (UserReviewStatus.user_id == user_id) + ), + ) .where(reduce(operator.and_, clauses)) .dicts() .get() @@ -248,14 +291,12 @@ def review_summary(params: ReviewSummaryQueryParams = Depends()): for label in filtered_labels: label_clauses.append( - (ReviewSegment.data["objects"].cast("text") % f'*"{label}"*') + ReviewSegment.data["objects"].cast("text") % f'*"{label}"*' ) - - label_clause = reduce(operator.or_, label_clauses) - clauses.append((label_clause)) + clauses.append(reduce(operator.or_, label_clauses)) day_in_seconds = 60 * 60 * 24 - last_month = ( + last_month_query = ( ReviewSegment.select( fn.strftime( "%Y-%m-%d", @@ -271,8 +312,9 @@ def review_summary(params: ReviewSummaryQueryParams = Depends()): None, [ ( - (ReviewSegment.severity == SeverityEnum.alert), - ReviewSegment.has_been_reviewed, + (ReviewSegment.severity == SeverityEnum.alert) + & (UserReviewStatus.has_been_reviewed == True), + 1, ) ], 0, @@ -283,8 +325,9 @@ def review_summary(params: ReviewSummaryQueryParams = Depends()): None, [ ( - (ReviewSegment.severity == SeverityEnum.detection), - ReviewSegment.has_been_reviewed, + (ReviewSegment.severity == SeverityEnum.detection) + & (UserReviewStatus.has_been_reviewed == True), + 1, ) ], 0, @@ -315,28 +358,59 @@ def review_summary(params: ReviewSummaryQueryParams = Depends()): ) ).alias("total_detection"), ) + .left_outer_join( + UserReviewStatus, + on=( + (ReviewSegment.id == UserReviewStatus.review_segment) + & (UserReviewStatus.user_id == user_id) + ), + ) .where(reduce(operator.and_, clauses)) .group_by( - (ReviewSegment.start_time + seconds_offset).cast("int") / day_in_seconds, + (ReviewSegment.start_time + seconds_offset).cast("int") / day_in_seconds ) .order_by(ReviewSegment.start_time.desc()) ) data = { - "last24Hours": last_24, + "last24Hours": last_24_query, } - for e in last_month.dicts().iterator(): + for e in last_month_query.dicts().iterator(): data[e["day"]] = e return JSONResponse(content=data) @router.post("/reviews/viewed", response_model=GenericResponse) -def set_multiple_reviewed(body: ReviewModifyMultipleBody): - ReviewSegment.update(has_been_reviewed=True).where( - ReviewSegment.id << body.ids - ).execute() +async def set_multiple_reviewed( + body: ReviewModifyMultipleBody, + current_user: dict = Depends(get_current_user), +): + if isinstance(current_user, JSONResponse): + return current_user + + user_id = current_user["username"] + + for review_id in body.ids: + try: + review_status = UserReviewStatus.get( + UserReviewStatus.user_id == user_id, + UserReviewStatus.review_segment == review_id, + ) + # If it exists and isn’t reviewed, update it + if not review_status.has_been_reviewed: + review_status.has_been_reviewed = True + review_status.save() + except DoesNotExist: + try: + UserReviewStatus.create( + user_id=user_id, + review_segment=ReviewSegment.get(id=review_id), + has_been_reviewed=True, + ) + except (DoesNotExist, IntegrityError): + pass return JSONResponse( content=({"success": True, "message": "Reviewed multiple items"}), @@ -389,6 +463,9 @@ def delete_reviews(body: ReviewModifyMultipleBody): # delete recordings and review segments Recordings.delete().where(Recordings.id << recording_ids).execute() ReviewSegment.delete().where(ReviewSegment.id << list_of_ids).execute() + UserReviewStatus.delete().where( + UserReviewStatus.review_segment << list_of_ids + ).execute() return JSONResponse( content=({"success": True, "message": "Deleted review items."}), status_code=200 @@ -502,7 +579,15 @@ def get_review(review_id: str): @router.delete("/review/{review_id}/viewed", response_model=GenericResponse) -def set_not_reviewed(review_id: str): +async def set_not_reviewed( + review_id: str, + current_user: dict = Depends(get_current_user), +): + if isinstance(current_user, JSONResponse): + return current_user + + user_id = current_user["username"] + try: review: ReviewSegment = ReviewSegment.get(ReviewSegment.id == review_id) except DoesNotExist: @@ -513,8 +598,15 @@ def set_not_reviewed(review_id: str): status_code=404, ) - review.has_been_reviewed = False - review.save() + try: + user_review = UserReviewStatus.get( + UserReviewStatus.user_id == user_id, + UserReviewStatus.review_segment == review, + ) + # we could update here instead of delete if we need + user_review.delete_instance() + except DoesNotExist: + pass # Already effectively "not reviewed" return JSONResponse( content=({"success": True, "message": f"Set Review {review_id} as not viewed"}), diff --git a/frigate/models.py b/frigate/models.py index 11b25b938..5aa0dc5b2 100644 --- a/frigate/models.py +++ b/frigate/models.py @@ -3,6 +3,7 @@ from peewee import ( CharField, DateTimeField, FloatField, + ForeignKeyField, IntegerField, Model, TextField, @@ -92,12 +93,20 @@ class ReviewSegment(Model): # type: ignore[misc] camera = CharField(index=True, max_length=20) start_time = DateTimeField() end_time = DateTimeField() - has_been_reviewed = BooleanField(default=False) severity = CharField(max_length=30) # alert, detection thumb_path = CharField(unique=True) data = JSONField() # additional data about detection like list of labels, zone, areas of significant motion +class UserReviewStatus(Model): # type: ignore[misc] + user_id = CharField(max_length=30) + review_segment = ForeignKeyField(ReviewSegment, backref="user_reviews") + has_been_reviewed = BooleanField(default=False) + + class Meta: + indexes = ((("user_id", "review_segment"), True),) + + class Previews(Model): # type: ignore[misc] id = CharField(null=False, primary_key=True, max_length=30) camera = CharField(index=True, max_length=20) diff --git a/frigate/record/cleanup.py b/frigate/record/cleanup.py index e526b020d..c86c81859 100644 --- a/frigate/record/cleanup.py +++ b/frigate/record/cleanup.py @@ -12,7 +12,7 @@ from playhouse.sqlite_ext import SqliteExtDatabase from frigate.config import CameraConfig, FrigateConfig, RetainModeEnum from frigate.const import CACHE_DIR, CLIPS_DIR, MAX_WAL_SIZE, RECORD_DIR -from frigate.models import Previews, Recordings, ReviewSegment +from frigate.models import Previews, Recordings, ReviewSegment, UserReviewStatus from frigate.record.util import remove_empty_directories, sync_recordings from frigate.util.builtin import clear_and_unlink, get_tomorrow_at_time @@ -90,6 +90,10 @@ class RecordingCleanup(threading.Thread): ReviewSegment.delete().where( ReviewSegment.id << deleted_reviews_list[i : i + max_deletes] ).execute() + UserReviewStatus.delete().where( + UserReviewStatus.review_segment + << deleted_reviews_list[i : i + max_deletes] + ).execute() def expire_existing_camera_recordings( self, expire_date: float, config: CameraConfig, reviews: ReviewSegment diff --git a/frigate/test/http_api/base_http_test.py b/frigate/test/http_api/base_http_test.py index 35cda7b79..3c4a7ccdc 100644 --- a/frigate/test/http_api/base_http_test.py +++ b/frigate/test/http_api/base_http_test.py @@ -157,16 +157,14 @@ class BaseTestHttp(unittest.TestCase): start_time: float = datetime.datetime.now().timestamp(), end_time: float = datetime.datetime.now().timestamp() + 20, severity: SeverityEnum = SeverityEnum.alert, - has_been_reviewed: bool = False, data: Json = {}, - ) -> Event: + ) -> ReviewSegment: """Inserts a review segment model with a given id.""" return ReviewSegment.insert( id=id, camera="front_door", start_time=start_time, end_time=end_time, - has_been_reviewed=has_been_reviewed, severity=severity, thumb_path=False, data=data, diff --git a/frigate/test/http_api/test_http_review.py b/frigate/test/http_api/test_http_review.py index ee7d96bc5..19c589a67 100644 --- a/frigate/test/http_api/test_http_review.py +++ b/frigate/test/http_api/test_http_review.py @@ -1,16 +1,29 @@ from datetime import datetime, timedelta from fastapi.testclient import TestClient +from peewee import DoesNotExist -from frigate.models import Event, Recordings, ReviewSegment +from frigate.api.auth import get_current_user +from frigate.models import Event, Recordings, ReviewSegment, UserReviewStatus from frigate.review.types import SeverityEnum from frigate.test.http_api.base_http_test import BaseTestHttp class TestHttpReview(BaseTestHttp): def setUp(self): - super().setUp([Event, Recordings, ReviewSegment]) + super().setUp([Event, Recordings, ReviewSegment, UserReviewStatus]) self.app = super().create_app() + self.user_id = "admin" + + # Mock get_current_user for all tests + async def mock_get_current_user(): + return {"username": self.user_id, "role": "admin"} + + self.app.dependency_overrides[get_current_user] = mock_get_current_user + + def tearDown(self): + self.app.dependency_overrides.clear() + super().tearDown() def _get_reviews(self, ids: list[str]): return list( @@ -24,6 +37,13 @@ class TestHttpReview(BaseTestHttp): Recordings.select(Recordings.id).where(Recordings.id.in_(ids)).execute() ) + def _insert_user_review_status(self, review_id: str, reviewed: bool = True): + UserReviewStatus.create( + user_id=self.user_id, + review_segment=ReviewSegment.get(ReviewSegment.id == review_id), + has_been_reviewed=reviewed, + ) + #################################################################################################################### ################################### GET /review Endpoint ######################################################## #################################################################################################################### @@ -43,11 +63,14 @@ class TestHttpReview(BaseTestHttp): now = datetime.now().timestamp() with TestClient(self.app) as client: - super().insert_mock_review_segment("123456.random", now - 2, now - 1) + id = "123456.random" + super().insert_mock_review_segment(id, now - 2, now - 1) response = client.get("/review") assert response.status_code == 200 response_json = response.json() assert len(response_json) == 1 + assert response_json[0]["id"] == id + assert response_json[0]["has_been_reviewed"] == False def test_get_review_with_time_filter_no_matches(self): now = datetime.now().timestamp() @@ -391,37 +414,27 @@ class TestHttpReview(BaseTestHttp): with TestClient(self.app) as client: five_days_ago_ts = five_days_ago.timestamp() for i in range(10): + id = f"123456_{i}.random_alert_not_reviewed" super().insert_mock_review_segment( - f"123456_{i}.random_alert_not_reviewed", - five_days_ago_ts, - five_days_ago_ts, - SeverityEnum.alert, - False, + id, five_days_ago_ts, five_days_ago_ts, SeverityEnum.alert ) for i in range(10): + id = f"123456_{i}.random_alert_reviewed" super().insert_mock_review_segment( - f"123456_{i}.random_alert_reviewed", - five_days_ago_ts, - five_days_ago_ts, - SeverityEnum.alert, - True, + id, five_days_ago_ts, five_days_ago_ts, SeverityEnum.alert ) + self._insert_user_review_status(id, reviewed=True) for i in range(10): + id = f"123456_{i}.random_detection_not_reviewed" super().insert_mock_review_segment( - f"123456_{i}.random_detection_not_reviewed", - five_days_ago_ts, - five_days_ago_ts, - SeverityEnum.detection, - False, + id, five_days_ago_ts, five_days_ago_ts, SeverityEnum.detection ) for i in range(5): + id = f"123456_{i}.random_detection_reviewed" super().insert_mock_review_segment( - f"123456_{i}.random_detection_reviewed", - five_days_ago_ts, - five_days_ago_ts, - SeverityEnum.detection, - True, + id, five_days_ago_ts, five_days_ago_ts, SeverityEnum.detection ) + self._insert_user_review_status(id, reviewed=True) response = client.get("/review/summary") assert response.status_code == 200 response_json = response.json() @@ -447,6 +460,7 @@ class TestHttpReview(BaseTestHttp): #################################################################################################################### ################################### POST reviews/viewed Endpoint ################################################ #################################################################################################################### + def test_post_reviews_viewed_no_body(self): with TestClient(self.app) as client: super().insert_mock_review_segment("123456.random") @@ -473,12 +487,11 @@ class TestHttpReview(BaseTestHttp): assert response["success"] == True assert response["message"] == "Reviewed multiple items" # Verify that in DB the review segment was not changed - review_segment_in_db = ( - ReviewSegment.select(ReviewSegment.has_been_reviewed) - .where(ReviewSegment.id == id) - .get() - ) - assert review_segment_in_db.has_been_reviewed == False + with self.assertRaises(DoesNotExist): + UserReviewStatus.get( + UserReviewStatus.user_id == self.user_id, + UserReviewStatus.review_segment == "1", + ) def test_post_reviews_viewed(self): with TestClient(self.app) as client: @@ -487,16 +500,15 @@ class TestHttpReview(BaseTestHttp): body = {"ids": [id]} response = client.post("/reviews/viewed", json=body) assert response.status_code == 200 - response = response.json() - assert response["success"] == True - assert response["message"] == "Reviewed multiple items" - # Verify that in DB the review segment was changed - review_segment_in_db = ( - ReviewSegment.select(ReviewSegment.has_been_reviewed) - .where(ReviewSegment.id == id) - .get() + response_json = response.json() + assert response_json["success"] == True + assert response_json["message"] == "Reviewed multiple items" + # Verify UserReviewStatus was created + user_review = UserReviewStatus.get( + UserReviewStatus.user_id == self.user_id, + UserReviewStatus.review_segment == id, ) - assert review_segment_in_db.has_been_reviewed == True + assert user_review.has_been_reviewed == True #################################################################################################################### ################################### POST reviews/delete Endpoint ################################################ @@ -672,8 +684,7 @@ class TestHttpReview(BaseTestHttp): "camera": "front_door", "start_time": now + 1, "end_time": now + 2, - "has_been_reviewed": False, - "severity": SeverityEnum.alert, + "severity": "alert", "thumb_path": "False", "data": {"detections": {"event_id": event_id}}, }, @@ -708,8 +719,7 @@ class TestHttpReview(BaseTestHttp): "camera": "front_door", "start_time": now + 1, "end_time": now + 2, - "has_been_reviewed": False, - "severity": SeverityEnum.alert, + "severity": "alert", "thumb_path": "False", "data": {}, }, @@ -719,6 +729,7 @@ class TestHttpReview(BaseTestHttp): #################################################################################################################### ################################### DELETE /review/{review_id}/viewed Endpoint ################################## #################################################################################################################### + def test_delete_review_viewed_review_not_found(self): with TestClient(self.app) as client: review_id = "123456.random" @@ -735,11 +746,10 @@ class TestHttpReview(BaseTestHttp): with TestClient(self.app) as client: review_id = "123456.review.random" - super().insert_mock_review_segment( - review_id, now + 1, now + 2, has_been_reviewed=True - ) - review_before = ReviewSegment.get(ReviewSegment.id == review_id) - assert review_before.has_been_reviewed == True + super().insert_mock_review_segment(review_id, now + 1, now + 2) + self._insert_user_review_status(review_id, reviewed=True) + # Verify it’s reviewed before + response = client.get(f"/review/{review_id}") response = client.delete(f"/review/{review_id}/viewed") assert response.status_code == 200 @@ -749,5 +759,9 @@ class TestHttpReview(BaseTestHttp): response_json, ) - review_after = ReviewSegment.get(ReviewSegment.id == review_id) - assert review_after.has_been_reviewed == False + # Verify it’s unreviewed after + with self.assertRaises(DoesNotExist): + UserReviewStatus.get( + UserReviewStatus.user_id == self.user_id, + UserReviewStatus.review_segment == review_id, + ) diff --git a/migrations/030_create_user_review_status.py b/migrations/030_create_user_review_status.py new file mode 100644 index 000000000..d24738438 --- /dev/null +++ b/migrations/030_create_user_review_status.py @@ -0,0 +1,85 @@ +"""Peewee migrations -- 030_create_user_review_status.py. + +This migration creates the UserReviewStatus table to track per-user review states, +migrates existing has_been_reviewed data from ReviewSegment to all users in the user table, +and drops the has_been_reviewed column. Rollback drops UserReviewStatus and restores the column. + +Some examples (model - class or model_name):: + + > Model = migrator.orm['model_name'] # Return model in current state by name + > migrator.sql(sql) # Run custom SQL + > migrator.python(func, *args, **kwargs) # Run python code + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.drop_index(model, *col_names) + > migrator.add_not_null(model, *field_names) + > migrator.drop_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + +""" + +import peewee as pw + +from frigate.models import User, UserReviewStatus + +SQL = pw.SQL + + +def migrate(migrator, database, fake=False, **kwargs): + User._meta.database = database + UserReviewStatus._meta.database = database + + migrator.sql( + """ + CREATE TABLE IF NOT EXISTS "userreviewstatus" ( + "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "user_id" VARCHAR(30) NOT NULL, + "review_segment_id" VARCHAR(30) NOT NULL, + "has_been_reviewed" INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY ("review_segment_id") REFERENCES "reviewsegment" ("id") ON DELETE CASCADE + ) + """ + ) + + # Add unique index on (user_id, review_segment_id) + migrator.sql( + 'CREATE UNIQUE INDEX IF NOT EXISTS "userreviewstatus_user_segment" ON "userreviewstatus" ("user_id", "review_segment_id")' + ) + + # Migrate existing has_been_reviewed data to UserReviewStatus for all users + def migrate_data(): + all_users = list(User.select()) + if not all_users: + return + + cursor = database.execute_sql( + 'SELECT "id" FROM "reviewsegment" WHERE "has_been_reviewed" = 1' + ) + reviewed_segment_ids = [row[0] for row in cursor.fetchall()] + + for segment_id in reviewed_segment_ids: + for user in all_users: + UserReviewStatus.create( + user_id=user.username, + review_segment=segment_id, + has_been_reviewed=True, + ) + + if not fake: # Only run data migration if not faking + migrator.python(migrate_data) + + migrator.sql('ALTER TABLE "reviewsegment" DROP COLUMN "has_been_reviewed"') + + +def rollback(migrator, database, fake=False, **kwargs): + migrator.sql('DROP TABLE IF EXISTS "userreviewstatus"') + # Restore has_been_reviewed column to reviewsegment (no data restoration) + migrator.sql( + 'ALTER TABLE "reviewsegment" ADD COLUMN "has_been_reviewed" INTEGER NOT NULL DEFAULT 0' + )