diff options
| author | Max Nanis | 2026-02-24 17:26:15 -0500 |
|---|---|---|
| committer | Max Nanis | 2026-02-24 17:26:15 -0500 |
| commit | 8c1940445503fd6678d0961600f2be81622793a2 (patch) | |
| tree | b9173562b8824b5eaa805e446d9d780e1f23fb2a | |
| parent | 25d8c3c214baf10f6520cc1351f78473150e5d7a (diff) | |
| download | amt-jb-8c1940445503fd6678d0961600f2be81622793a2.tar.gz amt-jb-8c1940445503fd6678d0961600f2be81622793a2.zip | |
Extensive use of type checking. Movement of pytest conf towards handling managers (for db agnostic unittest). Starting to organize pytests.
38 files changed, 604 insertions, 588 deletions
diff --git a/jb/decorators.py b/jb/decorators.py index 54d36c7..cbc28b5 100644 --- a/jb/decorators.py +++ b/jb/decorators.py @@ -66,9 +66,7 @@ AM = AssignmentManager( pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE] ) -BM = BonusManager( - pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE] -) +BM = BonusManager(pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE]) influx_client = None if settings.influx_db: diff --git a/jb/flow/assignment_tasks.py b/jb/flow/assignment_tasks.py index 4022716..bb0877d 100644 --- a/jb/flow/assignment_tasks.py +++ b/jb/flow/assignment_tasks.py @@ -5,6 +5,7 @@ from typing import Optional from generalresearchutils.models.thl.definitions import PayoutStatus, StatusCode1 from generalresearchutils.models.thl.wallet.cashout_method import CashoutRequestInfo +from generalresearchutils.currency import USDCent from jb.decorators import AM, HM, BM from jb.flow.monitoring import emit_error_event, emit_assignment_event, emit_bonus_event @@ -21,21 +22,20 @@ from jb.managers.thl import ( get_user_blocked, get_task_status, user_cashout_request, - AMT_ASSIGNMENT_CASHOUT_METHOD, manage_pending_cashout, get_user_blocked_or_not_exists, get_wallet_balance, - AMT_BONUS_CASHOUT_METHOD, ) from jb.models.assignment import Assignment -from jb.models.currency import USDCent from jb.models.definitions import AssignmentStatus from jb.models.event import MTurkEvent +from jb.config import settings def process_assignment_submitted(event: MTurkEvent) -> None: """ - Called either directly or from the SNS Notification that a HIT was submitted + Called either directly or from the SNS Notification that a + HIT was submitted :return: None """ @@ -137,13 +137,22 @@ def process_assignment_submitted(event: MTurkEvent) -> None: return issue_worker_payment(assignment) -def review_hit(assignment): +def review_hit(assignment: Assignment) -> None: # Reviewable to Reviewing AMTManager.update_hit_review_status(amt_hit_id=assignment.amt_hit_id, revert=False) hit, _ = AMTManager.get_hit_if_exists(amt_hit_id=assignment.amt_hit_id) + + if hit is None: + logging.warning( + f"Hit not found when trying to review hit: {assignment.amt_hit_id}" + ) + return None + # Update the db HM.update_hit(hit) + return None + def handle_assignment_w_no_work(assignment: Assignment) -> Assignment: """ @@ -267,6 +276,10 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment: amt_worker_id = assignment.amt_worker_id amt_assignment_id = assignment.amt_assignment_id tsid = assignment.tsid + assert ( + tsid is not None + ), "Assignment must have a tsid to be handled in handle_assignment_w_work" + hit = HM.get_from_amt_id(amt_hit_id=assignment.amt_hit_id) tsr = get_task_status(tsid=tsid) @@ -289,6 +302,7 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment: event_type = "assignment_submitted_quality_fail" else: event_type = "assignment_submitted_work_not_complete" + emit_error_event( event_type=event_type, amt_hit_type_id=hit.amt_hit_type_id, @@ -324,6 +338,8 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment: ) return assignment + assert req.id + # We've approved the HIT payment, now update the db to reflect this, and approve the assignment assignment = approve_assignment( amt_assignment_id=amt_assignment_id, @@ -331,7 +347,9 @@ def handle_assignment_w_work(assignment: Assignment) -> Assignment: amt_hit_type_id=hit.amt_hit_type_id, ) # We complete after the assignment is approved - complete_res = manage_pending_cashout(req.id, PayoutStatus.COMPLETE) + complete_res = manage_pending_cashout( + cashout_id=req.id, payout_status=PayoutStatus.COMPLETE + ) if complete_res.status != PayoutStatus.COMPLETE: # unclear wny this would happen raise ValueError(f"Failed to complete cashout: {req.id}") @@ -345,8 +363,9 @@ def submit_and_approve_amt_assignment_request( req = user_cashout_request( amt_worker_id=amt_worker_id, amount=amount, - cashout_method_id=AMT_ASSIGNMENT_CASHOUT_METHOD, + cashout_method_id=settings.amt_assignment_cashout_method, ) + assert req.id if req.status != PayoutStatus.PENDING: return None @@ -365,8 +384,9 @@ def submit_and_approve_amt_bonus_request( req = user_cashout_request( amt_worker_id=amt_worker_id, amount=amount, - cashout_method_id=AMT_BONUS_CASHOUT_METHOD, + cashout_method_id=settings.amt_bonus_cashout_method, ) + assert req.id if req.status != PayoutStatus.PENDING: return None @@ -404,6 +424,7 @@ def issue_worker_payment(assignment: Assignment) -> None: amt_hit_type_id=hit.amt_hit_type_id, ) return None + assert pe.id AMTManager.send_bonus( amt_worker_id=assignment.amt_worker_id, @@ -412,13 +433,26 @@ def issue_worker_payment(assignment: Assignment) -> None: reason=BONUS_MESSAGE, unique_request_token=pe.id, ) + # Confirm it was sent through amt bonus = AMTManager.get_bonus( amt_assignment_id=assignment.amt_assignment_id, payout_event_id=pe.id ) + + if bonus is None: + logging.warning( + f"Failed to find bonus after sending it: {amt_assignment_id} {pe.id}" + ) + emit_error_event( + event_type="bonus_not_found_after_sending", + amt_hit_type_id=hit.amt_hit_type_id, + ) + return None + # Create in DB - BM.create(bonus) + BM.create(bonus=bonus) emit_bonus_event(amount=amount, amt_hit_type_id=hit.amt_hit_type_id) + # Complete cashout res = manage_pending_cashout(pe.id, PayoutStatus.COMPLETE) if res.status != PayoutStatus.COMPLETE: diff --git a/jb/flow/events.py b/jb/flow/events.py index 3961a64..7b7bd32 100644 --- a/jb/flow/events.py +++ b/jb/flow/events.py @@ -1,8 +1,8 @@ import logging import time from concurrent import futures -from concurrent.futures import ThreadPoolExecutor, Executor, as_completed -from typing import Optional +from concurrent.futures import ThreadPoolExecutor, Executor +from typing import Optional, cast, TypedDict import redis @@ -16,6 +16,15 @@ from jb.decorators import REDIS from jb.flow.assignment_tasks import process_assignment_submitted from jb.models.event import MTurkEvent +StreamMessages = list[tuple[str, list[tuple[bytes, dict[bytes, bytes]]]]] + + +class PendingEntry(TypedDict): + message_id: bytes + consumer: bytes + time_since_delivered: int + times_delivered: int + def process_mturk_events_task(): executor = ThreadPoolExecutor(max_workers=5) @@ -57,24 +66,26 @@ def create_consumer_group(): def process_mturk_events_chunk(executor: Executor) -> Optional[int]: - msgs = REDIS.xreadgroup( + msgs_raw = REDIS.xreadgroup( groupname=CONSUMER_GROUP, consumername=CONSUMER_NAME, streams={JB_EVENTS_STREAM: ">"}, count=10, ) - if not msgs: + if not msgs_raw: return None - msgs = msgs[0][1] # the queue, we only have 1 + + msgs = cast(StreamMessages, msgs_raw)[0][1] # The queue, we only have 1 fs = [] for msg in msgs: msg_id, data = msg - msg_json = data["data"] - event = MTurkEvent.model_validate_json(msg_json) + msg_json: str = data["data"] + + event = MTurkEvent.model_validate_json(json_data=msg_json) if event.event_type == "AssignmentSubmitted": fs.append( - executor.submit(process_assignment_submitted_event, event, msg_id) + executor.submit(process_assignment_submitted_event, event, str(msg_id)) ) else: logging.info(f"Discarding {event}") @@ -93,9 +104,13 @@ def handle_pending_msgs(): # Looks in the redis queue for msgs that # are pending (read by a consumer but not ACK). These prob failed. # Below is from chatgpt, idk if it works - pending = REDIS.xpending_range( - JB_EVENTS_STREAM, CONSUMER_GROUP, min="-", max="+", count=10 + pending = cast( + list[PendingEntry], + REDIS.xpending_range( + JB_EVENTS_STREAM, CONSUMER_GROUP, min="-", max="+", count=10 + ), ) + for entry in pending: msg_id = entry["message_id"] # Claim message if idle > 10 sec diff --git a/jb/flow/maintenance.py b/jb/flow/maintenance.py index 5dc9cea..744ed1d 100644 --- a/jb/flow/maintenance.py +++ b/jb/flow/maintenance.py @@ -11,16 +11,20 @@ def check_hit_status( ) -> HitStatus: """ (this used to be called "process_hit") - Request information from Amazon regarding the status of a HIT ID. Update the local state from - that response. + Request information from Amazon regarding the status of a HIT ID. Update + the local state from that response. """ + hit_status = AMTManager.get_hit_status(amt_hit_id=amt_hit_id) - # We're assuming that in the db this Hit is marked as Assignable, or else we wouldn't - # have called this function. + + # We're assuming that in the db this Hit is marked as Assignable, or + # else we wouldn't have called this function. if hit_status != HitStatus.Assignable: - # todo: should update also assignment_pending_count, assignment_available_count, assignment_completed_count + # TODO: should update also assignment_pending_count, + # assignment_available_count, assignment_completed_count HM.update_status(amt_hit_id=amt_hit_id, hit_status=hit_status) emit_hit_event( status=hit_status, amt_hit_type_id=amt_hit_type_id, reason=reason ) + return hit_status diff --git a/jb/flow/monitoring.py b/jb/flow/monitoring.py index c8432bb..28f7271 100644 --- a/jb/flow/monitoring.py +++ b/jb/flow/monitoring.py @@ -5,11 +5,12 @@ from mypy_boto3_mturk.literals import EventTypeType from jb.config import settings from jb.decorators import influx_client -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.definitions import HitStatus, AssignmentStatus -def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int): +def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int) -> None: + tags = { "host": socket.gethostname(), # could be "amt-jb-0" "service": "amt-jb", @@ -23,10 +24,14 @@ def write_hit_gauge(status: HitStatus, amt_hit_type_id: str, cnt: int): "fields": {"count": cnt}, } if influx_client: - influx_client.write_points([point]) + influx_client.write_points(points=[point]) + + return None -def write_assignment_gauge(status: AssignmentStatus, amt_hit_type_id: str, cnt: int): +def write_assignment_gauge( + status: AssignmentStatus, amt_hit_type_id: str, cnt: int +) -> None: tags = { "host": socket.gethostname(), "service": "amt-jb", @@ -40,15 +45,17 @@ def write_assignment_gauge(status: AssignmentStatus, amt_hit_type_id: str, cnt: "fields": {"count": cnt}, } if influx_client: - influx_client.write_points([point]) + influx_client.write_points(points=[point]) + + return None def emit_hit_event( status: HitStatus, amt_hit_type_id: str, reason: Optional[str] = None -): +) -> None: """ - e.g. a HIT was created, Reviewable, etc. We don't have a "created" HitStatus, - so it would just be when status=='Assignable' + e.g. a HIT was created, Reviewable, etc. We don't have a "created" + HitStatus, so it would just be when status=='Assignable' """ tags = { "host": socket.gethostname(), @@ -57,8 +64,10 @@ def emit_hit_event( "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + if reason: tags["reason"] = reason + point = { "measurement": "amt_jb.hit_events", "tags": tags, @@ -68,10 +77,12 @@ def emit_hit_event( if influx_client: influx_client.write_points([point]) + return None + def emit_assignment_event( status: AssignmentStatus, amt_hit_type_id: str, reason: Optional[str] = None -): +) -> None: """ e.g. an Assignment was accepted/approved/reject """ @@ -82,8 +93,10 @@ def emit_assignment_event( "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + if reason: tags["reason"] = reason + point = { "measurement": "amt_jb.assignment_events", "tags": tags, @@ -93,10 +106,15 @@ def emit_assignment_event( if influx_client: influx_client.write_points([point]) + return None + -def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: str): +def emit_mturk_notification_event( + event_type: EventTypeType, amt_hit_type_id: str +) -> None: """ - e.g. a Mturk notification was received. We just put it in redis, we haven't processed it yet. + e.g. a Mturk notification was received. We just put it in redis, we + haven't processed it yet. """ tags = { "host": socket.gethostname(), @@ -105,6 +123,7 @@ def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: st "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + point = { "measurement": "amt_jb.mturk_notification_events", "tags": tags, @@ -114,8 +133,10 @@ def emit_mturk_notification_event(event_type: EventTypeType, amt_hit_type_id: st if influx_client: influx_client.write_points([point]) + return None + -def emit_error_event(event_type: str, amt_hit_type_id: str): +def emit_error_event(event_type: str, amt_hit_type_id: str) -> None: """ e.g. todo: structure the error_types """ @@ -126,6 +147,7 @@ def emit_error_event(event_type: str, amt_hit_type_id: str): "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + point = { "measurement": "amt_jb.error_events", "tags": tags, @@ -135,8 +157,10 @@ def emit_error_event(event_type: str, amt_hit_type_id: str): if influx_client: influx_client.write_points([point]) + return None + -def emit_bonus_event(amount: USDCent, amt_hit_type_id: str): +def emit_bonus_event(amount: USDCent, amt_hit_type_id: str) -> None: """ An AMT bonus was awarded """ @@ -146,6 +170,7 @@ def emit_bonus_event(amount: USDCent, amt_hit_type_id: str): "amt_hit_type_id": amt_hit_type_id, "debug": settings.debug, } + point = { "measurement": "amt_jb.bonus_events", "tags": tags, @@ -154,3 +179,5 @@ def emit_bonus_event(amount: USDCent, amt_hit_type_id: str): if influx_client: influx_client.write_points([point]) + + return None diff --git a/jb/flow/setup_tasks.py b/jb/flow/setup_tasks.py new file mode 100644 index 0000000..4664374 --- /dev/null +++ b/jb/flow/setup_tasks.py @@ -0,0 +1,36 @@ +from jb.config import TOPIC_ARN, SUBSCRIPTION +from jb.decorators import SNS_CLIENT, AMT_CLIENT +from jb.config import settings + + +def initial_setup(): + # Run once for initial setup. Not on each server start or anything + subscription = SNS_CLIENT.subscribe( # type: ignore + TopicArn=TOPIC_ARN, + Protocol="https", + Endpoint=f"https://jamesbillings67.com/{settings.sns_path}/", + ReturnSubscriptionArn=True, + ) + + +def check_sns_configuration(): + SNS_CLIENT.get_topic_attributes(TopicArn=TOPIC_ARN) + + # check this TOPIC_ARN exists + # (doesnt have permission, dont need this anyways) + # res = SNS_CLIENT.list_topics() + # arns = {x["TopicArn"] for x in res["Topics"]} + # assert TOPIC_ARN in arns, f"SNS Topic {TOPIC_ARN} doesn't exist!" + + subs = SNS_CLIENT.list_subscriptions_by_topic(TopicArn=TOPIC_ARN) + assert SUBSCRIPTION in subs["Subscriptions"] + + AMT_CLIENT.send_test_event_notification( + Notification={ + "Destination": TOPIC_ARN, + "Transport": "SNS", + "Version": "2006-05-05", + "EventTypes": ["AssignmentSubmitted"], + }, + TestEventType="AssignmentSubmitted", + ) diff --git a/jb/flow/tasks.py b/jb/flow/tasks.py index e7c64b9..808c8f7 100644 --- a/jb/flow/tasks.py +++ b/jb/flow/tasks.py @@ -1,5 +1,6 @@ import logging import time +from typing import TypedDict, cast from generalresearchutils.config import is_debug @@ -15,6 +16,11 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) +class HitRow(TypedDict): + amt_hit_id: str + amt_hit_type_id: str + + def check_stale_hits(): # Check live hits that haven't been modified in a long time. They may not # be expired yet, but maybe something is wrong? @@ -28,7 +34,7 @@ def check_stale_hits(): LIMIT 100;""", params={"status": HitStatus.Assignable.value}, ) - for hit in res: + for hit in cast(list[HitRow], res): logging.info(f"check_stale_hits: {hit["amt_hit_id"]}") check_hit_status( amt_hit_id=hit["amt_hit_id"], @@ -49,7 +55,7 @@ def check_expired_hits(): LIMIT 100;""", params={"status": HitStatus.Assignable.value}, ) - for hit in res: + for hit in cast(list[HitRow], res): logging.info(f"check_expired_hits: {hit["amt_hit_id"]}") check_hit_status( amt_hit_id=hit["amt_hit_id"], @@ -74,8 +80,12 @@ def create_hit_from_hittype(hit_type: HitType) -> Hit: def refill_hits() -> None: + for hit_type in HTM.filter_active(): - active_count = HM.get_active_count(hit_type.id) + assert hit_type.id + assert hit_type.amt_hit_type_id + + active_count = HM.get_active_count(hit_type_id=hit_type.id) logging.info( f"HitType: {hit_type.amt_hit_type_id}, {hit_type.min_active=}, active_count={active_count}" ) @@ -1,6 +1,7 @@ from multiprocessing import Process +from typing import Any, Dict -from fastapi import FastAPI, Request +from fastapi import FastAPI from fastapi.responses import HTMLResponse from starlette.middleware.cors import CORSMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware @@ -35,7 +36,7 @@ app.include_router(router=common_router) @app.get("/robots.txt") @app.get("/sitemap.xml") @app.get("/favicon.ico") -def return_nothing(): +def return_nothing() -> Dict[str, Any]: return {} @@ -47,7 +48,7 @@ def serve_react_app(full_path: str): def schedule_tasks(): - from jb.flow.events import process_mturk_events_task, handle_pending_msgs_task + from jb.flow.events import process_mturk_events_task from jb.flow.tasks import refill_hits_task Process(target=process_mturk_events_task).start() diff --git a/jb/managers/__init__.py b/jb/managers/__init__.py index e2aab6d..e99569a 100644 --- a/jb/managers/__init__.py +++ b/jb/managers/__init__.py @@ -15,8 +15,8 @@ class PostgresManager: def __init__( self, pg_config: PostgresConfig, - permissions: Collection[Permission] = None, - **kwargs, + permissions: Collection[Permission] = None, # type: ignore + **kwargs, # type: ignore ): super().__init__(**kwargs) self.pg_config = pg_config diff --git a/jb/managers/amt.py b/jb/managers/amt.py index 79661c7..0ec70d3 100644 --- a/jb/managers/amt.py +++ b/jb/managers/amt.py @@ -10,7 +10,7 @@ from jb.decorators import AMT_CLIENT from jb.models import AMTAccount from jb.models.assignment import Assignment from jb.models.bonus import Bonus -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.definitions import HitStatus from jb.models.hit import HitType, HitQuestion, Hit @@ -48,19 +48,24 @@ class AMTManager: return hit, None @classmethod - def get_hit_status(cls, amt_hit_id: str): + def get_hit_status(cls, amt_hit_id: str) -> HitStatus: res, msg = cls.get_hit_if_exists(amt_hit_id=amt_hit_id) + if res is None: + if msg is None: + return HitStatus.Unassignable + if " does not exist. (" in msg: return HitStatus.Disposed else: logging.warning(msg) return HitStatus.Unassignable + return res.status @staticmethod def create_hit_type(hit_type: HitType): - res = AMT_CLIENT.create_hit_type(**hit_type.to_api_request_body()) + res = AMT_CLIENT.create_hit_type(**hit_type.to_api_request_body()) # type: ignore hit_type.amt_hit_type_id = res["HITTypeId"] AMT_CLIENT.update_notification_settings( HITTypeId=hit_type.amt_hit_type_id, @@ -94,8 +99,10 @@ class AMTManager: @staticmethod def get_assignment(amt_assignment_id: str) -> Assignment: - # note, you CANNOT get an assignment if it has been only ACCEPTED (by the worker) - # the api is stupid. it will only show up once it is submitted + """ + You CANNOT get an Assignment if it has been only ACCEPTED (by the + worker). The api is stupid, it will only show up once it is Submitted + """ res = AMT_CLIENT.get_assignment(AssignmentId=amt_assignment_id) ass_res: AssignmentTypeDef = res["Assignment"] assignment = Assignment.from_amt_get_assignment(ass_res) @@ -158,6 +165,7 @@ class AMTManager: raise ValueError(error_msg) # elif "This HIT is currently in the state 'Reviewing'" in error_msg: # logging.warning(error_msg) + return None @staticmethod @@ -203,7 +211,7 @@ class AMTManager: return None @staticmethod - def expire_all_hits(): + def expire_all_hits() -> None: # used in testing only (or in an emergency I guess) now = datetime.now(tz=timezone.utc) paginator = AMT_CLIENT.get_paginator("list_hits") @@ -214,3 +222,4 @@ class AMTManager: AMT_CLIENT.update_expiration_for_hit( HITId=hit["HITId"], ExpireAt=now ) + return None diff --git a/jb/managers/assignment.py b/jb/managers/assignment.py index fca72e8..dd3c866 100644 --- a/jb/managers/assignment.py +++ b/jb/managers/assignment.py @@ -28,7 +28,7 @@ class AssignmentManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() stub.id = pk return None @@ -62,7 +62,7 @@ class AssignmentManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() assignment.id = pk return None @@ -233,7 +233,7 @@ class AssignmentManager(PostgresManager): "lookback_interval": f"{lookback_hrs} hour", }, ) - return int(res[0]["c"]) + return int(res[0]["c"]) # type: ignore def rejected_count( self, amt_worker_id: str, lookback_hrs: int = 24 @@ -256,4 +256,4 @@ class AssignmentManager(PostgresManager): "status": AssignmentStatus.Rejected.value, }, ) - return int(res[0]["c"]) + return int(res[0]["c"]) # type: ignore diff --git a/jb/managers/bonus.py b/jb/managers/bonus.py index 0cb8b02..89b81f0 100644 --- a/jb/managers/bonus.py +++ b/jb/managers/bonus.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Any from psycopg import sql @@ -37,12 +37,12 @@ class BonusManager(PostgresManager): c.execute(query, data) res = c.fetchone() conn.commit() - bonus.id = res["id"] - bonus.assignment_id = res["assignment_id"] + bonus.id = res["id"] # type: ignore + bonus.assignment_id = res["assignment_id"] # type: ignore return None def filter(self, amt_assignment_id: str) -> List[Bonus]: - res = self.pg_config.execute_sql_query( + res: List[Any] = self.pg_config.execute_sql_query( """ SELECT mb.*, ma.amt_assignment_id FROM mtwerk_bonus mb @@ -51,4 +51,5 @@ class BonusManager(PostgresManager): """, params={"amt_assignment_id": amt_assignment_id}, ) - return [Bonus.from_postgres(x) for x in res] + + return [Bonus.from_postgres(data=x) for x in res] diff --git a/jb/managers/hit.py b/jb/managers/hit.py index 3832418..ce8ffa5 100644 --- a/jb/managers/hit.py +++ b/jb/managers/hit.py @@ -24,7 +24,7 @@ class HitQuestionManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() question.id = pk return None @@ -67,6 +67,7 @@ class HitQuestionManager(PostgresManager): class HitTypeManager(PostgresManager): + def create(self, hit_type: HitType) -> None: assert hit_type.amt_hit_type_id is not None data = hit_type.to_postgres() @@ -99,9 +100,10 @@ class HitTypeManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() hit_type.id = pk + return None def filter_active(self) -> List[HitType]: @@ -137,13 +139,13 @@ class HitTypeManager(PostgresManager): except AssertionError: return None - def get_or_create(self, hit_type: HitType) -> None: + def get_or_create(self, hit_type: HitType) -> HitType: res = self.get_if_exists(amt_hit_type_id=hit_type.amt_hit_type_id) if res: - hit_type.id = res.id - if res is None: - self.create(hit_type=hit_type) - return None + return res + + self.create(hit_type=hit_type) + return self.get(amt_hit_type_id=hit_type.amt_hit_type_id) def set_min_active(self, hit_type: HitType) -> None: assert hit_type.id, "must be in the db first!" @@ -164,6 +166,7 @@ class HitTypeManager(PostgresManager): class HitManager(PostgresManager): + def create(self, hit: Hit): assert hit.amt_hit_id is not None assert hit.id is None @@ -209,7 +212,7 @@ class HitManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() hit.id = pk return hit @@ -267,7 +270,7 @@ class HitManager(PostgresManager): c.execute(query, data) conn.commit() assert c.rowcount == 1, c.rowcount - hit.id = c.fetchone()["id"] + hit.id = c.fetchone()["id"] # type: ignore return None def get_from_amt_id(self, amt_hit_id: str) -> Hit: @@ -305,17 +308,26 @@ class HitManager(PostgresManager): """, params={"amt_hit_id": amt_hit_id}, ) - assert len(res) == 1 + assert len(res) == 1, "Incorrect number of results" res = res[0] + question_xml = HitQuestion.model_validate( {"height": res.pop("height"), "url": res.pop("url")} ).xml + res["question_id"] = res["question_id"] res["hit_question_xml"] = question_xml return Hit.from_postgres(res) - def get_active_count(self, hit_type_id: int): + def get_from_amt_id_if_exists(self, amt_hit_id: str) -> Optional[Hit]: + try: + return self.get_from_amt_id(amt_hit_id=amt_hit_id) + + except (AssertionError, Exception): + return None + + def get_active_count(self, hit_type_id: int) -> int: return self.pg_config.execute_sql_query( """ SELECT COUNT(1) as active_count @@ -326,7 +338,7 @@ class HitManager(PostgresManager): params={"status": HitStatus.Assignable, "hit_type_id": hit_type_id}, )[0]["active_count"] - def filter_active_ids(self, hit_type_id: int): + def filter_active_ids(self, hit_type_id: int) -> set[str]: res = self.pg_config.execute_sql_query( """ SELECT mh.amt_hit_id diff --git a/jb/managers/thl.py b/jb/managers/thl.py index b1dcbde..83f49f6 100644 --- a/jb/managers/thl.py +++ b/jb/managers/thl.py @@ -1,7 +1,3 @@ -from decimal import Decimal -from typing import Dict, Optional - -import requests from generalresearchutils.models.thl.payout import UserPayoutEvent from generalresearchutils.models.thl.task_status import TaskStatusResponse from generalresearchutils.models.thl.wallet.cashout_method import ( @@ -9,45 +5,58 @@ from generalresearchutils.models.thl.wallet.cashout_method import ( CashoutRequestInfo, ) +from generalresearchutils.models.thl.user_profile import UserProfile +from generalresearchutils.currency import USDCent + from jb.config import settings -from jb.models.currency import USDCent -from jb.models.definitions import PayoutStatus -# TODO: Organize this more with other endpoints (offerwall, cashout requests/approvals, etc). +from generalresearchutils.models.thl.definitions import PayoutStatus + + +from typing import Optional +import requests + +# TODO: Organize this more with other endpoints (offerwall, cashout +# requests/approvals, etc). -def get_user_profile(amt_worker_id: str) -> Dict: +def get_user_profile(amt_worker_id: str) -> UserProfile: url = f"{settings.fsb_host}{settings.product_id}/user/{amt_worker_id}/profile/" res = requests.get(url).json() if res.get("detail") == "user not found": raise ValueError("user not found") - return res["user_profile"] + + return UserProfile.model_validate(res["user_profile"]) def get_user_blocked(amt_worker_id: str) -> bool: + # Not blocked if None res = get_user_profile(amt_worker_id=amt_worker_id) - return res["user"]["blocked"] + return res.user.blocked if res.user.blocked is not None else False -def get_user_blocked_or_not_exists(amt_worker_id: str) -> bool: +def get_user_blocked_or_not_exists(amt_worker_id: str) -> Optional[bool]: try: res = get_user_profile(amt_worker_id=amt_worker_id) - return res["user"]["blocked"] + return res.user.blocked if res.user.blocked is not None else False except ValueError as e: if e.args[0] == "user not found": return True + return None + def get_task_status(tsid: str) -> Optional[TaskStatusResponse]: url = f"{settings.fsb_host}{settings.product_id}/status/{tsid}/" d = requests.get(url).json() if d.get("msg") == "invalid tsid": return None + return TaskStatusResponse.model_validate(d) def user_cashout_request( - amt_worker_id: str, amount: USDCent, cashout_method_id + amt_worker_id: str, amount: USDCent, cashout_method_id: str ) -> CashoutRequestInfo: assert cashout_method_id in { settings.amt_assignment_cashout_method, @@ -56,7 +65,8 @@ def user_cashout_request( assert isinstance(amount, USDCent) assert USDCent(0) < amount < USDCent(10_00) url = f"{settings.fsb_host}{settings.product_id}/cashout/" - body = { + + body: dict[str, str | int] = { "bpuid": amt_worker_id, "amount": int(amount), "cashout_method_id": cashout_method_id, @@ -81,7 +91,7 @@ def manage_pending_cashout( return UserPayoutEvent.model_validate(d) -def get_wallet_balance(amt_worker_id: str): +def get_wallet_balance(amt_worker_id: str) -> USDCent: url = f"{settings.fsb_host}{settings.product_id}/wallet/" params = {"bpuid": amt_worker_id} return USDCent(requests.get(url, params=params).json()["wallet"]["amount"]) diff --git a/jb/models/assignment.py b/jb/models/assignment.py index 39ae47c..5dd0167 100644 --- a/jb/models/assignment.py +++ b/jb/models/assignment.py @@ -1,6 +1,6 @@ import logging from datetime import datetime, timezone -from typing import Optional, TypedDict +from typing import Optional, TypedDict, Any from xml.etree import ElementTree from mypy_boto3_mturk.type_defs import AssignmentTypeDef @@ -10,7 +10,6 @@ from pydantic import ( ConfigDict, model_validator, PositiveInt, - computed_field, TypeAdapter, ValidationError, ) @@ -116,10 +115,12 @@ class Assignment(AssignmentStub): default=None, min_length=3, max_length=2_000, - help_text="The feedback string included with the call to the " - "ApproveAssignment operation or the RejectAssignment " - "operation, if the Requester approved or rejected the " - "assignment and specified feedback.", + json_schema_extra={ + "help_text": "The feedback string included with the call to the " + "ApproveAssignment operation or the RejectAssignment " + "operation, if the Requester approved or rejected the " + "assignment and specified feedback." + }, ) answer_xml: Optional[str] = Field(default=None, exclude=True) @@ -131,7 +132,7 @@ class Assignment(AssignmentStub): # --- Validators --- @model_validator(mode="before") - def set_tsid(cls, values: dict): + def set_tsid(cls, values: dict[str, Any]) -> dict[str, Any]: if values.get("tsid") is None and (answer_xml := values.get("answer_xml")): answer_dict = cls.parse_answer_xml(answer_xml) tsid = answer_dict.get("tsid") @@ -175,10 +176,10 @@ class Assignment(AssignmentStub): if self.answer_xml is None: return None - return self.parse_answer_xml(self.answer_xml) + return self.parse_answer_xml(self.answer_xml) # type: ignore @staticmethod - def parse_answer_xml(answer_xml: str): + def parse_answer_xml(answer_xml: str) -> dict[str, Any]: root = ElementTree.fromstring(answer_xml) ns = { "mt": "http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2005-10-01/QuestionFormAnswers.xsd" @@ -186,8 +187,8 @@ class Assignment(AssignmentStub): res = {} for a in root.findall("mt:Answer", ns): - name = a.find("mt:QuestionIdentifier", ns).text - value = a.find("mt:FreeText", ns).text + name = a.find("mt:QuestionIdentifier", ns).text # type: ignore + value = a.find("mt:FreeText", ns).text # type: ignore res[name] = value or "" EXPECTED_KEYS = {"amt_assignment_id", "amt_worker_id", "tsid"} diff --git a/jb/models/bonus.py b/jb/models/bonus.py index 564a32d..a536dd1 100644 --- a/jb/models/bonus.py +++ b/jb/models/bonus.py @@ -1,11 +1,10 @@ -from typing import Optional, Dict +from typing import Optional, Dict, Any from pydantic import BaseModel, Field, ConfigDict, PositiveInt from typing_extensions import Self -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.custom_types import AMTBoto3ID, AwareDatetimeISO, UUIDStr -from jb.models.definitions import PayoutStatus class Bonus(BaseModel): @@ -41,7 +40,7 @@ class Bonus(BaseModel): return d @classmethod - def from_postgres(cls, data: Dict) -> Self: + def from_postgres(cls, data: Dict[str, Any]) -> Self: data["amount"] = USDCent(round(data["amount"] * 100)) fields = set(cls.model_fields.keys()) data = {k: v for k, v in data.items() if k in fields} diff --git a/jb/models/currency.py b/jb/models/currency.py deleted file mode 100644 index 3094e2a..0000000 --- a/jb/models/currency.py +++ /dev/null @@ -1,70 +0,0 @@ -import warnings -from decimal import Decimal -from typing import Any - -from pydantic import GetCoreSchemaHandler, NonNegativeInt -from pydantic_core import CoreSchema, core_schema - - -class USDCent(int): - def __new__(cls, value, *args, **kwargs): - - if isinstance(value, float): - warnings.warn( - "USDCent init with a float. Rounding behavior may " "be unexpected" - ) - - if isinstance(value, Decimal): - warnings.warn( - "USDCent init with a Decimal. Rounding behavior may " "be unexpected" - ) - - if value < 0: - raise ValueError("USDCent not be less than zero") - - return super(cls, cls).__new__(cls, value) - - def __add__(self, other): - assert isinstance(other, USDCent) - res = super(USDCent, self).__add__(other) - return self.__class__(res) - - def __sub__(self, other): - assert isinstance(other, USDCent) - res = super(USDCent, self).__sub__(other) - return self.__class__(res) - - def __mul__(self, other): - assert isinstance(other, USDCent) - res = super(USDCent, self).__mul__(other) - return self.__class__(res) - - def __abs__(self): - res = super(USDCent, self).__abs__() - return self.__class__(res) - - def __truediv__(self, other): - raise ValueError("Division not allowed for USDCent") - - def __str__(self): - return "%d" % int(self) - - def __repr__(self): - return "USDCent(%d)" % int(self) - - @classmethod - def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler - ) -> CoreSchema: - """ - https://docs.pydantic.dev/latest/concepts/types/#customizing-validation-with-__get_pydantic_core_schema__ - """ - return core_schema.no_info_after_validator_function( - cls, handler(NonNegativeInt) - ) - - def to_usd(self) -> Decimal: - return Decimal(int(self) / 100).quantize(Decimal(".01")) - - def to_usd_str(self) -> str: - return "${:,.2f}".format(float(self.to_usd())) diff --git a/jb/models/custom_types.py b/jb/models/custom_types.py index 70bc5c1..10bc9d1 100644 --- a/jb/models/custom_types.py +++ b/jb/models/custom_types.py @@ -34,8 +34,7 @@ def convert_str_dt(v: Any) -> Optional[AwareDatetime]: def assert_utc(v: AwareDatetime) -> AwareDatetime: - if isinstance(v, datetime): - assert v.tzinfo == timezone.utc, "Timezone is not UTC" + assert v.tzinfo == timezone.utc, "Timezone is not UTC" return v diff --git a/jb/models/definitions.py b/jb/models/definitions.py index a3d27ba..4ae7a21 100644 --- a/jb/models/definitions.py +++ b/jb/models/definitions.py @@ -1,4 +1,4 @@ -from enum import IntEnum, StrEnum +from enum import IntEnum class AssignmentStatus(IntEnum): @@ -37,32 +37,6 @@ class HitReviewStatus(IntEnum): ReviewedInappropriate = 3 -class PayoutStatus(StrEnum): - """These are GRL's payout statuses""" - - # The user has requested a payout. The money is taken from their - # wallet. A PENDING request can either be APPROVED, REJECTED, or - # CANCELLED. We can also implicitly skip the APPROVED step and go - # straight to COMPLETE or FAILED. - PENDING = "PENDING" - # The request is approved (by us or automatically). Once approved, - # it can be FAILED or COMPLETE. - APPROVED = "APPROVED" - # The request is rejected. The user loses the money. - REJECTED = "REJECTED" - # The user requests to cancel the request, the money goes back into their wallet. - CANCELLED = "CANCELLED" - # The payment was approved, but failed within external payment provider. - # This is an "error" state, as the money won't have moved anywhere. A - # FAILED payment can be tried again and be COMPLETE. - FAILED = "FAILED" - # The payment was sent successfully and (usually) a fee was charged - # to us for it. - COMPLETE = "COMPLETE" - # Not supported # REFUNDED: I'm not sure if this is possible or - # if we'd want to allow it. - - class ReportValue(IntEnum): """ The reason a user reported a task. diff --git a/jb/models/event.py b/jb/models/event.py index c357772..c167420 100644 --- a/jb/models/event.py +++ b/jb/models/event.py @@ -11,13 +11,22 @@ class MTurkEvent(BaseModel): What AWS SNS will POST to our mturk_notifications endpoint (inside the request body) """ - event_type: EventTypeType = Field(example="AssignmentSubmitted") - event_timestamp: AwareDatetimeISO = Field(example="2025-10-16T18:45:51Z") - amt_hit_id: AMTBoto3ID = Field(example="12345678901234567890") + event_type: EventTypeType = Field( + json_schema_extra={"example": "AssignmentSubmitted"} + ) + event_timestamp: AwareDatetimeISO = Field( + json_schema_extra={"example": "2025-10-16T18:45:51Z"} + ) + amt_hit_id: AMTBoto3ID = Field( + json_schema_extra={"example": "12345678901234567890"} + ) amt_assignment_id: str = Field( - max_length=64, example="1234567890123456789012345678901234567890" + max_length=64, + json_schema_extra={"example": "1234567890123456789012345678901234567890"}, + ) + amt_hit_type_id: AMTBoto3ID = Field( + json_schema_extra={"example": "09876543210987654321"} ) - amt_hit_type_id: AMTBoto3ID = Field(example="09876543210987654321") @classmethod def from_sns(cls, data: Dict): diff --git a/jb/models/hit.py b/jb/models/hit.py index c3734fa..fba2ecf 100644 --- a/jb/models/hit.py +++ b/jb/models/hit.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone, timedelta -from typing import Optional, List, Dict +from typing import Optional, List, Dict, Any from uuid import uuid4 from xml.etree import ElementTree @@ -13,7 +13,7 @@ from pydantic import ( ) from typing_extensions import Self -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.custom_types import AMTBoto3ID, HttpsUrlStr, AwareDatetimeISO from jb.models.definitions import HitStatus, HitReviewStatus @@ -104,11 +104,11 @@ class HitType(HitTypeCommon): return d @classmethod - def from_postgres(cls, data: Dict) -> Self: + def from_postgres(cls, data: Dict[str, Any]) -> Self: data["reward"] = USDCent(round(data["reward"] * 100)) return cls.model_validate(data) - def generate_hit_amt_request(self, question: HitQuestion): + def generate_hit_amt_request(self, question: HitQuestion) -> Dict[str, Any]: d = dict() d["HITTypeId"] = self.amt_hit_type_id d["MaxAssignments"] = 1 @@ -135,7 +135,12 @@ class Hit(HitTypeCommon): status: HitStatus = Field() review_status: HitReviewStatus = Field() - creation_time: AwareDatetimeISO = Field(default=None, description="From aws") + + # TODO: Check if this is actually ever going to be None. I type fixed it, + # but I don't have anything to suggest it isn't requred. -- Max 2026-02-24 + creation_time: Optional[AwareDatetimeISO] = Field( + default=None, description="From aws" + ) expiration: Optional[AwareDatetimeISO] = Field(default=None) # GRL Specific @@ -150,7 +155,7 @@ class Hit(HitTypeCommon): # -- Hit specific - qualification_requirements: Optional[List[Dict]] = Field(default=None) + qualification_requirements: Optional[List[Dict[str, Any]]] = Field(default=None) max_assignments: int = Field() # # this comes back as expiration. only for the request @@ -171,7 +176,7 @@ class Hit(HitTypeCommon): assert hit_type.id is not None assert hit_type.amt_hit_type_id is not None - h = Hit.model_validate( + h = cls.model_validate( dict( amt_hit_id=data["HITId"], amt_hit_type_id=data["HITTypeId"], @@ -194,11 +199,12 @@ class Hit(HitTypeCommon): hit_type_id=hit_type.id, ) ) + return h @classmethod def from_amt_get_hit(cls, data: HITTypeDef) -> Self: - h = Hit.model_validate( + h = cls.model_validate( dict( amt_hit_id=data["HITId"], amt_hit_type_id=data["HITTypeId"], @@ -229,7 +235,7 @@ class Hit(HitTypeCommon): return d @classmethod - def from_postgres(cls, data: Dict) -> Self: + def from_postgres(cls, data: Dict[str, Any]) -> Self: data["reward"] = USDCent(round(data["reward"] * 100)) return cls.model_validate(data) diff --git a/jb/settings.py b/jb/settings.py index 538b89f..5754add 100644 --- a/jb/settings.py +++ b/jb/settings.py @@ -45,7 +45,7 @@ class Settings(AmtJbBaseSettings): debug: bool = False app_name: str = "AMT JB API" - fsb_host: HttpUrl = Field(default="https://fsb.generalresearch.com/") + fsb_host: HttpUrl = Field(default=HttpUrl("https://fsb.generalresearch.com/")) # Needed for admin function on fsb w/o authentication fsb_host_private_route: Optional[str] = Field(default=None) diff --git a/jb/views/common.py b/jb/views/common.py index 46ac608..0dc8b56 100644 --- a/jb/views/common.py +++ b/jb/views/common.py @@ -11,7 +11,7 @@ from starlette.responses import RedirectResponse from jb.config import settings, JB_EVENTS_STREAM from jb.decorators import REDIS, HM from jb.flow.monitoring import emit_assignment_event, emit_mturk_notification_event -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.definitions import ReportValue, AssignmentStatus from jb.models.event import MTurkEvent from jb.settings import BASE_HTML @@ -71,6 +71,7 @@ async def work(request: Request): url=f"/preview/?{request.url.query}" if request.url.query else "/preview/", status_code=302, ) + if amt_assignment_id is None or amt_assignment_id == "ASSIGNMENT_ID_NOT_AVAILABLE": # Worker is previewing the HIT amt_hit_type_id = "unknown" @@ -91,71 +92,6 @@ async def work(request: Request): return HTMLResponse(BASE_HTML) -@common_router.get(path="/survey/", response_class=JSONResponse) -def survey( - request: Request, - worker_id: str = Query(), - duration: int = Query(default=1200), -): - if not worker_id: - raise HTTPException(status_code=400, detail="Missing worker_id") - - # (1) Check wallet - wallet_url = f"{settings.fsb_host}{settings.product_id}/wallet/" - wallet_res = requests.get(wallet_url, params={"bpuid": worker_id}) - if wallet_res.status_code != 200: - raise HTTPException(status_code=502, detail="Wallet check failed") - - wallet_data = wallet_res.json() - wallet_balance = wallet_data["wallet"]["amount"] - if wallet_balance < -100: - return JSONResponse( - { - "total_surveys": 0, - "link": None, - "duration": None, - "payout": None, - } - ) - - # (2) Get offerwall - client_ip = "69.253.144.55" if settings.debug else request.client.host - offerwall_url = f"{settings.fsb_host}{settings.product_id}/offerwall/d48cce47/" - offerwall_res = requests.get( - offerwall_url, - params={ - "bpuid": worker_id, - "ip": client_ip, - "n_bins": 1, - "duration": duration, - }, - ) - - if offerwall_res.status_code != 200: - raise HTTPException(status_code=502, detail="Offerwall request failed") - - try: - rj = offerwall_res.json() - bucket = rj["offerwall"]["buckets"][0] - return JSONResponse( - { - "total_surveys": rj["offerwall"]["availability_count"], - "link": bucket["uri"], - "duration": round(bucket["duration"]["q2"] / 60), - "payout": USDCent(bucket["payout"]["q2"]).to_usd_str(), - } - ) - except Exception: - return JSONResponse( - { - "total_surveys": 0, - "link": None, - "duration": None, - "payout": None, - } - ) - - @common_router.post(path=f"/{settings.sns_path}/", include_in_schema=False) async def mturk_notifications(request: Request): """ diff --git a/jb/views/tasks.py b/jb/views/tasks.py index 7176b29..313295f 100644 --- a/jb/views/tasks.py +++ b/jb/views/tasks.py @@ -18,18 +18,23 @@ def process_request(request: Request) -> None: amt_assignment_id = request.query_params.get("assignmentId", None) if amt_assignment_id == "ASSIGNMENT_ID_NOT_AVAILABLE": raise ValueError("shouldn't happen") + amt_hit_id = request.query_params.get("hitId", None) amt_worker_id = request.query_params.get("workerId", None) print(f"process_request: {amt_assignment_id=} {amt_worker_id=} {amt_hit_id=}") assert amt_worker_id and amt_hit_id and amt_assignment_id # Check that the HIT is still valid - hit = HM.get_from_amt_id(amt_hit_id=amt_hit_id) + hit = HM.get_from_amt_id_if_exists(amt_hit_id=amt_hit_id) + if not hit: + raise ValueError(f"Hit {amt_hit_id} not found in DB") + _ = check_hit_status(amt_hit_id=amt_hit_id, amt_hit_type_id=hit.amt_hit_type_id) emit_assignment_event( status=AssignmentStatus.Accepted, amt_hit_type_id=hit.amt_hit_type_id, ) + # I think it won't be assignable anymore? idk # assert hit_status == HitStatus.Assignable, f"hit {amt_hit_id} {hit_status=}. Expected Assignable" @@ -53,7 +58,7 @@ def process_request(request: Request) -> None: assert assignment_stub.amt_worker_id == amt_worker_id assert assignment_stub.amt_assignment_id == amt_assignment_id assert assignment_stub.created_at > ( - datetime.now(tz=timezone.utc) - timedelta(minutes=90) + datetime.now(tz=timezone.utc) - timedelta(minutes=90) ) return None diff --git a/jb/views/utils.py b/jb/views/utils.py index 39db5d2..0d08e9b 100644 --- a/jb/views/utils.py +++ b/jb/views/utils.py @@ -8,8 +8,11 @@ def get_client_ip(request: Request) -> str: """ ip = request.headers.get("X-Forwarded-For") if not ip: - ip = request.client.host + ip = request.client.host # type: ignore elif ip == "testclient" or ip.startswith("10."): forwarded = request.headers.get("X-Forwarded-For") - ip = forwarded.split(",")[0].strip() if forwarded else request.client.host + ip = ( + forwarded.split(",")[0].strip() if forwarded else request.client.host # type: ignore + ) + return ip diff --git a/tests/__init__.py b/tests/__init__.py index 469eda2..f37e785 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,6 +2,6 @@ import random import string -def generate_amt_id(length=30): +def generate_amt_id(length=30) -> str: chars = string.ascii_uppercase + string.digits return "".join(random.choices(chars, k=length)) diff --git a/tests/amt/test_models.py b/tests/amt/test_models.py index c2a61b5..cecd948 100644 --- a/tests/amt/test_models.py +++ b/tests/amt/test_models.py @@ -22,20 +22,26 @@ def get_assignment_response_bad_tsid( ) return res -def test_get_assignment(get_assignment_response): - assignment = Assignment.from_amt_get_assignment( - get_assignment_response["Assignment"] - ) - assert assignment.tsid is not None -def test_get_assignment_no_tsid(get_assignment_response_no_tsid): - assignment = Assignment.from_amt_get_assignment( - get_assignment_response_no_tsid["Assignment"] - ) - assert assignment.tsid is None +class TestAssignment: -def test_get_assignment_bad_tsid(get_assignment_response_bad_tsid): - assignment = Assignment.from_amt_get_assignment( - get_assignment_response_bad_tsid["Assignment"] - ) - assert assignment.tsid is None
\ No newline at end of file + @pytest.mark.anyio + def test_get_assignment(get_assignment_response): + assignment = Assignment.from_amt_get_assignment( + get_assignment_response["Assignment"] + ) + assert assignment.tsid is not None + + @pytest.mark.anyio + def test_get_assignment_no_tsid(get_assignment_response_no_tsid): + assignment = Assignment.from_amt_get_assignment( + get_assignment_response_no_tsid["Assignment"] + ) + assert assignment.tsid is None + + @pytest.mark.anyio + def test_get_assignment_bad_tsid(get_assignment_response_bad_tsid): + assignment = Assignment.from_amt_get_assignment( + get_assignment_response_bad_tsid["Assignment"] + ) + assert assignment.tsid is None diff --git a/tests/conftest.py b/tests/conftest.py index 985c9dc..3318f1c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ import copy from datetime import datetime, timezone, timedelta +import os from typing import Optional from uuid import uuid4 - +from dotenv import load_dotenv import pytest from dateutil.tz import tzlocal from mypy_boto3_mturk.type_defs import ( @@ -11,62 +12,179 @@ from mypy_boto3_mturk.type_defs import ( CreateHITWithHITTypeResponseTypeDef, GetAssignmentResponseTypeDef, ) - -from jb.decorators import HQM, HTM, HM, AM +from jb.managers import Permission +from generalresearchutils.pg_helper import PostgresConfig from jb.managers.amt import AMTManager, APPROVAL_MESSAGE, NO_WORK_APPROVAL_MESSAGE from jb.models.assignment import AssignmentStub, Assignment -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.definitions import HitStatus, HitReviewStatus, AssignmentStatus from jb.models.hit import HitType, HitQuestion, Hit from tests import generate_amt_id @pytest.fixture -def amt_hit_type_id(): - return generate_amt_id() - - -@pytest.fixture -def amt_hit_id(): +def amt_hit_type_id() -> str: return generate_amt_id() @pytest.fixture -def amt_assignment_id(): +def amt_assignment_id() -> str: return generate_amt_id() @pytest.fixture -def amt_worker_id(): +def amt_worker_id() -> str: return generate_amt_id(length=21) @pytest.fixture -def amt_group_id(): +def amt_group_id() -> str: return generate_amt_id() @pytest.fixture -def tsid(): +def tsid() -> str: return uuid4().hex @pytest.fixture -def tsid1(): +def tsid1() -> str: return uuid4().hex @pytest.fixture -def tsid2(): +def tsid2() -> str: return uuid4().hex @pytest.fixture -def pe_id(): +def pe_id() -> str: # payout event / cashout request UUID return uuid4().hex +# --- Settings --- + + +@pytest.fixture(scope="session") +def env_file_path(pytestconfig): + root_path = pytestconfig.rootpath + env_path = os.path.join(root_path, ".env.test") + + if os.path.exists(env_path): + load_dotenv(dotenv_path=env_path, override=True) + + return env_path + + +@pytest.fixture(scope="session") +def settings(env_file_path) -> "Settings": + from jb.settings import Settings as JBSettings + + s = JBSettings(_env_file=env_file_path) + + return s + + +# --- Database Connectors --- + + +@pytest.fixture(scope="session") +def redis(settings): + from generalresearchutils.redis_helper import RedisConfig + + redis_config = RedisConfig( + dsn=settings.redis, + decode_responses=True, + socket_timeout=settings.redis_timeout, + socket_connect_timeout=settings.redis_timeout, + ) + return redis_config.create_redis_client() + + +@pytest.fixture(scope="session") +def pg_config(settings) -> PostgresConfig: + return PostgresConfig( + dsn=settings.amt_jb_db, + connect_timeout=1, + statement_timeout=1, + ) + + +# --- Managers --- + + +@pytest.fixture(scope="session") +def hqm(pg_config) -> "HitQuestionManager": + assert "/unittest-" in pg_config.dsn.path + + from jb.managers.hit import HitQuestionManager + + return HitQuestionManager( + pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE] + ) + + +@pytest.fixture(scope="session") +def htm(pg_config) -> "HitTypeManager": + assert "/unittest-" in pg_config.dsn.path + + from jb.managers.hit import HitTypeManager + + return HitTypeManager( + pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE] + ) + + +@pytest.fixture(scope="session") +def hm(pg_config) -> "HitManager": + assert "/unittest-" in pg_config.dsn.path + + from jb.managers.hit import HitManager + + return HitManager( + pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE] + ) + + +@pytest.fixture(scope="session") +def am(pg_config) -> "AssignmentManager": + assert "/unittest-" in pg_config.dsn.path + + from jb.managers.assignment import AssignmentManager + + return AssignmentManager( + pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE] + ) + + +@pytest.fixture(scope="session") +def bm(pg_config) -> "BonusManager": + assert "/unittest-" in pg_config.dsn.path + + from jb.managers.bonus import BonusManager + + return BonusManager( + pg_config=pg_config, permissions=[Permission.READ, Permission.CREATE] + ) + + +# --- Question --- + + +@pytest.fixture +def question() -> HitQuestion: + return HitQuestion(url="https://jamesbillings67.com/work/", height=1200) + + +@pytest.fixture +def question_record(hqm, question) -> HitQuestion: + return hqm.get_or_create(question) + + +# --- HITType --- + + @pytest.fixture def hit_type() -> HitType: return HitType( @@ -78,42 +196,38 @@ def hit_type() -> HitType: ) -from jb.models.hit import HitType +@pytest.fixture +def hit_type_record(htm, hit_type) -> HitType: + hit_type.amt_hit_type_id = generate_amt_id() + + return htm.get_or_create(hit_type) @pytest.fixture -def hit_type_with_amt_id(hit_type: HitType) -> HitType: +def hit_type_with_amt_id(htm, hit_type: HitType) -> HitType: # This is a real hit type I've previously registered with amt (sandbox). # It will always exist hit_type.amt_hit_type_id = "3217B3DC4P5YW9DRV9R3X8O56V041J" # Get or create our db - HTM.get_or_create(hit_type) + htm.get_or_create(hit_type) # this call adds the pk int id ---^ return hit_type -@pytest.fixture -def question(): - return HitQuestion(url="https://jamesbillings67.com/work/", height=1200) +# --- HIT --- @pytest.fixture -def hit_in_amt(hit_type_with_amt_id: HitType, question: HitQuestion) -> Hit: - # Actually create a new HIT in amt (sandbox) - question = HQM.get_or_create(question) - hit = AMTManager.create_hit_with_hit_type( - hit_type=hit_type_with_amt_id, question=question - ) - # Create it in the DB - HM.create(hit) - return hit +def amt_hit_id() -> str: + return generate_amt_id() @pytest.fixture -def hit(amt_hit_id, amt_hit_type_id, amt_group_id, question): +def hit(amt_hit_id, amt_hit_type_id, amt_group_id, question) -> Hit: now = datetime.now(tz=timezone.utc) + return Hit.model_validate( dict( amt_hit_id=amt_hit_id, @@ -140,24 +254,41 @@ def hit(amt_hit_id, amt_hit_type_id, amt_group_id, question): @pytest.fixture -def hit_in_db( - hit_type: HitType, amt_hit_type_id, amt_hit_id, question: HitQuestion, hit: Hit +def hit_record( + hm, + question_record, + hit_type_record, + hit, + amt_hit_id, ) -> Hit: """ Returns a hit that exists in our db, but does not in amazon (the amt ids are random). The mtwerk_hittype and mtwerk_question records will also exist (in the db) """ - question = HQM.get_or_create(question) - hit_type.amt_hit_type_id = amt_hit_type_id - HTM.create(hit_type) - hit.hit_type_id = hit_type.id + + hit.hit_type_id = hit_type_record.id hit.amt_hit_id = amt_hit_id - hit.question_id = question.id - HM.create(hit) + hit.question_id = question_record.id + + hm.create(hit) + return hit + + +@pytest.fixture +def hit_in_amt(hm, question_record, hit_type_with_amt_id: HitType) -> Hit: + # Actually create a new HIT in amt (sandbox) + hit = AMTManager.create_hit_with_hit_type( + hit_type=hit_type_with_amt_id, question=question_record + ) + # Create it in the DB + hm.create(hit) return hit +# --- Assignment --- + + @pytest.fixture def assignment_stub(hit: Hit, amt_assignment_id, amt_worker_id): now = datetime.now(tz=timezone.utc) @@ -193,28 +324,31 @@ def assignment_factory(hit: Hit): @pytest.fixture -def assignment_in_db_factory(assignment_factory): +def assignment_in_db_factory(am, assignment_factory): def inner(hit_id: int, amt_worker_id: Optional[str] = None): a = assignment_factory(amt_worker_id=amt_worker_id) a.hit_id = hit_id - AM.create_stub(a) - AM.update_answer(a) + am.create_stub(a) + am.update_answer(a) return a return inner @pytest.fixture -def assignment_stub_in_db(hit_in_db, assignment_stub) -> AssignmentStub: +def assignment_stub_in_db(am, hit_record, assignment_stub) -> AssignmentStub: """ Returns an AssignmentStub that exists in our db, but does not in amazon (the amt ids are random). The mtwerk_hit, mtwerk_hittype, and mtwerk_question records will also exist (in the db) """ - assignment_stub.hit_id = hit_in_db.id - AM.create_stub(assignment_stub) + assignment_stub.hit_id = hit_record.id + am.create_stub(assignment_stub) return assignment_stub +# --- HIT --- + + @pytest.fixture def amt_response_metadata(): req_id = str(uuid4()) diff --git a/tests/flow/test_tasks.py b/tests/flow/test_tasks.py index 37391d1..b708cf9 100644 --- a/tests/flow/test_tasks.py +++ b/tests/flow/test_tasks.py @@ -26,8 +26,9 @@ from jb.managers.amt import ( BONUS_MESSAGE, NO_WORK_APPROVAL_MESSAGE, ) -from jb.models.currency import USDCent -from jb.models.definitions import AssignmentStatus, PayoutStatus +from generalresearchutils.currency import USDCent +from jb.models.definitions import AssignmentStatus +from generalresearchutils.models.thl.definitions import PayoutStatus from jb.models.event import MTurkEvent @@ -336,7 +337,7 @@ class TestProcessAssignmentSubmitted: amt_assignment_id, get_assignment_response: Dict, caplog, - hit_in_db, + hit_record, rejected_assignment_stubs, ): # An assignment is submitted. The hit exists in the DB. The amt assignment id is valid, @@ -442,7 +443,7 @@ class TestProcessAssignmentSubmitted: get_assignment_response_rejected_no_tsid, get_assignment_response_no_tsid, assignment_in_db_factory, - hit_in_db, + hit_record, amt_worker_id, ): # An assignment is submitted. The hit and assignment stub exist in the DB. @@ -451,9 +452,9 @@ class TestProcessAssignmentSubmitted: # Going to create and submit 3 assignments w no work # (all on the same hit, which we don't do in JB for real, # but doesn't matter here) - a1 = assignment_in_db_factory(hit_id=hit_in_db.id, amt_worker_id=amt_worker_id) - a2 = assignment_in_db_factory(hit_id=hit_in_db.id, amt_worker_id=amt_worker_id) - a3 = assignment_in_db_factory(hit_id=hit_in_db.id, amt_worker_id=amt_worker_id) + a1 = assignment_in_db_factory(hit_id=hit_record.id, amt_worker_id=amt_worker_id) + a2 = assignment_in_db_factory(hit_id=hit_record.id, amt_worker_id=amt_worker_id) + a3 = assignment_in_db_factory(hit_id=hit_record.id, amt_worker_id=amt_worker_id) assert AM.missing_tsid_count(amt_worker_id=amt_worker_id) == 3 # So now, we'll reject, b/c they've already gotten 3 warnings diff --git a/tests/http/test_basic.py b/tests/http/test_basic.py index 7b03a1e..18359da 100644 --- a/tests/http/test_basic.py +++ b/tests/http/test_basic.py @@ -22,14 +22,3 @@ async def test_static_file_alias(httpxclient: AsyncClient): res = await client.get(p) assert res.status_code == 200, p assert res.json() == {} - - -@pytest.mark.anyio -async def test_health(httpxclient: AsyncClient): - client = httpxclient - """ - These are here for site crawlers and stuff.. - """ - res = await client.get("/health/") - assert res.status_code == 200 - assert res.json() == {} diff --git a/tests/http/test_notifications.py b/tests/http/test_notifications.py index 70458b8..6770044 100644 --- a/tests/http/test_notifications.py +++ b/tests/http/test_notifications.py @@ -5,7 +5,6 @@ from httpx import AsyncClient import secrets from jb.config import JB_EVENTS_STREAM, settings -from jb.decorators import REDIS from jb.models.event import MTurkEvent from tests import generate_amt_id @@ -40,16 +39,17 @@ def example_mturk_event_body(amt_hit_id, amt_hit_type_id, amt_assignment_id): @pytest.fixture() -def clean_mturk_events_redis_stream(): - REDIS.xtrim(JB_EVENTS_STREAM, maxlen=0) - assert REDIS.xlen(JB_EVENTS_STREAM) == 0 +def clean_mturk_events_redis_stream(redis): + redis.xtrim(JB_EVENTS_STREAM, maxlen=0) + assert redis.xlen(JB_EVENTS_STREAM) == 0 yield - REDIS.xtrim(JB_EVENTS_STREAM, maxlen=0) - assert REDIS.xlen(JB_EVENTS_STREAM) == 0 + redis.xtrim(JB_EVENTS_STREAM, maxlen=0) + assert redis.xlen(JB_EVENTS_STREAM) == 0 @pytest.mark.anyio async def test_mturk_notifications( + redis, httpxclient: AsyncClient, no_limit, example_mturk_event_body, @@ -61,10 +61,10 @@ async def test_mturk_notifications( res = await client.post(url=f"/{settings.sns_path}/", json=example_mturk_event_body) res.raise_for_status() - msg_res = REDIS.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100) + msg_res = redis.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100) msg_res = msg_res[0][1][0] msg_id, msg = msg_res - REDIS.xdel(JB_EVENTS_STREAM, msg_id) + redis.xdel(JB_EVENTS_STREAM, msg_id) msg_json = msg["data"] event = MTurkEvent.model_validate_json(msg_json) diff --git a/tests/http/test_preview.py b/tests/http/test_preview.py new file mode 100644 index 0000000..2bdf265 --- /dev/null +++ b/tests/http/test_preview.py @@ -0,0 +1,41 @@ +# There are two types of "preview" - one is where we navigate direclty +# to it and one is where we redirect possibly + +import pytest +from httpx import AsyncClient + +from jb.models.hit import Hit + + +class TestPreview: + + @pytest.mark.anyio + async def test_preview_direct(self, httpxclient: AsyncClient): + client = httpxclient + res = await client.get("/preview/") + + assert res.status_code == 200 + # the response is an html page + + assert res.headers["content-type"] == "text/html; charset=utf-8" + assert res.num_bytes_downloaded == 507 + + assert "James Billings loves you." in res.text + assert "https://cdn.jamesbillings67.com/james-has-style.css" in res.text + assert "https://cdn.jamesbillings67.com/james-billings.js" in res.text + + @pytest.mark.anyio + async def test_preview_redirect_from_work( + self, httpxclient: AsyncClient, amt_hit_id, amt_assignment_id + ): + client = httpxclient + + params = { + "workerId": None, + "assignmentId": amt_assignment_id, + "hitId": amt_hit_id, + } + res = await client.get("/work/", params=params) + + assert res.status_code == 302 + assert "/preview/" in res.headers["location"] diff --git a/tests/http/test_report.py b/tests/http/test_report.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/http/test_report.py diff --git a/tests/http/test_status.py b/tests/http/test_status.py deleted file mode 100644 index d88ff65..0000000 --- a/tests/http/test_status.py +++ /dev/null @@ -1,78 +0,0 @@ -from uuid import uuid4 - -import pytest -from httpx import AsyncClient - -from jb.config import settings -from tests import generate_amt_id - - -@pytest.mark.anyio -async def test_get_status_args(httpxclient: AsyncClient, no_limit): - client = httpxclient - - # tsid misformatted - res = await client.get(f"/status/{uuid4().hex[:-1]}/") - assert res.status_code == 422 - assert "String should have at least 32 characters" in res.text - - -@pytest.mark.anyio -async def test_get_status_error(httpxclient: AsyncClient, no_limit): - # Expects settings.fsb_host to point to a functional thl-fsb - client = httpxclient - - # tsid doesn't exist - res = await client.get(f"/status/{uuid4().hex}/") - assert res.status_code == 502 - assert res.json()["detail"] == "Failed to fetch status" - - -@pytest.mark.anyio -async def test_get_status_complete(httpxclient: AsyncClient, no_limit, mock_requests): - client = httpxclient - - tsid = uuid4().hex - url = f"{settings.fsb_host}{settings.product_id}/status/{tsid}/" - - mock_response = { - "tsid": tsid, - "product_id": settings.product_id, - "bpuid": generate_amt_id(length=21), - "started": "2022-06-29T23:43:48.247777Z", - "finished": "2022-06-29T23:56:57.632634Z", - "status": 3, - "payout": 81, - "user_payout": 77, - "payout_format": "${payout/100:.2f}", - "user_payout_string": "$0.77", - "kwargs": {}, - } - mock_requests.get(url, json=mock_response, status_code=200) - res = await client.get(f"/status/{tsid}/") - assert res.status_code == 200 - assert res.json() == {"status": 3, "payout": "$0.77"} - - -@pytest.mark.anyio -async def test_get_status_failure(httpxclient: AsyncClient, no_limit, mock_requests): - client = httpxclient - - tsid = uuid4().hex - url = f"{settings.fsb_host}{settings.product_id}/status/{tsid}/" - - mock_response = { - "tsid": tsid, - "product_id": settings.product_id, - "bpuid": "123ABC", - "status": 2, - "payout": 0, - "user_payout": 0, - "payout_format": "${payout/100:.2f}", - "user_payout_string": None, - "kwargs": {}, - } - mock_requests.get(url, json=mock_response, status_code=200) - res = await client.get(f"/status/{tsid}/") - assert res.status_code == 200 - assert res.json() == {"status": 2, "payout": None} diff --git a/tests/http/test_statuses.py b/tests/http/test_statuses.py deleted file mode 100644 index ffc98fd..0000000 --- a/tests/http/test_statuses.py +++ /dev/null @@ -1,102 +0,0 @@ -from datetime import datetime, timezone, timedelta -from urllib.parse import urlencode - -import pytest -from uuid import uuid4 -from httpx import AsyncClient - -from jb.config import settings - - -@pytest.mark.anyio -async def test_get_statuses(httpxclient: AsyncClient, no_limit, amt_worker_id): - # Expects settings.fsb_host to point to a functional thl-fsb - client = httpxclient - now = datetime.now(tz=timezone.utc) - - params = {"worker_id": amt_worker_id} - res = await client.get(f"/statuses/", params=params) - assert res.status_code == 200 - assert res.json() == [] - - params = {"worker_id": amt_worker_id, "started_after": now.isoformat()} - res = await client.get(f"/statuses/", params=params) - assert res.status_code == 422 - assert "Input should be a valid integer" in res.text - - -@pytest.fixture -def fsb_get_statuses_example_response(amt_worker_id, tsid1, tsid2): - return { - "tasks_status": [ - { - "tsid": tsid1, - "product_id": settings.product_id, - "bpuid": amt_worker_id, - "started": "2025-06-12T03:27:24.902280Z", - "finished": "2025-06-12T03:29:37.626481Z", - "status": 2, - "payout": 0, - "user_payout": None, - "payout_format": None, - "user_payout_string": None, - "kwargs": {}, - "status_code_1": "SESSION_START_QUALITY_FAIL", - "status_code_2": "ENTRY_URL_MODIFICATION", - }, - { - "tsid": tsid2, - "product_id": settings.product_id, - "bpuid": amt_worker_id, - "started": "2025-06-12T03:30:18.176826Z", - "finished": "2025-06-12T03:36:58.789059Z", - "status": 2, - "payout": 0, - "user_payout": None, - "payout_format": None, - "user_payout_string": None, - "kwargs": {}, - "status_code_1": "BUYER_QUALITY_FAIL", - "status_code_2": None, - }, - ] - } - - -@pytest.mark.anyio -async def test_get_statuses_mock( - httpxclient: AsyncClient, - no_limit, - amt_worker_id, - mock_requests, - fsb_get_statuses_example_response, - tsid1, - tsid2, -): - client = httpxclient - now = datetime.now(tz=timezone.utc) - started_after = now - timedelta(minutes=5) - - # The fsb call we are mocking ------v - params = { - "bpuid": amt_worker_id, - "started_after": round(started_after.timestamp()), - "started_before": round(now.timestamp()), - } - url = f"{settings.fsb_host}{settings.product_id}/status/" + "?" + urlencode(params) - mock_requests.get(url, json=fsb_get_statuses_example_response, status_code=200) - # ---- end mock - - params = { - "worker_id": amt_worker_id, - "started_after": round(started_after.timestamp()), - "started_before": round(now.timestamp()), - } - result = await client.get(f"/statuses/", params=params) - assert result.status_code == 200 - res = result.json() - assert len(res) == 2 - assert res == [ - {"status": 2, "tsid": tsid1}, - {"status": 2, "tsid": tsid2}, - ] diff --git a/tests/http/test_work.py b/tests/http/test_work.py index 59b8830..c69118b 100644 --- a/tests/http/test_work.py +++ b/tests/http/test_work.py @@ -1,24 +1,24 @@ import pytest from httpx import AsyncClient -from jb.models.hit import Hit +class TestWork: -@pytest.mark.skip(reason="hits live api, need to look at this") -async def test_work( - httpxclient: AsyncClient, - no_limit, - amt_worker_id, - amt_assignment_id, - hit_in_amt: Hit, -): - client = httpxclient - params = { - "workerId": amt_worker_id, - "assignmentId": amt_assignment_id, - "hitId": hit_in_amt.amt_hit_id, - } - res = await client.get(f"/work/", params=params) - assert res.status_code == 200 - # the response is an html page - assert res.headers["content-type"] == "text/html; charset=utf-8" + @pytest.mark.anyio + async def test_work( + self, + httpxclient: AsyncClient, + hit_record, + amt_assignment_id, + amt_worker_id, + ): + client = httpxclient + + params = { + "workerId": amt_worker_id, + "assignmentId": amt_assignment_id, + "hitId": hit_record.amt_hit_id, + } + res = await client.get("/work/", params=params) + + assert res.status_code == 200 diff --git a/tests/managers/amt.py b/tests/managers/amt.py index 0b2e501..a847582 100644 --- a/tests/managers/amt.py +++ b/tests/managers/amt.py @@ -1,4 +1,3 @@ -from jb.decorators import HTM, HM, HQM from jb.managers.amt import AMTManager @@ -8,16 +7,16 @@ def test_create_hit_type(hit_type): assert hit_type.amt_hit_type_id is not None -def test_create_hit_with_hit_type(hit_type_with_amt_id, question): - question = HQM.get_or_create(question) +def test_create_hit_with_hit_type(hqm, htm, hm, hit_type_with_amt_id, question): + question = hqm.get_or_create(question) hit_type = hit_type_with_amt_id hit_type = [ - x for x in HTM.filter_active() if x.amt_hit_type_id == hit_type.amt_hit_type_id + x for x in htm.filter_active() if x.amt_hit_type_id == hit_type.amt_hit_type_id ][0] hit = AMTManager.create_hit_with_hit_type(hit_type=hit_type, question=question) assert hit.amt_hit_id is not None assert hit.id is None - HM.create(hit) + hm.create(hit) assert hit.id is not None diff --git a/tests/managers/hit.py b/tests/managers/hit.py index 8fcd673..cb2b35a 100644 --- a/tests/managers/hit.py +++ b/tests/managers/hit.py @@ -1,18 +1,25 @@ -from jb.decorators import HTM +from jb.models import Question + + +class TestHitQuestionManager: + + def test_base(self, question_record): + assert isinstance(question_record, Question) + assert question_record.id is None class TestHitTypeManager: - def test_create(self, hit_type_with_amt_id): + def test_create(self, htm, hit_type_with_amt_id): assert hit_type_with_amt_id.id is None - HTM.create(hit_type_with_amt_id) + htm.create(hit_type_with_amt_id) assert hit_type_with_amt_id.id is not None - res = HTM.filter_active() + res = htm.filter_active() assert len(res) == 1 hit_type_with_amt_id.min_active = 0 - HTM.set_min_active(hit_type_with_amt_id) + htm.set_min_active(hit_type_with_amt_id) - res = HTM.filter_active() + res = htm.filter_active() assert len(res) == 0 |
