aboutsummaryrefslogtreecommitdiff
path: root/jb
diff options
context:
space:
mode:
Diffstat (limited to 'jb')
-rw-r--r--jb/config.py36
-rw-r--r--jb/flow/tasks.py103
-rw-r--r--jb/settings.py75
-rw-r--r--jb/views/common.py186
4 files changed, 400 insertions, 0 deletions
diff --git a/jb/config.py b/jb/config.py
new file mode 100644
index 0000000..c7d07e5
--- /dev/null
+++ b/jb/config.py
@@ -0,0 +1,36 @@
+import logging
+
+from generalresearchutils.config import is_debug
+
+from jb.settings import get_settings, get_test_settings
+
+if is_debug():
+ print("running using TEST settings")
+ settings = get_test_settings()
+ assert settings.debug is True
+else:
+ print("running using PROD settings")
+ settings = get_settings()
+ assert settings.debug is False
+
+if settings.debug:
+ LOG_LEVEL = logging.DEBUG
+else:
+ LOG_LEVEL = logging.WARNING
+
+# The SNS topic that 1) JB Mturk will send notifications to, 2) will make http POSTs
+# back to us (here)
+TOPIC_ARN = f"arn:aws:sns:us-east-2:{settings.aws_owner_id}:amt-jb"
+SUBSCRIPTION = {
+ "SubscriptionArn": settings.aws_subscription_arn,
+ "Owner": settings.aws_owner_id,
+ "Protocol": "https",
+ "Endpoint": f"https://jamesbillings67.com/{settings.sns_path}/",
+ "TopicArn": TOPIC_ARN,
+}
+
+JB_EVENTS_STREAM = "amt_jb_events"
+JB_EVENTS_FAILED_STREAM = "amt_jb_events_failed"
+CONSUMER_GROUP = "amt-jb-0"
+# We'll only have 1 consumer atm, change this if we don't
+CONSUMER_NAME = "amt-jb-0"
diff --git a/jb/flow/tasks.py b/jb/flow/tasks.py
new file mode 100644
index 0000000..e7c64b9
--- /dev/null
+++ b/jb/flow/tasks.py
@@ -0,0 +1,103 @@
+import logging
+import time
+
+from generalresearchutils.config import is_debug
+
+from jb.decorators import HTM, HM, HQM, pg_config
+from jb.flow.maintenance import check_hit_status
+from jb.flow.monitoring import write_hit_gauge, emit_hit_event
+from jb.managers.amt import AMTManager
+from jb.models.definitions import HitStatus
+from jb.models.hit import HitType, HitQuestion, Hit
+
+logging.basicConfig()
+logger = logging.getLogger()
+logger.setLevel(logging.INFO)
+
+
+def check_stale_hits():
+ # Check live hits that haven't been modified in a long time. They may not
+ # be expired yet, but maybe something is wrong?
+ res = pg_config.execute_sql_query(
+ """
+ SELECT amt_hit_id, amt_hit_type_id
+ FROM mtwerk_hit mh
+ JOIN mtwerk_hittype mht ON mht.id = mh.hit_type_id
+ WHERE status = %(status)s
+ ORDER BY modified_at
+ LIMIT 100;""",
+ params={"status": HitStatus.Assignable.value},
+ )
+ for hit in res:
+ logging.info(f"check_stale_hits: {hit["amt_hit_id"]}")
+ check_hit_status(
+ amt_hit_id=hit["amt_hit_id"],
+ amt_hit_type_id=hit["amt_hit_type_id"],
+ reason="cleanup",
+ )
+
+
+def check_expired_hits():
+ # Check live/assignable hits that are expired (based on AMT's expiration time)
+ res = pg_config.execute_sql_query(
+ """
+ SELECT amt_hit_id, amt_hit_type_id
+ FROM mtwerk_hit mh
+ JOIN mtwerk_hittype mht ON mht.id = mh.hit_type_id
+ WHERE status = %(status)s
+ AND expiration < now()
+ LIMIT 100;""",
+ params={"status": HitStatus.Assignable.value},
+ )
+ for hit in res:
+ logging.info(f"check_expired_hits: {hit["amt_hit_id"]}")
+ check_hit_status(
+ amt_hit_id=hit["amt_hit_id"],
+ amt_hit_type_id=hit["amt_hit_type_id"],
+ reason="expired",
+ )
+
+
+def create_hit_from_hittype(hit_type: HitType) -> Hit:
+ if is_debug():
+ raise Exception("Handle AMT Sandbox issues.")
+
+ else:
+ question = HQM.get_or_create(
+ HitQuestion(height=800, url="https://jamesbillings67.com/work/")
+ )
+
+ hit = AMTManager.create_hit_with_hit_type(hit_type=hit_type, question=question)
+ HM.create(hit)
+ emit_hit_event(status=hit.status, amt_hit_type_id=hit.amt_hit_type_id)
+ return hit
+
+
+def refill_hits() -> None:
+ for hit_type in HTM.filter_active():
+ active_count = HM.get_active_count(hit_type.id)
+ logging.info(
+ f"HitType: {hit_type.amt_hit_type_id}, {hit_type.min_active=}, active_count={active_count}"
+ )
+ write_hit_gauge(
+ status=HitStatus.Assignable,
+ amt_hit_type_id=hit_type.amt_hit_type_id,
+ cnt=active_count,
+ )
+ if active_count < hit_type.min_active:
+ cnt_todo = hit_type.min_active - active_count
+ logging.info(f"Refilling {cnt_todo} hits")
+ for _ in range(cnt_todo):
+ create_hit_from_hittype(hit_type)
+
+
+def refill_hits_task():
+ while True:
+ try:
+ check_expired_hits()
+ check_stale_hits()
+ refill_hits()
+ except Exception as e:
+ logging.exception(e)
+ finally:
+ time.sleep(5 * 60)
diff --git a/jb/settings.py b/jb/settings.py
new file mode 100644
index 0000000..28402f3
--- /dev/null
+++ b/jb/settings.py
@@ -0,0 +1,75 @@
+import os
+from functools import lru_cache
+from pathlib import Path
+from typing import Optional
+
+from generalresearchutils.models.custom_types import InfluxDsn
+from pydantic import Field, PostgresDsn, HttpUrl, RedisDsn
+from pydantic_settings import BaseSettings, SettingsConfigDict
+
+from jb.models.custom_types import UUIDStr
+
+BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+BASE_HTML_PATH = Path(BASE_DIR) / "templates" / "base.html"
+BASE_HTML = BASE_HTML_PATH.read_text()
+
+
+class AmtJbBaseSettings(BaseSettings):
+ debug: bool = Field(default=True)
+
+ redis: Optional[RedisDsn] = Field(default=None)
+ redis_timeout: float = Field(default=0.10)
+
+ amt_jb_db: PostgresDsn = Field()
+
+ amt_endpoint: Optional[HttpUrl] = Field(default=None)
+ amt_access_id: Optional[str] = Field(default=None)
+ amt_secret_key: Optional[str] = Field(default=None)
+
+ aws_owner_id: str = Field()
+ aws_subscription_arn: str = Field()
+
+
+class Settings(AmtJbBaseSettings):
+ model_config = SettingsConfigDict(
+ env_prefix="",
+ case_sensitive=False,
+ env_file=os.path.join(BASE_DIR, ".env"),
+ extra="allow",
+ cli_parse_args=False,
+ )
+ debug: bool = False
+ app_name: str = "AMT JB API"
+
+ fsb_host: HttpUrl = Field(default="https://fsb.generalresearch.com/")
+ # Needed for admin function on fsb w/o authentication
+ fsb_host_private_route: Optional[str] = Field(default=None)
+
+ product_id: UUIDStr = Field()
+
+ influx_db: Optional[InfluxDsn] = Field(default=None)
+
+ sns_path: str = Field()
+
+
+class TestSettings(Settings):
+ model_config = SettingsConfigDict(
+ env_prefix="",
+ case_sensitive=False,
+ env_file=os.path.join(BASE_DIR, ".env.dev"),
+ extra="allow",
+ cli_parse_args=False,
+ )
+ debug: bool = True
+ app_name: str = "AMT JB API Development"
+
+
+@lru_cache
+def get_settings():
+ return Settings()
+
+
+@lru_cache
+def get_test_settings():
+ return TestSettings()
diff --git a/jb/views/common.py b/jb/views/common.py
new file mode 100644
index 0000000..46ac608
--- /dev/null
+++ b/jb/views/common.py
@@ -0,0 +1,186 @@
+import json
+from typing import List
+from uuid import uuid4
+
+import requests
+from fastapi import Request, APIRouter, Response, HTTPException, Query
+from fastapi.responses import HTMLResponse, JSONResponse
+from pydantic import BaseModel, ConfigDict, Field
+from starlette.responses import RedirectResponse
+
+from jb.config import settings, JB_EVENTS_STREAM
+from jb.decorators import REDIS, HM
+from jb.flow.monitoring import emit_assignment_event, emit_mturk_notification_event
+from jb.models.currency import USDCent
+from jb.models.definitions import ReportValue, AssignmentStatus
+from jb.models.event import MTurkEvent
+from jb.settings import BASE_HTML
+from jb.config import settings
+from jb.views.tasks import process_request
+
+common_router = APIRouter(prefix="", tags=["API"], include_in_schema=True)
+
+
+class ReportTask(BaseModel):
+ model_config = ConfigDict(extra="forbid")
+
+ worker_id: str = Field()
+
+ reasons: List[ReportValue] = Field(
+ examples=[[3, 4]],
+ default_factory=list,
+ )
+
+ notes: str = Field(
+ default="", examples=["The survey wanted to watch me eat Haejang-guk"]
+ )
+
+
+@common_router.post("/report/")
+def report(request: Request, data: ReportTask):
+ url = f"{settings.fsb_host}{settings.product_id}/report/"
+ params = {
+ "bpuid": data.worker_id,
+ "reasons": [x.value for x in data.reasons],
+ "notes": data.notes,
+ }
+
+ req = requests.post(url, json=params)
+ res = req.json()
+ if res.status_code != 200:
+ raise HTTPException(
+ status_code=res.status_code, detail="Failed to submit report"
+ )
+
+ return Response(res)
+
+
+@common_router.get(path="/work/", response_class=HTMLResponse)
+async def work(request: Request):
+ """
+ HTML Page that the worker lands on in an iFrame.
+ They can either be previewing the HIT, or have already accepted it.
+ """
+ amt_assignment_id = request.query_params.get("assignmentId", None)
+ worker_id = request.query_params.get("workerId", None)
+ amt_hit_id = request.query_params.get("hitId", None)
+ print(f"work: {amt_assignment_id=} {worker_id=} {amt_hit_id=}")
+
+ if not worker_id:
+ return RedirectResponse(
+ url=f"/preview/?{request.url.query}" if request.url.query else "/preview/",
+ status_code=302,
+ )
+ if amt_assignment_id is None or amt_assignment_id == "ASSIGNMENT_ID_NOT_AVAILABLE":
+ # Worker is previewing the HIT
+ amt_hit_type_id = "unknown"
+ if amt_hit_id:
+ hit = HM.get_from_amt_id(amt_hit_id=amt_hit_id)
+ amt_hit_type_id = hit.amt_hit_type_id
+ emit_assignment_event(
+ status=AssignmentStatus.PreviewState, amt_hit_type_id=amt_hit_type_id
+ )
+ return RedirectResponse(
+ url=f"/preview/?{request.url.query}" if request.url.query else "/preview/",
+ status_code=302,
+ )
+
+ # The Worker has accepted the HIT
+ process_request(request)
+
+ return HTMLResponse(BASE_HTML)
+
+
+@common_router.get(path="/survey/", response_class=JSONResponse)
+def survey(
+ request: Request,
+ worker_id: str = Query(),
+ duration: int = Query(default=1200),
+):
+ if not worker_id:
+ raise HTTPException(status_code=400, detail="Missing worker_id")
+
+ # (1) Check wallet
+ wallet_url = f"{settings.fsb_host}{settings.product_id}/wallet/"
+ wallet_res = requests.get(wallet_url, params={"bpuid": worker_id})
+ if wallet_res.status_code != 200:
+ raise HTTPException(status_code=502, detail="Wallet check failed")
+
+ wallet_data = wallet_res.json()
+ wallet_balance = wallet_data["wallet"]["amount"]
+ if wallet_balance < -100:
+ return JSONResponse(
+ {
+ "total_surveys": 0,
+ "link": None,
+ "duration": None,
+ "payout": None,
+ }
+ )
+
+ # (2) Get offerwall
+ client_ip = "69.253.144.55" if settings.debug else request.client.host
+ offerwall_url = f"{settings.fsb_host}{settings.product_id}/offerwall/d48cce47/"
+ offerwall_res = requests.get(
+ offerwall_url,
+ params={
+ "bpuid": worker_id,
+ "ip": client_ip,
+ "n_bins": 1,
+ "duration": duration,
+ },
+ )
+
+ if offerwall_res.status_code != 200:
+ raise HTTPException(status_code=502, detail="Offerwall request failed")
+
+ try:
+ rj = offerwall_res.json()
+ bucket = rj["offerwall"]["buckets"][0]
+ return JSONResponse(
+ {
+ "total_surveys": rj["offerwall"]["availability_count"],
+ "link": bucket["uri"],
+ "duration": round(bucket["duration"]["q2"] / 60),
+ "payout": USDCent(bucket["payout"]["q2"]).to_usd_str(),
+ }
+ )
+ except Exception:
+ return JSONResponse(
+ {
+ "total_surveys": 0,
+ "link": None,
+ "duration": None,
+ "payout": None,
+ }
+ )
+
+
+@common_router.post(path=f"/{settings.sns_path}/", include_in_schema=False)
+async def mturk_notifications(request: Request):
+ """
+ Our SNS topic will POST to this endpoint whenever we get a new message
+ """
+ message = await request.json()
+ msg_type = message.get("Type")
+
+ if msg_type == "SubscriptionConfirmation":
+ subscribe_url = message["SubscribeURL"]
+ print("Confirming SNS subscription...")
+ requests.get(subscribe_url)
+
+ elif msg_type == "Notification":
+ msg = json.loads(message["Message"])
+ print("Received MTurk event:", msg)
+ enqueue_mturk_notifications(msg)
+
+ return {"status": "ok"}
+
+
+def enqueue_mturk_notifications(msg) -> None:
+ for evt in msg["Events"]:
+ event = MTurkEvent.from_sns(evt)
+ emit_mturk_notification_event(
+ event_type=event.event_type, amt_hit_type_id=event.amt_hit_type_id
+ )
+ REDIS.xadd(JB_EVENTS_STREAM, {"data": event.model_dump_json()})