diff options
| author | Max Nanis | 2026-02-25 16:20:18 -0500 |
|---|---|---|
| committer | Max Nanis | 2026-02-25 16:20:18 -0500 |
| commit | 04aee0dc7e908ce020d2d2c3f8ffb4a96424b883 (patch) | |
| tree | efb99622da9a962a73921a945373c019f98e6273 /tests/http/test_notifications.py | |
| parent | 8c1940445503fd6678d0961600f2be81622793a2 (diff) | |
| download | amt-jb-04aee0dc7e908ce020d2d2c3f8ffb4a96424b883.tar.gz amt-jb-04aee0dc7e908ce020d2d2c3f8ffb4a96424b883.zip | |
test_notification (for sns mgmt), along with more type hinting on pytest conftest
Diffstat (limited to 'tests/http/test_notifications.py')
| -rw-r--r-- | tests/http/test_notifications.py | 163 |
1 files changed, 97 insertions, 66 deletions
diff --git a/tests/http/test_notifications.py b/tests/http/test_notifications.py index 6770044..4386863 100644 --- a/tests/http/test_notifications.py +++ b/tests/http/test_notifications.py @@ -1,71 +1,102 @@ -import json - import pytest +import json +import redis +from typing import Dict, Any from httpx import AsyncClient -import secrets +from uuid import uuid4 from jb.config import JB_EVENTS_STREAM, settings from jb.models.event import MTurkEvent -from tests import generate_amt_id - - -def generate_hex_id(length: int = 40) -> str: - # length is number of hex chars, so we need length//2 bytes - return secrets.token_hex(length // 2) - - -@pytest.fixture -def example_mturk_event_body(amt_hit_id, amt_hit_type_id, amt_assignment_id): - return { - "Type": "Notification", - "Message": json.dumps( - { - "Events": [ - { - "EventType": "AssignmentSubmitted", - "EventTimestamp": "2025-10-16T18:45:51.000000Z", - "HITId": amt_hit_id, - "AssignmentId": amt_assignment_id, - "HITTypeId": amt_hit_type_id, - } - ], - "EventDocId": generate_hex_id(), - "SourceAccount": settings.aws_owner_id, - "CustomerId": generate_amt_id(length=14), - "EventDocVersion": "2006-05-05", - } - ), - } - - -@pytest.fixture() -def clean_mturk_events_redis_stream(redis): - redis.xtrim(JB_EVENTS_STREAM, maxlen=0) - assert redis.xlen(JB_EVENTS_STREAM) == 0 - yield - redis.xtrim(JB_EVENTS_STREAM, maxlen=0) - assert redis.xlen(JB_EVENTS_STREAM) == 0 - - -@pytest.mark.anyio -async def test_mturk_notifications( - redis, - httpxclient: AsyncClient, - no_limit, - example_mturk_event_body, - amt_assignment_id, - clean_mturk_events_redis_stream, -): - client = httpxclient - - res = await client.post(url=f"/{settings.sns_path}/", json=example_mturk_event_body) - res.raise_for_status() - - msg_res = redis.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100) - msg_res = msg_res[0][1][0] - msg_id, msg = msg_res - redis.xdel(JB_EVENTS_STREAM, msg_id) - - msg_json = msg["data"] - event = MTurkEvent.model_validate_json(msg_json) - assert event.amt_assignment_id == amt_assignment_id +from jb.models.hit import Hit + + +class TestNotifications: + + @pytest.mark.anyio + async def test_no_post_keys( + self, + httpxclient: AsyncClient, + ): + client = httpxclient + + res = await client.post(url=f"/{settings.sns_path}/", json={}) + assert res.status_code == 500 + assert res.json() == {"detail": "Invalid JSON"} + + @pytest.mark.anyio + async def test_no_post_data( + self, + httpxclient: AsyncClient, + ): + client = httpxclient + + res = await client.post(url=f"/{settings.sns_path}/") + assert res.status_code == 500 + assert res.json() == {"detail": "No POST data"} + + @pytest.mark.anyio + async def test_invalid_post_data( + self, + httpxclient: AsyncClient, + ): + client = httpxclient + + res = await client.post( + url=f"/{settings.sns_path}/", json={uuid4().hex: uuid4().hex} + ) + assert res.status_code == 500 + assert res.json() == {"detail": "Invalid JSON"} + + @pytest.mark.anyio + async def test_mturk_notifications( + self, + redis: redis.Redis, + httpxclient: AsyncClient, + hit_record: Hit, + assignment_stub_record: AssignmentStub, + mturk_event_body_record: Dict[str, Any], + ): + client = httpxclient + + json_msg = json.loads(mturk_event_body_record["Message"]) + # Assert the mturk event is owned by the correct account + assert json_msg["SourceAccount"] == settings.aws_owner_id + + # Assert the HIT and Assignment are "connected" + assert assignment_stub_record.hit_id == hit_record.id + assert assignment_stub_record.amt_hit_id == hit_record.amt_hit_id + + # Assert the event body and the Assignment/HIT are "connected" + assert json_msg["Events"][0]["HITId"] == hit_record.amt_hit_id + assert json_msg["Events"][0]["HITId"] == assignment_stub_record.amt_hit_id + assert ( + json_msg["Events"][0]["AssignmentId"] + == assignment_stub_record.amt_assignment_id + ) + + res = await client.post( + url=f"/{settings.sns_path}/", json=mturk_event_body_record + ) + res.raise_for_status() + + # AMT SNS needs to receive a 200 response to stop retrying the notification + assert res.status_code == 200 + assert res.json() == {"status": "ok"} + + # Check that the event was enqueued in Redis + msg_res = redis.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100) + msg_res = msg_res[0][1][0] + msg_id, msg = msg_res + redis.xdel(JB_EVENTS_STREAM, msg_id) + + msg_json = msg["data"] + event = MTurkEvent.model_validate_json(msg_json) + + # Confirm that the event that we got from redis is what was POSTED by + # AMT SNS. We're using the fixture() so the values are the same as + # what was in the example_mturk_event_body + assert event.event_type == "AssignmentSubmitted" + + assert event.amt_assignment_id == assignment_stub_record.amt_assignment_id + assert event.amt_hit_id == hit_record.amt_hit_id + assert event.amt_hit_type_id == hit_record.amt_hit_type_id |
