aboutsummaryrefslogtreecommitdiff
path: root/jb
diff options
context:
space:
mode:
Diffstat (limited to 'jb')
-rw-r--r--jb/decorators.py4
-rw-r--r--jb/flow/assignment_tasks.py52
-rw-r--r--jb/flow/events.py35
-rw-r--r--jb/flow/maintenance.py14
-rw-r--r--jb/flow/monitoring.py53
-rw-r--r--jb/flow/setup_tasks.py36
-rw-r--r--jb/flow/tasks.py16
-rw-r--r--jb/main.py7
-rw-r--r--jb/managers/__init__.py4
-rw-r--r--jb/managers/amt.py21
-rw-r--r--jb/managers/assignment.py8
-rw-r--r--jb/managers/bonus.py11
-rw-r--r--jb/managers/hit.py36
-rw-r--r--jb/managers/thl.py40
-rw-r--r--jb/models/assignment.py23
-rw-r--r--jb/models/bonus.py7
-rw-r--r--jb/models/currency.py70
-rw-r--r--jb/models/custom_types.py3
-rw-r--r--jb/models/definitions.py28
-rw-r--r--jb/models/event.py19
-rw-r--r--jb/models/hit.py24
-rw-r--r--jb/settings.py2
-rw-r--r--jb/views/common.py68
-rw-r--r--jb/views/tasks.py9
-rw-r--r--jb/views/utils.py7
25 files changed, 308 insertions, 289 deletions
diff --git a/jb/decorators.py b/jb/decorators.py
index 54d36c7..cbc28b5 100644
--- a/jb/decorators.py
+++ b/jb/decorators.py
@@ -66,9 +66,7 @@ AM = AssignmentManager(
pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE]
)
-BM = BonusManager(
- pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE]
-)
+BM = BonusManager(pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE])
influx_client = None
if settings.influx_db:
diff --git a/jb/flow/assignment_tasks.py b/jb/flow/assignment_tasks.py
index 4022716..bb0877d 100644
--- a/jb/flow/assignment_tasks.py
+++ b/jb/flow/assignment_tasks.py
@@ -5,6 +5,7 @@ from typing import Optional
from generalresearchutils.models.thl.definitions import PayoutStatus, StatusCode1
from generalresearchutils.models.thl.wallet.cashout_method import CashoutRequestInfo
+from generalresearchutils.currency import USDCent
from jb.decorators import AM, HM, BM
from jb.flow.monitoring import emit_error_event, emit_assignment_event, emit_bonus_event
@@ -21,21 +22,20 @@ from jb.managers.thl import (
get_user_blocked,
get_task_status,
user_cashout_request,
- AMT_ASSIGNMENT_CASHOUT_METHOD,
manage_pending_cashout,
get_user_blocked_or_not_exists,
get_wallet_balance,
- AMT_BONUS_CASHOUT_METHOD,
)
from jb.models.assignment import Assignment
-from jb.models.currency import USDCent
from jb.models.definitions import AssignmentStatus
from jb.models.event import MTurkEvent
+from jb.config import settings
def process_assignment_submitted(event: MTurkEvent) -> None:
"""
- Called either directly or from the SNS Notification that a HIT was submitted
+ Called either directly or from the SNS Notification that a
+ HIT was submitted
:return: None
"""
@@ -137,13 +137,22 @@ def process_assignment_submitted(event: MTurkEvent) -> None:
return issue_worker_payment(assignment)
-def review_hit(assignment):
+def review_hit(assignment: Assignment) -> None:
# Reviewable to Reviewing
AMTManager.update_hit_review_status(amt_hit_id=assignment.amt_hit_id, revert=False)
hit, _ = AMTManager.get_hit_if_exists(amt_hit_id=assignment.amt_hit_id)
+
+ if hit is None:
+ logging.warning(
+ f"Hit not found when trying to review hit: {assignment.amt_hit_id}"
+ )
+ return None
+
# Update the db
HM.update_hit(hit)
+ return None
+
def handle_assignment_w_no_work(assignment: Assignment) -> Assignment:
"""
@@ -267,6 +276,10 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment:
amt_worker_id = assignment.amt_worker_id
amt_assignment_id = assignment.amt_assignment_id
tsid = assignment.tsid
+ assert (
+ tsid is not None
+ ), "Assignment must have a tsid to be handled in handle_assignment_w_work"
+
hit = HM.get_from_amt_id(amt_hit_id=assignment.amt_hit_id)
tsr = get_task_status(tsid=tsid)
@@ -289,6 +302,7 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment:
event_type = "assignment_submitted_quality_fail"
else:
event_type = "assignment_submitted_work_not_complete"
+
emit_error_event(
event_type=event_type,
amt_hit_type_id=hit.amt_hit_type_id,
@@ -324,6 +338,8 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment:
)
return assignment
+ assert req.id
+
# We've approved the HIT payment, now update the db to reflect this, and approve the assignment
assignment = approve_assignment(
amt_assignment_id=amt_assignment_id,
@@ -331,7 +347,9 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment:
amt_hit_type_id=hit.amt_hit_type_id,
)
# We complete after the assignment is approved
- complete_res = manage_pending_cashout(req.id, PayoutStatus.COMPLETE)
+ complete_res = manage_pending_cashout(
+ cashout_id=req.id, payout_status=PayoutStatus.COMPLETE
+ )
if complete_res.status != PayoutStatus.COMPLETE:
# unclear wny this would happen
raise ValueError(f"Failed to complete cashout: {req.id}")
@@ -345,8 +363,9 @@ def submit_and_approve_amt_assignment_request(
req = user_cashout_request(
amt_worker_id=amt_worker_id,
amount=amount,
- cashout_method_id=AMT_ASSIGNMENT_CASHOUT_METHOD,
+ cashout_method_id=settings.amt_assignment_cashout_method,
)
+ assert req.id
if req.status != PayoutStatus.PENDING:
return None
@@ -365,8 +384,9 @@ def submit_and_approve_amt_bonus_request(
req = user_cashout_request(
amt_worker_id=amt_worker_id,
amount=amount,
- cashout_method_id=AMT_BONUS_CASHOUT_METHOD,
+ cashout_method_id=settings.amt_bonus_cashout_method,
)
+ assert req.id
if req.status != PayoutStatus.PENDING:
return None
@@ -404,6 +424,7 @@ def issue_worker_payment(assignment: Assignment) -> None:
amt_hit_type_id=hit.amt_hit_type_id,
)
return None
+ assert pe.id
AMTManager.send_bonus(
amt_worker_id=assignment.amt_worker_id,
@@ -412,13 +433,26 @@ def issue_worker_payment(assignment: Assignment) -> None:
reason=BONUS_MESSAGE,
unique_request_token=pe.id,
)
+
# Confirm it was sent through amt
bonus = AMTManager.get_bonus(
amt_assignment_id=assignment.amt_assignment_id, payout_event_id=pe.id
)
+
+ if bonus is None:
+ logging.warning(
+ f"Failed to find bonus after sending it: {amt_assignment_id} {pe.id}"
+ )
+ emit_error_event(
+ event_type="bonus_not_found_after_sending",
+ amt_hit_type_id=hit.amt_hit_type_id,
+ )
+ return None
+
# Create in DB
- BM.create(bonus)
+ BM.create(bonus=bonus)
emit_bonus_event(amount=amount, amt_hit_type_id=hit.amt_hit_type_id)
+
# Complete cashout
res = manage_pending_cashout(pe.id, PayoutStatus.COMPLETE)
if res.status != PayoutStatus.COMPLETE:
diff --git a/jb/flow/events.py b/jb/flow/events.py
index 3961a64..7b7bd32 100644
--- a/jb/flow/events.py
+++ b/jb/flow/events.py
@@ -1,8 +1,8 @@
import logging
import time
from concurrent import futures
-from concurrent.futures import ThreadPoolExecutor, Executor, as_completed
-from typing import Optional
+from concurrent.futures import ThreadPoolExecutor, Executor
+from typing import Optional, cast, TypedDict
import redis
@@ -16,6 +16,15 @@ from jb.decorators import REDIS
from jb.flow.assignment_tasks import process_assignment_submitted
from jb.models.event import MTurkEvent
+StreamMessages = list[tuple[str, list[tuple[bytes, dict[bytes, bytes]]]]]
+
+
+class PendingEntry(TypedDict):
+ message_id: bytes
+ consumer: bytes
+ time_since_delivered: int
+ times_delivered: int
+
def process_mturk_events_task():
executor = ThreadPoolExecutor(max_workers=5)
@@ -57,24 +66,26 @@ def create_consumer_group():
def process_mturk_events_chunk(executor: Executor) -> Optional[int]:
- msgs = REDIS.xreadgroup(
+ msgs_raw = REDIS.xreadgroup(
groupname=CONSUMER_GROUP,
consumername=CONSUMER_NAME,
streams={JB_EVENTS_STREAM: ">"},
count=10,
)
- if not msgs:
+ if not msgs_raw:
return None
- msgs = msgs[0][1] # the queue, we only have 1
+
+ msgs = cast(StreamMessages, msgs_raw)[0][1] # The queue, we only have 1
fs = []
for msg in msgs:
msg_id, data = msg
- msg_json = data["data"]
- event = MTurkEvent.model_validate_json(msg_json)
+ msg_json: str = data["data"]
+
+ event = MTurkEvent.model_validate_json(json_data=msg_json)
if event.event_type == "AssignmentSubmitted":
fs.append(
- executor.submit(process_assignment_submitted_event, event, msg_id)
+ executor.submit(process_assignment_submitted_event, event, str(msg_id))
)
else:
logging.info(f"Discarding {event}")
@@ -93,9 +104,13 @@ def handle_pending_msgs():
# Looks in the redis queue for msgs that
# are pending (read by a consumer but not ACK). These prob failed.
# Below is from chatgpt, idk if it works
- pending = REDIS.xpending_range(
- JB_EVENTS_STREAM, CONSUMER_GROUP, min="-", max="+", count=10
+ pending = cast(
+ list[PendingEntry],
+ REDIS.xpending_range(
+ JB_EVENTS_STREAM, CONSUMER_GROUP, min="-", max="+", count=10
+ ),
)
+
for entry in pending:
msg_id = entry["message_id"]
# Claim message if idle > 10 sec
diff --git a/jb/flow/maintenance.py b/jb/flow/maintenance.py
index 5dc9cea..744ed1d 100644
--- a/jb/flow/maintenance.py
+++ b/jb/flow/maintenance.py
@@ -11,16 +11,20 @@ def check_hit_status(
) -> HitStatus:
"""
(this used to be called "process_hit")
- Request information from Amazon regarding the status of a HIT ID. Update the local state from
- that response.
+ Request information from Amazon regarding the status of a HIT ID. Update
+ the local state from that response.
"""
+
hit_status = AMTManager.get_hit_status(amt_hit_id=amt_hit_id)
- # We're assuming that in the db this Hit is marked as Assignable, or else we wouldn't
- # have called this function.
+
+ # We're assuming that in the db this Hit is marked as Assignable, or
+ # else we wouldn't have called this function.
if hit_status != HitStatus.Assignable:
- # todo: should update also assignment_pending_count, assignment_available_count, assignment_completed_count
+ # TODO: should update also assignment_pending_count,
+ # assignment_available_count, assignment_completed_count
HM.update_status(amt_hit_id=amt_hit_id, hit_status=hit_status)
emit_hit_event(
status=hit_status, amt_hit_type_id=amt_hit_type_id, reason=reason
)
+
return hit_status
diff --git a/jb/flow/monitoring.py b/jb/flow/monitoring.py
index c8432bb..28f7271 100644
--- a/jb/flow/monitoring.py
+++ b/jb/flow/monitoring.py
@@ -5,11 +5,12 @@ from mypy_boto3_mturk.literals import EventTypeType
from jb.config import settings
from jb.decorators import influx_client
-from jb.models.currency import USDCent
+from generalresearchutils.currency import USDCent
from jb.models.definitions import HitStatus, AssignmentStatus
-def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int):
+def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int) -> None:
+
tags = {
"host": socket.gethostname(), # could be "amt-jb-0"
"service": "amt-jb",
@@ -23,10 +24,14 @@ def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int):
"fields": {"count": cnt},
}
if influx_client:
- influx_client.write_points([point])
+ influx_client.write_points(points=[point])
+
+ return None
-def write_assignment_gauge(status: AssignmentStatus, amt_hit_type_id: str, cnt: int):
+def write_assignment_gauge(
+ status: AssignmentStatus, amt_hit_type_id: str, cnt: int
+) -> None:
tags = {
"host": socket.gethostname(),
"service": "amt-jb",
@@ -40,15 +45,17 @@ def write_assignment_gauge(status: AssignmentStatus, amt_hit_type_id: str, cnt:
"fields": {"count": cnt},
}
if influx_client:
- influx_client.write_points([point])
+ influx_client.write_points(points=[point])
+
+ return None
def emit_hit_event(
status: HitStatus, amt_hit_type_id: str, reason: Optional[str] = None
-):
+) -> None:
"""
- e.g. a HIT was created, Reviewable, etc. We don't have a "created" HitStatus,
- so it would just be when status=='Assignable'
+ e.g. a HIT was created, Reviewable, etc. We don't have a "created"
+ HitStatus, so it would just be when status=='Assignable'
"""
tags = {
"host": socket.gethostname(),
@@ -57,8 +64,10 @@ def emit_hit_event(
"amt_hit_type_id": amt_hit_type_id,
"debug": settings.debug,
}
+
if reason:
tags["reason"] = reason
+
point = {
"measurement": "amt_jb.hit_events",
"tags": tags,
@@ -68,10 +77,12 @@ def emit_hit_event(
if influx_client:
influx_client.write_points([point])
+ return None
+
def emit_assignment_event(
status: AssignmentStatus, amt_hit_type_id: str, reason: Optional[str] = None
-):
+) -> None:
"""
e.g. an Assignment was accepted/approved/reject
"""
@@ -82,8 +93,10 @@ def emit_assignment_event(
"amt_hit_type_id": amt_hit_type_id,
"debug": settings.debug,
}
+
if reason:
tags["reason"] = reason
+
point = {
"measurement": "amt_jb.assignment_events",
"tags": tags,
@@ -93,10 +106,15 @@ def emit_assignment_event(
if influx_client:
influx_client.write_points([point])
+ return None
+
-def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: str):
+def emit_mturk_notification_event(
+ event_type: EventTypeType, amt_hit_type_id: str
+) -> None:
"""
- e.g. a Mturk notification was received. We just put it in redis, we haven't processed it yet.
+ e.g. a Mturk notification was received. We just put it in redis, we
+ haven't processed it yet.
"""
tags = {
"host": socket.gethostname(),
@@ -105,6 +123,7 @@ def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: st
"amt_hit_type_id": amt_hit_type_id,
"debug": settings.debug,
}
+
point = {
"measurement": "amt_jb.mturk_notification_events",
"tags": tags,
@@ -114,8 +133,10 @@ def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: st
if influx_client:
influx_client.write_points([point])
+ return None
+
-def emit_error_event(event_type: str, amt_hit_type_id: str):
+def emit_error_event(event_type: str, amt_hit_type_id: str) -> None:
"""
e.g. todo: structure the error_types
"""
@@ -126,6 +147,7 @@ def emit_error_event(event_type: str, amt_hit_type_id: str):
"amt_hit_type_id": amt_hit_type_id,
"debug": settings.debug,
}
+
point = {
"measurement": "amt_jb.error_events",
"tags": tags,
@@ -135,8 +157,10 @@ def emit_error_event(event_type: str, amt_hit_type_id: str):
if influx_client:
influx_client.write_points([point])
+ return None
+
-def emit_bonus_event(amount: USDCent, amt_hit_type_id: str):
+def emit_bonus_event(amount: USDCent, amt_hit_type_id: str) -> None:
"""
An AMT bonus was awarded
"""
@@ -146,6 +170,7 @@ def emit_bonus_event(amount: USDCent, amt_hit_type_id: str):
"amt_hit_type_id": amt_hit_type_id,
"debug": settings.debug,
}
+
point = {
"measurement": "amt_jb.bonus_events",
"tags": tags,
@@ -154,3 +179,5 @@ def emit_bonus_event(amount: USDCent, amt_hit_type_id: str):
if influx_client:
influx_client.write_points([point])
+
+ return None
diff --git a/jb/flow/setup_tasks.py b/jb/flow/setup_tasks.py
new file mode 100644
index 0000000..4664374
--- /dev/null
+++ b/jb/flow/setup_tasks.py
@@ -0,0 +1,36 @@
+from jb.config import TOPIC_ARN, SUBSCRIPTION
+from jb.decorators import SNS_CLIENT, AMT_CLIENT
+from jb.config import settings
+
+
+def initial_setup():
+ # Run once for initial setup. Not on each server start or anything
+ subscription = SNS_CLIENT.subscribe( # type: ignore
+ TopicArn=TOPIC_ARN,
+ Protocol="https",
+ Endpoint=f"https://jamesbillings67.com/{settings.sns_path}/",
+ ReturnSubscriptionArn=True,
+ )
+
+
+def check_sns_configuration():
+ SNS_CLIENT.get_topic_attributes(TopicArn=TOPIC_ARN)
+
+ # check this TOPIC_ARN exists
+ # (doesnt have permission, dont need this anyways)
+ # res = SNS_CLIENT.list_topics()
+ # arns = {x["TopicArn"] for x in res["Topics"]}
+ # assert TOPIC_ARN in arns, f"SNS Topic {TOPIC_ARN} doesn't exist!"
+
+ subs = SNS_CLIENT.list_subscriptions_by_topic(TopicArn=TOPIC_ARN)
+ assert SUBSCRIPTION in subs["Subscriptions"]
+
+ AMT_CLIENT.send_test_event_notification(
+ Notification={
+ "Destination": TOPIC_ARN,
+ "Transport": "SNS",
+ "Version": "2006-05-05",
+ "EventTypes": ["AssignmentSubmitted"],
+ },
+ TestEventType="AssignmentSubmitted",
+ )
diff --git a/jb/flow/tasks.py b/jb/flow/tasks.py
index e7c64b9..808c8f7 100644
--- a/jb/flow/tasks.py
+++ b/jb/flow/tasks.py
@@ -1,5 +1,6 @@
import logging
import time
+from typing import TypedDict, cast
from generalresearchutils.config import is_debug
@@ -15,6 +16,11 @@ logger = logging.getLogger()
logger.setLevel(logging.INFO)
+class HitRow(TypedDict):
+ amt_hit_id: str
+ amt_hit_type_id: str
+
+
def check_stale_hits():
# Check live hits that haven't been modified in a long time. They may not
# be expired yet, but maybe something is wrong?
@@ -28,7 +34,7 @@ def check_stale_hits():
LIMIT 100;""",
params={"status": HitStatus.Assignable.value},
)
- for hit in res:
+ for hit in cast(list[HitRow], res):
logging.info(f"check_stale_hits: {hit["amt_hit_id"]}")
check_hit_status(
amt_hit_id=hit["amt_hit_id"],
@@ -49,7 +55,7 @@ def check_expired_hits():
LIMIT 100;""",
params={"status": HitStatus.Assignable.value},
)
- for hit in res:
+ for hit in cast(list[HitRow], res):
logging.info(f"check_expired_hits: {hit["amt_hit_id"]}")
check_hit_status(
amt_hit_id=hit["amt_hit_id"],
@@ -74,8 +80,12 @@ def create_hit_from_hittype(hit_type: HitType) -> Hit:
def refill_hits() -> None:
+
for hit_type in HTM.filter_active():
- active_count = HM.get_active_count(hit_type.id)
+ assert hit_type.id
+ assert hit_type.amt_hit_type_id
+
+ active_count = HM.get_active_count(hit_type_id=hit_type.id)
logging.info(
f"HitType: {hit_type.amt_hit_type_id}, {hit_type.min_active=}, active_count={active_count}"
)
diff --git a/jb/main.py b/jb/main.py
index 8c1dbed..fa59167 100644
--- a/jb/main.py
+++ b/jb/main.py
@@ -1,6 +1,7 @@
from multiprocessing import Process
+from typing import Any, Dict
-from fastapi import FastAPI, Request
+from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
@@ -35,7 +36,7 @@ app.include_router(router=common_router)
@app.get("/robots.txt")
@app.get("/sitemap.xml")
@app.get("/favicon.ico")
-def return_nothing():
+def return_nothing() -> Dict[str, Any]:
return {}
@@ -47,7 +48,7 @@ def serve_react_app(full_path: str):
def schedule_tasks():
- from jb.flow.events import process_mturk_events_task, handle_pending_msgs_task
+ from jb.flow.events import process_mturk_events_task
from jb.flow.tasks import refill_hits_task
Process(target=process_mturk_events_task).start()
diff --git a/jb/managers/__init__.py b/jb/managers/__init__.py
index e2aab6d..e99569a 100644
--- a/jb/managers/__init__.py
+++ b/jb/managers/__init__.py
@@ -15,8 +15,8 @@ class PostgresManager:
def __init__(
self,
pg_config: PostgresConfig,
- permissions: Collection[Permission] = None,
- **kwargs,
+ permissions: Collection[Permission] = None, # type: ignore
+ **kwargs, # type: ignore
):
super().__init__(**kwargs)
self.pg_config = pg_config
diff --git a/jb/managers/amt.py b/jb/managers/amt.py
index 79661c7..0ec70d3 100644
--- a/jb/managers/amt.py
+++ b/jb/managers/amt.py
@@ -10,7 +10,7 @@ from jb.decorators import AMT_CLIENT
from jb.models import AMTAccount
from jb.models.assignment import Assignment
from jb.models.bonus import Bonus
-from jb.models.currency import USDCent
+from generalresearchutils.currency import USDCent
from jb.models.definitions import HitStatus
from jb.models.hit import HitType, HitQuestion, Hit
@@ -48,19 +48,24 @@ class AMTManager:
return hit, None
@classmethod
- def get_hit_status(cls, amt_hit_id: str):
+ def get_hit_status(cls, amt_hit_id: str) -> HitStatus:
res, msg = cls.get_hit_if_exists(amt_hit_id=amt_hit_id)
+
if res is None:
+ if msg is None:
+ return HitStatus.Unassignable
+
if " does not exist. (" in msg:
return HitStatus.Disposed
else:
logging.warning(msg)
return HitStatus.Unassignable
+
return res.status
@staticmethod
def create_hit_type(hit_type: HitType):
- res = AMT_CLIENT.create_hit_type(**hit_type.to_api_request_body())
+ res = AMT_CLIENT.create_hit_type(**hit_type.to_api_request_body()) # type: ignore
hit_type.amt_hit_type_id = res["HITTypeId"]
AMT_CLIENT.update_notification_settings(
HITTypeId=hit_type.amt_hit_type_id,
@@ -94,8 +99,10 @@ class AMTManager:
@staticmethod
def get_assignment(amt_assignment_id: str) -> Assignment:
- # note, you CANNOT get an assignment if it has been only ACCEPTED (by the worker)
- # the api is stupid. it will only show up once it is submitted
+ """
+ You CANNOT get an Assignment if it has been only ACCEPTED (by the
+ worker). The api is stupid, it will only show up once it is Submitted
+ """
res = AMT_CLIENT.get_assignment(AssignmentId=amt_assignment_id)
ass_res: AssignmentTypeDef = res["Assignment"]
assignment = Assignment.from_amt_get_assignment(ass_res)
@@ -158,6 +165,7 @@ class AMTManager:
raise ValueError(error_msg)
# elif "This HIT is currently in the state 'Reviewing'" in error_msg:
# logging.warning(error_msg)
+
return None
@staticmethod
@@ -203,7 +211,7 @@ class AMTManager:
return None
@staticmethod
- def expire_all_hits():
+ def expire_all_hits() -> None:
# used in testing only (or in an emergency I guess)
now = datetime.now(tz=timezone.utc)
paginator = AMT_CLIENT.get_paginator("list_hits")
@@ -214,3 +222,4 @@ class AMTManager:
AMT_CLIENT.update_expiration_for_hit(
HITId=hit["HITId"], ExpireAt=now
)
+ return None
diff --git a/jb/managers/assignment.py b/jb/managers/assignment.py
index fca72e8..dd3c866 100644
--- a/jb/managers/assignment.py
+++ b/jb/managers/assignment.py
@@ -28,7 +28,7 @@ class AssignmentManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
stub.id = pk
return None
@@ -62,7 +62,7 @@ class AssignmentManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
assignment.id = pk
return None
@@ -233,7 +233,7 @@ class AssignmentManager(PostgresManager):
"lookback_interval": f"{lookback_hrs} hour",
},
)
- return int(res[0]["c"])
+ return int(res[0]["c"]) # type: ignore
def rejected_count(
self, amt_worker_id: str, lookback_hrs: int = 24
@@ -256,4 +256,4 @@ class AssignmentManager(PostgresManager):
"status": AssignmentStatus.Rejected.value,
},
)
- return int(res[0]["c"])
+ return int(res[0]["c"]) # type: ignore
diff --git a/jb/managers/bonus.py b/jb/managers/bonus.py
index 0cb8b02..89b81f0 100644
--- a/jb/managers/bonus.py
+++ b/jb/managers/bonus.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import List, Any
from psycopg import sql
@@ -37,12 +37,12 @@ class BonusManager(PostgresManager):
c.execute(query, data)
res = c.fetchone()
conn.commit()
- bonus.id = res["id"]
- bonus.assignment_id = res["assignment_id"]
+ bonus.id = res["id"] # type: ignore
+ bonus.assignment_id = res["assignment_id"] # type: ignore
return None
def filter(self, amt_assignment_id: str) -> List[Bonus]:
- res = self.pg_config.execute_sql_query(
+ res: List[Any] = self.pg_config.execute_sql_query(
"""
SELECT mb.*, ma.amt_assignment_id
FROM mtwerk_bonus mb
@@ -51,4 +51,5 @@ class BonusManager(PostgresManager):
""",
params={"amt_assignment_id": amt_assignment_id},
)
- return [Bonus.from_postgres(x) for x in res]
+
+ return [Bonus.from_postgres(data=x) for x in res]
diff --git a/jb/managers/hit.py b/jb/managers/hit.py
index 3832418..ce8ffa5 100644
--- a/jb/managers/hit.py
+++ b/jb/managers/hit.py
@@ -24,7 +24,7 @@ class HitQuestionManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
question.id = pk
return None
@@ -67,6 +67,7 @@ class HitQuestionManager(PostgresManager):
class HitTypeManager(PostgresManager):
+
def create(self, hit_type: HitType) -> None:
assert hit_type.amt_hit_type_id is not None
data = hit_type.to_postgres()
@@ -99,9 +100,10 @@ class HitTypeManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
hit_type.id = pk
+
return None
def filter_active(self) -> List[HitType]:
@@ -137,13 +139,13 @@ class HitTypeManager(PostgresManager):
except AssertionError:
return None
- def get_or_create(self, hit_type: HitType) -> None:
+ def get_or_create(self, hit_type: HitType) -> HitType:
res = self.get_if_exists(amt_hit_type_id=hit_type.amt_hit_type_id)
if res:
- hit_type.id = res.id
- if res is None:
- self.create(hit_type=hit_type)
- return None
+ return res
+
+ self.create(hit_type=hit_type)
+ return self.get(amt_hit_type_id=hit_type.amt_hit_type_id)
def set_min_active(self, hit_type: HitType) -> None:
assert hit_type.id, "must be in the db first!"
@@ -164,6 +166,7 @@ class HitTypeManager(PostgresManager):
class HitManager(PostgresManager):
+
def create(self, hit: Hit):
assert hit.amt_hit_id is not None
assert hit.id is None
@@ -209,7 +212,7 @@ class HitManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
hit.id = pk
return hit
@@ -267,7 +270,7 @@ class HitManager(PostgresManager):
c.execute(query, data)
conn.commit()
assert c.rowcount == 1, c.rowcount
- hit.id = c.fetchone()["id"]
+ hit.id = c.fetchone()["id"] # type: ignore
return None
def get_from_amt_id(self, amt_hit_id: str) -> Hit:
@@ -305,17 +308,26 @@ class HitManager(PostgresManager):
""",
params={"amt_hit_id": amt_hit_id},
)
- assert len(res) == 1
+ assert len(res) == 1, "Incorrect number of results"
res = res[0]
+
question_xml = HitQuestion.model_validate(
{"height": res.pop("height"), "url": res.pop("url")}
).xml
+
res["question_id"] = res["question_id"]
res["hit_question_xml"] = question_xml
return Hit.from_postgres(res)
- def get_active_count(self, hit_type_id: int):
+ def get_from_amt_id_if_exists(self, amt_hit_id: str) -> Optional[Hit]:
+ try:
+ return self.get_from_amt_id(amt_hit_id=amt_hit_id)
+
+ except (AssertionError, Exception):
+ return None
+
+ def get_active_count(self, hit_type_id: int) -> int:
return self.pg_config.execute_sql_query(
"""
SELECT COUNT(1) as active_count
@@ -326,7 +338,7 @@ class HitManager(PostgresManager):
params={"status": HitStatus.Assignable, "hit_type_id": hit_type_id},
)[0]["active_count"]
- def filter_active_ids(self, hit_type_id: int):
+ def filter_active_ids(self, hit_type_id: int) -> set[str]:
res = self.pg_config.execute_sql_query(
"""
SELECT mh.amt_hit_id
diff --git a/jb/managers/thl.py b/jb/managers/thl.py
index b1dcbde..83f49f6 100644
--- a/jb/managers/thl.py
+++ b/jb/managers/thl.py
@@ -1,7 +1,3 @@
-from decimal import Decimal
-from typing import Dict, Optional
-
-import requests
from generalresearchutils.models.thl.payout import UserPayoutEvent
from generalresearchutils.models.thl.task_status import TaskStatusResponse
from generalresearchutils.models.thl.wallet.cashout_method import (
@@ -9,45 +5,58 @@ from generalresearchutils.models.thl.wallet.cashout_method import (
CashoutRequestInfo,
)
+from generalresearchutils.models.thl.user_profile import UserProfile
+from generalresearchutils.currency import USDCent
+
from jb.config import settings
-from jb.models.currency import USDCent
-from jb.models.definitions import PayoutStatus
-# TODO: Organize this more with other endpoints (offerwall, cashout requests/approvals, etc).
+from generalresearchutils.models.thl.definitions import PayoutStatus
+
+
+from typing import Optional
+import requests
+
+# TODO: Organize this more with other endpoints (offerwall, cashout
+# requests/approvals, etc).
-def get_user_profile(amt_worker_id: str) -> Dict:
+def get_user_profile(amt_worker_id: str) -> UserProfile:
url = f"{settings.fsb_host}{settings.product_id}/user/{amt_worker_id}/profile/"
res = requests.get(url).json()
if res.get("detail") == "user not found":
raise ValueError("user not found")
- return res["user_profile"]
+
+ return UserProfile.model_validate(res["user_profile"])
def get_user_blocked(amt_worker_id: str) -> bool:
+ # Not blocked if None
res = get_user_profile(amt_worker_id=amt_worker_id)
- return res["user"]["blocked"]
+ return res.user.blocked if res.user.blocked is not None else False
-def get_user_blocked_or_not_exists(amt_worker_id: str) -> bool:
+def get_user_blocked_or_not_exists(amt_worker_id: str) -> Optional[bool]:
try:
res = get_user_profile(amt_worker_id=amt_worker_id)
- return res["user"]["blocked"]
+ return res.user.blocked if res.user.blocked is not None else False
except ValueError as e:
if e.args[0] == "user not found":
return True
+ return None
+
def get_task_status(tsid: str) -> Optional[TaskStatusResponse]:
url = f"{settings.fsb_host}{settings.product_id}/status/{tsid}/"
d = requests.get(url).json()
if d.get("msg") == "invalid tsid":
return None
+
return TaskStatusResponse.model_validate(d)
def user_cashout_request(
- amt_worker_id: str, amount: USDCent, cashout_method_id
+ amt_worker_id: str, amount: USDCent, cashout_method_id: str
) -> CashoutRequestInfo:
assert cashout_method_id in {
settings.amt_assignment_cashout_method,
@@ -56,7 +65,8 @@ def user_cashout_request(
assert isinstance(amount, USDCent)
assert USDCent(0) < amount < USDCent(10_00)
url = f"{settings.fsb_host}{settings.product_id}/cashout/"
- body = {
+
+ body: dict[str, str | int] = {
"bpuid": amt_worker_id,
"amount": int(amount),
"cashout_method_id": cashout_method_id,
@@ -81,7 +91,7 @@ def manage_pending_cashout(
return UserPayoutEvent.model_validate(d)
-def get_wallet_balance(amt_worker_id: str):
+def get_wallet_balance(amt_worker_id: str) -> USDCent:
url = f"{settings.fsb_host}{settings.product_id}/wallet/"
params = {"bpuid": amt_worker_id}
return USDCent(requests.get(url, params=params).json()["wallet"]["amount"])
diff --git a/jb/models/assignment.py b/jb/models/assignment.py
index 39ae47c..5dd0167 100644
--- a/jb/models/assignment.py
+++ b/jb/models/assignment.py
@@ -1,6 +1,6 @@
import logging
from datetime import datetime, timezone
-from typing import Optional, TypedDict
+from typing import Optional, TypedDict, Any
from xml.etree import ElementTree
from mypy_boto3_mturk.type_defs import AssignmentTypeDef
@@ -10,7 +10,6 @@ from pydantic import (
ConfigDict,
model_validator,
PositiveInt,
- computed_field,
TypeAdapter,
ValidationError,
)
@@ -116,10 +115,12 @@ class Assignment(AssignmentStub):
default=None,
min_length=3,
max_length=2_000,
- help_text="The feedback string included with the call to the "
- "ApproveAssignment operation or the RejectAssignment "
- "operation, if the Requester approved or rejected the "
- "assignment and specified feedback.",
+ json_schema_extra={
+ "help_text": "The feedback string included with the call to the "
+ "ApproveAssignment operation or the RejectAssignment "
+ "operation, if the Requester approved or rejected the "
+ "assignment and specified feedback."
+ },
)
answer_xml: Optional[str] = Field(default=None, exclude=True)
@@ -131,7 +132,7 @@ class Assignment(AssignmentStub):
# --- Validators ---
@model_validator(mode="before")
- def set_tsid(cls, values: dict):
+ def set_tsid(cls, values: dict[str, Any]) -> dict[str, Any]:
if values.get("tsid") is None and (answer_xml := values.get("answer_xml")):
answer_dict = cls.parse_answer_xml(answer_xml)
tsid = answer_dict.get("tsid")
@@ -175,10 +176,10 @@ class Assignment(AssignmentStub):
if self.answer_xml is None:
return None
- return self.parse_answer_xml(self.answer_xml)
+ return self.parse_answer_xml(self.answer_xml) # type: ignore
@staticmethod
- def parse_answer_xml(answer_xml: str):
+ def parse_answer_xml(answer_xml: str) -> dict[str, Any]:
root = ElementTree.fromstring(answer_xml)
ns = {
"mt": "http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2005-10-01/QuestionFormAnswers.xsd"
@@ -186,8 +187,8 @@ class Assignment(AssignmentStub):
res = {}
for a in root.findall("mt:Answer", ns):
- name = a.find("mt:QuestionIdentifier", ns).text
- value = a.find("mt:FreeText", ns).text
+ name = a.find("mt:QuestionIdentifier", ns).text # type: ignore
+ value = a.find("mt:FreeText", ns).text # type: ignore
res[name] = value or ""
EXPECTED_KEYS = {"amt_assignment_id", "amt_worker_id", "tsid"}
diff --git a/jb/models/bonus.py b/jb/models/bonus.py
index 564a32d..a536dd1 100644
--- a/jb/models/bonus.py
+++ b/jb/models/bonus.py
@@ -1,11 +1,10 @@
-from typing import Optional, Dict
+from typing import Optional, Dict, Any
from pydantic import BaseModel, Field, ConfigDict, PositiveInt
from typing_extensions import Self
-from jb.models.currency import USDCent
+from generalresearchutils.currency import USDCent
from jb.models.custom_types import AMTBoto3ID, AwareDatetimeISO, UUIDStr
-from jb.models.definitions import PayoutStatus
class Bonus(BaseModel):
@@ -41,7 +40,7 @@ class Bonus(BaseModel):
return d
@classmethod
- def from_postgres(cls, data: Dict) -> Self:
+ def from_postgres(cls, data: Dict[str, Any]) -> Self:
data["amount"] = USDCent(round(data["amount"] * 100))
fields = set(cls.model_fields.keys())
data = {k: v for k, v in data.items() if k in fields}
diff --git a/jb/models/currency.py b/jb/models/currency.py
deleted file mode 100644
index 3094e2a..0000000
--- a/jb/models/currency.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import warnings
-from decimal import Decimal
-from typing import Any
-
-from pydantic import GetCoreSchemaHandler, NonNegativeInt
-from pydantic_core import CoreSchema, core_schema
-
-
-class USDCent(int):
- def __new__(cls, value, *args, **kwargs):
-
- if isinstance(value, float):
- warnings.warn(
- "USDCent init with a float. Rounding behavior may " "be unexpected"
- )
-
- if isinstance(value, Decimal):
- warnings.warn(
- "USDCent init with a Decimal. Rounding behavior may " "be unexpected"
- )
-
- if value < 0:
- raise ValueError("USDCent not be less than zero")
-
- return super(cls, cls).__new__(cls, value)
-
- def __add__(self, other):
- assert isinstance(other, USDCent)
- res = super(USDCent, self).__add__(other)
- return self.__class__(res)
-
- def __sub__(self, other):
- assert isinstance(other, USDCent)
- res = super(USDCent, self).__sub__(other)
- return self.__class__(res)
-
- def __mul__(self, other):
- assert isinstance(other, USDCent)
- res = super(USDCent, self).__mul__(other)
- return self.__class__(res)
-
- def __abs__(self):
- res = super(USDCent, self).__abs__()
- return self.__class__(res)
-
- def __truediv__(self, other):
- raise ValueError("Division not allowed for USDCent")
-
- def __str__(self):
- return "%d" % int(self)
-
- def __repr__(self):
- return "USDCent(%d)" % int(self)
-
- @classmethod
- def __get_pydantic_core_schema__(
- cls, source_type: Any, handler: GetCoreSchemaHandler
- ) -> CoreSchema:
- """
- https://docs.pydantic.dev/latest/concepts/types/#customizing-validation-with-__get_pydantic_core_schema__
- """
- return core_schema.no_info_after_validator_function(
- cls, handler(NonNegativeInt)
- )
-
- def to_usd(self) -> Decimal:
- return Decimal(int(self) / 100).quantize(Decimal(".01"))
-
- def to_usd_str(self) -> str:
- return "${:,.2f}".format(float(self.to_usd()))
diff --git a/jb/models/custom_types.py b/jb/models/custom_types.py
index 70bc5c1..10bc9d1 100644
--- a/jb/models/custom_types.py
+++ b/jb/models/custom_types.py
@@ -34,8 +34,7 @@ def convert_str_dt(v: Any) -> Optional[AwareDatetime]:
def assert_utc(v: AwareDatetime) -> AwareDatetime:
- if isinstance(v, datetime):
- assert v.tzinfo == timezone.utc, "Timezone is not UTC"
+ assert v.tzinfo == timezone.utc, "Timezone is not UTC"
return v
diff --git a/jb/models/definitions.py b/jb/models/definitions.py
index a3d27ba..4ae7a21 100644
--- a/jb/models/definitions.py
+++ b/jb/models/definitions.py
@@ -1,4 +1,4 @@
-from enum import IntEnum, StrEnum
+from enum import IntEnum
class AssignmentStatus(IntEnum):
@@ -37,32 +37,6 @@ class HitReviewStatus(IntEnum):
ReviewedInappropriate = 3
-class PayoutStatus(StrEnum):
- """These are GRL's payout statuses"""
-
- # The user has requested a payout. The money is taken from their
- # wallet. A PENDING request can either be APPROVED, REJECTED, or
- # CANCELLED. We can also implicitly skip the APPROVED step and go
- # straight to COMPLETE or FAILED.
- PENDING = "PENDING"
- # The request is approved (by us or automatically). Once approved,
- # it can be FAILED or COMPLETE.
- APPROVED = "APPROVED"
- # The request is rejected. The user loses the money.
- REJECTED = "REJECTED"
- # The user requests to cancel the request, the money goes back into their wallet.
- CANCELLED = "CANCELLED"
- # The payment was approved, but failed within external payment provider.
- # This is an "error" state, as the money won't have moved anywhere. A
- # FAILED payment can be tried again and be COMPLETE.
- FAILED = "FAILED"
- # The payment was sent successfully and (usually) a fee was charged
- # to us for it.
- COMPLETE = "COMPLETE"
- # Not supported # REFUNDED: I'm not sure if this is possible or
- # if we'd want to allow it.
-
-
class ReportValue(IntEnum):
"""
The reason a user reported a task.
diff --git a/jb/models/event.py b/jb/models/event.py
index c357772..c167420 100644
--- a/jb/models/event.py
+++ b/jb/models/event.py
@@ -11,13 +11,22 @@ class MTurkEvent(BaseModel):
What AWS SNS will POST to our mturk_notifications endpoint (inside the request body)
"""
- event_type: EventTypeType = Field(example="AssignmentSubmitted")
- event_timestamp: AwareDatetimeISO = Field(example="2025-10-16T18:45:51Z")
- amt_hit_id: AMTBoto3ID = Field(example="12345678901234567890")
+ event_type: EventTypeType = Field(
+ json_schema_extra={"example": "AssignmentSubmitted"}
+ )
+ event_timestamp: AwareDatetimeISO = Field(
+ json_schema_extra={"example": "2025-10-16T18:45:51Z"}
+ )
+ amt_hit_id: AMTBoto3ID = Field(
+ json_schema_extra={"example": "12345678901234567890"}
+ )
amt_assignment_id: str = Field(
- max_length=64, example="1234567890123456789012345678901234567890"
+ max_length=64,
+ json_schema_extra={"example": "1234567890123456789012345678901234567890"},
+ )
+ amt_hit_type_id: AMTBoto3ID = Field(
+ json_schema_extra={"example": "09876543210987654321"}
)
- amt_hit_type_id: AMTBoto3ID = Field(example="09876543210987654321")
@classmethod
def from_sns(cls, data: Dict):
diff --git a/jb/models/hit.py b/jb/models/hit.py
index c3734fa..fba2ecf 100644
--- a/jb/models/hit.py
+++ b/jb/models/hit.py
@@ -1,5 +1,5 @@
from datetime import datetime, timezone, timedelta
-from typing import Optional, List, Dict
+from typing import Optional, List, Dict, Any
from uuid import uuid4
from xml.etree import ElementTree
@@ -13,7 +13,7 @@ from pydantic import (
)
from typing_extensions import Self
-from jb.models.currency import USDCent
+from generalresearchutils.currency import USDCent
from jb.models.custom_types import AMTBoto3ID, HttpsUrlStr, AwareDatetimeISO
from jb.models.definitions import HitStatus, HitReviewStatus
@@ -104,11 +104,11 @@ class HitType(HitTypeCommon):
return d
@classmethod
- def from_postgres(cls, data: Dict) -> Self:
+ def from_postgres(cls, data: Dict[str, Any]) -> Self:
data["reward"] = USDCent(round(data["reward"] * 100))
return cls.model_validate(data)
- def generate_hit_amt_request(self, question: HitQuestion):
+ def generate_hit_amt_request(self, question: HitQuestion) -> Dict[str, Any]:
d = dict()
d["HITTypeId"] = self.amt_hit_type_id
d["MaxAssignments"] = 1
@@ -135,7 +135,12 @@ class Hit(HitTypeCommon):
status: HitStatus = Field()
review_status: HitReviewStatus = Field()
- creation_time: AwareDatetimeISO = Field(default=None, description="From aws")
+
+ # TODO: Check if this is actually ever going to be None. I type fixed it,
+ # but I don't have anything to suggest it isn't requred. -- Max 2026-02-24
+ creation_time: Optional[AwareDatetimeISO] = Field(
+ default=None, description="From aws"
+ )
expiration: Optional[AwareDatetimeISO] = Field(default=None)
# GRL Specific
@@ -150,7 +155,7 @@ class Hit(HitTypeCommon):
# -- Hit specific
- qualification_requirements: Optional[List[Dict]] = Field(default=None)
+ qualification_requirements: Optional[List[Dict[str, Any]]] = Field(default=None)
max_assignments: int = Field()
# # this comes back as expiration. only for the request
@@ -171,7 +176,7 @@ class Hit(HitTypeCommon):
assert hit_type.id is not None
assert hit_type.amt_hit_type_id is not None
- h = Hit.model_validate(
+ h = cls.model_validate(
dict(
amt_hit_id=data["HITId"],
amt_hit_type_id=data["HITTypeId"],
@@ -194,11 +199,12 @@ class Hit(HitTypeCommon):
hit_type_id=hit_type.id,
)
)
+
return h
@classmethod
def from_amt_get_hit(cls, data: HITTypeDef) -> Self:
- h = Hit.model_validate(
+ h = cls.model_validate(
dict(
amt_hit_id=data["HITId"],
amt_hit_type_id=data["HITTypeId"],
@@ -229,7 +235,7 @@ class Hit(HitTypeCommon):
return d
@classmethod
- def from_postgres(cls, data: Dict) -> Self:
+ def from_postgres(cls, data: Dict[str, Any]) -> Self:
data["reward"] = USDCent(round(data["reward"] * 100))
return cls.model_validate(data)
diff --git a/jb/settings.py b/jb/settings.py
index 538b89f..5754add 100644
--- a/jb/settings.py
+++ b/jb/settings.py
@@ -45,7 +45,7 @@ class Settings(AmtJbBaseSettings):
debug: bool = False
app_name: str = "AMT JB API"
- fsb_host: HttpUrl = Field(default="https://fsb.generalresearch.com/")
+ fsb_host: HttpUrl = Field(default=HttpUrl("https://fsb.generalresearch.com/"))
# Needed for admin function on fsb w/o authentication
fsb_host_private_route: Optional[str] = Field(default=None)
diff --git a/jb/views/common.py b/jb/views/common.py
index 46ac608..0dc8b56 100644
--- a/jb/views/common.py
+++ b/jb/views/common.py
@@ -11,7 +11,7 @@ from starlette.responses import RedirectResponse
from jb.config import settings, JB_EVENTS_STREAM
from jb.decorators import REDIS, HM
from jb.flow.monitoring import emit_assignment_event, emit_mturk_notification_event
-from jb.models.currency import USDCent
+from generalresearchutils.currency import USDCent
from jb.models.definitions import ReportValue, AssignmentStatus
from jb.models.event import MTurkEvent
from jb.settings import BASE_HTML
@@ -71,6 +71,7 @@ async def work(request: Request):
url=f"/preview/?{request.url.query}" if request.url.query else "/preview/",
status_code=302,
)
+
if amt_assignment_id is None or amt_assignment_id == "ASSIGNMENT_ID_NOT_AVAILABLE":
# Worker is previewing the HIT
amt_hit_type_id = "unknown"
@@ -91,71 +92,6 @@ async def work(request: Request):
return HTMLResponse(BASE_HTML)
-@common_router.get(path="/survey/", response_class=JSONResponse)
-def survey(
- request: Request,
- worker_id: str = Query(),
- duration: int = Query(default=1200),
-):
- if not worker_id:
- raise HTTPException(status_code=400, detail="Missing worker_id")
-
- # (1) Check wallet
- wallet_url = f"{settings.fsb_host}{settings.product_id}/wallet/"
- wallet_res = requests.get(wallet_url, params={"bpuid": worker_id})
- if wallet_res.status_code != 200:
- raise HTTPException(status_code=502, detail="Wallet check failed")
-
- wallet_data = wallet_res.json()
- wallet_balance = wallet_data["wallet"]["amount"]
- if wallet_balance < -100:
- return JSONResponse(
- {
- "total_surveys": 0,
- "link": None,
- "duration": None,
- "payout": None,
- }
- )
-
- # (2) Get offerwall
- client_ip = "69.253.144.55" if settings.debug else request.client.host
- offerwall_url = f"{settings.fsb_host}{settings.product_id}/offerwall/d48cce47/"
- offerwall_res = requests.get(
- offerwall_url,
- params={
- "bpuid": worker_id,
- "ip": client_ip,
- "n_bins": 1,
- "duration": duration,
- },
- )
-
- if offerwall_res.status_code != 200:
- raise HTTPException(status_code=502, detail="Offerwall request failed")
-
- try:
- rj = offerwall_res.json()
- bucket = rj["offerwall"]["buckets"][0]
- return JSONResponse(
- {
- "total_surveys": rj["offerwall"]["availability_count"],
- "link": bucket["uri"],
- "duration": round(bucket["duration"]["q2"] / 60),
- "payout": USDCent(bucket["payout"]["q2"]).to_usd_str(),
- }
- )
- except Exception:
- return JSONResponse(
- {
- "total_surveys": 0,
- "link": None,
- "duration": None,
- "payout": None,
- }
- )
-
-
@common_router.post(path=f"/{settings.sns_path}/", include_in_schema=False)
async def mturk_notifications(request: Request):
"""
diff --git a/jb/views/tasks.py b/jb/views/tasks.py
index 7176b29..313295f 100644
--- a/jb/views/tasks.py
+++ b/jb/views/tasks.py
@@ -18,18 +18,23 @@ def process_request(request: Request) -> None:
amt_assignment_id = request.query_params.get("assignmentId", None)
if amt_assignment_id == "ASSIGNMENT_ID_NOT_AVAILABLE":
raise ValueError("shouldn't happen")
+
amt_hit_id = request.query_params.get("hitId", None)
amt_worker_id = request.query_params.get("workerId", None)
print(f"process_request: {amt_assignment_id=} {amt_worker_id=} {amt_hit_id=}")
assert amt_worker_id and amt_hit_id and amt_assignment_id
# Check that the HIT is still valid
- hit = HM.get_from_amt_id(amt_hit_id=amt_hit_id)
+ hit = HM.get_from_amt_id_if_exists(amt_hit_id=amt_hit_id)
+ if not hit:
+ raise ValueError(f"Hit {amt_hit_id} not found in DB")
+
_ = check_hit_status(amt_hit_id=amt_hit_id, amt_hit_type_id=hit.amt_hit_type_id)
emit_assignment_event(
status=AssignmentStatus.Accepted,
amt_hit_type_id=hit.amt_hit_type_id,
)
+
# I think it won't be assignable anymore? idk
# assert hit_status == HitStatus.Assignable, f"hit {amt_hit_id} {hit_status=}. Expected Assignable"
@@ -53,7 +58,7 @@ def process_request(request: Request) -> None:
assert assignment_stub.amt_worker_id == amt_worker_id
assert assignment_stub.amt_assignment_id == amt_assignment_id
assert assignment_stub.created_at > (
- datetime.now(tz=timezone.utc) - timedelta(minutes=90)
+ datetime.now(tz=timezone.utc) - timedelta(minutes=90)
)
return None
diff --git a/jb/views/utils.py b/jb/views/utils.py
index 39db5d2..0d08e9b 100644
--- a/jb/views/utils.py
+++ b/jb/views/utils.py
@@ -8,8 +8,11 @@ def get_client_ip(request: Request) -> str:
"""
ip = request.headers.get("X-Forwarded-For")
if not ip:
- ip = request.client.host
+ ip = request.client.host # type: ignore
elif ip == "testclient" or ip.startswith("10."):
forwarded = request.headers.get("X-Forwarded-For")
- ip = forwarded.split(",")[0].strip() if forwarded else request.client.host
+ ip = (
+ forwarded.split(",")[0].strip() if forwarded else request.client.host # type: ignore
+ )
+
return ip