aboutsummaryrefslogtreecommitdiff
path: root/jb/managers
diff options
context:
space:
mode:
Diffstat (limited to 'jb/managers')
-rw-r--r--jb/managers/__init__.py4
-rw-r--r--jb/managers/amt.py21
-rw-r--r--jb/managers/assignment.py8
-rw-r--r--jb/managers/bonus.py11
-rw-r--r--jb/managers/hit.py36
-rw-r--r--jb/managers/thl.py40
6 files changed, 76 insertions, 44 deletions
diff --git a/jb/managers/__init__.py b/jb/managers/__init__.py
index e2aab6d..e99569a 100644
--- a/jb/managers/__init__.py
+++ b/jb/managers/__init__.py
@@ -15,8 +15,8 @@ class PostgresManager:
def __init__(
self,
pg_config: PostgresConfig,
- permissions: Collection[Permission] = None,
- **kwargs,
+ permissions: Collection[Permission] = None, # type: ignore
+ **kwargs, # type: ignore
):
super().__init__(**kwargs)
self.pg_config = pg_config
diff --git a/jb/managers/amt.py b/jb/managers/amt.py
index 79661c7..0ec70d3 100644
--- a/jb/managers/amt.py
+++ b/jb/managers/amt.py
@@ -10,7 +10,7 @@ from jb.decorators import AMT_CLIENT
from jb.models import AMTAccount
from jb.models.assignment import Assignment
from jb.models.bonus import Bonus
-from jb.models.currency import USDCent
+from generalresearchutils.currency import USDCent
from jb.models.definitions import HitStatus
from jb.models.hit import HitType, HitQuestion, Hit
@@ -48,19 +48,24 @@ class AMTManager:
return hit, None
@classmethod
- def get_hit_status(cls, amt_hit_id: str):
+ def get_hit_status(cls, amt_hit_id: str) -> HitStatus:
res, msg = cls.get_hit_if_exists(amt_hit_id=amt_hit_id)
+
if res is None:
+ if msg is None:
+ return HitStatus.Unassignable
+
if " does not exist. (" in msg:
return HitStatus.Disposed
else:
logging.warning(msg)
return HitStatus.Unassignable
+
return res.status
@staticmethod
def create_hit_type(hit_type: HitType):
- res = AMT_CLIENT.create_hit_type(**hit_type.to_api_request_body())
+ res = AMT_CLIENT.create_hit_type(**hit_type.to_api_request_body()) # type: ignore
hit_type.amt_hit_type_id = res["HITTypeId"]
AMT_CLIENT.update_notification_settings(
HITTypeId=hit_type.amt_hit_type_id,
@@ -94,8 +99,10 @@ class AMTManager:
@staticmethod
def get_assignment(amt_assignment_id: str) -> Assignment:
- # note, you CANNOT get an assignment if it has been only ACCEPTED (by the worker)
- # the api is stupid. it will only show up once it is submitted
+ """
+ You CANNOT get an Assignment if it has been only ACCEPTED (by the
+ worker). The api is stupid, it will only show up once it is Submitted
+ """
res = AMT_CLIENT.get_assignment(AssignmentId=amt_assignment_id)
ass_res: AssignmentTypeDef = res["Assignment"]
assignment = Assignment.from_amt_get_assignment(ass_res)
@@ -158,6 +165,7 @@ class AMTManager:
raise ValueError(error_msg)
# elif "This HIT is currently in the state 'Reviewing'" in error_msg:
# logging.warning(error_msg)
+
return None
@staticmethod
@@ -203,7 +211,7 @@ class AMTManager:
return None
@staticmethod
- def expire_all_hits():
+ def expire_all_hits() -> None:
# used in testing only (or in an emergency I guess)
now = datetime.now(tz=timezone.utc)
paginator = AMT_CLIENT.get_paginator("list_hits")
@@ -214,3 +222,4 @@ class AMTManager:
AMT_CLIENT.update_expiration_for_hit(
HITId=hit["HITId"], ExpireAt=now
)
+ return None
diff --git a/jb/managers/assignment.py b/jb/managers/assignment.py
index fca72e8..dd3c866 100644
--- a/jb/managers/assignment.py
+++ b/jb/managers/assignment.py
@@ -28,7 +28,7 @@ class AssignmentManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
stub.id = pk
return None
@@ -62,7 +62,7 @@ class AssignmentManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
assignment.id = pk
return None
@@ -233,7 +233,7 @@ class AssignmentManager(PostgresManager):
"lookback_interval": f"{lookback_hrs} hour",
},
)
- return int(res[0]["c"])
+ return int(res[0]["c"]) # type: ignore
def rejected_count(
self, amt_worker_id: str, lookback_hrs: int = 24
@@ -256,4 +256,4 @@ class AssignmentManager(PostgresManager):
"status": AssignmentStatus.Rejected.value,
},
)
- return int(res[0]["c"])
+ return int(res[0]["c"]) # type: ignore
diff --git a/jb/managers/bonus.py b/jb/managers/bonus.py
index 0cb8b02..89b81f0 100644
--- a/jb/managers/bonus.py
+++ b/jb/managers/bonus.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import List, Any
from psycopg import sql
@@ -37,12 +37,12 @@ class BonusManager(PostgresManager):
c.execute(query, data)
res = c.fetchone()
conn.commit()
- bonus.id = res["id"]
- bonus.assignment_id = res["assignment_id"]
+ bonus.id = res["id"] # type: ignore
+ bonus.assignment_id = res["assignment_id"] # type: ignore
return None
def filter(self, amt_assignment_id: str) -> List[Bonus]:
- res = self.pg_config.execute_sql_query(
+ res: List[Any] = self.pg_config.execute_sql_query(
"""
SELECT mb.*, ma.amt_assignment_id
FROM mtwerk_bonus mb
@@ -51,4 +51,5 @@ class BonusManager(PostgresManager):
""",
params={"amt_assignment_id": amt_assignment_id},
)
- return [Bonus.from_postgres(x) for x in res]
+
+ return [Bonus.from_postgres(data=x) for x in res]
diff --git a/jb/managers/hit.py b/jb/managers/hit.py
index 3832418..ce8ffa5 100644
--- a/jb/managers/hit.py
+++ b/jb/managers/hit.py
@@ -24,7 +24,7 @@ class HitQuestionManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
question.id = pk
return None
@@ -67,6 +67,7 @@ class HitQuestionManager(PostgresManager):
class HitTypeManager(PostgresManager):
+
def create(self, hit_type: HitType) -> None:
assert hit_type.amt_hit_type_id is not None
data = hit_type.to_postgres()
@@ -99,9 +100,10 @@ class HitTypeManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
hit_type.id = pk
+
return None
def filter_active(self) -> List[HitType]:
@@ -137,13 +139,13 @@ class HitTypeManager(PostgresManager):
except AssertionError:
return None
- def get_or_create(self, hit_type: HitType) -> None:
+ def get_or_create(self, hit_type: HitType) -> HitType:
res = self.get_if_exists(amt_hit_type_id=hit_type.amt_hit_type_id)
if res:
- hit_type.id = res.id
- if res is None:
- self.create(hit_type=hit_type)
- return None
+ return res
+
+ self.create(hit_type=hit_type)
+ return self.get(amt_hit_type_id=hit_type.amt_hit_type_id)
def set_min_active(self, hit_type: HitType) -> None:
assert hit_type.id, "must be in the db first!"
@@ -164,6 +166,7 @@ class HitTypeManager(PostgresManager):
class HitManager(PostgresManager):
+
def create(self, hit: Hit):
assert hit.amt_hit_id is not None
assert hit.id is None
@@ -209,7 +212,7 @@ class HitManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
hit.id = pk
return hit
@@ -267,7 +270,7 @@ class HitManager(PostgresManager):
c.execute(query, data)
conn.commit()
assert c.rowcount == 1, c.rowcount
- hit.id = c.fetchone()["id"]
+ hit.id = c.fetchone()["id"] # type: ignore
return None
def get_from_amt_id(self, amt_hit_id: str) -> Hit:
@@ -305,17 +308,26 @@ class HitManager(PostgresManager):
""",
params={"amt_hit_id": amt_hit_id},
)
- assert len(res) == 1
+ assert len(res) == 1, "Incorrect number of results"
res = res[0]
+
question_xml = HitQuestion.model_validate(
{"height": res.pop("height"), "url": res.pop("url")}
).xml
+
res["question_id"] = res["question_id"]
res["hit_question_xml"] = question_xml
return Hit.from_postgres(res)
- def get_active_count(self, hit_type_id: int):
+ def get_from_amt_id_if_exists(self, amt_hit_id: str) -> Optional[Hit]:
+ try:
+ return self.get_from_amt_id(amt_hit_id=amt_hit_id)
+
+ except (AssertionError, Exception):
+ return None
+
+ def get_active_count(self, hit_type_id: int) -> int:
return self.pg_config.execute_sql_query(
"""
SELECT COUNT(1) as active_count
@@ -326,7 +338,7 @@ class HitManager(PostgresManager):
params={"status": HitStatus.Assignable, "hit_type_id": hit_type_id},
)[0]["active_count"]
- def filter_active_ids(self, hit_type_id: int):
+ def filter_active_ids(self, hit_type_id: int) -> set[str]:
res = self.pg_config.execute_sql_query(
"""
SELECT mh.amt_hit_id
diff --git a/jb/managers/thl.py b/jb/managers/thl.py
index b1dcbde..83f49f6 100644
--- a/jb/managers/thl.py
+++ b/jb/managers/thl.py
@@ -1,7 +1,3 @@
-from decimal import Decimal
-from typing import Dict, Optional
-
-import requests
from generalresearchutils.models.thl.payout import UserPayoutEvent
from generalresearchutils.models.thl.task_status import TaskStatusResponse
from generalresearchutils.models.thl.wallet.cashout_method import (
@@ -9,45 +5,58 @@ from generalresearchutils.models.thl.wallet.cashout_method import (
CashoutRequestInfo,
)
+from generalresearchutils.models.thl.user_profile import UserProfile
+from generalresearchutils.currency import USDCent
+
from jb.config import settings
-from jb.models.currency import USDCent
-from jb.models.definitions import PayoutStatus
-# TODO: Organize this more with other endpoints (offerwall, cashout requests/approvals, etc).
+from generalresearchutils.models.thl.definitions import PayoutStatus
+
+
+from typing import Optional
+import requests
+
+# TODO: Organize this more with other endpoints (offerwall, cashout
+# requests/approvals, etc).
-def get_user_profile(amt_worker_id: str) -> Dict:
+def get_user_profile(amt_worker_id: str) -> UserProfile:
url = f"{settings.fsb_host}{settings.product_id}/user/{amt_worker_id}/profile/"
res = requests.get(url).json()
if res.get("detail") == "user not found":
raise ValueError("user not found")
- return res["user_profile"]
+
+ return UserProfile.model_validate(res["user_profile"])
def get_user_blocked(amt_worker_id: str) -> bool:
+ # Not blocked if None
res = get_user_profile(amt_worker_id=amt_worker_id)
- return res["user"]["blocked"]
+ return res.user.blocked if res.user.blocked is not None else False
-def get_user_blocked_or_not_exists(amt_worker_id: str) -> bool:
+def get_user_blocked_or_not_exists(amt_worker_id: str) -> Optional[bool]:
try:
res = get_user_profile(amt_worker_id=amt_worker_id)
- return res["user"]["blocked"]
+ return res.user.blocked if res.user.blocked is not None else False
except ValueError as e:
if e.args[0] == "user not found":
return True
+ return None
+
def get_task_status(tsid: str) -> Optional[TaskStatusResponse]:
url = f"{settings.fsb_host}{settings.product_id}/status/{tsid}/"
d = requests.get(url).json()
if d.get("msg") == "invalid tsid":
return None
+
return TaskStatusResponse.model_validate(d)
def user_cashout_request(
- amt_worker_id: str, amount: USDCent, cashout_method_id
+ amt_worker_id: str, amount: USDCent, cashout_method_id: str
) -> CashoutRequestInfo:
assert cashout_method_id in {
settings.amt_assignment_cashout_method,
@@ -56,7 +65,8 @@ def user_cashout_request(
assert isinstance(amount, USDCent)
assert USDCent(0) < amount < USDCent(10_00)
url = f"{settings.fsb_host}{settings.product_id}/cashout/"
- body = {
+
+ body: dict[str, str | int] = {
"bpuid": amt_worker_id,
"amount": int(amount),
"cashout_method_id": cashout_method_id,
@@ -81,7 +91,7 @@ def manage_pending_cashout(
return UserPayoutEvent.model_validate(d)
-def get_wallet_balance(amt_worker_id: str):
+def get_wallet_balance(amt_worker_id: str) -> USDCent:
url = f"{settings.fsb_host}{settings.product_id}/wallet/"
params = {"bpuid": amt_worker_id}
return USDCent(requests.get(url, params=params).json()["wallet"]["amount"])