From 8c1940445503fd6678d0961600f2be81622793a2 Mon Sep 17 00:00:00 2001 From: Max Nanis Date: Tue, 24 Feb 2026 17:26:15 -0500 Subject: Extensive use of type checking. Movement of pytest conf towards handling managers (for db agnostic unittest). Starting to organize pytests. --- jb/decorators.py | 4 +-- jb/flow/assignment_tasks.py | 52 +++++++++++++++++++++++++++------ jb/flow/events.py | 35 ++++++++++++++++------- jb/flow/maintenance.py | 14 +++++---- jb/flow/monitoring.py | 53 +++++++++++++++++++++++++--------- jb/flow/setup_tasks.py | 36 +++++++++++++++++++++++ jb/flow/tasks.py | 16 +++++++++-- jb/main.py | 7 +++-- jb/managers/__init__.py | 4 +-- jb/managers/amt.py | 21 ++++++++++---- jb/managers/assignment.py | 8 +++--- jb/managers/bonus.py | 11 +++---- jb/managers/hit.py | 36 +++++++++++++++-------- jb/managers/thl.py | 40 ++++++++++++++++---------- jb/models/assignment.py | 23 ++++++++------- jb/models/bonus.py | 7 ++--- jb/models/currency.py | 70 --------------------------------------------- jb/models/custom_types.py | 3 +- jb/models/definitions.py | 28 +----------------- jb/models/event.py | 19 ++++++++---- jb/models/hit.py | 24 ++++++++++------ jb/settings.py | 2 +- jb/views/common.py | 68 ++----------------------------------------- jb/views/tasks.py | 9 ++++-- jb/views/utils.py | 7 +++-- 25 files changed, 308 insertions(+), 289 deletions(-) create mode 100644 jb/flow/setup_tasks.py delete mode 100644 jb/models/currency.py (limited to 'jb') diff --git a/jb/decorators.py b/jb/decorators.py index 54d36c7..cbc28b5 100644 --- a/jb/decorators.py +++ b/jb/decorators.py @@ -66,9 +66,7 @@ AM = AssignmentManager( pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE] ) -BM = BonusManager( - pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE] -) +BM = BonusManager(pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE]) influx_client = None if settings.influx_db: diff --git a/jb/flow/assignment_tasks.py b/jb/flow/assignment_tasks.py index 4022716..bb0877d 100644 --- a/jb/flow/assignment_tasks.py +++ b/jb/flow/assignment_tasks.py @@ -5,6 +5,7 @@ from typing import Optional from generalresearchutils.models.thl.definitions import PayoutStatus, StatusCode1 from generalresearchutils.models.thl.wallet.cashout_method import CashoutRequestInfo +from generalresearchutils.currency import USDCent from jb.decorators import AM, HM, BM from jb.flow.monitoring import emit_error_event, emit_assignment_event, emit_bonus_event @@ -21,21 +22,20 @@ from jb.managers.thl import ( get_user_blocked, get_task_status, user_cashout_request, - AMT_ASSIGNMENT_CASHOUT_METHOD, manage_pending_cashout, get_user_blocked_or_not_exists, get_wallet_balance, - AMT_BONUS_CASHOUT_METHOD, ) from jb.models.assignment import Assignment -from jb.models.currency import USDCent from jb.models.definitions import AssignmentStatus from jb.models.event import MTurkEvent +from jb.config import settings def process_assignment_submitted(event: MTurkEvent) -> None: """ - Called either directly or from the SNS Notification that a HIT was submitted + Called either directly or from the SNS Notification that a + HIT was submitted :return: None """ @@ -137,13 +137,22 @@ def process_assignment_submitted(event: MTurkEvent) -> None: return issue_worker_payment(assignment) -def review_hit(assignment): +def review_hit(assignment: Assignment) -> None: # Reviewable to Reviewing AMTManager.update_hit_review_status(amt_hit_id=assignment.amt_hit_id, revert=False) hit, _ = AMTManager.get_hit_if_exists(amt_hit_id=assignment.amt_hit_id) + + if hit is None: + logging.warning( + f"Hit not found when trying to review hit: {assignment.amt_hit_id}" + ) + return None + # Update the db HM.update_hit(hit) + return None + def handle_assignment_w_no_work(assignment: Assignment) -> Assignment: """ @@ -267,6 +276,10 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment: amt_worker_id = assignment.amt_worker_id amt_assignment_id = assignment.amt_assignment_id tsid = assignment.tsid + assert ( + tsid is not None + ), "Assignment must have a tsid to be handled in handle_assignment_w_work" + hit = HM.get_from_amt_id(amt_hit_id=assignment.amt_hit_id) tsr = get_task_status(tsid=tsid) @@ -289,6 +302,7 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment: event_type = "assignment_submitted_quality_fail" else: event_type = "assignment_submitted_work_not_complete" + emit_error_event( event_type=event_type, amt_hit_type_id=hit.amt_hit_type_id, @@ -324,6 +338,8 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment: ) return assignment + assert req.id + # We've approved the HIT payment, now update the db to reflect this, and approve the assignment assignment = approve_assignment( amt_assignment_id=amt_assignment_id, @@ -331,7 +347,9 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment: amt_hit_type_id=hit.amt_hit_type_id, ) # We complete after the assignment is approved - complete_res = manage_pending_cashout(req.id, PayoutStatus.COMPLETE) + complete_res = manage_pending_cashout( + cashout_id=req.id, payout_status=PayoutStatus.COMPLETE + ) if complete_res.status != PayoutStatus.COMPLETE: # unclear wny this would happen raise ValueError(f"Failed to complete cashout: {req.id}") @@ -345,8 +363,9 @@ def submit_and_approve_amt_assignment_request( req = user_cashout_request( amt_worker_id=amt_worker_id, amount=amount, - cashout_method_id=AMT_ASSIGNMENT_CASHOUT_METHOD, + cashout_method_id=settings.amt_assignment_cashout_method, ) + assert req.id if req.status != PayoutStatus.PENDING: return None @@ -365,8 +384,9 @@ def submit_and_approve_amt_bonus_request( req = user_cashout_request( amt_worker_id=amt_worker_id, amount=amount, - cashout_method_id=AMT_BONUS_CASHOUT_METHOD, + cashout_method_id=settings.amt_bonus_cashout_method, ) + assert req.id if req.status != PayoutStatus.PENDING: return None @@ -404,6 +424,7 @@ def issue_worker_payment(assignment: Assignment) -> None: amt_hit_type_id=hit.amt_hit_type_id, ) return None + assert pe.id AMTManager.send_bonus( amt_worker_id=assignment.amt_worker_id, @@ -412,13 +433,26 @@ def issue_worker_payment(assignment: Assignment) -> None: reason=BONUS_MESSAGE, unique_request_token=pe.id, ) + # Confirm it was sent through amt bonus = AMTManager.get_bonus( amt_assignment_id=assignment.amt_assignment_id, payout_event_id=pe.id ) + + if bonus is None: + logging.warning( + f"Failed to find bonus after sending it: {amt_assignment_id} {pe.id}" + ) + emit_error_event( + event_type="bonus_not_found_after_sending", + amt_hit_type_id=hit.amt_hit_type_id, + ) + return None + # Create in DB - BM.create(bonus) + BM.create(bonus=bonus) emit_bonus_event(amount=amount, amt_hit_type_id=hit.amt_hit_type_id) + # Complete cashout res = manage_pending_cashout(pe.id, PayoutStatus.COMPLETE) if res.status != PayoutStatus.COMPLETE: diff --git a/jb/flow/events.py b/jb/flow/events.py index 3961a64..7b7bd32 100644 --- a/jb/flow/events.py +++ b/jb/flow/events.py @@ -1,8 +1,8 @@ import logging import time from concurrent import futures -from concurrent.futures import ThreadPoolExecutor, Executor, as_completed -from typing import Optional +from concurrent.futures import ThreadPoolExecutor, Executor +from typing import Optional, cast, TypedDict import redis @@ -16,6 +16,15 @@ from jb.decorators import REDIS from jb.flow.assignment_tasks import process_assignment_submitted from jb.models.event import MTurkEvent +StreamMessages = list[tuple[str, list[tuple[bytes, dict[bytes, bytes]]]]] + + +class PendingEntry(TypedDict): + message_id: bytes + consumer: bytes + time_since_delivered: int + times_delivered: int + def process_mturk_events_task(): executor = ThreadPoolExecutor(max_workers=5) @@ -57,24 +66,26 @@ def create_consumer_group(): def process_mturk_events_chunk(executor: Executor) -> Optional[int]: - msgs = REDIS.xreadgroup( + msgs_raw = REDIS.xreadgroup( groupname=CONSUMER_GROUP, consumername=CONSUMER_NAME, streams={JB_EVENTS_STREAM: ">"}, count=10, ) - if not msgs: + if not msgs_raw: return None - msgs = msgs[0][1] # the queue, we only have 1 + + msgs = cast(StreamMessages, msgs_raw)[0][1] # The queue, we only have 1 fs = [] for msg in msgs: msg_id, data = msg - msg_json = data["data"] - event = MTurkEvent.model_validate_json(msg_json) + msg_json: str = data["data"] + + event = MTurkEvent.model_validate_json(json_data=msg_json) if event.event_type == "AssignmentSubmitted": fs.append( - executor.submit(process_assignment_submitted_event, event, msg_id) + executor.submit(process_assignment_submitted_event, event, str(msg_id)) ) else: logging.info(f"Discarding {event}") @@ -93,9 +104,13 @@ def handle_pending_msgs(): # Looks in the redis queue for msgs that # are pending (read by a consumer but not ACK). These prob failed. # Below is from chatgpt, idk if it works - pending = REDIS.xpending_range( - JB_EVENTS_STREAM, CONSUMER_GROUP, min="-", max="+", count=10 + pending = cast( + list[PendingEntry], + REDIS.xpending_range( + JB_EVENTS_STREAM, CONSUMER_GROUP, min="-", max="+", count=10 + ), ) + for entry in pending: msg_id = entry["message_id"] # Claim message if idle > 10 sec diff --git a/jb/flow/maintenance.py b/jb/flow/maintenance.py index 5dc9cea..744ed1d 100644 --- a/jb/flow/maintenance.py +++ b/jb/flow/maintenance.py @@ -11,16 +11,20 @@ def check_hit_status( ) -> HitStatus: """ (this used to be called "process_hit") - Request information from Amazon regarding the status of a HIT ID. Update the local state from - that response. + Request information from Amazon regarding the status of a HIT ID. Update + the local state from that response. """ + hit_status = AMTManager.get_hit_status(amt_hit_id=amt_hit_id) - # We're assuming that in the db this Hit is marked as Assignable, or else we wouldn't - # have called this function. + + # We're assuming that in the db this Hit is marked as Assignable, or + # else we wouldn't have called this function. if hit_status != HitStatus.Assignable: - # todo: should update also assignment_pending_count, assignment_available_count, assignment_completed_count + # TODO: should update also assignment_pending_count, + # assignment_available_count, assignment_completed_count HM.update_status(amt_hit_id=amt_hit_id, hit_status=hit_status) emit_hit_event( status=hit_status, amt_hit_type_id=amt_hit_type_id, reason=reason ) + return hit_status diff --git a/jb/flow/monitoring.py b/jb/flow/monitoring.py index c8432bb..28f7271 100644 --- a/jb/flow/monitoring.py +++ b/jb/flow/monitoring.py @@ -5,11 +5,12 @@ from mypy_boto3_mturk.literals import EventTypeType from jb.config import settings from jb.decorators import influx_client -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.definitions import HitStatus, AssignmentStatus -def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int): +def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int) -> None: + tags = { "host": socket.gethostname(), # could be "amt-jb-0" "service": "amt-jb", @@ -23,10 +24,14 @@ def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int): "fields": {"count": cnt}, } if influx_client: - influx_client.write_points([point]) + influx_client.write_points(points=[point]) + + return None -def write_assignment_gauge(status: AssignmentStatus, amt_hit_type_id: str, cnt: int): +def write_assignment_gauge( + status: AssignmentStatus, amt_hit_type_id: str, cnt: int +) -> None: tags = { "host": socket.gethostname(), "service": "amt-jb", @@ -40,15 +45,17 @@ def write_assignment_gauge(status: AssignmentStatus, amt_hit_type_id: str, cnt: "fields": {"count": cnt}, } if influx_client: - influx_client.write_points([point]) + influx_client.write_points(points=[point]) + + return None def emit_hit_event( status: HitStatus, amt_hit_type_id: str, reason: Optional[str] = None -): +) -> None: """ - e.g. a HIT was created, Reviewable, etc. We don't have a "created" HitStatus, - so it would just be when status=='Assignable' + e.g. a HIT was created, Reviewable, etc. We don't have a "created" + HitStatus, so it would just be when status=='Assignable' """ tags = { "host": socket.gethostname(), @@ -57,8 +64,10 @@ def emit_hit_event( "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + if reason: tags["reason"] = reason + point = { "measurement": "amt_jb.hit_events", "tags": tags, @@ -68,10 +77,12 @@ def emit_hit_event( if influx_client: influx_client.write_points([point]) + return None + def emit_assignment_event( status: AssignmentStatus, amt_hit_type_id: str, reason: Optional[str] = None -): +) -> None: """ e.g. an Assignment was accepted/approved/reject """ @@ -82,8 +93,10 @@ def emit_assignment_event( "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + if reason: tags["reason"] = reason + point = { "measurement": "amt_jb.assignment_events", "tags": tags, @@ -93,10 +106,15 @@ def emit_assignment_event( if influx_client: influx_client.write_points([point]) + return None + -def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: str): +def emit_mturk_notification_event( + event_type: EventTypeType, amt_hit_type_id: str +) -> None: """ - e.g. a Mturk notification was received. We just put it in redis, we haven't processed it yet. + e.g. a Mturk notification was received. We just put it in redis, we + haven't processed it yet. """ tags = { "host": socket.gethostname(), @@ -105,6 +123,7 @@ def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: st "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + point = { "measurement": "amt_jb.mturk_notification_events", "tags": tags, @@ -114,8 +133,10 @@ def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: st if influx_client: influx_client.write_points([point]) + return None + -def emit_error_event(event_type: str, amt_hit_type_id: str): +def emit_error_event(event_type: str, amt_hit_type_id: str) -> None: """ e.g. todo: structure the error_types """ @@ -126,6 +147,7 @@ def emit_error_event(event_type: str, amt_hit_type_id: str): "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + point = { "measurement": "amt_jb.error_events", "tags": tags, @@ -135,8 +157,10 @@ def emit_error_event(event_type: str, amt_hit_type_id: str): if influx_client: influx_client.write_points([point]) + return None + -def emit_bonus_event(amount: USDCent, amt_hit_type_id: str): +def emit_bonus_event(amount: USDCent, amt_hit_type_id: str) -> None: """ An AMT bonus was awarded """ @@ -146,6 +170,7 @@ def emit_bonus_event(amount: USDCent, amt_hit_type_id: str): "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + point = { "measurement": "amt_jb.bonus_events", "tags": tags, @@ -154,3 +179,5 @@ def emit_bonus_event(amount: USDCent, amt_hit_type_id: str): if influx_client: influx_client.write_points([point]) + + return None diff --git a/jb/flow/setup_tasks.py b/jb/flow/setup_tasks.py new file mode 100644 index 0000000..4664374 --- /dev/null +++ b/jb/flow/setup_tasks.py @@ -0,0 +1,36 @@ +from jb.config import TOPIC_ARN, SUBSCRIPTION +from jb.decorators import SNS_CLIENT, AMT_CLIENT +from jb.config import settings + + +def initial_setup(): + # Run once for initial setup. Not on each server start or anything + subscription = SNS_CLIENT.subscribe( # type: ignore + TopicArn=TOPIC_ARN, + Protocol="https", + Endpoint=f"https://jamesbillings67.com/{settings.sns_path}/", + ReturnSubscriptionArn=True, + ) + + +def check_sns_configuration(): + SNS_CLIENT.get_topic_attributes(TopicArn=TOPIC_ARN) + + # check this TOPIC_ARN exists + # (doesnt have permission, dont need this anyways) + # res = SNS_CLIENT.list_topics() + # arns = {x["TopicArn"] for x in res["Topics"]} + # assert TOPIC_ARN in arns, f"SNS Topic {TOPIC_ARN} doesn't exist!" + + subs = SNS_CLIENT.list_subscriptions_by_topic(TopicArn=TOPIC_ARN) + assert SUBSCRIPTION in subs["Subscriptions"] + + AMT_CLIENT.send_test_event_notification( + Notification={ + "Destination": TOPIC_ARN, + "Transport": "SNS", + "Version": "2006-05-05", + "EventTypes": ["AssignmentSubmitted"], + }, + TestEventType="AssignmentSubmitted", + ) diff --git a/jb/flow/tasks.py b/jb/flow/tasks.py index e7c64b9..808c8f7 100644 --- a/jb/flow/tasks.py +++ b/jb/flow/tasks.py @@ -1,5 +1,6 @@ import logging import time +from typing import TypedDict, cast from generalresearchutils.config import is_debug @@ -15,6 +16,11 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) +class HitRow(TypedDict): + amt_hit_id: str + amt_hit_type_id: str + + def check_stale_hits(): # Check live hits that haven't been modified in a long time. They may not # be expired yet, but maybe something is wrong? @@ -28,7 +34,7 @@ def check_stale_hits(): LIMIT 100;""", params={"status": HitStatus.Assignable.value}, ) - for hit in res: + for hit in cast(list[HitRow], res): logging.info(f"check_stale_hits: {hit["amt_hit_id"]}") check_hit_status( amt_hit_id=hit["amt_hit_id"], @@ -49,7 +55,7 @@ def check_expired_hits(): LIMIT 100;""", params={"status": HitStatus.Assignable.value}, ) - for hit in res: + for hit in cast(list[HitRow], res): logging.info(f"check_expired_hits: {hit["amt_hit_id"]}") check_hit_status( amt_hit_id=hit["amt_hit_id"], @@ -74,8 +80,12 @@ def create_hit_from_hittype(hit_type: HitType) -> Hit: def refill_hits() -> None: + for hit_type in HTM.filter_active(): - active_count = HM.get_active_count(hit_type.id) + assert hit_type.id + assert hit_type.amt_hit_type_id + + active_count = HM.get_active_count(hit_type_id=hit_type.id) logging.info( f"HitType: {hit_type.amt_hit_type_id}, {hit_type.min_active=}, active_count={active_count}" ) diff --git a/jb/main.py b/jb/main.py index 8c1dbed..fa59167 100644 --- a/jb/main.py +++ b/jb/main.py @@ -1,6 +1,7 @@ from multiprocessing import Process +from typing import Any, Dict -from fastapi import FastAPI, Request +from fastapi import FastAPI from fastapi.responses import HTMLResponse from starlette.middleware.cors import CORSMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware @@ -35,7 +36,7 @@ app.include_router(router=common_router) @app.get("/robots.txt") @app.get("/sitemap.xml") @app.get("/favicon.ico") -def return_nothing(): +def return_nothing() -> Dict[str, Any]: return {} @@ -47,7 +48,7 @@ def serve_react_app(full_path: str): def schedule_tasks(): - from jb.flow.events import process_mturk_events_task, handle_pending_msgs_task + from jb.flow.events import process_mturk_events_task from jb.flow.tasks import refill_hits_task Process(target=process_mturk_events_task).start() diff --git a/jb/managers/__init__.py b/jb/managers/__init__.py index e2aab6d..e99569a 100644 --- a/jb/managers/__init__.py +++ b/jb/managers/__init__.py @@ -15,8 +15,8 @@ class PostgresManager: def __init__( self, pg_config: PostgresConfig, - permissions: Collection[Permission] = None, - **kwargs, + permissions: Collection[Permission] = None, # type: ignore + **kwargs, # type: ignore ): super().__init__(**kwargs) self.pg_config = pg_config diff --git a/jb/managers/amt.py b/jb/managers/amt.py index 79661c7..0ec70d3 100644 --- a/jb/managers/amt.py +++ b/jb/managers/amt.py @@ -10,7 +10,7 @@ from jb.decorators import AMT_CLIENT from jb.models import AMTAccount from jb.models.assignment import Assignment from jb.models.bonus import Bonus -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.definitions import HitStatus from jb.models.hit import HitType, HitQuestion, Hit @@ -48,19 +48,24 @@ class AMTManager: return hit, None @classmethod - def get_hit_status(cls, amt_hit_id: str): + def get_hit_status(cls, amt_hit_id: str) -> HitStatus: res, msg = cls.get_hit_if_exists(amt_hit_id=amt_hit_id) + if res is None: + if msg is None: + return HitStatus.Unassignable + if " does not exist. (" in msg: return HitStatus.Disposed else: logging.warning(msg) return HitStatus.Unassignable + return res.status @staticmethod def create_hit_type(hit_type: HitType): - res = AMT_CLIENT.create_hit_type(**hit_type.to_api_request_body()) + res = AMT_CLIENT.create_hit_type(**hit_type.to_api_request_body()) # type: ignore hit_type.amt_hit_type_id = res["HITTypeId"] AMT_CLIENT.update_notification_settings( HITTypeId=hit_type.amt_hit_type_id, @@ -94,8 +99,10 @@ class AMTManager: @staticmethod def get_assignment(amt_assignment_id: str) -> Assignment: - # note, you CANNOT get an assignment if it has been only ACCEPTED (by the worker) - # the api is stupid. it will only show up once it is submitted + """ + You CANNOT get an Assignment if it has been only ACCEPTED (by the + worker). The api is stupid, it will only show up once it is Submitted + """ res = AMT_CLIENT.get_assignment(AssignmentId=amt_assignment_id) ass_res: AssignmentTypeDef = res["Assignment"] assignment = Assignment.from_amt_get_assignment(ass_res) @@ -158,6 +165,7 @@ class AMTManager: raise ValueError(error_msg) # elif "This HIT is currently in the state 'Reviewing'" in error_msg: # logging.warning(error_msg) + return None @staticmethod @@ -203,7 +211,7 @@ class AMTManager: return None @staticmethod - def expire_all_hits(): + def expire_all_hits() -> None: # used in testing only (or in an emergency I guess) now = datetime.now(tz=timezone.utc) paginator = AMT_CLIENT.get_paginator("list_hits") @@ -214,3 +222,4 @@ class AMTManager: AMT_CLIENT.update_expiration_for_hit( HITId=hit["HITId"], ExpireAt=now ) + return None diff --git a/jb/managers/assignment.py b/jb/managers/assignment.py index fca72e8..dd3c866 100644 --- a/jb/managers/assignment.py +++ b/jb/managers/assignment.py @@ -28,7 +28,7 @@ class AssignmentManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() stub.id = pk return None @@ -62,7 +62,7 @@ class AssignmentManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() assignment.id = pk return None @@ -233,7 +233,7 @@ class AssignmentManager(PostgresManager): "lookback_interval": f"{lookback_hrs} hour", }, ) - return int(res[0]["c"]) + return int(res[0]["c"]) # type: ignore def rejected_count( self, amt_worker_id: str, lookback_hrs: int = 24 @@ -256,4 +256,4 @@ class AssignmentManager(PostgresManager): "status": AssignmentStatus.Rejected.value, }, ) - return int(res[0]["c"]) + return int(res[0]["c"]) # type: ignore diff --git a/jb/managers/bonus.py b/jb/managers/bonus.py index 0cb8b02..89b81f0 100644 --- a/jb/managers/bonus.py +++ b/jb/managers/bonus.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Any from psycopg import sql @@ -37,12 +37,12 @@ class BonusManager(PostgresManager): c.execute(query, data) res = c.fetchone() conn.commit() - bonus.id = res["id"] - bonus.assignment_id = res["assignment_id"] + bonus.id = res["id"] # type: ignore + bonus.assignment_id = res["assignment_id"] # type: ignore return None def filter(self, amt_assignment_id: str) -> List[Bonus]: - res = self.pg_config.execute_sql_query( + res: List[Any] = self.pg_config.execute_sql_query( """ SELECT mb.*, ma.amt_assignment_id FROM mtwerk_bonus mb @@ -51,4 +51,5 @@ class BonusManager(PostgresManager): """, params={"amt_assignment_id": amt_assignment_id}, ) - return [Bonus.from_postgres(x) for x in res] + + return [Bonus.from_postgres(data=x) for x in res] diff --git a/jb/managers/hit.py b/jb/managers/hit.py index 3832418..ce8ffa5 100644 --- a/jb/managers/hit.py +++ b/jb/managers/hit.py @@ -24,7 +24,7 @@ class HitQuestionManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() question.id = pk return None @@ -67,6 +67,7 @@ class HitQuestionManager(PostgresManager): class HitTypeManager(PostgresManager): + def create(self, hit_type: HitType) -> None: assert hit_type.amt_hit_type_id is not None data = hit_type.to_postgres() @@ -99,9 +100,10 @@ class HitTypeManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() hit_type.id = pk + return None def filter_active(self) -> List[HitType]: @@ -137,13 +139,13 @@ class HitTypeManager(PostgresManager): except AssertionError: return None - def get_or_create(self, hit_type: HitType) -> None: + def get_or_create(self, hit_type: HitType) -> HitType: res = self.get_if_exists(amt_hit_type_id=hit_type.amt_hit_type_id) if res: - hit_type.id = res.id - if res is None: - self.create(hit_type=hit_type) - return None + return res + + self.create(hit_type=hit_type) + return self.get(amt_hit_type_id=hit_type.amt_hit_type_id) def set_min_active(self, hit_type: HitType) -> None: assert hit_type.id, "must be in the db first!" @@ -164,6 +166,7 @@ class HitTypeManager(PostgresManager): class HitManager(PostgresManager): + def create(self, hit: Hit): assert hit.amt_hit_id is not None assert hit.id is None @@ -209,7 +212,7 @@ class HitManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() hit.id = pk return hit @@ -267,7 +270,7 @@ class HitManager(PostgresManager): c.execute(query, data) conn.commit() assert c.rowcount == 1, c.rowcount - hit.id = c.fetchone()["id"] + hit.id = c.fetchone()["id"] # type: ignore return None def get_from_amt_id(self, amt_hit_id: str) -> Hit: @@ -305,17 +308,26 @@ class HitManager(PostgresManager): """, params={"amt_hit_id": amt_hit_id}, ) - assert len(res) == 1 + assert len(res) == 1, "Incorrect number of results" res = res[0] + question_xml = HitQuestion.model_validate( {"height": res.pop("height"), "url": res.pop("url")} ).xml + res["question_id"] = res["question_id"] res["hit_question_xml"] = question_xml return Hit.from_postgres(res) - def get_active_count(self, hit_type_id: int): + def get_from_amt_id_if_exists(self, amt_hit_id: str) -> Optional[Hit]: + try: + return self.get_from_amt_id(amt_hit_id=amt_hit_id) + + except (AssertionError, Exception): + return None + + def get_active_count(self, hit_type_id: int) -> int: return self.pg_config.execute_sql_query( """ SELECT COUNT(1) as active_count @@ -326,7 +338,7 @@ class HitManager(PostgresManager): params={"status": HitStatus.Assignable, "hit_type_id": hit_type_id}, )[0]["active_count"] - def filter_active_ids(self, hit_type_id: int): + def filter_active_ids(self, hit_type_id: int) -> set[str]: res = self.pg_config.execute_sql_query( """ SELECT mh.amt_hit_id diff --git a/jb/managers/thl.py b/jb/managers/thl.py index b1dcbde..83f49f6 100644 --- a/jb/managers/thl.py +++ b/jb/managers/thl.py @@ -1,7 +1,3 @@ -from decimal import Decimal -from typing import Dict, Optional - -import requests from generalresearchutils.models.thl.payout import UserPayoutEvent from generalresearchutils.models.thl.task_status import TaskStatusResponse from generalresearchutils.models.thl.wallet.cashout_method import ( @@ -9,45 +5,58 @@ from generalresearchutils.models.thl.wallet.cashout_method import ( CashoutRequestInfo, ) +from generalresearchutils.models.thl.user_profile import UserProfile +from generalresearchutils.currency import USDCent + from jb.config import settings -from jb.models.currency import USDCent -from jb.models.definitions import PayoutStatus -# TODO: Organize this more with other endpoints (offerwall, cashout requests/approvals, etc). +from generalresearchutils.models.thl.definitions import PayoutStatus + + +from typing import Optional +import requests + +# TODO: Organize this more with other endpoints (offerwall, cashout +# requests/approvals, etc). -def get_user_profile(amt_worker_id: str) -> Dict: +def get_user_profile(amt_worker_id: str) -> UserProfile: url = f"{settings.fsb_host}{settings.product_id}/user/{amt_worker_id}/profile/" res = requests.get(url).json() if res.get("detail") == "user not found": raise ValueError("user not found") - return res["user_profile"] + + return UserProfile.model_validate(res["user_profile"]) def get_user_blocked(amt_worker_id: str) -> bool: + # Not blocked if None res = get_user_profile(amt_worker_id=amt_worker_id) - return res["user"]["blocked"] + return res.user.blocked if res.user.blocked is not None else False -def get_user_blocked_or_not_exists(amt_worker_id: str) -> bool: +def get_user_blocked_or_not_exists(amt_worker_id: str) -> Optional[bool]: try: res = get_user_profile(amt_worker_id=amt_worker_id) - return res["user"]["blocked"] + return res.user.blocked if res.user.blocked is not None else False except ValueError as e: if e.args[0] == "user not found": return True + return None + def get_task_status(tsid: str) -> Optional[TaskStatusResponse]: url = f"{settings.fsb_host}{settings.product_id}/status/{tsid}/" d = requests.get(url).json() if d.get("msg") == "invalid tsid": return None + return TaskStatusResponse.model_validate(d) def user_cashout_request( - amt_worker_id: str, amount: USDCent, cashout_method_id + amt_worker_id: str, amount: USDCent, cashout_method_id: str ) -> CashoutRequestInfo: assert cashout_method_id in { settings.amt_assignment_cashout_method, @@ -56,7 +65,8 @@ def user_cashout_request( assert isinstance(amount, USDCent) assert USDCent(0) < amount < USDCent(10_00) url = f"{settings.fsb_host}{settings.product_id}/cashout/" - body = { + + body: dict[str, str | int] = { "bpuid": amt_worker_id, "amount": int(amount), "cashout_method_id": cashout_method_id, @@ -81,7 +91,7 @@ def manage_pending_cashout( return UserPayoutEvent.model_validate(d) -def get_wallet_balance(amt_worker_id: str): +def get_wallet_balance(amt_worker_id: str) -> USDCent: url = f"{settings.fsb_host}{settings.product_id}/wallet/" params = {"bpuid": amt_worker_id} return USDCent(requests.get(url, params=params).json()["wallet"]["amount"]) diff --git a/jb/models/assignment.py b/jb/models/assignment.py index 39ae47c..5dd0167 100644 --- a/jb/models/assignment.py +++ b/jb/models/assignment.py @@ -1,6 +1,6 @@ import logging from datetime import datetime, timezone -from typing import Optional, TypedDict +from typing import Optional, TypedDict, Any from xml.etree import ElementTree from mypy_boto3_mturk.type_defs import AssignmentTypeDef @@ -10,7 +10,6 @@ from pydantic import ( ConfigDict, model_validator, PositiveInt, - computed_field, TypeAdapter, ValidationError, ) @@ -116,10 +115,12 @@ class Assignment(AssignmentStub): default=None, min_length=3, max_length=2_000, - help_text="The feedback string included with the call to the " - "ApproveAssignment operation or the RejectAssignment " - "operation, if the Requester approved or rejected the " - "assignment and specified feedback.", + json_schema_extra={ + "help_text": "The feedback string included with the call to the " + "ApproveAssignment operation or the RejectAssignment " + "operation, if the Requester approved or rejected the " + "assignment and specified feedback." + }, ) answer_xml: Optional[str] = Field(default=None, exclude=True) @@ -131,7 +132,7 @@ class Assignment(AssignmentStub): # --- Validators --- @model_validator(mode="before") - def set_tsid(cls, values: dict): + def set_tsid(cls, values: dict[str, Any]) -> dict[str, Any]: if values.get("tsid") is None and (answer_xml := values.get("answer_xml")): answer_dict = cls.parse_answer_xml(answer_xml) tsid = answer_dict.get("tsid") @@ -175,10 +176,10 @@ class Assignment(AssignmentStub): if self.answer_xml is None: return None - return self.parse_answer_xml(self.answer_xml) + return self.parse_answer_xml(self.answer_xml) # type: ignore @staticmethod - def parse_answer_xml(answer_xml: str): + def parse_answer_xml(answer_xml: str) -> dict[str, Any]: root = ElementTree.fromstring(answer_xml) ns = { "mt": "http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2005-10-01/QuestionFormAnswers.xsd" @@ -186,8 +187,8 @@ class Assignment(AssignmentStub): res = {} for a in root.findall("mt:Answer", ns): - name = a.find("mt:QuestionIdentifier", ns).text - value = a.find("mt:FreeText", ns).text + name = a.find("mt:QuestionIdentifier", ns).text # type: ignore + value = a.find("mt:FreeText", ns).text # type: ignore res[name] = value or "" EXPECTED_KEYS = {"amt_assignment_id", "amt_worker_id", "tsid"} diff --git a/jb/models/bonus.py b/jb/models/bonus.py index 564a32d..a536dd1 100644 --- a/jb/models/bonus.py +++ b/jb/models/bonus.py @@ -1,11 +1,10 @@ -from typing import Optional, Dict +from typing import Optional, Dict, Any from pydantic import BaseModel, Field, ConfigDict, PositiveInt from typing_extensions import Self -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.custom_types import AMTBoto3ID, AwareDatetimeISO, UUIDStr -from jb.models.definitions import PayoutStatus class Bonus(BaseModel): @@ -41,7 +40,7 @@ class Bonus(BaseModel): return d @classmethod - def from_postgres(cls, data: Dict) -> Self: + def from_postgres(cls, data: Dict[str, Any]) -> Self: data["amount"] = USDCent(round(data["amount"] * 100)) fields = set(cls.model_fields.keys()) data = {k: v for k, v in data.items() if k in fields} diff --git a/jb/models/currency.py b/jb/models/currency.py deleted file mode 100644 index 3094e2a..0000000 --- a/jb/models/currency.py +++ /dev/null @@ -1,70 +0,0 @@ -import warnings -from decimal import Decimal -from typing import Any - -from pydantic import GetCoreSchemaHandler, NonNegativeInt -from pydantic_core import CoreSchema, core_schema - - -class USDCent(int): - def __new__(cls, value, *args, **kwargs): - - if isinstance(value, float): - warnings.warn( - "USDCent init with a float. Rounding behavior may " "be unexpected" - ) - - if isinstance(value, Decimal): - warnings.warn( - "USDCent init with a Decimal. Rounding behavior may " "be unexpected" - ) - - if value < 0: - raise ValueError("USDCent not be less than zero") - - return super(cls, cls).__new__(cls, value) - - def __add__(self, other): - assert isinstance(other, USDCent) - res = super(USDCent, self).__add__(other) - return self.__class__(res) - - def __sub__(self, other): - assert isinstance(other, USDCent) - res = super(USDCent, self).__sub__(other) - return self.__class__(res) - - def __mul__(self, other): - assert isinstance(other, USDCent) - res = super(USDCent, self).__mul__(other) - return self.__class__(res) - - def __abs__(self): - res = super(USDCent, self).__abs__() - return self.__class__(res) - - def __truediv__(self, other): - raise ValueError("Division not allowed for USDCent") - - def __str__(self): - return "%d" % int(self) - - def __repr__(self): - return "USDCent(%d)" % int(self) - - @classmethod - def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler - ) -> CoreSchema: - """ - https://docs.pydantic.dev/latest/concepts/types/#customizing-validation-with-__get_pydantic_core_schema__ - """ - return core_schema.no_info_after_validator_function( - cls, handler(NonNegativeInt) - ) - - def to_usd(self) -> Decimal: - return Decimal(int(self) / 100).quantize(Decimal(".01")) - - def to_usd_str(self) -> str: - return "${:,.2f}".format(float(self.to_usd())) diff --git a/jb/models/custom_types.py b/jb/models/custom_types.py index 70bc5c1..10bc9d1 100644 --- a/jb/models/custom_types.py +++ b/jb/models/custom_types.py @@ -34,8 +34,7 @@ def convert_str_dt(v: Any) -> Optional[AwareDatetime]: def assert_utc(v: AwareDatetime) -> AwareDatetime: - if isinstance(v, datetime): - assert v.tzinfo == timezone.utc, "Timezone is not UTC" + assert v.tzinfo == timezone.utc, "Timezone is not UTC" return v diff --git a/jb/models/definitions.py b/jb/models/definitions.py index a3d27ba..4ae7a21 100644 --- a/jb/models/definitions.py +++ b/jb/models/definitions.py @@ -1,4 +1,4 @@ -from enum import IntEnum, StrEnum +from enum import IntEnum class AssignmentStatus(IntEnum): @@ -37,32 +37,6 @@ class HitReviewStatus(IntEnum): ReviewedInappropriate = 3 -class PayoutStatus(StrEnum): - """These are GRL's payout statuses""" - - # The user has requested a payout. The money is taken from their - # wallet. A PENDING request can either be APPROVED, REJECTED, or - # CANCELLED. We can also implicitly skip the APPROVED step and go - # straight to COMPLETE or FAILED. - PENDING = "PENDING" - # The request is approved (by us or automatically). Once approved, - # it can be FAILED or COMPLETE. - APPROVED = "APPROVED" - # The request is rejected. The user loses the money. - REJECTED = "REJECTED" - # The user requests to cancel the request, the money goes back into their wallet. - CANCELLED = "CANCELLED" - # The payment was approved, but failed within external payment provider. - # This is an "error" state, as the money won't have moved anywhere. A - # FAILED payment can be tried again and be COMPLETE. - FAILED = "FAILED" - # The payment was sent successfully and (usually) a fee was charged - # to us for it. - COMPLETE = "COMPLETE" - # Not supported # REFUNDED: I'm not sure if this is possible or - # if we'd want to allow it. - - class ReportValue(IntEnum): """ The reason a user reported a task. diff --git a/jb/models/event.py b/jb/models/event.py index c357772..c167420 100644 --- a/jb/models/event.py +++ b/jb/models/event.py @@ -11,13 +11,22 @@ class MTurkEvent(BaseModel): What AWS SNS will POST to our mturk_notifications endpoint (inside the request body) """ - event_type: EventTypeType = Field(example="AssignmentSubmitted") - event_timestamp: AwareDatetimeISO = Field(example="2025-10-16T18:45:51Z") - amt_hit_id: AMTBoto3ID = Field(example="12345678901234567890") + event_type: EventTypeType = Field( + json_schema_extra={"example": "AssignmentSubmitted"} + ) + event_timestamp: AwareDatetimeISO = Field( + json_schema_extra={"example": "2025-10-16T18:45:51Z"} + ) + amt_hit_id: AMTBoto3ID = Field( + json_schema_extra={"example": "12345678901234567890"} + ) amt_assignment_id: str = Field( - max_length=64, example="1234567890123456789012345678901234567890" + max_length=64, + json_schema_extra={"example": "1234567890123456789012345678901234567890"}, + ) + amt_hit_type_id: AMTBoto3ID = Field( + json_schema_extra={"example": "09876543210987654321"} ) - amt_hit_type_id: AMTBoto3ID = Field(example="09876543210987654321") @classmethod def from_sns(cls, data: Dict): diff --git a/jb/models/hit.py b/jb/models/hit.py index c3734fa..fba2ecf 100644 --- a/jb/models/hit.py +++ b/jb/models/hit.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone, timedelta -from typing import Optional, List, Dict +from typing import Optional, List, Dict, Any from uuid import uuid4 from xml.etree import ElementTree @@ -13,7 +13,7 @@ from pydantic import ( ) from typing_extensions import Self -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.custom_types import AMTBoto3ID, HttpsUrlStr, AwareDatetimeISO from jb.models.definitions import HitStatus, HitReviewStatus @@ -104,11 +104,11 @@ class HitType(HitTypeCommon): return d @classmethod - def from_postgres(cls, data: Dict) -> Self: + def from_postgres(cls, data: Dict[str, Any]) -> Self: data["reward"] = USDCent(round(data["reward"] * 100)) return cls.model_validate(data) - def generate_hit_amt_request(self, question: HitQuestion): + def generate_hit_amt_request(self, question: HitQuestion) -> Dict[str, Any]: d = dict() d["HITTypeId"] = self.amt_hit_type_id d["MaxAssignments"] = 1 @@ -135,7 +135,12 @@ class Hit(HitTypeCommon): status: HitStatus = Field() review_status: HitReviewStatus = Field() - creation_time: AwareDatetimeISO = Field(default=None, description="From aws") + + # TODO: Check if this is actually ever going to be None. I type fixed it, + # but I don't have anything to suggest it isn't requred. -- Max 2026-02-24 + creation_time: Optional[AwareDatetimeISO] = Field( + default=None, description="From aws" + ) expiration: Optional[AwareDatetimeISO] = Field(default=None) # GRL Specific @@ -150,7 +155,7 @@ class Hit(HitTypeCommon): # -- Hit specific - qualification_requirements: Optional[List[Dict]] = Field(default=None) + qualification_requirements: Optional[List[Dict[str, Any]]] = Field(default=None) max_assignments: int = Field() # # this comes back as expiration. only for the request @@ -171,7 +176,7 @@ class Hit(HitTypeCommon): assert hit_type.id is not None assert hit_type.amt_hit_type_id is not None - h = Hit.model_validate( + h = cls.model_validate( dict( amt_hit_id=data["HITId"], amt_hit_type_id=data["HITTypeId"], @@ -194,11 +199,12 @@ class Hit(HitTypeCommon): hit_type_id=hit_type.id, ) ) + return h @classmethod def from_amt_get_hit(cls, data: HITTypeDef) -> Self: - h = Hit.model_validate( + h = cls.model_validate( dict( amt_hit_id=data["HITId"], amt_hit_type_id=data["HITTypeId"], @@ -229,7 +235,7 @@ class Hit(HitTypeCommon): return d @classmethod - def from_postgres(cls, data: Dict) -> Self: + def from_postgres(cls, data: Dict[str, Any]) -> Self: data["reward"] = USDCent(round(data["reward"] * 100)) return cls.model_validate(data) diff --git a/jb/settings.py b/jb/settings.py index 538b89f..5754add 100644 --- a/jb/settings.py +++ b/jb/settings.py @@ -45,7 +45,7 @@ class Settings(AmtJbBaseSettings): debug: bool = False app_name: str = "AMT JB API" - fsb_host: HttpUrl = Field(default="https://fsb.generalresearch.com/") + fsb_host: HttpUrl = Field(default=HttpUrl("https://fsb.generalresearch.com/")) # Needed for admin function on fsb w/o authentication fsb_host_private_route: Optional[str] = Field(default=None) diff --git a/jb/views/common.py b/jb/views/common.py index 46ac608..0dc8b56 100644 --- a/jb/views/common.py +++ b/jb/views/common.py @@ -11,7 +11,7 @@ from starlette.responses import RedirectResponse from jb.config import settings, JB_EVENTS_STREAM from jb.decorators import REDIS, HM from jb.flow.monitoring import emit_assignment_event, emit_mturk_notification_event -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.definitions import ReportValue, AssignmentStatus from jb.models.event import MTurkEvent from jb.settings import BASE_HTML @@ -71,6 +71,7 @@ async def work(request: Request): url=f"/preview/?{request.url.query}" if request.url.query else "/preview/", status_code=302, ) + if amt_assignment_id is None or amt_assignment_id == "ASSIGNMENT_ID_NOT_AVAILABLE": # Worker is previewing the HIT amt_hit_type_id = "unknown" @@ -91,71 +92,6 @@ async def work(request: Request): return HTMLResponse(BASE_HTML) -@common_router.get(path="/survey/", response_class=JSONResponse) -def survey( - request: Request, - worker_id: str = Query(), - duration: int = Query(default=1200), -): - if not worker_id: - raise HTTPException(status_code=400, detail="Missing worker_id") - - # (1) Check wallet - wallet_url = f"{settings.fsb_host}{settings.product_id}/wallet/" - wallet_res = requests.get(wallet_url, params={"bpuid": worker_id}) - if wallet_res.status_code != 200: - raise HTTPException(status_code=502, detail="Wallet check failed") - - wallet_data = wallet_res.json() - wallet_balance = wallet_data["wallet"]["amount"] - if wallet_balance < -100: - return JSONResponse( - { - "total_surveys": 0, - "link": None, - "duration": None, - "payout": None, - } - ) - - # (2) Get offerwall - client_ip = "69.253.144.55" if settings.debug else request.client.host - offerwall_url = f"{settings.fsb_host}{settings.product_id}/offerwall/d48cce47/" - offerwall_res = requests.get( - offerwall_url, - params={ - "bpuid": worker_id, - "ip": client_ip, - "n_bins": 1, - "duration": duration, - }, - ) - - if offerwall_res.status_code != 200: - raise HTTPException(status_code=502, detail="Offerwall request failed") - - try: - rj = offerwall_res.json() - bucket = rj["offerwall"]["buckets"][0] - return JSONResponse( - { - "total_surveys": rj["offerwall"]["availability_count"], - "link": bucket["uri"], - "duration": round(bucket["duration"]["q2"] / 60), - "payout": USDCent(bucket["payout"]["q2"]).to_usd_str(), - } - ) - except Exception: - return JSONResponse( - { - "total_surveys": 0, - "link": None, - "duration": None, - "payout": None, - } - ) - - @common_router.post(path=f"/{settings.sns_path}/", include_in_schema=False) async def mturk_notifications(request: Request): """ diff --git a/jb/views/tasks.py b/jb/views/tasks.py index 7176b29..313295f 100644 --- a/jb/views/tasks.py +++ b/jb/views/tasks.py @@ -18,18 +18,23 @@ def process_request(request: Request) -> None: amt_assignment_id = request.query_params.get("assignmentId", None) if amt_assignment_id == "ASSIGNMENT_ID_NOT_AVAILABLE": raise ValueError("shouldn't happen") + amt_hit_id = request.query_params.get("hitId", None) amt_worker_id = request.query_params.get("workerId", None) print(f"process_request: {amt_assignment_id=} {amt_worker_id=} {amt_hit_id=}") assert amt_worker_id and amt_hit_id and amt_assignment_id # Check that the HIT is still valid - hit = HM.get_from_amt_id(amt_hit_id=amt_hit_id) + hit = HM.get_from_amt_id_if_exists(amt_hit_id=amt_hit_id) + if not hit: + raise ValueError(f"Hit {amt_hit_id} not found in DB") + _ = check_hit_status(amt_hit_id=amt_hit_id, amt_hit_type_id=hit.amt_hit_type_id) emit_assignment_event( status=AssignmentStatus.Accepted, amt_hit_type_id=hit.amt_hit_type_id, ) + # I think it won't be assignable anymore? idk # assert hit_status == HitStatus.Assignable, f"hit {amt_hit_id} {hit_status=}. Expected Assignable" @@ -53,7 +58,7 @@ def process_request(request: Request) -> None: assert assignment_stub.amt_worker_id == amt_worker_id assert assignment_stub.amt_assignment_id == amt_assignment_id assert assignment_stub.created_at > ( - datetime.now(tz=timezone.utc) - timedelta(minutes=90) + datetime.now(tz=timezone.utc) - timedelta(minutes=90) ) return None diff --git a/jb/views/utils.py b/jb/views/utils.py index 39db5d2..0d08e9b 100644 --- a/jb/views/utils.py +++ b/jb/views/utils.py @@ -8,8 +8,11 @@ def get_client_ip(request: Request) -> str: """ ip = request.headers.get("X-Forwarded-For") if not ip: - ip = request.client.host + ip = request.client.host # type: ignore elif ip == "testclient" or ip.startswith("10."): forwarded = request.headers.get("X-Forwarded-For") - ip = forwarded.split(",")[0].strip() if forwarded else request.client.host + ip = ( + forwarded.split(",")[0].strip() if forwarded else request.client.host # type: ignore + ) + return ip -- cgit v1.2.3