aboutsummaryrefslogtreecommitdiff
path: root/jb
diff options
context:
space:
mode:
authorMax Nanis2026-02-25 16:20:18 -0500
committerMax Nanis2026-02-25 16:20:18 -0500
commit04aee0dc7e908ce020d2d2c3f8ffb4a96424b883 (patch)
treeefb99622da9a962a73921a945373c019f98e6273 /jb
parent8c1940445503fd6678d0961600f2be81622793a2 (diff)
downloadamt-jb-04aee0dc7e908ce020d2d2c3f8ffb4a96424b883.tar.gz
amt-jb-04aee0dc7e908ce020d2d2c3f8ffb4a96424b883.zip
test_notification (for sns mgmt), along with more type hinting on pytest conftest
Diffstat (limited to 'jb')
-rw-r--r--jb/managers/assignment.py3
-rw-r--r--jb/settings.py4
-rw-r--r--jb/views/common.py81
3 files changed, 31 insertions, 57 deletions
diff --git a/jb/managers/assignment.py b/jb/managers/assignment.py
index dd3c866..c0a168e 100644
--- a/jb/managers/assignment.py
+++ b/jb/managers/assignment.py
@@ -17,7 +17,8 @@ class AssignmentManager(PostgresManager):
query = sql.SQL(
"""
INSERT INTO mtwerk_assignment
- (amt_assignment_id, amt_worker_id, status, created_at, modified_at, hit_id)
+ (amt_assignment_id, amt_worker_id, status,
+ created_at, modified_at, hit_id)
VALUES
(%(amt_assignment_id)s, %(amt_worker_id)s, %(status)s,
%(created_at)s, %(modified_at)s, %(hit_id)s)
diff --git a/jb/settings.py b/jb/settings.py
index 5754add..d529591 100644
--- a/jb/settings.py
+++ b/jb/settings.py
@@ -60,12 +60,12 @@ class TestSettings(Settings):
model_config = SettingsConfigDict(
env_prefix="",
case_sensitive=False,
- env_file=os.path.join(BASE_DIR, ".env.dev"),
+ env_file=os.path.join(BASE_DIR, ".env.test"),
extra="allow",
cli_parse_args=False,
)
debug: bool = True
- app_name: str = "AMT JB API Development"
+ app_name: str = "AMT JB API Test"
@lru_cache
diff --git a/jb/views/common.py b/jb/views/common.py
index 0dc8b56..c87b1a5 100644
--- a/jb/views/common.py
+++ b/jb/views/common.py
@@ -1,18 +1,16 @@
import json
-from typing import List
-from uuid import uuid4
+from unittest import case
+from typing import Dict, Any
import requests
-from fastapi import Request, APIRouter, Response, HTTPException, Query
-from fastapi.responses import HTMLResponse, JSONResponse
-from pydantic import BaseModel, ConfigDict, Field
+from fastapi import Request, APIRouter, HTTPException
+from fastapi.responses import HTMLResponse
from starlette.responses import RedirectResponse
from jb.config import settings, JB_EVENTS_STREAM
from jb.decorators import REDIS, HM
from jb.flow.monitoring import emit_assignment_event, emit_mturk_notification_event
-from generalresearchutils.currency import USDCent
-from jb.models.definitions import ReportValue, AssignmentStatus
+from jb.models.definitions import AssignmentStatus
from jb.models.event import MTurkEvent
from jb.settings import BASE_HTML
from jb.config import settings
@@ -21,40 +19,6 @@ from jb.views.tasks import process_request
common_router = APIRouter(prefix="", tags=["API"], include_in_schema=True)
-class ReportTask(BaseModel):
- model_config = ConfigDict(extra="forbid")
-
- worker_id: str = Field()
-
- reasons: List[ReportValue] = Field(
- examples=[[3, 4]],
- default_factory=list,
- )
-
- notes: str = Field(
- default="", examples=["The survey wanted to watch me eat Haejang-guk"]
- )
-
-
-@common_router.post("/report/")
-def report(request: Request, data: ReportTask):
- url = f"{settings.fsb_host}{settings.product_id}/report/"
- params = {
- "bpuid": data.worker_id,
- "reasons": [x.value for x in data.reasons],
- "notes": data.notes,
- }
-
- req = requests.post(url, json=params)
- res = req.json()
- if res.status_code != 200:
- raise HTTPException(
- status_code=res.status_code, detail="Failed to submit report"
- )
-
- return Response(res)
-
-
@common_router.get(path="/work/", response_class=HTMLResponse)
async def work(request: Request):
"""
@@ -86,8 +50,11 @@ async def work(request: Request):
status_code=302,
)
- # The Worker has accepted the HIT
- process_request(request)
+ try:
+ # The Worker has accepted the HIT
+ process_request(request)
+ except Exception:
+ raise HTTPException(status_code=500, detail="Error processing request")
return HTMLResponse(BASE_HTML)
@@ -97,23 +64,29 @@ async def mturk_notifications(request: Request):
"""
Our SNS topic will POST to this endpoint whenever we get a new message
"""
- message = await request.json()
- msg_type = message.get("Type")
+ try:
+ message = await request.json()
+ except Exception:
+ raise HTTPException(status_code=500, detail="No POST data")
+
+ match message.get("Type"):
+ case "SubscriptionConfirmation":
+ subscribe_url = message["SubscribeURL"]
+ print("Confirming SNS subscription...")
+ requests.get(subscribe_url)
- if msg_type == "SubscriptionConfirmation":
- subscribe_url = message["SubscribeURL"]
- print("Confirming SNS subscription...")
- requests.get(subscribe_url)
+ case "Notification":
+ msg = json.loads(message["Message"])
+ print("Received MTurk event:", msg)
+ enqueue_mturk_notifications(msg)
- elif msg_type == "Notification":
- msg = json.loads(message["Message"])
- print("Received MTurk event:", msg)
- enqueue_mturk_notifications(msg)
+ case _:
+ raise HTTPException(status_code=500, detail="Invalid JSON")
return {"status": "ok"}
-def enqueue_mturk_notifications(msg) -> None:
+def enqueue_mturk_notifications(msg: Dict[str, Any]) -> None:
for evt in msg["Events"]:
event = MTurkEvent.from_sns(evt)
emit_mturk_notification_event(