diff options
| author | Max Nanis | 2026-02-24 17:26:15 -0500 |
|---|---|---|
| committer | Max Nanis | 2026-02-24 17:26:15 -0500 |
| commit | 8c1940445503fd6678d0961600f2be81622793a2 (patch) | |
| tree | b9173562b8824b5eaa805e446d9d780e1f23fb2a /jb/models | |
| parent | 25d8c3c214baf10f6520cc1351f78473150e5d7a (diff) | |
| download | amt-jb-8c1940445503fd6678d0961600f2be81622793a2.tar.gz amt-jb-8c1940445503fd6678d0961600f2be81622793a2.zip | |
Extensive use of type checking. Movement of pytest conf towards handling managers (for db agnostic unittest). Starting to organize pytests.
Diffstat (limited to 'jb/models')
| -rw-r--r-- | jb/models/assignment.py | 23 | ||||
| -rw-r--r-- | jb/models/bonus.py | 7 | ||||
| -rw-r--r-- | jb/models/currency.py | 70 | ||||
| -rw-r--r-- | jb/models/custom_types.py | 3 | ||||
| -rw-r--r-- | jb/models/definitions.py | 28 | ||||
| -rw-r--r-- | jb/models/event.py | 19 | ||||
| -rw-r--r-- | jb/models/hit.py | 24 |
7 files changed, 46 insertions, 128 deletions
diff --git a/jb/models/assignment.py b/jb/models/assignment.py index 39ae47c..5dd0167 100644 --- a/jb/models/assignment.py +++ b/jb/models/assignment.py @@ -1,6 +1,6 @@ import logging from datetime import datetime, timezone -from typing import Optional, TypedDict +from typing import Optional, TypedDict, Any from xml.etree import ElementTree from mypy_boto3_mturk.type_defs import AssignmentTypeDef @@ -10,7 +10,6 @@ from pydantic import ( ConfigDict, model_validator, PositiveInt, - computed_field, TypeAdapter, ValidationError, ) @@ -116,10 +115,12 @@ class Assignment(AssignmentStub): default=None, min_length=3, max_length=2_000, - help_text="The feedback string included with the call to the " - "ApproveAssignment operation or the RejectAssignment " - "operation, if the Requester approved or rejected the " - "assignment and specified feedback.", + json_schema_extra={ + "help_text": "The feedback string included with the call to the " + "ApproveAssignment operation or the RejectAssignment " + "operation, if the Requester approved or rejected the " + "assignment and specified feedback." + }, ) answer_xml: Optional[str] = Field(default=None, exclude=True) @@ -131,7 +132,7 @@ class Assignment(AssignmentStub): # --- Validators --- @model_validator(mode="before") - def set_tsid(cls, values: dict): + def set_tsid(cls, values: dict[str, Any]) -> dict[str, Any]: if values.get("tsid") is None and (answer_xml := values.get("answer_xml")): answer_dict = cls.parse_answer_xml(answer_xml) tsid = answer_dict.get("tsid") @@ -175,10 +176,10 @@ class Assignment(AssignmentStub): if self.answer_xml is None: return None - return self.parse_answer_xml(self.answer_xml) + return self.parse_answer_xml(self.answer_xml) # type: ignore @staticmethod - def parse_answer_xml(answer_xml: str): + def parse_answer_xml(answer_xml: str) -> dict[str, Any]: root = ElementTree.fromstring(answer_xml) ns = { "mt": "http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2005-10-01/QuestionFormAnswers.xsd" @@ -186,8 +187,8 @@ class Assignment(AssignmentStub): res = {} for a in root.findall("mt:Answer", ns): - name = a.find("mt:QuestionIdentifier", ns).text - value = a.find("mt:FreeText", ns).text + name = a.find("mt:QuestionIdentifier", ns).text # type: ignore + value = a.find("mt:FreeText", ns).text # type: ignore res[name] = value or "" EXPECTED_KEYS = {"amt_assignment_id", "amt_worker_id", "tsid"} diff --git a/jb/models/bonus.py b/jb/models/bonus.py index 564a32d..a536dd1 100644 --- a/jb/models/bonus.py +++ b/jb/models/bonus.py @@ -1,11 +1,10 @@ -from typing import Optional, Dict +from typing import Optional, Dict, Any from pydantic import BaseModel, Field, ConfigDict, PositiveInt from typing_extensions import Self -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.custom_types import AMTBoto3ID, AwareDatetimeISO, UUIDStr -from jb.models.definitions import PayoutStatus class Bonus(BaseModel): @@ -41,7 +40,7 @@ class Bonus(BaseModel): return d @classmethod - def from_postgres(cls, data: Dict) -> Self: + def from_postgres(cls, data: Dict[str, Any]) -> Self: data["amount"] = USDCent(round(data["amount"] * 100)) fields = set(cls.model_fields.keys()) data = {k: v for k, v in data.items() if k in fields} diff --git a/jb/models/currency.py b/jb/models/currency.py deleted file mode 100644 index 3094e2a..0000000 --- a/jb/models/currency.py +++ /dev/null @@ -1,70 +0,0 @@ -import warnings -from decimal import Decimal -from typing import Any - -from pydantic import GetCoreSchemaHandler, NonNegativeInt -from pydantic_core import CoreSchema, core_schema - - -class USDCent(int): - def __new__(cls, value, *args, **kwargs): - - if isinstance(value, float): - warnings.warn( - "USDCent init with a float. Rounding behavior may " "be unexpected" - ) - - if isinstance(value, Decimal): - warnings.warn( - "USDCent init with a Decimal. Rounding behavior may " "be unexpected" - ) - - if value < 0: - raise ValueError("USDCent not be less than zero") - - return super(cls, cls).__new__(cls, value) - - def __add__(self, other): - assert isinstance(other, USDCent) - res = super(USDCent, self).__add__(other) - return self.__class__(res) - - def __sub__(self, other): - assert isinstance(other, USDCent) - res = super(USDCent, self).__sub__(other) - return self.__class__(res) - - def __mul__(self, other): - assert isinstance(other, USDCent) - res = super(USDCent, self).__mul__(other) - return self.__class__(res) - - def __abs__(self): - res = super(USDCent, self).__abs__() - return self.__class__(res) - - def __truediv__(self, other): - raise ValueError("Division not allowed for USDCent") - - def __str__(self): - return "%d" % int(self) - - def __repr__(self): - return "USDCent(%d)" % int(self) - - @classmethod - def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler - ) -> CoreSchema: - """ - https://docs.pydantic.dev/latest/concepts/types/#customizing-validation-with-__get_pydantic_core_schema__ - """ - return core_schema.no_info_after_validator_function( - cls, handler(NonNegativeInt) - ) - - def to_usd(self) -> Decimal: - return Decimal(int(self) / 100).quantize(Decimal(".01")) - - def to_usd_str(self) -> str: - return "${:,.2f}".format(float(self.to_usd())) diff --git a/jb/models/custom_types.py b/jb/models/custom_types.py index 70bc5c1..10bc9d1 100644 --- a/jb/models/custom_types.py +++ b/jb/models/custom_types.py @@ -34,8 +34,7 @@ def convert_str_dt(v: Any) -> Optional[AwareDatetime]: def assert_utc(v: AwareDatetime) -> AwareDatetime: - if isinstance(v, datetime): - assert v.tzinfo == timezone.utc, "Timezone is not UTC" + assert v.tzinfo == timezone.utc, "Timezone is not UTC" return v diff --git a/jb/models/definitions.py b/jb/models/definitions.py index a3d27ba..4ae7a21 100644 --- a/jb/models/definitions.py +++ b/jb/models/definitions.py @@ -1,4 +1,4 @@ -from enum import IntEnum, StrEnum +from enum import IntEnum class AssignmentStatus(IntEnum): @@ -37,32 +37,6 @@ class HitReviewStatus(IntEnum): ReviewedInappropriate = 3 -class PayoutStatus(StrEnum): - """These are GRL's payout statuses""" - - # The user has requested a payout. The money is taken from their - # wallet. A PENDING request can either be APPROVED, REJECTED, or - # CANCELLED. We can also implicitly skip the APPROVED step and go - # straight to COMPLETE or FAILED. - PENDING = "PENDING" - # The request is approved (by us or automatically). Once approved, - # it can be FAILED or COMPLETE. - APPROVED = "APPROVED" - # The request is rejected. The user loses the money. - REJECTED = "REJECTED" - # The user requests to cancel the request, the money goes back into their wallet. - CANCELLED = "CANCELLED" - # The payment was approved, but failed within external payment provider. - # This is an "error" state, as the money won't have moved anywhere. A - # FAILED payment can be tried again and be COMPLETE. - FAILED = "FAILED" - # The payment was sent successfully and (usually) a fee was charged - # to us for it. - COMPLETE = "COMPLETE" - # Not supported # REFUNDED: I'm not sure if this is possible or - # if we'd want to allow it. - - class ReportValue(IntEnum): """ The reason a user reported a task. diff --git a/jb/models/event.py b/jb/models/event.py index c357772..c167420 100644 --- a/jb/models/event.py +++ b/jb/models/event.py @@ -11,13 +11,22 @@ class MTurkEvent(BaseModel): What AWS SNS will POST to our mturk_notifications endpoint (inside the request body) """ - event_type: EventTypeType = Field(example="AssignmentSubmitted") - event_timestamp: AwareDatetimeISO = Field(example="2025-10-16T18:45:51Z") - amt_hit_id: AMTBoto3ID = Field(example="12345678901234567890") + event_type: EventTypeType = Field( + json_schema_extra={"example": "AssignmentSubmitted"} + ) + event_timestamp: AwareDatetimeISO = Field( + json_schema_extra={"example": "2025-10-16T18:45:51Z"} + ) + amt_hit_id: AMTBoto3ID = Field( + json_schema_extra={"example": "12345678901234567890"} + ) amt_assignment_id: str = Field( - max_length=64, example="1234567890123456789012345678901234567890" + max_length=64, + json_schema_extra={"example": "1234567890123456789012345678901234567890"}, + ) + amt_hit_type_id: AMTBoto3ID = Field( + json_schema_extra={"example": "09876543210987654321"} ) - amt_hit_type_id: AMTBoto3ID = Field(example="09876543210987654321") @classmethod def from_sns(cls, data: Dict): diff --git a/jb/models/hit.py b/jb/models/hit.py index c3734fa..fba2ecf 100644 --- a/jb/models/hit.py +++ b/jb/models/hit.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone, timedelta -from typing import Optional, List, Dict +from typing import Optional, List, Dict, Any from uuid import uuid4 from xml.etree import ElementTree @@ -13,7 +13,7 @@ from pydantic import ( ) from typing_extensions import Self -from jb.models.currency import USDCent +from generalresearchutils.currency import USDCent from jb.models.custom_types import AMTBoto3ID, HttpsUrlStr, AwareDatetimeISO from jb.models.definitions import HitStatus, HitReviewStatus @@ -104,11 +104,11 @@ class HitType(HitTypeCommon): return d @classmethod - def from_postgres(cls, data: Dict) -> Self: + def from_postgres(cls, data: Dict[str, Any]) -> Self: data["reward"] = USDCent(round(data["reward"] * 100)) return cls.model_validate(data) - def generate_hit_amt_request(self, question: HitQuestion): + def generate_hit_amt_request(self, question: HitQuestion) -> Dict[str, Any]: d = dict() d["HITTypeId"] = self.amt_hit_type_id d["MaxAssignments"] = 1 @@ -135,7 +135,12 @@ class Hit(HitTypeCommon): status: HitStatus = Field() review_status: HitReviewStatus = Field() - creation_time: AwareDatetimeISO = Field(default=None, description="From aws") + + # TODO: Check if this is actually ever going to be None. I type fixed it, + # but I don't have anything to suggest it isn't requred. -- Max 2026-02-24 + creation_time: Optional[AwareDatetimeISO] = Field( + default=None, description="From aws" + ) expiration: Optional[AwareDatetimeISO] = Field(default=None) # GRL Specific @@ -150,7 +155,7 @@ class Hit(HitTypeCommon): # -- Hit specific - qualification_requirements: Optional[List[Dict]] = Field(default=None) + qualification_requirements: Optional[List[Dict[str, Any]]] = Field(default=None) max_assignments: int = Field() # # this comes back as expiration. only for the request @@ -171,7 +176,7 @@ class Hit(HitTypeCommon): assert hit_type.id is not None assert hit_type.amt_hit_type_id is not None - h = Hit.model_validate( + h = cls.model_validate( dict( amt_hit_id=data["HITId"], amt_hit_type_id=data["HITTypeId"], @@ -194,11 +199,12 @@ class Hit(HitTypeCommon): hit_type_id=hit_type.id, ) ) + return h @classmethod def from_amt_get_hit(cls, data: HITTypeDef) -> Self: - h = Hit.model_validate( + h = cls.model_validate( dict( amt_hit_id=data["HITId"], amt_hit_type_id=data["HITTypeId"], @@ -229,7 +235,7 @@ class Hit(HitTypeCommon): return d @classmethod - def from_postgres(cls, data: Dict) -> Self: + def from_postgres(cls, data: Dict[str, Any]) -> Self: data["reward"] = USDCent(round(data["reward"] * 100)) return cls.model_validate(data) |
