aboutsummaryrefslogtreecommitdiff
path: root/jb/views/common.py
blob: 701155732a77705567f8e61fba4f439a99573c2e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
from typing import Dict, Any

import requests
from fastapi import Request, APIRouter, HTTPException
from fastapi.responses import HTMLResponse
from starlette.responses import RedirectResponse

from jb.config import settings, JB_EVENTS_STREAM
from jb.decorators import REDIS, HM
from jb.flow.monitoring import emit_assignment_event, emit_mturk_notification_event
from jb.models.definitions import AssignmentStatus
from jb.models.event import MTurkEvent
from jb.settings import BASE_HTML
from jb.config import settings
from jb.views.tasks import process_request

common_router = APIRouter(prefix="", tags=["API"], include_in_schema=True)


@common_router.get(path="/work/", response_class=HTMLResponse)
async def work(request: Request):
    """
    HTML Page that the worker lands on in an iFrame.
    They can either be previewing the HIT, or have already accepted it.
    """
    amt_assignment_id = request.query_params.get("assignmentId", None)
    worker_id = request.query_params.get("workerId", None)
    amt_hit_id = request.query_params.get("hitId", None)
    print(f"work: {amt_assignment_id=} {worker_id=} {amt_hit_id=}")

    if not worker_id:
        return RedirectResponse(
            url=f"/preview/?{request.url.query}" if request.url.query else "/preview/",
            status_code=302,
        )

    if amt_assignment_id is None or amt_assignment_id == "ASSIGNMENT_ID_NOT_AVAILABLE":
        # Worker is previewing the HIT
        amt_hit_type_id = "unknown"
        if amt_hit_id:
            hit = HM.get_from_amt_id(amt_hit_id=amt_hit_id)
            amt_hit_type_id = hit.amt_hit_type_id
        emit_assignment_event(
            status=AssignmentStatus.PreviewState, amt_hit_type_id=amt_hit_type_id
        )
        return RedirectResponse(
            url=f"/preview/?{request.url.query}" if request.url.query else "/preview/",
            status_code=302,
        )

    try:
        # The Worker has accepted the HIT
        process_request(request)
    except Exception:
        raise HTTPException(status_code=500, detail="Error processing request")

    return HTMLResponse(BASE_HTML)


@common_router.post(path=f"/{settings.sns_path}/", include_in_schema=False)
async def mturk_notifications(request: Request):
    """
    Our SNS topic will POST to this endpoint whenever we get a new message
    """
    try:
        message = await request.json()
    except Exception:
        raise HTTPException(status_code=500, detail="No POST data")

    match message.get("Type"):
        case "SubscriptionConfirmation":
            subscribe_url = message["SubscribeURL"]
            print("Confirming SNS subscription...")
            requests.get(subscribe_url)

        case "Notification":
            msg = json.loads(message["Message"])
            print("Received MTurk event:", msg)
            enqueue_mturk_notifications(msg)

        case _:
            raise HTTPException(status_code=500, detail="Invalid JSON")

    return {"status": "ok"}


def enqueue_mturk_notifications(msg: Dict[str, Any]) -> None:
    for evt in msg["Events"]:
        event = MTurkEvent.from_sns(evt)
        emit_mturk_notification_event(
            event_type=event.event_type, amt_hit_type_id=event.amt_hit_type_id
        )
        REDIS.xadd(JB_EVENTS_STREAM, {"data": event.model_dump_json()})