From 67ab724561e4ceb8fe8fb4031de277168f7d9724 Mon Sep 17 00:00:00 2001 From: Max Nanis Date: Sat, 21 Feb 2026 02:15:52 -0500 Subject: More pytest conf, some views, and defining more attrs on the settings config --- jb-ui/vite.config.ts | 2 +- jb/config.py | 36 ++++ jb/flow/tasks.py | 103 ++++++++++ jb/settings.py | 75 +++++++ jb/views/common.py | 186 +++++++++++++++++ tests/__init__.py | 7 + tests/conftest.py | 421 +++++++++++++++++++++++++++++++++++++++ tests/http/test_notifications.py | 71 +++++++ tests/http/test_status.py | 78 ++++++++ tests/http/test_statuses.py | 102 ++++++++++ 10 files changed, 1080 insertions(+), 1 deletion(-) create mode 100644 jb/config.py create mode 100644 jb/flow/tasks.py create mode 100644 jb/settings.py create mode 100644 jb/views/common.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/http/test_notifications.py create mode 100644 tests/http/test_status.py create mode 100644 tests/http/test_statuses.py diff --git a/jb-ui/vite.config.ts b/jb-ui/vite.config.ts index eb074fb..6b0e530 100644 --- a/jb-ui/vite.config.ts +++ b/jb-ui/vite.config.ts @@ -24,7 +24,7 @@ export default defineConfig({ // This forwards requests to the FastAPI development server // that must also be running proxy: { - '^/(status|statuses|report|survey|1393610267ad483387705ac279302143)(/|$)': { + '^/(status|statuses|report|survey)(/|$)': { target: 'http://localhost:8000', changeOrigin: true, } diff --git a/jb/config.py b/jb/config.py new file mode 100644 index 0000000..c7d07e5 --- /dev/null +++ b/jb/config.py @@ -0,0 +1,36 @@ +import logging + +from generalresearchutils.config import is_debug + +from jb.settings import get_settings, get_test_settings + +if is_debug(): + print("running using TEST settings") + settings = get_test_settings() + assert settings.debug is True +else: + print("running using PROD settings") + settings = get_settings() + assert settings.debug is False + +if settings.debug: + LOG_LEVEL = logging.DEBUG +else: + LOG_LEVEL = logging.WARNING + +# The SNS topic that 1) JB Mturk will send notifications to, 2) will make http POSTs +# back to us (here) +TOPIC_ARN = f"arn:aws:sns:us-east-2:{settings.aws_owner_id}:amt-jb" +SUBSCRIPTION = { + "SubscriptionArn": settings.aws_subscription_arn, + "Owner": settings.aws_owner_id, + "Protocol": "https", + "Endpoint": f"https://jamesbillings67.com/{settings.sns_path}/", + "TopicArn": TOPIC_ARN, +} + +JB_EVENTS_STREAM = "amt_jb_events" +JB_EVENTS_FAILED_STREAM = "amt_jb_events_failed" +CONSUMER_GROUP = "amt-jb-0" +# We'll only have 1 consumer atm, change this if we don't +CONSUMER_NAME = "amt-jb-0" diff --git a/jb/flow/tasks.py b/jb/flow/tasks.py new file mode 100644 index 0000000..e7c64b9 --- /dev/null +++ b/jb/flow/tasks.py @@ -0,0 +1,103 @@ +import logging +import time + +from generalresearchutils.config import is_debug + +from jb.decorators import HTM, HM, HQM, pg_config +from jb.flow.maintenance import check_hit_status +from jb.flow.monitoring import write_hit_gauge, emit_hit_event +from jb.managers.amt import AMTManager +from jb.models.definitions import HitStatus +from jb.models.hit import HitType, HitQuestion, Hit + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +def check_stale_hits(): + # Check live hits that haven't been modified in a long time. They may not + # be expired yet, but maybe something is wrong? + res = pg_config.execute_sql_query( + """ + SELECT amt_hit_id, amt_hit_type_id + FROM mtwerk_hit mh + JOIN mtwerk_hittype mht ON mht.id = mh.hit_type_id + WHERE status = %(status)s + ORDER BY modified_at + LIMIT 100;""", + params={"status": HitStatus.Assignable.value}, + ) + for hit in res: + logging.info(f"check_stale_hits: {hit["amt_hit_id"]}") + check_hit_status( + amt_hit_id=hit["amt_hit_id"], + amt_hit_type_id=hit["amt_hit_type_id"], + reason="cleanup", + ) + + +def check_expired_hits(): + # Check live/assignable hits that are expired (based on AMT's expiration time) + res = pg_config.execute_sql_query( + """ + SELECT amt_hit_id, amt_hit_type_id + FROM mtwerk_hit mh + JOIN mtwerk_hittype mht ON mht.id = mh.hit_type_id + WHERE status = %(status)s + AND expiration < now() + LIMIT 100;""", + params={"status": HitStatus.Assignable.value}, + ) + for hit in res: + logging.info(f"check_expired_hits: {hit["amt_hit_id"]}") + check_hit_status( + amt_hit_id=hit["amt_hit_id"], + amt_hit_type_id=hit["amt_hit_type_id"], + reason="expired", + ) + + +def create_hit_from_hittype(hit_type: HitType) -> Hit: + if is_debug(): + raise Exception("Handle AMT Sandbox issues.") + + else: + question = HQM.get_or_create( + HitQuestion(height=800, url="https://jamesbillings67.com/work/") + ) + + hit = AMTManager.create_hit_with_hit_type(hit_type=hit_type, question=question) + HM.create(hit) + emit_hit_event(status=hit.status, amt_hit_type_id=hit.amt_hit_type_id) + return hit + + +def refill_hits() -> None: + for hit_type in HTM.filter_active(): + active_count = HM.get_active_count(hit_type.id) + logging.info( + f"HitType: {hit_type.amt_hit_type_id}, {hit_type.min_active=}, active_count={active_count}" + ) + write_hit_gauge( + status=HitStatus.Assignable, + amt_hit_type_id=hit_type.amt_hit_type_id, + cnt=active_count, + ) + if active_count < hit_type.min_active: + cnt_todo = hit_type.min_active - active_count + logging.info(f"Refilling {cnt_todo} hits") + for _ in range(cnt_todo): + create_hit_from_hittype(hit_type) + + +def refill_hits_task(): + while True: + try: + check_expired_hits() + check_stale_hits() + refill_hits() + except Exception as e: + logging.exception(e) + finally: + time.sleep(5 * 60) diff --git a/jb/settings.py b/jb/settings.py new file mode 100644 index 0000000..28402f3 --- /dev/null +++ b/jb/settings.py @@ -0,0 +1,75 @@ +import os +from functools import lru_cache +from pathlib import Path +from typing import Optional + +from generalresearchutils.models.custom_types import InfluxDsn +from pydantic import Field, PostgresDsn, HttpUrl, RedisDsn +from pydantic_settings import BaseSettings, SettingsConfigDict + +from jb.models.custom_types import UUIDStr + +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +BASE_HTML_PATH = Path(BASE_DIR) / "templates" / "base.html" +BASE_HTML = BASE_HTML_PATH.read_text() + + +class AmtJbBaseSettings(BaseSettings): + debug: bool = Field(default=True) + + redis: Optional[RedisDsn] = Field(default=None) + redis_timeout: float = Field(default=0.10) + + amt_jb_db: PostgresDsn = Field() + + amt_endpoint: Optional[HttpUrl] = Field(default=None) + amt_access_id: Optional[str] = Field(default=None) + amt_secret_key: Optional[str] = Field(default=None) + + aws_owner_id: str = Field() + aws_subscription_arn: str = Field() + + +class Settings(AmtJbBaseSettings): + model_config = SettingsConfigDict( + env_prefix="", + case_sensitive=False, + env_file=os.path.join(BASE_DIR, ".env"), + extra="allow", + cli_parse_args=False, + ) + debug: bool = False + app_name: str = "AMT JB API" + + fsb_host: HttpUrl = Field(default="https://fsb.generalresearch.com/") + # Needed for admin function on fsb w/o authentication + fsb_host_private_route: Optional[str] = Field(default=None) + + product_id: UUIDStr = Field() + + influx_db: Optional[InfluxDsn] = Field(default=None) + + sns_path: str = Field() + + +class TestSettings(Settings): + model_config = SettingsConfigDict( + env_prefix="", + case_sensitive=False, + env_file=os.path.join(BASE_DIR, ".env.dev"), + extra="allow", + cli_parse_args=False, + ) + debug: bool = True + app_name: str = "AMT JB API Development" + + +@lru_cache +def get_settings(): + return Settings() + + +@lru_cache +def get_test_settings(): + return TestSettings() diff --git a/jb/views/common.py b/jb/views/common.py new file mode 100644 index 0000000..46ac608 --- /dev/null +++ b/jb/views/common.py @@ -0,0 +1,186 @@ +import json +from typing import List +from uuid import uuid4 + +import requests +from fastapi import Request, APIRouter, Response, HTTPException, Query +from fastapi.responses import HTMLResponse, JSONResponse +from pydantic import BaseModel, ConfigDict, Field +from starlette.responses import RedirectResponse + +from jb.config import settings, JB_EVENTS_STREAM +from jb.decorators import REDIS, HM +from jb.flow.monitoring import emit_assignment_event, emit_mturk_notification_event +from jb.models.currency import USDCent +from jb.models.definitions import ReportValue, AssignmentStatus +from jb.models.event import MTurkEvent +from jb.settings import BASE_HTML +from jb.config import settings +from jb.views.tasks import process_request + +common_router = APIRouter(prefix="", tags=["API"], include_in_schema=True) + + +class ReportTask(BaseModel): + model_config = ConfigDict(extra="forbid") + + worker_id: str = Field() + + reasons: List[ReportValue] = Field( + examples=[[3, 4]], + default_factory=list, + ) + + notes: str = Field( + default="", examples=["The survey wanted to watch me eat Haejang-guk"] + ) + + +@common_router.post("/report/") +def report(request: Request, data: ReportTask): + url = f"{settings.fsb_host}{settings.product_id}/report/" + params = { + "bpuid": data.worker_id, + "reasons": [x.value for x in data.reasons], + "notes": data.notes, + } + + req = requests.post(url, json=params) + res = req.json() + if res.status_code != 200: + raise HTTPException( + status_code=res.status_code, detail="Failed to submit report" + ) + + return Response(res) + + +@common_router.get(path="/work/", response_class=HTMLResponse) +async def work(request: Request): + """ + HTML Page that the worker lands on in an iFrame. + They can either be previewing the HIT, or have already accepted it. + """ + amt_assignment_id = request.query_params.get("assignmentId", None) + worker_id = request.query_params.get("workerId", None) + amt_hit_id = request.query_params.get("hitId", None) + print(f"work: {amt_assignment_id=} {worker_id=} {amt_hit_id=}") + + if not worker_id: + return RedirectResponse( + url=f"/preview/?{request.url.query}" if request.url.query else "/preview/", + status_code=302, + ) + if amt_assignment_id is None or amt_assignment_id == "ASSIGNMENT_ID_NOT_AVAILABLE": + # Worker is previewing the HIT + amt_hit_type_id = "unknown" + if amt_hit_id: + hit = HM.get_from_amt_id(amt_hit_id=amt_hit_id) + amt_hit_type_id = hit.amt_hit_type_id + emit_assignment_event( + status=AssignmentStatus.PreviewState, amt_hit_type_id=amt_hit_type_id + ) + return RedirectResponse( + url=f"/preview/?{request.url.query}" if request.url.query else "/preview/", + status_code=302, + ) + + # The Worker has accepted the HIT + process_request(request) + + return HTMLResponse(BASE_HTML) + + +@common_router.get(path="/survey/", response_class=JSONResponse) +def survey( + request: Request, + worker_id: str = Query(), + duration: int = Query(default=1200), +): + if not worker_id: + raise HTTPException(status_code=400, detail="Missing worker_id") + + # (1) Check wallet + wallet_url = f"{settings.fsb_host}{settings.product_id}/wallet/" + wallet_res = requests.get(wallet_url, params={"bpuid": worker_id}) + if wallet_res.status_code != 200: + raise HTTPException(status_code=502, detail="Wallet check failed") + + wallet_data = wallet_res.json() + wallet_balance = wallet_data["wallet"]["amount"] + if wallet_balance < -100: + return JSONResponse( + { + "total_surveys": 0, + "link": None, + "duration": None, + "payout": None, + } + ) + + # (2) Get offerwall + client_ip = "69.253.144.55" if settings.debug else request.client.host + offerwall_url = f"{settings.fsb_host}{settings.product_id}/offerwall/d48cce47/" + offerwall_res = requests.get( + offerwall_url, + params={ + "bpuid": worker_id, + "ip": client_ip, + "n_bins": 1, + "duration": duration, + }, + ) + + if offerwall_res.status_code != 200: + raise HTTPException(status_code=502, detail="Offerwall request failed") + + try: + rj = offerwall_res.json() + bucket = rj["offerwall"]["buckets"][0] + return JSONResponse( + { + "total_surveys": rj["offerwall"]["availability_count"], + "link": bucket["uri"], + "duration": round(bucket["duration"]["q2"] / 60), + "payout": USDCent(bucket["payout"]["q2"]).to_usd_str(), + } + ) + except Exception: + return JSONResponse( + { + "total_surveys": 0, + "link": None, + "duration": None, + "payout": None, + } + ) + + +@common_router.post(path=f"/{settings.sns_path}/", include_in_schema=False) +async def mturk_notifications(request: Request): + """ + Our SNS topic will POST to this endpoint whenever we get a new message + """ + message = await request.json() + msg_type = message.get("Type") + + if msg_type == "SubscriptionConfirmation": + subscribe_url = message["SubscribeURL"] + print("Confirming SNS subscription...") + requests.get(subscribe_url) + + elif msg_type == "Notification": + msg = json.loads(message["Message"]) + print("Received MTurk event:", msg) + enqueue_mturk_notifications(msg) + + return {"status": "ok"} + + +def enqueue_mturk_notifications(msg) -> None: + for evt in msg["Events"]: + event = MTurkEvent.from_sns(evt) + emit_mturk_notification_event( + event_type=event.event_type, amt_hit_type_id=event.amt_hit_type_id + ) + REDIS.xadd(JB_EVENTS_STREAM, {"data": event.model_dump_json()}) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..469eda2 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,7 @@ +import random +import string + + +def generate_amt_id(length=30): + chars = string.ascii_uppercase + string.digits + return "".join(random.choices(chars, k=length)) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..985c9dc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,421 @@ +import copy +from datetime import datetime, timezone, timedelta +from typing import Optional +from uuid import uuid4 + +import pytest +from dateutil.tz import tzlocal +from mypy_boto3_mturk.type_defs import ( + GetHITResponseTypeDef, + CreateHITTypeResponseTypeDef, + CreateHITWithHITTypeResponseTypeDef, + GetAssignmentResponseTypeDef, +) + +from jb.decorators import HQM, HTM, HM, AM +from jb.managers.amt import AMTManager, APPROVAL_MESSAGE, NO_WORK_APPROVAL_MESSAGE +from jb.models.assignment import AssignmentStub, Assignment +from jb.models.currency import USDCent +from jb.models.definitions import HitStatus, HitReviewStatus, AssignmentStatus +from jb.models.hit import HitType, HitQuestion, Hit +from tests import generate_amt_id + + +@pytest.fixture +def amt_hit_type_id(): + return generate_amt_id() + + +@pytest.fixture +def amt_hit_id(): + return generate_amt_id() + + +@pytest.fixture +def amt_assignment_id(): + return generate_amt_id() + + +@pytest.fixture +def amt_worker_id(): + return generate_amt_id(length=21) + + +@pytest.fixture +def amt_group_id(): + return generate_amt_id() + + +@pytest.fixture +def tsid(): + return uuid4().hex + + +@pytest.fixture +def tsid1(): + return uuid4().hex + + +@pytest.fixture +def tsid2(): + return uuid4().hex + + +@pytest.fixture +def pe_id(): + # payout event / cashout request UUID + return uuid4().hex + + +@pytest.fixture +def hit_type() -> HitType: + return HitType( + title="Awesome Surveys!", + description="Give us your opinion", + reward=USDCent(5), + keywords="market,research,amazing", + min_active=10, + ) + + +from jb.models.hit import HitType + + +@pytest.fixture +def hit_type_with_amt_id(hit_type: HitType) -> HitType: + # This is a real hit type I've previously registered with amt (sandbox). + # It will always exist + hit_type.amt_hit_type_id = "3217B3DC4P5YW9DRV9R3X8O56V041J" + + # Get or create our db + HTM.get_or_create(hit_type) + # this call adds the pk int id ---^ + + return hit_type + + +@pytest.fixture +def question(): + return HitQuestion(url="https://jamesbillings67.com/work/", height=1200) + + +@pytest.fixture +def hit_in_amt(hit_type_with_amt_id: HitType, question: HitQuestion) -> Hit: + # Actually create a new HIT in amt (sandbox) + question = HQM.get_or_create(question) + hit = AMTManager.create_hit_with_hit_type( + hit_type=hit_type_with_amt_id, question=question + ) + # Create it in the DB + HM.create(hit) + return hit + + +@pytest.fixture +def hit(amt_hit_id, amt_hit_type_id, amt_group_id, question): + now = datetime.now(tz=timezone.utc) + return Hit.model_validate( + dict( + amt_hit_id=amt_hit_id, + amt_hit_type_id=amt_hit_type_id, + amt_group_id=amt_group_id, + status=HitStatus.Assignable, + review_status=HitReviewStatus.NotReviewed, + creation_time=now, + expiration=now + timedelta(days=3), + hit_question_xml=question.xml, + qualification_requirements=[], + max_assignments=1, + assignment_pending_count=0, + assignment_available_count=1, + assignment_completed_count=0, + description="Description", + keywords="Keywords", + reward=USDCent(5), + title="Title", + question_id=question.id, + hit_type_id=None, + ) + ) + + +@pytest.fixture +def hit_in_db( + hit_type: HitType, amt_hit_type_id, amt_hit_id, question: HitQuestion, hit: Hit +) -> Hit: + """ + Returns a hit that exists in our db, but does not in amazon (the amt ids + are random). The mtwerk_hittype and mtwerk_question records will also + exist (in the db) + """ + question = HQM.get_or_create(question) + hit_type.amt_hit_type_id = amt_hit_type_id + HTM.create(hit_type) + hit.hit_type_id = hit_type.id + hit.amt_hit_id = amt_hit_id + hit.question_id = question.id + HM.create(hit) + return hit + + +@pytest.fixture +def assignment_stub(hit: Hit, amt_assignment_id, amt_worker_id): + now = datetime.now(tz=timezone.utc) + return AssignmentStub( + amt_assignment_id=amt_assignment_id, + amt_hit_id=hit.amt_hit_id, + amt_worker_id=amt_worker_id, + status=AssignmentStatus.Submitted, + modified_at=now, + created_at=now, + ) + + +@pytest.fixture +def assignment_factory(hit: Hit): + def inner(amt_worker_id: str = None): + now = datetime.now(tz=timezone.utc) + amt_assignment_id = generate_amt_id() + amt_worker_id = amt_worker_id or generate_amt_id() + return Assignment( + amt_assignment_id=amt_assignment_id, + amt_hit_id=hit.amt_hit_id, + amt_worker_id=amt_worker_id, + status=AssignmentStatus.Submitted, + modified_at=now, + created_at=now, + accept_time=now, + auto_approval_time=now, + submit_time=now, + ) + + return inner + + +@pytest.fixture +def assignment_in_db_factory(assignment_factory): + def inner(hit_id: int, amt_worker_id: Optional[str] = None): + a = assignment_factory(amt_worker_id=amt_worker_id) + a.hit_id = hit_id + AM.create_stub(a) + AM.update_answer(a) + return a + + return inner + + +@pytest.fixture +def assignment_stub_in_db(hit_in_db, assignment_stub) -> AssignmentStub: + """ + Returns an AssignmentStub that exists in our db, but does not in amazon (the amt ids are random). + The mtwerk_hit, mtwerk_hittype, and mtwerk_question records will also exist (in the db) + """ + assignment_stub.hit_id = hit_in_db.id + AM.create_stub(assignment_stub) + return assignment_stub + + +@pytest.fixture +def amt_response_metadata(): + req_id = str(uuid4()) + return { + "RequestId": req_id, + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": req_id, + "content-type": "application/x-amz-json-1.1", + "content-length": "46", + "date": "Wed, 15 Oct 2025 02:16:16 GMT", + }, + "RetryAttempts": 0, + } + + +@pytest.fixture +def create_hit_type_response( + amt_hit_type_id, amt_response_metadata +) -> CreateHITTypeResponseTypeDef: + return { + "HITTypeId": amt_hit_type_id, + "ResponseMetadata": amt_response_metadata, + } + + +@pytest.fixture +def create_hit_with_hit_type_response( + amt_hit_type_id, amt_hit_id, amt_response_metadata +) -> CreateHITWithHITTypeResponseTypeDef: + amt_group_id = generate_amt_id(length=30) + return { + "HIT": { + "HITId": amt_hit_id, + "HITTypeId": amt_hit_type_id, + "HITGroupId": amt_group_id, + "CreationTime": datetime(2025, 10, 14, 20, 22, tzinfo=tzlocal()), + "Title": "Test", + "Description": "test", + "Question": '\n\n https://jamesbillings67.com/work/\n 1200\n ', + "HITStatus": "Assignable", + "MaxAssignments": 1, + "Reward": "0.05", + "AutoApprovalDelayInSeconds": 2_592_000, + "Expiration": datetime(2025, 10, 14, 20, 24, 3, tzinfo=tzlocal()), + "AssignmentDurationInSeconds": 123, + "QualificationRequirements": [], + "HITReviewStatus": "NotReviewed", + "NumberOfAssignmentsPending": 0, + "NumberOfAssignmentsAvailable": 1, + "NumberOfAssignmentsCompleted": 0, + }, + "ResponseMetadata": amt_response_metadata, + } + + +@pytest.fixture +def get_hit_response( + amt_hit_type_id, amt_hit_id, amt_response_metadata +) -> GetHITResponseTypeDef: + amt_group_id = generate_amt_id(length=30) + return { + "HIT": { + "HITId": amt_hit_id, + "HITTypeId": amt_hit_type_id, + "HITGroupId": amt_group_id, + "CreationTime": datetime(2025, 10, 13, 23, 0, 3, tzinfo=tzlocal()), + "Title": "Awesome Surveys!", + "Description": "Give us your opinion", + "Question": '\n\n https://jamesbillings67.com/work/\n 1200\n ', + "Keywords": "market,research,amazing", + "HITStatus": "Assignable", + "MaxAssignments": 1, + "Reward": "0.05", + "AutoApprovalDelayInSeconds": 604_800, + "Expiration": datetime(2025, 10, 27, 23, 0, 3, tzinfo=tzlocal()), + "AssignmentDurationInSeconds": 5_400, + "QualificationRequirements": [], + "HITReviewStatus": "NotReviewed", + "NumberOfAssignmentsPending": 0, + "NumberOfAssignmentsAvailable": 1, + "NumberOfAssignmentsCompleted": 0, + }, + "ResponseMetadata": amt_response_metadata, + } + + +@pytest.fixture +def get_hit_response_reviewing(get_hit_response): + res = copy.deepcopy(get_hit_response) + res["HIT"]["NumberOfAssignmentsAvailable"] = 0 + res["HIT"]["NumberOfAssignmentsCompleted"] = 1 + res["HIT"]["HITStatus"] = "Reviewing" + return res + + +@pytest.fixture +def get_assignment_response( + amt_hit_type_id, + amt_hit_id, + amt_assignment_id, + amt_worker_id, + get_hit_response, + amt_response_metadata, + tsid, +) -> GetAssignmentResponseTypeDef: + hit_response = get_hit_response["HIT"] + local_now = datetime.now(tz=tzlocal()) + return { + "Assignment": { + "AssignmentId": amt_assignment_id, + "WorkerId": amt_worker_id, + "HITId": amt_hit_id, + "AssignmentStatus": "Submitted", + "AutoApprovalTime": local_now + timedelta(days=7), + "AcceptTime": local_now - timedelta(minutes=10), + "SubmitTime": local_now, + "Deadline": local_now + timedelta(minutes=90), + "Answer": '\n' + '\n ' + "\n amt_worker_id\n " + f" {amt_worker_id}\n \n \n " + " amt_assignment_id\n " + f" {amt_assignment_id}\n \n \n " + f" tsid\n {tsid}\n " + " \n", + "RequesterFeedback": "Good work", + }, + "HIT": hit_response, + "ResponseMetadata": amt_response_metadata, + } + + +@pytest.fixture +def get_assignment_response_no_tsid( + get_assignment_response, amt_worker_id, amt_assignment_id +): + res = copy.deepcopy(get_assignment_response) + res["Assignment"]["Answer"] = ( + '\n' + '\n ' + "\n amt_worker_id\n " + f" {amt_worker_id}\n \n \n " + " amt_assignment_id\n " + f" {amt_assignment_id}\n \n " + # f"\n tsid\n {tsid}\n \n" + f"" + ) + return res + + +@pytest.fixture +def get_assignment_response_approved( + get_assignment_response: GetAssignmentResponseTypeDef, +): + def inner(feedback: str = APPROVAL_MESSAGE) -> GetAssignmentResponseTypeDef: + res = copy.deepcopy(get_assignment_response) + res["Assignment"]["AssignmentStatus"] = "Approved" + res["Assignment"]["RequesterFeedback"] = feedback + res["Assignment"]["ApprovalTime"] = res["Assignment"]["SubmitTime"] + return res + + return inner + + +@pytest.fixture +def get_assignment_response_rejected( + get_assignment_response: GetAssignmentResponseTypeDef, +): + + def inner(reject_reason: str = "reject reason") -> GetAssignmentResponseTypeDef: + res = copy.deepcopy(get_assignment_response) + res["Assignment"]["AssignmentStatus"] = "Rejected" + res["Assignment"]["RequesterFeedback"] = reject_reason + res["Assignment"]["RejectionTime"] = res["Assignment"]["SubmitTime"] + return res + + return inner + + +@pytest.fixture +def get_assignment_response_rejected_no_tsid( + get_assignment_response_no_tsid: GetAssignmentResponseTypeDef, +): + + def inner(reject_reason: str = "reject reason") -> GetAssignmentResponseTypeDef: + res = copy.deepcopy(get_assignment_response_no_tsid) + res["Assignment"]["AssignmentStatus"] = "Rejected" + res["Assignment"]["RequesterFeedback"] = reject_reason + res["Assignment"]["RejectionTime"] = res["Assignment"]["SubmitTime"] + return res + + return inner + + +@pytest.fixture +def get_assignment_response_approved_no_tsid( + get_assignment_response_no_tsid: GetAssignmentResponseTypeDef, +): + res = copy.deepcopy(get_assignment_response_no_tsid) + res["Assignment"]["AssignmentStatus"] = "Approved" + res["Assignment"]["RequesterFeedback"] = NO_WORK_APPROVAL_MESSAGE + res["Assignment"]["ApprovalTime"] = res["Assignment"]["SubmitTime"] + return res diff --git a/tests/http/test_notifications.py b/tests/http/test_notifications.py new file mode 100644 index 0000000..70458b8 --- /dev/null +++ b/tests/http/test_notifications.py @@ -0,0 +1,71 @@ +import json + +import pytest +from httpx import AsyncClient +import secrets + +from jb.config import JB_EVENTS_STREAM, settings +from jb.decorators import REDIS +from jb.models.event import MTurkEvent +from tests import generate_amt_id + + +def generate_hex_id(length: int = 40) -> str: + # length is number of hex chars, so we need length//2 bytes + return secrets.token_hex(length // 2) + + +@pytest.fixture +def example_mturk_event_body(amt_hit_id, amt_hit_type_id, amt_assignment_id): + return { + "Type": "Notification", + "Message": json.dumps( + { + "Events": [ + { + "EventType": "AssignmentSubmitted", + "EventTimestamp": "2025-10-16T18:45:51.000000Z", + "HITId": amt_hit_id, + "AssignmentId": amt_assignment_id, + "HITTypeId": amt_hit_type_id, + } + ], + "EventDocId": generate_hex_id(), + "SourceAccount": settings.aws_owner_id, + "CustomerId": generate_amt_id(length=14), + "EventDocVersion": "2006-05-05", + } + ), + } + + +@pytest.fixture() +def clean_mturk_events_redis_stream(): + REDIS.xtrim(JB_EVENTS_STREAM, maxlen=0) + assert REDIS.xlen(JB_EVENTS_STREAM) == 0 + yield + REDIS.xtrim(JB_EVENTS_STREAM, maxlen=0) + assert REDIS.xlen(JB_EVENTS_STREAM) == 0 + + +@pytest.mark.anyio +async def test_mturk_notifications( + httpxclient: AsyncClient, + no_limit, + example_mturk_event_body, + amt_assignment_id, + clean_mturk_events_redis_stream, +): + client = httpxclient + + res = await client.post(url=f"/{settings.sns_path}/", json=example_mturk_event_body) + res.raise_for_status() + + msg_res = REDIS.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100) + msg_res = msg_res[0][1][0] + msg_id, msg = msg_res + REDIS.xdel(JB_EVENTS_STREAM, msg_id) + + msg_json = msg["data"] + event = MTurkEvent.model_validate_json(msg_json) + assert event.amt_assignment_id == amt_assignment_id diff --git a/tests/http/test_status.py b/tests/http/test_status.py new file mode 100644 index 0000000..d88ff65 --- /dev/null +++ b/tests/http/test_status.py @@ -0,0 +1,78 @@ +from uuid import uuid4 + +import pytest +from httpx import AsyncClient + +from jb.config import settings +from tests import generate_amt_id + + +@pytest.mark.anyio +async def test_get_status_args(httpxclient: AsyncClient, no_limit): + client = httpxclient + + # tsid misformatted + res = await client.get(f"/status/{uuid4().hex[:-1]}/") + assert res.status_code == 422 + assert "String should have at least 32 characters" in res.text + + +@pytest.mark.anyio +async def test_get_status_error(httpxclient: AsyncClient, no_limit): + # Expects settings.fsb_host to point to a functional thl-fsb + client = httpxclient + + # tsid doesn't exist + res = await client.get(f"/status/{uuid4().hex}/") + assert res.status_code == 502 + assert res.json()["detail"] == "Failed to fetch status" + + +@pytest.mark.anyio +async def test_get_status_complete(httpxclient: AsyncClient, no_limit, mock_requests): + client = httpxclient + + tsid = uuid4().hex + url = f"{settings.fsb_host}{settings.product_id}/status/{tsid}/" + + mock_response = { + "tsid": tsid, + "product_id": settings.product_id, + "bpuid": generate_amt_id(length=21), + "started": "2022-06-29T23:43:48.247777Z", + "finished": "2022-06-29T23:56:57.632634Z", + "status": 3, + "payout": 81, + "user_payout": 77, + "payout_format": "${payout/100:.2f}", + "user_payout_string": "$0.77", + "kwargs": {}, + } + mock_requests.get(url, json=mock_response, status_code=200) + res = await client.get(f"/status/{tsid}/") + assert res.status_code == 200 + assert res.json() == {"status": 3, "payout": "$0.77"} + + +@pytest.mark.anyio +async def test_get_status_failure(httpxclient: AsyncClient, no_limit, mock_requests): + client = httpxclient + + tsid = uuid4().hex + url = f"{settings.fsb_host}{settings.product_id}/status/{tsid}/" + + mock_response = { + "tsid": tsid, + "product_id": settings.product_id, + "bpuid": "123ABC", + "status": 2, + "payout": 0, + "user_payout": 0, + "payout_format": "${payout/100:.2f}", + "user_payout_string": None, + "kwargs": {}, + } + mock_requests.get(url, json=mock_response, status_code=200) + res = await client.get(f"/status/{tsid}/") + assert res.status_code == 200 + assert res.json() == {"status": 2, "payout": None} diff --git a/tests/http/test_statuses.py b/tests/http/test_statuses.py new file mode 100644 index 0000000..ffc98fd --- /dev/null +++ b/tests/http/test_statuses.py @@ -0,0 +1,102 @@ +from datetime import datetime, timezone, timedelta +from urllib.parse import urlencode + +import pytest +from uuid import uuid4 +from httpx import AsyncClient + +from jb.config import settings + + +@pytest.mark.anyio +async def test_get_statuses(httpxclient: AsyncClient, no_limit, amt_worker_id): + # Expects settings.fsb_host to point to a functional thl-fsb + client = httpxclient + now = datetime.now(tz=timezone.utc) + + params = {"worker_id": amt_worker_id} + res = await client.get(f"/statuses/", params=params) + assert res.status_code == 200 + assert res.json() == [] + + params = {"worker_id": amt_worker_id, "started_after": now.isoformat()} + res = await client.get(f"/statuses/", params=params) + assert res.status_code == 422 + assert "Input should be a valid integer" in res.text + + +@pytest.fixture +def fsb_get_statuses_example_response(amt_worker_id, tsid1, tsid2): + return { + "tasks_status": [ + { + "tsid": tsid1, + "product_id": settings.product_id, + "bpuid": amt_worker_id, + "started": "2025-06-12T03:27:24.902280Z", + "finished": "2025-06-12T03:29:37.626481Z", + "status": 2, + "payout": 0, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "kwargs": {}, + "status_code_1": "SESSION_START_QUALITY_FAIL", + "status_code_2": "ENTRY_URL_MODIFICATION", + }, + { + "tsid": tsid2, + "product_id": settings.product_id, + "bpuid": amt_worker_id, + "started": "2025-06-12T03:30:18.176826Z", + "finished": "2025-06-12T03:36:58.789059Z", + "status": 2, + "payout": 0, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "kwargs": {}, + "status_code_1": "BUYER_QUALITY_FAIL", + "status_code_2": None, + }, + ] + } + + +@pytest.mark.anyio +async def test_get_statuses_mock( + httpxclient: AsyncClient, + no_limit, + amt_worker_id, + mock_requests, + fsb_get_statuses_example_response, + tsid1, + tsid2, +): + client = httpxclient + now = datetime.now(tz=timezone.utc) + started_after = now - timedelta(minutes=5) + + # The fsb call we are mocking ------v + params = { + "bpuid": amt_worker_id, + "started_after": round(started_after.timestamp()), + "started_before": round(now.timestamp()), + } + url = f"{settings.fsb_host}{settings.product_id}/status/" + "?" + urlencode(params) + mock_requests.get(url, json=fsb_get_statuses_example_response, status_code=200) + # ---- end mock + + params = { + "worker_id": amt_worker_id, + "started_after": round(started_after.timestamp()), + "started_before": round(now.timestamp()), + } + result = await client.get(f"/statuses/", params=params) + assert result.status_code == 200 + res = result.json() + assert len(res) == 2 + assert res == [ + {"status": 2, "tsid": tsid1}, + {"status": 2, "tsid": tsid2}, + ] -- cgit v1.2.3