diff options
Diffstat (limited to 'tests/http')
| -rw-r--r-- | tests/http/conftest.py | 49 | ||||
| -rw-r--r-- | tests/http/test_basic.py | 37 | ||||
| -rw-r--r-- | tests/http/test_notifications.py | 163 | ||||
| -rw-r--r-- | tests/http/test_preview.py | 39 | ||||
| -rw-r--r-- | tests/http/test_work.py | 98 |
5 files changed, 294 insertions, 92 deletions
diff --git a/tests/http/conftest.py b/tests/http/conftest.py index 200bf1c..4f11fde 100644 --- a/tests/http/conftest.py +++ b/tests/http/conftest.py @@ -1,10 +1,19 @@ import httpx +import redis import pytest import requests_mock from asgi_lifespan import LifespanManager from httpx import AsyncClient, ASGITransport +from typing import Dict, Any from jb.main import app +import json + +from httpx import AsyncClient +import secrets + +from jb.config import JB_EVENTS_STREAM, settings +from tests import generate_amt_id @pytest.fixture(scope="session") @@ -48,3 +57,43 @@ def httpxclient_ip(httpxclient): def mock_requests(): with requests_mock.Mocker() as m: yield m + + +def generate_hex_id(length: int = 40) -> str: + # length is number of hex chars, so we need length//2 bytes + return secrets.token_hex(length // 2) + + +@pytest.fixture +def mturk_event_body_record( + hit_record: Hit, assignment_stub_record: AssignmentStub +) -> Dict[str, Any]: + return { + "Type": "Notification", + "Message": json.dumps( + { + "Events": [ + { + "EventType": "AssignmentSubmitted", + "EventTimestamp": "2025-10-16T18:45:51.000000Z", + "HITId": hit_record.amt_hit_id, + "AssignmentId": assignment_stub_record.amt_assignment_id, + "HITTypeId": hit_record.amt_hit_type_id, + } + ], + "EventDocId": generate_hex_id(), + "SourceAccount": settings.aws_owner_id, + "CustomerId": generate_amt_id(length=14), + "EventDocVersion": "2006-05-05", + } + ), + } + + +@pytest.fixture() +def clean_mturk_events_redis_stream(redis: redis.Redis): + redis.xtrim(JB_EVENTS_STREAM, maxlen=0) + assert redis.xlen(JB_EVENTS_STREAM) == 0 + yield + redis.xtrim(JB_EVENTS_STREAM, maxlen=0) + assert redis.xlen(JB_EVENTS_STREAM) == 0 diff --git a/tests/http/test_basic.py b/tests/http/test_basic.py index 18359da..e806fa9 100644 --- a/tests/http/test_basic.py +++ b/tests/http/test_basic.py @@ -2,23 +2,24 @@ import pytest from httpx import AsyncClient -@pytest.mark.anyio -async def test_base(httpxclient: AsyncClient): - client = httpxclient - res = await client.get("/") - # actually returns 404. old test expects 401. idk what is should be - print(res.text) - # assert res.status_code == 404 - assert res.status_code == 200 +class TestBase: + @pytest.mark.anyio + async def test_base(self, httpxclient: AsyncClient): + client = httpxclient + res = await client.get("/") + # actually returns 404. old test expects 401. idk what is should be + print(res.text) + # assert res.status_code == 404 + assert res.status_code == 200 -@pytest.mark.anyio -async def test_static_file_alias(httpxclient: AsyncClient): - client = httpxclient - """ - These are here for site crawlers and stuff.. - """ - for p in ["/robots.txt", "/sitemap.xml", "/favicon.ico"]: - res = await client.get(p) - assert res.status_code == 200, p - assert res.json() == {} + @pytest.mark.anyio + async def test_static_file_alias(self, httpxclient: AsyncClient): + client = httpxclient + """ + These are here for site crawlers and stuff.. + """ + for p in ["/robots.txt", "/sitemap.xml", "/favicon.ico"]: + res = await client.get(p) + assert res.status_code == 200, p + assert res.json() == {} diff --git a/tests/http/test_notifications.py b/tests/http/test_notifications.py index 6770044..4386863 100644 --- a/tests/http/test_notifications.py +++ b/tests/http/test_notifications.py @@ -1,71 +1,102 @@ -import json - import pytest +import json +import redis +from typing import Dict, Any from httpx import AsyncClient -import secrets +from uuid import uuid4 from jb.config import JB_EVENTS_STREAM, settings from jb.models.event import MTurkEvent -from tests import generate_amt_id - - -def generate_hex_id(length: int = 40) -> str: - # length is number of hex chars, so we need length//2 bytes - return secrets.token_hex(length // 2) - - -@pytest.fixture -def example_mturk_event_body(amt_hit_id, amt_hit_type_id, amt_assignment_id): - return { - "Type": "Notification", - "Message": json.dumps( - { - "Events": [ - { - "EventType": "AssignmentSubmitted", - "EventTimestamp": "2025-10-16T18:45:51.000000Z", - "HITId": amt_hit_id, - "AssignmentId": amt_assignment_id, - "HITTypeId": amt_hit_type_id, - } - ], - "EventDocId": generate_hex_id(), - "SourceAccount": settings.aws_owner_id, - "CustomerId": generate_amt_id(length=14), - "EventDocVersion": "2006-05-05", - } - ), - } - - -@pytest.fixture() -def clean_mturk_events_redis_stream(redis): - redis.xtrim(JB_EVENTS_STREAM, maxlen=0) - assert redis.xlen(JB_EVENTS_STREAM) == 0 - yield - redis.xtrim(JB_EVENTS_STREAM, maxlen=0) - assert redis.xlen(JB_EVENTS_STREAM) == 0 - - -@pytest.mark.anyio -async def test_mturk_notifications( - redis, - httpxclient: AsyncClient, - no_limit, - example_mturk_event_body, - amt_assignment_id, - clean_mturk_events_redis_stream, -): - client = httpxclient - - res = await client.post(url=f"/{settings.sns_path}/", json=example_mturk_event_body) - res.raise_for_status() - - msg_res = redis.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100) - msg_res = msg_res[0][1][0] - msg_id, msg = msg_res - redis.xdel(JB_EVENTS_STREAM, msg_id) - - msg_json = msg["data"] - event = MTurkEvent.model_validate_json(msg_json) - assert event.amt_assignment_id == amt_assignment_id +from jb.models.hit import Hit + + +class TestNotifications: + + @pytest.mark.anyio + async def test_no_post_keys( + self, + httpxclient: AsyncClient, + ): + client = httpxclient + + res = await client.post(url=f"/{settings.sns_path}/", json={}) + assert res.status_code == 500 + assert res.json() == {"detail": "Invalid JSON"} + + @pytest.mark.anyio + async def test_no_post_data( + self, + httpxclient: AsyncClient, + ): + client = httpxclient + + res = await client.post(url=f"/{settings.sns_path}/") + assert res.status_code == 500 + assert res.json() == {"detail": "No POST data"} + + @pytest.mark.anyio + async def test_invalid_post_data( + self, + httpxclient: AsyncClient, + ): + client = httpxclient + + res = await client.post( + url=f"/{settings.sns_path}/", json={uuid4().hex: uuid4().hex} + ) + assert res.status_code == 500 + assert res.json() == {"detail": "Invalid JSON"} + + @pytest.mark.anyio + async def test_mturk_notifications( + self, + redis: redis.Redis, + httpxclient: AsyncClient, + hit_record: Hit, + assignment_stub_record: AssignmentStub, + mturk_event_body_record: Dict[str, Any], + ): + client = httpxclient + + json_msg = json.loads(mturk_event_body_record["Message"]) + # Assert the mturk event is owned by the correct account + assert json_msg["SourceAccount"] == settings.aws_owner_id + + # Assert the HIT and Assignment are "connected" + assert assignment_stub_record.hit_id == hit_record.id + assert assignment_stub_record.amt_hit_id == hit_record.amt_hit_id + + # Assert the event body and the Assignment/HIT are "connected" + assert json_msg["Events"][0]["HITId"] == hit_record.amt_hit_id + assert json_msg["Events"][0]["HITId"] == assignment_stub_record.amt_hit_id + assert ( + json_msg["Events"][0]["AssignmentId"] + == assignment_stub_record.amt_assignment_id + ) + + res = await client.post( + url=f"/{settings.sns_path}/", json=mturk_event_body_record + ) + res.raise_for_status() + + # AMT SNS needs to receive a 200 response to stop retrying the notification + assert res.status_code == 200 + assert res.json() == {"status": "ok"} + + # Check that the event was enqueued in Redis + msg_res = redis.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100) + msg_res = msg_res[0][1][0] + msg_id, msg = msg_res + redis.xdel(JB_EVENTS_STREAM, msg_id) + + msg_json = msg["data"] + event = MTurkEvent.model_validate_json(msg_json) + + # Confirm that the event that we got from redis is what was POSTED by + # AMT SNS. We're using the fixture() so the values are the same as + # what was in the example_mturk_event_body + assert event.event_type == "AssignmentSubmitted" + + assert event.amt_assignment_id == assignment_stub_record.amt_assignment_id + assert event.amt_hit_id == hit_record.amt_hit_id + assert event.amt_hit_type_id == hit_record.amt_hit_type_id diff --git a/tests/http/test_preview.py b/tests/http/test_preview.py index 2bdf265..467c63c 100644 --- a/tests/http/test_preview.py +++ b/tests/http/test_preview.py @@ -3,8 +3,8 @@ import pytest from httpx import AsyncClient - from jb.models.hit import Hit +from jb.models.assignment import AssignmentStub class TestPreview: @@ -15,8 +15,6 @@ class TestPreview: res = await client.get("/preview/") assert res.status_code == 200 - # the response is an html page - assert res.headers["content-type"] == "text/html; charset=utf-8" assert res.num_bytes_downloaded == 507 @@ -25,11 +23,17 @@ class TestPreview: assert "https://cdn.jamesbillings67.com/james-billings.js" in res.text @pytest.mark.anyio - async def test_preview_redirect_from_work( - self, httpxclient: AsyncClient, amt_hit_id, amt_assignment_id + async def test_preview_redirect_from_work_random_str( + self, httpxclient: AsyncClient, amt_hit_id: str, amt_assignment_id: str ): client = httpxclient + """ + The redirect occurs regardless of any parameter validation. This is + because the query params will be used for record lookup on the work + page itself. + """ + params = { "workerId": None, "assignmentId": amt_assignment_id, @@ -39,3 +43,28 @@ class TestPreview: assert res.status_code == 302 assert "/preview/" in res.headers["location"] + + @pytest.mark.anyio + async def test_preview_redirect_from_work_records( + self, + httpxclient: AsyncClient, + hit_record: Hit, + assignment_stub_record: AssignmentStub, + ): + client = httpxclient + + """ + The redirect occurs regardless of any parameter validation. This is + because the query params will be used for record lookup on the work + page itself. + """ + + params = { + "workerId": None, + "assignmentId": assignment_stub_record.amt_assignment_id, + "hitId": hit_record.amt_hit_id, + } + res = await client.get("/work/", params=params) + + assert res.status_code == 302 + assert "/preview/" in res.headers["location"] diff --git a/tests/http/test_work.py b/tests/http/test_work.py index c69118b..66251f6 100644 --- a/tests/http/test_work.py +++ b/tests/http/test_work.py @@ -1,5 +1,9 @@ import pytest from httpx import AsyncClient +from jb.models.hit import Hit +from jb.models.assignment import AssignmentStub + +from jb.managers.assignment import AssignmentManager class TestWork: @@ -8,17 +12,105 @@ class TestWork: async def test_work( self, httpxclient: AsyncClient, - hit_record, - amt_assignment_id, - amt_worker_id, + hit_record: Hit, + amt_assignment_id: str, + amt_worker_id: str, ): client = httpxclient + assert isinstance(hit_record.id, int) + params = { "workerId": amt_worker_id, "assignmentId": amt_assignment_id, "hitId": hit_record.amt_hit_id, } res = await client.get("/work/", params=params) + assert res.status_code == 200 + + @pytest.mark.anyio + async def test_work_no_hit_record( + self, + httpxclient: AsyncClient, + hit: Hit, + amt_assignment_id: str, + amt_worker_id: str, + ): + client = httpxclient + + # Because no AssignmentStub record is created, and we're just using + # random strings as IDs, we should also confirm that the Hit record + # is not a saved record. + assert hit.id is None + + params = { + "workerId": amt_worker_id, + "assignmentId": amt_assignment_id, + "hitId": hit.amt_hit_id, + } + res = await client.get("/work/", params=params) + assert res.status_code == 500 + + @pytest.mark.anyio + async def test_work_assignment_stub_existing( + self, + httpxclient: AsyncClient, + am: AssignmentManager, + hit: Hit, + assignment_stub_record: AssignmentStub, + amt_assignment_id: str, + amt_worker_id: str, + ): + client = httpxclient + + # Because the AssignmentStub is created with a reference to the Hit, + # the Hit is actually a "Hit Record" (with a primary key), so it's + # saved in the database + assert isinstance(hit.id, int) + + # Confirm that it exists in the database before the call + res = am.get_stub_if_exists(amt_assignment_id=amt_assignment_id) + assert isinstance(res, AssignmentStub) + assert isinstance(res.id, int) + + params = { + "workerId": amt_worker_id, + "assignmentId": assignment_stub_record.amt_assignment_id, + "hitId": hit.amt_hit_id, + } + res = await client.get("/work/", params=params) + assert res.status_code == 200 + + # Confirm that it exists in the database + res = am.get_stub_if_exists(amt_assignment_id=amt_assignment_id) + assert isinstance(res, AssignmentStub) + assert isinstance(res.id, int) + + @pytest.mark.anyio + async def test_work_assignment_stub_created( + self, + httpxclient: AsyncClient, + am: AssignmentManager, + hit_record: Hit, + assignment_stub: AssignmentStub, + amt_assignment_id: str, + amt_worker_id: str, + ): + client = httpxclient + + # Confirm that it exists in the database before the call + res = am.get_stub_if_exists(amt_assignment_id=amt_assignment_id) + assert res is None + params = { + "workerId": amt_worker_id, + "assignmentId": assignment_stub.amt_assignment_id, + "hitId": hit_record.amt_hit_id, + } + res = await client.get("/work/", params=params) assert res.status_code == 200 + + # Confirm that it exists in the database + res = am.get_stub_if_exists(amt_assignment_id=amt_assignment_id) + assert isinstance(res, AssignmentStub) + assert isinstance(res.id, int) |
