diff options
| author | Max Nanis | 2026-02-19 02:43:23 -0500 |
|---|---|---|
| committer | Max Nanis | 2026-02-19 02:43:23 -0500 |
| commit | f0f96f83c2630e890a2cbcab53f77fd4c37e1684 (patch) | |
| tree | c6d2cb092e76bf5d499e0ea9949508d6b22164fd /jb | |
| parent | 3eaa56f0306ead818f64c3d99fc6d230d9b970a4 (diff) | |
| download | amt-jb-master.tar.gz amt-jb-master.zip | |
Diffstat (limited to 'jb')
| -rw-r--r-- | jb/managers/__init__.py | 23 | ||||
| -rw-r--r-- | jb/managers/amt.py | 216 | ||||
| -rw-r--r-- | jb/managers/assignment.py | 259 | ||||
| -rw-r--r-- | jb/managers/bonus.py | 54 | ||||
| -rw-r--r-- | jb/managers/hit.py | 338 | ||||
| -rw-r--r-- | jb/managers/worker.py | 16 | ||||
| -rw-r--r-- | jb/models/__init__.py | 40 | ||||
| -rw-r--r-- | jb/models/api_response.py | 17 | ||||
| -rw-r--r-- | jb/models/assignment.py | 388 | ||||
| -rw-r--r-- | jb/models/bonus.py | 48 | ||||
| -rw-r--r-- | jb/models/currency.py | 70 | ||||
| -rw-r--r-- | jb/models/custom_types.py | 113 | ||||
| -rw-r--r-- | jb/models/definitions.py | 90 | ||||
| -rw-r--r-- | jb/models/errors.py | 80 | ||||
| -rw-r--r-- | jb/models/event.py | 38 | ||||
| -rw-r--r-- | jb/models/hit.py | 251 | ||||
| -rw-r--r-- | jb/views/__init__.py | 0 |
17 files changed, 2041 insertions, 0 deletions
diff --git a/jb/managers/__init__.py b/jb/managers/__init__.py new file mode 100644 index 0000000..e2aab6d --- /dev/null +++ b/jb/managers/__init__.py @@ -0,0 +1,23 @@ +from enum import IntEnum +from typing import Collection + +from generalresearchutils.pg_helper import PostgresConfig + + +class Permission(IntEnum): + READ = 1 + UPDATE = 2 + CREATE = 3 + DELETE = 4 + + +class PostgresManager: + def __init__( + self, + pg_config: PostgresConfig, + permissions: Collection[Permission] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.pg_config = pg_config + self.permissions = set(permissions) if permissions else set() diff --git a/jb/managers/amt.py b/jb/managers/amt.py new file mode 100644 index 0000000..79661c7 --- /dev/null +++ b/jb/managers/amt.py @@ -0,0 +1,216 @@ +import logging +from datetime import timezone, datetime +from typing import Tuple, Optional, List + +import botocore.exceptions +from mypy_boto3_mturk.type_defs import AssignmentTypeDef, BonusPaymentTypeDef + +from jb.config import TOPIC_ARN +from jb.decorators import AMT_CLIENT +from jb.models import AMTAccount +from jb.models.assignment import Assignment +from jb.models.bonus import Bonus +from jb.models.currency import USDCent +from jb.models.definitions import HitStatus +from jb.models.hit import HitType, HitQuestion, Hit + +REJECT_MESSAGE_UNKNOWN_ASSIGNMENT = "Unknown assignment" +REJECT_MESSAGE_NO_WORK = "Assignment was submitted with no attempted work." +REJECT_MESSAGE_BADDIE = "Quality has dropped below an acceptable level" +APPROVAL_MESSAGE = "Thank you!" +NO_WORK_APPROVAL_MESSAGE = ( + REJECT_MESSAGE_NO_WORK + " In the future, if you are not sent into a task, " + "please return the assignment, otherwise it will be rejected!" +) +BONUS_MESSAGE = "Great job! Bonus for a survey complete" + + +class AMTManager: + + @staticmethod + def fetch_account() -> AMTAccount: + res = AMT_CLIENT.get_account_balance() + return AMTAccount.model_validate( + { + "available_balance": res["AvailableBalance"], + "onhold_balance": res.get("OnHoldBalance", 0), + } + ) + + @staticmethod + def get_hit_if_exists(amt_hit_id: str) -> Tuple[Optional[Hit], Optional[str]]: + try: + res = AMT_CLIENT.get_hit(HITId=amt_hit_id) + except AMT_CLIENT.exceptions.RequestError as e: + msg = e.response.get("Error", {}).get("Message", "") + return None, msg + hit = Hit.from_amt_get_hit(res["HIT"]) + return hit, None + + @classmethod + def get_hit_status(cls, amt_hit_id: str): + res, msg = cls.get_hit_if_exists(amt_hit_id=amt_hit_id) + if res is None: + if " does not exist. (" in msg: + return HitStatus.Disposed + else: + logging.warning(msg) + return HitStatus.Unassignable + return res.status + + @staticmethod + def create_hit_type(hit_type: HitType): + res = AMT_CLIENT.create_hit_type(**hit_type.to_api_request_body()) + hit_type.amt_hit_type_id = res["HITTypeId"] + AMT_CLIENT.update_notification_settings( + HITTypeId=hit_type.amt_hit_type_id, + Notification={ + "Destination": TOPIC_ARN, + "Transport": "SNS", + "Version": "2006-05-05", + # you can add more events, see mypy_boto3_mturk.literals.EventTypeType + "EventTypes": ["AssignmentSubmitted"], + }, + Active=True, + ) + + return hit_type + + @staticmethod + def create_hit_with_hit_type(hit_type: HitType, question: HitQuestion) -> Hit: + """ + HITTypeId: str + LifetimeInSeconds: int + MaxAssignments: NotRequired[int] + Question: NotRequired[str] + UniqueRequestToken: NotRequired[str] + """ + assert hit_type.id is not None + assert hit_type.amt_hit_type_id is not None + + data = hit_type.generate_hit_amt_request(question=question) + res = AMT_CLIENT.create_hit_with_hit_type(**data) + return Hit.from_amt_create_hit(res["HIT"], hit_type=hit_type, question=question) + + @staticmethod + def get_assignment(amt_assignment_id: str) -> Assignment: + # note, you CANNOT get an assignment if it has been only ACCEPTED (by the worker) + # the api is stupid. it will only show up once it is submitted + res = AMT_CLIENT.get_assignment(AssignmentId=amt_assignment_id) + ass_res: AssignmentTypeDef = res["Assignment"] + assignment = Assignment.from_amt_get_assignment(ass_res) + # to be clear, this has not checked whether it exists in our db + assert assignment.id is None + return assignment + + @classmethod + def get_assignment_if_exists(cls, amt_assignment_id: str) -> Optional[Assignment]: + expected_err_msg = f"Assignment {amt_assignment_id} does not exist" + try: + return cls.get_assignment(amt_assignment_id=amt_assignment_id) + except botocore.exceptions.ClientError as e: + logging.warning(e) + error_code = e.response["Error"]["Code"] + error_msg = e.response["Error"]["Message"] + if error_code == "RequestError" and expected_err_msg in error_msg: + return None + raise e + + @staticmethod + def reject_assignment_if_possible( + amt_assignment_id: str, msg: str = REJECT_MESSAGE_UNKNOWN_ASSIGNMENT + ): + # Unclear to me when this would fail + try: + return AMT_CLIENT.reject_assignment( + AssignmentId=amt_assignment_id, RequesterFeedback=msg + ) + except botocore.exceptions.ClientError as e: + logging.warning(e) + return None + + @staticmethod + def approve_assignment_if_possible( + amt_assignment_id: str, + msg: str = APPROVAL_MESSAGE, + override_rejection: bool = False, + ): + # Unclear to me when this would fail + try: + return AMT_CLIENT.approve_assignment( + AssignmentId=amt_assignment_id, + RequesterFeedback=msg, + OverrideRejection=override_rejection, + ) + except botocore.exceptions.ClientError as e: + logging.warning(e) + return None + + @staticmethod + def update_hit_review_status(amt_hit_id: str, revert: bool = False) -> None: + try: + # Reviewable to Reviewing + AMT_CLIENT.update_hit_review_status(HITId=amt_hit_id, Revert=revert) + except botocore.exceptions.ClientError as e: + logging.warning(f"{amt_hit_id=}, {e}") + error_msg = e.response["Error"]["Message"] + if "does not exist" in error_msg: + raise ValueError(error_msg) + # elif "This HIT is currently in the state 'Reviewing'" in error_msg: + # logging.warning(error_msg) + return None + + @staticmethod + def send_bonus( + amt_worker_id: str, + amount: USDCent, + amt_assignment_id: str, + reason: str, + unique_request_token: str, + ): + try: + return AMT_CLIENT.send_bonus( + WorkerId=amt_worker_id, + BonusAmount=str(amount.to_usd()), + AssignmentId=amt_assignment_id, + Reason=reason, + UniqueRequestToken=unique_request_token, + ) + except botocore.exceptions.ClientError as e: + logging.warning(f"{amt_worker_id=} {amt_assignment_id=}, {e}") + return None + + @staticmethod + def get_bonus(amt_assignment_id: str, payout_event_id: str) -> Optional[Bonus]: + res: List[BonusPaymentTypeDef] = AMT_CLIENT.list_bonus_payments( + AssignmentId=amt_assignment_id + )["BonusPayments"] + assert ( + len(res) <= 1 + ), f"{amt_assignment_id=} Expected 1 or 0 bonuses, got {len(res)}" + d = res[0] if res else None + if d: + return Bonus.model_validate( + { + "amt_worker_id": d["WorkerId"], + "amount": USDCent(round(float(d["BonusAmount"]) * 100)), + "amt_assignment_id": d["AssignmentId"], + "reason": d["Reason"], + "grant_time": d["GrantTime"].astimezone(tz=timezone.utc), + "payout_event_id": payout_event_id, + } + ) + return None + + @staticmethod + def expire_all_hits(): + # used in testing only (or in an emergency I guess) + now = datetime.now(tz=timezone.utc) + paginator = AMT_CLIENT.get_paginator("list_hits") + + for page in paginator.paginate(): + for hit in page["HITs"]: + if hit["HITStatus"] in ("Assignable", "Reviewable", "Reviewing"): + AMT_CLIENT.update_expiration_for_hit( + HITId=hit["HITId"], ExpireAt=now + ) diff --git a/jb/managers/assignment.py b/jb/managers/assignment.py new file mode 100644 index 0000000..fca72e8 --- /dev/null +++ b/jb/managers/assignment.py @@ -0,0 +1,259 @@ +from datetime import datetime, timezone +from typing import Optional + +from psycopg import sql +from pydantic import NonNegativeInt + +from jb.managers import PostgresManager +from jb.models.assignment import AssignmentStub, Assignment +from jb.models.definitions import AssignmentStatus + + +class AssignmentManager(PostgresManager): + + def create_stub(self, stub: AssignmentStub) -> None: + assert stub.id is None + data = stub.to_postgres() + query = sql.SQL( + """ + INSERT INTO mtwerk_assignment + (amt_assignment_id, amt_worker_id, status, created_at, modified_at, hit_id) + VALUES + (%(amt_assignment_id)s, %(amt_worker_id)s, %(status)s, + %(created_at)s, %(modified_at)s, %(hit_id)s) + RETURNING id; + """ + ) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + pk = c.fetchone()["id"] + conn.commit() + stub.id = pk + return None + + def create(self, assignment: Assignment) -> None: + # Typically this is NOT used (we'd create the stub when HIT is + # accepted, then update it once the assignment is submitted), however + # there may be cases where an assignment is submitted and the stub + # was never created (probably due to error/downtime/baddie). + + assert assignment.id is None + data = assignment.to_postgres() + query = sql.SQL( + """ + INSERT INTO mtwerk_assignment + (amt_assignment_id, amt_worker_id, status, + created_at, modified_at, hit_id, + auto_approval_time, accept_time, submit_time, + approval_time, rejection_time, requester_feedback, + tsid) + VALUES + (%(amt_assignment_id)s, %(amt_worker_id)s, %(status)s, + %(created_at)s, %(modified_at)s, %(hit_id)s, + %(auto_approval_time)s, %(accept_time)s, %(submit_time)s, + %(approval_time)s, %(rejection_time)s, %(requester_feedback)s, + %(tsid)s) + RETURNING id; + """ + ) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + pk = c.fetchone()["id"] + conn.commit() + assignment.id = pk + return None + + def get_stub(self, amt_assignment_id: str) -> AssignmentStub: + res = self.pg_config.execute_sql_query( + """ + SELECT ma.id, ma.hit_id, + ma.amt_assignment_id, ma.amt_worker_id, + ma.status, ma.created_at, ma.modified_at, + mh.amt_hit_id + FROM mtwerk_assignment ma + JOIN mtwerk_hit mh ON ma.hit_id = mh.id + WHERE amt_assignment_id = %(amt_assignment_id)s + LIMIT 1; + """, + params={"amt_assignment_id": amt_assignment_id}, + ) + assert len(res) == 1 + return AssignmentStub.model_validate(res[0]) + + def get_stub_if_exists(self, amt_assignment_id: str) -> Optional[AssignmentStub]: + try: + return self.get_stub(amt_assignment_id=amt_assignment_id) + except AssertionError: + return None + + def get(self, amt_assignment_id: str) -> Assignment: + res = self.pg_config.execute_sql_query( + query=""" + SELECT mh.amt_hit_id, ma.id, ma.amt_assignment_id, + ma.amt_worker_id, ma.status, ma.auto_approval_time, + ma.accept_time, ma.submit_time, ma.approval_time, + ma.rejection_time, ma.requester_feedback, + ma.created_at, ma.modified_at, ma.tsid, ma.hit_id + FROM mtwerk_assignment ma + JOIN mtwerk_hit mh ON ma.hit_id = mh.id + WHERE amt_assignment_id = %(amt_assignment_id)s + LIMIT 1; + """, + params={"amt_assignment_id": amt_assignment_id}, + ) + assert len(res) == 1 + return Assignment.model_validate(res[0]) + + def update_answer(self, assignment: Assignment) -> None: + # We're assuming a stub already exists + # The assignment was submitted, but we haven't made a decision yet + now = datetime.now(tz=timezone.utc) + data = { + "status": assignment.status.value, + "submit_time": assignment.submit_time, + "auto_approval_time": assignment.auto_approval_time, + "tsid": assignment.tsid, + "amt_assignment_id": assignment.amt_assignment_id, + "modified_at": now, + } + query = sql.SQL( + """ + UPDATE mtwerk_assignment + SET submit_time = %(submit_time)s, + auto_approval_time = %(auto_approval_time)s, + status = %(status)s, + tsid = %(tsid)s, + modified_at = %(modified_at)s + WHERE amt_assignment_id = %(amt_assignment_id)s + """ + ) + # We force this to fail if the assignment doesn't already exist in the db + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + assert c.rowcount == 1, f"Expected 1 row, got {c.rowcount}" + conn.commit() + return None + + def reject(self, assignment: Assignment) -> None: + assert assignment.status == AssignmentStatus.Rejected + assert assignment.rejection_time is not None + assert assignment.approval_time is None + assert assignment.requester_feedback is not None + now = datetime.now(tz=timezone.utc) + data = { + "status": assignment.status.value, + "submit_time": assignment.submit_time, + "rejection_time": assignment.rejection_time, + "requester_feedback": assignment.requester_feedback, + "amt_assignment_id": assignment.amt_assignment_id, + "auto_approval_time": assignment.auto_approval_time, + "accept_time": assignment.accept_time, + "modified_at": now, + } + query = sql.SQL( + """ + UPDATE mtwerk_assignment + SET submit_time = %(submit_time)s, + rejection_time = %(rejection_time)s, + status = %(status)s, + requester_feedback = %(requester_feedback)s, + auto_approval_time = %(auto_approval_time)s, + accept_time = %(accept_time)s, + modified_at = %(modified_at)s + WHERE amt_assignment_id = %(amt_assignment_id)s + """ + ) + # We force this to fail if the assignment doesn't already exist in the db + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + assert c.rowcount == 1, f"Expected 1 row, got {c.rowcount}" + conn.commit() + return None + + def approve(self, assignment: Assignment) -> None: + assert assignment.status == AssignmentStatus.Approved + assert assignment.rejection_time is None + assert assignment.approval_time is not None + assert assignment.requester_feedback is not None + now = datetime.now(tz=timezone.utc) + data = { + "status": assignment.status.value, + "submit_time": assignment.submit_time, + "approval_time": assignment.approval_time, + "requester_feedback": assignment.requester_feedback, + "amt_assignment_id": assignment.amt_assignment_id, + "auto_approval_time": assignment.auto_approval_time, + "accept_time": assignment.accept_time, + "modified_at": now, + } + query = sql.SQL( + """ + UPDATE mtwerk_assignment + SET submit_time = %(submit_time)s, + approval_time = %(approval_time)s, + status = %(status)s, + requester_feedback = %(requester_feedback)s, + auto_approval_time = %(auto_approval_time)s, + accept_time = %(accept_time)s, + modified_at = %(modified_at)s + WHERE amt_assignment_id = %(amt_assignment_id)s + """ + ) + # We force this to fail if the assignment doesn't already exist in the db + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + assert c.rowcount == 1, f"Expected 1 row, got {c.rowcount}" + conn.commit() + return None + + def missing_tsid_count( + self, amt_worker_id: str, lookback_hrs: int = 24 + ) -> NonNegativeInt: + """ + Look at this user's previous N hrs of submitted assignments. + Count how many assignments they have submitted without a tsid. + """ + res = self.pg_config.execute_sql_query( + query=""" + SELECT COUNT(1) AS c + FROM mtwerk_assignment + WHERE amt_worker_id = %(amt_worker_id)s + AND submit_time > NOW() - (%(lookback_interval)s)::interval + AND tsid IS NULL; + """, + params={ + "amt_worker_id": amt_worker_id, + "lookback_interval": f"{lookback_hrs} hour", + }, + ) + return int(res[0]["c"]) + + def rejected_count( + self, amt_worker_id: str, lookback_hrs: int = 24 + ) -> NonNegativeInt: + """ + Look at this user's previous N hrs of submitted assignments. + Count how many rejected assignments they have. + """ + res = self.pg_config.execute_sql_query( + query=""" + SELECT COUNT(1) AS c + FROM mtwerk_assignment + WHERE amt_worker_id = %(amt_worker_id)s + AND submit_time > NOW() - (%(lookback_interval)s)::interval + AND status = %(status)s; + """, + params={ + "amt_worker_id": amt_worker_id, + "lookback_interval": f"{lookback_hrs} hour", + "status": AssignmentStatus.Rejected.value, + }, + ) + return int(res[0]["c"]) diff --git a/jb/managers/bonus.py b/jb/managers/bonus.py new file mode 100644 index 0000000..0cb8b02 --- /dev/null +++ b/jb/managers/bonus.py @@ -0,0 +1,54 @@ +from typing import List + +from psycopg import sql + +from jb.managers import PostgresManager +from jb.models.bonus import Bonus + + +class BonusManager(PostgresManager): + + def create(self, bonus: Bonus) -> None: + assert bonus.id is None + data = bonus.to_postgres() + query = sql.SQL( + """ + INSERT INTO mtwerk_bonus + (payout_event_id, amt_worker_id, amount, grant_time, assignment_id, reason) + VALUES ( + %(payout_event_id)s, + %(amt_worker_id)s, + %(amount)s, + %(grant_time)s, + ( + SELECT id + FROM mtwerk_assignment + WHERE amt_assignment_id = %(amt_assignment_id)s + LIMIT 1 + ), + %(reason)s + ) + RETURNING id, assignment_id; + """ + ) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + res = c.fetchone() + conn.commit() + bonus.id = res["id"] + bonus.assignment_id = res["assignment_id"] + return None + + def filter(self, amt_assignment_id: str) -> List[Bonus]: + res = self.pg_config.execute_sql_query( + """ + SELECT mb.*, ma.amt_assignment_id + FROM mtwerk_bonus mb + JOIN mtwerk_assignment ma ON ma.id = mb.assignment_id + WHERE amt_assignment_id = %(amt_assignment_id)s; + """, + params={"amt_assignment_id": amt_assignment_id}, + ) + return [Bonus.from_postgres(x) for x in res] diff --git a/jb/managers/hit.py b/jb/managers/hit.py new file mode 100644 index 0000000..3832418 --- /dev/null +++ b/jb/managers/hit.py @@ -0,0 +1,338 @@ +from datetime import datetime, timezone +from typing import Optional, List + +from psycopg import sql + +from jb.managers import PostgresManager +from jb.models.definitions import HitStatus +from jb.models.hit import HitQuestion, HitType, Hit + + +class HitQuestionManager(PostgresManager): + + def create(self, question: HitQuestion) -> None: + assert question.id is None + data = question.to_postgres() + query = sql.SQL( + """ + INSERT INTO mtwerk_question (url, height) + VALUES (%(url)s, %(height)s) + RETURNING id; + """ + ) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + pk = c.fetchone()["id"] + conn.commit() + question.id = pk + return None + + def get_by_id(self, question_id: int) -> HitQuestion: + res = self.pg_config.execute_sql_query( + """ + SELECT * + FROM mtwerk_question + WHERE id = %(question_id)s + LIMIT 1; + """, + params={"question_id": question_id}, + ) + assert len(res) == 1 + return HitQuestion.model_validate(res[0]) + + def get_by_values_if_exists(self, url: str, height: int) -> Optional[HitQuestion]: + res = self.pg_config.execute_sql_query( + """ + SELECT * + FROM mtwerk_question + WHERE url = %(url)s AND height = %(height)s + LIMIT 2; + """, + params={"url": url, "height": height}, + ) + assert len(res) != 2, "More than 1 result!" + if len(res) == 0: + return None + + return HitQuestion.model_validate(res[0]) + + def get_or_create(self, question: HitQuestion) -> HitQuestion: + res = self.get_by_values_if_exists(url=question.url, height=question.height) + if res: + return res + self.create(question=question) + return question + + +class HitTypeManager(PostgresManager): + def create(self, hit_type: HitType) -> None: + assert hit_type.amt_hit_type_id is not None + data = hit_type.to_postgres() + query = sql.SQL( + """ + INSERT INTO mtwerk_hittype ( + amt_hit_type_id, + title, + description, + reward, + assignment_duration, + auto_approval_delay, + keywords, + min_active + ) + VALUES ( + %(amt_hit_type_id)s, + %(title)s, + %(description)s, + %(reward)s, + %(assignment_duration)s, + %(auto_approval_delay)s, + %(keywords)s, + %(min_active)s + ) + RETURNING id; + """ + ) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + pk = c.fetchone()["id"] + conn.commit() + hit_type.id = pk + return None + + def filter_active(self) -> List[HitType]: + res = self.pg_config.execute_sql_query( + """ + SELECT * + FROM mtwerk_hittype + WHERE min_active > 0 + LIMIT 50 + """ + ) + + if len(res) == 50: + raise ValueError("Too many HitTypes!") + + return [HitType.from_postgres(i) for i in res] + + def get(self, amt_hit_type_id: str) -> HitType: + res = self.pg_config.execute_sql_query( + """ + SELECT * + FROM mtwerk_hittype + WHERE amt_hit_type_id = %(amt_hit_type_id)s + """, + params={"amt_hit_type_id": amt_hit_type_id}, + ) + assert len(res) == 1 + return HitType.from_postgres(res[0]) + + def get_if_exists(self, amt_hit_type_id: str) -> Optional[HitType]: + try: + return self.get(amt_hit_type_id=amt_hit_type_id) + except AssertionError: + return None + + def get_or_create(self, hit_type: HitType) -> None: + res = self.get_if_exists(amt_hit_type_id=hit_type.amt_hit_type_id) + if res: + hit_type.id = res.id + if res is None: + self.create(hit_type=hit_type) + return None + + def set_min_active(self, hit_type: HitType) -> None: + assert hit_type.id, "must be in the db first!" + query = sql.SQL( + """ + UPDATE mtwerk_hittype + SET min_active = %(min_active)s + WHERE id = %(id)s + """ + ) + data = {"id": hit_type.id, "min_active": hit_type.min_active} + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + conn.commit() + assert c.rowcount == 1, c.rowcount + return None + + +class HitManager(PostgresManager): + def create(self, hit: Hit): + assert hit.amt_hit_id is not None + assert hit.id is None + data = hit.to_postgres() + query = sql.SQL( + """ + INSERT INTO mtwerk_hit ( + amt_hit_id, + hit_type_id, + amt_group_id, + status, + review_status, + creation_time, + expiration, + question_id, + created_at, + modified_at, + max_assignments, + assignment_pending_count, + assignment_completed_count, + assignment_available_count + ) + VALUES ( + %(amt_hit_id)s, + %(hit_type_id)s, + %(amt_group_id)s, + %(status)s, + %(review_status)s, + %(creation_time)s, + %(expiration)s, + %(question_id)s, + %(created_at)s, + %(modified_at)s, + %(max_assignments)s, + %(assignment_pending_count)s, + %(assignment_completed_count)s, + %(assignment_available_count)s + ) + RETURNING id; + """ + ) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + pk = c.fetchone()["id"] + conn.commit() + hit.id = pk + return hit + + def update_status(self, amt_hit_id: str, hit_status: HitStatus): + now = datetime.now(tz=timezone.utc) + query = sql.SQL( + """ + UPDATE mtwerk_hit + SET status = %(status)s, modified_at = %(modified_at)s + WHERE amt_hit_id = %(amt_hit_id)s; + """ + ) + + data = { + "amt_hit_id": amt_hit_id, + "status": hit_status.value, + "modified_at": now, + } + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + conn.commit() + assert c.rowcount == 1, c.rowcount + return None + + def update_hit(self, hit: Hit): + hit.modified_at = datetime.now(tz=timezone.utc) + # fields expected to change: + fields = { + "status", + "review_status", + "assignment_pending_count", + "assignment_available_count", + "assignment_completed_count", + "amt_hit_id", + "modified_at", + } + query = sql.SQL( + """ + UPDATE mtwerk_hit + SET status = %(status)s, review_status = %(review_status)s, + assignment_pending_count = %(assignment_pending_count)s, + assignment_available_count = %(assignment_available_count)s, + assignment_completed_count = %(assignment_completed_count)s, + modified_at = %(modified_at)s + WHERE amt_hit_id = %(amt_hit_id)s + RETURNING id; + """ + ) + + data = hit.model_dump(mode="json", include=fields) + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + conn.commit() + assert c.rowcount == 1, c.rowcount + hit.id = c.fetchone()["id"] + return None + + def get_from_amt_id(self, amt_hit_id: str) -> Hit: + res = self.pg_config.execute_sql_query( + """ + SELECT + mh.id, + mh.amt_hit_id, + mh.amt_group_id, + mh.status, + mh.review_status, + mh.creation_time, + mh.expiration, + mh.created_at, + mh.modified_at, + mh.assignment_available_count, + mh.assignment_completed_count, + mh.assignment_pending_count, + mh.max_assignments, + mht.amt_hit_type_id, + mht.title, + mht.description, + mht.reward, + mht.assignment_duration, + mht.auto_approval_delay, + mht.keywords, + mq.id as question_id, + mq.height, + mq.url + FROM mtwerk_hit mh + JOIN mtwerk_hittype mht ON mh.hit_type_id = mht.id + JOIN mtwerk_question mq ON mh.question_id = mq.id + WHERE amt_hit_id = %(amt_hit_id)s + LIMIT 2; + """, + params={"amt_hit_id": amt_hit_id}, + ) + assert len(res) == 1 + res = res[0] + question_xml = HitQuestion.model_validate( + {"height": res.pop("height"), "url": res.pop("url")} + ).xml + res["question_id"] = res["question_id"] + res["hit_question_xml"] = question_xml + + return Hit.from_postgres(res) + + def get_active_count(self, hit_type_id: int): + return self.pg_config.execute_sql_query( + """ + SELECT COUNT(1) as active_count + FROM mtwerk_hit + WHERE status = %(status)s + AND hit_type_id = %(hit_type_id)s; + """, + params={"status": HitStatus.Assignable, "hit_type_id": hit_type_id}, + )[0]["active_count"] + + def filter_active_ids(self, hit_type_id: int): + res = self.pg_config.execute_sql_query( + """ + SELECT mh.amt_hit_id + FROM mtwerk_hit mh + WHERE hit_type_id = %(hit_type_id)s; + """, + params={"hit_type_id": hit_type_id}, + ) + return {x["amt_hit_id"] for x in res} diff --git a/jb/managers/worker.py b/jb/managers/worker.py new file mode 100644 index 0000000..e2d7237 --- /dev/null +++ b/jb/managers/worker.py @@ -0,0 +1,16 @@ +from typing import List + +from mypy_boto3_mturk.type_defs import WorkerBlockTypeDef + +from jb.decorators import AMT_CLIENT + + +class WorkerManager: + + @staticmethod + def fetch_worker_blocks() -> List[WorkerBlockTypeDef]: + p = AMT_CLIENT.get_paginator("list_worker_blocks") + res: List[WorkerBlockTypeDef] = [] + for item in p.paginate(): + res.extend(item["WorkerBlocks"]) + return res diff --git a/jb/models/__init__.py b/jb/models/__init__.py new file mode 100644 index 0000000..0aeae14 --- /dev/null +++ b/jb/models/__init__.py @@ -0,0 +1,40 @@ +from decimal import Decimal +from typing import Optional + +from pydantic import BaseModel, Field, ConfigDict + + +class HTTPHeaders(BaseModel): + request_id: str = Field(alias="x-amzn-requestid", min_length=36, max_length=36) + content_type: str = Field(alias="content-type", min_length=26, max_length=26) + # 'content-length': '1255', + content_length: str = Field(alias="content-length", min_length=2) + # 'Mon, 15 Jan 2024 23:40:32 GMT' + date: str = Field() + + connection: Optional[str] = Field(default=None) # 'close' + + +class ResponseMetadata(BaseModel): + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + request_id: str = Field(alias="RequestId", min_length=36, max_length=36) + status_code: int = Field(alias="HTTPStatusCode", ge=200, le=599) + headers: HTTPHeaders = Field(alias="HTTPHeaders") + retry_attempts: int = Field(alias="RetryAttempts", ge=0) + + +class AMTAccount(BaseModel): + model_config = ConfigDict(extra="ignore", validate_assignment=True) + + # Remaining available AWS Billing usage if you have enabled AWS Billing. + available_balance: Decimal = Field() + onhold_balance: Decimal = Field(default=Decimal(0)) + + # --- Properties --- + + @property + def is_healthy(self) -> bool: + # A healthy account is one with at least $2,500 worth of + # credit available to it + return self.available_balance >= 2_500 diff --git a/jb/models/api_response.py b/jb/models/api_response.py new file mode 100644 index 0000000..6b29e51 --- /dev/null +++ b/jb/models/api_response.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from jb.models.assignment import Assignment +from jb.models.hit import Hit + + +class AssignmentResponse(BaseModel): + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + assignment: Assignment = Field(alias="Assignment") + hit: Hit = Field(alias="HIT") + + @model_validator(mode="after") + def check_consistent_hit_id(self) -> "AssignmentResponse": + if self.hit.id != self.assignment.hit_id: + raise ValueError("Inconsistent Hit IDs") + return self diff --git a/jb/models/assignment.py b/jb/models/assignment.py new file mode 100644 index 0000000..39ae47c --- /dev/null +++ b/jb/models/assignment.py @@ -0,0 +1,388 @@ +import logging +from datetime import datetime, timezone +from typing import Optional, TypedDict +from xml.etree import ElementTree + +from mypy_boto3_mturk.type_defs import AssignmentTypeDef +from pydantic import ( + BaseModel, + Field, + ConfigDict, + model_validator, + PositiveInt, + computed_field, + TypeAdapter, + ValidationError, +) +from typing_extensions import Self + +from jb.models.custom_types import AMTBoto3ID, AwareDatetimeISO, UUIDStr +from jb.models.definitions import AssignmentStatus + + +class AnswerDict(TypedDict): + amt_assignment_id: str + amt_worker_id: str + tsid: str + + +class AssignmentStub(BaseModel): + # todo: we need an "AssignmentStub" model that just has + # the IDs, this is used when a user accepts an assignment + # but hasn't submitted it yet. We want to create it in the db + # at that point. + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + id: Optional[PositiveInt] = Field(default=None) + hit_id: Optional[PositiveInt] = Field(default=None) + amt_assignment_id: AMTBoto3ID = Field() + amt_hit_id: AMTBoto3ID = Field() + amt_worker_id: str = Field(min_length=3, max_length=50) + + status: AssignmentStatus = Field() + + # GRL Specific + created_at: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When this record was saved in the database", + ) + + modified_at: Optional[AwareDatetimeISO] = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When this record was updated / modified in the database", + ) + + def to_postgres(self): + d = self.model_dump(mode="json") + return d + + +class Assignment(AssignmentStub): + """ + The Assignment data structure represents a single assignment of a HIT to + a Worker. The assignment tracks the Worker's efforts to complete the HIT, + and contains the results for later retrieval. + + The Assignment data structure is used as a response element for the + following operations: + + GetAssignment + GetAssignmentsForHIT + + https://docs.aws.amazon.com/AWSMechTurk/latest/AWSMturkAPI/ApiReference_AssignmentDataStructureArticle.html + """ + + auto_approval_time: AwareDatetimeISO = Field( + description="If results have been submitted, AutoApprovalTime is the " + "date and time the results of the assignment results are " + "considered Approved automatically if they have not already " + "been explicitly approved or rejected by the Requester. " + "This value is derived from the auto-approval delay " + "specified by the Requester in the HIT. This value is " + "omitted from the assignment if the Worker has not yet " + "submitted results.", + ) + + accept_time: AwareDatetimeISO = Field( + description="The date and time the Worker accepted the assignment.", + ) + + submit_time: AwareDatetimeISO = Field( + description="The date and time the assignment was submitted. This value " + "is omitted from the assignment if the Worker has not yet " + "submitted results.", + ) + + approval_time: Optional[AwareDatetimeISO] = Field( + default=None, + description="The date and time the Requester approved the results. This " + "value is omitted from the assignment if the Requester has " + "not yet approved the results.", + ) + rejection_time: Optional[AwareDatetimeISO] = Field( + default=None, + description="The date and time the Requester rejected the results.", + ) + + requester_feedback: Optional[str] = Field( + # Default: None. This field isn't returned with assignment data by + # default. To request this field, specify a response group of + # AssignmentFeedback. For information about response groups, see + # Common Parameters. + default=None, + min_length=3, + max_length=2_000, + help_text="The feedback string included with the call to the " + "ApproveAssignment operation or the RejectAssignment " + "operation, if the Requester approved or rejected the " + "assignment and specified feedback.", + ) + + answer_xml: Optional[str] = Field(default=None, exclude=True) + + # GRL Specific + + tsid: Optional[UUIDStr] = Field(default=None) + + # --- Validators --- + + @model_validator(mode="before") + def set_tsid(cls, values: dict): + if values.get("tsid") is None and (answer_xml := values.get("answer_xml")): + answer_dict = cls.parse_answer_xml(answer_xml) + tsid = answer_dict.get("tsid") + try: + values["tsid"] = TypeAdapter(UUIDStr).validate_python(tsid) + except ValidationError as e: + # Don't break the model validation if a baddie messes with the tsid in the answer. + logging.warning(e) + values["tsid"] = None + return values + + @model_validator(mode="after") + def check_time_sequences(self) -> Self: + if self.accept_time > self.submit_time: + raise ValueError("Assignment times invalid") + + return self + + @model_validator(mode="after") + def check_answers_alignment(self) -> Self: + if self.answers_dict is None: + return self + if self.amt_worker_id != self.answers_dict["amt_worker_id"]: + raise ValueError("Assignment answer invalid worker_id") + if self.amt_assignment_id != self.answers_dict["amt_assignment_id"]: + raise ValueError("Assignment answer invalid amt_assignment_id") + if ( + self.tsid + and self.answers_dict["tsid"] + and self.tsid != self.answers_dict["tsid"] + ): + raise ValueError("Assignment answer invalid tsid") + return self + + # --- Properties --- + + @property + def answers_dict(self) -> Optional[AnswerDict]: + # See https://docs.aws.amazon.com/AWSMechTurk/latest/AWSMturkAPI/ApiReference_AssignmentDataStructureArticle.html + # https://docs.aws.amazon.com/AWSMechTurk/latest/AWSMechanicalTurkRequester/Concepts_NotificationsArticle.html + if self.answer_xml is None: + return None + + return self.parse_answer_xml(self.answer_xml) + + @staticmethod + def parse_answer_xml(answer_xml: str): + root = ElementTree.fromstring(answer_xml) + ns = { + "mt": "http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2005-10-01/QuestionFormAnswers.xsd" + } + res = {} + + for a in root.findall("mt:Answer", ns): + name = a.find("mt:QuestionIdentifier", ns).text + value = a.find("mt:FreeText", ns).text + res[name] = value or "" + + EXPECTED_KEYS = {"amt_assignment_id", "amt_worker_id", "tsid"} + # We don't want validation to fail if a baddie inserts or changes url + # params, which will result in new or missing keys. Amazon generates the xml + # so that should always be correct + # assert all(k in res for k in EXPECTED_KEYS), list(res.keys()) + res = {k: v for k, v in res.items() if k in EXPECTED_KEYS} + return res + + @classmethod + def from_amt_get_assignment(cls, data: AssignmentTypeDef) -> Self: + assignment = cls( + amt_assignment_id=data["AssignmentId"], + amt_hit_id=data["HITId"], + amt_worker_id=data["WorkerId"], + status=AssignmentStatus[data["AssignmentStatus"]], + auto_approval_time=data["AutoApprovalTime"].astimezone(tz=timezone.utc), + accept_time=data["AcceptTime"].astimezone(tz=timezone.utc), + submit_time=data["SubmitTime"].astimezone(tz=timezone.utc), + approval_time=( + data["ApprovalTime"].astimezone(tz=timezone.utc) + if data.get("ApprovalTime") + else None + ), + rejection_time=( + data["RejectionTime"].astimezone(tz=timezone.utc) + if data.get("RejectionTime") + else None + ), + answer_xml=data["Answer"], + requester_feedback=data.get("RequesterFeedback"), + ) + return assignment + + def to_stub(self) -> AssignmentStub: + return AssignmentStub.model_validate( + self.model_dump(include=set(AssignmentStub.model_fields.keys())) + ) + + # --- Methods --- + # + # def refresh(self) -> Self: + # from tasks.mtwerk.managers.assignment import AssignmentManager + # return AssignmentManager.fetch_by_id(self) + # + # def reject(self, msg: str = REJECT_MESSAGE_UNKNOWN_ASSIGNMENT): + # """ + # Save in the database that the Assignment was rejected, and also + # Report to Amazon Mechanical Turk that this Assignment should be + # rejected + # + # TODO: can this only occur when the Assignment is in a certain status? + # + # :return: + # """ + # now = datetime.now(tz=None) + # + # MYSQLC.execute_sql_query(""" + # UPDATE `amt-jb`.`mtwerk_assignment` + # SET submit_time = %s, rejection_time = %s, status = %s, + # requester_feedback = %s + # WHERE assignment_id = %s""", + # params=[ + # now, now, + # AssignmentStatus.Rejected.value, + # msg, self.id], + # commit=True) + # + # CLIENT.reject_assignment( + # AssignmentId=self.id, + # RequesterFeedback=msg) + # + # def approve(self, msg: str = "Approved."): + # """ + # Report to Amazon Mechanical Turk that this Assignment should be + # approved + # + # TODO: can this only occur when the Assignment is in a certain status? + # + # :return: + # """ + # CLIENT.approve_assignment( + # AssignmentId=self.id, + # RequesterFeedback=msg) + # + # def submit_and_complete_request(self) -> Optional[str]: + # """ + # This approves the Assignment and issues the Reward + # amount (typically $.05) + # + # :return: + # """ + # worker = self.worker + # amount = DecimalUSDDollars(self.hit.reward) + # + # # If successful, returns the cashout id, otherwise, returns None + # cashout: Optional[dict] = worker.cashout_request( + # amount=amount, + # cashout_method_id=AMT_ASSIGNMENT_CASHOUT_METHOD) + # + # if cashout is None or cashout.get('status') != PayoutStatus.PENDING: + # return None + # + # cashout_id: str = cashout[id] + # + # approval: Optional[dict] = Bonus.manage_pending_cashout( + # cashout_id=cashout_id, + # action=PayoutStatus.APPROVED) + # + # if approval is None or approval['status'] != PayoutStatus.APPROVED: + # return None + # + # completion: Optional[dict] = Bonus.manage_pending_cashout( + # cashout_id=cashout_id, + # action=PayoutStatus.COMPLETE) + # + # if completion is None or completion['status'] != PayoutStatus.COMPLETE: + # return None + # + # return cashout_id + # + # # --- ORM --- + # + # def model_dump_mysql(self, *args, **kwargs) -> dict: + # d = self.model_dump(mode='json', *args, **kwargs) + # + # d['auto_approval_time'] = self.auto_approval_time.replace(tzinfo=None) + # d['accept_time'] = self.accept_time.replace(tzinfo=None) + # d['submit_time'] = self.submit_time.replace(tzinfo=None) + # + # if self.approval_time: + # d['approval_time'] = self.approval_time.replace(tzinfo=None) + # + # if self.rejection_time: + # d['rejection_time'] = self.rejection_time.replace(tzinfo=None) + # + # # created is automatically added by the database + # d['created'] = self.created.replace(tzinfo=None) + # + # if self.modified: + # d['modified'] = self.modified.replace(tzinfo=None) + # + # d['tsid'] = self.answers.get('tsid') + # + # return d + # + # def save(self) -> bool: + # """ + # Either INSERTS or UPDATES the Assignment instance to a Mysql + # record. + # """ + # + # # We're modifying the record, so set the time to right now! + # self.modified = datetime.now(tz=timezone.utc) + # + # query = """ + # INSERT `amt-jb`.`mtwerk_assignment` ( + # id, worker_id, hit_id, status, + # auto_approval_time, accept_time, submit_time, + # approval_time, rejection_time, + # requester_feedback, created, modified, tsid + # ) + # VALUES ( + # %(id)s, %(worker_id)s, %(hit_id)s, %(status)s, + # %(auto_approval_time)s, %(accept_time)s, %(submit_time)s, + # %(approval_time)s, %(rejection_time)s, + # %(requester_feedback)s, %(created)s, %(modified)s, %(tsid)s + # ) + # ON DUPLICATE KEY UPDATE + # worker_id = %(worker_id)s, + # hit_id = %(hit_id)s, + # status = %(status)s, + # + # auto_approval_time = %(auto_approval_time)s, + # accept_time = %(accept_time)s, + # submit_time = %(submit_time)s, + # + # approval_time = %(approval_time)s, + # rejection_time = %(rejection_time)s, + # + # requester_feedback = %(requester_feedback)s, + # -- Not going to update created just incase it changed + # -- in pydantic for some reason + # modified = %(modified)s, + # tsid = %(tsid)s + # """ + # + # try: + # MYSQLC.execute_sql_query(query, params=self.model_dump_mysql(), commit=True) + # return True + # + # except Exception as e: + # return False + # + + +# REJECT_MESSAGE_UNKNOWN_ASSIGNMENT = "Unknown assignment" diff --git a/jb/models/bonus.py b/jb/models/bonus.py new file mode 100644 index 0000000..564a32d --- /dev/null +++ b/jb/models/bonus.py @@ -0,0 +1,48 @@ +from typing import Optional, Dict + +from pydantic import BaseModel, Field, ConfigDict, PositiveInt +from typing_extensions import Self + +from jb.models.currency import USDCent +from jb.models.custom_types import AMTBoto3ID, AwareDatetimeISO, UUIDStr +from jb.models.definitions import PayoutStatus + + +class Bonus(BaseModel): + """ + A Bonus is created (in our DB) ONLY associated with an APPROVED + thl-payout-event, AFTER the bonus has actually been sent to + the worker. + We have the payout_event uuid as the unique request token to make + sure it only gets sent once (param in the boto request). + """ + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + id: Optional[PositiveInt] = Field(default=None) + assignment_id: Optional[PositiveInt] = Field(default=None) + + amt_worker_id: str = Field(min_length=3, max_length=50) + amt_assignment_id: AMTBoto3ID = Field() + + amount: USDCent = Field() + reason: str = Field(min_length=5) + grant_time: AwareDatetimeISO = Field() + + # -- GRL Specific --- + payout_event_id: UUIDStr = Field() + # created: Optional[AwareDatetimeISO] = Field(default=None) + + def to_postgres(self): + d = self.model_dump(mode="json") + d["amount"] = self.amount.to_usd() + return d + + @classmethod + def from_postgres(cls, data: Dict) -> Self: + data["amount"] = USDCent(round(data["amount"] * 100)) + fields = set(cls.model_fields.keys()) + data = {k: v for k, v in data.items() if k in fields} + return cls.model_validate(data) diff --git a/jb/models/currency.py b/jb/models/currency.py new file mode 100644 index 0000000..3094e2a --- /dev/null +++ b/jb/models/currency.py @@ -0,0 +1,70 @@ +import warnings +from decimal import Decimal +from typing import Any + +from pydantic import GetCoreSchemaHandler, NonNegativeInt +from pydantic_core import CoreSchema, core_schema + + +class USDCent(int): + def __new__(cls, value, *args, **kwargs): + + if isinstance(value, float): + warnings.warn( + "USDCent init with a float. Rounding behavior may " "be unexpected" + ) + + if isinstance(value, Decimal): + warnings.warn( + "USDCent init with a Decimal. Rounding behavior may " "be unexpected" + ) + + if value < 0: + raise ValueError("USDCent not be less than zero") + + return super(cls, cls).__new__(cls, value) + + def __add__(self, other): + assert isinstance(other, USDCent) + res = super(USDCent, self).__add__(other) + return self.__class__(res) + + def __sub__(self, other): + assert isinstance(other, USDCent) + res = super(USDCent, self).__sub__(other) + return self.__class__(res) + + def __mul__(self, other): + assert isinstance(other, USDCent) + res = super(USDCent, self).__mul__(other) + return self.__class__(res) + + def __abs__(self): + res = super(USDCent, self).__abs__() + return self.__class__(res) + + def __truediv__(self, other): + raise ValueError("Division not allowed for USDCent") + + def __str__(self): + return "%d" % int(self) + + def __repr__(self): + return "USDCent(%d)" % int(self) + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """ + https://docs.pydantic.dev/latest/concepts/types/#customizing-validation-with-__get_pydantic_core_schema__ + """ + return core_schema.no_info_after_validator_function( + cls, handler(NonNegativeInt) + ) + + def to_usd(self) -> Decimal: + return Decimal(int(self) / 100).quantize(Decimal(".01")) + + def to_usd_str(self) -> str: + return "${:,.2f}".format(float(self.to_usd())) diff --git a/jb/models/custom_types.py b/jb/models/custom_types.py new file mode 100644 index 0000000..70bc5c1 --- /dev/null +++ b/jb/models/custom_types.py @@ -0,0 +1,113 @@ +import re +from datetime import datetime, timezone +from typing import Any, Optional +from uuid import UUID + +from pydantic import ( + AwareDatetime, + StringConstraints, + TypeAdapter, + HttpUrl, +) +from pydantic.functional_serializers import PlainSerializer +from pydantic.functional_validators import AfterValidator, BeforeValidator +from pydantic.networks import UrlConstraints +from pydantic_core import Url +from typing_extensions import Annotated + + +def convert_datetime_to_iso_8601_with_z_suffix(dt: datetime) -> str: + # By default, datetimes are serialized with the %f optional. We don't want that because + # then the deserialization fails if the datetime didn't have microseconds. + return dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + + +def convert_str_dt(v: Any) -> Optional[AwareDatetime]: + # By default, pydantic is unable to handle tz-aware isoformat str. Attempt to parse a str + # that was dumped using the iso8601 format with Z suffix. + if v is not None and type(v) is str: + assert v.endswith("Z") and "T" in v, "invalid format" + return datetime.strptime(v, "%Y-%m-%dT%H:%M:%S.%fZ").replace( + tzinfo=timezone.utc + ) + return v + + +def assert_utc(v: AwareDatetime) -> AwareDatetime: + if isinstance(v, datetime): + assert v.tzinfo == timezone.utc, "Timezone is not UTC" + return v + + +# Our custom AwareDatetime that correctly serializes and deserializes +# to an ISO8601 str with timezone +AwareDatetimeISO = Annotated[ + AwareDatetime, + BeforeValidator(convert_str_dt), + AfterValidator(assert_utc), + PlainSerializer( + lambda x: x.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + when_used="json-unless-none", + ), +] + +# ISO 3166-1 alpha-2 (two-letter codes, lowercase) +# "Like" b/c it matches the format, but we're not explicitly checking +# it is one of our supported values. See models.thl.locales for that. +CountryISOLike = Annotated[ + str, StringConstraints(max_length=2, min_length=2, pattern=r"^[a-z]{2}$") +] +# 3-char ISO 639-2/B, lowercase +LanguageISOLike = Annotated[ + str, StringConstraints(max_length=3, min_length=3, pattern=r"^[a-z]{3}$") +] + + +def check_valid_uuid(v: str) -> str: + try: + assert UUID(v).hex == v + except Exception: + raise ValueError("Invalid UUID") + return v + + +# Our custom field that stores a UUID4 as the .hex string representation +UUIDStr = Annotated[ + str, + StringConstraints(min_length=32, max_length=32), + AfterValidator(check_valid_uuid), +] +# Accepts the non-hex representation and coerces +UUIDStrCoerce = Annotated[ + str, + StringConstraints(min_length=32, max_length=32), + BeforeValidator(lambda value: TypeAdapter(UUID).validate_python(value).hex), + AfterValidator(check_valid_uuid), +] + +# Same thing as UUIDStr with HttpUrl field. It is confusing that this +# is not a str https://github.com/pydantic/pydantic/discussions/6395 +HttpUrlStr = Annotated[ + str, + BeforeValidator(lambda value: str(TypeAdapter(HttpUrl).validate_python(value))), +] + +HttpsUrl = Annotated[Url, UrlConstraints(max_length=2083, allowed_schemes=["https"])] +HttpsUrlStr = Annotated[ + str, + BeforeValidator(lambda value: str(TypeAdapter(HttpsUrl).validate_python(value))), +] + + +def check_valid_amt_boto3_id(v: str) -> str: + # Test ids from amazon have 20 chars + if not re.fullmatch(r"[A-Z0-9]{20}|[A-Z0-9]{30}", v): + raise ValueError("Invalid AMT Boto3 ID") + return v + + +AMTBoto3ID = Annotated[ + str, + StringConstraints(min_length=20, max_length=30), + AfterValidator(check_valid_amt_boto3_id), +] diff --git a/jb/models/definitions.py b/jb/models/definitions.py new file mode 100644 index 0000000..a3d27ba --- /dev/null +++ b/jb/models/definitions.py @@ -0,0 +1,90 @@ +from enum import IntEnum, StrEnum + + +class AssignmentStatus(IntEnum): + # boto3.mturk specific + Submitted = 0 # same thing as Reviewable + Approved = 1 + Rejected = 2 + + # GRL specific + Accepted = 3 + PreviewState = 4 + # Invalid = 5 + # NotExist = 6 + + +class HitStatus(IntEnum): + """ + https://docs.aws.amazon.com/AWSMechTurk/latest/AWSMturkAPI/ApiReference_HITDataStructureArticle.html + """ + + # Official boto3.mturk + Assignable = 0 + Unassignable = 1 + Reviewable = 2 + Reviewing = 3 + Disposed = 4 + + # GRL Specific + NotExist = 5 + + +class HitReviewStatus(IntEnum): + NotReviewed = 0 + MarkedForReview = 1 + ReviewedAppropriate = 2 + ReviewedInappropriate = 3 + + +class PayoutStatus(StrEnum): + """These are GRL's payout statuses""" + + # The user has requested a payout. The money is taken from their + # wallet. A PENDING request can either be APPROVED, REJECTED, or + # CANCELLED. We can also implicitly skip the APPROVED step and go + # straight to COMPLETE or FAILED. + PENDING = "PENDING" + # The request is approved (by us or automatically). Once approved, + # it can be FAILED or COMPLETE. + APPROVED = "APPROVED" + # The request is rejected. The user loses the money. + REJECTED = "REJECTED" + # The user requests to cancel the request, the money goes back into their wallet. + CANCELLED = "CANCELLED" + # The payment was approved, but failed within external payment provider. + # This is an "error" state, as the money won't have moved anywhere. A + # FAILED payment can be tried again and be COMPLETE. + FAILED = "FAILED" + # The payment was sent successfully and (usually) a fee was charged + # to us for it. + COMPLETE = "COMPLETE" + # Not supported # REFUNDED: I'm not sure if this is possible or + # if we'd want to allow it. + + +class ReportValue(IntEnum): + """ + The reason a user reported a task. + """ + + # Used to indicate the user exited the task without giving feedback + REASON_UNKNOWN = 0 + # Task is in the wrong language/country, unanswerable question, won't proceed to + # next question, loading forever, error message + TECHNICAL_ERROR = 1 + # Task ended (completed or failed, and showed the user some dialog + # indicating the task was over), but failed to redirect + NO_REDIRECT = 2 + # Asked for full name, home address, identity on another site, cc# + PRIVACY_INVASION = 3 + # Asked about children, employer, medical issues, drug use, STDs, etc. + UNCOMFORTABLE_TOPICS = 4 + # Asked to install software, signup/login to external site, access webcam, + # promise to pay using external site, etc. + ASKED_FOR_NOT_ALLOWED_ACTION = 5 + # Task doesn't work well on a mobile device + BAD_ON_MOBILE = 6 + # Too long, too boring, confusing, complicated, too many + # open-ended/free-response questions + DIDNT_LIKE = 7 diff --git a/jb/models/errors.py b/jb/models/errors.py new file mode 100644 index 0000000..94f5fbb --- /dev/null +++ b/jb/models/errors.py @@ -0,0 +1,80 @@ +import re +from enum import Enum + +from pydantic import BaseModel, Field, ConfigDict, model_validator + +from jb.models import ResponseMetadata + + +class BotoRequestErrorOperation(str, Enum): + GET_ASSIGNMENT = "GetAssignment" + GET_HIT = "GetHIT" + + +class TurkErrorCode(str, Enum): + # Unclear: maybe when it's new? + HIT_NOT_EXIST = "AWS.MechanicalTurk.HITDoesNotExist" + + # This seems to be for really old Assignments + # Also maybe when it's only a Preview? + # Happens 2+ years back, and also from past 24hrs + INVALID_ASSIGNEMENT_STATE = "AWS.MechanicalTurk.InvalidAssignmentState" + + # If random assignmentId is used + ASSIGNMENT_NOT_EXIST = "AWS.MechanicalTurk.AssignmentDoesNotExist" + + +class BotoRequestErrorResponseErrorCodes(str, Enum): + REQUEST_ERROR = "RequestError" + + +class BotoRequestErrorResponseError(BaseModel): + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + message: str = Field(alias="Message") + code: BotoRequestErrorResponseErrorCodes = Field(alias="Code") + + +class BotoRequestErrorResponse(BaseModel): + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + error: BotoRequestErrorResponseError = Field(alias="Error") + response_metadata: ResponseMetadata = Field(alias="ResponseMetadata") + message: str = Field(alias="Message", min_length=50) + error_code: TurkErrorCode = Field(alias="TurkErrorCode") + + @model_validator(mode="after") + def check_consistent_hit_id(self) -> "BotoRequestErrorResponse": + + match self.error_code: + case TurkErrorCode.HIT_NOT_EXIST: + if not re.match( + r"Hit [A-Z0-9]{30} does not exist. \(\d{13}\)", self.message + ): + raise ValueError("Unknown message for TurkErrorCode.HIT_NOT_EXIST") + + case TurkErrorCode.INVALID_ASSIGNEMENT_STATE: + if not re.match( + r"This operation can be called with a status of: Reviewable,Approved,Rejected \(\d{13}\)", + self.message, + ): + raise ValueError( + "Unknown message for TurkErrorCode.INVALID_ASSIGNEMENT_STATE" + ) + + case TurkErrorCode.ASSIGNMENT_NOT_EXIST: + if not re.match( + r"Assignment [A-Z0-9]{30} does not exist. \(\d{13}\)", self.message + ): + raise ValueError( + "Unknown message for TurkErrorCode.ASSIGNMENT_NOT_EXIST" + ) + + return self + + +class BotoRequestError(BaseModel): + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + response: BotoRequestErrorResponse = Field() + operation_name: BotoRequestErrorOperation = Field() diff --git a/jb/models/event.py b/jb/models/event.py new file mode 100644 index 0000000..c357772 --- /dev/null +++ b/jb/models/event.py @@ -0,0 +1,38 @@ +from typing import Literal, Dict + +from mypy_boto3_mturk.literals import EventTypeType +from pydantic import BaseModel, Field + +from jb.models.custom_types import AwareDatetimeISO, AMTBoto3ID + + +class MTurkEvent(BaseModel): + """ + What AWS SNS will POST to our mturk_notifications endpoint (inside the request body) + """ + + event_type: EventTypeType = Field(example="AssignmentSubmitted") + event_timestamp: AwareDatetimeISO = Field(example="2025-10-16T18:45:51Z") + amt_hit_id: AMTBoto3ID = Field(example="12345678901234567890") + amt_assignment_id: str = Field( + max_length=64, example="1234567890123456789012345678901234567890" + ) + amt_hit_type_id: AMTBoto3ID = Field(example="09876543210987654321") + + @classmethod + def from_sns(cls, data: Dict): + return cls.model_validate( + { + "event_type": data["EventType"], + "event_timestamp": cls.fix_mturk_timestamp(data["EventTimestamp"]), + "amt_hit_id": data["HITId"], + "amt_assignment_id": data["AssignmentId"], + "amt_hit_type_id": data["HITTypeId"], + } + ) + + @staticmethod + def fix_mturk_timestamp(ts: str) -> str: + if ts.endswith("Z") and "." not in ts: + ts = ts[:-1] + ".000Z" + return ts diff --git a/jb/models/hit.py b/jb/models/hit.py new file mode 100644 index 0000000..c3734fa --- /dev/null +++ b/jb/models/hit.py @@ -0,0 +1,251 @@ +from datetime import datetime, timezone, timedelta +from typing import Optional, List, Dict +from uuid import uuid4 +from xml.etree import ElementTree + +from mypy_boto3_mturk.type_defs import HITTypeDef +from pydantic import ( + BaseModel, + Field, + PositiveInt, + ConfigDict, + NonNegativeInt, +) +from typing_extensions import Self + +from jb.models.currency import USDCent +from jb.models.custom_types import AMTBoto3ID, HttpsUrlStr, AwareDatetimeISO +from jb.models.definitions import HitStatus, HitReviewStatus + + +class HitQuestion(BaseModel): + id: Optional[PositiveInt] = Field(default=None) + + url: HttpsUrlStr = Field() + height: PositiveInt = Field(default=1_200, ge=100, le=4_000) + + # --- Properties --- + + def to_postgres(self): + return self.model_dump(mode="json") + + @property + def xml(self) -> str: + return f"""<?xml version="1.0" encoding="UTF-8"?> + <ExternalQuestion xmlns="http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2006-07-14/ExternalQuestion.xsd"> + <ExternalURL>{str(self.url)}</ExternalURL> + <FrameHeight>{self.height}</FrameHeight> + </ExternalQuestion>""" + + +class HitTypeCommon(BaseModel): + """ + Fields on both the HitType and Hit + """ + + model_config = ConfigDict( + extra="forbid", validate_assignment=True, ser_json_timedelta="float" + ) + + title: str = Field( + min_length=3, + max_length=200, + description="The HIT post title that appears in the listing view", + ) + description: str = Field( + min_length=3, + max_length=2_000, + description="The expand more about textarea, has a max of 2000 characters", + ) + reward: USDCent = Field( + description="The amount of money the Requester will pay a Worker for successfully completing the HIT." + ) + + assignment_duration: timedelta = Field( + default=timedelta(minutes=90), + description="The amount of time, in seconds, that a Worker has to complete " + "the HIT after accepting it.", + ) + auto_approval_delay: timedelta = Field( + default=timedelta(days=7), + description="The number of seconds after an assignment for the HIT has " + "been submitted, after which the assignment is considered " + "Approved automatically unless the Requester explicitly " + "rejects it.", + ) + keywords: str = Field(min_length=3, max_length=999) + + +class HitType(HitTypeCommon): + """ + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/mturk/client/create_hit_type.html + https://docs.aws.amazon.com/AWSMechTurk/latest/AWSMturkAPI/ApiReference_CreateHITTypeOperation.html + """ + + id: Optional[PositiveInt] = Field(default=None) + amt_hit_type_id: Optional[AMTBoto3ID] = Field(default=None) + + # --- GRL Specific --- + min_active: NonNegativeInt = Field(default=0, le=100_000) + + def to_api_request_body(self): + return dict( + AutoApprovalDelayInSeconds=round(self.auto_approval_delay.total_seconds()), + AssignmentDurationInSeconds=round(self.assignment_duration.total_seconds()), + Reward=str(self.reward.to_usd()), + Title=self.title, + Keywords=self.keywords, + Description=self.description, + ) + + def to_postgres(self): + d = self.model_dump(mode="json") + d["reward"] = self.reward.to_usd() + return d + + @classmethod + def from_postgres(cls, data: Dict) -> Self: + data["reward"] = USDCent(round(data["reward"] * 100)) + return cls.model_validate(data) + + def generate_hit_amt_request(self, question: HitQuestion): + d = dict() + d["HITTypeId"] = self.amt_hit_type_id + d["MaxAssignments"] = 1 + d["LifetimeInSeconds"] = round(timedelta(days=14).total_seconds()) + d["Question"] = question.xml + d["UniqueRequestToken"] = uuid4().hex + return d + + +class Hit(HitTypeCommon): + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + id: Optional[PositiveInt] = Field(default=None) + hit_type_id: Optional[PositiveInt] = Field(default=None) + question_id: Optional[PositiveInt] = Field(default=None) + + amt_hit_id: AMTBoto3ID = Field() + amt_hit_type_id: AMTBoto3ID = Field() + amt_group_id: AMTBoto3ID = Field() + hit_question_xml: str = Field() + + status: HitStatus = Field() + review_status: HitReviewStatus = Field() + creation_time: AwareDatetimeISO = Field(default=None, description="From aws") + expiration: Optional[AwareDatetimeISO] = Field(default=None) + + # GRL Specific + created_at: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When this record was saved in the database", + ) + modified_at: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When this record was last modified", + ) + + # -- Hit specific + + qualification_requirements: Optional[List[Dict]] = Field(default=None) + max_assignments: int = Field() + + # # this comes back as expiration. only for the request + # lifetime: timedelta = Field( + # default=timedelta(days=14), + # description="An amount of time, in seconds, after which the HIT is no longer " + # "available for users to accept.", + # ) + assignment_pending_count: NonNegativeInt = Field() + assignment_available_count: NonNegativeInt = Field() + assignment_completed_count: NonNegativeInt = Field() + + @classmethod + def from_amt_create_hit( + cls, data: HITTypeDef, question: HitQuestion, hit_type: HitType + ) -> Self: + assert question.id is not None + assert hit_type.id is not None + assert hit_type.amt_hit_type_id is not None + + h = Hit.model_validate( + dict( + amt_hit_id=data["HITId"], + amt_hit_type_id=data["HITTypeId"], + amt_group_id=data["HITGroupId"], + status=HitStatus[data["HITStatus"]], + review_status=HitReviewStatus[data["HITReviewStatus"]], + creation_time=data["CreationTime"].astimezone(tz=timezone.utc), + expiration=data["Expiration"].astimezone(tz=timezone.utc), + hit_question_xml=data["Question"], + qualification_requirements=data["QualificationRequirements"], + max_assignments=data["MaxAssignments"], + assignment_pending_count=data["NumberOfAssignmentsPending"], + assignment_available_count=data["NumberOfAssignmentsAvailable"], + assignment_completed_count=data["NumberOfAssignmentsCompleted"], + description=data["Description"], + keywords=data["Keywords"], + reward=USDCent(round(float(data["Reward"]) * 100)), + title=data["Title"], + question_id=question.id, + hit_type_id=hit_type.id, + ) + ) + return h + + @classmethod + def from_amt_get_hit(cls, data: HITTypeDef) -> Self: + h = Hit.model_validate( + dict( + amt_hit_id=data["HITId"], + amt_hit_type_id=data["HITTypeId"], + amt_group_id=data["HITGroupId"], + status=HitStatus[data["HITStatus"]], + review_status=HitReviewStatus[data["HITReviewStatus"]], + creation_time=data["CreationTime"].astimezone(tz=timezone.utc), + expiration=data["Expiration"].astimezone(tz=timezone.utc), + hit_question_xml=data["Question"], + qualification_requirements=data["QualificationRequirements"], + max_assignments=data["MaxAssignments"], + assignment_pending_count=data["NumberOfAssignmentsPending"], + assignment_available_count=data["NumberOfAssignmentsAvailable"], + assignment_completed_count=data["NumberOfAssignmentsCompleted"], + description=data["Description"], + keywords=data["Keywords"], + reward=USDCent(round(float(data["Reward"]) * 100)), + title=data["Title"], + question_id=None, + hit_type_id=None, + ) + ) + return h + + def to_postgres(self): + d = self.model_dump(mode="json") + d["reward"] = self.reward.to_usd() + return d + + @classmethod + def from_postgres(cls, data: Dict) -> Self: + data["reward"] = USDCent(round(data["reward"] * 100)) + return cls.model_validate(data) + + @property + def hit_question(self) -> HitQuestion: + root = ElementTree.fromstring(self.hit_question_xml) + + ns = { + "mt": "http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2006-07-14/ExternalQuestion.xsd" + } + res = {} + + lookup_table = dict(ExternalURL="url", FrameHeight="height") + for a in root.findall("mt:*", ns): + key = lookup_table[a.tag.split("}")[1]] + val = a.text + res[key] = val + + return HitQuestion.model_validate(res, from_attributes=True) diff --git a/jb/views/__init__.py b/jb/views/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/jb/views/__init__.py |
