diff options
| -rw-r--r-- | jb/managers/assignment.py | 3 | ||||
| -rw-r--r-- | jb/settings.py | 4 | ||||
| -rw-r--r-- | jb/views/common.py | 81 | ||||
| -rw-r--r-- | tests/conftest.py | 131 | ||||
| -rw-r--r-- | tests/http/conftest.py | 49 | ||||
| -rw-r--r-- | tests/http/test_basic.py | 37 | ||||
| -rw-r--r-- | tests/http/test_notifications.py | 163 | ||||
| -rw-r--r-- | tests/http/test_preview.py | 39 | ||||
| -rw-r--r-- | tests/http/test_work.py | 98 |
9 files changed, 406 insertions, 199 deletions
diff --git a/jb/managers/assignment.py b/jb/managers/assignment.py index dd3c866..c0a168e 100644 --- a/jb/managers/assignment.py +++ b/jb/managers/assignment.py @@ -17,7 +17,8 @@ class AssignmentManager(PostgresManager): query = sql.SQL( """ INSERT INTO mtwerk_assignment - (amt_assignment_id, amt_worker_id, status, created_at, modified_at, hit_id) + (amt_assignment_id, amt_worker_id, status, + created_at, modified_at, hit_id) VALUES (%(amt_assignment_id)s, %(amt_worker_id)s, %(status)s, %(created_at)s, %(modified_at)s, %(hit_id)s) diff --git a/jb/settings.py b/jb/settings.py index 5754add..d529591 100644 --- a/jb/settings.py +++ b/jb/settings.py @@ -60,12 +60,12 @@ class TestSettings(Settings): model_config = SettingsConfigDict( env_prefix="", case_sensitive=False, - env_file=os.path.join(BASE_DIR, ".env.dev"), + env_file=os.path.join(BASE_DIR, ".env.test"), extra="allow", cli_parse_args=False, ) debug: bool = True - app_name: str = "AMT JB API Development" + app_name: str = "AMT JB API Test" @lru_cache diff --git a/jb/views/common.py b/jb/views/common.py index 0dc8b56..c87b1a5 100644 --- a/jb/views/common.py +++ b/jb/views/common.py @@ -1,18 +1,16 @@ import json -from typing import List -from uuid import uuid4 +from unittest import case +from typing import Dict, Any import requests -from fastapi import Request, APIRouter, Response, HTTPException, Query -from fastapi.responses import HTMLResponse, JSONResponse -from pydantic import BaseModel, ConfigDict, Field +from fastapi import Request, APIRouter, HTTPException +from fastapi.responses import HTMLResponse from starlette.responses import RedirectResponse from jb.config import settings, JB_EVENTS_STREAM from jb.decorators import REDIS, HM from jb.flow.monitoring import emit_assignment_event, emit_mturk_notification_event -from generalresearchutils.currency import USDCent -from jb.models.definitions import ReportValue, AssignmentStatus +from jb.models.definitions import AssignmentStatus from jb.models.event import MTurkEvent from jb.settings import BASE_HTML from jb.config import settings @@ -21,40 +19,6 @@ from jb.views.tasks import process_request common_router = APIRouter(prefix="", tags=["API"], include_in_schema=True) -class ReportTask(BaseModel): - model_config = ConfigDict(extra="forbid") - - worker_id: str = Field() - - reasons: List[ReportValue] = Field( - examples=[[3, 4]], - default_factory=list, - ) - - notes: str = Field( - default="", examples=["The survey wanted to watch me eat Haejang-guk"] - ) - - -@common_router.post("/report/") -def report(request: Request, data: ReportTask): - url = f"{settings.fsb_host}{settings.product_id}/report/" - params = { - "bpuid": data.worker_id, - "reasons": [x.value for x in data.reasons], - "notes": data.notes, - } - - req = requests.post(url, json=params) - res = req.json() - if res.status_code != 200: - raise HTTPException( - status_code=res.status_code, detail="Failed to submit report" - ) - - return Response(res) - - @common_router.get(path="/work/", response_class=HTMLResponse) async def work(request: Request): """ @@ -86,8 +50,11 @@ async def work(request: Request): status_code=302, ) - # The Worker has accepted the HIT - process_request(request) + try: + # The Worker has accepted the HIT + process_request(request) + except Exception: + raise HTTPException(status_code=500, detail="Error processing request") return HTMLResponse(BASE_HTML) @@ -97,23 +64,29 @@ async def mturk_notifications(request: Request): """ Our SNS topic will POST to this endpoint whenever we get a new message """ - message = await request.json() - msg_type = message.get("Type") + try: + message = await request.json() + except Exception: + raise HTTPException(status_code=500, detail="No POST data") + + match message.get("Type"): + case "SubscriptionConfirmation": + subscribe_url = message["SubscribeURL"] + print("Confirming SNS subscription...") + requests.get(subscribe_url) - if msg_type == "SubscriptionConfirmation": - subscribe_url = message["SubscribeURL"] - print("Confirming SNS subscription...") - requests.get(subscribe_url) + case "Notification": + msg = json.loads(message["Message"]) + print("Received MTurk event:", msg) + enqueue_mturk_notifications(msg) - elif msg_type == "Notification": - msg = json.loads(message["Message"]) - print("Received MTurk event:", msg) - enqueue_mturk_notifications(msg) + case _: + raise HTTPException(status_code=500, detail="Invalid JSON") return {"status": "ok"} -def enqueue_mturk_notifications(msg) -> None: +def enqueue_mturk_notifications(msg: Dict[str, Any]) -> None: for evt in msg["Events"]: event = MTurkEvent.from_sns(evt) emit_mturk_notification_event( diff --git a/tests/conftest.py b/tests/conftest.py index 3318f1c..a33b149 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import copy from datetime import datetime, timezone, timedelta import os -from typing import Optional +from typing import Optional, TYPE_CHECKING, Callable, Dict, Any from uuid import uuid4 from dotenv import load_dotenv import pytest @@ -20,6 +20,16 @@ from generalresearchutils.currency import USDCent from jb.models.definitions import HitStatus, HitReviewStatus, AssignmentStatus from jb.models.hit import HitType, HitQuestion, Hit from tests import generate_amt_id +from _pytest.config import Config + +if TYPE_CHECKING: + from jb.settings import Settings + from jb.managers.hit import HitQuestionManager, HitTypeManager, HitManager + from jb.managers.assignment import AssignmentManager + from jb.managers.bonus import BonusManager + + +# --- IDs and Identifiers --- @pytest.fixture @@ -67,7 +77,7 @@ def pe_id() -> str: @pytest.fixture(scope="session") -def env_file_path(pytestconfig): +def env_file_path(pytestconfig: Config) -> str: root_path = pytestconfig.rootpath env_path = os.path.join(root_path, ".env.test") @@ -78,7 +88,7 @@ def env_file_path(pytestconfig): @pytest.fixture(scope="session") -def settings(env_file_path) -> "Settings": +def settings(env_file_path: str) -> "Settings": from jb.settings import Settings as JBSettings s = JBSettings(_env_file=env_file_path) @@ -90,7 +100,7 @@ def settings(env_file_path) -> "Settings": @pytest.fixture(scope="session") -def redis(settings): +def redis(settings: "Settings"): from generalresearchutils.redis_helper import RedisConfig redis_config = RedisConfig( @@ -103,7 +113,7 @@ def redis(settings): @pytest.fixture(scope="session") -def pg_config(settings) -> PostgresConfig: +def pg_config(settings: "Settings") -> PostgresConfig: return PostgresConfig( dsn=settings.amt_jb_db, connect_timeout=1, @@ -115,8 +125,10 @@ def pg_config(settings) -> PostgresConfig: @pytest.fixture(scope="session") -def hqm(pg_config) -> "HitQuestionManager": - assert "/unittest-" in pg_config.dsn.path +def hqm(pg_config: PostgresConfig) -> "HitQuestionManager": + assert ( + pg_config.dsn.path and "/unittest-" in pg_config.dsn.path + ), "pg_config must point to a unittest database (dsn path must contain '/unittest-')" from jb.managers.hit import HitQuestionManager @@ -126,8 +138,10 @@ def hqm(pg_config) -> "HitQuestionManager": @pytest.fixture(scope="session") -def htm(pg_config) -> "HitTypeManager": - assert "/unittest-" in pg_config.dsn.path +def htm(pg_config: PostgresConfig) -> "HitTypeManager": + assert ( + pg_config.dsn.path and "/unittest-" in pg_config.dsn.path + ), "pg_config must point to a unittest database (dsn path must contain '/unittest-')" from jb.managers.hit import HitTypeManager @@ -137,8 +151,10 @@ def htm(pg_config) -> "HitTypeManager": @pytest.fixture(scope="session") -def hm(pg_config) -> "HitManager": - assert "/unittest-" in pg_config.dsn.path +def hm(pg_config: PostgresConfig) -> "HitManager": + assert ( + pg_config.dsn.path and "/unittest-" in pg_config.dsn.path + ), "pg_config must point to a unittest database (dsn path must contain '/unittest-')" from jb.managers.hit import HitManager @@ -148,8 +164,10 @@ def hm(pg_config) -> "HitManager": @pytest.fixture(scope="session") -def am(pg_config) -> "AssignmentManager": - assert "/unittest-" in pg_config.dsn.path +def am(pg_config: PostgresConfig) -> "AssignmentManager": + assert ( + pg_config.dsn.path and "/unittest-" in pg_config.dsn.path + ), "pg_config must point to a unittest database (dsn path must contain '/unittest-')" from jb.managers.assignment import AssignmentManager @@ -159,8 +177,10 @@ def am(pg_config) -> "AssignmentManager": @pytest.fixture(scope="session") -def bm(pg_config) -> "BonusManager": - assert "/unittest-" in pg_config.dsn.path +def bm(pg_config: PostgresConfig) -> "BonusManager": + assert ( + pg_config.dsn.path and "/unittest-" in pg_config.dsn.path + ), "pg_config must point to a unittest database (dsn path must contain '/unittest-')" from jb.managers.bonus import BonusManager @@ -178,7 +198,7 @@ def question() -> HitQuestion: @pytest.fixture -def question_record(hqm, question) -> HitQuestion: +def question_record(hqm: "HitQuestionManager", question: HitQuestion) -> HitQuestion: return hqm.get_or_create(question) @@ -197,14 +217,14 @@ def hit_type() -> HitType: @pytest.fixture -def hit_type_record(htm, hit_type) -> HitType: +def hit_type_record(htm: "HitTypeManager", hit_type: HitType) -> HitType: hit_type.amt_hit_type_id = generate_amt_id() return htm.get_or_create(hit_type) @pytest.fixture -def hit_type_with_amt_id(htm, hit_type: HitType) -> HitType: +def hit_type_with_amt_id(htm: "HitTypeManager", hit_type: HitType) -> HitType: # This is a real hit type I've previously registered with amt (sandbox). # It will always exist hit_type.amt_hit_type_id = "3217B3DC4P5YW9DRV9R3X8O56V041J" @@ -225,7 +245,9 @@ def amt_hit_id() -> str: @pytest.fixture -def hit(amt_hit_id, amt_hit_type_id, amt_group_id, question) -> Hit: +def hit( + amt_hit_id: str, amt_hit_type_id: str, amt_group_id: str, question: HitQuestion +) -> Hit: now = datetime.now(tz=timezone.utc) return Hit.model_validate( @@ -255,11 +277,11 @@ def hit(amt_hit_id, amt_hit_type_id, amt_group_id, question) -> Hit: @pytest.fixture def hit_record( - hm, - question_record, - hit_type_record, - hit, - amt_hit_id, + hm: "HitManager", + question_record: HitQuestion, + hit_type_record: HitType, + hit: Hit, + amt_hit_id: str, ) -> Hit: """ Returns a hit that exists in our db, but does not in amazon (the amt ids @@ -276,7 +298,9 @@ def hit_record( @pytest.fixture -def hit_in_amt(hm, question_record, hit_type_with_amt_id: HitType) -> Hit: +def hit_in_amt( + hm: "HitManager", question_record: HitQuestion, hit_type_with_amt_id: HitType +) -> Hit: # Actually create a new HIT in amt (sandbox) hit = AMTManager.create_hit_with_hit_type( hit_type=hit_type_with_amt_id, question=question_record @@ -290,7 +314,7 @@ def hit_in_amt(hm, question_record, hit_type_with_amt_id: HitType) -> Hit: @pytest.fixture -def assignment_stub(hit: Hit, amt_assignment_id, amt_worker_id): +def assignment_stub(hit: Hit, amt_assignment_id: str, amt_worker_id: str): now = datetime.now(tz=timezone.utc) return AssignmentStub( amt_assignment_id=amt_assignment_id, @@ -303,11 +327,27 @@ def assignment_stub(hit: Hit, amt_assignment_id, amt_worker_id): @pytest.fixture +def assignment_stub_record( + am: "AssignmentManager", hit_record: Hit, assignment_stub: AssignmentStub +) -> AssignmentStub: + """ + Returns an AssignmentStub that exists in our db, but does not in + amazon (the amt ids are random). The mtwerk_hit, mtwerk_hittype, and + mtwerk_question records will also exist (in the db) + """ + assignment_stub.hit_id = hit_record.id + am.create_stub(stub=assignment_stub) + return assignment_stub + + +@pytest.fixture def assignment_factory(hit: Hit): - def inner(amt_worker_id: str = None): + + def inner(amt_worker_id: str = None) -> Assignment: now = datetime.now(tz=timezone.utc) amt_assignment_id = generate_amt_id() amt_worker_id = amt_worker_id or generate_amt_id() + return Assignment( amt_assignment_id=amt_assignment_id, amt_hit_id=hit.amt_hit_id, @@ -324,7 +364,10 @@ def assignment_factory(hit: Hit): @pytest.fixture -def assignment_in_db_factory(am, assignment_factory): +def assignment_record_factory( + am: "AssignmentManager", assignment_factory: Callable[..., Assignment] +): + def inner(hit_id: int, amt_worker_id: Optional[str] = None): a = assignment_factory(amt_worker_id=amt_worker_id) a.hit_id = hit_id @@ -335,22 +378,11 @@ def assignment_in_db_factory(am, assignment_factory): return inner -@pytest.fixture -def assignment_stub_in_db(am, hit_record, assignment_stub) -> AssignmentStub: - """ - Returns an AssignmentStub that exists in our db, but does not in amazon (the amt ids are random). - The mtwerk_hit, mtwerk_hittype, and mtwerk_question records will also exist (in the db) - """ - assignment_stub.hit_id = hit_record.id - am.create_stub(assignment_stub) - return assignment_stub - - -# --- HIT --- +# --- Response --- @pytest.fixture -def amt_response_metadata(): +def amt_response_metadata() -> Dict[str, Any]: req_id = str(uuid4()) return { "RequestId": req_id, @@ -367,7 +399,7 @@ def amt_response_metadata(): @pytest.fixture def create_hit_type_response( - amt_hit_type_id, amt_response_metadata + amt_hit_type_id: str, amt_response_metadata: Dict[str, Any] ) -> CreateHITTypeResponseTypeDef: return { "HITTypeId": amt_hit_type_id, @@ -377,7 +409,7 @@ def create_hit_type_response( @pytest.fixture def create_hit_with_hit_type_response( - amt_hit_type_id, amt_hit_id, amt_response_metadata + amt_hit_type_id: str, amt_hit_id: str, amt_response_metadata ) -> CreateHITWithHITTypeResponseTypeDef: amt_group_id = generate_amt_id(length=30) return { @@ -407,7 +439,7 @@ def create_hit_with_hit_type_response( @pytest.fixture def get_hit_response( - amt_hit_type_id, amt_hit_id, amt_response_metadata + amt_hit_type_id: str, amt_hit_id: str, amt_response_metadata ) -> GetHITResponseTypeDef: amt_group_id = generate_amt_id(length=30) return { @@ -447,13 +479,12 @@ def get_hit_response_reviewing(get_hit_response): @pytest.fixture def get_assignment_response( - amt_hit_type_id, - amt_hit_id, - amt_assignment_id, - amt_worker_id, + amt_hit_id: str, + amt_assignment_id: str, + amt_worker_id: str, get_hit_response, amt_response_metadata, - tsid, + tsid: str, ) -> GetAssignmentResponseTypeDef: hit_response = get_hit_response["HIT"] local_now = datetime.now(tz=tzlocal()) @@ -484,7 +515,7 @@ def get_assignment_response( @pytest.fixture def get_assignment_response_no_tsid( - get_assignment_response, amt_worker_id, amt_assignment_id + get_assignment_response, amt_worker_id: str, amt_assignment_id: str ): res = copy.deepcopy(get_assignment_response) res["Assignment"]["Answer"] = ( diff --git a/tests/http/conftest.py b/tests/http/conftest.py index 200bf1c..4f11fde 100644 --- a/tests/http/conftest.py +++ b/tests/http/conftest.py @@ -1,10 +1,19 @@ import httpx +import redis import pytest import requests_mock from asgi_lifespan import LifespanManager from httpx import AsyncClient, ASGITransport +from typing import Dict, Any from jb.main import app +import json + +from httpx import AsyncClient +import secrets + +from jb.config import JB_EVENTS_STREAM, settings +from tests import generate_amt_id @pytest.fixture(scope="session") @@ -48,3 +57,43 @@ def httpxclient_ip(httpxclient): def mock_requests(): with requests_mock.Mocker() as m: yield m + + +def generate_hex_id(length: int = 40) -> str: + # length is number of hex chars, so we need length//2 bytes + return secrets.token_hex(length // 2) + + +@pytest.fixture +def mturk_event_body_record( + hit_record: Hit, assignment_stub_record: AssignmentStub +) -> Dict[str, Any]: + return { + "Type": "Notification", + "Message": json.dumps( + { + "Events": [ + { + "EventType": "AssignmentSubmitted", + "EventTimestamp": "2025-10-16T18:45:51.000000Z", + "HITId": hit_record.amt_hit_id, + "AssignmentId": assignment_stub_record.amt_assignment_id, + "HITTypeId": hit_record.amt_hit_type_id, + } + ], + "EventDocId": generate_hex_id(), + "SourceAccount": settings.aws_owner_id, + "CustomerId": generate_amt_id(length=14), + "EventDocVersion": "2006-05-05", + } + ), + } + + +@pytest.fixture() +def clean_mturk_events_redis_stream(redis: redis.Redis): + redis.xtrim(JB_EVENTS_STREAM, maxlen=0) + assert redis.xlen(JB_EVENTS_STREAM) == 0 + yield + redis.xtrim(JB_EVENTS_STREAM, maxlen=0) + assert redis.xlen(JB_EVENTS_STREAM) == 0 diff --git a/tests/http/test_basic.py b/tests/http/test_basic.py index 18359da..e806fa9 100644 --- a/tests/http/test_basic.py +++ b/tests/http/test_basic.py @@ -2,23 +2,24 @@ import pytest from httpx import AsyncClient -@pytest.mark.anyio -async def test_base(httpxclient: AsyncClient): - client = httpxclient - res = await client.get("/") - # actually returns 404. old test expects 401. idk what is should be - print(res.text) - # assert res.status_code == 404 - assert res.status_code == 200 +class TestBase: + @pytest.mark.anyio + async def test_base(self, httpxclient: AsyncClient): + client = httpxclient + res = await client.get("/") + # actually returns 404. old test expects 401. idk what is should be + print(res.text) + # assert res.status_code == 404 + assert res.status_code == 200 -@pytest.mark.anyio -async def test_static_file_alias(httpxclient: AsyncClient): - client = httpxclient - """ - These are here for site crawlers and stuff.. - """ - for p in ["/robots.txt", "/sitemap.xml", "/favicon.ico"]: - res = await client.get(p) - assert res.status_code == 200, p - assert res.json() == {} + @pytest.mark.anyio + async def test_static_file_alias(self, httpxclient: AsyncClient): + client = httpxclient + """ + These are here for site crawlers and stuff.. + """ + for p in ["/robots.txt", "/sitemap.xml", "/favicon.ico"]: + res = await client.get(p) + assert res.status_code == 200, p + assert res.json() == {} diff --git a/tests/http/test_notifications.py b/tests/http/test_notifications.py index 6770044..4386863 100644 --- a/tests/http/test_notifications.py +++ b/tests/http/test_notifications.py @@ -1,71 +1,102 @@ -import json - import pytest +import json +import redis +from typing import Dict, Any from httpx import AsyncClient -import secrets +from uuid import uuid4 from jb.config import JB_EVENTS_STREAM, settings from jb.models.event import MTurkEvent -from tests import generate_amt_id - - -def generate_hex_id(length: int = 40) -> str: - # length is number of hex chars, so we need length//2 bytes - return secrets.token_hex(length // 2) - - -@pytest.fixture -def example_mturk_event_body(amt_hit_id, amt_hit_type_id, amt_assignment_id): - return { - "Type": "Notification", - "Message": json.dumps( - { - "Events": [ - { - "EventType": "AssignmentSubmitted", - "EventTimestamp": "2025-10-16T18:45:51.000000Z", - "HITId": amt_hit_id, - "AssignmentId": amt_assignment_id, - "HITTypeId": amt_hit_type_id, - } - ], - "EventDocId": generate_hex_id(), - "SourceAccount": settings.aws_owner_id, - "CustomerId": generate_amt_id(length=14), - "EventDocVersion": "2006-05-05", - } - ), - } - - -@pytest.fixture() -def clean_mturk_events_redis_stream(redis): - redis.xtrim(JB_EVENTS_STREAM, maxlen=0) - assert redis.xlen(JB_EVENTS_STREAM) == 0 - yield - redis.xtrim(JB_EVENTS_STREAM, maxlen=0) - assert redis.xlen(JB_EVENTS_STREAM) == 0 - - -@pytest.mark.anyio -async def test_mturk_notifications( - redis, - httpxclient: AsyncClient, - no_limit, - example_mturk_event_body, - amt_assignment_id, - clean_mturk_events_redis_stream, -): - client = httpxclient - - res = await client.post(url=f"/{settings.sns_path}/", json=example_mturk_event_body) - res.raise_for_status() - - msg_res = redis.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100) - msg_res = msg_res[0][1][0] - msg_id, msg = msg_res - redis.xdel(JB_EVENTS_STREAM, msg_id) - - msg_json = msg["data"] - event = MTurkEvent.model_validate_json(msg_json) - assert event.amt_assignment_id == amt_assignment_id +from jb.models.hit import Hit + + +class TestNotifications: + + @pytest.mark.anyio + async def test_no_post_keys( + self, + httpxclient: AsyncClient, + ): + client = httpxclient + + res = await client.post(url=f"/{settings.sns_path}/", json={}) + assert res.status_code == 500 + assert res.json() == {"detail": "Invalid JSON"} + + @pytest.mark.anyio + async def test_no_post_data( + self, + httpxclient: AsyncClient, + ): + client = httpxclient + + res = await client.post(url=f"/{settings.sns_path}/") + assert res.status_code == 500 + assert res.json() == {"detail": "No POST data"} + + @pytest.mark.anyio + async def test_invalid_post_data( + self, + httpxclient: AsyncClient, + ): + client = httpxclient + + res = await client.post( + url=f"/{settings.sns_path}/", json={uuid4().hex: uuid4().hex} + ) + assert res.status_code == 500 + assert res.json() == {"detail": "Invalid JSON"} + + @pytest.mark.anyio + async def test_mturk_notifications( + self, + redis: redis.Redis, + httpxclient: AsyncClient, + hit_record: Hit, + assignment_stub_record: AssignmentStub, + mturk_event_body_record: Dict[str, Any], + ): + client = httpxclient + + json_msg = json.loads(mturk_event_body_record["Message"]) + # Assert the mturk event is owned by the correct account + assert json_msg["SourceAccount"] == settings.aws_owner_id + + # Assert the HIT and Assignment are "connected" + assert assignment_stub_record.hit_id == hit_record.id + assert assignment_stub_record.amt_hit_id == hit_record.amt_hit_id + + # Assert the event body and the Assignment/HIT are "connected" + assert json_msg["Events"][0]["HITId"] == hit_record.amt_hit_id + assert json_msg["Events"][0]["HITId"] == assignment_stub_record.amt_hit_id + assert ( + json_msg["Events"][0]["AssignmentId"] + == assignment_stub_record.amt_assignment_id + ) + + res = await client.post( + url=f"/{settings.sns_path}/", json=mturk_event_body_record + ) + res.raise_for_status() + + # AMT SNS needs to receive a 200 response to stop retrying the notification + assert res.status_code == 200 + assert res.json() == {"status": "ok"} + + # Check that the event was enqueued in Redis + msg_res = redis.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100) + msg_res = msg_res[0][1][0] + msg_id, msg = msg_res + redis.xdel(JB_EVENTS_STREAM, msg_id) + + msg_json = msg["data"] + event = MTurkEvent.model_validate_json(msg_json) + + # Confirm that the event that we got from redis is what was POSTED by + # AMT SNS. We're using the fixture() so the values are the same as + # what was in the example_mturk_event_body + assert event.event_type == "AssignmentSubmitted" + + assert event.amt_assignment_id == assignment_stub_record.amt_assignment_id + assert event.amt_hit_id == hit_record.amt_hit_id + assert event.amt_hit_type_id == hit_record.amt_hit_type_id diff --git a/tests/http/test_preview.py b/tests/http/test_preview.py index 2bdf265..467c63c 100644 --- a/tests/http/test_preview.py +++ b/tests/http/test_preview.py @@ -3,8 +3,8 @@ import pytest from httpx import AsyncClient - from jb.models.hit import Hit +from jb.models.assignment import AssignmentStub class TestPreview: @@ -15,8 +15,6 @@ class TestPreview: res = await client.get("/preview/") assert res.status_code == 200 - # the response is an html page - assert res.headers["content-type"] == "text/html; charset=utf-8" assert res.num_bytes_downloaded == 507 @@ -25,11 +23,17 @@ class TestPreview: assert "https://cdn.jamesbillings67.com/james-billings.js" in res.text @pytest.mark.anyio - async def test_preview_redirect_from_work( - self, httpxclient: AsyncClient, amt_hit_id, amt_assignment_id + async def test_preview_redirect_from_work_random_str( + self, httpxclient: AsyncClient, amt_hit_id: str, amt_assignment_id: str ): client = httpxclient + """ + The redirect occurs regardless of any parameter validation. This is + because the query params will be used for record lookup on the work + page itself. + """ + params = { "workerId": None, "assignmentId": amt_assignment_id, @@ -39,3 +43,28 @@ class TestPreview: assert res.status_code == 302 assert "/preview/" in res.headers["location"] + + @pytest.mark.anyio + async def test_preview_redirect_from_work_records( + self, + httpxclient: AsyncClient, + hit_record: Hit, + assignment_stub_record: AssignmentStub, + ): + client = httpxclient + + """ + The redirect occurs regardless of any parameter validation. This is + because the query params will be used for record lookup on the work + page itself. + """ + + params = { + "workerId": None, + "assignmentId": assignment_stub_record.amt_assignment_id, + "hitId": hit_record.amt_hit_id, + } + res = await client.get("/work/", params=params) + + assert res.status_code == 302 + assert "/preview/" in res.headers["location"] diff --git a/tests/http/test_work.py b/tests/http/test_work.py index c69118b..66251f6 100644 --- a/tests/http/test_work.py +++ b/tests/http/test_work.py @@ -1,5 +1,9 @@ import pytest from httpx import AsyncClient +from jb.models.hit import Hit +from jb.models.assignment import AssignmentStub + +from jb.managers.assignment import AssignmentManager class TestWork: @@ -8,17 +12,105 @@ class TestWork: async def test_work( self, httpxclient: AsyncClient, - hit_record, - amt_assignment_id, - amt_worker_id, + hit_record: Hit, + amt_assignment_id: str, + amt_worker_id: str, ): client = httpxclient + assert isinstance(hit_record.id, int) + params = { "workerId": amt_worker_id, "assignmentId": amt_assignment_id, "hitId": hit_record.amt_hit_id, } res = await client.get("/work/", params=params) + assert res.status_code == 200 + + @pytest.mark.anyio + async def test_work_no_hit_record( + self, + httpxclient: AsyncClient, + hit: Hit, + amt_assignment_id: str, + amt_worker_id: str, + ): + client = httpxclient + + # Because no AssignmentStub record is created, and we're just using + # random strings as IDs, we should also confirm that the Hit record + # is not a saved record. + assert hit.id is None + + params = { + "workerId": amt_worker_id, + "assignmentId": amt_assignment_id, + "hitId": hit.amt_hit_id, + } + res = await client.get("/work/", params=params) + assert res.status_code == 500 + + @pytest.mark.anyio + async def test_work_assignment_stub_existing( + self, + httpxclient: AsyncClient, + am: AssignmentManager, + hit: Hit, + assignment_stub_record: AssignmentStub, + amt_assignment_id: str, + amt_worker_id: str, + ): + client = httpxclient + + # Because the AssignmentStub is created with a reference to the Hit, + # the Hit is actually a "Hit Record" (with a primary key), so it's + # saved in the database + assert isinstance(hit.id, int) + + # Confirm that it exists in the database before the call + res = am.get_stub_if_exists(amt_assignment_id=amt_assignment_id) + assert isinstance(res, AssignmentStub) + assert isinstance(res.id, int) + + params = { + "workerId": amt_worker_id, + "assignmentId": assignment_stub_record.amt_assignment_id, + "hitId": hit.amt_hit_id, + } + res = await client.get("/work/", params=params) + assert res.status_code == 200 + + # Confirm that it exists in the database + res = am.get_stub_if_exists(amt_assignment_id=amt_assignment_id) + assert isinstance(res, AssignmentStub) + assert isinstance(res.id, int) + + @pytest.mark.anyio + async def test_work_assignment_stub_created( + self, + httpxclient: AsyncClient, + am: AssignmentManager, + hit_record: Hit, + assignment_stub: AssignmentStub, + amt_assignment_id: str, + amt_worker_id: str, + ): + client = httpxclient + + # Confirm that it exists in the database before the call + res = am.get_stub_if_exists(amt_assignment_id=amt_assignment_id) + assert res is None + params = { + "workerId": amt_worker_id, + "assignmentId": assignment_stub.amt_assignment_id, + "hitId": hit_record.amt_hit_id, + } + res = await client.get("/work/", params=params) assert res.status_code == 200 + + # Confirm that it exists in the database + res = am.get_stub_if_exists(amt_assignment_id=amt_assignment_id) + assert isinstance(res, AssignmentStub) + assert isinstance(res.id, int) |
