diff options
| author | Max Nanis | 2026-02-19 02:43:23 -0500 |
|---|---|---|
| committer | Max Nanis | 2026-02-19 02:43:23 -0500 |
| commit | f0f96f83c2630e890a2cbcab53f77fd4c37e1684 (patch) | |
| tree | c6d2cb092e76bf5d499e0ea9949508d6b22164fd /jb/managers/hit.py | |
| parent | 3eaa56f0306ead818f64c3d99fc6d230d9b970a4 (diff) | |
| download | amt-jb-master.tar.gz amt-jb-master.zip | |
Diffstat (limited to 'jb/managers/hit.py')
| -rw-r--r-- | jb/managers/hit.py | 338 |
1 files changed, 338 insertions, 0 deletions
diff --git a/jb/managers/hit.py b/jb/managers/hit.py new file mode 100644 index 0000000..3832418 --- /dev/null +++ b/jb/managers/hit.py @@ -0,0 +1,338 @@ +from datetime import datetime, timezone +from typing import Optional, List + +from psycopg import sql + +from jb.managers import PostgresManager +from jb.models.definitions import HitStatus +from jb.models.hit import HitQuestion, HitType, Hit + + +class HitQuestionManager(PostgresManager): + + def create(self, question: HitQuestion) -> None: + assert question.id is None + data = question.to_postgres() + query = sql.SQL( + """ + INSERT INTO mtwerk_question (url, height) + VALUES (%(url)s, %(height)s) + RETURNING id; + """ + ) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + pk = c.fetchone()["id"] + conn.commit() + question.id = pk + return None + + def get_by_id(self, question_id: int) -> HitQuestion: + res = self.pg_config.execute_sql_query( + """ + SELECT * + FROM mtwerk_question + WHERE id = %(question_id)s + LIMIT 1; + """, + params={"question_id": question_id}, + ) + assert len(res) == 1 + return HitQuestion.model_validate(res[0]) + + def get_by_values_if_exists(self, url: str, height: int) -> Optional[HitQuestion]: + res = self.pg_config.execute_sql_query( + """ + SELECT * + FROM mtwerk_question + WHERE url = %(url)s AND height = %(height)s + LIMIT 2; + """, + params={"url": url, "height": height}, + ) + assert len(res) != 2, "More than 1 result!" + if len(res) == 0: + return None + + return HitQuestion.model_validate(res[0]) + + def get_or_create(self, question: HitQuestion) -> HitQuestion: + res = self.get_by_values_if_exists(url=question.url, height=question.height) + if res: + return res + self.create(question=question) + return question + + +class HitTypeManager(PostgresManager): + def create(self, hit_type: HitType) -> None: + assert hit_type.amt_hit_type_id is not None + data = hit_type.to_postgres() + query = sql.SQL( + """ + INSERT INTO mtwerk_hittype ( + amt_hit_type_id, + title, + description, + reward, + assignment_duration, + auto_approval_delay, + keywords, + min_active + ) + VALUES ( + %(amt_hit_type_id)s, + %(title)s, + %(description)s, + %(reward)s, + %(assignment_duration)s, + %(auto_approval_delay)s, + %(keywords)s, + %(min_active)s + ) + RETURNING id; + """ + ) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + pk = c.fetchone()["id"] + conn.commit() + hit_type.id = pk + return None + + def filter_active(self) -> List[HitType]: + res = self.pg_config.execute_sql_query( + """ + SELECT * + FROM mtwerk_hittype + WHERE min_active > 0 + LIMIT 50 + """ + ) + + if len(res) == 50: + raise ValueError("Too many HitTypes!") + + return [HitType.from_postgres(i) for i in res] + + def get(self, amt_hit_type_id: str) -> HitType: + res = self.pg_config.execute_sql_query( + """ + SELECT * + FROM mtwerk_hittype + WHERE amt_hit_type_id = %(amt_hit_type_id)s + """, + params={"amt_hit_type_id": amt_hit_type_id}, + ) + assert len(res) == 1 + return HitType.from_postgres(res[0]) + + def get_if_exists(self, amt_hit_type_id: str) -> Optional[HitType]: + try: + return self.get(amt_hit_type_id=amt_hit_type_id) + except AssertionError: + return None + + def get_or_create(self, hit_type: HitType) -> None: + res = self.get_if_exists(amt_hit_type_id=hit_type.amt_hit_type_id) + if res: + hit_type.id = res.id + if res is None: + self.create(hit_type=hit_type) + return None + + def set_min_active(self, hit_type: HitType) -> None: + assert hit_type.id, "must be in the db first!" + query = sql.SQL( + """ + UPDATE mtwerk_hittype + SET min_active = %(min_active)s + WHERE id = %(id)s + """ + ) + data = {"id": hit_type.id, "min_active": hit_type.min_active} + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + conn.commit() + assert c.rowcount == 1, c.rowcount + return None + + +class HitManager(PostgresManager): + def create(self, hit: Hit): + assert hit.amt_hit_id is not None + assert hit.id is None + data = hit.to_postgres() + query = sql.SQL( + """ + INSERT INTO mtwerk_hit ( + amt_hit_id, + hit_type_id, + amt_group_id, + status, + review_status, + creation_time, + expiration, + question_id, + created_at, + modified_at, + max_assignments, + assignment_pending_count, + assignment_completed_count, + assignment_available_count + ) + VALUES ( + %(amt_hit_id)s, + %(hit_type_id)s, + %(amt_group_id)s, + %(status)s, + %(review_status)s, + %(creation_time)s, + %(expiration)s, + %(question_id)s, + %(created_at)s, + %(modified_at)s, + %(max_assignments)s, + %(assignment_pending_count)s, + %(assignment_completed_count)s, + %(assignment_available_count)s + ) + RETURNING id; + """ + ) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + pk = c.fetchone()["id"] + conn.commit() + hit.id = pk + return hit + + def update_status(self, amt_hit_id: str, hit_status: HitStatus): + now = datetime.now(tz=timezone.utc) + query = sql.SQL( + """ + UPDATE mtwerk_hit + SET status = %(status)s, modified_at = %(modified_at)s + WHERE amt_hit_id = %(amt_hit_id)s; + """ + ) + + data = { + "amt_hit_id": amt_hit_id, + "status": hit_status.value, + "modified_at": now, + } + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + conn.commit() + assert c.rowcount == 1, c.rowcount + return None + + def update_hit(self, hit: Hit): + hit.modified_at = datetime.now(tz=timezone.utc) + # fields expected to change: + fields = { + "status", + "review_status", + "assignment_pending_count", + "assignment_available_count", + "assignment_completed_count", + "amt_hit_id", + "modified_at", + } + query = sql.SQL( + """ + UPDATE mtwerk_hit + SET status = %(status)s, review_status = %(review_status)s, + assignment_pending_count = %(assignment_pending_count)s, + assignment_available_count = %(assignment_available_count)s, + assignment_completed_count = %(assignment_completed_count)s, + modified_at = %(modified_at)s + WHERE amt_hit_id = %(amt_hit_id)s + RETURNING id; + """ + ) + + data = hit.model_dump(mode="json", include=fields) + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + conn.commit() + assert c.rowcount == 1, c.rowcount + hit.id = c.fetchone()["id"] + return None + + def get_from_amt_id(self, amt_hit_id: str) -> Hit: + res = self.pg_config.execute_sql_query( + """ + SELECT + mh.id, + mh.amt_hit_id, + mh.amt_group_id, + mh.status, + mh.review_status, + mh.creation_time, + mh.expiration, + mh.created_at, + mh.modified_at, + mh.assignment_available_count, + mh.assignment_completed_count, + mh.assignment_pending_count, + mh.max_assignments, + mht.amt_hit_type_id, + mht.title, + mht.description, + mht.reward, + mht.assignment_duration, + mht.auto_approval_delay, + mht.keywords, + mq.id as question_id, + mq.height, + mq.url + FROM mtwerk_hit mh + JOIN mtwerk_hittype mht ON mh.hit_type_id = mht.id + JOIN mtwerk_question mq ON mh.question_id = mq.id + WHERE amt_hit_id = %(amt_hit_id)s + LIMIT 2; + """, + params={"amt_hit_id": amt_hit_id}, + ) + assert len(res) == 1 + res = res[0] + question_xml = HitQuestion.model_validate( + {"height": res.pop("height"), "url": res.pop("url")} + ).xml + res["question_id"] = res["question_id"] + res["hit_question_xml"] = question_xml + + return Hit.from_postgres(res) + + def get_active_count(self, hit_type_id: int): + return self.pg_config.execute_sql_query( + """ + SELECT COUNT(1) as active_count + FROM mtwerk_hit + WHERE status = %(status)s + AND hit_type_id = %(hit_type_id)s; + """, + params={"status": HitStatus.Assignable, "hit_type_id": hit_type_id}, + )[0]["active_count"] + + def filter_active_ids(self, hit_type_id: int): + res = self.pg_config.execute_sql_query( + """ + SELECT mh.amt_hit_id + FROM mtwerk_hit mh + WHERE hit_type_id = %(hit_type_id)s; + """, + params={"hit_type_id": hit_type_id}, + ) + return {x["amt_hit_id"] for x in res} |
