aboutsummaryrefslogtreecommitdiff
path: root/jb/models
diff options
context:
space:
mode:
authorMax Nanis2026-02-24 17:26:15 -0500
committerMax Nanis2026-02-24 17:26:15 -0500
commit8c1940445503fd6678d0961600f2be81622793a2 (patch)
treeb9173562b8824b5eaa805e446d9d780e1f23fb2a /jb/models
parent25d8c3c214baf10f6520cc1351f78473150e5d7a (diff)
downloadamt-jb-8c1940445503fd6678d0961600f2be81622793a2.tar.gz
amt-jb-8c1940445503fd6678d0961600f2be81622793a2.zip
Extensive use of type checking. Movement of pytest conf towards handling managers (for db agnostic unittest). Starting to organize pytests.
Diffstat (limited to 'jb/models')
-rw-r--r--jb/models/assignment.py23
-rw-r--r--jb/models/bonus.py7
-rw-r--r--jb/models/currency.py70
-rw-r--r--jb/models/custom_types.py3
-rw-r--r--jb/models/definitions.py28
-rw-r--r--jb/models/event.py19
-rw-r--r--jb/models/hit.py24
7 files changed, 46 insertions, 128 deletions
diff --git a/jb/models/assignment.py b/jb/models/assignment.py
index 39ae47c..5dd0167 100644
--- a/jb/models/assignment.py
+++ b/jb/models/assignment.py
@@ -1,6 +1,6 @@
import logging
from datetime import datetime, timezone
-from typing import Optional, TypedDict
+from typing import Optional, TypedDict, Any
from xml.etree import ElementTree
from mypy_boto3_mturk.type_defs import AssignmentTypeDef
@@ -10,7 +10,6 @@ from pydantic import (
ConfigDict,
model_validator,
PositiveInt,
- computed_field,
TypeAdapter,
ValidationError,
)
@@ -116,10 +115,12 @@ class Assignment(AssignmentStub):
default=None,
min_length=3,
max_length=2_000,
- help_text="The feedback string included with the call to the "
- "ApproveAssignment operation or the RejectAssignment "
- "operation, if the Requester approved or rejected the "
- "assignment and specified feedback.",
+ json_schema_extra={
+ "help_text": "The feedback string included with the call to the "
+ "ApproveAssignment operation or the RejectAssignment "
+ "operation, if the Requester approved or rejected the "
+ "assignment and specified feedback."
+ },
)
answer_xml: Optional[str] = Field(default=None, exclude=True)
@@ -131,7 +132,7 @@ class Assignment(AssignmentStub):
# --- Validators ---
@model_validator(mode="before")
- def set_tsid(cls, values: dict):
+ def set_tsid(cls, values: dict[str, Any]) -> dict[str, Any]:
if values.get("tsid") is None and (answer_xml := values.get("answer_xml")):
answer_dict = cls.parse_answer_xml(answer_xml)
tsid = answer_dict.get("tsid")
@@ -175,10 +176,10 @@ class Assignment(AssignmentStub):
if self.answer_xml is None:
return None
- return self.parse_answer_xml(self.answer_xml)
+ return self.parse_answer_xml(self.answer_xml) # type: ignore
@staticmethod
- def parse_answer_xml(answer_xml: str):
+ def parse_answer_xml(answer_xml: str) -> dict[str, Any]:
root = ElementTree.fromstring(answer_xml)
ns = {
"mt": "http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2005-10-01/QuestionFormAnswers.xsd"
@@ -186,8 +187,8 @@ class Assignment(AssignmentStub):
res = {}
for a in root.findall("mt:Answer", ns):
- name = a.find("mt:QuestionIdentifier", ns).text
- value = a.find("mt:FreeText", ns).text
+ name = a.find("mt:QuestionIdentifier", ns).text # type: ignore
+ value = a.find("mt:FreeText", ns).text # type: ignore
res[name] = value or ""
EXPECTED_KEYS = {"amt_assignment_id", "amt_worker_id", "tsid"}
diff --git a/jb/models/bonus.py b/jb/models/bonus.py
index 564a32d..a536dd1 100644
--- a/jb/models/bonus.py
+++ b/jb/models/bonus.py
@@ -1,11 +1,10 @@
-from typing import Optional, Dict
+from typing import Optional, Dict, Any
from pydantic import BaseModel, Field, ConfigDict, PositiveInt
from typing_extensions import Self
-from jb.models.currency import USDCent
+from generalresearchutils.currency import USDCent
from jb.models.custom_types import AMTBoto3ID, AwareDatetimeISO, UUIDStr
-from jb.models.definitions import PayoutStatus
class Bonus(BaseModel):
@@ -41,7 +40,7 @@ class Bonus(BaseModel):
return d
@classmethod
- def from_postgres(cls, data: Dict) -> Self:
+ def from_postgres(cls, data: Dict[str, Any]) -> Self:
data["amount"] = USDCent(round(data["amount"] * 100))
fields = set(cls.model_fields.keys())
data = {k: v for k, v in data.items() if k in fields}
diff --git a/jb/models/currency.py b/jb/models/currency.py
deleted file mode 100644
index 3094e2a..0000000
--- a/jb/models/currency.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import warnings
-from decimal import Decimal
-from typing import Any
-
-from pydantic import GetCoreSchemaHandler, NonNegativeInt
-from pydantic_core import CoreSchema, core_schema
-
-
-class USDCent(int):
- def __new__(cls, value, *args, **kwargs):
-
- if isinstance(value, float):
- warnings.warn(
- "USDCent init with a float. Rounding behavior may " "be unexpected"
- )
-
- if isinstance(value, Decimal):
- warnings.warn(
- "USDCent init with a Decimal. Rounding behavior may " "be unexpected"
- )
-
- if value < 0:
- raise ValueError("USDCent not be less than zero")
-
- return super(cls, cls).__new__(cls, value)
-
- def __add__(self, other):
- assert isinstance(other, USDCent)
- res = super(USDCent, self).__add__(other)
- return self.__class__(res)
-
- def __sub__(self, other):
- assert isinstance(other, USDCent)
- res = super(USDCent, self).__sub__(other)
- return self.__class__(res)
-
- def __mul__(self, other):
- assert isinstance(other, USDCent)
- res = super(USDCent, self).__mul__(other)
- return self.__class__(res)
-
- def __abs__(self):
- res = super(USDCent, self).__abs__()
- return self.__class__(res)
-
- def __truediv__(self, other):
- raise ValueError("Division not allowed for USDCent")
-
- def __str__(self):
- return "%d" % int(self)
-
- def __repr__(self):
- return "USDCent(%d)" % int(self)
-
- @classmethod
- def __get_pydantic_core_schema__(
- cls, source_type: Any, handler: GetCoreSchemaHandler
- ) -> CoreSchema:
- """
- https://docs.pydantic.dev/latest/concepts/types/#customizing-validation-with-__get_pydantic_core_schema__
- """
- return core_schema.no_info_after_validator_function(
- cls, handler(NonNegativeInt)
- )
-
- def to_usd(self) -> Decimal:
- return Decimal(int(self) / 100).quantize(Decimal(".01"))
-
- def to_usd_str(self) -> str:
- return "${:,.2f}".format(float(self.to_usd()))
diff --git a/jb/models/custom_types.py b/jb/models/custom_types.py
index 70bc5c1..10bc9d1 100644
--- a/jb/models/custom_types.py
+++ b/jb/models/custom_types.py
@@ -34,8 +34,7 @@ def convert_str_dt(v: Any) -> Optional[AwareDatetime]:
def assert_utc(v: AwareDatetime) -> AwareDatetime:
- if isinstance(v, datetime):
- assert v.tzinfo == timezone.utc, "Timezone is not UTC"
+ assert v.tzinfo == timezone.utc, "Timezone is not UTC"
return v
diff --git a/jb/models/definitions.py b/jb/models/definitions.py
index a3d27ba..4ae7a21 100644
--- a/jb/models/definitions.py
+++ b/jb/models/definitions.py
@@ -1,4 +1,4 @@
-from enum import IntEnum, StrEnum
+from enum import IntEnum
class AssignmentStatus(IntEnum):
@@ -37,32 +37,6 @@ class HitReviewStatus(IntEnum):
ReviewedInappropriate = 3
-class PayoutStatus(StrEnum):
- """These are GRL's payout statuses"""
-
- # The user has requested a payout. The money is taken from their
- # wallet. A PENDING request can either be APPROVED, REJECTED, or
- # CANCELLED. We can also implicitly skip the APPROVED step and go
- # straight to COMPLETE or FAILED.
- PENDING = "PENDING"
- # The request is approved (by us or automatically). Once approved,
- # it can be FAILED or COMPLETE.
- APPROVED = "APPROVED"
- # The request is rejected. The user loses the money.
- REJECTED = "REJECTED"
- # The user requests to cancel the request, the money goes back into their wallet.
- CANCELLED = "CANCELLED"
- # The payment was approved, but failed within external payment provider.
- # This is an "error" state, as the money won't have moved anywhere. A
- # FAILED payment can be tried again and be COMPLETE.
- FAILED = "FAILED"
- # The payment was sent successfully and (usually) a fee was charged
- # to us for it.
- COMPLETE = "COMPLETE"
- # Not supported # REFUNDED: I'm not sure if this is possible or
- # if we'd want to allow it.
-
-
class ReportValue(IntEnum):
"""
The reason a user reported a task.
diff --git a/jb/models/event.py b/jb/models/event.py
index c357772..c167420 100644
--- a/jb/models/event.py
+++ b/jb/models/event.py
@@ -11,13 +11,22 @@ class MTurkEvent(BaseModel):
What AWS SNS will POST to our mturk_notifications endpoint (inside the request body)
"""
- event_type: EventTypeType = Field(example="AssignmentSubmitted")
- event_timestamp: AwareDatetimeISO = Field(example="2025-10-16T18:45:51Z")
- amt_hit_id: AMTBoto3ID = Field(example="12345678901234567890")
+ event_type: EventTypeType = Field(
+ json_schema_extra={"example": "AssignmentSubmitted"}
+ )
+ event_timestamp: AwareDatetimeISO = Field(
+ json_schema_extra={"example": "2025-10-16T18:45:51Z"}
+ )
+ amt_hit_id: AMTBoto3ID = Field(
+ json_schema_extra={"example": "12345678901234567890"}
+ )
amt_assignment_id: str = Field(
- max_length=64, example="1234567890123456789012345678901234567890"
+ max_length=64,
+ json_schema_extra={"example": "1234567890123456789012345678901234567890"},
+ )
+ amt_hit_type_id: AMTBoto3ID = Field(
+ json_schema_extra={"example": "09876543210987654321"}
)
- amt_hit_type_id: AMTBoto3ID = Field(example="09876543210987654321")
@classmethod
def from_sns(cls, data: Dict):
diff --git a/jb/models/hit.py b/jb/models/hit.py
index c3734fa..fba2ecf 100644
--- a/jb/models/hit.py
+++ b/jb/models/hit.py
@@ -1,5 +1,5 @@
from datetime import datetime, timezone, timedelta
-from typing import Optional, List, Dict
+from typing import Optional, List, Dict, Any
from uuid import uuid4
from xml.etree import ElementTree
@@ -13,7 +13,7 @@ from pydantic import (
)
from typing_extensions import Self
-from jb.models.currency import USDCent
+from generalresearchutils.currency import USDCent
from jb.models.custom_types import AMTBoto3ID, HttpsUrlStr, AwareDatetimeISO
from jb.models.definitions import HitStatus, HitReviewStatus
@@ -104,11 +104,11 @@ class HitType(HitTypeCommon):
return d
@classmethod
- def from_postgres(cls, data: Dict) -> Self:
+ def from_postgres(cls, data: Dict[str, Any]) -> Self:
data["reward"] = USDCent(round(data["reward"] * 100))
return cls.model_validate(data)
- def generate_hit_amt_request(self, question: HitQuestion):
+ def generate_hit_amt_request(self, question: HitQuestion) -> Dict[str, Any]:
d = dict()
d["HITTypeId"] = self.amt_hit_type_id
d["MaxAssignments"] = 1
@@ -135,7 +135,12 @@ class Hit(HitTypeCommon):
status: HitStatus = Field()
review_status: HitReviewStatus = Field()
- creation_time: AwareDatetimeISO = Field(default=None, description="From aws")
+
+ # TODO: Check if this is actually ever going to be None. I type fixed it,
+ # but I don't have anything to suggest it isn't requred. -- Max 2026-02-24
+ creation_time: Optional[AwareDatetimeISO] = Field(
+ default=None, description="From aws"
+ )
expiration: Optional[AwareDatetimeISO] = Field(default=None)
# GRL Specific
@@ -150,7 +155,7 @@ class Hit(HitTypeCommon):
# -- Hit specific
- qualification_requirements: Optional[List[Dict]] = Field(default=None)
+ qualification_requirements: Optional[List[Dict[str, Any]]] = Field(default=None)
max_assignments: int = Field()
# # this comes back as expiration. only for the request
@@ -171,7 +176,7 @@ class Hit(HitTypeCommon):
assert hit_type.id is not None
assert hit_type.amt_hit_type_id is not None
- h = Hit.model_validate(
+ h = cls.model_validate(
dict(
amt_hit_id=data["HITId"],
amt_hit_type_id=data["HITTypeId"],
@@ -194,11 +199,12 @@ class Hit(HitTypeCommon):
hit_type_id=hit_type.id,
)
)
+
return h
@classmethod
def from_amt_get_hit(cls, data: HITTypeDef) -> Self:
- h = Hit.model_validate(
+ h = cls.model_validate(
dict(
amt_hit_id=data["HITId"],
amt_hit_type_id=data["HITTypeId"],
@@ -229,7 +235,7 @@ class Hit(HitTypeCommon):
return d
@classmethod
- def from_postgres(cls, data: Dict) -> Self:
+ def from_postgres(cls, data: Dict[str, Any]) -> Self:
data["reward"] = USDCent(round(data["reward"] * 100))
return cls.model_validate(data)