aboutsummaryrefslogtreecommitdiff
path: root/tests/http/test_notifications.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http/test_notifications.py')
-rw-r--r--tests/http/test_notifications.py163
1 files changed, 97 insertions, 66 deletions
diff --git a/tests/http/test_notifications.py b/tests/http/test_notifications.py
index 6770044..4386863 100644
--- a/tests/http/test_notifications.py
+++ b/tests/http/test_notifications.py
@@ -1,71 +1,102 @@
-import json
-
import pytest
+import json
+import redis
+from typing import Dict, Any
from httpx import AsyncClient
-import secrets
+from uuid import uuid4
from jb.config import JB_EVENTS_STREAM, settings
from jb.models.event import MTurkEvent
-from tests import generate_amt_id
-
-
-def generate_hex_id(length: int = 40) -> str:
- # length is number of hex chars, so we need length//2 bytes
- return secrets.token_hex(length // 2)
-
-
-@pytest.fixture
-def example_mturk_event_body(amt_hit_id, amt_hit_type_id, amt_assignment_id):
- return {
- "Type": "Notification",
- "Message": json.dumps(
- {
- "Events": [
- {
- "EventType": "AssignmentSubmitted",
- "EventTimestamp": "2025-10-16T18:45:51.000000Z",
- "HITId": amt_hit_id,
- "AssignmentId": amt_assignment_id,
- "HITTypeId": amt_hit_type_id,
- }
- ],
- "EventDocId": generate_hex_id(),
- "SourceAccount": settings.aws_owner_id,
- "CustomerId": generate_amt_id(length=14),
- "EventDocVersion": "2006-05-05",
- }
- ),
- }
-
-
-@pytest.fixture()
-def clean_mturk_events_redis_stream(redis):
- redis.xtrim(JB_EVENTS_STREAM, maxlen=0)
- assert redis.xlen(JB_EVENTS_STREAM) == 0
- yield
- redis.xtrim(JB_EVENTS_STREAM, maxlen=0)
- assert redis.xlen(JB_EVENTS_STREAM) == 0
-
-
-@pytest.mark.anyio
-async def test_mturk_notifications(
- redis,
- httpxclient: AsyncClient,
- no_limit,
- example_mturk_event_body,
- amt_assignment_id,
- clean_mturk_events_redis_stream,
-):
- client = httpxclient
-
- res = await client.post(url=f"/{settings.sns_path}/", json=example_mturk_event_body)
- res.raise_for_status()
-
- msg_res = redis.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100)
- msg_res = msg_res[0][1][0]
- msg_id, msg = msg_res
- redis.xdel(JB_EVENTS_STREAM, msg_id)
-
- msg_json = msg["data"]
- event = MTurkEvent.model_validate_json(msg_json)
- assert event.amt_assignment_id == amt_assignment_id
+from jb.models.hit import Hit
+
+
+class TestNotifications:
+
+ @pytest.mark.anyio
+ async def test_no_post_keys(
+ self,
+ httpxclient: AsyncClient,
+ ):
+ client = httpxclient
+
+ res = await client.post(url=f"/{settings.sns_path}/", json={})
+ assert res.status_code == 500
+ assert res.json() == {"detail": "Invalid JSON"}
+
+ @pytest.mark.anyio
+ async def test_no_post_data(
+ self,
+ httpxclient: AsyncClient,
+ ):
+ client = httpxclient
+
+ res = await client.post(url=f"/{settings.sns_path}/")
+ assert res.status_code == 500
+ assert res.json() == {"detail": "No POST data"}
+
+ @pytest.mark.anyio
+ async def test_invalid_post_data(
+ self,
+ httpxclient: AsyncClient,
+ ):
+ client = httpxclient
+
+ res = await client.post(
+ url=f"/{settings.sns_path}/", json={uuid4().hex: uuid4().hex}
+ )
+ assert res.status_code == 500
+ assert res.json() == {"detail": "Invalid JSON"}
+
+ @pytest.mark.anyio
+ async def test_mturk_notifications(
+ self,
+ redis: redis.Redis,
+ httpxclient: AsyncClient,
+ hit_record: Hit,
+ assignment_stub_record: AssignmentStub,
+ mturk_event_body_record: Dict[str, Any],
+ ):
+ client = httpxclient
+
+ json_msg = json.loads(mturk_event_body_record["Message"])
+ # Assert the mturk event is owned by the correct account
+ assert json_msg["SourceAccount"] == settings.aws_owner_id
+
+ # Assert the HIT and Assignment are "connected"
+ assert assignment_stub_record.hit_id == hit_record.id
+ assert assignment_stub_record.amt_hit_id == hit_record.amt_hit_id
+
+ # Assert the event body and the Assignment/HIT are "connected"
+ assert json_msg["Events"][0]["HITId"] == hit_record.amt_hit_id
+ assert json_msg["Events"][0]["HITId"] == assignment_stub_record.amt_hit_id
+ assert (
+ json_msg["Events"][0]["AssignmentId"]
+ == assignment_stub_record.amt_assignment_id
+ )
+
+ res = await client.post(
+ url=f"/{settings.sns_path}/", json=mturk_event_body_record
+ )
+ res.raise_for_status()
+
+ # AMT SNS needs to receive a 200 response to stop retrying the notification
+ assert res.status_code == 200
+ assert res.json() == {"status": "ok"}
+
+ # Check that the event was enqueued in Redis
+ msg_res = redis.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100)
+ msg_res = msg_res[0][1][0]
+ msg_id, msg = msg_res
+ redis.xdel(JB_EVENTS_STREAM, msg_id)
+
+ msg_json = msg["data"]
+ event = MTurkEvent.model_validate_json(msg_json)
+
+ # Confirm that the event that we got from redis is what was POSTED by
+ # AMT SNS. We're using the fixture() so the values are the same as
+ # what was in the example_mturk_event_body
+ assert event.event_type == "AssignmentSubmitted"
+
+ assert event.amt_assignment_id == assignment_stub_record.amt_assignment_id
+ assert event.amt_hit_id == hit_record.amt_hit_id
+ assert event.amt_hit_type_id == hit_record.amt_hit_type_id