from datetime import datetime, timezone from typing import Optional, List from psycopg import sql from jb.managers import PostgresManager from jb.models.definitions import HitStatus from jb.models.hit import HitQuestion, HitType, Hit class HitQuestionManager(PostgresManager): def create(self, question: HitQuestion) -> None: assert question.id is None data = question.to_postgres() query = sql.SQL( """ INSERT INTO mtwerk_question (url, height) VALUES (%(url)s, %(height)s) RETURNING id; """ ) with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) pk = c.fetchone()["id"] conn.commit() question.id = pk return None def get_by_id(self, question_id: int) -> HitQuestion: res = self.pg_config.execute_sql_query( """ SELECT * FROM mtwerk_question WHERE id = %(question_id)s LIMIT 1; """, params={"question_id": question_id}, ) assert len(res) == 1 return HitQuestion.model_validate(res[0]) def get_by_values_if_exists(self, url: str, height: int) -> Optional[HitQuestion]: res = self.pg_config.execute_sql_query( """ SELECT * FROM mtwerk_question WHERE url = %(url)s AND height = %(height)s LIMIT 2; """, params={"url": url, "height": height}, ) assert len(res) != 2, "More than 1 result!" if len(res) == 0: return None return HitQuestion.model_validate(res[0]) def get_or_create(self, question: HitQuestion) -> HitQuestion: res = self.get_by_values_if_exists(url=question.url, height=question.height) if res: return res self.create(question=question) return question class HitTypeManager(PostgresManager): def create(self, hit_type: HitType) -> None: assert hit_type.amt_hit_type_id is not None data = hit_type.to_postgres() query = sql.SQL( """ INSERT INTO mtwerk_hittype ( amt_hit_type_id, title, description, reward, assignment_duration, auto_approval_delay, keywords, min_active ) VALUES ( %(amt_hit_type_id)s, %(title)s, %(description)s, %(reward)s, %(assignment_duration)s, %(auto_approval_delay)s, %(keywords)s, %(min_active)s ) RETURNING id; """ ) with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) pk = c.fetchone()["id"] conn.commit() hit_type.id = pk return None def filter_active(self) -> List[HitType]: res = self.pg_config.execute_sql_query( """ SELECT * FROM mtwerk_hittype WHERE min_active > 0 LIMIT 50 """ ) if len(res) == 50: raise ValueError("Too many HitTypes!") return [HitType.from_postgres(i) for i in res] def get(self, amt_hit_type_id: str) -> HitType: res = self.pg_config.execute_sql_query( """ SELECT * FROM mtwerk_hittype WHERE amt_hit_type_id = %(amt_hit_type_id)s """, params={"amt_hit_type_id": amt_hit_type_id}, ) assert len(res) == 1 return HitType.from_postgres(res[0]) def get_if_exists(self, amt_hit_type_id: str) -> Optional[HitType]: try: return self.get(amt_hit_type_id=amt_hit_type_id) except AssertionError: return None def get_or_create(self, hit_type: HitType) -> None: res = self.get_if_exists(amt_hit_type_id=hit_type.amt_hit_type_id) if res: hit_type.id = res.id if res is None: self.create(hit_type=hit_type) return None def set_min_active(self, hit_type: HitType) -> None: assert hit_type.id, "must be in the db first!" query = sql.SQL( """ UPDATE mtwerk_hittype SET min_active = %(min_active)s WHERE id = %(id)s """ ) data = {"id": hit_type.id, "min_active": hit_type.min_active} with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) conn.commit() assert c.rowcount == 1, c.rowcount return None class HitManager(PostgresManager): def create(self, hit: Hit): assert hit.amt_hit_id is not None assert hit.id is None data = hit.to_postgres() query = sql.SQL( """ INSERT INTO mtwerk_hit ( amt_hit_id, hit_type_id, amt_group_id, status, review_status, creation_time, expiration, question_id, created_at, modified_at, max_assignments, assignment_pending_count, assignment_completed_count, assignment_available_count ) VALUES ( %(amt_hit_id)s, %(hit_type_id)s, %(amt_group_id)s, %(status)s, %(review_status)s, %(creation_time)s, %(expiration)s, %(question_id)s, %(created_at)s, %(modified_at)s, %(max_assignments)s, %(assignment_pending_count)s, %(assignment_completed_count)s, %(assignment_available_count)s ) RETURNING id; """ ) with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) pk = c.fetchone()["id"] conn.commit() hit.id = pk return hit def update_status(self, amt_hit_id: str, hit_status: HitStatus): now = datetime.now(tz=timezone.utc) query = sql.SQL( """ UPDATE mtwerk_hit SET status = %(status)s, modified_at = %(modified_at)s WHERE amt_hit_id = %(amt_hit_id)s; """ ) data = { "amt_hit_id": amt_hit_id, "status": hit_status.value, "modified_at": now, } with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) conn.commit() assert c.rowcount == 1, c.rowcount return None def update_hit(self, hit: Hit): hit.modified_at = datetime.now(tz=timezone.utc) # fields expected to change: fields = { "status", "review_status", "assignment_pending_count", "assignment_available_count", "assignment_completed_count", "amt_hit_id", "modified_at", } query = sql.SQL( """ UPDATE mtwerk_hit SET status = %(status)s, review_status = %(review_status)s, assignment_pending_count = %(assignment_pending_count)s, assignment_available_count = %(assignment_available_count)s, assignment_completed_count = %(assignment_completed_count)s, modified_at = %(modified_at)s WHERE amt_hit_id = %(amt_hit_id)s RETURNING id; """ ) data = hit.model_dump(mode="json", include=fields) with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) conn.commit() assert c.rowcount == 1, c.rowcount hit.id = c.fetchone()["id"] return None def get_from_amt_id(self, amt_hit_id: str) -> Hit: res = self.pg_config.execute_sql_query( """ SELECT mh.id, mh.amt_hit_id, mh.amt_group_id, mh.status, mh.review_status, mh.creation_time, mh.expiration, mh.created_at, mh.modified_at, mh.assignment_available_count, mh.assignment_completed_count, mh.assignment_pending_count, mh.max_assignments, mht.amt_hit_type_id, mht.title, mht.description, mht.reward, mht.assignment_duration, mht.auto_approval_delay, mht.keywords, mq.id as question_id, mq.height, mq.url FROM mtwerk_hit mh JOIN mtwerk_hittype mht ON mh.hit_type_id = mht.id JOIN mtwerk_question mq ON mh.question_id = mq.id WHERE amt_hit_id = %(amt_hit_id)s LIMIT 2; """, params={"amt_hit_id": amt_hit_id}, ) assert len(res) == 1 res = res[0] question_xml = HitQuestion.model_validate( {"height": res.pop("height"), "url": res.pop("url")} ).xml res["question_id"] = res["question_id"] res["hit_question_xml"] = question_xml return Hit.from_postgres(res) def get_active_count(self, hit_type_id: int): return self.pg_config.execute_sql_query( """ SELECT COUNT(1) as active_count FROM mtwerk_hit WHERE status = %(status)s AND hit_type_id = %(hit_type_id)s; """, params={"status": HitStatus.Assignable, "hit_type_id": hit_type_id}, )[0]["active_count"] def filter_active_ids(self, hit_type_id: int): res = self.pg_config.execute_sql_query( """ SELECT mh.amt_hit_id FROM mtwerk_hit mh WHERE hit_type_id = %(hit_type_id)s; """, params={"hit_type_id": hit_type_id}, ) return {x["amt_hit_id"] for x in res}