diff options
Diffstat (limited to 'generalresearch/managers/thl/wall.py')
| -rw-r--r-- | generalresearch/managers/thl/wall.py | 675 |
1 files changed, 675 insertions, 0 deletions
diff --git a/generalresearch/managers/thl/wall.py b/generalresearch/managers/thl/wall.py new file mode 100644 index 0000000..cd1bdbf --- /dev/null +++ b/generalresearch/managers/thl/wall.py @@ -0,0 +1,675 @@ +import logging +from collections import defaultdict +from datetime import datetime, timezone, timedelta +from decimal import Decimal, ROUND_DOWN +from functools import cached_property +from random import choice as rchoice +from typing import Optional, Collection, List +from uuid import uuid4 + +from faker import Faker +from psycopg import sql +from psycopg.rows import dict_row +from pydantic import AwareDatetime, PositiveInt, PostgresDsn, RedisDsn + +from generalresearch.managers import parse_order_by +from generalresearch.managers.base import ( + Permission, + PostgresManager, + PostgresManagerWithRedis, +) +from generalresearch.models import Source +from generalresearch.models.custom_types import UUIDStr, SurveyKey +from generalresearch.models.thl.definitions import ( + Status, + StatusCode1, + WallStatusCode2, + ReportValue, + WallAdjustedStatus, +) +from generalresearch.models.thl.ledger import OrderBy +from generalresearch.models.thl.session import ( + check_adjusted_status_wall_consistent, + Wall, + WallAttempt, +) +from generalresearch.models.thl.survey.model import TaskActivity +from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig + +logger = logging.getLogger("WallManager") +fake = Faker() + + +class WallManager(PostgresManager): + def __init__( + self, + pg_config: PostgresConfig, + permissions: Optional[Collection[Permission]] = None, + ): + assert pg_config.row_factory == dict_row + super().__init__(pg_config=pg_config, permissions=permissions) + + def create( + self, + session_id: int, + user_id: int, + started: datetime, + source: Source, + req_survey_id: str, + req_cpi: Decimal, + buyer_id: Optional[str] = None, + uuid_id: Optional[str] = None, + ) -> Wall: + """ + Creates a Wall event. Prefer to use this rather than instantiating + the model directly, because we're explicitly defining here which keys + should be set and which won't get set until later. + """ + if uuid_id is None: + uuid_id = uuid4().hex + + wall = Wall( + session_id=session_id, + user_id=user_id, + uuid=uuid_id, + started=started, + source=source, + buyer_id=buyer_id, + req_survey_id=req_survey_id, + req_cpi=req_cpi, + ) + d = wall.model_dump_mysql() + query = """ + INSERT INTO thl_wall ( + uuid, started, source, buyer_id, req_survey_id, + req_cpi, survey_id, cpi, session_id + ) VALUES ( + %(uuid)s, %(started)s, %(source)s, + %(buyer_id)s, %(req_survey_id)s, %(req_cpi)s, + %(survey_id)s, %(cpi)s, %(session_id)s + ); + """ + self.pg_config.execute_write(query=query, params=d) + return wall + + def create_dummy( + self, + session_id: Optional[int] = None, + user_id: Optional[int] = None, + started: Optional[datetime] = None, + source: Optional[Source] = None, + req_survey_id: Optional[str] = None, + req_cpi: Optional[Decimal] = None, + buyer_id: Optional[str] = None, + uuid_id: Optional[str] = None, + ): + """To be used in tests, where we don't care about certain fields""" + + user_id = user_id or fake.random_int(min=1, max=2_147_483_648) + started = started or fake.date_time_between( + start_date=datetime(year=1900, month=1, day=1, tzinfo=timezone.utc), + end_date=datetime.now(tz=timezone.utc), + tzinfo=timezone.utc, + ) + + if session_id is None: + from generalresearch.managers.thl.session import SessionManager + + session = SessionManager(pg_config=self.pg_config).create_dummy( + started=started + ) + session_id = session.id + + source = source or rchoice(list(Source)) + req_survey_id = req_survey_id or uuid4().hex + req_cpi = req_cpi or Decimal(fake.random_int(min=1, max=150) / 100).quantize( + Decimal(".01"), rounding=ROUND_DOWN + ) + + return self.create( + session_id=session_id, + user_id=user_id, + started=started, + source=source, + req_survey_id=req_survey_id, + req_cpi=req_cpi, + buyer_id=buyer_id, + uuid_id=uuid_id, + ) + + def get_from_uuid(self, wall_uuid: UUIDStr) -> Wall: + query = """ + SELECT + tw.uuid, tw.source, tw.buyer_id, tw.survey_id, + tw.req_survey_id, tw.cpi, tw.req_cpi, tw.started, + tw.finished, tw.status, tw.status_code_1, + tw.status_code_2, tw.ext_status_code_1, + tw.ext_status_code_2, tw.ext_status_code_3, + tw.report_value, tw.report_notes, tw.adjusted_status, + tw.adjusted_cpi, tw.adjusted_timestamp, tw.session_id, + ts.user_id + FROM thl_wall AS tw + JOIN thl_session AS ts + ON tw.session_id = ts.id + WHERE tw.uuid = %(wall_uuid)s + LIMIT 2; + """ + res = self.pg_config.execute_sql_query(query, params={"wall_uuid": wall_uuid}) + assert len(res) == 1, f"Expected 1 result, got {len(res)}" + return Wall.model_validate(res[0]) + + def get_from_uuid_if_exists(self, wall_uuid: UUIDStr) -> Optional[Wall]: + try: + return self.get_from_uuid(wall_uuid=wall_uuid) + except AssertionError: + return None + + def finish( + self, + wall: Wall, + status: Status, + status_code_1: StatusCode1, + finished: datetime, + ext_status_code_1: Optional[str] = None, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, + status_code_2: Optional[WallStatusCode2] = None, + survey_id: Optional[str] = None, + cpi: Optional[Decimal] = None, + ) -> None: + """This wall event is finished. This would be called if/when we get a + callback for this wall event. Some other code is responsible for + translating external status codes to grl statuses + """ + wall.finish( + status=status, + status_code_1=status_code_1, + status_code_2=status_code_2, + ext_status_code_1=ext_status_code_1, + ext_status_code_2=ext_status_code_2, + ext_status_code_3=ext_status_code_3, + finished=finished, + survey_id=survey_id, + cpi=cpi, + ) + d = { + "status": status, + "status_code_1": status_code_1.value, + "status_code_2": status_code_2.value if status_code_2 else None, + "finished": finished, + "ext_status_code_1": ext_status_code_1, + "ext_status_code_2": ext_status_code_2, + "ext_status_code_3": ext_status_code_3, + "uuid": wall.uuid, + } + extra = [] + if survey_id is not None: + extra.append("survey_id = %(survey_id)s") + d["survey_id"] = survey_id + if cpi is not None: + extra.append("cpi = %(cpi)s") + d["cpi"] = str(cpi) + extra_str = "," + ", ".join(extra) if extra else "" + + query = f""" + UPDATE thl_wall + SET status=%(status)s, status_code_1=%(status_code_1)s, + status_code_2=%(status_code_2)s, finished=%(finished)s, + ext_status_code_1=%(ext_status_code_1)s, + ext_status_code_2=%(ext_status_code_2)s, + ext_status_code_3=%(ext_status_code_3)s + {extra_str} + WHERE uuid = %(uuid)s; + """ + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, params=d) + assert c.rowcount == 1 + conn.commit() + + return None + + def get_wall_events( + self, + session_id: Optional[PositiveInt] = None, + session_ids: Optional[List[PositiveInt]] = None, + order_by: OrderBy = OrderBy.ASC, + ) -> List[Wall]: + + if session_id is not None and session_ids is not None: + raise ValueError("Cannot provide both session_id and session_ids") + + if session_id is None and session_ids is None: + raise ValueError("Must provide either session_id or session_ids") + + ids = session_ids if session_ids is not None else [session_id] + + if len(ids) > 500: + raise ValueError("Cannot look up more than 500 Sessions at once.") + + query = f""" + SELECT + tw.uuid, tw.source, tw.buyer_id, tw.survey_id, + tw.req_survey_id, tw.cpi, tw.req_cpi, tw.started, + tw.finished, tw.status, tw.status_code_1, + tw.status_code_2, tw.ext_status_code_1, + tw.ext_status_code_2, tw.ext_status_code_3, + tw.report_value, tw.report_notes, tw.adjusted_status, + tw.adjusted_cpi, tw.adjusted_timestamp, tw.session_id, + ts.user_id + FROM thl_wall AS tw + JOIN thl_session AS ts + ON tw.session_id = ts.id + WHERE tw.session_id = ANY(%s) + ORDER BY tw.started {order_by.value} + """ + res = self.pg_config.execute_sql_query(query=query, params=[ids]) + return [Wall.model_validate(d) for d in res] + + def adjust_status( + self, + wall: Wall, + adjusted_timestamp: AwareDatetime, + adjusted_status: Optional[WallAdjustedStatus] = None, + adjusted_cpi: Optional[Decimal] = None, + ) -> None: + assert wall.status, "Wall must have an existing Status" + + # Be generous here, and if adjusted_status is adj to fail and + # adjusted_cpi is None, set it to 0 + if ( + adjusted_status == WallAdjustedStatus.ADJUSTED_TO_FAIL + and adjusted_cpi is None + ): + adjusted_cpi = 0 + elif ( + adjusted_status == WallAdjustedStatus.ADJUSTED_TO_COMPLETE + and adjusted_cpi is None + ): + adjusted_cpi = wall.cpi + + allowed, msg = check_adjusted_status_wall_consistent( + status=wall.status, + cpi=wall.cpi, + adjusted_status=wall.adjusted_status, + adjusted_cpi=wall.adjusted_cpi, + new_adjusted_status=adjusted_status, + new_adjusted_cpi=adjusted_cpi, + ) + + if not allowed: + raise ValueError(msg) + + wall.update( + adjusted_status=adjusted_status, + adjusted_cpi=adjusted_cpi, + adjusted_timestamp=adjusted_timestamp, + ) + d = { + "adjusted_status": ( + wall.adjusted_status.value if wall.adjusted_status else None + ), + "adjusted_timestamp": adjusted_timestamp, + "adjusted_cpi": ( + str(wall.adjusted_cpi) if wall.adjusted_cpi is not None else None + ), + "uuid": wall.uuid, + } + + query = sql.SQL( + """ + UPDATE thl_wall + SET adjusted_status = %(adjusted_status)s, + adjusted_timestamp = %(adjusted_timestamp)s, + adjusted_cpi = %(adjusted_cpi)s + WHERE uuid = %(uuid)s; + """ + ) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=query, params=d) + assert c.rowcount == 1 + conn.commit() + + return None + + def report( + self, + wall: Wall, + report_value: ReportValue, + report_notes: Optional[str] = None, + report_timestamp: Optional[AwareDatetime] = None, + ) -> None: + wall.report( + report_value=report_value, + report_notes=report_notes, + report_timestamp=report_timestamp, + ) + params = { + "uuid": wall.uuid, + "report_value": report_value.value, + "status": wall.status.value, + "finished": wall.finished, + "report_notes": report_notes, + } + query = sql.SQL( + """ + UPDATE thl_wall + SET report_value = %(report_value)s, + report_notes = %(report_notes)s, + status = %(status)s, + finished = %(finished)s + WHERE uuid = %(uuid)s; + """ + ) + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=query, params=params) + assert c.rowcount == 1 + conn.commit() + return None + + def filter_count_attempted_live(self, user_id: int) -> int: + """ + Get the number of surveys this user has attempted that + are still currently live. This can be shown as port of + a "progress bar" for eligible, live, surveys they've + already attempted. + """ + query = f""" + SELECT + COUNT(1) as cnt + FROM thl_wall w + JOIN thl_session s ON w.session_id = s.id + JOIN marketplace_survey ms ON + ms.source = w.source AND + ms.survey_id = w.req_survey_id AND + ms.is_live + WHERE user_id = %(user_id)s + AND w.source != 'g' + """ + params = {"user_id": user_id} + res = self.pg_config.execute_sql_query( + query=query, + params=params, + ) + return res[0]["cnt"] + + def filter_wall_attempts_paginated( + self, + user_id: int, + started_after: Optional[datetime] = None, + started_before: Optional[datetime] = None, + page: int = 1, + size: int = 100, + order_by: Optional[str] = "-started", + ) -> List[WallAttempt]: + """ + Returns WallAttempt + """ + filters = [] + params = {} + filters.append("user_id = %(user_id)s") + params["user_id"] = user_id + default_started = datetime.now(tz=timezone.utc) - timedelta(days=90) + started_after = started_after or default_started + started_before = started_before or datetime.now(tz=timezone.utc) + assert ( + started_before.tzinfo == timezone.utc + ), "started_before must be tz-aware as UTC" + assert ( + started_after < started_before + ), "started_after must be before started_before" + # Don't use BETWEEN b/c we want exclusive started_after here + filters.append( + "(w.started > %(started_after)s AND w.started <= %(started_before)s)" + ) + params["started_after"] = started_after + params["started_before"] = started_before + + filter_str = "WHERE " + " AND ".join(filters) if filters else "" + + assert page >= 1, "page starts at 1" + assert 1 <= size <= 500 + params["offset"] = (page - 1) * size + params["limit"] = size + paginated_filter_str = "LIMIT %(limit)s OFFSET %(offset)s" + + order_by_str = parse_order_by(order_by) + query = f""" + SELECT + w.req_survey_id, + w.started::timestamptz, + w.source, + w.uuid::uuid, + s.user_id + FROM thl_wall w + JOIN thl_session s on w.session_id = s.id + + {filter_str} + {order_by_str} + {paginated_filter_str} + """ + res = self.pg_config.execute_sql_query( + query=query, + params=params, + ) + return [WallAttempt.model_validate(x) for x in res] + + def filter_wall_attempts( + self, + user_id: int, + started_after: Optional[datetime] = None, + started_before: Optional[datetime] = None, + order_by: Optional[str] = "-started", + ) -> List[WallAttempt]: + started_before = started_before or datetime.now(tz=timezone.utc) + res = [] + page = 1 + while True: + chunk = self.filter_wall_attempts_paginated( + user_id=user_id, + started_after=started_after, + started_before=started_before, + order_by=order_by, + page=page, + size=250, + ) + res.extend(chunk) + if not chunk: + break + page += 1 + + return res + + def get_survey_activities( + self, survey_keys: Collection[SurveyKey], product_id: Optional[str] = None + ) -> List[TaskActivity]: + query_base = """ + row_stats AS ( + SELECT + source, survey_id, + count(*) FILTER (WHERE effective_status IS NULL) AS in_progress_count, + max(started) AS last_entrance, + max(finished) FILTER (WHERE effective_status = 'c') AS last_complete + FROM classified + GROUP BY source, survey_id + ), + status_agg AS ( + SELECT + source, survey_id, + jsonb_object_agg(effective_status, cnt) AS status_counts + FROM ( + SELECT source, survey_id, effective_status, count(*) AS cnt + FROM classified + WHERE effective_status IS NOT NULL + GROUP BY source, survey_id, effective_status + ) s + GROUP BY source, survey_id + ), + status_code_1_agg AS ( + SELECT + source, survey_id, + jsonb_object_agg(status_code_1, cnt) AS status_code_1_counts + FROM ( + SELECT source, survey_id, status_code_1, count(*) AS cnt + FROM classified + WHERE status_code_1 IS NOT NULL + GROUP BY source, survey_id, status_code_1 + ) sc + GROUP BY source, survey_id + ) + SELECT + rs.source, + rs.survey_id, + rs.in_progress_count, + rs.last_entrance, + rs.last_complete, + COALESCE(sa.status_counts, '{}'::jsonb) as status_counts, + COALESCE(sc1.status_code_1_counts, '{}'::jsonb) as status_code_1_counts + FROM row_stats rs + LEFT JOIN status_agg sa + ON sa.source = rs.source + AND sa.survey_id = rs.survey_id + LEFT JOIN status_code_1_agg sc1 + ON sc1.source = rs.source + AND sc1.survey_id = rs.survey_id + ORDER BY rs.source, rs.survey_id; + """ + + params = dict() + filters = [] + + # Instead of doing a big IN with a big set of tuples, since we know + # we only have N possible sources, we just split by that and do e.g.: + # ( (source = 'x' and survey_id IN ('1', '2') ) OR + # (source = 'y' and survey_id IN ('3', '4') ) ... ) + sk_filters = [] + survey_source_ids = defaultdict(set) + for sk in survey_keys: + source, survey_id = sk.split(":") + survey_source_ids[Source(source).value].add(survey_id) + for source, survey_ids in survey_source_ids.items(): + sk_filters.append( + f"(source = '{source}' AND survey_id = ANY(%(survey_ids_{source})s))" + ) + params[f"survey_ids_{source}"] = list(survey_ids) + # Make sure this is wrapped in parentheses! + filters.append(f"({' OR '.join(sk_filters)})") + + product_query_join = "" + if product_id is not None: + product_query_join = """ + JOIN thl_session ON w.session_id = thl_session.id + JOIN thl_user ON thl_user.id = thl_session.user_id""" + filters.append("product_id = %(product_id)s") + params["product_id"] = product_id + + filter_str = "WHERE " + " AND ".join(filters) if filters else "" + query_filter = f""" + WITH classified AS ( + SELECT + CASE WHEN w.status IS NULL AND now() - w.started >= interval '90 minutes' + THEN 't' ELSE w.status + END AS effective_status, + w.status_code_1, + w.started, + w.finished, + w.source, + w.survey_id + FROM thl_wall w {product_query_join} + {filter_str} + ), + """ + query = query_filter + query_base + res = self.pg_config.execute_sql_query(query, params) + return [TaskActivity.model_validate(x) for x in res] + + +class WallCacheManager(PostgresManagerWithRedis): + + @cached_property + def wall_manager(self): + return WallManager(pg_config=self.pg_config) + + def get_cache_key_(self, user_id: int) -> str: + assert type(user_id) is int, "user_id must be int" + return f"{self.cache_prefix}:generate_attempts:{user_id}" + + def get_flag_key_(self, user_id: int) -> str: + assert type(user_id) is int, "user_id must be int" + return f"{self.cache_prefix}:generate_attempts:flag:{user_id}" + + def is_flag_set(self, user_id: int) -> bool: + # This flag gets set if a new wall event is created. Whenever we + # update the cache we'll delete the flag. + assert type(user_id) is int, "user_id must be int" + return bool(self.redis_client.get(self.get_flag_key_(user_id=user_id))) + + def set_flag(self, user_id: int) -> None: + # Upon a wall entrance, set this, so we know we have to refresh the cache + assert type(user_id) is int, "user_id must be int" + self.redis_client.set(self.get_flag_key_(user_id=user_id), 1, ex=60 * 60 * 24) + + def clear_flag(self, user_id: int) -> None: + assert type(user_id) is int, "user_id must be int" + self.redis_client.delete(self.get_flag_key_(user_id=user_id)) + + def get_attempts_redis_(self, user_id: int) -> List[WallAttempt]: + redis_key = self.get_cache_key_(user_id=user_id) + # Returns a list even if there is nothing set + res = self.redis_client.lrange(redis_key, 0, 5000) + attempts = [WallAttempt.model_validate_json(x) for x in res] + return attempts + + def update_attempts_redis_(self, attempts: List[WallAttempt], user_id: int) -> None: + if not attempts: + return None + redis_key = self.get_cache_key_(user_id=user_id) + # Make sure attempts is ordered, so the most recent is last + # "LPUSH mylist a b c will result into a list containing c as first element, + # b as second element and a as third element" + attempts = sorted(attempts, key=lambda x: x.started) + json_res = [attempt.model_dump_json() for attempt in attempts] + res = self.redis_client.lpush(redis_key, *json_res) + self.redis_client.expire(redis_key, time=60 * 60 * 24) + # So this doesn't grow forever, keep only the most recent 5k + self.redis_client.ltrim(redis_key, 0, 4999) + return None + + def get_attempts(self, user_id: int) -> List[WallAttempt]: + """ + This is used in the GetOpportunityIDs call to get a list of surveys + (& surveygroups) which should be excluded for this user. We don't + need to know the status or if they finished the survey, just they + entered it, so we don't need to fetch 90 min backfills. The + WallAttempts are stored in a Redis List, ordered most-recent + in index 0. + """ + assert type(user_id) is int, "user_id must be int" + + wall_modified = self.is_flag_set(user_id=user_id) + if not wall_modified: + return self.get_attempts_redis_(user_id=user_id) + + # Attempt to get the most recent wall attempt + redis_key = self.get_cache_key_(user_id=user_id) + res: Optional[str] = self.redis_client.lindex(redis_key, 0) # type: ignore[assignment] + if res is None: + # Nothing in the cache, query for all from db + attempts = self.wall_manager.filter_wall_attempts(user_id=user_id) + self.update_attempts_redis_(attempts=attempts, user_id=user_id) + self.clear_flag(user_id=user_id) + return attempts + + # See if there is anything after the latest cached wall event we have + w = WallAttempt.model_validate_json(res) + new_attempts = self.wall_manager.filter_wall_attempts( + user_id=user_id, started_after=w.started + ) + self.update_attempts_redis_(attempts=new_attempts, user_id=user_id) + self.clear_flag(user_id=user_id) + return self.get_attempts_redis_(user_id=user_id) |
