aboutsummaryrefslogtreecommitdiff
path: root/jb/flow
diff options
context:
space:
mode:
authorMax Nanis2026-02-24 17:26:15 -0500
committerMax Nanis2026-02-24 17:26:15 -0500
commit8c1940445503fd6678d0961600f2be81622793a2 (patch)
treeb9173562b8824b5eaa805e446d9d780e1f23fb2a /jb/flow
parent25d8c3c214baf10f6520cc1351f78473150e5d7a (diff)
downloadamt-jb-8c1940445503fd6678d0961600f2be81622793a2.tar.gz
amt-jb-8c1940445503fd6678d0961600f2be81622793a2.zip
Extensive use of type checking. Movement of pytest conf towards handling managers (for db agnostic unittest). Starting to organize pytests.
Diffstat (limited to 'jb/flow')
-rw-r--r--jb/flow/assignment_tasks.py52
-rw-r--r--jb/flow/events.py35
-rw-r--r--jb/flow/maintenance.py14
-rw-r--r--jb/flow/monitoring.py53
-rw-r--r--jb/flow/setup_tasks.py36
-rw-r--r--jb/flow/tasks.py16
6 files changed, 166 insertions, 40 deletions
diff --git a/jb/flow/assignment_tasks.py b/jb/flow/assignment_tasks.py
index 4022716..bb0877d 100644
--- a/jb/flow/assignment_tasks.py
+++ b/jb/flow/assignment_tasks.py
@@ -5,6 +5,7 @@ from typing import Optional
from generalresearchutils.models.thl.definitions import PayoutStatus, StatusCode1
from generalresearchutils.models.thl.wallet.cashout_method import CashoutRequestInfo
+from generalresearchutils.currency import USDCent
from jb.decorators import AM, HM, BM
from jb.flow.monitoring import emit_error_event, emit_assignment_event, emit_bonus_event
@@ -21,21 +22,20 @@ from jb.managers.thl import (
get_user_blocked,
get_task_status,
user_cashout_request,
- AMT_ASSIGNMENT_CASHOUT_METHOD,
manage_pending_cashout,
get_user_blocked_or_not_exists,
get_wallet_balance,
- AMT_BONUS_CASHOUT_METHOD,
)
from jb.models.assignment import Assignment
-from jb.models.currency import USDCent
from jb.models.definitions import AssignmentStatus
from jb.models.event import MTurkEvent
+from jb.config import settings
def process_assignment_submitted(event: MTurkEvent) -> None:
"""
- Called either directly or from the SNS Notification that a HIT was submitted
+ Called either directly or from the SNS Notification that a
+ HIT was submitted
:return: None
"""
@@ -137,13 +137,22 @@ def process_assignment_submitted(event: MTurkEvent) -> None:
return issue_worker_payment(assignment)
-def review_hit(assignment):
+def review_hit(assignment: Assignment) -> None:
# Reviewable to Reviewing
AMTManager.update_hit_review_status(amt_hit_id=assignment.amt_hit_id, revert=False)
hit, _ = AMTManager.get_hit_if_exists(amt_hit_id=assignment.amt_hit_id)
+
+ if hit is None:
+ logging.warning(
+ f"Hit not found when trying to review hit: {assignment.amt_hit_id}"
+ )
+ return None
+
# Update the db
HM.update_hit(hit)
+ return None
+
def handle_assignment_w_no_work(assignment: Assignment) -> Assignment:
"""
@@ -267,6 +276,10 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment:
amt_worker_id = assignment.amt_worker_id
amt_assignment_id = assignment.amt_assignment_id
tsid = assignment.tsid
+ assert (
+ tsid is not None
+ ), "Assignment must have a tsid to be handled in handle_assignment_w_work"
+
hit = HM.get_from_amt_id(amt_hit_id=assignment.amt_hit_id)
tsr = get_task_status(tsid=tsid)
@@ -289,6 +302,7 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment:
event_type = "assignment_submitted_quality_fail"
else:
event_type = "assignment_submitted_work_not_complete"
+
emit_error_event(
event_type=event_type,
amt_hit_type_id=hit.amt_hit_type_id,
@@ -324,6 +338,8 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment:
)
return assignment
+ assert req.id
+
# We've approved the HIT payment, now update the db to reflect this, and approve the assignment
assignment = approve_assignment(
amt_assignment_id=amt_assignment_id,
@@ -331,7 +347,9 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment:
amt_hit_type_id=hit.amt_hit_type_id,
)
# We complete after the assignment is approved
- complete_res = manage_pending_cashout(req.id, PayoutStatus.COMPLETE)
+ complete_res = manage_pending_cashout(
+ cashout_id=req.id, payout_status=PayoutStatus.COMPLETE
+ )
if complete_res.status != PayoutStatus.COMPLETE:
# unclear wny this would happen
raise ValueError(f"Failed to complete cashout: {req.id}")
@@ -345,8 +363,9 @@ def submit_and_approve_amt_assignment_request(
req = user_cashout_request(
amt_worker_id=amt_worker_id,
amount=amount,
- cashout_method_id=AMT_ASSIGNMENT_CASHOUT_METHOD,
+ cashout_method_id=settings.amt_assignment_cashout_method,
)
+ assert req.id
if req.status != PayoutStatus.PENDING:
return None
@@ -365,8 +384,9 @@ def submit_and_approve_amt_bonus_request(
req = user_cashout_request(
amt_worker_id=amt_worker_id,
amount=amount,
- cashout_method_id=AMT_BONUS_CASHOUT_METHOD,
+ cashout_method_id=settings.amt_bonus_cashout_method,
)
+ assert req.id
if req.status != PayoutStatus.PENDING:
return None
@@ -404,6 +424,7 @@ def issue_worker_payment(assignment: Assignment) -> None:
amt_hit_type_id=hit.amt_hit_type_id,
)
return None
+ assert pe.id
AMTManager.send_bonus(
amt_worker_id=assignment.amt_worker_id,
@@ -412,13 +433,26 @@ def issue_worker_payment(assignment: Assignment) -> None:
reason=BONUS_MESSAGE,
unique_request_token=pe.id,
)
+
# Confirm it was sent through amt
bonus = AMTManager.get_bonus(
amt_assignment_id=assignment.amt_assignment_id, payout_event_id=pe.id
)
+
+ if bonus is None:
+ logging.warning(
+ f"Failed to find bonus after sending it: {amt_assignment_id} {pe.id}"
+ )
+ emit_error_event(
+ event_type="bonus_not_found_after_sending",
+ amt_hit_type_id=hit.amt_hit_type_id,
+ )
+ return None
+
# Create in DB
- BM.create(bonus)
+ BM.create(bonus=bonus)
emit_bonus_event(amount=amount, amt_hit_type_id=hit.amt_hit_type_id)
+
# Complete cashout
res = manage_pending_cashout(pe.id, PayoutStatus.COMPLETE)
if res.status != PayoutStatus.COMPLETE:
diff --git a/jb/flow/events.py b/jb/flow/events.py
index 3961a64..7b7bd32 100644
--- a/jb/flow/events.py
+++ b/jb/flow/events.py
@@ -1,8 +1,8 @@
import logging
import time
from concurrent import futures
-from concurrent.futures import ThreadPoolExecutor, Executor, as_completed
-from typing import Optional
+from concurrent.futures import ThreadPoolExecutor, Executor
+from typing import Optional, cast, TypedDict
import redis
@@ -16,6 +16,15 @@ from jb.decorators import REDIS
from jb.flow.assignment_tasks import process_assignment_submitted
from jb.models.event import MTurkEvent
+StreamMessages = list[tuple[str, list[tuple[bytes, dict[bytes, bytes]]]]]
+
+
+class PendingEntry(TypedDict):
+ message_id: bytes
+ consumer: bytes
+ time_since_delivered: int
+ times_delivered: int
+
def process_mturk_events_task():
executor = ThreadPoolExecutor(max_workers=5)
@@ -57,24 +66,26 @@ def create_consumer_group():
def process_mturk_events_chunk(executor: Executor) -> Optional[int]:
- msgs = REDIS.xreadgroup(
+ msgs_raw = REDIS.xreadgroup(
groupname=CONSUMER_GROUP,
consumername=CONSUMER_NAME,
streams={JB_EVENTS_STREAM: ">"},
count=10,
)
- if not msgs:
+ if not msgs_raw:
return None
- msgs = msgs[0][1] # the queue, we only have 1
+
+ msgs = cast(StreamMessages, msgs_raw)[0][1] # The queue, we only have 1
fs = []
for msg in msgs:
msg_id, data = msg
- msg_json = data["data"]
- event = MTurkEvent.model_validate_json(msg_json)
+ msg_json: str = data["data"]
+
+ event = MTurkEvent.model_validate_json(json_data=msg_json)
if event.event_type == "AssignmentSubmitted":
fs.append(
- executor.submit(process_assignment_submitted_event, event, msg_id)
+ executor.submit(process_assignment_submitted_event, event, str(msg_id))
)
else:
logging.info(f"Discarding {event}")
@@ -93,9 +104,13 @@ def handle_pending_msgs():
# Looks in the redis queue for msgs that
# are pending (read by a consumer but not ACK). These prob failed.
# Below is from chatgpt, idk if it works
- pending = REDIS.xpending_range(
- JB_EVENTS_STREAM, CONSUMER_GROUP, min="-", max="+", count=10
+ pending = cast(
+ list[PendingEntry],
+ REDIS.xpending_range(
+ JB_EVENTS_STREAM, CONSUMER_GROUP, min="-", max="+", count=10
+ ),
)
+
for entry in pending:
msg_id = entry["message_id"]
# Claim message if idle > 10 sec
diff --git a/jb/flow/maintenance.py b/jb/flow/maintenance.py
index 5dc9cea..744ed1d 100644
--- a/jb/flow/maintenance.py
+++ b/jb/flow/maintenance.py
@@ -11,16 +11,20 @@ def check_hit_status(
) -> HitStatus:
"""
(this used to be called "process_hit")
- Request information from Amazon regarding the status of a HIT ID. Update the local state from
- that response.
+ Request information from Amazon regarding the status of a HIT ID. Update
+ the local state from that response.
"""
+
hit_status = AMTManager.get_hit_status(amt_hit_id=amt_hit_id)
- # We're assuming that in the db this Hit is marked as Assignable, or else we wouldn't
- # have called this function.
+
+ # We're assuming that in the db this Hit is marked as Assignable, or
+ # else we wouldn't have called this function.
if hit_status != HitStatus.Assignable:
- # todo: should update also assignment_pending_count, assignment_available_count, assignment_completed_count
+ # TODO: should update also assignment_pending_count,
+ # assignment_available_count, assignment_completed_count
HM.update_status(amt_hit_id=amt_hit_id, hit_status=hit_status)
emit_hit_event(
status=hit_status, amt_hit_type_id=amt_hit_type_id, reason=reason
)
+
return hit_status
diff --git a/jb/flow/monitoring.py b/jb/flow/monitoring.py
index c8432bb..28f7271 100644
--- a/jb/flow/monitoring.py
+++ b/jb/flow/monitoring.py
@@ -5,11 +5,12 @@ from mypy_boto3_mturk.literals import EventTypeType
from jb.config import settings
from jb.decorators import influx_client
-from jb.models.currency import USDCent
+from generalresearchutils.currency import USDCent
from jb.models.definitions import HitStatus, AssignmentStatus
-def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int):
+def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int) -> None:
+
tags = {
"host": socket.gethostname(), # could be "amt-jb-0"
"service": "amt-jb",
@@ -23,10 +24,14 @@ def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int):
"fields": {"count": cnt},
}
if influx_client:
- influx_client.write_points([point])
+ influx_client.write_points(points=[point])
+
+ return None
-def write_assignment_gauge(status: AssignmentStatus, amt_hit_type_id: str, cnt: int):
+def write_assignment_gauge(
+ status: AssignmentStatus, amt_hit_type_id: str, cnt: int
+) -> None:
tags = {
"host": socket.gethostname(),
"service": "amt-jb",
@@ -40,15 +45,17 @@ def write_assignment_gauge(status: AssignmentStatus, amt_hit_type_id: str, cnt:
"fields": {"count": cnt},
}
if influx_client:
- influx_client.write_points([point])
+ influx_client.write_points(points=[point])
+
+ return None
def emit_hit_event(
status: HitStatus, amt_hit_type_id: str, reason: Optional[str] = None
-):
+) -> None:
"""
- e.g. a HIT was created, Reviewable, etc. We don't have a "created" HitStatus,
- so it would just be when status=='Assignable'
+ e.g. a HIT was created, Reviewable, etc. We don't have a "created"
+ HitStatus, so it would just be when status=='Assignable'
"""
tags = {
"host": socket.gethostname(),
@@ -57,8 +64,10 @@ def emit_hit_event(
"amt_hit_type_id": amt_hit_type_id,
"debug": settings.debug,
}
+
if reason:
tags["reason"] = reason
+
point = {
"measurement": "amt_jb.hit_events",
"tags": tags,
@@ -68,10 +77,12 @@ def emit_hit_event(
if influx_client:
influx_client.write_points([point])
+ return None
+
def emit_assignment_event(
status: AssignmentStatus, amt_hit_type_id: str, reason: Optional[str] = None
-):
+) -> None:
"""
e.g. an Assignment was accepted/approved/reject
"""
@@ -82,8 +93,10 @@ def emit_assignment_event(
"amt_hit_type_id": amt_hit_type_id,
"debug": settings.debug,
}
+
if reason:
tags["reason"] = reason
+
point = {
"measurement": "amt_jb.assignment_events",
"tags": tags,
@@ -93,10 +106,15 @@ def emit_assignment_event(
if influx_client:
influx_client.write_points([point])
+ return None
+
-def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: str):
+def emit_mturk_notification_event(
+ event_type: EventTypeType, amt_hit_type_id: str
+) -> None:
"""
- e.g. a Mturk notification was received. We just put it in redis, we haven't processed it yet.
+ e.g. a Mturk notification was received. We just put it in redis, we
+ haven't processed it yet.
"""
tags = {
"host": socket.gethostname(),
@@ -105,6 +123,7 @@ def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: st
"amt_hit_type_id": amt_hit_type_id,
"debug": settings.debug,
}
+
point = {
"measurement": "amt_jb.mturk_notification_events",
"tags": tags,
@@ -114,8 +133,10 @@ def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: st
if influx_client:
influx_client.write_points([point])
+ return None
+
-def emit_error_event(event_type: str, amt_hit_type_id: str):
+def emit_error_event(event_type: str, amt_hit_type_id: str) -> None:
"""
e.g. todo: structure the error_types
"""
@@ -126,6 +147,7 @@ def emit_error_event(event_type: str, amt_hit_type_id: str):
"amt_hit_type_id": amt_hit_type_id,
"debug": settings.debug,
}
+
point = {
"measurement": "amt_jb.error_events",
"tags": tags,
@@ -135,8 +157,10 @@ def emit_error_event(event_type: str, amt_hit_type_id: str):
if influx_client:
influx_client.write_points([point])
+ return None
+
-def emit_bonus_event(amount: USDCent, amt_hit_type_id: str):
+def emit_bonus_event(amount: USDCent, amt_hit_type_id: str) -> None:
"""
An AMT bonus was awarded
"""
@@ -146,6 +170,7 @@ def emit_bonus_event(amount: USDCent, amt_hit_type_id: str):
"amt_hit_type_id": amt_hit_type_id,
"debug": settings.debug,
}
+
point = {
"measurement": "amt_jb.bonus_events",
"tags": tags,
@@ -154,3 +179,5 @@ def emit_bonus_event(amount: USDCent, amt_hit_type_id: str):
if influx_client:
influx_client.write_points([point])
+
+ return None
diff --git a/jb/flow/setup_tasks.py b/jb/flow/setup_tasks.py
new file mode 100644
index 0000000..4664374
--- /dev/null
+++ b/jb/flow/setup_tasks.py
@@ -0,0 +1,36 @@
+from jb.config import TOPIC_ARN, SUBSCRIPTION
+from jb.decorators import SNS_CLIENT, AMT_CLIENT
+from jb.config import settings
+
+
+def initial_setup():
+ # Run once for initial setup. Not on each server start or anything
+ subscription = SNS_CLIENT.subscribe( # type: ignore
+ TopicArn=TOPIC_ARN,
+ Protocol="https",
+ Endpoint=f"https://jamesbillings67.com/{settings.sns_path}/",
+ ReturnSubscriptionArn=True,
+ )
+
+
+def check_sns_configuration():
+ SNS_CLIENT.get_topic_attributes(TopicArn=TOPIC_ARN)
+
+ # check this TOPIC_ARN exists
+ # (doesnt have permission, dont need this anyways)
+ # res = SNS_CLIENT.list_topics()
+ # arns = {x["TopicArn"] for x in res["Topics"]}
+ # assert TOPIC_ARN in arns, f"SNS Topic {TOPIC_ARN} doesn't exist!"
+
+ subs = SNS_CLIENT.list_subscriptions_by_topic(TopicArn=TOPIC_ARN)
+ assert SUBSCRIPTION in subs["Subscriptions"]
+
+ AMT_CLIENT.send_test_event_notification(
+ Notification={
+ "Destination": TOPIC_ARN,
+ "Transport": "SNS",
+ "Version": "2006-05-05",
+ "EventTypes": ["AssignmentSubmitted"],
+ },
+ TestEventType="AssignmentSubmitted",
+ )
diff --git a/jb/flow/tasks.py b/jb/flow/tasks.py
index e7c64b9..808c8f7 100644
--- a/jb/flow/tasks.py
+++ b/jb/flow/tasks.py
@@ -1,5 +1,6 @@
import logging
import time
+from typing import TypedDict, cast
from generalresearchutils.config import is_debug
@@ -15,6 +16,11 @@ logger = logging.getLogger()
logger.setLevel(logging.INFO)
+class HitRow(TypedDict):
+ amt_hit_id: str
+ amt_hit_type_id: str
+
+
def check_stale_hits():
# Check live hits that haven't been modified in a long time. They may not
# be expired yet, but maybe something is wrong?
@@ -28,7 +34,7 @@ def check_stale_hits():
LIMIT 100;""",
params={"status": HitStatus.Assignable.value},
)
- for hit in res:
+ for hit in cast(list[HitRow], res):
logging.info(f"check_stale_hits: {hit["amt_hit_id"]}")
check_hit_status(
amt_hit_id=hit["amt_hit_id"],
@@ -49,7 +55,7 @@ def check_expired_hits():
LIMIT 100;""",
params={"status": HitStatus.Assignable.value},
)
- for hit in res:
+ for hit in cast(list[HitRow], res):
logging.info(f"check_expired_hits: {hit["amt_hit_id"]}")
check_hit_status(
amt_hit_id=hit["amt_hit_id"],
@@ -74,8 +80,12 @@ def create_hit_from_hittype(hit_type: HitType) -> Hit:
def refill_hits() -> None:
+
for hit_type in HTM.filter_active():
- active_count = HM.get_active_count(hit_type.id)
+ assert hit_type.id
+ assert hit_type.amt_hit_type_id
+
+ active_count = HM.get_active_count(hit_type_id=hit_type.id)
logging.info(
f"HitType: {hit_type.amt_hit_type_id}, {hit_type.min_active=}, active_count={active_count}"
)