From ce291a165fab6b6dc9f053c7b75a699d0fdf389f Mon Sep 17 00:00:00 2001 From: Max Nanis Date: Sun, 8 Mar 2026 21:57:53 -0400 Subject: Simple typing changes, Ruff import formatter. p2 --- generalresearch/__init__.py | 9 +- generalresearch/config.py | 6 +- generalresearch/healing_ppe.py | 5 +- generalresearch/mariadb.py | 2 +- generalresearch/models/cint/question.py | 33 +++-- generalresearch/models/cint/survey.py | 23 ++-- generalresearch/models/cint/task_collection.py | 2 +- generalresearch/models/custom_types.py | 13 +- generalresearch/models/dynata/question.py | 10 +- generalresearch/models/dynata/survey.py | 38 ++++-- generalresearch/models/dynata/task_collection.py | 4 +- generalresearch/models/events.py | 14 +- generalresearch/models/gr/__init__.py | 4 +- generalresearch/models/gr/authentication.py | 18 +-- generalresearch/models/gr/business.py | 29 +++-- generalresearch/models/gr/team.py | 13 +- generalresearch/models/innovate/question.py | 26 ++-- generalresearch/models/innovate/survey.py | 31 ++--- generalresearch/models/innovate/task_collection.py | 2 +- generalresearch/models/legacy/bucket.py | 31 +++-- generalresearch/models/legacy/offerwall.py | 16 +-- generalresearch/models/legacy/questions.py | 21 +-- generalresearch/models/lucid/question.py | 10 +- generalresearch/models/lucid/survey.py | 12 +- generalresearch/models/marketplace/summary.py | 4 +- generalresearch/models/morning/question.py | 18 +-- generalresearch/models/morning/survey.py | 39 +++--- generalresearch/models/morning/task_collection.py | 7 +- generalresearch/models/pollfish/question.py | 15 ++- generalresearch/models/precision/question.py | 21 +-- generalresearch/models/precision/survey.py | 24 ++-- .../models/precision/task_collection.py | 6 +- generalresearch/models/prodege/__init__.py | 6 +- generalresearch/models/prodege/question.py | 49 ++++--- generalresearch/models/prodege/survey.py | 76 +++++++---- generalresearch/models/prodege/task_collection.py | 4 +- generalresearch/models/repdata/question.py | 29 +++-- generalresearch/models/repdata/survey.py | 23 ++-- generalresearch/models/repdata/task_collection.py | 8 +- generalresearch/models/sago/question.py | 29 +++-- generalresearch/models/sago/survey.py | 24 ++-- generalresearch/models/sago/task_collection.py | 8 +- generalresearch/models/spectrum/question.py | 25 ++-- generalresearch/models/spectrum/survey.py | 18 +-- generalresearch/models/spectrum/task_collection.py | 8 +- generalresearch/models/thl/__init__.py | 2 +- generalresearch/models/thl/category.py | 6 +- generalresearch/models/thl/contest/__init__.py | 19 +-- generalresearch/models/thl/contest/contest.py | 18 +-- .../models/thl/contest/contest_entry.py | 14 +- generalresearch/models/thl/contest/examples.py | 126 ++++++++++-------- generalresearch/models/thl/contest/leaderboard.py | 20 +-- generalresearch/models/thl/contest/milestone.py | 23 ++-- generalresearch/models/thl/contest/raffle.py | 19 +-- generalresearch/models/thl/contest/utils.py | 4 +- generalresearch/models/thl/definitions.py | 3 +- generalresearch/models/thl/demographics.py | 9 +- generalresearch/models/thl/finance.py | 18 ++- generalresearch/models/thl/grliq.py | 2 +- generalresearch/models/thl/ipinfo.py | 16 ++- generalresearch/models/thl/leaderboard.py | 14 +- generalresearch/models/thl/ledger.py | 20 +-- generalresearch/models/thl/ledger_example.py | 10 +- generalresearch/models/thl/locales.py | 2 +- generalresearch/models/thl/offerwall/__init__.py | 6 +- generalresearch/models/thl/offerwall/base.py | 14 +- generalresearch/models/thl/offerwall/behavior.py | 2 +- generalresearch/models/thl/offerwall/bucket.py | 6 +- generalresearch/models/thl/offerwall/cache.py | 4 +- generalresearch/models/thl/pagination.py | 2 +- generalresearch/models/thl/payout.py | 8 +- generalresearch/models/thl/payout_format.py | 4 +- generalresearch/models/thl/product.py | 39 +++--- .../models/thl/profiling/marketplace.py | 16 ++- generalresearch/models/thl/profiling/question.py | 8 +- .../models/thl/profiling/upk_property.py | 4 +- .../models/thl/profiling/upk_question.py | 20 +-- .../models/thl/profiling/upk_question_answer.py | 11 +- generalresearch/models/thl/profiling/user_info.py | 2 +- .../models/thl/profiling/user_question_answer.py | 12 +- generalresearch/models/thl/report_task.py | 2 +- generalresearch/models/thl/session.py | 29 ++--- generalresearch/models/thl/stats.py | 2 +- generalresearch/models/thl/survey/__init__.py | 15 ++- generalresearch/models/thl/survey/buyer.py | 14 +- generalresearch/models/thl/survey/condition.py | 10 +- generalresearch/models/thl/survey/model.py | 18 +-- generalresearch/models/thl/survey/penalty.py | 4 +- .../models/thl/survey/task_collection.py | 2 +- .../models/thl/synchronize_global_vars.py | 2 +- generalresearch/models/thl/task_adjustment.py | 4 +- generalresearch/models/thl/task_status.py | 16 +-- generalresearch/models/thl/user.py | 18 +-- generalresearch/models/thl/user_iphistory.py | 8 +- generalresearch/models/thl/user_profile.py | 6 +- generalresearch/models/thl/user_quality_event.py | 4 +- generalresearch/models/thl/user_streak.py | 10 +- generalresearch/models/thl/userhealth.py | 4 +- generalresearch/models/thl/wallet/__init__.py | 2 +- .../models/thl/wallet/cashout_method.py | 14 +- generalresearch/models/thl/wallet/payout.py | 12 +- generalresearch/models/thl/wallet/user_wallet.py | 4 +- generalresearch/pg_helper.py | 13 +- generalresearch/schemas/survey_stats.py | 2 +- generalresearch/sql_helper.py | 12 +- generalresearch/utils/aggregation.py | 4 +- generalresearch/utils/grpc_logger.py | 2 +- generalresearch/wall_status_codes/__init__.py | 11 +- generalresearch/wall_status_codes/cint.py | 5 +- generalresearch/wall_status_codes/dynata.py | 18 ++- generalresearch/wall_status_codes/fullcircle.py | 14 +- generalresearch/wall_status_codes/innovate.py | 15 ++- generalresearch/wall_status_codes/lucid.py | 33 +++-- generalresearch/wall_status_codes/morning.py | 15 ++- generalresearch/wall_status_codes/pollfish.py | 13 +- generalresearch/wall_status_codes/precision.py | 12 +- generalresearch/wall_status_codes/prodege.py | 10 +- generalresearch/wall_status_codes/repdata.py | 14 +- generalresearch/wall_status_codes/sago.py | 19 ++- generalresearch/wall_status_codes/spectrum.py | 12 +- generalresearch/wall_status_codes/wxet.py | 15 ++- generalresearch/wxet/models/definitions.py | 3 + generalresearch/wxet/models/finish_type.py | 2 +- test_utils/conftest.py | 9 +- test_utils/managers/conftest.py | 145 +++++++++++++-------- test_utils/managers/upk/conftest.py | 32 +++-- test_utils/models/conftest.py | 110 +++++++++------- test_utils/spectrum/conftest.py | 8 +- 128 files changed, 1179 insertions(+), 885 deletions(-) diff --git a/generalresearch/__init__.py b/generalresearch/__init__.py index 5993613..fb5f28d 100644 --- a/generalresearch/__init__.py +++ b/generalresearch/__init__.py @@ -1,12 +1,19 @@ import threading import time from functools import wraps +from typing import Any, Callable, Optional from decorator import decorator from wrapt import FunctionWrapper, ObjectProxy -def retry(exceptions, tries=4, delay=0.5, backoff=2, logger=None): +def retry( + exceptions, + tries: int = 4, + delay: float = 0.5, + backoff: int = 2, + logger: Optional[Any] = None, +) -> Callable: """ https://www.calazan.com/retry-decorator-for-python-3/ Retry calling the decorated function using an exponential backoff. diff --git a/generalresearch/config.py b/generalresearch/config.py index 1c3f6e6..94885b1 100644 --- a/generalresearch/config.py +++ b/generalresearch/config.py @@ -1,11 +1,11 @@ from datetime import datetime, timezone -from typing import Optional from pathlib import Path +from typing import Optional -from pydantic import RedisDsn, Field, MariaDBDsn, DirectoryPath, PostgresDsn +from pydantic import DirectoryPath, Field, MariaDBDsn, PostgresDsn, RedisDsn from pydantic_settings import BaseSettings -from generalresearch.models.custom_types import DaskDsn, SentryDsn, MySQLOrMariaDsn +from generalresearch.models.custom_types import DaskDsn, SentryDsn def is_debug() -> bool: diff --git a/generalresearch/healing_ppe.py b/generalresearch/healing_ppe.py index 2a3aecf..e8ee144 100644 --- a/generalresearch/healing_ppe.py +++ b/generalresearch/healing_ppe.py @@ -5,6 +5,7 @@ import time from collections import defaultdict from concurrent import futures from concurrent.futures.process import BrokenProcessPool +from typing import Optional logger = logging.getLogger() @@ -15,7 +16,9 @@ signal_int_name = defaultdict( class HealingProcessPoolExecutor: def __init__( - self, max_workers=None, name=None, slack_token=None, slack_channel=None + self, + max_workers: Optional[int] = None, + name: Optional[str] = None, ): if not name: try: diff --git a/generalresearch/mariadb.py b/generalresearch/mariadb.py index d95e46c..8bcd8ee 100644 --- a/generalresearch/mariadb.py +++ b/generalresearch/mariadb.py @@ -25,7 +25,7 @@ def example(): # actually we don't need the field flags. I didn't see, but there is an # extended field type returned also. Which explicitly tags uuid fields. conn = mariadb.connect( - host="127.0.0.1", user="root", password="", database="300large-morning" + host="127.0.0.1", user="root", password="", database="thl-morning" ) c = conn.cursor() c.execute("SELECT user_id, pid as greg FROM morning_userpid limit 1") diff --git a/generalresearch/models/cint/question.py b/generalresearch/models/cint/question.py index f840a46..77212a2 100644 --- a/generalresearch/models/cint/question.py +++ b/generalresearch/models/cint/question.py @@ -1,10 +1,10 @@ import json from datetime import datetime, timezone from enum import Enum -from typing import Optional, List, Literal, Dict, Any +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional from uuid import UUID -from pydantic import Field, BaseModel, model_validator, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self from generalresearch.models import Source, string_utils @@ -15,6 +15,11 @@ from generalresearch.models.thl.profiling.marketplace import ( MarketplaceUserQuestionAnswer, ) +if TYPE_CHECKING: + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + class CintQuestionType(str, Enum): SINGLE_SELECT = "s" @@ -118,27 +123,34 @@ class CintQuestion(MarketplaceQuestion): @field_validator("options") @classmethod - def order_options(cls, options): + def order_options( + cls, options: Optional[List[CintQuestionOption]] + ) -> Optional[List[CintQuestionOption]]: if options: options.sort(key=lambda x: x.order) + return options @field_validator("options") @classmethod - def validate_options(cls, options): + def validate_options( + cls, options: Optional[List[CintQuestionOption]] + ) -> Optional[List[CintQuestionOption]]: if options: ids = {x.id for x in options} assert len(ids) == len(options), "options.id must be unique" orders = {x.order for x in options} assert len(orders) == len(options), "options.order must be unique" + return options @classmethod - def from_api(cls, d: dict, country_iso: str, language_iso: str) -> Self: + def from_api(cls, d: Dict[str, Any], country_iso: str, language_iso: str) -> Self: options = None created_at = datetime.strptime( d["create_date"], "%Y-%m-%dT%H:%M:%S%z" ).astimezone(timezone.utc) + if d.get("question_options"): options = [ CintQuestionOption( @@ -166,15 +178,17 @@ class CintQuestion(MarketplaceQuestion): ) @classmethod - def from_db(cls, d: dict) -> Self: + def from_db(cls, d: Dict[str, Any]) -> Self: options = None if d["options"]: options = [ CintQuestionOption(id=r["id"], text=r["text"], order=r["order"]) for r in d["options"] ] + if d.get("created_at"): d["created_at"] = d["created_at"].replace(tzinfo=timezone.utc) + return cls( question_id=d["question_id"], question_name=d["question_name"], @@ -194,15 +208,16 @@ class CintQuestion(MarketplaceQuestion): d["options"] = json.dumps(d["options"]) if self.created_at: d["created_at"] = self.created_at.replace(tzinfo=None) + return d - def to_upk_question(self): + def to_upk_question(self) -> "UpkQuestion": from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, UpkQuestionChoice, - UpkQuestionType, UpkQuestionSelectorMC, UpkQuestionSelectorTE, - UpkQuestion, + UpkQuestionType, ) upk_type_selector_map = { diff --git a/generalresearch/models/cint/survey.py b/generalresearch/models/cint/survey.py index 30b140c..be07e04 100644 --- a/generalresearch/models/cint/survey.py +++ b/generalresearch/models/cint/survey.py @@ -4,32 +4,32 @@ import json import logging from datetime import datetime, timezone from decimal import Decimal -from typing import Optional, Dict, Set, Tuple, List, Literal, Any, Type +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type from more_itertools import flatten from pydantic import ( - NonNegativeInt, - Field, - ConfigDict, BaseModel, + ConfigDict, + Field, + NonNegativeInt, computed_field, model_validator, ) -from typing_extensions import Self, Annotated +from typing_extensions import Annotated, Self from generalresearch.locales import Localelator from generalresearch.models import Source, TaskCalculationType from generalresearch.models.cint import CintQuestionIdType from generalresearch.models.custom_types import ( + AlphaNumStr, AwareDatetimeISO, CoercedStr, - AlphaNumStr, ) from generalresearch.models.thl.demographics import Gender from generalresearch.models.thl.survey import MarketplaceTask from generalresearch.models.thl.survey.condition import ( - MarketplaceCondition, ConditionValueType, + MarketplaceCondition, ) logging.basicConfig() @@ -95,7 +95,8 @@ class CintQuota(BaseModel): def matches(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: # Matches means we meet all conditions. - # We can "match" a quota that is closed. In that case, we would not be eligible for the survey. + # We can "match" a quota that is closed. In that case, we would + # not be eligible for the survey. return all(criteria_evaluation.get(c) for c in self.condition_hashes) def matches_optional( @@ -245,7 +246,7 @@ class CintSurvey(MarketplaceTask): def is_live(self) -> bool: return self.is_live_raw - def model_dump(self, **kwargs: Any) -> dict: + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: data = super().model_dump(**kwargs) data["is_live"] = data.pop("is_live_raw", None) return data @@ -314,7 +315,7 @@ class CintSurvey(MarketplaceTask): } @classmethod - def from_api(cls, d: Dict) -> Optional[Self]: + def from_api(cls, d: Dict[str, Any]) -> Optional[Self]: try: return cls._from_api(d) except Exception as e: @@ -322,7 +323,7 @@ class CintSurvey(MarketplaceTask): return None @classmethod - def _from_api(cls, d: Dict) -> Self: + def _from_api(cls, d: Dict[str, Any]) -> Self: if "cpi" in d: d["gross_cpi"] = Decimal(d.pop("cpi")) if "revenue_per_interview" in d: diff --git a/generalresearch/models/cint/task_collection.py b/generalresearch/models/cint/task_collection.py index 91db35f..b43250a 100644 --- a/generalresearch/models/cint/task_collection.py +++ b/generalresearch/models/cint/task_collection.py @@ -1,7 +1,7 @@ from typing import List, Set import pandas as pd -from pandera import Column, DataFrameSchema, Check, Index +from pandera import Check, Column, DataFrameSchema, Index from generalresearch.locales import Localelator from generalresearch.models.cint.survey import CintSurvey diff --git a/generalresearch/models/custom_types.py b/generalresearch/models/custom_types.py index 54208a1..aefbbe9 100644 --- a/generalresearch/models/custom_types.py +++ b/generalresearch/models/custom_types.py @@ -1,16 +1,16 @@ import json -from datetime import datetime, timezone, timedelta -from typing import Any, Optional, Set, Literal +from datetime import datetime, timedelta, timezone +from typing import Any, Literal, Optional, Set from uuid import UUID from pydantic import ( + AnyUrl, AwareDatetime, - StringConstraints, - TypeAdapter, + Field, HttpUrl, IPvAnyAddress, - Field, - AnyUrl, + StringConstraints, + TypeAdapter, ) from pydantic.functional_serializers import PlainSerializer from pydantic.functional_validators import AfterValidator, BeforeValidator @@ -20,7 +20,6 @@ from typing_extensions import Annotated from generalresearch.models import DeviceType, Source - # if TYPE_CHECKING: # from generalresearch.models import DeviceType diff --git a/generalresearch/models/dynata/question.py b/generalresearch/models/dynata/question.py index 76670ce..4d675c0 100644 --- a/generalresearch/models/dynata/question.py +++ b/generalresearch/models/dynata/question.py @@ -5,11 +5,11 @@ import re from datetime import timedelta from enum import Enum from functools import cached_property -from typing import List, Optional, Literal, Any, Dict, Set +from typing import Any, Dict, List, Literal, Optional, Set -from pydantic import BaseModel, Field, model_validator, field_validator, PositiveInt +from pydantic import BaseModel, Field, PositiveInt, field_validator, model_validator -from generalresearch.models import Source, MAX_INT32 +from generalresearch.models import MAX_INT32, Source from generalresearch.models.custom_types import AwareDatetimeISO from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion @@ -228,11 +228,11 @@ class DynataQuestion(MarketplaceQuestion): def to_upk_question(self): from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, UpkQuestionChoice, - UpkQuestionType, UpkQuestionSelectorMC, UpkQuestionSelectorTE, - UpkQuestion, + UpkQuestionType, order_exclusive_options, ) diff --git a/generalresearch/models/dynata/survey.py b/generalresearch/models/dynata/survey.py index de4b177..ae12436 100644 --- a/generalresearch/models/dynata/survey.py +++ b/generalresearch/models/dynata/survey.py @@ -5,28 +5,28 @@ import logging from datetime import timezone from decimal import Decimal from functools import cached_property -from typing import Optional, Dict, Any, List, Literal, Set, Tuple, Type +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type from more_itertools import flatten from pydantic import ( - Field, - ConfigDict, BaseModel, - model_validator, - field_validator, + ConfigDict, + Field, RootModel, computed_field, + field_validator, + model_validator, ) from typing_extensions import Self from generalresearch.locales import Localelator -from generalresearch.models import TaskCalculationType, Source +from generalresearch.models import Source, TaskCalculationType from generalresearch.models.custom_types import ( - CoercedStr, - AwareDatetimeISO, + AlphaNumStr, AlphaNumStrSet, + AwareDatetimeISO, + CoercedStr, DeviceTypes, - AlphaNumStr, ) from generalresearch.models.dynata import DynataStatus from generalresearch.models.thl.demographics import ( @@ -48,20 +48,30 @@ locale_helper = Localelator() class DynataRequirements(BaseModel): # Requires inviting (recontacting) specific respondents to a follow up survey. requires_recontact: bool = Field(default=False) - # Requires respondents to provide personally identifiable information (PII) within client survey. + + # Requires respondents to provide personally identifiable + # information (PII) within client survey. requires_pii_collection: bool = Field(default=False) + # Requires respondents to utilize their webcam to participate. requires_webcam: bool = Field(default=False) - # Requires use of facial recognition technology with respondents, such as eye tracking. + + # Requires use of facial recognition technology with + # respondents, such as eye tracking. requires_eye_tracking: bool = Field(default=False) + # Requires partner to allow Dynata to drop a cookie on respondent. requires_cookie_drops: bool = Field(default=False) - # Requires partner-uploaded respondent PII to expand third-party matched data. + + # Requires partner-uploaded respondent PII to expand + # third-party matched data. requires_sample_plus: bool = Field(default=False) + # Requires respondents to download a software application. requires_app_vpn: bool = Field(default=False) - # Requires additional incentives to be manually awarded to respondent by partner outside of the typical online - # survey flow. + + # Requires additional incentives to be manually awarded to + # respondent by partner outside of the typical online survey flow. requires_manual_rewards: bool = Field(default=False) def __repr__(self) -> str: diff --git a/generalresearch/models/dynata/task_collection.py b/generalresearch/models/dynata/task_collection.py index bd9b81a..a10fcf0 100644 --- a/generalresearch/models/dynata/task_collection.py +++ b/generalresearch/models/dynata/task_collection.py @@ -1,7 +1,7 @@ -from typing import List, Dict, Any +from typing import Any, Dict, List import pandas as pd -from pandera import Column, DataFrameSchema, Check, Index +from pandera import Check, Column, DataFrameSchema, Index from generalresearch.locales import Localelator from generalresearch.models import TaskCalculationType diff --git a/generalresearch/models/events.py b/generalresearch/models/events.py index 4504ab1..63ed2a1 100644 --- a/generalresearch/models/events.py +++ b/generalresearch/models/events.py @@ -1,30 +1,30 @@ -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from enum import StrEnum -from typing import Union, Literal, Optional, Dict +from typing import Dict, Literal, Optional, Union from uuid import uuid4 from pydantic import ( BaseModel, + ConfigDict, Field, - PositiveFloat, NonNegativeInt, - model_validator, + PositiveFloat, TypeAdapter, - ConfigDict, + model_validator, ) from typing_extensions import Annotated from generalresearch.models import Source from generalresearch.models.custom_types import ( + AwareDatetimeISO, CountryISOLike, UUIDStr, - AwareDatetimeISO, ) from generalresearch.models.thl.definitions import ( + SessionStatusCode2, Status, StatusCode1, WallStatusCode2, - SessionStatusCode2, ) diff --git a/generalresearch/models/gr/__init__.py b/generalresearch/models/gr/__init__.py index 713bba6..7e1516b 100644 --- a/generalresearch/models/gr/__init__.py +++ b/generalresearch/models/gr/__init__.py @@ -1,9 +1,9 @@ -from generalresearch.models.gr.authentication import GRUser, GRToken +from generalresearch.models.gr.authentication import GRToken, GRUser from generalresearch.models.gr.business import Business from generalresearch.models.gr.team import Team +from generalresearch.models.thl.finance import BusinessBalances from generalresearch.models.thl.payout import BrokerageProductPayoutEvent from generalresearch.models.thl.product import Product -from generalresearch.models.thl.finance import BusinessBalances _ = Business, Product, BrokerageProductPayoutEvent, BusinessBalances diff --git a/generalresearch/models/gr/authentication.py b/generalresearch/models/gr/authentication.py index ff1d065..764c694 100644 --- a/generalresearch/models/gr/authentication.py +++ b/generalresearch/models/gr/authentication.py @@ -4,16 +4,16 @@ import binascii import json import os from datetime import datetime, timezone -from typing import Optional, List, TYPE_CHECKING, Dict, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from pydantic import AnyHttpUrl from pydantic import ( + AnyHttpUrl, BaseModel, ConfigDict, Field, + NonNegativeInt, PositiveInt, field_validator, - NonNegativeInt, ) from typing_extensions import Self @@ -119,7 +119,7 @@ class GRUser(BaseModel): claims: Optional["Claims"] = Field(default=None) def prefetch_claims( - self, token: str, key: Dict, audience: str, issuer: AnyHttpUrl + self, token: str, key: Dict[str, Any], audience: str, issuer: AnyHttpUrl ) -> None: from jose import jwt @@ -220,7 +220,7 @@ class GRUser(BaseModel): LOG.warning("prefetch not run") return None - return [b.id for b in self.businesses] + return [b.id for b in self.businesses if b.id is not None] @property def team_uuids(self) -> Optional[List[UUIDStr]]: @@ -236,7 +236,7 @@ class GRUser(BaseModel): LOG.warning("prefetch not run") return None - return [t.id for t in self.teams] + return [t.id for t in self.teams if t.id is not None] @property def product_uuids(self) -> Optional[List[UUIDStr]]: @@ -294,7 +294,7 @@ class GRUser(BaseModel): return GRUser.model_validate(d) @classmethod - def from_redis(cls, d: Union[str, Dict]) -> Self: + def from_redis(cls, d: Union[str, Dict[str, Any]]) -> Self: if isinstance(d, str): d = json.loads(d) assert isinstance(d, dict) @@ -359,13 +359,13 @@ class GRToken(BaseModel): # --- Properties --- @property - def auth_header(self, key_name="Authorization") -> Dict: + def auth_header(self, key_name: str = "Authorization") -> Dict[str, str]: return {key_name: self.key} # --- ORM --- @classmethod - def from_redis(cls, d: Union[str, Dict]) -> Self: + def from_redis(cls, d: Union[str, Dict[str, Any]]) -> Self: if isinstance(d, str): d = json.loads(d) assert isinstance(d, dict) diff --git a/generalresearch/models/gr/business.py b/generalresearch/models/gr/business.py index b07c584..8a37f74 100644 --- a/generalresearch/models/gr/business.py +++ b/generalresearch/models/gr/business.py @@ -6,7 +6,7 @@ import os from datetime import datetime, timezone from enum import Enum from pathlib import Path -from typing import Optional, List, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, List, Optional, Union from uuid import uuid4 import pandas as pd @@ -24,12 +24,11 @@ from generalresearch.incite.mergers.pop_ledger import PopLedgerMerge from generalresearch.incite.schemas.mergers.pop_ledger import ( numerical_col_names, ) -from generalresearch.utils.enum import ReprEnumMeta from generalresearch.models.admin.request import ReportRequest, ReportType from generalresearch.models.custom_types import ( + AwareDatetime, UUIDStr, UUIDStrCoerce, - AwareDatetime, ) from generalresearch.models.thl.finance import POPFinancial from generalresearch.models.thl.ledger import LedgerAccount, OrderBy @@ -37,12 +36,9 @@ from generalresearch.models.thl.payout import BusinessPayoutEvent from generalresearch.pg_helper import PostgresConfig from generalresearch.redis_helper import RedisConfig from generalresearch.utils.aggregation import group_by_year +from generalresearch.utils.enum import ReprEnumMeta if TYPE_CHECKING: - from generalresearch.models.thl.finance import BusinessBalances - - from generalresearch.models.thl.product import Product - from generalresearch.models.gr.team import Team from generalresearch.incite.base import GRLDatasets from generalresearch.incite.mergers.foundations.enriched_session import ( EnrichedSessionMerge, @@ -59,7 +55,9 @@ if TYPE_CHECKING: from generalresearch.managers.thl.payout import ( BusinessPayoutEventManager, ) - from generalresearch.managers.thl.product import ProductManager + from generalresearch.models.gr.team import Team + from generalresearch.models.thl.finance import BusinessBalances + from generalresearch.models.thl.product import Product class TransferMethod(Enum, metaclass=ReprEnumMeta): @@ -248,7 +246,7 @@ class Business(BaseModel): with pg_config.make_connection() as conn: with conn.cursor(row_factory=dict_row) as c: c.execute( - query=f""" + query=""" SELECT * FROM common_businessaddress AS ba WHERE ba.business_id = %s @@ -271,7 +269,7 @@ class Business(BaseModel): c: Cursor c.execute( - query=f""" + query=""" SELECT t.* FROM common_team AS t INNER JOIN common_team_businesses AS tb @@ -359,9 +357,11 @@ class Business(BaseModel): self.prefetch_products(thl_pg_config=thl_pg_config) accounts: List[LedgerAccount] = lm.get_accounts_if_exists( - qualified_names=[ - f"{lm.currency.value}:bp_wallet:{bpid}" for bpid in self.product_uuids - ] + qualified_names=( + [f"{lm.currency.value}:bp_wallet:{bpid}" for bpid in self.product_uuids] + if self.product_uuids + else [] + ) ) if len(accounts) != len(self.products): @@ -711,7 +711,8 @@ class Business(BaseModel): fields: List[str], gr_redis_config: RedisConfig, ) -> Optional[Self]: - keys: List = Business.required_fields() + fields + keys: List[str] = Business.required_fields() + fields + if "pop_financial" in keys: # We should explicitly pass the pop_financial years we want. By default, # at least get this year. diff --git a/generalresearch/models/gr/team.py b/generalresearch/models/gr/team.py index 38aff56..a3ac9cf 100644 --- a/generalresearch/models/gr/team.py +++ b/generalresearch/models/gr/team.py @@ -3,7 +3,7 @@ import os from datetime import datetime, timezone from enum import Enum from pathlib import Path -from typing import Optional, Union, List, TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional, Union from uuid import uuid4 import pandas as pd @@ -25,7 +25,6 @@ from generalresearch.incite.mergers.foundations.enriched_session import ( from generalresearch.incite.mergers.foundations.enriched_wall import ( EnrichedWallMerge, ) -from generalresearch.utils.enum import ReprEnumMeta from generalresearch.models.admin.request import ReportRequest, ReportType from generalresearch.models.custom_types import ( AwareDatetimeISO, @@ -34,11 +33,12 @@ from generalresearch.models.custom_types import ( ) from generalresearch.pg_helper import PostgresConfig from generalresearch.redis_helper import RedisConfig +from generalresearch.utils.enum import ReprEnumMeta if TYPE_CHECKING: from generalresearch.incite.base import GRLDatasets - from generalresearch.models.gr.business import Business from generalresearch.models.gr.authentication import GRUser + from generalresearch.models.gr.business import Business from generalresearch.models.thl.product import Product @@ -77,7 +77,7 @@ class Membership(BaseModel): "account was created." ) - user_id: SkipJsonSchema[PositiveInt] = Field(default=None) + user_id: SkipJsonSchema[PositiveInt] = Field() team_id: SkipJsonSchema[PositiveInt] = Field() @@ -189,7 +189,7 @@ class Team(BaseModel): ) try: - test = pd.read_parquet(path, engine="pyarrow") + _ = pd.read_parquet(path, engine="pyarrow") except Exception as e: raise IOError(f"Parquet verification failed: {e}") @@ -234,7 +234,7 @@ class Team(BaseModel): ) try: - test = pd.read_parquet(path, engine="pyarrow") + _ = pd.read_parquet(path, engine="pyarrow") except Exception as e: raise IOError(f"Parquet verification failed: {e}") @@ -342,5 +342,6 @@ class Team(BaseModel): res: List = rc.hmget(name=f"team:{uuid}", keys=keys) d = {val: json.loads(res[idx]) for idx, val in enumerate(keys)} return Team.model_validate(d) + except (Exception,) as e: return None diff --git a/generalresearch/models/innovate/question.py b/generalresearch/models/innovate/question.py index 0fd2547..0d47ba7 100644 --- a/generalresearch/models/innovate/question.py +++ b/generalresearch/models/innovate/question.py @@ -4,9 +4,9 @@ from __future__ import annotations import json import logging from enum import Enum -from typing import List, Optional, Literal, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional -from pydantic import BaseModel, Field, model_validator, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from generalresearch.models import Source from generalresearch.models.innovate import InnovateQuestionID @@ -15,6 +15,11 @@ from generalresearch.models.thl.profiling.marketplace import ( MarketplaceUserQuestionAnswer, ) +if TYPE_CHECKING: + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -47,8 +52,9 @@ class InnovateQuestionOption(BaseModel): class InnovateQuestionType(str, Enum): # API response: {'Multipunch', 'Numeric Open Ended', 'Single Punch'} - # "Numeric Open Ended" must be wrong... It can't be numeric, as UK's postcode question is marked - # as this, but it wants alphanumeric answers. So this is just text_entry. + # "Numeric Open Ended" must be wrong... It can't be numeric, as UK's + # postcode question is marked as this, but it wants alphanumeric + # answers. So this is just text_entry. SINGLE_SELECT = "s" MULTI_SELECT = "m" @@ -135,7 +141,7 @@ class InnovateQuestion(MarketplaceQuestion): @classmethod def from_api( - cls, d: dict, country_iso: str, language_iso: str + cls, d: Dict, country_iso: str, language_iso: str ) -> Optional["InnovateQuestion"]: """ :param d: Raw response from API @@ -179,13 +185,15 @@ class InnovateQuestion(MarketplaceQuestion): ) @classmethod - def from_db(cls, d: dict) -> "InnovateQuestion": + def from_db(cls, d: Dict[str, Any]) -> "InnovateQuestion": + options = None if d["options"]: options = [ InnovateQuestionOption(id=r["id"], text=r["text"], order=r["order"]) for r in d["options"] ] + return cls( question_id=d["question_id"], question_key=d["question_key"], @@ -204,13 +212,13 @@ class InnovateQuestion(MarketplaceQuestion): d["options"] = json.dumps(d["options"]) return d - def to_upk_question(self): + def to_upk_question(self) -> "UpkQuestion": from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, UpkQuestionChoice, - UpkQuestionType, UpkQuestionSelectorMC, UpkQuestionSelectorTE, - UpkQuestion, + UpkQuestionType, ) upk_type_selector_map = { diff --git a/generalresearch/models/innovate/survey.py b/generalresearch/models/innovate/survey.py index e230899..3a3d9e2 100644 --- a/generalresearch/models/innovate/survey.py +++ b/generalresearch/models/innovate/survey.py @@ -2,47 +2,47 @@ from __future__ import annotations import json import logging -from datetime import timezone, date +from datetime import date, timezone from decimal import Decimal from functools import cached_property from typing import ( - Optional, - Dict, + Annotated, Any, + Dict, List, Literal, + Optional, Set, Tuple, - Annotated, Type, ) from more_itertools import flatten from pydantic import ( - Field, - ConfigDict, BaseModel, - model_validator, + ConfigDict, + Field, computed_field, + model_validator, ) from typing_extensions import Self from generalresearch.locales import Localelator from generalresearch.models import ( - Source, LogicalOperator, + Source, TaskCalculationType, ) from generalresearch.models.custom_types import ( - CoercedStr, - AwareDatetimeISO, AlphaNumStrSet, + AwareDatetimeISO, + CoercedStr, DeviceTypes, ) from generalresearch.models.innovate import ( - InnovateStatus, - InnovateQuotaStatus, InnovateDuplicateCheckLevel, + InnovateQuotaStatus, + InnovateStatus, ) from generalresearch.models.innovate.question import InnovateQuestionID from generalresearch.models.thl.demographics import Gender @@ -266,7 +266,7 @@ class InnovateSurvey(MarketplaceTask): return data @classmethod - def from_api(cls, d: Dict) -> Optional["InnovateSurvey"]: + def from_api(cls, d: Dict[str, Any]) -> Optional["InnovateSurvey"]: try: return cls._from_api(d) except Exception as e: @@ -274,7 +274,7 @@ class InnovateSurvey(MarketplaceTask): return None @classmethod - def _from_api(cls, d: Dict): + def _from_api(cls, d: Dict[str, Any]) -> "InnovateSurvey": d["conditions"] = dict() # If we haven't hit the "detail" endpoint, we won't get this @@ -334,7 +334,7 @@ class InnovateSurvey(MarketplaceTask): ) return f"{self.__repr_name__()}({repr_str})" - def is_unchanged(self, other) -> bool: + def is_unchanged(self, other: "InnovateSurvey") -> bool: # Avoiding overloading __eq__ because it looks kind of complicated? I # want to be explicit that this is not testing object equivalence, # just that the objects don't require any db updates. We also exclude @@ -465,6 +465,7 @@ class InnovateSurvey(MarketplaceTask): return None, set( flatten([m[1] for q, m in quota_eval.items() if m[0] is None]) ) + return False, set() def determine_eligibility( diff --git a/generalresearch/models/innovate/task_collection.py b/generalresearch/models/innovate/task_collection.py index 97e7a7c..4d7fe51 100644 --- a/generalresearch/models/innovate/task_collection.py +++ b/generalresearch/models/innovate/task_collection.py @@ -1,7 +1,7 @@ from typing import List, Set import pandas as pd -from pandera import Column, DataFrameSchema, Check, Index +from pandera import Check, Column, DataFrameSchema, Index from generalresearch.locales import Localelator from generalresearch.models.innovate import InnovateStatus diff --git a/generalresearch/models/legacy/bucket.py b/generalresearch/models/legacy/bucket.py index 5afb17b..8ce559b 100644 --- a/generalresearch/models/legacy/bucket.py +++ b/generalresearch/models/legacy/bucket.py @@ -4,23 +4,23 @@ import logging import math from datetime import timedelta from decimal import Decimal -from typing import Optional, Dict, List, Union, Literal, Tuple -from typing_extensions import Self +from typing import Any, Dict, List, Literal, Optional, Tuple, Union from pydantic import ( BaseModel, + ConfigDict, Field, + NonNegativeInt, field_validator, model_validator, - ConfigDict, - NonNegativeInt, ) +from typing_extensions import Self from generalresearch.models import Source from generalresearch.models.custom_types import ( HttpsUrl, - UUIDStr, PropertyCode, + UUIDStr, ) from generalresearch.models.thl.stats import StatisticalSummary @@ -181,8 +181,10 @@ class Bucket(BaseModel): loi_q1: Optional[timedelta] = Field(strict=True, default=None) loi_q2: Optional[timedelta] = Field(strict=True, default=None) loi_q3: Optional[timedelta] = Field(strict=True, default=None) - # decimal USD. This should not have more than 2 decimal places. - # There is no way to make this "strict" and optional, so we have a separate pre-validator + + # Decimal USD. This should not have more than 2 decimal places. + # There is no way to make this "strict" and optional, so + # we have a separate pre-validator user_payout_min: Optional[Decimal] = Field(default=None, lt=1000, gt=0) user_payout_max: Optional[Decimal] = Field(default=None, lt=1000, gt=0) user_payout_q1: Optional[Decimal] = Field(default=None, lt=1000, gt=0) @@ -358,7 +360,7 @@ class Bucket(BaseModel): ) @classmethod - def parse_from_offerwall_style2(cls, bucket: Dict): + def parse_from_offerwall_style2(cls, bucket: Dict[str, Any]): # {'payout': {'min': 123}} loi_min_sec = bucket.get("duration", {}).get("min") loi_max_sec = bucket.get("duration", {}).get("max") @@ -383,7 +385,7 @@ class Bucket(BaseModel): ) @classmethod - def parse_from_offerwall_style3(cls, bucket: Dict): + def parse_from_offerwall_style3(cls, bucket: Dict[str, Any]): # {'payout': 123, 'duration': 123} return cls( user_payout_min=cls.usd_cents_to_decimal(bucket["payout"]), @@ -397,13 +399,13 @@ class Bucket(BaseModel): ) @staticmethod - def usd_cents_to_decimal(v: int): + def usd_cents_to_decimal(v: Optional[int]) -> Optional[Decimal]: if v is None: return None return Decimal(Decimal(int(v)) / Decimal(100)) @staticmethod - def decimal_to_usd_cents(d: Decimal): + def decimal_to_usd_cents(d: Optional[Decimal]) -> Optional[Decimal]: if d is None: return None return round(d * Decimal(100), 2) @@ -437,7 +439,7 @@ class DurationSummary(StatisticalSummary): } @classmethod - def from_bucket(cls, bucket: Bucket): + def from_bucket(cls, bucket: Bucket) -> "DurationSummary": return cls( min=bucket.loi_min.total_seconds(), max=bucket.loi_max.total_seconds(), @@ -610,18 +612,21 @@ class TopNPlusBucket(BucketBase): def eligibility_ranks(cls, criteria): criteria = list(criteria) ranks = [c.rank for c in criteria] + if all(r is None for r in ranks): for i, c in enumerate(criteria): c.rank = i return tuple(criteria) + if any(r is None for r in ranks): raise ValueError("Set all or no ranks in eligibility_criteria") if len(ranks) != len(set(ranks)): raise ValueError("Duplicate ranks") + return tuple(sorted(criteria, key=lambda c: c.rank)) @classmethod - def from_bucket(cls, bucket: Bucket): + def from_bucket(cls, bucket: Bucket) -> "TopNPlusBucket": return cls.model_validate( { "id": bucket.id, diff --git a/generalresearch/models/legacy/offerwall.py b/generalresearch/models/legacy/offerwall.py index 67213f7..a8efe9a 100644 --- a/generalresearch/models/legacy/offerwall.py +++ b/generalresearch/models/legacy/offerwall.py @@ -1,22 +1,22 @@ from __future__ import annotations -from typing import List, Dict +from typing import Dict, List -from pydantic import BaseModel, Field, ConfigDict, NonNegativeInt +from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt from generalresearch.models.custom_types import UUIDStr from generalresearch.models.legacy.bucket import ( BucketBase, - SoftPairBucket, - TopNBucket, - TimeBucksBucket, MarketplaceBucket, - TopNPlusBucket, - SingleEntryBucket, - WXETOfferwallBucket, OneShotOfferwallBucket, OneShotSoftPairOfferwallBucket, + SingleEntryBucket, + SoftPairBucket, + TimeBucksBucket, + TopNBucket, + TopNPlusBucket, TopNPlusRecontactBucket, + WXETOfferwallBucket, ) from generalresearch.models.legacy.definitions import OfferwallReason from generalresearch.models.thl.payout_format import ( diff --git a/generalresearch/models/legacy/questions.py b/generalresearch/models/legacy/questions.py index 1559f24..1a86121 100644 --- a/generalresearch/models/legacy/questions.py +++ b/generalresearch/models/legacy/questions.py @@ -1,21 +1,20 @@ from __future__ import annotations -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional from pydantic import ( BaseModel, + BeforeValidator, + ConfigDict, Field, NonNegativeInt, - model_validator, StringConstraints, - ConfigDict, ValidationError, - BeforeValidator, field_validator, + model_validator, ) from sentry_sdk import capture_exception -from typing_extensions import Annotated -from typing_extensions import Self +from typing_extensions import Annotated, Self from generalresearch.models.custom_types import UUIDStr from generalresearch.models.legacy.api_status import StatusResponse @@ -34,10 +33,10 @@ if TYPE_CHECKING: class UpkQuestionResponse(StatusResponse): questions: List[UpkQuestionOut] = Field() - consent_questions: List[Dict] = Field( + consent_questions: List[Dict[str, Any]] = Field( description="For internal use", default_factory=list ) - special_questions: List[Dict] = Field( + special_questions: List[Dict[str, Any]] = Field( description="For internal use", default_factory=list ) count: NonNegativeInt = Field(description="The number of questions returned") @@ -221,7 +220,7 @@ class UserQuestionAnswers(BaseModel): def prefetch_user(self, um: "UserManager") -> None: from generalresearch.models.thl.user import User - res: User = um.get_user_if_exists( + res: Optional[User] = um.get_user_if_exists( product_id=self.product_id, product_user_id=self.product_user_id ) @@ -229,10 +228,11 @@ class UserQuestionAnswers(BaseModel): raise ValidationError("Invalid user") self.user = res + return None def prefetch_wall(self, wm: "WallManager") -> None: - from generalresearch.models.thl.session import Wall from generalresearch.models import Source + from generalresearch.models.thl.session import Wall res: Optional[Wall] = wm.get_from_uuid_if_exists(wall_uuid=self.session_id) @@ -252,3 +252,4 @@ class UserQuestionAnswers(BaseModel): raise ValueError("Not a valid GRS event status") self.wall = res + return None diff --git a/generalresearch/models/lucid/question.py b/generalresearch/models/lucid/question.py index b3d9b27..8542156 100644 --- a/generalresearch/models/lucid/question.py +++ b/generalresearch/models/lucid/question.py @@ -2,9 +2,9 @@ from __future__ import annotations import logging from enum import Enum -from typing import List, Optional, Literal, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional -from pydantic import BaseModel, Field, model_validator, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self from generalresearch.models import Source @@ -94,7 +94,7 @@ class LucidQuestion(MarketplaceQuestion): return options @classmethod - def from_db(cls, d: dict) -> Self: + def from_db(cls, d: Dict[str, Any]) -> Self: options = None if d["options"]: options = [ @@ -112,11 +112,11 @@ class LucidQuestion(MarketplaceQuestion): def to_upk_question(self) -> "UpkQuestion": from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, UpkQuestionChoice, - UpkQuestionType, UpkQuestionSelectorMC, UpkQuestionSelectorTE, - UpkQuestion, + UpkQuestionType, ) upk_type_selector_map = { diff --git a/generalresearch/models/lucid/survey.py b/generalresearch/models/lucid/survey.py index 6d81254..0e01243 100644 --- a/generalresearch/models/lucid/survey.py +++ b/generalresearch/models/lucid/survey.py @@ -1,20 +1,20 @@ from __future__ import annotations -from typing import Optional, Dict, Set, Tuple, List +from typing import Any, Dict, List, Optional, Self, Set, Tuple -from pydantic import NonNegativeInt, Field, ConfigDict, BaseModel +from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt from generalresearch.models import Source from generalresearch.models.custom_types import ( AwareDatetimeISO, - UUIDStr, - CoercedStr, BigAutoInteger, + CoercedStr, + UUIDStr, ) from generalresearch.models.thl.locales import CountryISO, LanguageISO from generalresearch.models.thl.survey.condition import ( - MarketplaceCondition, ConditionValueType, + MarketplaceCondition, ) @@ -41,7 +41,7 @@ class LucidCondition(MarketplaceCondition): return hash(self.id) @classmethod - def from_mysql(cls, x): + def from_mysql(cls, x: Dict[str, Any]) -> Self: x["value_type"] = ConditionValueType.LIST x["negate"] = False x["values"] = x.pop("pre_codes").split("|") diff --git a/generalresearch/models/marketplace/summary.py b/generalresearch/models/marketplace/summary.py index 0dd3404..49cf7e9 100644 --- a/generalresearch/models/marketplace/summary.py +++ b/generalresearch/models/marketplace/summary.py @@ -1,10 +1,10 @@ from __future__ import annotations from abc import ABC -from typing import List, Optional, Dict, Literal, Collection +from typing import Collection, Dict, List, Literal, Optional import numpy as np -from pydantic import BaseModel, Field, computed_field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, computed_field from typing_extensions import Self from generalresearch.models.thl.stats import StatisticalSummary diff --git a/generalresearch/models/morning/question.py b/generalresearch/models/morning/question.py index 77ea209..6dca8c0 100644 --- a/generalresearch/models/morning/question.py +++ b/generalresearch/models/morning/question.py @@ -1,17 +1,17 @@ import json from enum import Enum -from typing import List, Optional, Dict, Literal, Any +from typing import Any, Dict, List, Literal, Optional from uuid import UUID -from pydantic import BaseModel, Field, model_validator, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self from generalresearch.locales import Localelator from generalresearch.models import Source from generalresearch.models.morning import MorningQuestionID from generalresearch.models.thl.profiling.marketplace import ( - MarketplaceUserQuestionAnswer, MarketplaceQuestion, + MarketplaceUserQuestionAnswer, ) # todo: we could validate that the country_iso / language_iso exists ... @@ -119,13 +119,14 @@ class MorningQuestion(MarketplaceQuestion): return options @classmethod - def from_api(cls, d: dict, country_iso: str, language_iso: str): + def from_api(cls, d: Dict[str, Any], country_iso: str, language_iso: str): options = None if d.get("responses"): options = [ MorningQuestionOption(id=r["id"], text=r["text"], order=order) for order, r in enumerate(d["responses"]) ] + return cls( id=d["id"], name=d["name"], @@ -137,7 +138,7 @@ class MorningQuestion(MarketplaceQuestion): ) @classmethod - def from_db(cls, d: dict): + def from_db(cls, d: Dict[str, Any]) -> Self: options = None if d["options"]: options = [ @@ -146,6 +147,7 @@ class MorningQuestion(MarketplaceQuestion): ) for r in d["options"] ] + return cls( id=d["question_id"], name=d["question_name"], @@ -167,12 +169,12 @@ class MorningQuestion(MarketplaceQuestion): def to_upk_question(self): from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, UpkQuestionChoice, - UpkQuestionType, + UpkQuestionSelectorHIDDEN, UpkQuestionSelectorMC, UpkQuestionSelectorTE, - UpkQuestionSelectorHIDDEN, - UpkQuestion, + UpkQuestionType, ) upk_type_selector_map = { diff --git a/generalresearch/models/morning/survey.py b/generalresearch/models/morning/survey.py index 6f61661..bc71e58 100644 --- a/generalresearch/models/morning/survey.py +++ b/generalresearch/models/morning/survey.py @@ -6,26 +6,26 @@ from datetime import timezone from decimal import Decimal from functools import cached_property from typing import ( - Optional, - Dict, + Annotated, Any, + Dict, List, + Literal, + Optional, Set, - Annotated, Tuple, - Literal, Type, ) from pydantic import ( - Field, - ConfigDict, BaseModel, - computed_field, + ConfigDict, + Field, NonNegativeInt, - model_validator, PositiveInt, PrivateAttr, + computed_field, + model_validator, ) from typing_extensions import Self @@ -35,7 +35,7 @@ from generalresearch.models.custom_types import ( AwareDatetimeISO, UUIDStrCoerce, ) -from generalresearch.models.morning import MorningStatus, MorningQuestionID +from generalresearch.models.morning import MorningQuestionID, MorningStatus from generalresearch.models.morning.question import MorningQuestion from generalresearch.models.thl.demographics import Gender from generalresearch.models.thl.locales import ( @@ -387,7 +387,7 @@ class MorningBid(MorningTaskStatistics): @model_validator(mode="before") @classmethod - def setup_quota_fields(cls, data: dict) -> dict: + def setup_quota_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]: # These fields get "inherited" by each quota from its bid. quota_fields = [ "country_iso", @@ -396,10 +396,12 @@ class MorningBid(MorningTaskStatistics): "bid_loi", "used_question_ids", ] + for quota in data["quotas"]: for field in quota_fields: if field not in quota: quota[field] = data[field] + return data @model_validator(mode="before") @@ -417,9 +419,10 @@ class MorningBid(MorningTaskStatistics): @model_validator(mode="before") @classmethod - def setup_conditions(cls, data: dict) -> dict: + def setup_conditions(cls, data: Dict[str, Any]) -> Dict[str, Any]: if "conditions" in data: return data + data["conditions"] = dict() for quota in data["quotas"]: if "qualifications" in quota: @@ -445,7 +448,7 @@ class MorningBid(MorningTaskStatistics): @model_validator(mode="before") @classmethod - def clean_alias(cls, data: Dict) -> Dict: + def clean_alias(cls, data: Dict[str, Any]) -> Dict[str, Any]: # Make sure fields are named certain ways, so we don't have to check # aliases within other validators if "estimated_length_of_interview" in data: @@ -500,7 +503,7 @@ class MorningBid(MorningTaskStatistics): return d @classmethod - def from_db(cls, d: Dict[str, Any]): + def from_db(cls, d: Dict[str, Any]) -> Self: d["created"] = d["created"].replace(tzinfo=timezone.utc) d["updated"] = d["updated"].replace(tzinfo=timezone.utc) d["expected_end"] = d["expected_end"].replace(tzinfo=timezone.utc) @@ -522,10 +525,12 @@ class MorningBid(MorningTaskStatistics): self, criteria_evaluation: Dict[str, Optional[bool]] ) -> Tuple[Optional[bool], Optional[List[str]], Optional[Set[str]]]: """ - Quotas are mutually-exclusive. A user can only possibly match 1 quota. As such, all unknown - questions on any quota will be the same unknowns on all. - Returns (the eligibility (True/False/None), passing quota ID or None (if eligibility is not True), - unknown_hashes (or None)) + Quotas are mutually-exclusive. A user can only possibly match 1 + quota. As such, all unknown questions on any quota will be + the same unknowns on all. + + Returns (the eligibility (True/False/None), passing quota ID or + None (if eligibility is not True), unknown_hashes (or None)) """ unknown_quotas = [] unknown_hashes = set() diff --git a/generalresearch/models/morning/task_collection.py b/generalresearch/models/morning/task_collection.py index 174f5e1..8e1eb4b 100644 --- a/generalresearch/models/morning/task_collection.py +++ b/generalresearch/models/morning/task_collection.py @@ -1,14 +1,14 @@ from typing import List, Set import pandas as pd -from pandera import Column, DataFrameSchema, Check, Index +from pandera import Check, Column, DataFrameSchema, Index from generalresearch.locales import Localelator from generalresearch.models.morning import MorningStatus from generalresearch.models.morning.survey import MorningBid from generalresearch.models.thl.survey.task_collection import ( - create_empty_df_from_schema, TaskCollection, + create_empty_df_from_schema, ) COUNTRY_ISOS: Set[str] = Localelator().get_all_countries() @@ -130,10 +130,11 @@ class MorningTaskCollection(TaskCollection): rows.append(d) return rows - def to_df(self): + def to_df(self) -> pd.DataFrame: rows = [] for s in self.items: rows.extend(self.to_rows(s)) + if rows: return pd.DataFrame.from_records(rows, index="quota_id") else: diff --git a/generalresearch/models/pollfish/question.py b/generalresearch/models/pollfish/question.py index 30f8088..a89a793 100644 --- a/generalresearch/models/pollfish/question.py +++ b/generalresearch/models/pollfish/question.py @@ -2,13 +2,18 @@ import json import logging from enum import Enum -from typing import List, Optional, Literal, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Self from pydantic import BaseModel, Field, model_validator from generalresearch.models import Source from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion +if TYPE_CHECKING: + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -73,7 +78,7 @@ class PollfishQuestion(MarketplaceQuestion): return self @classmethod - def from_db(cls, d: dict): + def from_db(cls, d: Dict[str, Any]) -> Self: options = None if d["options"]: options = [ @@ -97,13 +102,13 @@ class PollfishQuestion(MarketplaceQuestion): d["options"] = json.dumps(d["options"]) return d - def to_upk_question(self): + def to_upk_question(self) -> "UpkQuestion": from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, UpkQuestionChoice, - UpkQuestionType, UpkQuestionSelectorMC, UpkQuestionSelectorTE, - UpkQuestion, + UpkQuestionType, order_exclusive_options, ) diff --git a/generalresearch/models/precision/question.py b/generalresearch/models/precision/question.py index 5750bb6..9673f54 100644 --- a/generalresearch/models/precision/question.py +++ b/generalresearch/models/precision/question.py @@ -2,9 +2,9 @@ import json import logging from enum import Enum -from typing import List, Optional, Literal, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Self -from pydantic import BaseModel, Field, model_validator, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from generalresearch.models import Source, string_utils from generalresearch.models.precision import PrecisionQuestionID @@ -13,6 +13,11 @@ from generalresearch.models.thl.profiling.marketplace import ( MarketplaceUserQuestionAnswer, ) +if TYPE_CHECKING: + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -99,7 +104,7 @@ class PrecisionQuestion(MarketplaceQuestion): return self @classmethod - def from_api(cls, d: dict) -> Optional["PrecisionQuestion"]: + def from_api(cls, d: Dict[str, Any]) -> Optional["PrecisionQuestion"]: """ :param d: Raw response from API """ @@ -110,7 +115,7 @@ class PrecisionQuestion(MarketplaceQuestion): return None @classmethod - def _from_api(cls, d: dict) -> "PrecisionQuestion": + def _from_api(cls, d: Dict[str, Any]) -> "PrecisionQuestion": question_type = PrecisionQuestionType.from_api(d["question_type_name"]) # sometimes an empty option is returned .... ? options = [ @@ -132,7 +137,7 @@ class PrecisionQuestion(MarketplaceQuestion): ) @classmethod - def from_db(cls, d: dict): + def from_db(cls, d: Dict[str, Any]) -> "PrecisionQuestion": options = None if d["options"]: options = [ @@ -156,13 +161,13 @@ class PrecisionQuestion(MarketplaceQuestion): d["options"] = json.dumps(d["options"]) return d - def to_upk_question(self): + def to_upk_question(self) -> "UpkQuestion": from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, UpkQuestionChoice, - UpkQuestionType, UpkQuestionSelectorMC, UpkQuestionSelectorTE, - UpkQuestion, + UpkQuestionType, order_exclusive_options, ) diff --git a/generalresearch/models/precision/survey.py b/generalresearch/models/precision/survey.py index 847093b..646d60e 100644 --- a/generalresearch/models/precision/survey.py +++ b/generalresearch/models/precision/survey.py @@ -3,14 +3,14 @@ from __future__ import annotations import json from datetime import timezone from functools import cached_property -from typing import Optional, List, Literal, Set, Dict, Any, Tuple, Type +from typing import Any, Dict, List, Literal, Optional, Self, Set, Tuple, Type from more_itertools import flatten from pydantic import ( + BaseModel, ConfigDict, Field, PrivateAttr, - BaseModel, computed_field, model_validator, ) @@ -18,18 +18,18 @@ from typing_extensions import Annotated from generalresearch.models import Source from generalresearch.models.custom_types import ( - CoercedStr, - UUIDStrCoerce, - AwareDatetimeISO, AlphaNumStrSet, + AwareDatetimeISO, + CoercedStr, DeviceTypes, + UUIDStrCoerce, ) from generalresearch.models.precision import PrecisionQuestionID, PrecisionStatus from generalresearch.models.thl.demographics import Gender from generalresearch.models.thl.survey import MarketplaceTask from generalresearch.models.thl.survey.condition import ( - MarketplaceCondition, ConditionValueType, + MarketplaceCondition, ) @@ -137,7 +137,7 @@ class PrecisionSurvey(MarketplaceTask): # complete_pct: float = Field(ge=0, le=1, validation_alias="cp") # Also skipping: ismultiple (allowing multiple entrances). How is that even possible? They are all False anyways. - bid_loi: int = Field(default=None, ge=59, le=120 * 60, validation_alias="loi") + bid_loi: int = Field(ge=59, le=120 * 60, validation_alias="loi") bid_ir: float = Field(ge=0, le=1, validation_alias="ir") # Be careful with this, it doesn't make any sense. See survey 452481, has 12 completes with a 100% live_ir, # but the only quotas have 0 completes and 1052 terms. .... ?? @@ -257,12 +257,12 @@ class PrecisionSurvey(MarketplaceTask): ) return f"{self.__repr_name__()}({repr_str})" - def is_unchanged(self, other): + def is_unchanged(self, other) -> bool: return self.model_dump( exclude={"updated", "conditions", "created"} ) == other.model_dump(exclude={"updated", "conditions", "created"}) - def to_mysql(self): + def to_mysql(self) -> Dict[str, Any]: d = self.model_dump( mode="json", exclude={ @@ -283,7 +283,7 @@ class PrecisionSurvey(MarketplaceTask): return d @classmethod - def from_db(cls, d: Dict[str, Any]): + def from_db(cls, d: Dict[str, Any]) -> Self: d["created"] = d["created"].replace(tzinfo=timezone.utc) d["updated"] = d["updated"].replace(tzinfo=timezone.utc) d["expected_end_date"] = ( @@ -334,13 +334,13 @@ class PrecisionSurvey(MarketplaceTask): else: # we don't match any quotas, so everything is unknown return None, set( - flatten([m[1] for q, m in quota_eval.items() if m[0] is None]) + flatten([m[1] for _, m in quota_eval.items() if m[0] is None]) ) if True in evals: return True, set() if None in evals: return None, set( - flatten([m[1] for q, m in quota_eval.items() if m[0] is None]) + flatten([m[1] for _, m in quota_eval.items() if m[0] is None]) ) return False, set() diff --git a/generalresearch/models/precision/task_collection.py b/generalresearch/models/precision/task_collection.py index e2942b5..bd7b274 100644 --- a/generalresearch/models/precision/task_collection.py +++ b/generalresearch/models/precision/task_collection.py @@ -1,7 +1,7 @@ -from typing import List +from typing import Any, Dict, List import pandas as pd -from pandera import Column, DataFrameSchema, Check, Index +from pandera import Check, Column, DataFrameSchema, Index from generalresearch.locales import Localelator from generalresearch.models.precision import PrecisionStatus @@ -56,7 +56,7 @@ class PrecisionTaskCollection(TaskCollection): items: List[PrecisionSurvey] _schema = PrecisionTaskCollectionSchema - def to_row(self, s: PrecisionSurvey): + def to_row(self, s: PrecisionSurvey) -> Dict[str, Any]: d = s.model_dump( mode="json", exclude={ diff --git a/generalresearch/models/prodege/__init__.py b/generalresearch/models/prodege/__init__.py index c7bc4e7..d419c0c 100644 --- a/generalresearch/models/prodege/__init__.py +++ b/generalresearch/models/prodege/__init__.py @@ -30,8 +30,10 @@ class ProdegePastParticipationType(str, Enum): # This is the value of the 'status' url param in the redirect # https://developer.prodege.com/surveys-feed/redirects -# Note: there is no status for ProdegePastParticipationType.CLICK b/c that would be an abandonent +# Note: there is no status for ProdegePastParticipationType.CLICK b/c +# that would be an abandonent # Note: there is no ProdegePastParticipationType for quality (status 4) ProdgeRedirectStatus = Literal["1", "2", "3", "4"] -# I'm not using the ProdegePastParticipationType for the values here b/c there is not a 1-to-1 mapping. +# I'm not using the ProdegePastParticipationType for the values here +# b/c there is not a 1-to-1 mapping. ProdgeRedirectStatusNameMap = {"1": "complete", "2": "oq", "3": "dq", "4": "quality"} diff --git a/generalresearch/models/prodege/question.py b/generalresearch/models/prodege/question.py index 0ae4548..e3cdf08 100644 --- a/generalresearch/models/prodege/question.py +++ b/generalresearch/models/prodege/question.py @@ -6,16 +6,21 @@ import logging from datetime import datetime, timezone from enum import Enum from functools import cached_property -from typing import List, Optional, Literal, Any, Dict, Set +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Self, Set -from pydantic import BaseModel, Field, model_validator, ConfigDict, PositiveInt +from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator from generalresearch.locales import Localelator -from generalresearch.models import Source, MAX_INT32 +from generalresearch.models import MAX_INT32, Source from generalresearch.models.custom_types import AwareDatetimeISO from generalresearch.models.prodege import ProdegeQuestionIdType from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion +if TYPE_CHECKING: + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -24,23 +29,29 @@ locale_helper = Localelator() class ProdegeUserQuestionAnswer(BaseModel): - # This is optional b/c this model can be used for eligibility checks for "anonymous" users, which are represented - # by a list of question answers not associated with an actual user. No default b/c we must explicitly set - # the field to None. + # This is optional b/c this model can be used for eligibility checks + # for "anonymous" users, which are represented by a list of question + # answers not associated with an actual user. No default b/c we must + # explicitly set the field to None. user_id: Optional[PositiveInt] = Field(lt=MAX_INT32) question_id: ProdegeQuestionIdType = Field() - # This is optional b/c we do not need it when writing these to the db. When these are fetched from the db - # for use in yield-management, we read this field from the prodege_question table. + + # This is optional b/c we do not need it when writing these to the + # db. When these are fetched from the db for use in yield-management, + # we read this field from the prodege_question table. question_type: Optional[ProdegeQuestionType] = Field(default=None) + # This may be a pipe-separated string if the question_type is multi. regex means any chars except capital letters option_id: str = Field(pattern=r"^[^A-Z]*$") created: AwareDatetimeISO = Field( default_factory=lambda: datetime.now(tz=timezone.utc) ) + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) country_iso: str = Field( max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True ) + # 3-char ISO 639-2/B, lowercase language_iso: str = Field( max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True @@ -72,10 +83,12 @@ class ProdegeQuestionOption(BaseModel): validation_alias="option_text", description="The response text shown to respondents", ) - # Order does not come back explicitly in the API, but the responses seem to be ordered + # Order does not come back explicitly in the API, but the + # responses seem to be ordered order: int = Field() - # Both is_exclusive and is_anchored are returned, but I don't see how they are different. - # We are merging them both into is_exclusive. + + # Both is_exclusive and is_anchored are returned, but I don't see how + # they are different. We are merging them both into is_exclusive. is_exclusive: bool = Field(default=False) @@ -126,7 +139,9 @@ class ProdegeQuestion(MarketplaceQuestion): return self @classmethod - def from_api(cls, d: dict, country_iso: str) -> Optional["ProdegeQuestion"]: + def from_api( + cls, d: Dict[str, Any], country_iso: str + ) -> Optional["ProdegeQuestion"]: """ :param d: Raw response from API """ @@ -137,7 +152,7 @@ class ProdegeQuestion(MarketplaceQuestion): return None @classmethod - def _from_api(cls, d: dict, country_iso: str) -> "ProdegeQuestion": + def _from_api(cls, d: Dict[str, Any], country_iso: str) -> "ProdegeQuestion": # The API has no concept of language at all. Questions for a country # are returned both in english and other languages. Questions do have # a field 'country_specific', and if True, that generally means the @@ -170,7 +185,7 @@ class ProdegeQuestion(MarketplaceQuestion): return cls.model_validate(d) @classmethod - def from_db(cls, d: dict): + def from_db(cls, d: Dict[str, Any]) -> "ProdegeQuestion": options = None if d["options"]: options = [ @@ -200,13 +215,13 @@ class ProdegeQuestion(MarketplaceQuestion): d["options"] = json.dumps(d["options"]) return d - def to_upk_question(self): + def to_upk_question(self) -> "UpkQuestion": from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, UpkQuestionChoice, - UpkQuestionType, UpkQuestionSelectorMC, UpkQuestionSelectorTE, - UpkQuestion, + UpkQuestionType, order_exclusive_options, ) diff --git a/generalresearch/models/prodege/survey.py b/generalresearch/models/prodege/survey.py index 378f0a7..eaae883 100644 --- a/generalresearch/models/prodege/survey.py +++ b/generalresearch/models/prodege/survey.py @@ -4,34 +4,34 @@ from __future__ import annotations import json import logging from collections import defaultdict -from datetime import timezone, datetime +from datetime import datetime, timezone from decimal import Decimal from functools import cached_property -from typing import List, Optional, Dict, Any, Set, Literal, Tuple, Type +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type from pydantic import ( BaseModel, - Field, ConfigDict, + Field, computed_field, - model_validator, field_validator, + model_validator, ) from generalresearch.locales import Localelator from generalresearch.models import LogicalOperator, Source, TaskCalculationType from generalresearch.models.custom_types import ( AlphaNumStrSet, + AwareDatetimeISO, + CoercedStr, InclExcl, UUIDStr, - CoercedStr, - AwareDatetimeISO, ) from generalresearch.models.prodege import ( - ProdegeStatus, + ProdegePastParticipationType, ProdegeQuestionIdType, + ProdegeStatus, ProdgeRedirectStatus, - ProdegePastParticipationType, ) from generalresearch.models.prodege.definitions import PG_COUNTRY_TO_ISO from generalresearch.models.thl.demographics import Gender @@ -151,15 +151,18 @@ class ProdegeQuota(BaseModel): } @classmethod - def from_api(cls, d: Dict): + def from_api(cls, d: Dict[str, Any]) -> "ProdegeQuota": # the API doesn't handle None's correctly? idk if d["parent_quota_id"] == 0: d["parent_quota_id"] = None + d["calculation_type"] = TaskCalculationType.prodege_from_api( d["calculation_type"] ) + if d.get("country_id"): d["country_iso"] = PG_COUNTRY_TO_ISO[d["country_id"]] + return cls.model_validate(d) def passes( @@ -174,12 +177,13 @@ class ProdegeQuota(BaseModel): self, criteria_evaluation: Dict[str, Optional[bool]], country_iso: str ) -> bool: # Match means we meet all conditions. - # We can "match" a quota that is closed. In that case, we would fail the parent quota + # We can "match" a quota that is closed. In that case, we would + # fail the parent quota return self.matches_country(country_iso) and all( criteria_evaluation.get(c) for c in self.condition_hashes ) - def matches_country(self, country_iso: str): + def matches_country(self, country_iso: str) -> bool: return self.country_iso is None or self.country_iso == country_iso def passes_verbose( @@ -224,9 +228,12 @@ class ProdegeMaxClicksSetting(BaseModel): # The total number of clicks allowed before survey traffic is paused. cap: int = Field(validation_alias="max_clicks_cap") + # The current remaining number of clicks before survey traffic is paused. allowed_clicks: int = Field(validation_alias="max_clicks_allowed_clicks") - # The refill rate id for clicks (1: every 30 min, 2: every 1 hour, 3: every 24 hours, 0: one-time setting). + + # The refill rate id for clicks (1: every 30 min, 2: every 1 hour, + # 3: every 24 hours, 0: one-time setting). # (not going to bother structuring this, we can't really use it...) max_click_rate_id: int = Field(validation_alias="max_clicks_max_click_rate_id") @@ -243,24 +250,31 @@ class ProdegeUserPastParticipation(BaseModel): @property def participation_types(self) -> Set[ProdegePastParticipationType]: - # If the survey is filtering completes, then only a complete counts. But if the survey is filtering - # on clicks, then a person who got a complete ALSO did click. And so, the logic here is that + # If the survey is filtering completes, then only a complete + # counts. But if the survey is filtering on clicks, then a person + # who got a complete ALSO did click. And so, the logic here is that # participation_types should always include "click". if self.ext_status_code_1 is None: return {ProdegePastParticipationType.CLICK} + elif self.ext_status_code_1 == "1": return { ProdegePastParticipationType.CLICK, ProdegePastParticipationType.COMPLETE, } + elif self.ext_status_code_1 == "2": return {ProdegePastParticipationType.CLICK, ProdegePastParticipationType.OQ} + elif self.ext_status_code_1 == "3": return {ProdegePastParticipationType.CLICK, ProdegePastParticipationType.DQ} + elif self.ext_status_code_1 == "4": # 4 means "Quality Disqualification". unclear which participation type this is. return {ProdegePastParticipationType.CLICK, ProdegePastParticipationType.DQ} + raise ValueError(f"Unknown ext_status_code_1: {self.ext_status_code_1}") + def days_ago(self) -> float: now = datetime.now(timezone.utc) return (now - self.started).total_seconds() / (3600 * 24) @@ -285,15 +299,18 @@ class ProdegePastParticipation(BaseModel): """ @classmethod - def from_api(cls, d: Dict): + def from_api(cls, d: Dict[str, Any]) -> "ProdegePastParticipation": # the API doesn't handle None's correctly? idk if d["in_past_days"] == 0: d["in_past_days"] = None + d["participation_project_ids"] = list(map(str, d["participation_project_ids"])) return cls.model_validate(d) def user_participated(self, user_participation: ProdegeUserPastParticipation): - # Given this user's participation event (1 single event), is it being filtered by this survey? + # Given this user's participation event (1 single event), is it + # being filtered by this survey? + return ( user_participation.survey_id in self.survey_ids and ( @@ -493,7 +510,7 @@ class ProdegeSurvey(MarketplaceTask): return round(float(v), 2) @classmethod - def from_api(cls, d: Dict) -> Optional["ProdegeSurvey"]: + def from_api(cls, d: Dict[str, Any]) -> Optional["ProdegeSurvey"]: try: return cls._from_api(d) except Exception as e: @@ -501,16 +518,19 @@ class ProdegeSurvey(MarketplaceTask): return None @classmethod - def _from_api(cls, d: Dict): - # handle phases. keys in api response are 'loi' and 'actual_ir' + def _from_api(cls, d: Dict[str, Any]) -> "ProdegeSurvey": + + # Handle phases. keys in api response are 'loi' and 'actual_ir' if d["phases"]["loi_phase"] == "actual": d["actual_loi"] = d.pop("loi") * 60 else: d["bid_loi"] = d.pop("loi") * 60 + if d["phases"]["actual_ir_phase"] == "actual": d["actual_ir"] = d.pop("actual_ir") / 100 else: d["bid_ir"] = d.pop("actual_ir") / 100 + d["conversion_rate"] = ( d["conversion_rate"] / 100 if d["conversion_rate"] else None ) @@ -576,7 +596,7 @@ class ProdegeSurvey(MarketplaceTask): sub_res.append(q) return res - def is_unchanged(self, other): + def is_unchanged(self, other) -> bool: # Avoiding overloading __eq__ because it looks kind of complicated? I # want to be explicit that this is not testing object equivalence, # just that the objects don't require any db updates. We also exclude @@ -593,7 +613,8 @@ class ProdegeSurvey(MarketplaceTask): exclude={"created", "updated", "conditions", "survey_name"} ) if o1 == o2: - # We don't have to check bid/actual, b/c we already know it's not changed + # We don't have to check bid/actual, b/c we already know + # it's not changed return True # Ignore bid fields if either one is NULL @@ -604,7 +625,7 @@ class ProdegeSurvey(MarketplaceTask): return o1 == o2 - def to_mysql(self): + def to_mysql(self) -> Dict[str, Any]: d = self.model_dump( mode="json", exclude={ @@ -618,6 +639,7 @@ class ProdegeSurvey(MarketplaceTask): }, ) d["quotas"] = json.dumps(d["quotas"]) + for k in [ "max_clicks_settings", "past_participation", @@ -625,13 +647,15 @@ class ProdegeSurvey(MarketplaceTask): "exclude_psids", ]: d[k] = json.dumps(d[k]) if d[k] else None + d["used_question_ids"] = json.dumps(d["used_question_ids"]) d["created"] = self.created d["updated"] = self.updated + return d @classmethod - def from_db(cls, d: Dict[str, Any]): + def from_db(cls, d: Dict[str, Any]) -> "ProdegeSurvey": d["created"] = d["created"].replace(tzinfo=timezone.utc) d["updated"] = d["updated"].replace(tzinfo=timezone.utc) d["quotas"] = json.loads(d["quotas"]) @@ -653,7 +677,7 @@ class ProdegeSurvey(MarketplaceTask): self, criteria_evaluation: Dict[str, Optional[bool]], country_iso: str, - verbose=False, + verbose: bool = False, ) -> bool: # https://developer.prodege.com/surveys-feed/api-reference/survey-matching/quota-structure # https://developer.prodege.com/surveys-feed/api-reference/survey-matching/quota-matching-requirements @@ -700,7 +724,7 @@ class ProdegeSurvey(MarketplaceTask): criteria_evaluation: Dict[str, Optional[bool]], child_quotas: List[ProdegeQuota], country_iso: str, - verbose=False, + verbose: bool = False, ) -> bool: if len(child_quotas) == 0: # If the parent has no children, we pass @@ -745,3 +769,5 @@ class ProdegeSurvey(MarketplaceTask): criteria_evaluation, country_iso=country_iso, verbose=True ) ) + + return None diff --git a/generalresearch/models/prodege/task_collection.py b/generalresearch/models/prodege/task_collection.py index 4765f29..0ea09e1 100644 --- a/generalresearch/models/prodege/task_collection.py +++ b/generalresearch/models/prodege/task_collection.py @@ -1,7 +1,7 @@ -from typing import List, Dict, Any +from typing import Any, Dict, List import pandas as pd -from pandera import Column, DataFrameSchema, Check, Index +from pandera import Check, Column, DataFrameSchema, Index from generalresearch.locales import Localelator from generalresearch.models.prodege import ProdegeStatus diff --git a/generalresearch/models/repdata/question.py b/generalresearch/models/repdata/question.py index 8b4eb4e..d33a4d8 100644 --- a/generalresearch/models/repdata/question.py +++ b/generalresearch/models/repdata/question.py @@ -4,22 +4,27 @@ import json import logging from enum import Enum from functools import cached_property -from typing import List, Optional, Literal, Any, Dict, Set +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set from uuid import UUID from pydantic import ( BaseModel, - Field, - model_validator, ConfigDict, - field_validator, + Field, PositiveInt, + field_validator, + model_validator, ) -from generalresearch.models import Source, MAX_INT32 -from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models import MAX_INT32, Source +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion +if TYPE_CHECKING: + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -155,7 +160,7 @@ class RepDataQuestion(MarketplaceQuestion): @classmethod def from_api( - cls, d: dict, country_iso: str, language_iso: str + cls, d: Dict[str, Any], country_iso: str, language_iso: str ) -> Optional["RepDataQuestion"]: """ :param d: Raw response from API @@ -168,7 +173,7 @@ class RepDataQuestion(MarketplaceQuestion): @classmethod def _from_api( - cls, d: dict, country_iso: str, language_iso: str + cls, d: Dict[str, Any], country_iso: str, language_iso: str ) -> "RepDataQuestion": d["QualificationType"] = RepDataQuestionType.from_api(d["QualificationType"]) # zip code/age has a placeholder invalid option for some reason @@ -186,7 +191,7 @@ class RepDataQuestion(MarketplaceQuestion): ) @classmethod - def from_db(cls, d: dict): + def from_db(cls, d: Dict[str, Any]) -> "RepDataQuestion": options = None if d["options"]: options = [ @@ -212,13 +217,13 @@ class RepDataQuestion(MarketplaceQuestion): d["options"] = json.dumps(d["options"]) return d - def to_upk_question(self): + def to_upk_question(self) -> "UpkQuestion": from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, UpkQuestionChoice, - UpkQuestionType, UpkQuestionSelectorMC, UpkQuestionSelectorTE, - UpkQuestion, + UpkQuestionType, order_exclusive_options, ) diff --git a/generalresearch/models/repdata/survey.py b/generalresearch/models/repdata/survey.py index 572e696..b6acd92 100644 --- a/generalresearch/models/repdata/survey.py +++ b/generalresearch/models/repdata/survey.py @@ -6,38 +6,38 @@ import logging from datetime import datetime, timezone from decimal import Decimal from functools import cached_property -from typing import List, Optional, Dict, Any, Set, Literal, Type -from typing_extensions import Self +from typing import Any, Dict, List, Literal, Optional, Set, Type from uuid import UUID from pydantic import ( BaseModel, - Field, - field_validator, ConfigDict, + Field, computed_field, + field_validator, model_validator, ) +from typing_extensions import Self from generalresearch.grpc import timestamp_from_datetime from generalresearch.locales import Localelator from generalresearch.models import ( + DeviceType, LogicalOperator, Source, TaskCalculationType, - DeviceType, ) from generalresearch.models.custom_types import ( + AwareDatetimeISO, CoercedStr, UUIDStr, - AwareDatetimeISO, ) from generalresearch.models.repdata import RepDataStatus from generalresearch.models.thl.demographics import Gender from generalresearch.models.thl.survey import MarketplaceTask from generalresearch.models.thl.survey.condition import ( - MarketplaceCondition, ConditionValueType, + MarketplaceCondition, ) logging.basicConfig() @@ -328,7 +328,7 @@ class RepDataStream(MarketplaceTask): } @classmethod - def from_api(cls, stream_res, country_iso, language_iso): + def from_api(cls, stream_res, country_iso: str, language_iso: str): # qualifications and quotas need to be added to the stream_res manually d = stream_res.copy() d["CalculationType"] = TaskCalculationType.from_api(d["CalculationType"]) @@ -361,7 +361,9 @@ class RepDataStreamHashed(RepDataStream): return d @classmethod - def from_db(cls, res, survey: RepDataSurveyHashed): + def from_db( + cls, res: Dict[str, Any], survey: RepDataSurveyHashed + ) -> "RepDataStreamHashed": # We need certain fields copied over here so that a stream can exist # independent of the survey res["country_iso"] = survey.country_iso @@ -531,7 +533,7 @@ class RepDataSurveyHashed(RepDataSurvey): streams: None = Field(default=None, exclude=True) @classmethod - def from_db(cls, res): + def from_db(cls, res: Dict[str, Any]) -> "RepDataSurveyHashed": res["allowed_devices"] = [ DeviceType(int(x)) for x in res["allowed_devices"].split(",") ] @@ -553,6 +555,7 @@ class RepDataSurveyHashed(RepDataSurvey): def to_grpc(self, repdata_pb2): now = datetime.now(tz=timezone.utc) timestamp = timestamp_from_datetime(now) + return repdata_pb2.RepDataOpportunity( json_str=self.model_dump_json(), timestamp=timestamp, diff --git a/generalresearch/models/repdata/task_collection.py b/generalresearch/models/repdata/task_collection.py index 7b99638..740670b 100644 --- a/generalresearch/models/repdata/task_collection.py +++ b/generalresearch/models/repdata/task_collection.py @@ -1,7 +1,7 @@ -from typing import List +from typing import Any, Dict, List import pandas as pd -from pandera import Column, DataFrameSchema, Check, Index +from pandera import Check, Column, DataFrameSchema, Index from generalresearch.locales import Localelator from generalresearch.models import TaskCalculationType @@ -81,7 +81,7 @@ class RepDataTaskCollection(TaskCollection): items: List[RepDataSurveyHashed] _schema = RepDataTaskCollectionSchema - def to_rows(self, s: RepDataSurveyHashed): + def to_rows(self, s: RepDataSurveyHashed) -> List[Dict[str, Any]]: survey_fields = [ "survey_id", "survey_uuid", @@ -122,7 +122,7 @@ class RepDataTaskCollection(TaskCollection): rows.append(ds) return rows - def to_df(self): + def to_df(self) -> pd.DataFrame: rows = [] for s in self.items: rows.extend(self.to_rows(s)) diff --git a/generalresearch/models/sago/question.py b/generalresearch/models/sago/question.py index 8137b38..62ff363 100644 --- a/generalresearch/models/sago/question.py +++ b/generalresearch/models/sago/question.py @@ -4,21 +4,26 @@ import json import logging from enum import Enum from functools import cached_property -from typing import List, Optional, Literal, Any, Dict, Set +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set from pydantic import ( BaseModel, + ConfigDict, Field, - model_validator, - field_validator, PositiveInt, - ConfigDict, + field_validator, + model_validator, ) -from generalresearch.models import Source, string_utils, MAX_INT32 +from generalresearch.models import MAX_INT32, Source, string_utils from generalresearch.models.custom_types import AwareDatetimeISO from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion +if TYPE_CHECKING: + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -165,7 +170,7 @@ class SagoQuestion(MarketplaceQuestion): @classmethod def from_api( - cls, d: dict, country_iso: str, language_iso: str + cls, d: Dict[str, Any], country_iso: str, language_iso: str ) -> Optional["SagoQuestion"]: """ :param d: Raw response from API @@ -180,7 +185,9 @@ class SagoQuestion(MarketplaceQuestion): return None @classmethod - def _from_api(cls, d: dict, country_iso, language_iso) -> "SagoQuestion": + def _from_api( + cls, d: Dict[str, Any], country_iso: str, language_iso: str + ) -> "SagoQuestion": sago_category_to_tags = { 1: "Standard", 2: "Custom", @@ -214,7 +221,7 @@ class SagoQuestion(MarketplaceQuestion): ) @classmethod - def from_db(cls, d: dict): + def from_db(cls, d: Dict[str, Any]) -> "SagoQuestion": options = None if d["options"]: options = [ @@ -241,13 +248,13 @@ class SagoQuestion(MarketplaceQuestion): d["options"] = json.dumps(d["options"]) return d - def to_upk_question(self): + def to_upk_question(self) -> "UpkQuestion": from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, UpkQuestionChoice, - UpkQuestionType, UpkQuestionSelectorMC, UpkQuestionSelectorTE, - UpkQuestion, + UpkQuestionType, order_exclusive_options, ) diff --git a/generalresearch/models/sago/survey.py b/generalresearch/models/sago/survey.py index a2e9a8a..a8188b0 100644 --- a/generalresearch/models/sago/survey.py +++ b/generalresearch/models/sago/survey.py @@ -5,19 +5,19 @@ import logging from datetime import timezone from decimal import Decimal from functools import cached_property -from typing import Optional, Dict, Any, List, Literal, Set, Tuple, Annotated, Type -from typing_extensions import Self +from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Tuple, Type from more_itertools import flatten -from pydantic import Field, ConfigDict, BaseModel, model_validator, computed_field +from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator +from typing_extensions import Self from generalresearch.locales import Localelator -from generalresearch.models import Source, LogicalOperator +from generalresearch.models import LogicalOperator, Source from generalresearch.models.custom_types import ( - CoercedStr, - AwareDatetimeISO, - AlphaNumStrSet, AlphaNumStr, + AlphaNumStrSet, + AwareDatetimeISO, + CoercedStr, DeviceTypes, IPLikeStrSet, ) @@ -80,7 +80,7 @@ class SagoQuota(BaseModel): return self.remaining_count >= min_open_spots @classmethod - def from_api(cls, d: Dict) -> Self: + def from_api(cls, d: Dict[str, Any]) -> Self: return cls.model_validate(d) def passes(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: @@ -259,7 +259,7 @@ class SagoSurvey(MarketplaceTask): } @classmethod - def from_api(cls, d: Dict) -> Optional["SagoSurvey"]: + def from_api(cls, d: Dict[str, Any]) -> Optional["SagoSurvey"]: try: return cls._from_api(d) except Exception as e: @@ -267,7 +267,7 @@ class SagoSurvey(MarketplaceTask): return None @classmethod - def _from_api(cls, d: Dict): + def _from_api(cls, d: Dict[str, Any]) -> "SagoSurvey": return cls.model_validate(d) def __repr__(self) -> str: @@ -285,7 +285,7 @@ class SagoSurvey(MarketplaceTask): ) return f"{self.__repr_name__()}({repr_str})" - def is_unchanged(self, other): + def is_unchanged(self, other) -> bool: # Avoiding overloading __eq__ because it looks kind of complicated? I # want to be explicit that this is not testing object equivalence, just # that the objects don't require any db updates. We also exclude @@ -294,7 +294,7 @@ class SagoSurvey(MarketplaceTask): exclude={"updated", "conditions", "created"} ) == other.model_dump(exclude={"updated", "conditions", "created"}) - def to_mysql(self): + def to_mysql(self) -> Dict[str, Any]: d = self.model_dump( mode="json", exclude={ diff --git a/generalresearch/models/sago/task_collection.py b/generalresearch/models/sago/task_collection.py index c2f168a..41d047e 100644 --- a/generalresearch/models/sago/task_collection.py +++ b/generalresearch/models/sago/task_collection.py @@ -1,7 +1,7 @@ -from typing import List, Set +from typing import Any, Dict, List, Set import pandas as pd -from pandera import Column, DataFrameSchema, Check, Index +from pandera import Check, Column, DataFrameSchema, Index from generalresearch.locales import Localelator from generalresearch.models.sago import SagoStatus @@ -51,7 +51,7 @@ class SagoTaskCollection(TaskCollection): items: List[SagoSurvey] _schema = SagoTaskCollectionSchema - def to_row(self, s: SagoSurvey): + def to_row(self, s: SagoSurvey) -> Dict[str, Any]: d = s.model_dump( mode="json", exclude={ @@ -71,7 +71,7 @@ class SagoTaskCollection(TaskCollection): d["cpi"] = float(s.cpi) return d - def to_df(self): + def to_df(self) -> pd.DataFrame: rows = [] for s in self.items: rows.append(self.to_row(s)) diff --git a/generalresearch/models/spectrum/question.py b/generalresearch/models/spectrum/question.py index 549fc7e..4f8b5e1 100644 --- a/generalresearch/models/spectrum/question.py +++ b/generalresearch/models/spectrum/question.py @@ -6,25 +6,30 @@ import logging from datetime import datetime, timezone from enum import Enum from functools import cached_property -from typing import List, Optional, Literal, Any, Dict, Set +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set from uuid import UUID from pydantic import ( BaseModel, Field, - model_validator, - field_validator, PositiveInt, + field_validator, + model_validator, ) from typing_extensions import Self -from generalresearch.models import Source, string_utils, MAX_INT32 +from generalresearch.models import MAX_INT32, Source, string_utils from generalresearch.models.custom_types import AwareDatetimeISO from generalresearch.models.spectrum import SpectrumQuestionIdType from generalresearch.models.thl.profiling.marketplace import ( MarketplaceQuestion, ) +if TYPE_CHECKING: + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -246,7 +251,7 @@ class SpectrumQuestion(MarketplaceQuestion): @classmethod def from_api( - cls, d: dict, country_iso: str, language_iso: str + cls, d: Dict[str, Any], country_iso: str, language_iso: str ) -> Optional["SpectrumQuestion"]: # To not pollute our logs, we know we are skipping any question that # meets the following conditions: @@ -263,7 +268,7 @@ class SpectrumQuestion(MarketplaceQuestion): return None @classmethod - def _from_api(cls, d: dict, country_iso: str, language_iso: str) -> Self: + def _from_api(cls, d: Dict[str, Any], country_iso: str, language_iso: str) -> Self: options = None if d.get("condition_codes"): # Sometimes they use the key "name" instead of "text" ... ? @@ -296,7 +301,7 @@ class SpectrumQuestion(MarketplaceQuestion): ) @classmethod - def from_db(cls, d: dict) -> Self: + def from_db(cls, d: Dict[str, Any]) -> Self: options = None if d["options"]: options = [ @@ -331,13 +336,13 @@ class SpectrumQuestion(MarketplaceQuestion): d["created"] = self.created.replace(tzinfo=None) return d - def to_upk_question(self): + def to_upk_question(self) -> "UpkQuestion": from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, UpkQuestionChoice, - UpkQuestionType, UpkQuestionSelectorMC, UpkQuestionSelectorTE, - UpkQuestion, + UpkQuestionType, ) upk_type_selector_map = { diff --git a/generalresearch/models/spectrum/survey.py b/generalresearch/models/spectrum/survey.py index 7bebaa2..72093eb 100644 --- a/generalresearch/models/spectrum/survey.py +++ b/generalresearch/models/spectrum/survey.py @@ -4,20 +4,20 @@ import json import logging from datetime import timezone from decimal import Decimal -from typing import Optional, Dict, Any, List, Literal, Set, Tuple, Type -from typing_extensions import Self +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type from more_itertools import flatten -from pydantic import Field, ConfigDict, BaseModel, model_validator, computed_field +from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator +from typing_extensions import Self from generalresearch.locales import Localelator -from generalresearch.models import TaskCalculationType, Source +from generalresearch.models import Source, TaskCalculationType from generalresearch.models.custom_types import ( - CoercedStr, - AwareDatetimeISO, + AlphaNumStr, AlphaNumStrSet, + AwareDatetimeISO, + CoercedStr, UUIDStrSet, - AlphaNumStr, ) from generalresearch.models.spectrum import SpectrumStatus from generalresearch.models.thl.demographics import Gender @@ -321,7 +321,7 @@ class SpectrumSurvey(MarketplaceTask): } @classmethod - def from_api(cls, d: Dict) -> Optional["SpectrumSurvey"]: + def from_api(cls, d: Dict[str, Any]) -> Optional["SpectrumSurvey"]: try: return cls._from_api(d) except Exception as e: @@ -329,7 +329,7 @@ class SpectrumSurvey(MarketplaceTask): return None @classmethod - def _from_api(cls, d: Dict) -> Self: + def _from_api(cls, d: Dict[str, Any]) -> Self: assert d["click_balancing"] in {0, 1}, "unknown click_balancing value" d["calculation_type"] = ( TaskCalculationType.STARTS diff --git a/generalresearch/models/spectrum/task_collection.py b/generalresearch/models/spectrum/task_collection.py index 8aeac7d..3c2ee85 100644 --- a/generalresearch/models/spectrum/task_collection.py +++ b/generalresearch/models/spectrum/task_collection.py @@ -1,15 +1,15 @@ -from typing import List, Set, Dict +from typing import Dict, List, Set import pandas as pd -from pandera import Column, DataFrameSchema, Check, Index +from pandera import Check, Column, DataFrameSchema, Index from generalresearch.locales import Localelator from generalresearch.models import TaskCalculationType from generalresearch.models.spectrum import SpectrumStatus from generalresearch.models.spectrum.survey import SpectrumSurvey from generalresearch.models.thl.survey.task_collection import ( - create_empty_df_from_schema, TaskCollection, + create_empty_df_from_schema, ) COUNTRY_ISOS: Set[str] = Localelator().get_all_countries() @@ -100,7 +100,7 @@ class SpectrumTaskCollection(TaskCollection): rows.append(d) return rows - def to_df(self): + def to_df(self) -> pd.DataFrame: rows = [] for s in self.items: rows.extend(self.to_rows(s)) diff --git a/generalresearch/models/thl/__init__.py b/generalresearch/models/thl/__init__.py index 875b2bb..0356842 100644 --- a/generalresearch/models/thl/__init__.py +++ b/generalresearch/models/thl/__init__.py @@ -2,8 +2,8 @@ from decimal import Decimal from typing import Optional from generalresearch.models.thl.finance import ( - ProductBalances, POPFinancial, + ProductBalances, ) from generalresearch.models.thl.payout import ( BrokerageProductPayoutEvent, diff --git a/generalresearch/models/thl/category.py b/generalresearch/models/thl/category.py index fe330f1..4e9e2ff 100644 --- a/generalresearch/models/thl/category.py +++ b/generalresearch/models/thl/category.py @@ -1,7 +1,7 @@ -from typing import Optional +from typing import Any, Dict, Optional from uuid import uuid4 -from pydantic import BaseModel, Field, model_validator, PositiveInt +from pydantic import BaseModel, Field, PositiveInt, model_validator from typing_extensions import Self from generalresearch.models.custom_types import UUIDStr @@ -53,7 +53,7 @@ class Category(BaseModel, frozen=True): def is_root(self) -> bool: return self.parent_path is None - def to_offerwall_api(self) -> dict: + def to_offerwall_api(self) -> Dict[str, Any]: return { "id": self.uuid, "label": self.label, diff --git a/generalresearch/models/thl/contest/__init__.py b/generalresearch/models/thl/contest/__init__.py index 7cf1f54..f9693eb 100644 --- a/generalresearch/models/thl/contest/__init__.py +++ b/generalresearch/models/thl/contest/__init__.py @@ -1,20 +1,20 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional from uuid import uuid4 from pydantic import ( BaseModel, Field, - model_validator, - computed_field, PositiveInt, + computed_field, + model_validator, ) from typing_extensions import Self from generalresearch.currency import USDCent -from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr from generalresearch.models.thl.contest.definitions import ContestPrizeKind from generalresearch.models.thl.user import User @@ -34,17 +34,18 @@ class ContestEntryRule(BaseModel): default=None, ) - # TODO: Only allow entries if user meets some criteria: gold-membership status, - # ID/phone verified, min_completes etc... - # Maybe these get put in a separate model b/c the could apply if the ContestType is not ENTRY + # TODO: Only allow entries if user meets some criteria: gold-membership + # status, ID/phone verified, min_completes etc... Maybe these get put + # in a separate model b/c the could apply if the ContestType is not ENTRY min_completes: Optional[int] = None min_membership_level: Optional[int] = None id_verified: Optional[bool] = None class ContestEndCondition(BaseModel): - """Defines the conditions to evaluate to determine when the contest is over. - Multiple conditions can be set. The contest is over once ANY conditions are met. + """Defines the conditions to evaluate to determine when the contest + is over. Multiple conditions can be set. The contest is over + once ANY conditions are met. """ target_entry_amount: USDCent | PositiveInt | None = Field( diff --git a/generalresearch/models/thl/contest/contest.py b/generalresearch/models/thl/contest/contest.py index 232a038..e644ae4 100644 --- a/generalresearch/models/thl/contest/contest.py +++ b/generalresearch/models/thl/contest/contest.py @@ -1,31 +1,31 @@ from __future__ import annotations import json -from abc import abstractmethod, ABC -from datetime import timezone, datetime -from typing import List, Tuple, Optional, Dict +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Tuple from uuid import uuid4 from pydantic import ( BaseModel, + ConfigDict, Field, HttpUrl, - ConfigDict, - model_validator, NonNegativeInt, + model_validator, ) from typing_extensions import Self -from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr from generalresearch.models.thl.contest import ( ContestEndCondition, ContestPrize, ContestWinner, ) from generalresearch.models.thl.contest.definitions import ( + ContestEndReason, ContestStatus, ContestType, - ContestEndReason, ) from generalresearch.models.thl.locales import CountryISOs @@ -169,7 +169,7 @@ class Contest(ContestBase): ) return None - def model_dump_mysql(self, **kwargs) -> Dict: + def model_dump_mysql(self, **kwargs) -> Dict[str, Any]: d = self.model_dump(mode="json", **kwargs) d["created_at"] = self.created_at @@ -183,7 +183,7 @@ class Contest(ContestBase): return d @classmethod - def model_validate_mysql(cls, data) -> Self: + def model_validate_mysql(cls, data: Dict[str, Any]) -> Self: data = {k: v for k, v in data.items() if k in cls.model_fields.keys()} if isinstance(data["end_condition"], dict): data["end_condition"] = ContestEndCondition.model_validate( diff --git a/generalresearch/models/thl/contest/contest_entry.py b/generalresearch/models/thl/contest/contest_entry.py index 146f06f..19586da 100644 --- a/generalresearch/models/thl/contest/contest_entry.py +++ b/generalresearch/models/thl/contest/contest_entry.py @@ -1,18 +1,18 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Union, Dict, Any +from typing import Any, Dict, Union from uuid import uuid4 from pydantic import ( - Field, BaseModel, - model_validator, + Field, computed_field, + model_validator, ) from generalresearch.currency import USDCent -from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr from generalresearch.models.thl.contest.definitions import ContestEntryType from generalresearch.models.thl.user import User @@ -23,8 +23,8 @@ class ContestEntryCreate(BaseModel): amount: Union[USDCent, int] = Field( description="The amount of the entry in integer counts or USD Cents", gt=0, - default=None, ) + # This is used in the Create Entry API. We'll look up the user and set # user_id. When we return this model in the API, user_id is excluded product_user_id: str = Field( @@ -54,7 +54,6 @@ class ContestEntry(BaseModel): amount: Union[USDCent, int] = Field( description="The amount of the entry in integer counts or USD Cents", gt=0, - default=None, ) # user_id used internally, for DB joins/index @@ -77,6 +76,7 @@ class ContestEntry(BaseModel): elif entry_type == ContestEntryType.CASH: # This may be coming from the DB, in which case it is an int. data["amount"] = USDCent(data["amount"]) + return data @computed_field() @@ -91,6 +91,8 @@ class ContestEntry(BaseModel): elif self.entry_type == ContestEntryType.CASH: return self.amount.to_usd_str() + raise ValueError(f"Unknown entry_type: {self.entry_type}") + @computed_field() @property def censored_product_user_id(self) -> str: diff --git a/generalresearch/models/thl/contest/examples.py b/generalresearch/models/thl/contest/examples.py index 9748597..4835c14 100644 --- a/generalresearch/models/thl/contest/examples.py +++ b/generalresearch/models/thl/contest/examples.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Any, Dict from pydantic import HttpUrl @@ -6,21 +6,21 @@ from generalresearch.config import EXAMPLE_PRODUCT_ID from generalresearch.currency import USDCent -def _example_raffle_create(schema: Dict) -> None: - from generalresearch.models.thl.contest.raffle import ( - RaffleContestCreate, - ) +def _example_raffle_create(schema: Dict[str, Any]) -> None: from generalresearch.models.thl.contest import ( ContestEndCondition, - ContestPrize, ContestEntryRule, + ContestPrize, + ) + from generalresearch.models.thl.contest.contest_entry import ( + ContestEntryType, ) from generalresearch.models.thl.contest.definitions import ( - ContestType, ContestPrizeKind, + ContestType, ) - from generalresearch.models.thl.contest.contest_entry import ( - ContestEntryType, + from generalresearch.models.thl.contest.raffle import ( + RaffleContestCreate, ) schema["example"] = RaffleContestCreate( @@ -47,20 +47,20 @@ def _example_raffle_create(schema: Dict) -> None: def _example_raffle(schema: Dict) -> None: - from generalresearch.models.thl.contest.raffle import RaffleContest from generalresearch.models.thl.contest import ( ContestEndCondition, - ContestPrize, ContestEntryRule, + ContestPrize, + ) + from generalresearch.models.thl.contest.contest_entry import ( + ContestEntryType, ) from generalresearch.models.thl.contest.definitions import ( - ContestStatus, ContestPrizeKind, + ContestStatus, ContestType, ) - from generalresearch.models.thl.contest.contest_entry import ( - ContestEntryType, - ) + from generalresearch.models.thl.contest.raffle import RaffleContest schema["example"] = RaffleContest( name="Win an iPhone", @@ -91,22 +91,24 @@ def _example_raffle(schema: Dict) -> None: product_id=EXAMPLE_PRODUCT_ID, ).model_dump(mode="json") + return None + -def _example_raffle_user_view(schema: Dict) -> None: - from generalresearch.models.thl.contest.raffle import RaffleUserView +def _example_raffle_user_view(schema: Dict[str, Any]) -> None: from generalresearch.models.thl.contest import ( ContestEndCondition, - ContestPrize, ContestEntryRule, + ContestPrize, + ) + from generalresearch.models.thl.contest.contest_entry import ( + ContestEntryType, ) from generalresearch.models.thl.contest.definitions import ( - ContestStatus, ContestPrizeKind, + ContestStatus, ContestType, ) - from generalresearch.models.thl.contest.contest_entry import ( - ContestEntryType, - ) + from generalresearch.models.thl.contest.raffle import RaffleUserView schema["example"] = RaffleUserView( name="Win an iPhone", @@ -140,19 +142,21 @@ def _example_raffle_user_view(schema: Dict) -> None: product_user_id="test-user", ).model_dump(mode="json") + return None -def _example_milestone_create(schema: Dict) -> None: - from generalresearch.models.thl.contest.milestone import ( - MilestoneContestCreate, - MilestoneContestEndCondition, - ContestEntryTrigger, - ) + +def _example_milestone_create(schema: Dict[str, Any]) -> None: from generalresearch.models.thl.contest import ( ContestPrize, ) from generalresearch.models.thl.contest.definitions import ( - ContestType, ContestPrizeKind, + ContestType, + ) + from generalresearch.models.thl.contest.milestone import ( + ContestEntryTrigger, + MilestoneContestCreate, + MilestoneContestEndCondition, ) schema["example"] = MilestoneContestCreate( @@ -179,19 +183,21 @@ def _example_milestone_create(schema: Dict) -> None: terms_and_conditions=HttpUrl("https://www.example.com"), ).model_dump(mode="json") + return None -def _example_milestone(schema: Dict) -> None: - from generalresearch.models.thl.contest.milestone import ( - MilestoneContest, - MilestoneContestEndCondition, - ContestEntryTrigger, - ) + +def _example_milestone(schema: Dict[str, Any]) -> None: from generalresearch.models.thl.contest import ( ContestPrize, ) from generalresearch.models.thl.contest.definitions import ( - ContestType, ContestPrizeKind, + ContestType, + ) + from generalresearch.models.thl.contest.milestone import ( + ContestEntryTrigger, + MilestoneContest, + MilestoneContestEndCondition, ) schema["example"] = MilestoneContest( @@ -223,17 +229,19 @@ def _example_milestone(schema: Dict) -> None: win_count=12, ).model_dump(mode="json") + return None -def _example_milestone_user_view(schema: Dict) -> None: - from generalresearch.models.thl.contest.milestone import ( - MilestoneUserView, - MilestoneContestEndCondition, - ContestEntryTrigger, - ) + +def _example_milestone_user_view(schema: Dict[str, Any]) -> None: from generalresearch.models.thl.contest import ContestPrize from generalresearch.models.thl.contest.definitions import ( - ContestType, ContestPrizeKind, + ContestType, + ) + from generalresearch.models.thl.contest.milestone import ( + ContestEntryTrigger, + MilestoneContestEndCondition, + MilestoneUserView, ) schema["example"] = MilestoneUserView( @@ -267,17 +275,19 @@ def _example_milestone_user_view(schema: Dict) -> None: product_user_id="test-user", ).model_dump(mode="json") + return None -def _example_leaderboard_contest_create(schema: Dict) -> None: - from generalresearch.models.thl.contest.leaderboard import ( - LeaderboardContestCreate, - ) + +def _example_leaderboard_contest_create(schema: Dict[str, Any]) -> None: from generalresearch.models.thl.contest import ( ContestPrize, ) from generalresearch.models.thl.contest.definitions import ( - ContestType, ContestPrizeKind, + ContestType, + ) + from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContestCreate, ) schema["example"] = LeaderboardContestCreate( @@ -313,16 +323,16 @@ def _example_leaderboard_contest_create(schema: Dict) -> None: return None -def _example_leaderboard_contest(schema: Dict) -> None: - from generalresearch.models.thl.contest.leaderboard import ( - LeaderboardContest, - ) +def _example_leaderboard_contest(schema: Dict[str, Any]) -> None: from generalresearch.models.thl.contest import ( ContestPrize, ) from generalresearch.models.thl.contest.definitions import ( - ContestType, ContestPrizeKind, + ContestType, + ) + from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContest, ) schema["example"] = LeaderboardContest( @@ -359,10 +369,7 @@ def _example_leaderboard_contest(schema: Dict) -> None: return None -def _example_leaderboard_contest_user_view(schema: Dict) -> None: - from generalresearch.models.thl.contest.leaderboard import ( - LeaderboardContestUserView, - ) +def _example_leaderboard_contest_user_view(schema: Dict[str, Any]) -> None: from generalresearch.models.thl.contest import ( ContestPrize, ) @@ -370,6 +377,9 @@ def _example_leaderboard_contest_user_view(schema: Dict) -> None: ContestPrizeKind, ContestType, ) + from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContestUserView, + ) schema["example"] = LeaderboardContestUserView( name="Prizes for top survey takers this week", @@ -402,3 +412,5 @@ def _example_leaderboard_contest_user_view(schema: Dict) -> None: product_id=EXAMPLE_PRODUCT_ID, product_user_id="test-user", ).model_dump(mode="json") + + return None diff --git a/generalresearch/models/thl/contest/leaderboard.py b/generalresearch/models/thl/contest/leaderboard.py index 0b24190..d0681c4 100644 --- a/generalresearch/models/thl/contest/leaderboard.py +++ b/generalresearch/models/thl/contest/leaderboard.py @@ -1,12 +1,12 @@ -from datetime import datetime, timezone, timedelta -from typing import Optional, Literal, List, Tuple, Dict, Any +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Literal, Optional, Tuple from pydantic import ( - Field, ConfigDict, + Field, + PrivateAttr, computed_field, model_validator, - PrivateAttr, ) from redis import Redis from typing_extensions import Self @@ -18,8 +18,8 @@ from generalresearch.managers.thl.user_manager.user_manager import ( UserManager, ) from generalresearch.models.thl.contest import ( - ContestWinner, ContestEndCondition, + ContestWinner, ) from generalresearch.models.thl.contest.contest import ( Contest, @@ -27,16 +27,16 @@ from generalresearch.models.thl.contest.contest import ( ContestUserView, ) from generalresearch.models.thl.contest.definitions import ( + ContestEndReason, + ContestPrizeKind, ContestStatus, ContestType, - ContestPrizeKind, - ContestEndReason, LeaderboardTieBreakStrategy, ) from generalresearch.models.thl.contest.examples import ( - _example_leaderboard_contest_user_view, _example_leaderboard_contest, _example_leaderboard_contest_create, + _example_leaderboard_contest_user_view, ) from generalresearch.models.thl.leaderboard import ( Leaderboard, @@ -96,7 +96,7 @@ class LeaderboardContestCreate(ContestBase): return self @property - def leaderboard_key_parts(self) -> Dict: + def leaderboard_key_parts(self) -> Dict[str, Any]: assert self.leaderboard_key.count(":") == 5, "invalid leaderboard_key" parts = self.leaderboard_key.split(":") _, product_id, country_iso, freq_str, date_str, board_code_value = parts @@ -281,7 +281,7 @@ class LeaderboardContestUserView(LeaderboardContest, ContestUserView): return False, "Contest hasn't started" # This would indicate something is wrong, as something else should have done this - e, reason = self.should_end() + e, _ = self.should_end() if e: LOG.warning("contest should be over") return False, "contest is over" diff --git a/generalresearch/models/thl/contest/milestone.py b/generalresearch/models/thl/contest/milestone.py index f902401..8c51a88 100644 --- a/generalresearch/models/thl/contest/milestone.py +++ b/generalresearch/models/thl/contest/milestone.py @@ -2,33 +2,33 @@ from __future__ import annotations import logging from datetime import timedelta -from typing import Literal, Optional, Tuple, Dict +from typing import Any, Dict, Literal, Optional, Tuple from pydantic import ( - Field, - ConfigDict, BaseModel, + ConfigDict, + Field, PositiveInt, ) from typing_extensions import Self from generalresearch.models.custom_types import AwareDatetimeISO from generalresearch.models.thl.contest.contest import ( - ContestBase, Contest, + ContestBase, ContestUserView, ) from generalresearch.models.thl.contest.contest_entry import ContestEntry from generalresearch.models.thl.contest.definitions import ( + ContestEndReason, + ContestEntryTrigger, ContestEntryType, ContestStatus, ContestType, - ContestEndReason, - ContestEntryTrigger, ) from generalresearch.models.thl.contest.examples import ( - _example_milestone_create, _example_milestone, + _example_milestone_create, _example_milestone_user_view, ) @@ -42,9 +42,7 @@ class MilestoneEntry(ContestEntry): entry_type: Literal[ContestEntryType.COUNT] = Field(default=ContestEntryType.COUNT) - # TODO: Must fix - how can the default be None if it's not Optional... amount: int = Field( - default=None, description="The amount of the entry in integer counts", gt=0, ) @@ -147,7 +145,7 @@ class MilestoneContest(MilestoneContestCreate, Contest): # just does nothing return None - def model_dump_mysql(self): + def model_dump_mysql(self) -> Dict[str, Any]: d = super().model_dump_mysql( exclude={ "entry_trigger", @@ -165,7 +163,7 @@ class MilestoneContest(MilestoneContestCreate, Contest): return d @classmethod - def model_validate_mysql(cls, data: Dict) -> Self: + def model_validate_mysql(cls, data: Dict[str, Any]) -> Self: data.update( MilestoneContestConfig.model_validate(data["milestone_config"]).model_dump() ) @@ -218,9 +216,10 @@ class MilestoneUserView(MilestoneContest, ContestUserView): # i.e. it hasn't been >24 hrs since user signed up, or whatever # This would indicate something is wrong, as something else should have done this - e, reason = self.should_end() + e, _ = self.should_end() if e: LOG.warning("contest should be over") return False, "contest is over" + # TODO: others in self.entry_rule ... min_completes, id_verified, etc. return True, "" diff --git a/generalresearch/models/thl/contest/raffle.py b/generalresearch/models/thl/contest/raffle.py index 9a01d0f..1592857 100644 --- a/generalresearch/models/thl/contest/raffle.py +++ b/generalresearch/models/thl/contest/raffle.py @@ -4,14 +4,14 @@ import logging import random from collections import defaultdict from datetime import datetime, timezone -from typing import Literal, List, Dict, Tuple, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union from pydantic import ( + ConfigDict, Field, - model_validator, computed_field, field_validator, - ConfigDict, + model_validator, ) from scipy.stats import hypergeom from typing_extensions import Self @@ -28,14 +28,14 @@ from generalresearch.models.thl.contest.contest import ( ) from generalresearch.models.thl.contest.contest_entry import ContestEntry from generalresearch.models.thl.contest.definitions import ( + ContestEndReason, ContestEntryType, ContestStatus, ContestType, - ContestEndReason, ) from generalresearch.models.thl.contest.examples import ( - _example_raffle_create, _example_raffle, + _example_raffle_create, _example_raffle_user_view, ) @@ -186,10 +186,10 @@ class RaffleContest(RaffleContestCreate, Contest): def get_current_participants(self) -> int: return len({entry.user.user_id for entry in self.entries}) - def get_current_amount(self) -> Union[int | USDCent]: + def get_current_amount(self) -> Union[int, USDCent]: return sum([x.amount for x in self.entries]) - def get_user_amount(self, product_user_id: str) -> Union[int | USDCent]: + def get_user_amount(self, product_user_id: str) -> Union[int, USDCent]: # Sum of this user's amounts return sum( e.amount for e in self.entries if e.user.product_user_id == product_user_id @@ -206,7 +206,7 @@ class RaffleContest(RaffleContestCreate, Contest): return True return False - def model_dump_mysql(self): + def model_dump_mysql(self) -> Dict[str, Any]: d = super().model_dump_mysql() d["entry_rule"] = self.entry_rule.model_dump_json() return d @@ -309,9 +309,10 @@ class RaffleUserView(RaffleContest, ContestUserView): return False, "Reached max amount today." # This would indicate something is wrong, as something else should have done this - e, reason = self.should_end() + e, _ = self.should_end() if e: LOG.warning("contest should be over") return False, "contest is over" + # todo: others in self.entry_rule ... min_completes, id_verified, etc. return True, "" diff --git a/generalresearch/models/thl/contest/utils.py b/generalresearch/models/thl/contest/utils.py index 1505f7b..e5043c2 100644 --- a/generalresearch/models/thl/contest/utils.py +++ b/generalresearch/models/thl/contest/utils.py @@ -1,9 +1,9 @@ -from typing import TYPE_CHECKING, List, Dict +from typing import TYPE_CHECKING, Dict, List if TYPE_CHECKING: - from generalresearch.models.thl.user import User from generalresearch.currency import USDCent from generalresearch.models.thl.leaderboard import LeaderboardRow + from generalresearch.models.thl.user import User def censor_product_user_id(user: "User") -> str: diff --git a/generalresearch/models/thl/definitions.py b/generalresearch/models/thl/definitions.py index 259093b..22ca4b9 100644 --- a/generalresearch/models/thl/definitions.py +++ b/generalresearch/models/thl/definitions.py @@ -67,7 +67,8 @@ class THLPaths(str, Enum, metaclass=ReprEnumMeta): class Status(str, Enum, metaclass=ReprEnumMeta): """ - The outcome of a session or wall event. If the session is still in progress, the status will be NULL. + The outcome of a session or wall event. If the session is still in + progress, the status will be NULL. """ # User completed the job successfully and should be paid something diff --git a/generalresearch/models/thl/demographics.py b/generalresearch/models/thl/demographics.py index fd833c5..a688e49 100644 --- a/generalresearch/models/thl/demographics.py +++ b/generalresearch/models/thl/demographics.py @@ -1,10 +1,10 @@ from __future__ import annotations import copy -from collections import defaultdict, Counter +from collections import Counter, defaultdict from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Literal, List, Dict +from typing import TYPE_CHECKING, Any, Dict, List, Literal import numpy as np @@ -76,7 +76,7 @@ class AgeGroup(Enum): return self.label -def calculate_demographic_metrics(opps: List[MarketplaceTask]): +def calculate_demographic_metrics(opps: List[MarketplaceTask]) -> List: """ Measurement: marketplace_survey_demographics tags: source (marketplace) @@ -141,12 +141,13 @@ def calculate_demographic_metrics(opps: List[MarketplaceTask]): d["tags"].update(k.to_tags()) d["fields"].update(v) points.append(d) + return points def calculate_used_question_metrics( opps: List[MarketplaceTask], qid_label: Dict[str, str] -): +) -> List[Dict[str, Any]]: """ Measurement: marketplace_survey_targeting tags: source (marketplace), "type", country (all and individual) diff --git a/generalresearch/models/thl/finance.py b/generalresearch/models/thl/finance.py index 6a24b5e..8d526ac 100644 --- a/generalresearch/models/thl/finance.py +++ b/generalresearch/models/thl/finance.py @@ -1,23 +1,23 @@ import random from datetime import timezone -from typing import Optional, TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from uuid import uuid4 import pandas as pd from pydantic import ( BaseModel, + ConfigDict, Field, NonNegativeInt, - ConfigDict, - model_validator, computed_field, field_validator, + model_validator, ) from pydantic.json_schema import SkipJsonSchema from generalresearch.currency import USDCent from generalresearch.decorators import LOG -from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr from generalresearch.models.thl.definitions import SessionAdjustedStatus from generalresearch.pg_helper import PostgresConfig @@ -25,9 +25,8 @@ payout_example = random.randint(150, 750 * 100) adjustment_example = random.randint(-1_000, 50 * 100) if TYPE_CHECKING: - from generalresearch.models.thl.ledger import LedgerAccount from generalresearch.managers.thl.product import ProductManager - from generalresearch.models.thl.ledger import AccountType, Direction + from generalresearch.models.thl.ledger import AccountType, Direction, LedgerAccount class AdjustmentType(BaseModel): @@ -111,12 +110,11 @@ class POPFinancial(BaseModel): has a single Product. """ + from generalresearch.config import is_debug from generalresearch.incite.schemas.mergers.pop_ledger import ( numerical_col_names, ) - from generalresearch.config import is_debug - # Validate the input accounts assert len(accounts) > 0, "Must provide accounts" from generalresearch.models.thl.ledger import ( @@ -504,7 +502,7 @@ class ProductBalances(BaseModel): @staticmethod def from_pandas( input_data: pd.DataFrame | pd.Series, - ): + ) -> "ProductBalances": LOG.debug(f"ProductBalances.from_pandas(input_data={input_data.shape})") if isinstance(input_data, pd.Series): @@ -832,11 +830,11 @@ class BusinessBalances(BaseModel): from generalresearch.incite.schemas.mergers.pop_ledger import ( numerical_col_names, ) - from generalresearch.models.thl.product import Product from generalresearch.models.thl.ledger import ( AccountType, Direction, ) + from generalresearch.models.thl.product import Product # Validate the input accounts assert len(accounts) > 0, "Must provide accounts" diff --git a/generalresearch/models/thl/grliq.py b/generalresearch/models/thl/grliq.py index 1e769aa..69b23bd 100644 --- a/generalresearch/models/thl/grliq.py +++ b/generalresearch/models/thl/grliq.py @@ -1,6 +1,6 @@ from generalresearch.grliq.models.decider import ( - Decider, AttemptDecision, + Decider, GrlIqAttemptResult, ) diff --git a/generalresearch/models/thl/ipinfo.py b/generalresearch/models/thl/ipinfo.py index d8abfc0..a98689f 100644 --- a/generalresearch/models/thl/ipinfo.py +++ b/generalresearch/models/thl/ipinfo.py @@ -1,23 +1,23 @@ import ipaddress from datetime import datetime, timezone -from typing import Optional, Dict, Any, Literal, Tuple +from typing import Any, Dict, Literal, Optional, Tuple import geoip2.models from faker import Faker from pydantic import ( BaseModel, + ConfigDict, Field, PositiveInt, - field_validator, PrivateAttr, - ConfigDict, + field_validator, ) from typing_extensions import Self from generalresearch.models.custom_types import ( AwareDatetimeISO, - IPvAnyAddressStr, CountryISOLike, + IPvAnyAddressStr, ) from generalresearch.models.thl.maxmind.definitions import UserType from generalresearch.pg_helper import PostgresConfig @@ -104,9 +104,10 @@ class IPGeoname(BaseModel): "subdivision_2_iso", mode="before", ) - def make_lower(cls, value: str): + def make_lower(cls, value: Optional[str]) -> Optional[str]: if value is not None: return value.lower() + return value # --- ORM --- @@ -116,7 +117,7 @@ class IPGeoname(BaseModel): return d @classmethod - def from_mysql(cls, d: Dict) -> Self: + def from_mysql(cls, d: Dict[str, Any]) -> Self: d["updated"] = d["updated"].replace(tzinfo=timezone.utc) return cls.model_validate(d) @@ -250,9 +251,10 @@ class IPInformation(BaseModel): _geoname: Optional[IPGeoname] = PrivateAttr(default=None) @field_validator("country_iso", "registered_country_iso", mode="before") - def make_lower(cls, value: str): + def make_lower(cls, value: Optional[str]) -> Optional[str]: if value is not None: return value.lower() + return value @property diff --git a/generalresearch/models/thl/leaderboard.py b/generalresearch/models/thl/leaderboard.py index a4c2134..7d79091 100644 --- a/generalresearch/models/thl/leaderboard.py +++ b/generalresearch/models/thl/leaderboard.py @@ -1,25 +1,25 @@ from __future__ import annotations import logging +import math from datetime import datetime, timedelta, timezone from enum import Enum from typing import List, Literal from uuid import UUID, uuid3 -from zoneinfo import ZoneInfo -import math import pandas as pd from pydantic import ( + AwareDatetime, BaseModel, Field, NonNegativeInt, - model_validator, computed_field, - AwareDatetime, field_validator, + model_validator, ) +from zoneinfo import ZoneInfo -from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr from generalresearch.models.legacy.api_status import StatusResponse from generalresearch.models.thl.locales import CountryISO from generalresearch.utils.enum import ReprEnumMeta @@ -81,13 +81,11 @@ class Leaderboard(BaseModel): id: UUIDStr = Field( description="Unique ID for this leaderboard", examples=["845b0074ad533df580ebb9c80cc3bce1"], - default=None, ) name: str = Field( description="Descriptive name for the leaderboard based on the board_code", examples=["Number of Completes"], - default=None, ) board_code: LeaderboardCode = Field( @@ -111,7 +109,6 @@ class Leaderboard(BaseModel): timezone_name: str = Field( description="The timezone for the requested country", examples=["America/New_York"], - default=None, ) sort_order: Literal["ascending", "descending"] = Field(default="descending") @@ -155,7 +152,6 @@ class Leaderboard(BaseModel): ) ], # exclude=True, - default=None, ) @property diff --git a/generalresearch/models/thl/ledger.py b/generalresearch/models/thl/ledger.py index 045eeba..31679d2 100644 --- a/generalresearch/models/thl/ledger.py +++ b/generalresearch/models/thl/ledger.py @@ -1,31 +1,31 @@ from datetime import datetime, timezone from enum import Enum -from typing import Dict, Optional, List, Literal, Annotated, Union +from typing import Annotated, Any, Dict, List, Literal, Optional, Union from uuid import uuid4 from pydantic import ( BaseModel, - Field, - field_validator, - model_validator, ConfigDict, + Field, + NonNegativeInt, PositiveInt, computed_field, - NonNegativeInt, + field_validator, + model_validator, ) from typing_extensions import Self from generalresearch.models.custom_types import ( - UUIDStr, AwareDatetimeISO, - check_valid_uuid, HttpsUrlStr, + UUIDStr, + check_valid_uuid, ) from generalresearch.models.thl.ledger_example import ( - _example_user_tx_payout, + _example_user_tx_adjustment, _example_user_tx_bonus, _example_user_tx_complete, - _example_user_tx_adjustment, + _example_user_tx_payout, ) from generalresearch.models.thl.pagination import Page from generalresearch.models.thl.payout_format import ( @@ -300,7 +300,7 @@ class LedgerTransaction(BaseModel): ), "ledger entries must balance" return entries - def model_dump_mysql(self, *args, **kwargs) -> dict: + def model_dump_mysql(self, *args, **kwargs) -> Dict[str, Any]: d = self.model_dump(mode="json", *args, **kwargs) if "created" in d: d["created"] = self.created.replace(tzinfo=None) diff --git a/generalresearch/models/thl/ledger_example.py b/generalresearch/models/thl/ledger_example.py index 32cd464..bf120cb 100644 --- a/generalresearch/models/thl/ledger_example.py +++ b/generalresearch/models/thl/ledger_example.py @@ -1,9 +1,9 @@ from datetime import datetime, timezone -from typing import Dict +from typing import Any, Dict from uuid import uuid4 -def _example_user_tx_payout(schema: Dict) -> None: +def _example_user_tx_payout(schema: Dict[str, Any]) -> None: from generalresearch.models.thl.ledger import ( UserLedgerTransactionUserPayout, ) @@ -18,7 +18,7 @@ def _example_user_tx_payout(schema: Dict) -> None: ).model_dump(mode="json") -def _example_user_tx_bonus(schema: Dict) -> None: +def _example_user_tx_bonus(schema: Dict[str, Any]) -> None: from generalresearch.models.thl.ledger import ( UserLedgerTransactionUserBonus, ) @@ -32,7 +32,7 @@ def _example_user_tx_bonus(schema: Dict) -> None: ).model_dump(mode="json") -def _example_user_tx_complete(schema: Dict) -> None: +def _example_user_tx_complete(schema: Dict[str, Any]) -> None: from generalresearch.models.thl.ledger import ( UserLedgerTransactionTaskComplete, ) @@ -47,7 +47,7 @@ def _example_user_tx_complete(schema: Dict) -> None: ).model_dump(mode="json") -def _example_user_tx_adjustment(schema: Dict) -> None: +def _example_user_tx_adjustment(schema: Dict[str, Any]) -> None: from generalresearch.models.thl.ledger import ( UserLedgerTransactionTaskAdjustment, ) diff --git a/generalresearch/models/thl/locales.py b/generalresearch/models/thl/locales.py index 210a158..bff774a 100644 --- a/generalresearch/models/thl/locales.py +++ b/generalresearch/models/thl/locales.py @@ -4,8 +4,8 @@ from pydantic import AfterValidator from generalresearch.locales import Localelator from generalresearch.models.custom_types import ( - to_comma_sep_str, from_comma_sep_str, + to_comma_sep_str, ) locale_helper = Localelator() diff --git a/generalresearch/models/thl/offerwall/__init__.py b/generalresearch/models/thl/offerwall/__init__.py index 2da9d43..bf6b7ea 100644 --- a/generalresearch/models/thl/offerwall/__init__.py +++ b/generalresearch/models/thl/offerwall/__init__.py @@ -4,14 +4,14 @@ import hashlib import json from decimal import Decimal from enum import Enum -from typing import Literal, Optional, Dict, Set, Any +from typing import Any, Dict, Literal, Optional, Set from pydantic import ( BaseModel, Field, - model_validator, - computed_field, PositiveInt, + computed_field, + model_validator, ) from typing_extensions import Self diff --git a/generalresearch/models/thl/offerwall/base.py b/generalresearch/models/thl/offerwall/base.py index 66149a9..8ea259d 100644 --- a/generalresearch/models/thl/offerwall/base.py +++ b/generalresearch/models/thl/offerwall/base.py @@ -2,29 +2,31 @@ import statistics from datetime import timedelta from decimal import Decimal from string import Formatter -from typing import Optional, List, Any, Set, Dict, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple from uuid import uuid4 import numpy as np import pandas as pd from pydantic import ( BaseModel, + ConfigDict, Field, NonNegativeFloat, NonNegativeInt, - ConfigDict, field_validator, model_validator, ) -from typing_extensions import Self, Annotated +from typing_extensions import Annotated, Self from generalresearch.models import Source -from generalresearch.models.custom_types import UUIDStr, HttpsUrl +from generalresearch.models.custom_types import HttpsUrl, UUIDStr from generalresearch.models.legacy.bucket import ( Bucket as LegacyBucket, - Eligibility, +) +from generalresearch.models.legacy.bucket import ( CategoryAssociation, DurationSummary, + Eligibility, PayoutSummary, PayoutSummaryDecimal, SurveyEligibilityCriterion, @@ -32,9 +34,9 @@ from generalresearch.models.legacy.bucket import ( from generalresearch.models.legacy.definitions import OfferwallReason from generalresearch.models.thl.locales import CountryISO from generalresearch.models.thl.offerwall import ( + OFFERWALL_TYPE_CLASS, OfferWallType, OfferWallTypeClass, - OFFERWALL_TYPE_CLASS, ) from generalresearch.models.thl.offerwall.bucket import ( generate_offerwall_entry_url, diff --git a/generalresearch/models/thl/offerwall/behavior.py b/generalresearch/models/thl/offerwall/behavior.py index e8da334..8e6c089 100644 --- a/generalresearch/models/thl/offerwall/behavior.py +++ b/generalresearch/models/thl/offerwall/behavior.py @@ -1,6 +1,6 @@ from typing import Any, Dict, Literal -from pydantic import Field, BaseModel +from pydantic import BaseModel, Field class OfferWallBehavior(BaseModel): diff --git a/generalresearch/models/thl/offerwall/bucket.py b/generalresearch/models/thl/offerwall/bucket.py index dd93cd1..1e08e32 100644 --- a/generalresearch/models/thl/offerwall/bucket.py +++ b/generalresearch/models/thl/offerwall/bucket.py @@ -8,9 +8,9 @@ def generate_offerwall_entry_url( bp_user_id: str, request_id: Optional[str] = None, nudge_id: Optional[str] = None, -): - # for an offerwall entry link, we need the clicked bucket_id and the request hash (so we know - # which GetOfferwall cache to get +) -> str: + # For an offerwall entry link, we need the clicked bucket_id and the + # request hash (so we know which GetOfferwall cache to get query_dict = {"i": obj_id, "b": bp_user_id} if request_id: query_dict["66482fb"] = request_id diff --git a/generalresearch/models/thl/offerwall/cache.py b/generalresearch/models/thl/offerwall/cache.py index 6040175..b75a803 100644 --- a/generalresearch/models/thl/offerwall/cache.py +++ b/generalresearch/models/thl/offerwall/cache.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import Dict, Any, List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field @@ -8,8 +8,8 @@ from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr from generalresearch.models.thl.offerwall import OfferWallRequest from generalresearch.models.thl.offerwall.base import ( OfferwallBase, - TaskResult, ScoredTaskResult, + TaskResult, ) diff --git a/generalresearch/models/thl/pagination.py b/generalresearch/models/thl/pagination.py index 6679cb4..1b31078 100644 --- a/generalresearch/models/thl/pagination.py +++ b/generalresearch/models/thl/pagination.py @@ -1,6 +1,6 @@ +from math import ceil from typing import Optional -from math import ceil from pydantic import BaseModel, Field, computed_field diff --git a/generalresearch/models/thl/payout.py b/generalresearch/models/thl/payout.py index b6f880f..8ab01ec 100644 --- a/generalresearch/models/thl/payout.py +++ b/generalresearch/models/thl/payout.py @@ -1,19 +1,19 @@ import json from datetime import datetime, timezone -from typing import Dict, Optional, Collection, List +from typing import Collection, Dict, List, Optional from uuid import uuid4 from pydantic import ( BaseModel, Field, + PositiveInt, computed_field, field_validator, - PositiveInt, ) from typing_extensions import Self from generalresearch.currency import USDCent -from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr from generalresearch.models.thl.definitions import PayoutStatus from generalresearch.models.thl.ledger import OrderBy from generalresearch.models.thl.wallet import PayoutType @@ -229,6 +229,7 @@ class BrokerageProductPayoutEvent(PayoutEvent): account_product_mapping: Optional[Dict[UUIDStr, UUIDStr]] = None, redis_config: Optional[RedisConfig] = None, ) -> Self: + # TODO!: prevent re-assignment, rework this... if account_product_mapping is None: rc = redis_config.create_redis_client() @@ -248,6 +249,7 @@ class BrokerageProductPayoutEvent(PayoutEvent): account_product_mapping: Optional[Dict[UUIDStr, UUIDStr]] = None, redis_config: Optional[RedisConfig] = None, ) -> List[Self]: + # TODO!: prevent re-assignment, rework this... if account_product_mapping is None: rc = redis_config.create_redis_client() diff --git a/generalresearch/models/thl/payout_format.py b/generalresearch/models/thl/payout_format.py index f108e46..4989bdc 100644 --- a/generalresearch/models/thl/payout_format.py +++ b/generalresearch/models/thl/payout_format.py @@ -20,7 +20,9 @@ def validate_payout_format(payout_format: str) -> str: def format_payout_format(payout_format: str, payout_int: int) -> str: """ - Generate a str representation of a payout. Typically, this would be displayed to a user. + Generate a str representation of a payout. Typically, this would be + displayed to a user. + :param payout_format: see BPC_DEFAULTS.payout_format :param payout_int: The actual value in integer usd cents. """ diff --git a/generalresearch/models/thl/product.py b/generalresearch/models/thl/product.py index 949f1f8..5889dc3 100644 --- a/generalresearch/models/thl/product.py +++ b/generalresearch/models/thl/product.py @@ -125,16 +125,17 @@ class ProfilingConfig(BaseModel): enabled: bool = Field( default=True, - description="If False, the harmonizer/profiling system is not used at all. This should " - "never be False unless special circumstances", + description="If False, the harmonizer/profiling system is not used at " + "all. This should never be False unless special circumstances", ) grs_enabled: bool = Field( default=True, - description="""If grs_enabled is False, and is_grs is passed in the profiling-questions call, - then don't actually return any questions. This allows a client to hit the endpoint with no limit - and still get questions. In effect, this means that we'll redirect the user through the GRS - system but won't present them any questions.""", + description="""If grs_enabled is False, and is_grs is passed in the + profiling-questions call, then don't actually return any questions. + This allows a client to hit the endpoint with no limit and still get + questions. In effect, this means that we'll redirect the user through + the GRS system but won't present them any questions.""", ) n_questions: Optional[PositiveInt] = Field( @@ -155,14 +156,16 @@ class ProfilingConfig(BaseModel): # Don't set this to 0, use enabled task_injection_freq_mult: PositiveFloat = Field( default=1, - description="Scale how frequently we inject profiling questions, relative to the default." - "1 is default, 2 is twice as often. 10 means always. 0.5 half as often", + description="""Scale how frequently we inject profiling questions, + relative to the default. 1 is default, 2 is twice as often. 10 means + always. 0.5 half as often""", ) non_us_mult: PositiveFloat = Field( default=2, - description="Non-us multiplier, used to increase freq and length of profilers in all non-us countries." - "This value is multiplied by task_injection_freq_mult and avg_question_count.", + description="""Non-us multiplier, used to increase freq and length of + profilers in all non-us countries. This value is multiplied by + task_injection_freq_mult and avg_question_count.""", ) hidden_questions_expiration_hours: PositiveInt = Field( @@ -186,13 +189,13 @@ class UserHealthConfig(BaseModel): # are blocked. banned_countries: List[CountryISOLike] = Field(default_factory=list) - # Decide if a user can be blocked for IP-related triggers such as sharing IPs - # and location history. This should eventually be deprecated and replaced - # with something with more specificity. + # Decide if a user can be blocked for IP-related triggers such as sharing + # IPs and location history. This should eventually be deprecated and + # replaced with something with more specificity. allow_ban_iphist: bool = Field(default=True) - # These are only checked by ym-user-predict, which I'm not sure even works properly. - # To be deprecated ... don't even use them. + # These are only checked by ym-user-predict, which I'm not sure even + # works properly. To be deprecated ... don't even use them. userprofit_cutoff: Optional[Decimal] = Field(default=None, exclude=True) recon_cutoff: Optional[float] = Field(default=None, exclude=True) droprate_cutoff: Optional[float] = Field(default=None, exclude=True) @@ -205,9 +208,11 @@ class UserHealthConfig(BaseModel): class OfferWallRequestYieldmanParams(BaseModel): # model_config = ConfigDict(extra='forbid') - # keys: use_stats, use_harmonizer, allow_pii, add_default_lang_eng, first_n_completes_easier_per_day are + # keys: use_stats, use_harmonizer, allow_pii, add_default_lang_eng, + # first_n_completes_easier_per_day are # ignored/deprecated - # allow_pii: bool = Field(default=True, description="Allow tasks that request PII. This actually does nothing.") + # allow_pii: bool = Field(default=True, description="Allow tasks that + # request PII. This actually does nothing.") # see thl-grpc:yield_management.scoring.score() for more info conversion_factor_adj: float = Field( diff --git a/generalresearch/models/thl/profiling/marketplace.py b/generalresearch/models/thl/profiling/marketplace.py index 414a437..2dd7028 100644 --- a/generalresearch/models/thl/profiling/marketplace.py +++ b/generalresearch/models/thl/profiling/marketplace.py @@ -1,16 +1,16 @@ from abc import ABC, abstractmethod from datetime import datetime, timezone from functools import cached_property -from typing import Any, Dict, Set, Optional, Tuple +from typing import Any, Dict, Optional, Set, Tuple -from pydantic import PositiveInt, BaseModel, Field, computed_field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, PositiveInt, computed_field from generalresearch.models import MAX_INT32, Source from generalresearch.models.custom_types import ( AwareDatetimeISO, - UUIDStr, CountryISOLike, LanguageISOLike, + UUIDStr, ) from generalresearch.models.thl.locales import CountryISO, LanguageISO @@ -48,8 +48,9 @@ class MarketplaceQuestion(BaseModel, ABC): @property @abstractmethod def internal_id(self) -> str: - """This is the value that is used for this question within the marketplace. Typically, - this is question_id. Innovate uses question_key.""" + """This is the value that is used for this question within the + marketplace. Typically, this is question_id. Innovate uses question_key. + """ ... @property @@ -58,8 +59,9 @@ class MarketplaceQuestion(BaseModel, ABC): @property def _key(self) -> Tuple[str, CountryISOLike, LanguageISOLike]: - """This uniquely identifies a question in a locale. There is a unique index - on this in the db. e.g. (question_id, country_iso, language_iso)""" + """This uniquely identifies a question in a locale. There is a unique + index on this in the db. e.g. (question_id, country_iso, language_iso) + """ return self.internal_id, self.country_iso, self.language_iso @abstractmethod diff --git a/generalresearch/models/thl/profiling/question.py b/generalresearch/models/thl/profiling/question.py index 6f6b270..72115fe 100644 --- a/generalresearch/models/thl/profiling/question.py +++ b/generalresearch/models/thl/profiling/question.py @@ -1,19 +1,19 @@ from __future__ import annotations -from typing import Optional, Dict, Tuple +from typing import Any, Dict, Optional, Tuple from pydantic import ( BaseModel, - Field, ConfigDict, + Field, computed_field, ) from generalresearch.models.custom_types import ( - UUIDStr, AwareDatetimeISO, CountryISOLike, LanguageISOLike, + UUIDStr, ) from generalresearch.models.thl.profiling.upk_question import UpkQuestion @@ -34,7 +34,7 @@ class Question(BaseModel): ) data: UpkQuestion = Field() is_live: bool = Field() - custom: Dict = Field(default_factory=dict) + custom: Dict[str, Any] = Field(default_factory=dict) last_updated: AwareDatetimeISO = Field() @computed_field diff --git a/generalresearch/models/thl/profiling/upk_property.py b/generalresearch/models/thl/profiling/upk_property.py index 2b63aef..9eede95 100644 --- a/generalresearch/models/thl/profiling/upk_property.py +++ b/generalresearch/models/thl/profiling/upk_property.py @@ -1,11 +1,11 @@ from enum import Enum from functools import cached_property -from typing import List, Optional, Dict +from typing import Dict, List, Optional from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field, TypeAdapter -from generalresearch.models.custom_types import UUIDStr, CountryISOLike +from generalresearch.models.custom_types import CountryISOLike, UUIDStr from generalresearch.models.thl.category import Category from generalresearch.utils.enum import ReprEnumMeta diff --git a/generalresearch/models/thl/profiling/upk_question.py b/generalresearch/models/thl/profiling/upk_question.py index 2b952ec..5c908b0 100644 --- a/generalresearch/models/thl/profiling/upk_question.py +++ b/generalresearch/models/thl/profiling/upk_question.py @@ -5,16 +5,16 @@ import json import re from enum import Enum from functools import cached_property -from typing import List, Optional, Union, Literal, Dict, Tuple, Set +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union from pydantic import ( BaseModel, - Field, - model_validator, - field_validator, ConfigDict, + Field, NonNegativeInt, PositiveInt, + field_validator, + model_validator, ) from typing_extensions import Annotated @@ -302,7 +302,7 @@ class UpkQuestion(BaseModel): # Don't set a min_length=1 here. We'll allow this to be created, but it # won't be askable with empty choices. choices: Optional[List[UpkQuestionChoice]] = Field(default=None) - selector: SelectorType = Field(default=None) + selector: SelectorType = Field() configuration: Optional[Configuration] = Field(default=None) validation: Optional[UpkQuestionValidation] = Field(default=None) importance: Optional[UPKImportance] = Field(default=None) @@ -348,7 +348,7 @@ class UpkQuestion(BaseModel): @model_validator(mode="before") @classmethod - def check_configuration_type(cls, data: Dict): + def check_configuration_type(cls, data: Dict[str, Any]) -> Dict[str, Any]: # The model knows what the type of Configuration to grab depending on # the key 'type' which it expects inside the configuration object. # Here, we grab the type from the top-level model instead. @@ -433,19 +433,23 @@ class UpkQuestion(BaseModel): @field_validator("choices") @classmethod - def order_choices(cls, choices): + def order_choices(cls, choices: List): if choices: choices.sort(key=lambda x: x.order) return choices @field_validator("choices") @classmethod - def validate_choices(cls, choices): + def validate_choices( + cls, choices: Optional[List[UpkQuestionChoice]] + ) -> Optional[List[UpkQuestionChoice]]: if choices: ids = {x.id for x in choices} assert len(ids) == len(choices), "choices.id must be unique" + orders = {x.order for x in choices} assert len(orders) == len(choices), "choices.order must be unique" + return choices @field_validator("explanation_template", "explanation_fragment_template") diff --git a/generalresearch/models/thl/profiling/upk_question_answer.py b/generalresearch/models/thl/profiling/upk_question_answer.py index 28b9f27..2eb52e1 100644 --- a/generalresearch/models/thl/profiling/upk_question_answer.py +++ b/generalresearch/models/thl/profiling/upk_question_answer.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import Optional, Union, Dict +from typing import Any, Dict, Optional, Union from uuid import uuid4 from pydantic import ( @@ -7,25 +7,24 @@ from pydantic import ( ConfigDict, Field, PositiveInt, - model_validator, computed_field, + model_validator, ) from typing_extensions import Self from generalresearch.models import MAX_INT32 from generalresearch.models.custom_types import ( - UUIDStr, AwareDatetimeISO, CountryISOLike, + UUIDStr, ) from generalresearch.models.thl.profiling.upk_property import ( - PropertyType, Cardinality, + PropertyType, ) class UpkQuestionAnswer(BaseModel): - """ """ model_config = ConfigDict(populate_by_name=True) @@ -110,7 +109,7 @@ class UpkQuestionAnswer(BaseModel): return self - def model_dump_mysql(self) -> Dict: + def model_dump_mysql(self) -> Dict[str, Any]: d = self.model_dump(mode="json") d["created"] = self.created return d diff --git a/generalresearch/models/thl/profiling/user_info.py b/generalresearch/models/thl/profiling/user_info.py index 609121e..15197c9 100644 --- a/generalresearch/models/thl/profiling/user_info.py +++ b/generalresearch/models/thl/profiling/user_info.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import List, Optional from pydantic import BaseModel, ConfigDict, Field from pydantic.json_schema import SkipJsonSchema diff --git a/generalresearch/models/thl/profiling/user_question_answer.py b/generalresearch/models/thl/profiling/user_question_answer.py index b42183d..b325583 100644 --- a/generalresearch/models/thl/profiling/user_question_answer.py +++ b/generalresearch/models/thl/profiling/user_question_answer.py @@ -1,19 +1,19 @@ import json -from datetime import datetime, timezone, timedelta -from typing import Dict, Tuple, Iterator, Optional, Literal, Union, Any +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Iterator, Literal, Optional, Tuple, Union from pydantic import ( - PositiveInt, + BaseModel, + ConfigDict, Field, + PositiveInt, field_validator, model_validator, - BaseModel, - ConfigDict, ) from typing_extensions import Self from generalresearch.grpc import timestamp_to_datetime -from generalresearch.models import Source, MAX_INT32 +from generalresearch.models import MAX_INT32, Source from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr from generalresearch.models.thl.locales import CountryISO, LanguageISO from generalresearch.models.thl.profiling.upk_question import UpkQuestion diff --git a/generalresearch/models/thl/report_task.py b/generalresearch/models/thl/report_task.py index e945b3b..1abd330 100644 --- a/generalresearch/models/thl/report_task.py +++ b/generalresearch/models/thl/report_task.py @@ -1,6 +1,6 @@ import random from collections import defaultdict -from typing import List, Collection, Optional +from typing import Collection, List, Optional from pydantic import BaseModel, ConfigDict, Field diff --git a/generalresearch/models/thl/session.py b/generalresearch/models/thl/session.py index 0066559..74ed5eb 100644 --- a/generalresearch/models/thl/session.py +++ b/generalresearch/models/thl/session.py @@ -1,29 +1,28 @@ import json import logging -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from decimal import Decimal -from typing import Optional, Dict, Any, Tuple, Union, List, Annotated +from typing import TYPE_CHECKING, Annotated, Any, Dict, List, Optional, Tuple, Union from uuid import uuid4 -from typing import TYPE_CHECKING from pydantic import ( - BaseModel, AwareDatetime, + BaseModel, + ConfigDict, Field, - model_validator, - field_validator, computed_field, - ConfigDict, field_serializer, + field_validator, + model_validator, ) from typing_extensions import Self from generalresearch.models import DeviceType, Source from generalresearch.models.custom_types import ( - UUIDStr, AwareDatetimeISO, - IPvAnyAddressStr, EnumNameSerializer, + IPvAnyAddressStr, + UUIDStr, ) from generalresearch.models.legacy.bucket import Bucket from generalresearch.models.thl import ( @@ -32,15 +31,15 @@ from generalresearch.models.thl import ( int_cents_to_decimal, ) from generalresearch.models.thl.definitions import ( - Status, + WALL_ALLOWED_STATUS_CODE_1_2, + WALL_ALLOWED_STATUS_STATUS_CODE, + ReportValue, SessionAdjustedStatus, - WallAdjustedStatus, + SessionStatusCode2, + Status, StatusCode1, - ReportValue, + WallAdjustedStatus, WallStatusCode2, - SessionStatusCode2, - WALL_ALLOWED_STATUS_CODE_1_2, - WALL_ALLOWED_STATUS_STATUS_CODE, ) from generalresearch.models.thl.user import User diff --git a/generalresearch/models/thl/stats.py b/generalresearch/models/thl/stats.py index fc8be59..9e9971d 100644 --- a/generalresearch/models/thl/stats.py +++ b/generalresearch/models/thl/stats.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import BaseModel, Field, model_validator, computed_field +from pydantic import BaseModel, Field, computed_field, model_validator class StatisticalSummary(BaseModel): diff --git a/generalresearch/models/thl/survey/__init__.py b/generalresearch/models/thl/survey/__init__.py index fd78091..5feb955 100644 --- a/generalresearch/models/thl/survey/__init__.py +++ b/generalresearch/models/thl/survey/__init__.py @@ -1,26 +1,26 @@ from abc import ABC, abstractmethod from decimal import Decimal from itertools import product -from typing import Set, Optional, List, Dict, Type +from typing import Dict, List, Optional, Set, Type from more_itertools import flatten from pydantic import BaseModel, Field from generalresearch.models import Source from generalresearch.models.thl.demographics import ( + AgeGroup, DemographicTarget, Gender, - AgeGroup, ) from generalresearch.models.thl.locales import ( - CountryISOs, - LanguageISOs, CountryISO, + CountryISOs, LanguageISO, + LanguageISOs, ) from generalresearch.models.thl.survey.condition import ( - MarketplaceCondition, ConditionValueType, + MarketplaceCondition, ) @@ -68,8 +68,9 @@ class MarketplaceTask(BaseModel, ABC): @property @abstractmethod def internal_id(self) -> str: - """This is the value that is used for this survey within the marketplace. Typically, - this is survey_id/survey_number. Morning is quota_id, repdata: stream_id. + """This is the value that is used for this survey within the + marketplace. Typically, this is survey_id/survey_number. Morning + is quota_id, repdata: stream_id. """ ... diff --git a/generalresearch/models/thl/survey/buyer.py b/generalresearch/models/thl/survey/buyer.py index 0d235ed..b4d8fb0 100644 --- a/generalresearch/models/thl/survey/buyer.py +++ b/generalresearch/models/thl/survey/buyer.py @@ -1,16 +1,16 @@ -from datetime import timezone, datetime +from datetime import datetime, timezone from decimal import Decimal -from typing import Optional, Annotated - from math import log +from typing import Annotated, Optional + from pydantic import ( - model_validator, BaseModel, ConfigDict, Field, - PositiveInt, NonNegativeInt, + PositiveInt, computed_field, + model_validator, ) from scipy.stats import beta as beta_dist @@ -24,7 +24,8 @@ from generalresearch.models.custom_types import ( class Buyer(BaseModel): """ - The entity that commissions and pays for a task and uses the resulting data or insights. + The entity that commissions and pays for a task and uses the + resulting data or insights. """ model_config = ConfigDict(validate_assignment=True) @@ -176,7 +177,6 @@ class BuyerCountryStat(BaseModel): # ---- Scoring ---- score: float = Field( description="Composite score calculated from all of the individual features", - default=None, examples=[-5.329389837486194], ) diff --git a/generalresearch/models/thl/survey/condition.py b/generalresearch/models/thl/survey/condition.py index a60b034..3610750 100644 --- a/generalresearch/models/thl/survey/condition.py +++ b/generalresearch/models/thl/survey/condition.py @@ -2,19 +2,19 @@ import hashlib from abc import ABC from enum import Enum from functools import cached_property -from typing import List, Dict, Set, Optional, Any, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple from pydantic import ( BaseModel, + ConfigDict, Field, + PrivateAttr, + StringConstraints, computed_field, - ConfigDict, field_validator, model_validator, - StringConstraints, - PrivateAttr, ) -from typing_extensions import Self, Annotated +from typing_extensions import Annotated, Self from generalresearch.models import LogicalOperator diff --git a/generalresearch/models/thl/survey/model.py b/generalresearch/models/thl/survey/model.py index 9e7a402..8d37af4 100644 --- a/generalresearch/models/thl/survey/model.py +++ b/generalresearch/models/thl/survey/model.py @@ -1,28 +1,28 @@ -from datetime import timezone, datetime +from datetime import datetime, timezone from decimal import Decimal -from typing import Optional, List, Tuple, Dict -from typing_extensions import Annotated +from typing import Any, Dict, List, Optional, Tuple from pydantic import ( BaseModel, ConfigDict, Field, + NonNegativeFloat, + NonNegativeInt, PositiveInt, - model_validator, computed_field, - NonNegativeInt, - NonNegativeFloat, field_validator, + model_validator, ) +from typing_extensions import Annotated from generalresearch.managers.thl.buyer import Buyer from generalresearch.models import Source from generalresearch.models.custom_types import ( AwareDatetimeISO, CountryISOLike, - SurveyKey, EnumNameSerializer, PropertyCode, + SurveyKey, ) from generalresearch.models.thl.category import Category from generalresearch.models.thl.definitions import Status, StatusCode1 @@ -273,7 +273,7 @@ class TaskActivityPublic(BaseModel): last_entrance: Optional[AwareDatetimeISO] = Field(default=None) @field_validator("status_code_1_percentages", mode="before") - def transform_enum_name_pct(cls, value: dict) -> dict: + def transform_enum_name_pct(cls, value: Dict[str, Any]) -> Dict[str, Any]: # If we are serializing+deserializing this model (i.e. when we cache # it), this fails because we've replaced the enum value with the # name. Put it back here ... @@ -297,7 +297,7 @@ class TaskActivityPrivate(TaskActivityPublic): ) @field_validator("status_code_1_counts", mode="before") - def transform_enum_name_cnt(cls, value: dict) -> dict: + def transform_enum_name_cnt(cls, value: Dict[str, Any]) -> Dict[str, Any]: # If we are serializing+deserializing this model (i.e. when we cache # it), this fails because we've replaced the enum value with the # name. Put it back here ... diff --git a/generalresearch/models/thl/survey/penalty.py b/generalresearch/models/thl/survey/penalty.py index 915254f..a9c8a56 100644 --- a/generalresearch/models/thl/survey/penalty.py +++ b/generalresearch/models/thl/survey/penalty.py @@ -1,5 +1,5 @@ import abc -from datetime import timezone, datetime +from datetime import datetime, timezone from typing import List, Literal, Union from pydantic import BaseModel, ConfigDict, Field, TypeAdapter @@ -7,8 +7,8 @@ from typing_extensions import Annotated from generalresearch.models import Source from generalresearch.models.custom_types import ( - UUIDStr, AwareDatetimeISO, + UUIDStr, ) diff --git a/generalresearch/models/thl/survey/task_collection.py b/generalresearch/models/thl/survey/task_collection.py index 68c2ff3..c41d8b9 100644 --- a/generalresearch/models/thl/survey/task_collection.py +++ b/generalresearch/models/thl/survey/task_collection.py @@ -6,7 +6,7 @@ from typing import List import pandas as pd import pandera from pandera import DataFrameSchema -from pydantic import Field, ConfigDict, BaseModel, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from generalresearch.models.thl.survey import MarketplaceTask diff --git a/generalresearch/models/thl/synchronize_global_vars.py b/generalresearch/models/thl/synchronize_global_vars.py index 4a587e1..72d987d 100644 --- a/generalresearch/models/thl/synchronize_global_vars.py +++ b/generalresearch/models/thl/synchronize_global_vars.py @@ -1,6 +1,6 @@ from typing import List -from pydantic import Field, BaseModel +from pydantic import BaseModel, Field class SynchronizeGlobalVarsMsg(BaseModel): diff --git a/generalresearch/models/thl/task_adjustment.py b/generalresearch/models/thl/task_adjustment.py index afb2a6a..9d8a04d 100644 --- a/generalresearch/models/thl/task_adjustment.py +++ b/generalresearch/models/thl/task_adjustment.py @@ -5,8 +5,8 @@ from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator -from generalresearch.models import Source, MAX_INT32 -from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models import MAX_INT32, Source +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr from generalresearch.models.thl.definitions import ( WallAdjustedStatus, ) diff --git a/generalresearch/models/thl/task_status.py b/generalresearch/models/thl/task_status.py index f476286..a7aec91 100644 --- a/generalresearch/models/thl/task_status.py +++ b/generalresearch/models/thl/task_status.py @@ -1,39 +1,39 @@ from datetime import datetime -from typing import Dict, Optional, Any, Literal, Annotated, List +from typing import Annotated, Any, Dict, List, Literal, Optional from pydantic import ( BaseModel, Field, - model_validator, NonNegativeInt, computed_field, - field_validator, field_serializer, + field_validator, + model_validator, ) from typing_extensions import Self from generalresearch.models.custom_types import ( - UUIDStr, AwareDatetimeISO, EnumNameSerializer, + UUIDStr, ) from generalresearch.models.thl import decimal_to_int_cents from generalresearch.models.thl.definitions import ( - StatusCode1, + SessionAdjustedStatus, SessionStatusCode2, Status, - SessionAdjustedStatus, + StatusCode1, ) from generalresearch.models.thl.pagination import Page from generalresearch.models.thl.payout_format import ( - PayoutFormatType, PayoutFormatOptionalField, + PayoutFormatType, ) from generalresearch.models.thl.product import ( PayoutTransformation, Product, ) -from generalresearch.models.thl.session import WallOut, Session +from generalresearch.models.thl.session import Session, WallOut # API uses the ints, b/c this is what the grpc returned originally ... STATUS_MAP = { diff --git a/generalresearch/models/thl/user.py b/generalresearch/models/thl/user.py index e393830..852766c 100644 --- a/generalresearch/models/thl/user.py +++ b/generalresearch/models/thl/user.py @@ -3,20 +3,20 @@ from __future__ import annotations import json import logging import re -from datetime import timezone, datetime -from typing import Optional, Dict, List, TYPE_CHECKING -from uuid import uuid4, UUID +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Dict, List, Optional +from uuid import UUID, uuid4 from pydantic import ( + AfterValidator, AwareDatetime, - Field, BaseModel, - field_validator, - model_validator, - PositiveInt, ConfigDict, + Field, + PositiveInt, StringConstraints, - AfterValidator, + field_validator, + model_validator, ) from sentry_sdk import set_tag, set_user from typing_extensions import Annotated, Self @@ -30,10 +30,10 @@ from generalresearch.models.thl.userhealth import AuditLog from generalresearch.pg_helper import PostgresConfig if TYPE_CHECKING: - from generalresearch.managers.thl.userhealth import AuditLogManager from generalresearch.managers.thl.ledger_manager.thl_ledger import ( ThlLedgerManager, ) + from generalresearch.managers.thl.userhealth import AuditLogManager # from generalresearch.managers.thl.userhealth import UserIpHistoryManager diff --git a/generalresearch/models/thl/user_iphistory.py b/generalresearch/models/thl/user_iphistory.py index ecdbc7a..159b97b 100644 --- a/generalresearch/models/thl/user_iphistory.py +++ b/generalresearch/models/thl/user_iphistory.py @@ -1,12 +1,12 @@ import ipaddress -from datetime import timezone, datetime, timedelta -from typing import List, Optional, Dict +from datetime import datetime, timedelta, timezone +from typing import Dict, List, Optional from faker import Faker from pydantic import ( BaseModel, - Field, ConfigDict, + Field, PositiveInt, field_validator, ) @@ -14,8 +14,8 @@ from typing_extensions import Self from generalresearch.models.custom_types import ( AwareDatetimeISO, - IPvAnyAddressStr, CountryISOLike, + IPvAnyAddressStr, ) from generalresearch.models.thl.ipinfo import ( GeoIPInformation, diff --git a/generalresearch/models/thl/user_profile.py b/generalresearch/models/thl/user_profile.py index 6514c7d..98b3326 100644 --- a/generalresearch/models/thl/user_profile.py +++ b/generalresearch/models/thl/user_profile.py @@ -1,16 +1,16 @@ import hashlib -from typing import Optional, Dict, Any, List +from typing import Any, Dict, List, Optional from pydantic import ( - Field, BaseModel, ConfigDict, EmailStr, + Field, PositiveInt, computed_field, ) from pydantic.json_schema import SkipJsonSchema -from typing_extensions import Self, Annotated +from typing_extensions import Annotated, Self from generalresearch.models import MAX_INT32, Source from generalresearch.models.custom_types import UUIDStr diff --git a/generalresearch/models/thl/user_quality_event.py b/generalresearch/models/thl/user_quality_event.py index f98e42d..2b4873a 100644 --- a/generalresearch/models/thl/user_quality_event.py +++ b/generalresearch/models/thl/user_quality_event.py @@ -7,8 +7,8 @@ from typing import List, Literal, Optional from pydantic import BaseModel, Field, PositiveInt -from generalresearch.models import Source, MAX_INT32 -from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models import MAX_INT32, Source +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr from generalresearch.models.thl.definitions import WallAdjustedStatus from generalresearch.models.thl.user import BPUIDStr from generalresearch.utils.enum import ReprEnumMeta diff --git a/generalresearch/models/thl/user_streak.py b/generalresearch/models/thl/user_streak.py index fa6d3b1..5d13bc5 100644 --- a/generalresearch/models/thl/user_streak.py +++ b/generalresearch/models/thl/user_streak.py @@ -1,20 +1,20 @@ from datetime import date, datetime, timedelta from enum import Enum from typing import Optional, Tuple -from zoneinfo import ZoneInfo import pandas as pd from pydantic import ( + AwareDatetime, BaseModel, - NonNegativeInt, + ConfigDict, Field, + NonNegativeInt, + PositiveInt, computed_field, - AwareDatetime, model_validator, - ConfigDict, - PositiveInt, ) from pydantic.json_schema import SkipJsonSchema +from zoneinfo import ZoneInfo from generalresearch.managers.leaderboard import country_timezone from generalresearch.models import MAX_INT32 diff --git a/generalresearch/models/thl/userhealth.py b/generalresearch/models/thl/userhealth.py index 8ea81dd..fd275de 100644 --- a/generalresearch/models/thl/userhealth.py +++ b/generalresearch/models/thl/userhealth.py @@ -1,8 +1,8 @@ from datetime import datetime, timezone from enum import Enum -from typing import Optional, Dict +from typing import Dict, Optional -from pydantic import Field, BaseModel, PositiveInt, NonNegativeFloat +from pydantic import BaseModel, Field, NonNegativeFloat, PositiveInt from typing_extensions import Self from generalresearch.models.custom_types import AwareDatetimeISO diff --git a/generalresearch/models/thl/wallet/__init__.py b/generalresearch/models/thl/wallet/__init__.py index 928d67f..e3f5144 100644 --- a/generalresearch/models/thl/wallet/__init__.py +++ b/generalresearch/models/thl/wallet/__init__.py @@ -13,7 +13,7 @@ class PayoutType(str, Enum, metaclass=ReprEnumMeta): # User is paid out to their personal PayPal email address PAYPAL = "PAYPAL" - # User is paid uut via a Tango Gift Card + # User is paid out via a Tango Gift Card TANGO = "TANGO" # DWOLLA DWOLLA = "DWOLLA" diff --git a/generalresearch/models/thl/wallet/cashout_method.py b/generalresearch/models/thl/wallet/cashout_method.py index def724d..c04bd66 100644 --- a/generalresearch/models/thl/wallet/cashout_method.py +++ b/generalresearch/models/thl/wallet/cashout_method.py @@ -4,31 +4,31 @@ import hashlib import logging from datetime import datetime, timezone from enum import Enum -from typing import List, Dict, Any, Optional, Literal, Union +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import ( BaseModel, - Field, ConfigDict, + EmailStr, + Field, NonNegativeInt, PositiveInt, - EmailStr, - model_validator, field_validator, + model_validator, ) from typing_extensions import Self from generalresearch.currency import USDCent from generalresearch.models.custom_types import ( - UUIDStr, - HttpsUrlStr, AwareDatetimeISO, + HttpsUrlStr, + UUIDStr, ) from generalresearch.models.legacy.api_status import StatusResponse from generalresearch.models.thl.definitions import PayoutStatus from generalresearch.models.thl.locales import CountryISO from generalresearch.models.thl.user import BPUIDStr, User -from generalresearch.models.thl.wallet import PayoutType, Currency +from generalresearch.models.thl.wallet import Currency, PayoutType from generalresearch.utils.enum import ReprEnumMeta logger = logging.getLogger() diff --git a/generalresearch/models/thl/wallet/payout.py b/generalresearch/models/thl/wallet/payout.py index 97e0c3d..72aea61 100644 --- a/generalresearch/models/thl/wallet/payout.py +++ b/generalresearch/models/thl/wallet/payout.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timezone -from typing import Dict, Optional, Collection, List +from typing import Any, Collection, Dict, List, Optional, Union from uuid import uuid4 from pydantic import ( @@ -12,7 +12,7 @@ from pydantic import ( ) from generalresearch.currency import USDCent -from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr from generalresearch.models.thl.definitions import PayoutStatus from generalresearch.models.thl.wallet import PayoutType from generalresearch.models.thl.wallet.cashout_method import ( @@ -74,11 +74,11 @@ class PayoutEvent(BaseModel, validate_assignment=True): # Stores payout-type-specific information that is used to request this # payout from the external provider. - request_data: Dict = Field(default_factory=dict) + request_data: Dict[str, Any] = Field(default_factory=dict) # Stores payout-type-specific order information that is returned from # the external payout provider. - order_data: Optional[Dict | CashMailOrderData] = Field(default=None) + order_data: Optional[Union[Dict[str, Any], CashMailOrderData]] = Field(default=None) @field_validator("payout_type", mode="before") @classmethod @@ -94,7 +94,7 @@ class PayoutEvent(BaseModel, validate_assignment=True): self, status: PayoutStatus, ext_ref_id: Optional[str] = None, - order_data: Optional[Dict] = None, + order_data: Optional[Dict[str, Any]] = None, ) -> None: # These 3 things are the only modifiable attributes self.check_status_change_allowed(status) @@ -128,7 +128,7 @@ class PayoutEvent(BaseModel, validate_assignment=True): else: raise ValueError("this shouldn't happen") - def model_dump_mysql(self, *args, **kwargs) -> dict: + def model_dump_mysql(self, *args, **kwargs) -> Dict[str, Any]: d = self.model_dump(mode="json", *args, **kwargs) if "created" in d: d["created"] = self.created.replace(tzinfo=None) diff --git a/generalresearch/models/thl/wallet/user_wallet.py b/generalresearch/models/thl/wallet/user_wallet.py index 917a09d..dbe66fa 100644 --- a/generalresearch/models/thl/wallet/user_wallet.py +++ b/generalresearch/models/thl/wallet/user_wallet.py @@ -2,12 +2,12 @@ from __future__ import annotations import logging -from pydantic import BaseModel, Field, ConfigDict, NonNegativeInt +from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt from generalresearch.models.legacy.api_status import StatusResponse from generalresearch.models.thl.payout_format import ( - PayoutFormatType, PayoutFormatField, + PayoutFormatType, ) logger = logging.getLogger() diff --git a/generalresearch/pg_helper.py b/generalresearch/pg_helper.py index 064a883..f7c9675 100644 --- a/generalresearch/pg_helper.py +++ b/generalresearch/pg_helper.py @@ -1,15 +1,14 @@ +from datetime import timezone from typing import Optional +import psycopg from psycopg.adapt import Buffer -from psycopg.types.net import InetLoader, Address, Interface +from psycopg.rows import RowFactory, dict_row +from psycopg.types.datetime import TimestampLoader +from psycopg.types.net import Address, InetLoader, Interface from psycopg.types.string import TextLoader -from pydantic import PostgresDsn - -import psycopg -from psycopg.rows import dict_row, RowFactory from psycopg.types.uuid import UUIDLoader -from psycopg.types.datetime import TimestampLoader -from datetime import timezone +from pydantic import PostgresDsn class UUIDHexLoader(UUIDLoader): diff --git a/generalresearch/schemas/survey_stats.py b/generalresearch/schemas/survey_stats.py index 532e0d7..1570554 100644 --- a/generalresearch/schemas/survey_stats.py +++ b/generalresearch/schemas/survey_stats.py @@ -1,5 +1,5 @@ import pandas as pd -from pandera import DataFrameSchema, Column, Check, Index +from pandera import Check, Column, DataFrameSchema, Index from generalresearch.locales import Localelator from generalresearch.models import Source diff --git a/generalresearch/sql_helper.py b/generalresearch/sql_helper.py index b92813b..175830c 100644 --- a/generalresearch/sql_helper.py +++ b/generalresearch/sql_helper.py @@ -1,8 +1,8 @@ import logging -from typing import Any, Dict, List, Tuple, Union, Optional +from typing import Any, Dict, List, Optional, Tuple, Union from uuid import UUID -from pydantic import MySQLDsn, PostgresDsn, MariaDBDsn +from pydantic import MariaDBDsn, MySQLDsn, PostgresDsn from pymysql import Connection ListOrTupleOfStrings = Union[List[str], Tuple[str, ...]] @@ -119,7 +119,7 @@ def is_uuid4(s: Any) -> bool: return False -def decode_uuids(row: Dict) -> Dict: +def decode_uuids(row: Dict[str, Any]) -> Dict[str, Any]: return { key: (UUID(value, version=4).hex if is_uuid4(value) else value) for key, value in row.items() @@ -131,7 +131,9 @@ class SqlHelper(SqlConnector): def __init__(self, dsn: Optional[DataBaseDsn] = None, **kwargs): super(SqlHelper, self).__init__(dsn, **kwargs) - def execute_sql_query(self, query, params=None, commit=False) -> List[Dict]: + def execute_sql_query( + self, query: str, params: Optional[Dict[str, Any]] = None, commit: bool = False + ) -> List[Dict[str, Any]]: for param in params if params else []: if isinstance(param, (tuple, list, set)) and len(param) == 0: logging.warning("param is empty. not executing query") @@ -157,7 +159,7 @@ class SqlHelper(SqlConnector): field_names: ListOrTupleOfStrings, values_to_insert: ListOrTupleOfListOrTuple, cursor=None, - ignore_existing=False, + ignore_existing: bool = False, ) -> None: """ :param table_name: name of table diff --git a/generalresearch/utils/aggregation.py b/generalresearch/utils/aggregation.py index b168e4c..4023dc9 100644 --- a/generalresearch/utils/aggregation.py +++ b/generalresearch/utils/aggregation.py @@ -1,8 +1,8 @@ from collections import defaultdict -from typing import Dict, List +from typing import Any, Dict, List -def group_by_year(records: List[Dict], datetime_field: str) -> Dict[int, List]: +def group_by_year(records: List[Dict], datetime_field: str) -> Dict[int, List[Any]]: """Memory efficient - processes records one at a time""" by_year = defaultdict(list) diff --git a/generalresearch/utils/grpc_logger.py b/generalresearch/utils/grpc_logger.py index a98a13b..59f7471 100644 --- a/generalresearch/utils/grpc_logger.py +++ b/generalresearch/utils/grpc_logger.py @@ -1,7 +1,7 @@ import json import logging -from logging.handlers import TimedRotatingFileHandler import time +from logging.handlers import TimedRotatingFileHandler handler = TimedRotatingFileHandler( "grpc_access.log", when="midnight", backupCount=3, encoding="utf-8" diff --git a/generalresearch/wall_status_codes/__init__.py b/generalresearch/wall_status_codes/__init__.py index 102b72a..3a0abb8 100644 --- a/generalresearch/wall_status_codes/__init__.py +++ b/generalresearch/wall_status_codes/__init__.py @@ -1,21 +1,21 @@ -from typing import Tuple, Optional +from typing import Optional, Tuple from generalresearch.models import Source from generalresearch.models.thl.definitions import Status, StatusCode1 from generalresearch.models.thl.session import Wall from generalresearch.wall_status_codes import ( + cint, dynata, fullcircle, innovate, + lucid, morning, pollfish, precision, - spectrum, - sago, - cint, - lucid, prodege, repdata, + sago, + spectrum, ) @@ -36,6 +36,7 @@ def annotate_status_code( return Status.FAIL, StatusCode1.UNKNOWN, None if source == Source.PULLEY: return Status.FAIL, StatusCode1.UNKNOWN, None + return { Source.CINT: cint.annotate_status_code, Source.DYNATA: dynata.annotate_status_code, diff --git a/generalresearch/wall_status_codes/cint.py b/generalresearch/wall_status_codes/cint.py index 0fa929c..8042cd2 100644 --- a/generalresearch/wall_status_codes/cint.py +++ b/generalresearch/wall_status_codes/cint.py @@ -1,5 +1,6 @@ -from typing import Optional +from typing import Any, Optional, Tuple +from generalresearch.models.thl.definitions import Status, StatusCode1 from generalresearch.wall_status_codes import lucid @@ -7,7 +8,7 @@ def annotate_status_code( ext_status_code_1: str, ext_status_code_2: Optional[str] = None, ext_status_code_3: Optional[str] = None, -): +) -> Tuple[Status, StatusCode1, Optional[Any]]: return lucid.annotate_status_code( ext_status_code_1=ext_status_code_1, ext_status_code_2=ext_status_code_2, diff --git a/generalresearch/wall_status_codes/dynata.py b/generalresearch/wall_status_codes/dynata.py index f37fca5..958e06f 100644 --- a/generalresearch/wall_status_codes/dynata.py +++ b/generalresearch/wall_status_codes/dynata.py @@ -4,11 +4,11 @@ checked by Greg 2023-10-10 """ from collections import defaultdict -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from generalresearch.models.thl.definitions import Status, StatusCode1 -status_codes_name = { +status_codes_name: Dict[str, str] = { "0.0": "Unknown", "0.1": "Missing Language", "0.2": "Missing Respondent ID", @@ -51,10 +51,10 @@ status_codes_name = { "5.10": "Daily Limit", } -status_map = defaultdict( +status_map: Dict[str, Status] = defaultdict( lambda: Status.FAIL, **{"1.0": Status.COMPLETE, "1.1": Status.COMPLETE} ) -status_codes_ext_map = { +status_codes_ext_map: Dict[StatusCode1, List[str]] = { StatusCode1.COMPLETE: ["1.0", "1.1"], StatusCode1.BUYER_FAIL: ["2.2", "3.2"], StatusCode1.BUYER_QUALITY_FAIL: ["5.1", "5.2"], @@ -88,9 +88,13 @@ status_codes_ext_map = { "5.10", ], } -ext_status_code_map = dict() +ext_status_code_map: Dict[str, StatusCode1] = dict() for k, v in status_codes_ext_map.items(): + k: StatusCode1 + v: List[str] + for vv in v: + vv: str ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k @@ -98,7 +102,7 @@ def annotate_status_code( ext_status_code_1: str, ext_status_code_2: Optional[str] = None, ext_status_code_3: Optional[str] = None, -) -> Tuple: +) -> Tuple[Status, StatusCode1, Optional[Any]]: """ :params ext_status_code_1: this is from the callback url params: disposition and status, '.'-joined @@ -113,7 +117,7 @@ def annotate_status_code( return status, status_code, None -def stop_marketplace_session(status_code_1, ext_status_code_1) -> bool: +def stop_marketplace_session(status_code_1: StatusCode1, ext_status_code_1) -> bool: if ext_status_code_1.startswith("5"): # '5.10' is the user hit a Daily Limit, so they should not be sent in again today return True diff --git a/generalresearch/wall_status_codes/fullcircle.py b/generalresearch/wall_status_codes/fullcircle.py index 08f2c44..eda4d9c 100644 --- a/generalresearch/wall_status_codes/fullcircle.py +++ b/generalresearch/wall_status_codes/fullcircle.py @@ -7,11 +7,11 @@ we'll try to infer based on the time spent in survey. from collections import defaultdict from datetime import timedelta -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from generalresearch.models.thl.definitions import Status, StatusCode1 -status_codes_map = { +status_codes_map: Dict[str, str] = { "1": "Complete", "2": "Terminate", "3": "Over-quota", @@ -19,7 +19,7 @@ status_codes_map = { } status_map = defaultdict(lambda: Status.FAIL, **{"1": Status.COMPLETE}) -status_codes_ext_map = { +status_codes_ext_map: Dict[StatusCode1, List[str]] = { StatusCode1.COMPLETE: ["1"], StatusCode1.BUYER_FAIL: ["2", "3"], StatusCode1.BUYER_QUALITY_FAIL: ["4"], @@ -29,9 +29,13 @@ status_codes_ext_map = { StatusCode1.PS_FAIL: [], StatusCode1.PS_OVERQUOTA: [], } -ext_status_code_map = dict() +ext_status_code_map: Dict[str, StatusCode1] = dict() for k, v in status_codes_ext_map.items(): + k: StatusCode1 + v: List[str] + for vv in v: + vv: str ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k @@ -39,7 +43,7 @@ def annotate_status_code( ext_status_code_1: str, ext_status_code_2: Optional[str] = None, ext_status_code_3: Optional[str] = None, -) -> Tuple: +) -> Tuple[Status, StatusCode1, Optional[Any]]: """ :params ext_status_code_1: this is from the callback url param 's' :params ext_status_code_2: not used diff --git a/generalresearch/wall_status_codes/innovate.py b/generalresearch/wall_status_codes/innovate.py index 0d11425..b650ade 100644 --- a/generalresearch/wall_status_codes/innovate.py +++ b/generalresearch/wall_status_codes/innovate.py @@ -9,11 +9,11 @@ can map directly, and some we have to look at the category. """ from collections import defaultdict -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from generalresearch.models.thl.definitions import Status, StatusCode1 -status_codes_innovate = { +status_codes_innovate: Dict[str, str] = { "1": "Complete", "2": "Buyer Fail", "3": "Buyer Over Quota", @@ -29,7 +29,7 @@ status_map = defaultdict( lambda: Status.FAIL, **{"1": Status.COMPLETE, "0": Status.ABANDON, "6": Status.ABANDON}, ) -status_codes_ext_map = { +status_codes_ext_map: Dict[StatusCode1, List[str]] = { StatusCode1.BUYER_FAIL: ["2", "3"], StatusCode1.BUYER_QUALITY_FAIL: ["4"], StatusCode1.PS_BLOCKED: [], @@ -43,14 +43,14 @@ for k, v in status_codes_ext_map.items(): for vv in v: ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k -category_innovate = { +category_innovate: Dict[str, StatusCode1] = { "Selected threat potential score at joblevel not allow the survey": StatusCode1.PS_QUALITY, "OE Validation": StatusCode1.PS_QUALITY, "Unique IP": StatusCode1.PS_DUPLICATE, "Unique PID": StatusCode1.PS_DUPLICATE, # 'Duplicated to token {token} and Group {groupID}': StatusCode1.PS_DUPLICATE, # 'Duplicate Due to Multi Groups: Token {token} and Group {groupID}': StatusCode1.PS_DUPLICATE, - # todo: we should not send them into this marketplace for a day? + # TODO: we should not send them into this marketplace for a day? "User has attended {count} survey in 5 range": StatusCode1.PS_FAIL, "PII_OPT": StatusCode1.PS_QUALITY, "Recaptcha": StatusCode1.PS_QUALITY, @@ -80,7 +80,7 @@ def annotate_status_code( ext_status_code_1: str, ext_status_code_2: Optional[str] = None, ext_status_code_3: Optional[str] = None, -) -> Tuple: +) -> Tuple[Status, StatusCode1, Optional[Any]]: """ Only quality terminate (4 and 8), and PS term (5) return a term_reason (af=). @@ -93,13 +93,16 @@ def annotate_status_code( status = status_map[ext_status_code_1] if status == Status.COMPLETE: return status, StatusCode1.COMPLETE, None + # First use the 1 through 8 status code. Then, using the reason (af=), if available, # try to maybe reclassify it. if ext_status_code_1 not in ext_status_code_map: return status, StatusCode1.UNKNOWN, None + status_code = ext_status_code_map.get(ext_status_code_1, StatusCode1.UNKNOWN) if ext_status_code_2 in category_innovate: status_code = category_innovate[ext_status_code_2] + if ext_status_code_2: # Some of these have ids in them... so we have to pattern match it if ( diff --git a/generalresearch/wall_status_codes/lucid.py b/generalresearch/wall_status_codes/lucid.py index bf49b45..c0098e6 100644 --- a/generalresearch/wall_status_codes/lucid.py +++ b/generalresearch/wall_status_codes/lucid.py @@ -5,11 +5,11 @@ https://support.lucidhq.com/s/article/Collecting-Data-From-Redirects """ from collections import defaultdict -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from generalresearch.models.thl.definitions import Status, StatusCode1 -mp_codes = { +mp_codes: Dict[str, str] = { "-6": "Pre-Client Intermediary Page Drop Off", "-5": "Failure in the Post Answer Behavior", "-1": "Failure to Load the Lucid Marketplace", @@ -54,7 +54,7 @@ mp_codes = { } # todo: finish, there's a bunch more -client_status_map = { +client_status_map: Dict[str, StatusCode1] = { "30": StatusCode1.BUYER_QUALITY_FAIL, "33": StatusCode1.BUYER_QUALITY_FAIL, "34": StatusCode1.BUYER_QUALITY_FAIL, @@ -62,7 +62,7 @@ client_status_map = { } status_map = defaultdict(lambda: Status.FAIL, **{"s": Status.COMPLETE}) -status_codes_ext_map = { +status_codes_ext_map: Dict[StatusCode1, List[str]] = { StatusCode1.COMPLETE: [], StatusCode1.BUYER_FAIL: ["3"], StatusCode1.BUYER_QUALITY_FAIL: [], @@ -101,9 +101,15 @@ status_codes_ext_map = { ], StatusCode1.PS_OVERQUOTA: ["40", "41", "42"], } -ext_status_code_map = dict() + +ext_status_code_map: Dict[str, StatusCode1] = dict() for k, v in status_codes_ext_map.items(): + k: StatusCode1 + v: List[str] + for vv in v: + vv: str + ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k @@ -111,18 +117,27 @@ def annotate_status_code( ext_status_code_1: str, ext_status_code_2: Optional[str] = None, ext_status_code_3: Optional[str] = None, -) -> Tuple: +) -> Tuple[Status, StatusCode1, Optional[Any]]: """ :params ext_status_code_1: this indicates which callback url was hit. possible values {'s', *anything else*} :params ext_status_code_2: this is from the callback url params: InitialStatus :params ext_status_code_3: this is from the callback url params: ClientStatus + returns: (status, status_code_1, status_code_2) """ - status = status_map[ext_status_code_1] + status: Status = status_map[ext_status_code_1] + if ext_status_code_2 == "3": - status_code = client_status_map.get(ext_status_code_3, StatusCode1.BUYER_FAIL) + status_code = client_status_map.get( + key=ext_status_code_3, default=StatusCode1.BUYER_FAIL + ) + else: - status_code = ext_status_code_map.get(ext_status_code_2, StatusCode1.UNKNOWN) + status_code = ext_status_code_map.get( + key=ext_status_code_2, default=StatusCode1.UNKNOWN + ) + if status == Status.COMPLETE: status_code = StatusCode1.COMPLETE + return status, status_code, None diff --git a/generalresearch/wall_status_codes/morning.py b/generalresearch/wall_status_codes/morning.py index 2539b55..318d1b2 100644 --- a/generalresearch/wall_status_codes/morning.py +++ b/generalresearch/wall_status_codes/morning.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from generalresearch.models.thl.definitions import Status, StatusCode1 @@ -17,7 +17,7 @@ timeout: The respondent completed the survey after the timeout period had expire in_progress: The respondent interview session is still in progress, such as in the prescreener or survey. """ -short_code_to_status_codes_morning = { +short_code_to_status_codes_morning: Dict[str, str] = { "att_che": "attention_check", "banned": "banned", "bid_clo": "bid_closed", @@ -53,7 +53,8 @@ short_code_to_status_codes_morning = { "tem_ban": "temporarily_banned", } status_map = defaultdict(lambda: Status.FAIL, **{"complete": Status.COMPLETE}) -status_codes_ext_map = { + +status_codes_ext_map: Dict[StatusCode1, List[str]] = { StatusCode1.COMPLETE: ["complete"], StatusCode1.BUYER_FAIL: [ "in_survey_failure", @@ -96,9 +97,13 @@ status_codes_ext_map = { "quota_invalid_for_bid", ], } -ext_status_code_map = dict() +ext_status_code_map: Dict[str, StatusCode1] = dict() for k, v in status_codes_ext_map.items(): + k: StatusCode1 + v: List[str] + for vv in v: + vv: str ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k @@ -106,7 +111,7 @@ def annotate_status_code( ext_status_code_1: str, ext_status_code_2: Optional[str] = None, ext_status_code_3: Optional[str] = None, -) -> Tuple: +) -> Tuple[Status, StatusCode1, Optional[Any]]: """ :params ext_status_code_1: from callback url params: &sti={{status_id}} :params ext_status_code_2: from callback url params: &sdi={{status_detail_id}} diff --git a/generalresearch/wall_status_codes/pollfish.py b/generalresearch/wall_status_codes/pollfish.py index fbe1169..361785d 100644 --- a/generalresearch/wall_status_codes/pollfish.py +++ b/generalresearch/wall_status_codes/pollfish.py @@ -1,9 +1,9 @@ from collections import defaultdict -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from generalresearch.models.thl.definitions import Status, StatusCode1 -status_codes_map = { +status_codes_map: Dict[str, str] = { "quo_ful": "quota_full", "sur_clo": "survey_closed", "profilin": "profiling", @@ -29,7 +29,7 @@ status_codes_map = { "complete": "complete", } status_map = defaultdict(lambda: Status.FAIL, **{"complete": Status.COMPLETE}) -status_codes_ext_map = { +status_codes_ext_map: Dict[StatusCode1, List[str]] = { StatusCode1.COMPLETE: ["complete"], StatusCode1.BUYER_FAIL: ["third_party_termination", "screenout"], StatusCode1.BUYER_QUALITY_FAIL: [ @@ -60,7 +60,11 @@ status_codes_ext_map = { } ext_status_code_map = dict() for k, v in status_codes_ext_map.items(): + k: StatusCode1 + v: List[str] + for vv in v: + vv: str ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k @@ -68,11 +72,12 @@ def annotate_status_code( ext_status_code_1: str, ext_status_code_2: Optional[str] = None, ext_status_code_3: Optional[str] = None, -) -> Tuple: +) -> Tuple[Status, StatusCode1, Optional[Any]]: """ :params ext_status_code_1: from callback url params: &sti={{status_id}} :params ext_status_code_2: from callback url params: &sdi={{status_detail_id}} :params ext_status_code_3: not used + returns: (status, status_code_1, status_code_2) """ status = status_map[ext_status_code_1] diff --git a/generalresearch/wall_status_codes/precision.py b/generalresearch/wall_status_codes/precision.py index 147685a..ffeaeca 100644 --- a/generalresearch/wall_status_codes/precision.py +++ b/generalresearch/wall_status_codes/precision.py @@ -7,11 +7,11 @@ f - client approved the Preliminary complete as Final Complete """ from collections import defaultdict -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from generalresearch.models.thl.definitions import Status, StatusCode1 -status_codes_precision = { +status_codes_precision: Dict[str, str] = { "10": "Complete", "20": "Client Terminate", "21": "PS Terminate", @@ -46,7 +46,7 @@ status_codes_precision = { "80": "Final Complete", } status_map = defaultdict(lambda: Status.FAIL, **{"s": Status.COMPLETE}) -status_codes_ext_map = { +status_codes_ext_map: Dict[StatusCode1, List[str]] = { StatusCode1.COMPLETE: ["10"], StatusCode1.BUYER_FAIL: ["20", "30"], StatusCode1.BUYER_QUALITY_FAIL: ["60"], @@ -75,7 +75,11 @@ status_codes_ext_map = { } ext_status_code_map = dict() for k, v in status_codes_ext_map.items(): + k: StatusCode1 + v: List[str] + for vv in v: + vv: str ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k @@ -83,7 +87,7 @@ def annotate_status_code( ext_status_code_1: str, ext_status_code_2: Optional[str] = None, ext_status_code_3: Optional[str] = None, -) -> Tuple: +) -> Tuple[Status, StatusCode1, Optional[Any]]: """ :params ext_status_code_1: from callback url params: status :params ext_status_code_2: from callback url params: code diff --git a/generalresearch/wall_status_codes/prodege.py b/generalresearch/wall_status_codes/prodege.py index 8ce4d22..a1e25ea 100644 --- a/generalresearch/wall_status_codes/prodege.py +++ b/generalresearch/wall_status_codes/prodege.py @@ -3,12 +3,12 @@ https://developer.prodege.com/surveys-feed/term-reasons """ from collections import defaultdict -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from generalresearch.models.thl.definitions import Status, StatusCode1 status_map = defaultdict(lambda: Status.FAIL, **{"1": Status.COMPLETE}) -status_code_map = { +status_code_map: Dict[StatusCode1, List[str]] = { StatusCode1.COMPLETE: [], StatusCode1.BUYER_FAIL: ["1", "2"], StatusCode1.BUYER_QUALITY_FAIL: ["10", "12"], @@ -33,7 +33,11 @@ status_code_map = { status_class = dict() for k, v in status_code_map.items(): + k: StatusCode1 + v: List[str] + for vv in v: + vv: str status_class[status_code_map.get(vv, vv)] = k @@ -41,7 +45,7 @@ def annotate_status_code( ext_status_code_1: str, ext_status_code_2: Optional[str] = None, ext_status_code_3: Optional[str] = None, -) -> Tuple: +) -> Tuple[Status, StatusCode1, Optional[Any]]: """ :params ext_status_code_1: status from redirect url :params ext_status_code_2: termreason from redirect url diff --git a/generalresearch/wall_status_codes/repdata.py b/generalresearch/wall_status_codes/repdata.py index 5e575c4..c502532 100644 --- a/generalresearch/wall_status_codes/repdata.py +++ b/generalresearch/wall_status_codes/repdata.py @@ -3,11 +3,11 @@ Status codes are in a xlsx file. See thl-repdata readme """ from collections import defaultdict -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from generalresearch.models.thl.definitions import Status, StatusCode1 -status_codes_name = { +status_codes_name: Dict[str, str] = { "2": "Search Failed", "3": "Activity Failed", "4": "Review Failed", @@ -26,7 +26,7 @@ status_codes_name = { "6003": "In-Survey maximum exceeded (Research Desk)", } # See: 02, and 13 are de-dupes -rd_threat_name = { +rd_threat_name: Dict[str, str] = { "02": "Duplicate entrant into survey", "03": "Emulator Usage", "04": "VPN usage detected", @@ -47,7 +47,7 @@ rd_threat_name = { } status_map = defaultdict(lambda: Status.FAIL, **{"complete": Status.COMPLETE}) -status_code_map = { +status_code_map: Dict[StatusCode1, List[str]] = { StatusCode1.COMPLETE: ["1000"], StatusCode1.BUYER_FAIL: ["2000", "4000"], StatusCode1.BUYER_QUALITY_FAIL: ["3000"], @@ -60,7 +60,11 @@ status_code_map = { status_class = dict() for k, v in status_code_map.items(): + k: StatusCode1 + v: List[str] + for vv in v: + vv: str status_class[status_code_map.get(vv, vv)] = k @@ -68,7 +72,7 @@ def annotate_status_code( ext_status_code_1: str, ext_status_code_2: Optional[str] = None, ext_status_code_3: Optional[str] = None, -) -> Tuple: +) -> Tuple[Status, StatusCode1, Optional[Any]]: """ :params ext_status_code_1: the redirect urls category (as defined in url param 549f3710b) {'term', 'overquota', 'fraud', 'complete'} diff --git a/generalresearch/wall_status_codes/sago.py b/generalresearch/wall_status_codes/sago.py index ac354ce..9c8710f 100644 --- a/generalresearch/wall_status_codes/sago.py +++ b/generalresearch/wall_status_codes/sago.py @@ -3,11 +3,11 @@ https://developer-beta.market-cube.com/api-details#api=definition-api&operation= """ from collections import defaultdict -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from generalresearch.models.thl.definitions import Status, StatusCode1 -status_codes_schlesinger = { +status_codes_schlesinger: Dict[str, str] = { "1": "Complete", "2": "Buyer Fail", "3": "Buyer Fail", @@ -20,7 +20,7 @@ status_codes_schlesinger = { "11": "Abandon", # really it is "Buyer Abandon" } -status_reason_name = { +status_reason_name: Dict[str, str] = { "1": "Not a Unique Sample Cube User", "4": "GeoIP - wrong country", "7": "Duplicate - not a unique IP", @@ -121,7 +121,7 @@ status_map = defaultdict( lambda: Status.FAIL, **{"1": Status.COMPLETE, "0": Status.ABANDON} ) -status_codes_ext_map = { +status_codes_ext_map: Dict[StatusCode1, List[str]] = { StatusCode1.COMPLETE: ["48"], StatusCode1.BUYER_FAIL: ["16", "29", "49", "50", "78", "114", "110", "114"], StatusCode1.BUYER_QUALITY_FAIL: ["26", "52", "68", "81", "84"], @@ -167,9 +167,13 @@ status_codes_ext_map = { StatusCode1.PS_FAIL: ["7", "29", "36", "47", "56", "58", "64"], StatusCode1.PS_OVERQUOTA: ["29", "46", "33", "31"], } -ext_status_code_map = dict() +ext_status_code_map: Dict[str, StatusCode1] = dict() for k, v in status_codes_ext_map.items(): + k: StatusCode1 + v: List[str] + for vv in v: + vv: str ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k @@ -177,17 +181,20 @@ def annotate_status_code( ext_status_code_1: str, ext_status_code_2: Optional[str] = None, ext_status_code_3: Optional[str] = None, -) -> Tuple: +) -> Tuple[Status, StatusCode1, Optional[Any]]: """ :params ext_status_code_1: from callback url params: scstatus :params ext_status_code_2: from callback url params: scsecuritystatus :params ext_status_code_3: not used + returns: (status, status_code_1, status_code_2) """ status = status_map[ext_status_code_1] status_code = ext_status_code_map.get(ext_status_code_2, StatusCode1.UNKNOWN) # According to personal communication, scsecuritystatus may not always # come back for completes. Going to ignore it if the status is complete + if status == Status.COMPLETE: status_code = StatusCode1.COMPLETE + return status, status_code, None diff --git a/generalresearch/wall_status_codes/spectrum.py b/generalresearch/wall_status_codes/spectrum.py index 0cf5814..9e9e0a2 100644 --- a/generalresearch/wall_status_codes/spectrum.py +++ b/generalresearch/wall_status_codes/spectrum.py @@ -3,11 +3,11 @@ https://purespectrum.atlassian.net/wiki/spaces/PA/pages/33613201/Minimizing+Clic """ from collections import defaultdict -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from generalresearch.models.thl.definitions import Status, StatusCode1 -status_codes_spectrum = { +status_codes_spectrum: Dict[str, str] = { "11": "PS Drop", "12": "PS Quota Full Core", "13": "PS Termination Core", @@ -80,7 +80,7 @@ status_codes_spectrum = { "88": "PS_Supplier_Allocation_Throttle", } status_map = defaultdict(lambda: Status.FAIL, **{"21": Status.COMPLETE}) -status_codes_ext_map = { +status_codes_ext_map: Dict[StatusCode1, List[str]] = { StatusCode1.COMPLETE: ["21"], StatusCode1.BUYER_FAIL: ["16", "17", "18", "19", "30", "59", "84"], StatusCode1.BUYER_QUALITY_FAIL: ["20", "31"], @@ -142,7 +142,11 @@ status_codes_ext_map = { } ext_status_code_map = dict() for k, v in status_codes_ext_map.items(): + k: StatusCode1 + v: List[str] + for vv in v: + vv: str ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k @@ -150,7 +154,7 @@ def annotate_status_code( ext_status_code_1: str, ext_status_code_2: Optional[str] = None, ext_status_code_3: Optional[str] = None, -) -> Tuple: +) -> Tuple[Status, StatusCode1, Optional[Any]]: """ :params ext_status_code_1: from url params: ps_rstatus https://purespectrum.atlassian.net/wiki/spaces/PA/pages/33613201/Minimizing+Clickwaste+with+ps+rstatus diff --git a/generalresearch/wall_status_codes/wxet.py b/generalresearch/wall_status_codes/wxet.py index 1cbf568..e7cf67d 100644 --- a/generalresearch/wall_status_codes/wxet.py +++ b/generalresearch/wall_status_codes/wxet.py @@ -1,7 +1,7 @@ from collections import defaultdict -from typing import Optional, Dict, Tuple +from typing import Dict, List, Optional, Tuple -from generalresearch.models.thl.definitions import StatusCode1, Status +from generalresearch.models.thl.definitions import Status, StatusCode1 from generalresearch.wxet.models.definitions import ( WXETStatus, WXETStatusCode1, @@ -11,7 +11,7 @@ from generalresearch.wxet.models.definitions import ( status_map: Dict[WXETStatus, Status] = defaultdict( lambda: Status.FAIL, **{WXETStatus.COMPLETE: Status.COMPLETE} ) -status_codes_ext_map = { +status_codes_ext_map: Dict[StatusCode1, List[WXETStatusCode1]] = { StatusCode1.COMPLETE: [WXETStatusCode1.COMPLETE], StatusCode1.BUYER_FAIL: [ WXETStatusCode1.BUYER_DUPLICATE, @@ -32,10 +32,14 @@ status_codes_ext_map = { } ext_status_code_map = dict() for k, v in status_codes_ext_map.items(): + k: StatusCode1 + v: List[WXETStatusCode1] + for vv in v: + vv: WXETStatusCode1 ext_status_code_map[vv] = k -status_code2_map = { +status_code2_map: Dict[StatusCode1, List[WXETStatusCode2]] = { StatusCode1.PS_QUALITY: [], StatusCode1.PS_DUPLICATE: [ WXETStatusCode2.WORKER_INELIGIBLE, @@ -65,11 +69,12 @@ def annotate_status_code( ext_status_code_1: str, ext_status_code_2: Optional[str] = None, ext_status_code_3: Optional[str] = None, -) -> Tuple: +) -> Tuple[Status, StatusCode1, Optional[WXETStatusCode2]]: """ :params ext_status_code_1: WXETStatus :params ext_status_code_2: WXETStatusCode1 :params ext_status_code_3: WXETStatusCode2 + returns: (status, status_code_1, status_code_2) """ ext_status_code_1 = WXETStatus(ext_status_code_1) diff --git a/generalresearch/wxet/models/definitions.py b/generalresearch/wxet/models/definitions.py index 50853dd..801346d 100644 --- a/generalresearch/wxet/models/definitions.py +++ b/generalresearch/wxet/models/definitions.py @@ -39,10 +39,13 @@ class WXETStatus(str, Enum, metaclass=ReprEnumMeta): class WXETAdjustedStatus(str, Enum, metaclass=ReprEnumMeta): # Task was reconciled to complete ADJUSTED_TO_COMPLETE = "ac" + # Task was reconciled to incomplete ADJUSTED_TO_FAIL = "af" + # The cpi for a task was adjusted CPI_ADJUSTMENT = "ca" + # The user was redirected without a Postback, but the postback was then "immediately" # recieved. The supplier thinks this was a failure. This is distinct from an # actual adjustment to complete. diff --git a/generalresearch/wxet/models/finish_type.py b/generalresearch/wxet/models/finish_type.py index 2aa4c7f..af60fe6 100644 --- a/generalresearch/wxet/models/finish_type.py +++ b/generalresearch/wxet/models/finish_type.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Set, Optional +from typing import Optional, Set from generalresearch.utils.enum import ReprEnumMeta from generalresearch.wxet.models.definitions import WXETStatus, WXETStatusCode1 diff --git a/test_utils/conftest.py b/test_utils/conftest.py index e074301..54fb682 100644 --- a/test_utils/conftest.py +++ b/test_utils/conftest.py @@ -247,13 +247,16 @@ def utc_30days_ago() -> "datetime": @pytest.fixture(scope="function") -def delete_df_collection(thl_web_rw, create_main_accounts) -> Callable: +def delete_df_collection( + thl_web_rw: PostgresConfig, create_main_accounts: Callable[..., None] +) -> Callable[..., None]: + from generalresearch.incite.collections import ( DFCollection, DFCollectionType, ) - def _teardown_events(coll: "DFCollection"): + def _inner(coll: "DFCollection"): match coll.data_type: case DFCollectionType.LEDGER: for table in [ @@ -290,7 +293,7 @@ def delete_df_collection(thl_web_rw, create_main_accounts) -> Callable: query=f"DELETE FROM {coll.data_type.value};", ) - return _teardown_events + return _inner # === GR Related === diff --git a/test_utils/managers/conftest.py b/test_utils/managers/conftest.py index f1f774e..94dabae 100644 --- a/test_utils/managers/conftest.py +++ b/test_utils/managers/conftest.py @@ -1,15 +1,18 @@ from typing import TYPE_CHECKING, Callable -import pymysql import pytest from generalresearch.managers.base import Permission from generalresearch.models import Source +from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig +from generalresearch.sql_helper import SqlHelper from test_utils.managers.cashout_methods import ( EXAMPLE_TANGO_CASHOUT_METHODS, ) if TYPE_CHECKING: + from generalresearch.config import GRLBaseSettings from generalresearch.grliq.managers.forensic_data import ( GrlIqDataManager, ) @@ -32,6 +35,7 @@ if TYPE_CHECKING: MembershipManager, TeamManager, ) + from generalresearch.managers.thl.category import CategoryManager from generalresearch.managers.thl.contest_manager import ContestManager from generalresearch.managers.thl.ipinfo import ( GeoIpInfoManager, @@ -84,7 +88,9 @@ if TYPE_CHECKING: @pytest.fixture(scope="session") -def ltxm(thl_web_rw, thl_redis_config) -> "LedgerTransactionManager": +def ltxm( + thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig +) -> "LedgerTransactionManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.ledger_manager.ledger import ( @@ -100,7 +106,9 @@ def ltxm(thl_web_rw, thl_redis_config) -> "LedgerTransactionManager": @pytest.fixture(scope="session") -def lam(thl_web_rw, thl_redis_config) -> "LedgerAccountManager": +def lam( + thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig +) -> "LedgerAccountManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.ledger_manager.ledger import ( @@ -116,7 +124,7 @@ def lam(thl_web_rw, thl_redis_config) -> "LedgerAccountManager": @pytest.fixture(scope="session") -def lm(thl_web_rw, thl_redis_config) -> "LedgerManager": +def lm(thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig) -> "LedgerManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.ledger_manager.ledger import ( @@ -137,7 +145,9 @@ def lm(thl_web_rw, thl_redis_config) -> "LedgerManager": @pytest.fixture(scope="session") -def thl_lm(thl_web_rw, thl_redis_config) -> "ThlLedgerManager": +def thl_lm( + thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig +) -> "ThlLedgerManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.ledger_manager.thl_ledger import ( @@ -158,7 +168,9 @@ def thl_lm(thl_web_rw, thl_redis_config) -> "ThlLedgerManager": @pytest.fixture(scope="session") -def payout_event_manager(thl_web_rw, thl_redis_config) -> "PayoutEventManager": +def payout_event_manager( + thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig +) -> "PayoutEventManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.payout import PayoutEventManager @@ -171,7 +183,9 @@ def payout_event_manager(thl_web_rw, thl_redis_config) -> "PayoutEventManager": @pytest.fixture(scope="session") -def user_payout_event_manager(thl_web_rw, thl_redis_config) -> "UserPayoutEventManager": +def user_payout_event_manager( + thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig +) -> "UserPayoutEventManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.payout import UserPayoutEventManager @@ -185,7 +199,7 @@ def user_payout_event_manager(thl_web_rw, thl_redis_config) -> "UserPayoutEventM @pytest.fixture(scope="session") def brokerage_product_payout_event_manager( - thl_web_rw, thl_redis_config + thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig ) -> "BrokerageProductPayoutEventManager": assert "/unittest-" in thl_web_rw.dsn.path @@ -202,7 +216,7 @@ def brokerage_product_payout_event_manager( @pytest.fixture(scope="session") def business_payout_event_manager( - thl_web_rw, thl_redis_config + thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig ) -> "BusinessPayoutEventManager": assert "/unittest-" in thl_web_rw.dsn.path @@ -218,7 +232,7 @@ def business_payout_event_manager( @pytest.fixture(scope="session") -def product_manager(thl_web_rw) -> "ProductManager": +def product_manager(thl_web_rw: PostgresConfig) -> "ProductManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.product import ProductManager @@ -227,7 +241,9 @@ def product_manager(thl_web_rw) -> "ProductManager": @pytest.fixture(scope="session") -def user_manager(settings, thl_web_rw, thl_web_rr) -> "UserManager": +def user_manager( + settings: "GRLBaseSettings", thl_web_rw: PostgresConfig, thl_web_rr: PostgresConfig +) -> "UserManager": assert "/unittest-" in thl_web_rw.dsn.path assert "/unittest-" in thl_web_rr.dsn.path @@ -243,7 +259,7 @@ def user_manager(settings, thl_web_rw, thl_web_rr) -> "UserManager": @pytest.fixture(scope="session") -def user_metadata_manager(thl_web_rw) -> "UserMetadataManager": +def user_metadata_manager(thl_web_rw: PostgresConfig) -> "UserMetadataManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.user_manager.user_metadata_manager import ( @@ -254,7 +270,7 @@ def user_metadata_manager(thl_web_rw) -> "UserMetadataManager": @pytest.fixture(scope="session") -def session_manager(thl_web_rw) -> "SessionManager": +def session_manager(thl_web_rw: PostgresConfig) -> "SessionManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.session import SessionManager @@ -263,7 +279,7 @@ def session_manager(thl_web_rw) -> "SessionManager": @pytest.fixture(scope="session") -def wall_manager(thl_web_rw) -> "WallManager": +def wall_manager(thl_web_rw: PostgresConfig) -> "WallManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.wall import WallManager @@ -272,7 +288,9 @@ def wall_manager(thl_web_rw) -> "WallManager": @pytest.fixture(scope="session") -def wall_cache_manager(thl_web_rw, thl_redis_config) -> "WallCacheManager": +def wall_cache_manager( + thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig +) -> "WallCacheManager": # assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.wall import WallCacheManager @@ -281,7 +299,7 @@ def wall_cache_manager(thl_web_rw, thl_redis_config) -> "WallCacheManager": @pytest.fixture(scope="session") -def task_adjustment_manager(thl_web_rw) -> "TaskAdjustmentManager": +def task_adjustment_manager(thl_web_rw: PostgresConfig) -> "TaskAdjustmentManager": # assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.task_adjustment import ( @@ -292,7 +310,7 @@ def task_adjustment_manager(thl_web_rw) -> "TaskAdjustmentManager": @pytest.fixture(scope="session") -def contest_manager(thl_web_rw) -> "ContestManager": +def contest_manager(thl_web_rw: PostgresConfig) -> "ContestManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.contest_manager import ContestManager @@ -309,7 +327,7 @@ def contest_manager(thl_web_rw) -> "ContestManager": @pytest.fixture(scope="session") -def category_manager(thl_web_rw): +def category_manager(thl_web_rw: PostgresConfig) -> "CategoryManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.category import CategoryManager @@ -317,7 +335,7 @@ def category_manager(thl_web_rw): @pytest.fixture(scope="session") -def buyer_manager(thl_web_rw): +def buyer_manager(thl_web_rw: PostgresConfig): # assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.buyer import BuyerManager @@ -325,7 +343,7 @@ def buyer_manager(thl_web_rw): @pytest.fixture(scope="session") -def survey_manager(thl_web_rw): +def survey_manager(thl_web_rw: PostgresConfig): # assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.survey import SurveyManager @@ -333,7 +351,7 @@ def survey_manager(thl_web_rw): @pytest.fixture(scope="session") -def surveystat_manager(thl_web_rw): +def surveystat_manager(thl_web_rw: PostgresConfig): # assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.survey import SurveyStatManager @@ -348,7 +366,7 @@ def surveypenalty_manager(thl_redis_config): @pytest.fixture(scope="session") -def upk_schema_manager(thl_web_rw): +def upk_schema_manager(thl_web_rw: PostgresConfig): assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.profiling.schema import ( UpkSchemaManager, @@ -358,7 +376,7 @@ def upk_schema_manager(thl_web_rw): @pytest.fixture(scope="session") -def user_upk_manager(thl_web_rw, thl_redis_config): +def user_upk_manager(thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig): assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.profiling.user_upk import ( UserUpkManager, @@ -368,7 +386,7 @@ def user_upk_manager(thl_web_rw, thl_redis_config): @pytest.fixture(scope="session") -def question_manager(thl_web_rw, thl_redis_config): +def question_manager(thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig): assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.profiling.question import ( QuestionManager, @@ -378,7 +396,7 @@ def question_manager(thl_web_rw, thl_redis_config): @pytest.fixture(scope="session") -def uqa_manager(thl_web_rw, thl_redis_config): +def uqa_manager(thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig): assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.profiling.uqa import UQAManager @@ -395,7 +413,7 @@ def uqa_manager_clear_cache(uqa_manager, user): @pytest.fixture(scope="session") -def audit_log_manager(thl_web_rw) -> "AuditLogManager": +def audit_log_manager(thl_web_rw: PostgresConfig) -> "AuditLogManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.userhealth import AuditLogManager @@ -404,7 +422,7 @@ def audit_log_manager(thl_web_rw) -> "AuditLogManager": @pytest.fixture(scope="session") -def ip_geoname_manager(thl_web_rw) -> "IPGeonameManager": +def ip_geoname_manager(thl_web_rw: PostgresConfig) -> "IPGeonameManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.ipinfo import IPGeonameManager @@ -413,7 +431,7 @@ def ip_geoname_manager(thl_web_rw) -> "IPGeonameManager": @pytest.fixture(scope="session") -def ip_information_manager(thl_web_rw) -> "IPInformationManager": +def ip_information_manager(thl_web_rw: PostgresConfig) -> "IPInformationManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.ipinfo import IPInformationManager @@ -422,7 +440,9 @@ def ip_information_manager(thl_web_rw) -> "IPInformationManager": @pytest.fixture(scope="session") -def ip_record_manager(thl_web_rw, thl_redis_config) -> "IPRecordManager": +def ip_record_manager( + thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig +) -> "IPRecordManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.userhealth import IPRecordManager @@ -431,7 +451,9 @@ def ip_record_manager(thl_web_rw, thl_redis_config) -> "IPRecordManager": @pytest.fixture(scope="session") -def user_iphistory_manager(thl_web_rw, thl_redis_config) -> "UserIpHistoryManager": +def user_iphistory_manager( + thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig +) -> "UserIpHistoryManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.userhealth import ( @@ -451,7 +473,9 @@ def user_iphistory_manager_clear_cache(user_iphistory_manager, user): @pytest.fixture(scope="session") -def geoipinfo_manager(thl_web_rw, thl_redis_config) -> "GeoIpInfoManager": +def geoipinfo_manager( + thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig +) -> "GeoIpInfoManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.ipinfo import GeoIpInfoManager @@ -469,7 +493,9 @@ def maxmind_basic_manager() -> "MaxmindBasicManager": @pytest.fixture(scope="session") -def maxmind_manager(thl_web_rw, thl_redis_config) -> "MaxmindManager": +def maxmind_manager( + thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig +) -> "MaxmindManager": assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.maxmind import MaxmindManager @@ -478,7 +504,7 @@ def maxmind_manager(thl_web_rw, thl_redis_config) -> "MaxmindManager": @pytest.fixture(scope="session") -def cashout_method_manager(thl_web_rw): +def cashout_method_manager(thl_web_rw: PostgresConfig): assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.cashout_method import ( CashoutMethodManager, @@ -488,14 +514,14 @@ def cashout_method_manager(thl_web_rw): @pytest.fixture(scope="session") -def event_manager(thl_redis_config): +def event_manager(thl_redis_config: RedisConfig): from generalresearch.managers.events import EventManager return EventManager(redis_config=thl_redis_config) @pytest.fixture(scope="session") -def user_streak_manager(thl_web_rw): +def user_streak_manager(thl_web_rw: PostgresConfig): assert "/unittest-" in thl_web_rw.dsn.path from generalresearch.managers.thl.user_streak import ( UserStreakManager, @@ -505,7 +531,7 @@ def user_streak_manager(thl_web_rw): @pytest.fixture(scope="session") -def uqa_db_index(thl_web_rw): +def uqa_db_index(thl_web_rw: PostgresConfig): # There were some custom indices created not through django. # Make sure the index used in the index hint exists assert "/unittest-" in thl_web_rw.dsn.path @@ -521,7 +547,7 @@ def uqa_db_index(thl_web_rw): @pytest.fixture(scope="session") -def delete_cashoutmethod_db(thl_web_rw) -> Callable: +def delete_cashoutmethod_db(thl_web_rw: PostgresConfig) -> Callable[..., None]: def _delete_cashoutmethod_db(): thl_web_rw.execute_write( query="DELETE FROM accounting_cashoutmethod;", @@ -531,14 +557,21 @@ def delete_cashoutmethod_db(thl_web_rw) -> Callable: @pytest.fixture(scope="session") -def setup_cashoutmethod_db(settings, cashout_method_manager, delete_cashoutmethod_db): - settings.amt_ - +def setup_cashoutmethod_db( + settings: "GRLBaseSettings", cashout_method_manager, delete_cashoutmethod_db +): delete_cashoutmethod_db() for x in EXAMPLE_TANGO_CASHOUT_METHODS: cashout_method_manager.create(x) - cashout_method_manager.create(AMT_ASSIGNMENT_CASHOUT_METHOD) - cashout_method_manager.create(AMT_BONUS_CASHOUT_METHOD) + + # TODO: convert these ids into instances to use. + # settings.amt_bonus_cashout_method_id + # settings.amt_assignment_cashout_method_id + + # cashout_method_manager.create(AMT_ASSIGNMENT_CASHOUT_METHOD) + # cashout_method_manager.create(AMT_BONUS_CASHOUT_METHOD) + raise NotImplementedError("Need to implement setup_cashoutmethod_db") + return None @@ -546,7 +579,7 @@ def setup_cashoutmethod_db(settings, cashout_method_manager, delete_cashoutmetho @pytest.fixture(scope="session") -def spectrum_manager(spectrum_rw): +def spectrum_manager(spectrum_rw: SqlHelper) -> "SpectrumSurveyManager": from generalresearch.managers.spectrum.survey import ( SpectrumSurveyManager, ) @@ -556,7 +589,9 @@ def spectrum_manager(spectrum_rw): # === GR === @pytest.fixture(scope="session") -def business_manager(gr_db, gr_redis_config) -> "BusinessManager": +def business_manager( + gr_db: PostgresConfig, gr_redis_config: RedisConfig +) -> "BusinessManager": from generalresearch.redis_helper import RedisConfig assert "/unittest-" in gr_db.dsn.path @@ -571,7 +606,7 @@ def business_manager(gr_db, gr_redis_config) -> "BusinessManager": @pytest.fixture(scope="session") -def business_address_manager(gr_db) -> "BusinessAddressManager": +def business_address_manager(gr_db: PostgresConfig) -> "BusinessAddressManager": assert "/unittest-" in gr_db.dsn.path from generalresearch.managers.gr.business import BusinessAddressManager @@ -580,7 +615,9 @@ def business_address_manager(gr_db) -> "BusinessAddressManager": @pytest.fixture(scope="session") -def business_bank_account_manager(gr_db) -> "BusinessBankAccountManager": +def business_bank_account_manager( + gr_db: PostgresConfig, +) -> "BusinessBankAccountManager": assert "/unittest-" in gr_db.dsn.path from generalresearch.managers.gr.business import ( @@ -591,7 +628,7 @@ def business_bank_account_manager(gr_db) -> "BusinessBankAccountManager": @pytest.fixture(scope="session") -def team_manager(gr_db, gr_redis_config) -> "TeamManager": +def team_manager(gr_db: PostgresConfig, gr_redis_config: RedisConfig) -> "TeamManager": assert "/unittest-" in gr_db.dsn.path from generalresearch.managers.gr.team import TeamManager @@ -600,7 +637,7 @@ def team_manager(gr_db, gr_redis_config) -> "TeamManager": @pytest.fixture(scope="session") -def gr_um(gr_db, gr_redis_config) -> "GRUserManager": +def gr_um(gr_db: PostgresConfig, gr_redis_config: RedisConfig) -> "GRUserManager": assert "/unittest-" in gr_db.dsn.path from generalresearch.managers.gr.authentication import GRUserManager @@ -609,7 +646,7 @@ def gr_um(gr_db, gr_redis_config) -> "GRUserManager": @pytest.fixture(scope="session") -def gr_tm(gr_db) -> "GRTokenManager": +def gr_tm(gr_db: PostgresConfig) -> "GRTokenManager": assert "/unittest-" in gr_db.dsn.path from generalresearch.managers.gr.authentication import GRTokenManager @@ -618,7 +655,7 @@ def gr_tm(gr_db) -> "GRTokenManager": @pytest.fixture(scope="session") -def membership_manager(gr_db) -> "MembershipManager": +def membership_manager(gr_db: PostgresConfig) -> "MembershipManager": assert "/unittest-" in gr_db.dsn.path from generalresearch.managers.gr.team import MembershipManager @@ -630,7 +667,7 @@ def membership_manager(gr_db) -> "MembershipManager": @pytest.fixture(scope="session") -def grliq_dm(grliq_db) -> "GrlIqDataManager": +def grliq_dm(grliq_db: PostgresConfig) -> "GrlIqDataManager": assert "/unittest-" in grliq_db.dsn.path from generalresearch.grliq.managers.forensic_data import ( @@ -641,7 +678,7 @@ def grliq_dm(grliq_db) -> "GrlIqDataManager": @pytest.fixture(scope="session") -def grliq_em(grliq_db) -> "GrlIqEventManager": +def grliq_em(grliq_db: PostgresConfig) -> "GrlIqEventManager": assert "/unittest-" in grliq_db.dsn.path from generalresearch.grliq.managers.forensic_events import ( @@ -652,7 +689,7 @@ def grliq_em(grliq_db) -> "GrlIqEventManager": @pytest.fixture(scope="session") -def grliq_crr(grliq_db) -> "GrlIqCategoryResultsReader": +def grliq_crr(grliq_db: PostgresConfig) -> "GrlIqCategoryResultsReader": assert "/unittest-" in grliq_db.dsn.path from generalresearch.grliq.managers.forensic_results import ( @@ -663,7 +700,7 @@ def grliq_crr(grliq_db) -> "GrlIqCategoryResultsReader": @pytest.fixture(scope="session") -def delete_buyers_surveys(thl_web_rw, buyer_manager): +def delete_buyers_surveys(thl_web_rw: PostgresConfig, buyer_manager): # assert "/unittest-" in thl_web_rw.dsn.path thl_web_rw.execute_write( """ diff --git a/test_utils/managers/upk/conftest.py b/test_utils/managers/upk/conftest.py index 61be924..e28d085 100644 --- a/test_utils/managers/upk/conftest.py +++ b/test_utils/managers/upk/conftest.py @@ -1,6 +1,6 @@ import os import time -from typing import Optional +from typing import TYPE_CHECKING, Optional from uuid import UUID import pandas as pd @@ -8,6 +8,9 @@ import pytest from generalresearch.pg_helper import PostgresConfig +if TYPE_CHECKING: + from generalresearch.managers.thl.category import CategoryManager + def insert_data_from_csv( thl_web_rw: PostgresConfig, @@ -39,7 +42,9 @@ def insert_data_from_csv( @pytest.fixture(scope="session") -def category_data(thl_web_rw, category_manager) -> None: +def category_data( + thl_web_rw: PostgresConfig, category_manager: "CategoryManager" +) -> None: fp = os.path.join(os.path.dirname(__file__), "marketplace_category.csv.gz") insert_data_from_csv( thl_web_rw, @@ -66,20 +71,23 @@ def category_data(thl_web_rw, category_manager) -> None: @pytest.fixture(scope="session") -def property_data(thl_web_rw) -> None: +def property_data(thl_web_rw: PostgresConfig) -> None: fp = os.path.join(os.path.dirname(__file__), "marketplace_property.csv.gz") insert_data_from_csv(thl_web_rw, fp=fp, table_name="marketplace_property") @pytest.fixture(scope="session") -def item_data(thl_web_rw) -> None: +def item_data(thl_web_rw: PostgresConfig) -> None: fp = os.path.join(os.path.dirname(__file__), "marketplace_item.csv.gz") insert_data_from_csv(thl_web_rw, fp=fp, table_name="marketplace_item") @pytest.fixture(scope="session") def propertycategoryassociation_data( - thl_web_rw, category_data, property_data, category_manager + thl_web_rw: PostgresConfig, + category_data, + property_data, + category_manager: "CategoryManager", ) -> None: table_name = "marketplace_propertycategoryassociation" fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz") @@ -93,27 +101,31 @@ def propertycategoryassociation_data( @pytest.fixture(scope="session") -def propertycountry_data(thl_web_rw, property_data) -> None: +def propertycountry_data(thl_web_rw: PostgresConfig, property_data) -> None: fp = os.path.join(os.path.dirname(__file__), "marketplace_propertycountry.csv.gz") insert_data_from_csv(thl_web_rw, fp=fp, table_name="marketplace_propertycountry") @pytest.fixture(scope="session") -def propertymarketplaceassociation_data(thl_web_rw, property_data) -> None: +def propertymarketplaceassociation_data( + thl_web_rw: PostgresConfig, property_data +) -> None: table_name = "marketplace_propertymarketplaceassociation" fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz") insert_data_from_csv(thl_web_rw, fp=fp, table_name=table_name) @pytest.fixture(scope="session") -def propertyitemrange_data(thl_web_rw, property_data, item_data) -> None: +def propertyitemrange_data( + thl_web_rw: PostgresConfig, property_data, item_data +) -> None: table_name = "marketplace_propertyitemrange" fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz") insert_data_from_csv(thl_web_rw, fp=fp, table_name=table_name) @pytest.fixture(scope="session") -def question_data(thl_web_rw) -> None: +def question_data(thl_web_rw: PostgresConfig) -> None: table_name = "marketplace_question" fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz") insert_data_from_csv( @@ -122,7 +134,7 @@ def question_data(thl_web_rw) -> None: @pytest.fixture(scope="session") -def clear_upk_tables(thl_web_rw): +def clear_upk_tables(thl_web_rw: PostgresConfig): tables = [ "marketplace_propertyitemrange", "marketplace_propertymarketplaceassociation", diff --git a/test_utils/models/conftest.py b/test_utils/models/conftest.py index dcc3b66..81dc11e 100644 --- a/test_utils/models/conftest.py +++ b/test_utils/models/conftest.py @@ -5,7 +5,6 @@ from random import randint from typing import TYPE_CHECKING, Callable, Dict, List, Optional from uuid import uuid4 -import pytest from pydantic import AwareDatetime, PositiveInt from generalresearch.models import Source @@ -14,6 +13,7 @@ from generalresearch.models.thl.definitions import ( Status, ) from generalresearch.models.thl.survey.model import Buyer, Survey +from generalresearch.pg_helper import PostgresConfig from test_utils.managers.conftest import ( business_address_manager, business_manager, @@ -28,6 +28,7 @@ from test_utils.managers.conftest import ( if TYPE_CHECKING: from generalresearch.currency import USDCent + from generalresearch.managers.thl.session import SessionManager from generalresearch.models.gr.authentication import GRToken, GRUser from generalresearch.models.gr.business import ( Business, @@ -53,7 +54,7 @@ if TYPE_CHECKING: @pytest.fixture(scope="function") -def user(request, product_manager, user_manager, thl_web_rr) -> "User": +def user(request, product_manager, user_manager, thl_web_rr: PostgresConfig) -> "User": product = getattr(request, "product", None) if product is None: @@ -74,25 +75,29 @@ def user_with_wallet( @pytest.fixture -def user_with_wallet_amt(request, user_factory, product_amt_true: "Product") -> "User": +def user_with_wallet_amt( + request, user_factory: Callable[..., "User"], product_amt_true: "Product" +) -> "User": # A user on a product with user wallet enabled, on AMT, but they have no money return user_factory(product=product_amt_true) @pytest.fixture(scope="function") -def user_factory(user_manager, thl_web_rr) -> Callable: - def _create_user(product: "Product", created: Optional[datetime] = None): +def user_factory(user_manager, thl_web_rr: PostgresConfig) -> Callable[..., "User"]: + + def _inner(product: "Product", created: Optional[datetime] = None): u = user_manager.create_dummy(product=product, created=created) u.prefetch_product(pg_config=thl_web_rr) return u - return _create_user + return _inner @pytest.fixture(scope="function") -def wall_factory(wall_manager) -> Callable: - def _create_wall( +def wall_factory(wall_manager) -> Callable[..., "Wall"]: + + def _inner( session: "Session", wall_status: "Status", req_cpi: Optional[Decimal] = None ): @@ -126,10 +131,10 @@ def wall_factory(wall_manager) -> Callable: return wall - return _create_wall + return _inner -@pytest.fixture(scope="function") +@pytest.fixture def wall(session, user, wall_manager) -> Optional["Wall"]: from generalresearch.models.thl.task_status import StatusCode1 @@ -143,9 +148,12 @@ def wall(session, user, wall_manager) -> Optional["Wall"]: return wall -@pytest.fixture(scope="function") +@pytest.fixture def session_factory( - wall_factory, session_manager, wall_manager, utc_hour_ago + wall_factory, + session_manager: "SessionManager", + wall_manager, + utc_hour_ago: datetime, ) -> Callable[..., "Session"]: from generalresearch.models.thl.session import Source @@ -208,11 +216,13 @@ def session_factory( @pytest.fixture(scope="function") def finished_session_factory( - session_factory, session_manager, utc_hour_ago -) -> Callable: + session_factory: Callable[..., "Session"], + session_manager: "SessionManager", + utc_hour_ago: datetime, +) -> Callable[..., "Session"]: from generalresearch.models.thl.session import Source - def _create_finished_session( + def _inner( user: "User", # Wall details wall_count: int = 5, @@ -246,11 +256,11 @@ def finished_session_factory( ) return s - return _create_finished_session + return _inner @pytest.fixture(scope="function") -def session(user, session_manager, wall_manager) -> "Session": +def session(user, session_manager: "SessionManager", wall_manager) -> "Session": from generalresearch.models.thl.session import Session, Wall session: Session = session_manager.create_dummy(user=user, country_iso="us") @@ -294,7 +304,7 @@ def product_factory(product_manager) -> Callable: return _create_product -@pytest.fixture(scope="function") +@pytest.fixture def payout_config(request) -> "PayoutConfig": from generalresearch.models.thl.product import ( PayoutConfig, @@ -315,7 +325,7 @@ def payout_config(request) -> "PayoutConfig": ) -@pytest.fixture(scope="function") +@pytest.fixture def product_user_wallet_yes(payout_config, product_manager) -> "Product": from generalresearch.managers.thl.product import ProductManager from generalresearch.models.thl.product import UserWalletConfig @@ -326,7 +336,7 @@ def product_user_wallet_yes(payout_config, product_manager) -> "Product": ) -@pytest.fixture(scope="function") +@pytest.fixture def product_user_wallet_no(product_manager) -> "Product": from generalresearch.managers.thl.product import ProductManager from generalresearch.models.thl.product import UserWalletConfig @@ -337,7 +347,7 @@ def product_user_wallet_no(product_manager) -> "Product": ) -@pytest.fixture(scope="function") +@pytest.fixture def product_amt_true(product_manager, payout_config) -> "Product": from generalresearch.models.thl.product import UserWalletConfig @@ -347,7 +357,7 @@ def product_amt_true(product_manager, payout_config) -> "Product": ) -@pytest.fixture(scope="function") +@pytest.fixture def bp_payout_factory( thl_lm, product_manager, business_payout_event_manager ) -> Callable: @@ -380,7 +390,7 @@ def bp_payout_factory( # === GR === -@pytest.fixture(scope="function") +@pytest.fixture def business(request, business_manager) -> "Business": from generalresearch.managers.gr.business import BusinessManager @@ -388,7 +398,7 @@ def business(request, business_manager) -> "Business": return business_manager.create_dummy() -@pytest.fixture(scope="function") +@pytest.fixture def business_address(request, business, business_address_manager) -> "BusinessAddress": from generalresearch.managers.gr.business import BusinessAddressManager @@ -396,7 +406,7 @@ def business_address(request, business, business_address_manager) -> "BusinessAd return business_address_manager.create_dummy(business_id=business.id) -@pytest.fixture(scope="function") +@pytest.fixture def business_bank_account( request, business, business_bank_account_manager ) -> "BusinessBankAccount": @@ -406,7 +416,7 @@ def business_bank_account( return business_bank_account_manager.create_dummy(business_id=business.id) -@pytest.fixture(scope="function") +@pytest.fixture def team(request, team_manager) -> "Team": from generalresearch.managers.gr.team import TeamManager @@ -414,7 +424,7 @@ def team(request, team_manager) -> "Team": return team_manager.create_dummy() -@pytest.fixture(scope="function") +@pytest.fixture def gr_user(gr_um) -> "GRUser": from generalresearch.managers.gr.authentication import GRUserManager @@ -422,7 +432,7 @@ def gr_user(gr_um) -> "GRUser": return gr_um.create_dummy() -@pytest.fixture(scope="function") +@pytest.fixture def gr_user_cache(gr_user, gr_db, thl_web_rr, gr_redis_config): gr_user.set_cache( pg_config=gr_db, thl_web_rr=thl_web_rr, redis_config=gr_redis_config @@ -430,12 +440,12 @@ def gr_user_cache(gr_user, gr_db, thl_web_rr, gr_redis_config): return gr_user -@pytest.fixture(scope="function") -def gr_user_factory(gr_um) -> Callable: - def _create_gr_user(): +@pytest.fixture +def gr_user_factory(gr_um) -> Callable[..., "GRUser"]: + def _inner(): return gr_um.create_dummy() - return _create_gr_user + return _inner @pytest.fixture() @@ -447,7 +457,7 @@ def gr_user_token(gr_user, gr_tm, gr_db) -> "GRToken": @pytest.fixture() -def gr_user_token_header(gr_user_token) -> Dict: +def gr_user_token_header(gr_user_token: "GRToken") -> Dict[str, str]: return gr_user_token.auth_header @@ -461,41 +471,41 @@ def membership(request, team, gr_user, team_manager) -> "Membership": @pytest.fixture(scope="function") def membership_factory( team: "Team", gr_user: "GRUser", membership_manager, team_manager, gr_um -) -> Callable: +) -> Callable[..., "Membership"]: from generalresearch.managers.gr.team import MembershipManager membership_manager: MembershipManager - def _create_membership(**kwargs): + def _inner(**kwargs) -> "Membership": _team = kwargs.get("team", team_manager.create_dummy()) _gr_user = kwargs.get("gr_user", gr_um.create_dummy()) return membership_manager.create(team=_team, gr_user=_gr_user) - return _create_membership + return _inner -@pytest.fixture(scope="function") -def audit_log(audit_log_manager, user) -> "AuditLog": +@pytest.fixture +def audit_log(audit_log_manager, user: "User") -> "AuditLog": from generalresearch.managers.thl.userhealth import AuditLogManager audit_log_manager: AuditLogManager return audit_log_manager.create_dummy(user_id=user.user_id) -@pytest.fixture(scope="function") -def audit_log_factory(audit_log_manager) -> Callable: +@pytest.fixture +def audit_log_factory(audit_log_manager) -> Callable[..., "AuditLog"]: from generalresearch.managers.thl.userhealth import AuditLogManager audit_log_manager: AuditLogManager - def _create_audit_log( + def _inner( user_id: PositiveInt, level: Optional["AuditLogLevel"] = None, event_type: Optional[str] = None, event_msg: Optional[str] = None, event_value: Optional[float] = None, - ): + ) -> "AuditLog": return audit_log_manager.create_dummy( user_id=user_id, level=level, @@ -504,10 +514,10 @@ def audit_log_factory(audit_log_manager) -> Callable: event_value=event_value, ) - return _create_audit_log + return _inner -@pytest.fixture(scope="function") +@pytest.fixture def ip_geoname(ip_geoname_manager) -> "IPGeoname": from generalresearch.managers.thl.ipinfo import IPGeonameManager @@ -515,7 +525,7 @@ def ip_geoname(ip_geoname_manager) -> "IPGeoname": return ip_geoname_manager.create_dummy() -@pytest.fixture(scope="function") +@pytest.fixture def ip_information(ip_information_manager, ip_geoname) -> "IPInformation": from generalresearch.managers.thl.ipinfo import IPInformationManager @@ -525,7 +535,7 @@ def ip_information(ip_information_manager, ip_geoname) -> "IPInformation": ) -@pytest.fixture(scope="function") +@pytest.fixture def ip_information_factory(ip_information_manager) -> Callable: from generalresearch.managers.thl.ipinfo import IPInformationManager @@ -542,8 +552,8 @@ def ip_information_factory(ip_information_manager) -> Callable: return _create_ip_info -@pytest.fixture(scope="function") -def ip_record(ip_record_manager, ip_geoname, user) -> "IPRecord": +@pytest.fixture +def ip_record(ip_record_manager, ip_geoname, user: "User") -> "IPRecord": from generalresearch.managers.thl.userhealth import IPRecordManager ip_record_manager: IPRecordManager @@ -551,8 +561,8 @@ def ip_record(ip_record_manager, ip_geoname, user) -> "IPRecord": return ip_record_manager.create_dummy(user_id=user.user_id) -@pytest.fixture(scope="function") -def ip_record_factory(ip_record_manager, user) -> Callable: +@pytest.fixture +def ip_record_factory(ip_record_manager, user: "User") -> Callable: from generalresearch.managers.thl.userhealth import IPRecordManager ip_record_manager: IPRecordManager diff --git a/test_utils/spectrum/conftest.py b/test_utils/spectrum/conftest.py index d737730..0afc3f5 100644 --- a/test_utils/spectrum/conftest.py +++ b/test_utils/spectrum/conftest.py @@ -34,21 +34,21 @@ def spectrum_rw(settings: "GRLBaseSettings") -> SqlHelper: @pytest.fixture(scope="session") -def spectrum_criteria_manager(spectrum_rw) -> SpectrumCriteriaManager: +def spectrum_criteria_manager(spectrum_rw: SqlHelper) -> SpectrumCriteriaManager: assert "/unittest-" in spectrum_rw.dsn.path return SpectrumCriteriaManager(spectrum_rw) @pytest.fixture(scope="session") -def spectrum_survey_manager(spectrum_rw) -> SpectrumSurveyManager: +def spectrum_survey_manager(spectrum_rw: SqlHelper) -> SpectrumSurveyManager: assert "/unittest-" in spectrum_rw.dsn.path return SpectrumSurveyManager(spectrum_rw) @pytest.fixture(scope="session") def setup_spectrum_surveys( - spectrum_rw, spectrum_survey_manager, spectrum_criteria_manager -): + spectrum_rw: SqlHelper, spectrum_survey_manager, spectrum_criteria_manager +) -> None: now = datetime.now(timezone.utc) # make sure these example surveys exist in db surveys = [SpectrumSurvey.model_validate_json(x) for x in SURVEYS_JSON] -- cgit v1.2.3