aboutsummaryrefslogtreecommitdiff
path: root/tests/http/test_notifications.py
blob: 508b2362f2f45becab8fd7c9b898a7bfeac94b20 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import pytest
import json
import redis
from typing import Dict, Any
from httpx import AsyncClient
from uuid import uuid4

from jb.config import JB_EVENTS_STREAM, settings
from jb.models.event import MTurkEvent
from jb.models.hit import Hit
from jb.models.assignment import AssignmentStub


class TestNotifications:

    @pytest.mark.anyio
    async def test_no_post_keys(
        self,
        httpxclient: AsyncClient,
    ):
        client = httpxclient

        res = await client.post(url=f"/{settings.sns_path}/", json={})
        assert res.status_code == 500
        assert res.json() == {"detail": "Invalid JSON"}

    @pytest.mark.anyio
    async def test_no_post_data(
        self,
        httpxclient: AsyncClient,
    ):
        client = httpxclient

        res = await client.post(url=f"/{settings.sns_path}/")
        assert res.status_code == 500
        assert res.json() == {"detail": "No POST data"}

    @pytest.mark.anyio
    async def test_invalid_post_data(
        self,
        httpxclient: AsyncClient,
    ):
        client = httpxclient

        res = await client.post(
            url=f"/{settings.sns_path}/", json={uuid4().hex: uuid4().hex}
        )
        assert res.status_code == 500
        assert res.json() == {"detail": "Invalid JSON"}

    @pytest.mark.anyio
    async def test_mturk_notifications(
        self,
        redis: redis.Redis,
        httpxclient: AsyncClient,
        hit_record: Hit,
        assignment_stub_record: AssignmentStub,
        mturk_event_body_record: Dict[str, Any],
    ):
        client = httpxclient

        json_msg = json.loads(mturk_event_body_record["Message"])
        # Assert the mturk event is owned by the correct account
        assert json_msg["SourceAccount"] == settings.aws_owner_id

        # Assert the HIT and Assignment are "connected"
        assert assignment_stub_record.hit_id == hit_record.id
        assert assignment_stub_record.amt_hit_id == hit_record.amt_hit_id

        # Assert the event body and the Assignment/HIT are "connected"
        assert json_msg["Events"][0]["HITId"] == hit_record.amt_hit_id
        assert json_msg["Events"][0]["HITId"] == assignment_stub_record.amt_hit_id
        assert (
            json_msg["Events"][0]["AssignmentId"]
            == assignment_stub_record.amt_assignment_id
        )

        # Confirm the stream is empty
        assert redis.xlen(JB_EVENTS_STREAM) == 0

        res = await client.post(
            url=f"/{settings.sns_path}/", json=mturk_event_body_record
        )
        res.raise_for_status()

        # Now that we POSTed, confirm the stream has 1 event in it
        # Confirm the stream is empty
        assert redis.xlen(JB_EVENTS_STREAM) == 1

        # AMT SNS needs to receive a 200 response to stop retrying the notification
        assert res.status_code == 200
        assert res.json() == {"status": "ok"}

        # Check that the event was enqueued in Redis
        msg_res = redis.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100)
        msg_res = msg_res[0][1][0]
        msg_id, msg = msg_res
        redis.xdel(JB_EVENTS_STREAM, msg_id)

        # After running xdel, we can confirm the stream is empty
        assert redis.xlen(JB_EVENTS_STREAM) == 0

        msg_json = msg["data"]
        event = MTurkEvent.model_validate_json(msg_json)

        # Confirm that the event that we got from redis is what was POSTED by
        # AMT SNS. We're using the fixture() so the values are the same as
        # what was in the example_mturk_event_body
        assert event.event_type == "AssignmentSubmitted"

        assert event.amt_assignment_id == assignment_stub_record.amt_assignment_id
        assert event.amt_hit_id == hit_record.amt_hit_id
        assert event.amt_hit_type_id == hit_record.amt_hit_type_id