diff options
| author | Max Nanis | 2026-02-25 16:20:18 -0500 |
|---|---|---|
| committer | Max Nanis | 2026-02-25 16:20:18 -0500 |
| commit | 04aee0dc7e908ce020d2d2c3f8ffb4a96424b883 (patch) | |
| tree | efb99622da9a962a73921a945373c019f98e6273 /jb | |
| parent | 8c1940445503fd6678d0961600f2be81622793a2 (diff) | |
| download | amt-jb-04aee0dc7e908ce020d2d2c3f8ffb4a96424b883.tar.gz amt-jb-04aee0dc7e908ce020d2d2c3f8ffb4a96424b883.zip | |
test_notification (for sns mgmt), along with more type hinting on pytest conftest
Diffstat (limited to 'jb')
| -rw-r--r-- | jb/managers/assignment.py | 3 | ||||
| -rw-r--r-- | jb/settings.py | 4 | ||||
| -rw-r--r-- | jb/views/common.py | 81 |
3 files changed, 31 insertions, 57 deletions
diff --git a/jb/managers/assignment.py b/jb/managers/assignment.py index dd3c866..c0a168e 100644 --- a/jb/managers/assignment.py +++ b/jb/managers/assignment.py @@ -17,7 +17,8 @@ class AssignmentManager(PostgresManager): query = sql.SQL( """ INSERT INTO mtwerk_assignment - (amt_assignment_id, amt_worker_id, status, created_at, modified_at, hit_id) + (amt_assignment_id, amt_worker_id, status, + created_at, modified_at, hit_id) VALUES (%(amt_assignment_id)s, %(amt_worker_id)s, %(status)s, %(created_at)s, %(modified_at)s, %(hit_id)s) diff --git a/jb/settings.py b/jb/settings.py index 5754add..d529591 100644 --- a/jb/settings.py +++ b/jb/settings.py @@ -60,12 +60,12 @@ class TestSettings(Settings): model_config = SettingsConfigDict( env_prefix="", case_sensitive=False, - env_file=os.path.join(BASE_DIR, ".env.dev"), + env_file=os.path.join(BASE_DIR, ".env.test"), extra="allow", cli_parse_args=False, ) debug: bool = True - app_name: str = "AMT JB API Development" + app_name: str = "AMT JB API Test" @lru_cache diff --git a/jb/views/common.py b/jb/views/common.py index 0dc8b56..c87b1a5 100644 --- a/jb/views/common.py +++ b/jb/views/common.py @@ -1,18 +1,16 @@ import json -from typing import List -from uuid import uuid4 +from unittest import case +from typing import Dict, Any import requests -from fastapi import Request, APIRouter, Response, HTTPException, Query -from fastapi.responses import HTMLResponse, JSONResponse -from pydantic import BaseModel, ConfigDict, Field +from fastapi import Request, APIRouter, HTTPException +from fastapi.responses import HTMLResponse from starlette.responses import RedirectResponse from jb.config import settings, JB_EVENTS_STREAM from jb.decorators import REDIS, HM from jb.flow.monitoring import emit_assignment_event, emit_mturk_notification_event -from generalresearchutils.currency import USDCent -from jb.models.definitions import ReportValue, AssignmentStatus +from jb.models.definitions import AssignmentStatus from jb.models.event import MTurkEvent from jb.settings import BASE_HTML from jb.config import settings @@ -21,40 +19,6 @@ from jb.views.tasks import process_request common_router = APIRouter(prefix="", tags=["API"], include_in_schema=True) -class ReportTask(BaseModel): - model_config = ConfigDict(extra="forbid") - - worker_id: str = Field() - - reasons: List[ReportValue] = Field( - examples=[[3, 4]], - default_factory=list, - ) - - notes: str = Field( - default="", examples=["The survey wanted to watch me eat Haejang-guk"] - ) - - -@common_router.post("/report/") -def report(request: Request, data: ReportTask): - url = f"{settings.fsb_host}{settings.product_id}/report/" - params = { - "bpuid": data.worker_id, - "reasons": [x.value for x in data.reasons], - "notes": data.notes, - } - - req = requests.post(url, json=params) - res = req.json() - if res.status_code != 200: - raise HTTPException( - status_code=res.status_code, detail="Failed to submit report" - ) - - return Response(res) - - @common_router.get(path="/work/", response_class=HTMLResponse) async def work(request: Request): """ @@ -86,8 +50,11 @@ async def work(request: Request): status_code=302, ) - # The Worker has accepted the HIT - process_request(request) + try: + # The Worker has accepted the HIT + process_request(request) + except Exception: + raise HTTPException(status_code=500, detail="Error processing request") return HTMLResponse(BASE_HTML) @@ -97,23 +64,29 @@ async def mturk_notifications(request: Request): """ Our SNS topic will POST to this endpoint whenever we get a new message """ - message = await request.json() - msg_type = message.get("Type") + try: + message = await request.json() + except Exception: + raise HTTPException(status_code=500, detail="No POST data") + + match message.get("Type"): + case "SubscriptionConfirmation": + subscribe_url = message["SubscribeURL"] + print("Confirming SNS subscription...") + requests.get(subscribe_url) - if msg_type == "SubscriptionConfirmation": - subscribe_url = message["SubscribeURL"] - print("Confirming SNS subscription...") - requests.get(subscribe_url) + case "Notification": + msg = json.loads(message["Message"]) + print("Received MTurk event:", msg) + enqueue_mturk_notifications(msg) - elif msg_type == "Notification": - msg = json.loads(message["Message"]) - print("Received MTurk event:", msg) - enqueue_mturk_notifications(msg) + case _: + raise HTTPException(status_code=500, detail="Invalid JSON") return {"status": "ok"} -def enqueue_mturk_notifications(msg) -> None: +def enqueue_mturk_notifications(msg: Dict[str, Any]) -> None: for evt in msg["Events"]: event = MTurkEvent.from_sns(evt) emit_mturk_notification_event( |
