diff options
Diffstat (limited to 'jb/managers/hit.py')
| -rw-r--r-- | jb/managers/hit.py | 36 |
1 files changed, 24 insertions, 12 deletions
diff --git a/jb/managers/hit.py b/jb/managers/hit.py index 3832418..ce8ffa5 100644 --- a/jb/managers/hit.py +++ b/jb/managers/hit.py @@ -24,7 +24,7 @@ class HitQuestionManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() question.id = pk return None @@ -67,6 +67,7 @@ class HitQuestionManager(PostgresManager): class HitTypeManager(PostgresManager): + def create(self, hit_type: HitType) -> None: assert hit_type.amt_hit_type_id is not None data = hit_type.to_postgres() @@ -99,9 +100,10 @@ class HitTypeManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() hit_type.id = pk + return None def filter_active(self) -> List[HitType]: @@ -137,13 +139,13 @@ class HitTypeManager(PostgresManager): except AssertionError: return None - def get_or_create(self, hit_type: HitType) -> None: + def get_or_create(self, hit_type: HitType) -> HitType: res = self.get_if_exists(amt_hit_type_id=hit_type.amt_hit_type_id) if res: - hit_type.id = res.id - if res is None: - self.create(hit_type=hit_type) - return None + return res + + self.create(hit_type=hit_type) + return self.get(amt_hit_type_id=hit_type.amt_hit_type_id) def set_min_active(self, hit_type: HitType) -> None: assert hit_type.id, "must be in the db first!" @@ -164,6 +166,7 @@ class HitTypeManager(PostgresManager): class HitManager(PostgresManager): + def create(self, hit: Hit): assert hit.amt_hit_id is not None assert hit.id is None @@ -209,7 +212,7 @@ class HitManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() hit.id = pk return hit @@ -267,7 +270,7 @@ class HitManager(PostgresManager): c.execute(query, data) conn.commit() assert c.rowcount == 1, c.rowcount - hit.id = c.fetchone()["id"] + hit.id = c.fetchone()["id"] # type: ignore return None def get_from_amt_id(self, amt_hit_id: str) -> Hit: @@ -305,17 +308,26 @@ class HitManager(PostgresManager): """, params={"amt_hit_id": amt_hit_id}, ) - assert len(res) == 1 + assert len(res) == 1, "Incorrect number of results" res = res[0] + question_xml = HitQuestion.model_validate( {"height": res.pop("height"), "url": res.pop("url")} ).xml + res["question_id"] = res["question_id"] res["hit_question_xml"] = question_xml return Hit.from_postgres(res) - def get_active_count(self, hit_type_id: int): + def get_from_amt_id_if_exists(self, amt_hit_id: str) -> Optional[Hit]: + try: + return self.get_from_amt_id(amt_hit_id=amt_hit_id) + + except (AssertionError, Exception): + return None + + def get_active_count(self, hit_type_id: int) -> int: return self.pg_config.execute_sql_query( """ SELECT COUNT(1) as active_count @@ -326,7 +338,7 @@ class HitManager(PostgresManager): params={"status": HitStatus.Assignable, "hit_type_id": hit_type_id}, )[0]["active_count"] - def filter_active_ids(self, hit_type_id: int): + def filter_active_ids(self, hit_type_id: int) -> set[str]: res = self.pg_config.execute_sql_query( """ SELECT mh.amt_hit_id |
