aboutsummaryrefslogtreecommitdiff
path: root/jb/managers/hit.py
diff options
context:
space:
mode:
Diffstat (limited to 'jb/managers/hit.py')
-rw-r--r--jb/managers/hit.py36
1 files changed, 24 insertions, 12 deletions
diff --git a/jb/managers/hit.py b/jb/managers/hit.py
index 3832418..ce8ffa5 100644
--- a/jb/managers/hit.py
+++ b/jb/managers/hit.py
@@ -24,7 +24,7 @@ class HitQuestionManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
question.id = pk
return None
@@ -67,6 +67,7 @@ class HitQuestionManager(PostgresManager):
class HitTypeManager(PostgresManager):
+
def create(self, hit_type: HitType) -> None:
assert hit_type.amt_hit_type_id is not None
data = hit_type.to_postgres()
@@ -99,9 +100,10 @@ class HitTypeManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
hit_type.id = pk
+
return None
def filter_active(self) -> List[HitType]:
@@ -137,13 +139,13 @@ class HitTypeManager(PostgresManager):
except AssertionError:
return None
- def get_or_create(self, hit_type: HitType) -> None:
+ def get_or_create(self, hit_type: HitType) -> HitType:
res = self.get_if_exists(amt_hit_type_id=hit_type.amt_hit_type_id)
if res:
- hit_type.id = res.id
- if res is None:
- self.create(hit_type=hit_type)
- return None
+ return res
+
+ self.create(hit_type=hit_type)
+ return self.get(amt_hit_type_id=hit_type.amt_hit_type_id)
def set_min_active(self, hit_type: HitType) -> None:
assert hit_type.id, "must be in the db first!"
@@ -164,6 +166,7 @@ class HitTypeManager(PostgresManager):
class HitManager(PostgresManager):
+
def create(self, hit: Hit):
assert hit.amt_hit_id is not None
assert hit.id is None
@@ -209,7 +212,7 @@ class HitManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
hit.id = pk
return hit
@@ -267,7 +270,7 @@ class HitManager(PostgresManager):
c.execute(query, data)
conn.commit()
assert c.rowcount == 1, c.rowcount
- hit.id = c.fetchone()["id"]
+ hit.id = c.fetchone()["id"] # type: ignore
return None
def get_from_amt_id(self, amt_hit_id: str) -> Hit:
@@ -305,17 +308,26 @@ class HitManager(PostgresManager):
""",
params={"amt_hit_id": amt_hit_id},
)
- assert len(res) == 1
+ assert len(res) == 1, "Incorrect number of results"
res = res[0]
+
question_xml = HitQuestion.model_validate(
{"height": res.pop("height"), "url": res.pop("url")}
).xml
+
res["question_id"] = res["question_id"]
res["hit_question_xml"] = question_xml
return Hit.from_postgres(res)
- def get_active_count(self, hit_type_id: int):
+ def get_from_amt_id_if_exists(self, amt_hit_id: str) -> Optional[Hit]:
+ try:
+ return self.get_from_amt_id(amt_hit_id=amt_hit_id)
+
+ except (AssertionError, Exception):
+ return None
+
+ def get_active_count(self, hit_type_id: int) -> int:
return self.pg_config.execute_sql_query(
"""
SELECT COUNT(1) as active_count
@@ -326,7 +338,7 @@ class HitManager(PostgresManager):
params={"status": HitStatus.Assignable, "hit_type_id": hit_type_id},
)[0]["active_count"]
- def filter_active_ids(self, hit_type_id: int):
+ def filter_active_ids(self, hit_type_id: int) -> set[str]:
res = self.pg_config.execute_sql_query(
"""
SELECT mh.amt_hit_id