diff options
| author | Max Nanis | 2026-02-24 17:26:15 -0500 |
|---|---|---|
| committer | Max Nanis | 2026-02-24 17:26:15 -0500 |
| commit | 8c1940445503fd6678d0961600f2be81622793a2 (patch) | |
| tree | b9173562b8824b5eaa805e446d9d780e1f23fb2a /jb/flow | |
| parent | 25d8c3c214baf10f6520cc1351f78473150e5d7a (diff) | |
| download | amt-jb-8c1940445503fd6678d0961600f2be81622793a2.tar.gz amt-jb-8c1940445503fd6678d0961600f2be81622793a2.zip | |
Extensive use of type checking. Movement of pytest conf towards handling managers (for db agnostic unittest). Starting to organize pytests.
Diffstat (limited to 'jb/flow')
| -rw-r--r-- | jb/flow/assignment_tasks.py | 52 | ||||
| -rw-r--r-- | jb/flow/events.py | 35 | ||||
| -rw-r--r-- | jb/flow/maintenance.py | 14 | ||||
| -rw-r--r-- | jb/flow/monitoring.py | 53 | ||||
| -rw-r--r-- | jb/flow/setup_tasks.py | 36 | ||||
| -rw-r--r-- | jb/flow/tasks.py | 16 |
6 files changed, 166 insertions, 40 deletions
diff --git a/jb/flow/assignment_tasks.py b/jb/flow/assignment_tasks.py index 4022716..bb0877d 100644 --- a/jb/flow/assignment_tasks.py +++ b/jb/flow/assignment_tasks.py @@ -5,6 +5,7 @@ from typing import Optional from generalresearchutils.models.thl.definitions import PayoutStatus, StatusCode1 from generalresearchutils.models.thl.wallet.cashout_method import CashoutRequestInfo +from generalresearchutils.currency import USDCent from jb.decorators import AM, HM, BM from jb.flow.monitoring import emit_error_event, emit_assignment_event, emit_bonus_event @@ -21,21 +22,20 @@ from jb.managers.thl import ( get_user_blocked, get_task_status, user_cashout_request, - AMT_ASSIGNMENT_CASHOUT_METHOD, manage_pending_cashout, get_user_blocked_or_not_exists, get_wallet_balance, - AMT_BONUS_CASHOUT_METHOD, ) from jb.models.assignment import Assignment -from jb.models.currency import USDCent from jb.models.definitions import AssignmentStatus from jb.models.event import MTurkEvent +from jb.config import settings def process_assignment_submitted(event: MTurkEvent) -> None: """ - Called either directly or from the SNS Notification that a HIT was submitted + Called either directly or from the SNS Notification that a + HIT was submitted :return: None """ @@ -137,13 +137,22 @@ def process_assignment_submitted(event: MTurkEvent) -> None: return issue_worker_payment(assignment) -def review_hit(assignment): +def review_hit(assignment: Assignment) -> None: # Reviewable to Reviewing AMTManager.update_hit_review_status(amt_hit_id=assignment.amt_hit_id, revert=False) hit, _ = AMTManager.get_hit_if_exists(amt_hit_id=assignment.amt_hit_id) + + if hit is None: + logging.warning( + f"Hit not found when trying to review hit: {assignment.amt_hit_id}" + ) + return None + # Update the db HM.update_hit(hit) + return None + def handle_assignment_w_no_work(assignment: Assignment) -> Assignment: """ @@ -267,6 +276,10 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment: amt_worker_id = assignment.amt_worker_id amt_assignment_id = assignment.amt_assignment_id tsid = assignment.tsid + assert ( + tsid is not None + ), "Assignment must have a tsid to be handled in handle_assignment_w_work" + hit = HM.get_from_amt_id(amt_hit_id=assignment.amt_hit_id) tsr = get_task_status(tsid=tsid) @@ -289,6 +302,7 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment: event_type = "assignment_submitted_quality_fail" else: event_type = "assignment_submitted_work_not_complete" + emit_error_event( event_type=event_type, amt_hit_type_id=hit.amt_hit_type_id, @@ -324,6 +338,8 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment: ) return assignment + assert req.id + # We've approved the HIT payment, now update the db to reflect this, and approve the assignment assignment = approve_assignment( amt_assignment_id=amt_assignment_id, @@ -331,7 +347,9 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment: amt_hit_type_id=hit.amt_hit_type_id, ) # We complete after the assignment is approved - complete_res = manage_pending_cashout(req.id, PayoutStatus.COMPLETE) + complete_res = manage_pending_cashout( + cashout_id=req.id, payout_status=PayoutStatus.COMPLETE + ) if complete_res.status != PayoutStatus.COMPLETE: # unclear wny this would happen raise ValueError(f"Failed to complete cashout: {req.id}") @@ -345,8 +363,9 @@ def submit_and_approve_amt_assignment_request( req = user_cashout_request( amt_worker_id=amt_worker_id, amount=amount, - cashout_method_id=AMT_ASSIGNMENT_CASHOUT_METHOD, + cashout_method_id=settings.amt_assignment_cashout_method, ) + assert req.id if req.status != PayoutStatus.PENDING: return None @@ -365,8 +384,9 @@ def submit_and_approve_amt_bonus_request( req = user_cashout_request( amt_worker_id=amt_worker_id, amount=amount, - cashout_method_id=AMT_BONUS_CASHOUT_METHOD, + cashout_method_id=settings.amt_bonus_cashout_method, ) + assert req.id if req.status != PayoutStatus.PENDING: return None @@ -404,6 +424,7 @@ def issue_worker_payment(assignment: Assignment) -> None: amt_hit_type_id=hit.amt_hit_type_id, ) return None + assert pe.id AMTManager.send_bonus( amt_worker_id=assignment.amt_worker_id, @@ -412,13 +433,26 @@ def issue_worker_payment(assignment: Assignment) -> None: reason=BONUS_MESSAGE, unique_request_token=pe.id, ) + # Confirm it was sent through amt bonus = AMTManager.get_bonus( amt_assignment_id=assignment.amt_assignment_id, payout_event_id=pe.id ) + + if bonus is None: + logging.warning( + f"Failed to find bonus after sending it: {amt_assignment_id} {pe.id}" + ) + emit_error_event( + event_type="bonus_not_found_after_sending", + amt_hit_type_id=hit.amt_hit_type_id, + ) + return None + # Create in DB - BM.create(bonus) + BM.create(bonus=bonus) emit_bonus_event(amount=amount, amt_hit_type_id=hit.amt_hit_type_id) + # Complete cashout res = manage_pending_cashout(pe.id, PayoutStatus.COMPLETE) if res.status != PayoutStatus.COMPLETE: diff --git a/jb/flow/events.py b/jb/flow/events.py index 3961a64..7b7bd32 100644 --- a/jb/flow/events.py +++ b/jb/flow/events.py @@ -1,8 +1,8 @@ import logging import time from concurrent import futures -from concurrent.futures import ThreadPoolExecutor, Executor, as_completed -from typing import Optional +from concurrent.futures import ThreadPoolExecutor, Executor +from typing import Optional, cast, TypedDict import redis @@ -16,6 +16,15 @@ from jb.decorators import REDIS from jb.flow.assignment_tasks import process_assignment_submitted from jb.models.event import MTurkEvent +StreamMessages = list[tuple[str, list[tuple[bytes, dict[bytes, bytes]]]]] + + +class PendingEntry(TypedDict): + message_id: bytes + consumer: bytes + time_since_delivered: int + times_delivered: int + def process_mturk_events_task(): executor = ThreadPoolExecutor(max_workers=5) @@ -57,24 +66,26 @@ def create_consumer_group(): def process_mturk_events_chunk(executor: Executor) -> Optional[int]: - msgs = REDIS.xreadgroup( + msgs_raw = REDIS.xreadgroup( groupname=CONSUMER_GROUP, consumername=CONSUMER_NAME, streams={JB_EVENTS_STREAM: ">"}, count=10, ) - if not msgs: + if not msgs_raw: return None - msgs = msgs[0][1] # the queue, we only have 1 + + msgs = cast(StreamMessages, msgs_raw)[0][1] # The queue, we only have 1 fs = [] for msg in msgs: msg_id, data = msg - msg_json = data["data"] - event = MTurkEvent.model_validate_json(msg_json) + msg_json: str = data["data"] + + event = MTurkEvent.model_validate_json(json_data=msg_json) if event.event_type == "AssignmentSubmitted": fs.append( - executor.submit(process_assignment_submitted_event, event, msg_id) + executor.submit(process_assignment_submitted_event, event, str(msg_id)) ) else: logging.info(f"Discarding {event}") @@ -93,9 +104,13 @@ def handle_pending_msgs(): # Looks in the redis queue for msgs that # are pending (read by a consumer but not ACK). These prob failed. # Below is from chatgpt, idk if it works - pending = REDIS.xpending_range( - JB_EVENTS_STREAM, CONSUMER_GROUP, min="-", max="+", count=10 + pending = cast( + list[PendingEntry], + REDIS.xpending_range( + JB_EVENTS_STREAM, CONSUMER_GROUP, min="-", max="+", count=10 + ), ) + for entry in pending: msg_id = entry["message_id"] # Claim message if idle > 10 sec diff --git a/jb/flow/maintenance.py b/jb/flow/maintenance.py index 5dc9cea..744ed1d 100644 --- a/jb/flow/maintenance.py +++ b/jb/flow/maintenance.py @@ -11,16 +11,20 @@ def check_hit_status( ) -> HitStatus: """ (this used to be called "process_hit") - Request information from Amazon regarding the status of a HIT ID. Update the local state from - that response. + Request information from Amazon regarding the status of a HIT ID. Update + the local state from that response. """ + hit_status = AMTManager.get_hit_status(amt_hit_id=amt_hit_id) - # We're assuming that in the db this Hit is marked as Assignable, or else we wouldn't - # have called this function. + + # We're assuming that in the db this Hit is marked as Assignable, or + # else we wouldn't have called this function. if hit_status != HitStatus.Assignable: - # todo: should update also assignment_pending_count, assignment_available_count, assignment_completed_count + # TODO: should update also assignment_pending_count, + # assignment_available_count, assignment_completed_count HM.update_status(amt_hit_id=amt_hit_id, hit_status=hit_status) emit_hit_event( status=hit_status, amt_hit_type_id=amt_hit_type_id, reason=reason ) + return hit_status diff --git a/jb/flow/monitoring.py b/jb/flow/monitoring.py index c8432bb..28f7271 100644 --- a/jb/flow/monitoring.py +++ b/jb/flow/monitoring.py @@ -5,11 +5,12 @@ from mypy_boto3_mturk.literals import EventTypeType from jb.config import settings from jb.decorators import influx_client -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.definitions import HitStatus, AssignmentStatus -def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int): +def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int) -> None: + tags = { "host": socket.gethostname(), # could be "amt-jb-0" "service": "amt-jb", @@ -23,10 +24,14 @@ def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int): "fields": {"count": cnt}, } if influx_client: - influx_client.write_points([point]) + influx_client.write_points(points=[point]) + + return None -def write_assignment_gauge(status: AssignmentStatus, amt_hit_type_id: str, cnt: int): +def write_assignment_gauge( + status: AssignmentStatus, amt_hit_type_id: str, cnt: int +) -> None: tags = { "host": socket.gethostname(), "service": "amt-jb", @@ -40,15 +45,17 @@ def write_assignment_gauge(status: AssignmentStatus, amt_hit_type_id: str, cnt: "fields": {"count": cnt}, } if influx_client: - influx_client.write_points([point]) + influx_client.write_points(points=[point]) + + return None def emit_hit_event( status: HitStatus, amt_hit_type_id: str, reason: Optional[str] = None -): +) -> None: """ - e.g. a HIT was created, Reviewable, etc. We don't have a "created" HitStatus, - so it would just be when status=='Assignable' + e.g. a HIT was created, Reviewable, etc. We don't have a "created" + HitStatus, so it would just be when status=='Assignable' """ tags = { "host": socket.gethostname(), @@ -57,8 +64,10 @@ def emit_hit_event( "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + if reason: tags["reason"] = reason + point = { "measurement": "amt_jb.hit_events", "tags": tags, @@ -68,10 +77,12 @@ def emit_hit_event( if influx_client: influx_client.write_points([point]) + return None + def emit_assignment_event( status: AssignmentStatus, amt_hit_type_id: str, reason: Optional[str] = None -): +) -> None: """ e.g. an Assignment was accepted/approved/reject """ @@ -82,8 +93,10 @@ def emit_assignment_event( "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + if reason: tags["reason"] = reason + point = { "measurement": "amt_jb.assignment_events", "tags": tags, @@ -93,10 +106,15 @@ def emit_assignment_event( if influx_client: influx_client.write_points([point]) + return None + -def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: str): +def emit_mturk_notification_event( + event_type: EventTypeType, amt_hit_type_id: str +) -> None: """ - e.g. a Mturk notification was received. We just put it in redis, we haven't processed it yet. + e.g. a Mturk notification was received. We just put it in redis, we + haven't processed it yet. """ tags = { "host": socket.gethostname(), @@ -105,6 +123,7 @@ def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: st "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + point = { "measurement": "amt_jb.mturk_notification_events", "tags": tags, @@ -114,8 +133,10 @@ def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: st if influx_client: influx_client.write_points([point]) + return None + -def emit_error_event(event_type: str, amt_hit_type_id: str): +def emit_error_event(event_type: str, amt_hit_type_id: str) -> None: """ e.g. todo: structure the error_types """ @@ -126,6 +147,7 @@ def emit_error_event(event_type: str, amt_hit_type_id: str): "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + point = { "measurement": "amt_jb.error_events", "tags": tags, @@ -135,8 +157,10 @@ def emit_error_event(event_type: str, amt_hit_type_id: str): if influx_client: influx_client.write_points([point]) + return None + -def emit_bonus_event(amount: USDCent, amt_hit_type_id: str): +def emit_bonus_event(amount: USDCent, amt_hit_type_id: str) -> None: """ An AMT bonus was awarded """ @@ -146,6 +170,7 @@ def emit_bonus_event(amount: USDCent, amt_hit_type_id: str): "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + point = { "measurement": "amt_jb.bonus_events", "tags": tags, @@ -154,3 +179,5 @@ def emit_bonus_event(amount: USDCent, amt_hit_type_id: str): if influx_client: influx_client.write_points([point]) + + return None diff --git a/jb/flow/setup_tasks.py b/jb/flow/setup_tasks.py new file mode 100644 index 0000000..4664374 --- /dev/null +++ b/jb/flow/setup_tasks.py @@ -0,0 +1,36 @@ +from jb.config import TOPIC_ARN, SUBSCRIPTION +from jb.decorators import SNS_CLIENT, AMT_CLIENT +from jb.config import settings + + +def initial_setup(): + # Run once for initial setup. Not on each server start or anything + subscription = SNS_CLIENT.subscribe( # type: ignore + TopicArn=TOPIC_ARN, + Protocol="https", + Endpoint=f"https://jamesbillings67.com/{settings.sns_path}/", + ReturnSubscriptionArn=True, + ) + + +def check_sns_configuration(): + SNS_CLIENT.get_topic_attributes(TopicArn=TOPIC_ARN) + + # check this TOPIC_ARN exists + # (doesnt have permission, dont need this anyways) + # res = SNS_CLIENT.list_topics() + # arns = {x["TopicArn"] for x in res["Topics"]} + # assert TOPIC_ARN in arns, f"SNS Topic {TOPIC_ARN} doesn't exist!" + + subs = SNS_CLIENT.list_subscriptions_by_topic(TopicArn=TOPIC_ARN) + assert SUBSCRIPTION in subs["Subscriptions"] + + AMT_CLIENT.send_test_event_notification( + Notification={ + "Destination": TOPIC_ARN, + "Transport": "SNS", + "Version": "2006-05-05", + "EventTypes": ["AssignmentSubmitted"], + }, + TestEventType="AssignmentSubmitted", + ) diff --git a/jb/flow/tasks.py b/jb/flow/tasks.py index e7c64b9..808c8f7 100644 --- a/jb/flow/tasks.py +++ b/jb/flow/tasks.py @@ -1,5 +1,6 @@ import logging import time +from typing import TypedDict, cast from generalresearchutils.config import is_debug @@ -15,6 +16,11 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) +class HitRow(TypedDict): + amt_hit_id: str + amt_hit_type_id: str + + def check_stale_hits(): # Check live hits that haven't been modified in a long time. They may not # be expired yet, but maybe something is wrong? @@ -28,7 +34,7 @@ def check_stale_hits(): LIMIT 100;""", params={"status": HitStatus.Assignable.value}, ) - for hit in res: + for hit in cast(list[HitRow], res): logging.info(f"check_stale_hits: {hit["amt_hit_id"]}") check_hit_status( amt_hit_id=hit["amt_hit_id"], @@ -49,7 +55,7 @@ def check_expired_hits(): LIMIT 100;""", params={"status": HitStatus.Assignable.value}, ) - for hit in res: + for hit in cast(list[HitRow], res): logging.info(f"check_expired_hits: {hit["amt_hit_id"]}") check_hit_status( amt_hit_id=hit["amt_hit_id"], @@ -74,8 +80,12 @@ def create_hit_from_hittype(hit_type: HitType) -> Hit: def refill_hits() -> None: + for hit_type in HTM.filter_active(): - active_count = HM.get_active_count(hit_type.id) + assert hit_type.id + assert hit_type.amt_hit_type_id + + active_count = HM.get_active_count(hit_type_id=hit_type.id) logging.info( f"HitType: {hit_type.amt_hit_type_id}, {hit_type.min_active=}, active_count={active_count}" ) |
