diff options
| author | Max Nanis | 2026-02-21 02:15:52 -0500 |
|---|---|---|
| committer | Max Nanis | 2026-02-21 02:15:52 -0500 |
| commit | 67ab724561e4ceb8fe8fb4031de277168f7d9724 (patch) | |
| tree | 4d85619973491e7239f0e83dc5cdd85618f0f248 /jb | |
| parent | af8057d58ff152f511f5161a7626b0fffa9d661a (diff) | |
| download | amt-jb-67ab724561e4ceb8fe8fb4031de277168f7d9724.tar.gz amt-jb-67ab724561e4ceb8fe8fb4031de277168f7d9724.zip | |
More pytest conf, some views, and defining more attrs on the settings config
Diffstat (limited to 'jb')
| -rw-r--r-- | jb/config.py | 36 | ||||
| -rw-r--r-- | jb/flow/tasks.py | 103 | ||||
| -rw-r--r-- | jb/settings.py | 75 | ||||
| -rw-r--r-- | jb/views/common.py | 186 |
4 files changed, 400 insertions, 0 deletions
diff --git a/jb/config.py b/jb/config.py new file mode 100644 index 0000000..c7d07e5 --- /dev/null +++ b/jb/config.py @@ -0,0 +1,36 @@ +import logging + +from generalresearchutils.config import is_debug + +from jb.settings import get_settings, get_test_settings + +if is_debug(): + print("running using TEST settings") + settings = get_test_settings() + assert settings.debug is True +else: + print("running using PROD settings") + settings = get_settings() + assert settings.debug is False + +if settings.debug: + LOG_LEVEL = logging.DEBUG +else: + LOG_LEVEL = logging.WARNING + +# The SNS topic that 1) JB Mturk will send notifications to, 2) will make http POSTs +# back to us (here) +TOPIC_ARN = f"arn:aws:sns:us-east-2:{settings.aws_owner_id}:amt-jb" +SUBSCRIPTION = { + "SubscriptionArn": settings.aws_subscription_arn, + "Owner": settings.aws_owner_id, + "Protocol": "https", + "Endpoint": f"https://jamesbillings67.com/{settings.sns_path}/", + "TopicArn": TOPIC_ARN, +} + +JB_EVENTS_STREAM = "amt_jb_events" +JB_EVENTS_FAILED_STREAM = "amt_jb_events_failed" +CONSUMER_GROUP = "amt-jb-0" +# We'll only have 1 consumer atm, change this if we don't +CONSUMER_NAME = "amt-jb-0" diff --git a/jb/flow/tasks.py b/jb/flow/tasks.py new file mode 100644 index 0000000..e7c64b9 --- /dev/null +++ b/jb/flow/tasks.py @@ -0,0 +1,103 @@ +import logging +import time + +from generalresearchutils.config import is_debug + +from jb.decorators import HTM, HM, HQM, pg_config +from jb.flow.maintenance import check_hit_status +from jb.flow.monitoring import write_hit_gauge, emit_hit_event +from jb.managers.amt import AMTManager +from jb.models.definitions import HitStatus +from jb.models.hit import HitType, HitQuestion, Hit + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +def check_stale_hits(): + # Check live hits that haven't been modified in a long time. They may not + # be expired yet, but maybe something is wrong? + res = pg_config.execute_sql_query( + """ + SELECT amt_hit_id, amt_hit_type_id + FROM mtwerk_hit mh + JOIN mtwerk_hittype mht ON mht.id = mh.hit_type_id + WHERE status = %(status)s + ORDER BY modified_at + LIMIT 100;""", + params={"status": HitStatus.Assignable.value}, + ) + for hit in res: + logging.info(f"check_stale_hits: {hit["amt_hit_id"]}") + check_hit_status( + amt_hit_id=hit["amt_hit_id"], + amt_hit_type_id=hit["amt_hit_type_id"], + reason="cleanup", + ) + + +def check_expired_hits(): + # Check live/assignable hits that are expired (based on AMT's expiration time) + res = pg_config.execute_sql_query( + """ + SELECT amt_hit_id, amt_hit_type_id + FROM mtwerk_hit mh + JOIN mtwerk_hittype mht ON mht.id = mh.hit_type_id + WHERE status = %(status)s + AND expiration < now() + LIMIT 100;""", + params={"status": HitStatus.Assignable.value}, + ) + for hit in res: + logging.info(f"check_expired_hits: {hit["amt_hit_id"]}") + check_hit_status( + amt_hit_id=hit["amt_hit_id"], + amt_hit_type_id=hit["amt_hit_type_id"], + reason="expired", + ) + + +def create_hit_from_hittype(hit_type: HitType) -> Hit: + if is_debug(): + raise Exception("Handle AMT Sandbox issues.") + + else: + question = HQM.get_or_create( + HitQuestion(height=800, url="https://jamesbillings67.com/work/") + ) + + hit = AMTManager.create_hit_with_hit_type(hit_type=hit_type, question=question) + HM.create(hit) + emit_hit_event(status=hit.status, amt_hit_type_id=hit.amt_hit_type_id) + return hit + + +def refill_hits() -> None: + for hit_type in HTM.filter_active(): + active_count = HM.get_active_count(hit_type.id) + logging.info( + f"HitType: {hit_type.amt_hit_type_id}, {hit_type.min_active=}, active_count={active_count}" + ) + write_hit_gauge( + status=HitStatus.Assignable, + amt_hit_type_id=hit_type.amt_hit_type_id, + cnt=active_count, + ) + if active_count < hit_type.min_active: + cnt_todo = hit_type.min_active - active_count + logging.info(f"Refilling {cnt_todo} hits") + for _ in range(cnt_todo): + create_hit_from_hittype(hit_type) + + +def refill_hits_task(): + while True: + try: + check_expired_hits() + check_stale_hits() + refill_hits() + except Exception as e: + logging.exception(e) + finally: + time.sleep(5 * 60) diff --git a/jb/settings.py b/jb/settings.py new file mode 100644 index 0000000..28402f3 --- /dev/null +++ b/jb/settings.py @@ -0,0 +1,75 @@ +import os +from functools import lru_cache +from pathlib import Path +from typing import Optional + +from generalresearchutils.models.custom_types import InfluxDsn +from pydantic import Field, PostgresDsn, HttpUrl, RedisDsn +from pydantic_settings import BaseSettings, SettingsConfigDict + +from jb.models.custom_types import UUIDStr + +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +BASE_HTML_PATH = Path(BASE_DIR) / "templates" / "base.html" +BASE_HTML = BASE_HTML_PATH.read_text() + + +class AmtJbBaseSettings(BaseSettings): + debug: bool = Field(default=True) + + redis: Optional[RedisDsn] = Field(default=None) + redis_timeout: float = Field(default=0.10) + + amt_jb_db: PostgresDsn = Field() + + amt_endpoint: Optional[HttpUrl] = Field(default=None) + amt_access_id: Optional[str] = Field(default=None) + amt_secret_key: Optional[str] = Field(default=None) + + aws_owner_id: str = Field() + aws_subscription_arn: str = Field() + + +class Settings(AmtJbBaseSettings): + model_config = SettingsConfigDict( + env_prefix="", + case_sensitive=False, + env_file=os.path.join(BASE_DIR, ".env"), + extra="allow", + cli_parse_args=False, + ) + debug: bool = False + app_name: str = "AMT JB API" + + fsb_host: HttpUrl = Field(default="https://fsb.generalresearch.com/") + # Needed for admin function on fsb w/o authentication + fsb_host_private_route: Optional[str] = Field(default=None) + + product_id: UUIDStr = Field() + + influx_db: Optional[InfluxDsn] = Field(default=None) + + sns_path: str = Field() + + +class TestSettings(Settings): + model_config = SettingsConfigDict( + env_prefix="", + case_sensitive=False, + env_file=os.path.join(BASE_DIR, ".env.dev"), + extra="allow", + cli_parse_args=False, + ) + debug: bool = True + app_name: str = "AMT JB API Development" + + +@lru_cache +def get_settings(): + return Settings() + + +@lru_cache +def get_test_settings(): + return TestSettings() diff --git a/jb/views/common.py b/jb/views/common.py new file mode 100644 index 0000000..46ac608 --- /dev/null +++ b/jb/views/common.py @@ -0,0 +1,186 @@ +import json +from typing import List +from uuid import uuid4 + +import requests +from fastapi import Request, APIRouter, Response, HTTPException, Query +from fastapi.responses import HTMLResponse, JSONResponse +from pydantic import BaseModel, ConfigDict, Field +from starlette.responses import RedirectResponse + +from jb.config import settings, JB_EVENTS_STREAM +from jb.decorators import REDIS, HM +from jb.flow.monitoring import emit_assignment_event, emit_mturk_notification_event +from jb.models.currency import USDCent +from jb.models.definitions import ReportValue, AssignmentStatus +from jb.models.event import MTurkEvent +from jb.settings import BASE_HTML +from jb.config import settings +from jb.views.tasks import process_request + +common_router = APIRouter(prefix="", tags=["API"], include_in_schema=True) + + +class ReportTask(BaseModel): + model_config = ConfigDict(extra="forbid") + + worker_id: str = Field() + + reasons: List[ReportValue] = Field( + examples=[[3, 4]], + default_factory=list, + ) + + notes: str = Field( + default="", examples=["The survey wanted to watch me eat Haejang-guk"] + ) + + +@common_router.post("/report/") +def report(request: Request, data: ReportTask): + url = f"{settings.fsb_host}{settings.product_id}/report/" + params = { + "bpuid": data.worker_id, + "reasons": [x.value for x in data.reasons], + "notes": data.notes, + } + + req = requests.post(url, json=params) + res = req.json() + if res.status_code != 200: + raise HTTPException( + status_code=res.status_code, detail="Failed to submit report" + ) + + return Response(res) + + +@common_router.get(path="/work/", response_class=HTMLResponse) +async def work(request: Request): + """ + HTML Page that the worker lands on in an iFrame. + They can either be previewing the HIT, or have already accepted it. + """ + amt_assignment_id = request.query_params.get("assignmentId", None) + worker_id = request.query_params.get("workerId", None) + amt_hit_id = request.query_params.get("hitId", None) + print(f"work: {amt_assignment_id=} {worker_id=} {amt_hit_id=}") + + if not worker_id: + return RedirectResponse( + url=f"/preview/?{request.url.query}" if request.url.query else "/preview/", + status_code=302, + ) + if amt_assignment_id is None or amt_assignment_id == "ASSIGNMENT_ID_NOT_AVAILABLE": + # Worker is previewing the HIT + amt_hit_type_id = "unknown" + if amt_hit_id: + hit = HM.get_from_amt_id(amt_hit_id=amt_hit_id) + amt_hit_type_id = hit.amt_hit_type_id + emit_assignment_event( + status=AssignmentStatus.PreviewState, amt_hit_type_id=amt_hit_type_id + ) + return RedirectResponse( + url=f"/preview/?{request.url.query}" if request.url.query else "/preview/", + status_code=302, + ) + + # The Worker has accepted the HIT + process_request(request) + + return HTMLResponse(BASE_HTML) + + +@common_router.get(path="/survey/", response_class=JSONResponse) +def survey( + request: Request, + worker_id: str = Query(), + duration: int = Query(default=1200), +): + if not worker_id: + raise HTTPException(status_code=400, detail="Missing worker_id") + + # (1) Check wallet + wallet_url = f"{settings.fsb_host}{settings.product_id}/wallet/" + wallet_res = requests.get(wallet_url, params={"bpuid": worker_id}) + if wallet_res.status_code != 200: + raise HTTPException(status_code=502, detail="Wallet check failed") + + wallet_data = wallet_res.json() + wallet_balance = wallet_data["wallet"]["amount"] + if wallet_balance < -100: + return JSONResponse( + { + "total_surveys": 0, + "link": None, + "duration": None, + "payout": None, + } + ) + + # (2) Get offerwall + client_ip = "69.253.144.55" if settings.debug else request.client.host + offerwall_url = f"{settings.fsb_host}{settings.product_id}/offerwall/d48cce47/" + offerwall_res = requests.get( + offerwall_url, + params={ + "bpuid": worker_id, + "ip": client_ip, + "n_bins": 1, + "duration": duration, + }, + ) + + if offerwall_res.status_code != 200: + raise HTTPException(status_code=502, detail="Offerwall request failed") + + try: + rj = offerwall_res.json() + bucket = rj["offerwall"]["buckets"][0] + return JSONResponse( + { + "total_surveys": rj["offerwall"]["availability_count"], + "link": bucket["uri"], + "duration": round(bucket["duration"]["q2"] / 60), + "payout": USDCent(bucket["payout"]["q2"]).to_usd_str(), + } + ) + except Exception: + return JSONResponse( + { + "total_surveys": 0, + "link": None, + "duration": None, + "payout": None, + } + ) + + +@common_router.post(path=f"/{settings.sns_path}/", include_in_schema=False) +async def mturk_notifications(request: Request): + """ + Our SNS topic will POST to this endpoint whenever we get a new message + """ + message = await request.json() + msg_type = message.get("Type") + + if msg_type == "SubscriptionConfirmation": + subscribe_url = message["SubscribeURL"] + print("Confirming SNS subscription...") + requests.get(subscribe_url) + + elif msg_type == "Notification": + msg = json.loads(message["Message"]) + print("Received MTurk event:", msg) + enqueue_mturk_notifications(msg) + + return {"status": "ok"} + + +def enqueue_mturk_notifications(msg) -> None: + for evt in msg["Events"]: + event = MTurkEvent.from_sns(evt) + emit_mturk_notification_event( + event_type=event.event_type, amt_hit_type_id=event.amt_hit_type_id + ) + REDIS.xadd(JB_EVENTS_STREAM, {"data": event.model_dump_json()}) |
