aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--generalresearch/__init__.py9
-rw-r--r--generalresearch/config.py6
-rw-r--r--generalresearch/healing_ppe.py5
-rw-r--r--generalresearch/mariadb.py2
-rw-r--r--generalresearch/models/cint/question.py33
-rw-r--r--generalresearch/models/cint/survey.py23
-rw-r--r--generalresearch/models/cint/task_collection.py2
-rw-r--r--generalresearch/models/custom_types.py13
-rw-r--r--generalresearch/models/dynata/question.py10
-rw-r--r--generalresearch/models/dynata/survey.py38
-rw-r--r--generalresearch/models/dynata/task_collection.py4
-rw-r--r--generalresearch/models/events.py14
-rw-r--r--generalresearch/models/gr/__init__.py4
-rw-r--r--generalresearch/models/gr/authentication.py18
-rw-r--r--generalresearch/models/gr/business.py29
-rw-r--r--generalresearch/models/gr/team.py13
-rw-r--r--generalresearch/models/innovate/question.py26
-rw-r--r--generalresearch/models/innovate/survey.py31
-rw-r--r--generalresearch/models/innovate/task_collection.py2
-rw-r--r--generalresearch/models/legacy/bucket.py31
-rw-r--r--generalresearch/models/legacy/offerwall.py16
-rw-r--r--generalresearch/models/legacy/questions.py21
-rw-r--r--generalresearch/models/lucid/question.py10
-rw-r--r--generalresearch/models/lucid/survey.py12
-rw-r--r--generalresearch/models/marketplace/summary.py4
-rw-r--r--generalresearch/models/morning/question.py18
-rw-r--r--generalresearch/models/morning/survey.py39
-rw-r--r--generalresearch/models/morning/task_collection.py7
-rw-r--r--generalresearch/models/pollfish/question.py15
-rw-r--r--generalresearch/models/precision/question.py21
-rw-r--r--generalresearch/models/precision/survey.py24
-rw-r--r--generalresearch/models/precision/task_collection.py6
-rw-r--r--generalresearch/models/prodege/__init__.py6
-rw-r--r--generalresearch/models/prodege/question.py49
-rw-r--r--generalresearch/models/prodege/survey.py76
-rw-r--r--generalresearch/models/prodege/task_collection.py4
-rw-r--r--generalresearch/models/repdata/question.py29
-rw-r--r--generalresearch/models/repdata/survey.py23
-rw-r--r--generalresearch/models/repdata/task_collection.py8
-rw-r--r--generalresearch/models/sago/question.py29
-rw-r--r--generalresearch/models/sago/survey.py24
-rw-r--r--generalresearch/models/sago/task_collection.py8
-rw-r--r--generalresearch/models/spectrum/question.py25
-rw-r--r--generalresearch/models/spectrum/survey.py18
-rw-r--r--generalresearch/models/spectrum/task_collection.py8
-rw-r--r--generalresearch/models/thl/__init__.py2
-rw-r--r--generalresearch/models/thl/category.py6
-rw-r--r--generalresearch/models/thl/contest/__init__.py19
-rw-r--r--generalresearch/models/thl/contest/contest.py18
-rw-r--r--generalresearch/models/thl/contest/contest_entry.py14
-rw-r--r--generalresearch/models/thl/contest/examples.py126
-rw-r--r--generalresearch/models/thl/contest/leaderboard.py20
-rw-r--r--generalresearch/models/thl/contest/milestone.py23
-rw-r--r--generalresearch/models/thl/contest/raffle.py19
-rw-r--r--generalresearch/models/thl/contest/utils.py4
-rw-r--r--generalresearch/models/thl/definitions.py3
-rw-r--r--generalresearch/models/thl/demographics.py9
-rw-r--r--generalresearch/models/thl/finance.py18
-rw-r--r--generalresearch/models/thl/grliq.py2
-rw-r--r--generalresearch/models/thl/ipinfo.py16
-rw-r--r--generalresearch/models/thl/leaderboard.py14
-rw-r--r--generalresearch/models/thl/ledger.py20
-rw-r--r--generalresearch/models/thl/ledger_example.py10
-rw-r--r--generalresearch/models/thl/locales.py2
-rw-r--r--generalresearch/models/thl/offerwall/__init__.py6
-rw-r--r--generalresearch/models/thl/offerwall/base.py14
-rw-r--r--generalresearch/models/thl/offerwall/behavior.py2
-rw-r--r--generalresearch/models/thl/offerwall/bucket.py6
-rw-r--r--generalresearch/models/thl/offerwall/cache.py4
-rw-r--r--generalresearch/models/thl/pagination.py2
-rw-r--r--generalresearch/models/thl/payout.py8
-rw-r--r--generalresearch/models/thl/payout_format.py4
-rw-r--r--generalresearch/models/thl/product.py39
-rw-r--r--generalresearch/models/thl/profiling/marketplace.py16
-rw-r--r--generalresearch/models/thl/profiling/question.py8
-rw-r--r--generalresearch/models/thl/profiling/upk_property.py4
-rw-r--r--generalresearch/models/thl/profiling/upk_question.py20
-rw-r--r--generalresearch/models/thl/profiling/upk_question_answer.py11
-rw-r--r--generalresearch/models/thl/profiling/user_info.py2
-rw-r--r--generalresearch/models/thl/profiling/user_question_answer.py12
-rw-r--r--generalresearch/models/thl/report_task.py2
-rw-r--r--generalresearch/models/thl/session.py29
-rw-r--r--generalresearch/models/thl/stats.py2
-rw-r--r--generalresearch/models/thl/survey/__init__.py15
-rw-r--r--generalresearch/models/thl/survey/buyer.py14
-rw-r--r--generalresearch/models/thl/survey/condition.py10
-rw-r--r--generalresearch/models/thl/survey/model.py18
-rw-r--r--generalresearch/models/thl/survey/penalty.py4
-rw-r--r--generalresearch/models/thl/survey/task_collection.py2
-rw-r--r--generalresearch/models/thl/synchronize_global_vars.py2
-rw-r--r--generalresearch/models/thl/task_adjustment.py4
-rw-r--r--generalresearch/models/thl/task_status.py16
-rw-r--r--generalresearch/models/thl/user.py18
-rw-r--r--generalresearch/models/thl/user_iphistory.py8
-rw-r--r--generalresearch/models/thl/user_profile.py6
-rw-r--r--generalresearch/models/thl/user_quality_event.py4
-rw-r--r--generalresearch/models/thl/user_streak.py10
-rw-r--r--generalresearch/models/thl/userhealth.py4
-rw-r--r--generalresearch/models/thl/wallet/__init__.py2
-rw-r--r--generalresearch/models/thl/wallet/cashout_method.py14
-rw-r--r--generalresearch/models/thl/wallet/payout.py12
-rw-r--r--generalresearch/models/thl/wallet/user_wallet.py4
-rw-r--r--generalresearch/pg_helper.py13
-rw-r--r--generalresearch/schemas/survey_stats.py2
-rw-r--r--generalresearch/sql_helper.py12
-rw-r--r--generalresearch/utils/aggregation.py4
-rw-r--r--generalresearch/utils/grpc_logger.py2
-rw-r--r--generalresearch/wall_status_codes/__init__.py11
-rw-r--r--generalresearch/wall_status_codes/cint.py5
-rw-r--r--generalresearch/wall_status_codes/dynata.py18
-rw-r--r--generalresearch/wall_status_codes/fullcircle.py14
-rw-r--r--generalresearch/wall_status_codes/innovate.py15
-rw-r--r--generalresearch/wall_status_codes/lucid.py33
-rw-r--r--generalresearch/wall_status_codes/morning.py15
-rw-r--r--generalresearch/wall_status_codes/pollfish.py13
-rw-r--r--generalresearch/wall_status_codes/precision.py12
-rw-r--r--generalresearch/wall_status_codes/prodege.py10
-rw-r--r--generalresearch/wall_status_codes/repdata.py14
-rw-r--r--generalresearch/wall_status_codes/sago.py19
-rw-r--r--generalresearch/wall_status_codes/spectrum.py12
-rw-r--r--generalresearch/wall_status_codes/wxet.py15
-rw-r--r--generalresearch/wxet/models/definitions.py3
-rw-r--r--generalresearch/wxet/models/finish_type.py2
-rw-r--r--test_utils/conftest.py9
-rw-r--r--test_utils/managers/conftest.py145
-rw-r--r--test_utils/managers/upk/conftest.py32
-rw-r--r--test_utils/models/conftest.py110
-rw-r--r--test_utils/spectrum/conftest.py8
128 files changed, 1179 insertions, 885 deletions
diff --git a/generalresearch/__init__.py b/generalresearch/__init__.py
index 5993613..fb5f28d 100644
--- a/generalresearch/__init__.py
+++ b/generalresearch/__init__.py
@@ -1,12 +1,19 @@
import threading
import time
from functools import wraps
+from typing import Any, Callable, Optional
from decorator import decorator
from wrapt import FunctionWrapper, ObjectProxy
-def retry(exceptions, tries=4, delay=0.5, backoff=2, logger=None):
+def retry(
+ exceptions,
+ tries: int = 4,
+ delay: float = 0.5,
+ backoff: int = 2,
+ logger: Optional[Any] = None,
+) -> Callable:
"""
https://www.calazan.com/retry-decorator-for-python-3/
Retry calling the decorated function using an exponential backoff.
diff --git a/generalresearch/config.py b/generalresearch/config.py
index 1c3f6e6..94885b1 100644
--- a/generalresearch/config.py
+++ b/generalresearch/config.py
@@ -1,11 +1,11 @@
from datetime import datetime, timezone
-from typing import Optional
from pathlib import Path
+from typing import Optional
-from pydantic import RedisDsn, Field, MariaDBDsn, DirectoryPath, PostgresDsn
+from pydantic import DirectoryPath, Field, MariaDBDsn, PostgresDsn, RedisDsn
from pydantic_settings import BaseSettings
-from generalresearch.models.custom_types import DaskDsn, SentryDsn, MySQLOrMariaDsn
+from generalresearch.models.custom_types import DaskDsn, SentryDsn
def is_debug() -> bool:
diff --git a/generalresearch/healing_ppe.py b/generalresearch/healing_ppe.py
index 2a3aecf..e8ee144 100644
--- a/generalresearch/healing_ppe.py
+++ b/generalresearch/healing_ppe.py
@@ -5,6 +5,7 @@ import time
from collections import defaultdict
from concurrent import futures
from concurrent.futures.process import BrokenProcessPool
+from typing import Optional
logger = logging.getLogger()
@@ -15,7 +16,9 @@ signal_int_name = defaultdict(
class HealingProcessPoolExecutor:
def __init__(
- self, max_workers=None, name=None, slack_token=None, slack_channel=None
+ self,
+ max_workers: Optional[int] = None,
+ name: Optional[str] = None,
):
if not name:
try:
diff --git a/generalresearch/mariadb.py b/generalresearch/mariadb.py
index d95e46c..8bcd8ee 100644
--- a/generalresearch/mariadb.py
+++ b/generalresearch/mariadb.py
@@ -25,7 +25,7 @@ def example():
# actually we don't need the field flags. I didn't see, but there is an
# extended field type returned also. Which explicitly tags uuid fields.
conn = mariadb.connect(
- host="127.0.0.1", user="root", password="", database="300large-morning"
+ host="127.0.0.1", user="root", password="", database="thl-morning"
)
c = conn.cursor()
c.execute("SELECT user_id, pid as greg FROM morning_userpid limit 1")
diff --git a/generalresearch/models/cint/question.py b/generalresearch/models/cint/question.py
index f840a46..77212a2 100644
--- a/generalresearch/models/cint/question.py
+++ b/generalresearch/models/cint/question.py
@@ -1,10 +1,10 @@
import json
from datetime import datetime, timezone
from enum import Enum
-from typing import Optional, List, Literal, Dict, Any
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from uuid import UUID
-from pydantic import Field, BaseModel, model_validator, field_validator
+from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from generalresearch.models import Source, string_utils
@@ -15,6 +15,11 @@ from generalresearch.models.thl.profiling.marketplace import (
MarketplaceUserQuestionAnswer,
)
+if TYPE_CHECKING:
+ from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
+ )
+
class CintQuestionType(str, Enum):
SINGLE_SELECT = "s"
@@ -118,27 +123,34 @@ class CintQuestion(MarketplaceQuestion):
@field_validator("options")
@classmethod
- def order_options(cls, options):
+ def order_options(
+ cls, options: Optional[List[CintQuestionOption]]
+ ) -> Optional[List[CintQuestionOption]]:
if options:
options.sort(key=lambda x: x.order)
+
return options
@field_validator("options")
@classmethod
- def validate_options(cls, options):
+ def validate_options(
+ cls, options: Optional[List[CintQuestionOption]]
+ ) -> Optional[List[CintQuestionOption]]:
if options:
ids = {x.id for x in options}
assert len(ids) == len(options), "options.id must be unique"
orders = {x.order for x in options}
assert len(orders) == len(options), "options.order must be unique"
+
return options
@classmethod
- def from_api(cls, d: dict, country_iso: str, language_iso: str) -> Self:
+ def from_api(cls, d: Dict[str, Any], country_iso: str, language_iso: str) -> Self:
options = None
created_at = datetime.strptime(
d["create_date"], "%Y-%m-%dT%H:%M:%S%z"
).astimezone(timezone.utc)
+
if d.get("question_options"):
options = [
CintQuestionOption(
@@ -166,15 +178,17 @@ class CintQuestion(MarketplaceQuestion):
)
@classmethod
- def from_db(cls, d: dict) -> Self:
+ def from_db(cls, d: Dict[str, Any]) -> Self:
options = None
if d["options"]:
options = [
CintQuestionOption(id=r["id"], text=r["text"], order=r["order"])
for r in d["options"]
]
+
if d.get("created_at"):
d["created_at"] = d["created_at"].replace(tzinfo=timezone.utc)
+
return cls(
question_id=d["question_id"],
question_name=d["question_name"],
@@ -194,15 +208,16 @@ class CintQuestion(MarketplaceQuestion):
d["options"] = json.dumps(d["options"])
if self.created_at:
d["created_at"] = self.created_at.replace(tzinfo=None)
+
return d
- def to_upk_question(self):
+ def to_upk_question(self) -> "UpkQuestion":
from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
UpkQuestionChoice,
- UpkQuestionType,
UpkQuestionSelectorMC,
UpkQuestionSelectorTE,
- UpkQuestion,
+ UpkQuestionType,
)
upk_type_selector_map = {
diff --git a/generalresearch/models/cint/survey.py b/generalresearch/models/cint/survey.py
index 30b140c..be07e04 100644
--- a/generalresearch/models/cint/survey.py
+++ b/generalresearch/models/cint/survey.py
@@ -4,32 +4,32 @@ import json
import logging
from datetime import datetime, timezone
from decimal import Decimal
-from typing import Optional, Dict, Set, Tuple, List, Literal, Any, Type
+from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type
from more_itertools import flatten
from pydantic import (
- NonNegativeInt,
- Field,
- ConfigDict,
BaseModel,
+ ConfigDict,
+ Field,
+ NonNegativeInt,
computed_field,
model_validator,
)
-from typing_extensions import Self, Annotated
+from typing_extensions import Annotated, Self
from generalresearch.locales import Localelator
from generalresearch.models import Source, TaskCalculationType
from generalresearch.models.cint import CintQuestionIdType
from generalresearch.models.custom_types import (
+ AlphaNumStr,
AwareDatetimeISO,
CoercedStr,
- AlphaNumStr,
)
from generalresearch.models.thl.demographics import Gender
from generalresearch.models.thl.survey import MarketplaceTask
from generalresearch.models.thl.survey.condition import (
- MarketplaceCondition,
ConditionValueType,
+ MarketplaceCondition,
)
logging.basicConfig()
@@ -95,7 +95,8 @@ class CintQuota(BaseModel):
def matches(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool:
# Matches means we meet all conditions.
- # We can "match" a quota that is closed. In that case, we would not be eligible for the survey.
+ # We can "match" a quota that is closed. In that case, we would
+ # not be eligible for the survey.
return all(criteria_evaluation.get(c) for c in self.condition_hashes)
def matches_optional(
@@ -245,7 +246,7 @@ class CintSurvey(MarketplaceTask):
def is_live(self) -> bool:
return self.is_live_raw
- def model_dump(self, **kwargs: Any) -> dict:
+ def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
data = super().model_dump(**kwargs)
data["is_live"] = data.pop("is_live_raw", None)
return data
@@ -314,7 +315,7 @@ class CintSurvey(MarketplaceTask):
}
@classmethod
- def from_api(cls, d: Dict) -> Optional[Self]:
+ def from_api(cls, d: Dict[str, Any]) -> Optional[Self]:
try:
return cls._from_api(d)
except Exception as e:
@@ -322,7 +323,7 @@ class CintSurvey(MarketplaceTask):
return None
@classmethod
- def _from_api(cls, d: Dict) -> Self:
+ def _from_api(cls, d: Dict[str, Any]) -> Self:
if "cpi" in d:
d["gross_cpi"] = Decimal(d.pop("cpi"))
if "revenue_per_interview" in d:
diff --git a/generalresearch/models/cint/task_collection.py b/generalresearch/models/cint/task_collection.py
index 91db35f..b43250a 100644
--- a/generalresearch/models/cint/task_collection.py
+++ b/generalresearch/models/cint/task_collection.py
@@ -1,7 +1,7 @@
from typing import List, Set
import pandas as pd
-from pandera import Column, DataFrameSchema, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
from generalresearch.locales import Localelator
from generalresearch.models.cint.survey import CintSurvey
diff --git a/generalresearch/models/custom_types.py b/generalresearch/models/custom_types.py
index 54208a1..aefbbe9 100644
--- a/generalresearch/models/custom_types.py
+++ b/generalresearch/models/custom_types.py
@@ -1,16 +1,16 @@
import json
-from datetime import datetime, timezone, timedelta
-from typing import Any, Optional, Set, Literal
+from datetime import datetime, timedelta, timezone
+from typing import Any, Literal, Optional, Set
from uuid import UUID
from pydantic import (
+ AnyUrl,
AwareDatetime,
- StringConstraints,
- TypeAdapter,
+ Field,
HttpUrl,
IPvAnyAddress,
- Field,
- AnyUrl,
+ StringConstraints,
+ TypeAdapter,
)
from pydantic.functional_serializers import PlainSerializer
from pydantic.functional_validators import AfterValidator, BeforeValidator
@@ -20,7 +20,6 @@ from typing_extensions import Annotated
from generalresearch.models import DeviceType, Source
-
# if TYPE_CHECKING:
# from generalresearch.models import DeviceType
diff --git a/generalresearch/models/dynata/question.py b/generalresearch/models/dynata/question.py
index 76670ce..4d675c0 100644
--- a/generalresearch/models/dynata/question.py
+++ b/generalresearch/models/dynata/question.py
@@ -5,11 +5,11 @@ import re
from datetime import timedelta
from enum import Enum
from functools import cached_property
-from typing import List, Optional, Literal, Any, Dict, Set
+from typing import Any, Dict, List, Literal, Optional, Set
-from pydantic import BaseModel, Field, model_validator, field_validator, PositiveInt
+from pydantic import BaseModel, Field, PositiveInt, field_validator, model_validator
-from generalresearch.models import Source, MAX_INT32
+from generalresearch.models import MAX_INT32, Source
from generalresearch.models.custom_types import AwareDatetimeISO
from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion
@@ -228,11 +228,11 @@ class DynataQuestion(MarketplaceQuestion):
def to_upk_question(self):
from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
UpkQuestionChoice,
- UpkQuestionType,
UpkQuestionSelectorMC,
UpkQuestionSelectorTE,
- UpkQuestion,
+ UpkQuestionType,
order_exclusive_options,
)
diff --git a/generalresearch/models/dynata/survey.py b/generalresearch/models/dynata/survey.py
index de4b177..ae12436 100644
--- a/generalresearch/models/dynata/survey.py
+++ b/generalresearch/models/dynata/survey.py
@@ -5,28 +5,28 @@ import logging
from datetime import timezone
from decimal import Decimal
from functools import cached_property
-from typing import Optional, Dict, Any, List, Literal, Set, Tuple, Type
+from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type
from more_itertools import flatten
from pydantic import (
- Field,
- ConfigDict,
BaseModel,
- model_validator,
- field_validator,
+ ConfigDict,
+ Field,
RootModel,
computed_field,
+ field_validator,
+ model_validator,
)
from typing_extensions import Self
from generalresearch.locales import Localelator
-from generalresearch.models import TaskCalculationType, Source
+from generalresearch.models import Source, TaskCalculationType
from generalresearch.models.custom_types import (
- CoercedStr,
- AwareDatetimeISO,
+ AlphaNumStr,
AlphaNumStrSet,
+ AwareDatetimeISO,
+ CoercedStr,
DeviceTypes,
- AlphaNumStr,
)
from generalresearch.models.dynata import DynataStatus
from generalresearch.models.thl.demographics import (
@@ -48,20 +48,30 @@ locale_helper = Localelator()
class DynataRequirements(BaseModel):
# Requires inviting (recontacting) specific respondents to a follow up survey.
requires_recontact: bool = Field(default=False)
- # Requires respondents to provide personally identifiable information (PII) within client survey.
+
+ # Requires respondents to provide personally identifiable
+ # information (PII) within client survey.
requires_pii_collection: bool = Field(default=False)
+
# Requires respondents to utilize their webcam to participate.
requires_webcam: bool = Field(default=False)
- # Requires use of facial recognition technology with respondents, such as eye tracking.
+
+ # Requires use of facial recognition technology with
+ # respondents, such as eye tracking.
requires_eye_tracking: bool = Field(default=False)
+
# Requires partner to allow Dynata to drop a cookie on respondent.
requires_cookie_drops: bool = Field(default=False)
- # Requires partner-uploaded respondent PII to expand third-party matched data.
+
+ # Requires partner-uploaded respondent PII to expand
+ # third-party matched data.
requires_sample_plus: bool = Field(default=False)
+
# Requires respondents to download a software application.
requires_app_vpn: bool = Field(default=False)
- # Requires additional incentives to be manually awarded to respondent by partner outside of the typical online
- # survey flow.
+
+ # Requires additional incentives to be manually awarded to
+ # respondent by partner outside of the typical online survey flow.
requires_manual_rewards: bool = Field(default=False)
def __repr__(self) -> str:
diff --git a/generalresearch/models/dynata/task_collection.py b/generalresearch/models/dynata/task_collection.py
index bd9b81a..a10fcf0 100644
--- a/generalresearch/models/dynata/task_collection.py
+++ b/generalresearch/models/dynata/task_collection.py
@@ -1,7 +1,7 @@
-from typing import List, Dict, Any
+from typing import Any, Dict, List
import pandas as pd
-from pandera import Column, DataFrameSchema, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
from generalresearch.locales import Localelator
from generalresearch.models import TaskCalculationType
diff --git a/generalresearch/models/events.py b/generalresearch/models/events.py
index 4504ab1..63ed2a1 100644
--- a/generalresearch/models/events.py
+++ b/generalresearch/models/events.py
@@ -1,30 +1,30 @@
-from datetime import datetime, timezone, timedelta
+from datetime import datetime, timedelta, timezone
from enum import StrEnum
-from typing import Union, Literal, Optional, Dict
+from typing import Dict, Literal, Optional, Union
from uuid import uuid4
from pydantic import (
BaseModel,
+ ConfigDict,
Field,
- PositiveFloat,
NonNegativeInt,
- model_validator,
+ PositiveFloat,
TypeAdapter,
- ConfigDict,
+ model_validator,
)
from typing_extensions import Annotated
from generalresearch.models import Source
from generalresearch.models.custom_types import (
+ AwareDatetimeISO,
CountryISOLike,
UUIDStr,
- AwareDatetimeISO,
)
from generalresearch.models.thl.definitions import (
+ SessionStatusCode2,
Status,
StatusCode1,
WallStatusCode2,
- SessionStatusCode2,
)
diff --git a/generalresearch/models/gr/__init__.py b/generalresearch/models/gr/__init__.py
index 713bba6..7e1516b 100644
--- a/generalresearch/models/gr/__init__.py
+++ b/generalresearch/models/gr/__init__.py
@@ -1,9 +1,9 @@
-from generalresearch.models.gr.authentication import GRUser, GRToken
+from generalresearch.models.gr.authentication import GRToken, GRUser
from generalresearch.models.gr.business import Business
from generalresearch.models.gr.team import Team
+from generalresearch.models.thl.finance import BusinessBalances
from generalresearch.models.thl.payout import BrokerageProductPayoutEvent
from generalresearch.models.thl.product import Product
-from generalresearch.models.thl.finance import BusinessBalances
_ = Business, Product, BrokerageProductPayoutEvent, BusinessBalances
diff --git a/generalresearch/models/gr/authentication.py b/generalresearch/models/gr/authentication.py
index ff1d065..764c694 100644
--- a/generalresearch/models/gr/authentication.py
+++ b/generalresearch/models/gr/authentication.py
@@ -4,16 +4,16 @@ import binascii
import json
import os
from datetime import datetime, timezone
-from typing import Optional, List, TYPE_CHECKING, Dict, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
-from pydantic import AnyHttpUrl
from pydantic import (
+ AnyHttpUrl,
BaseModel,
ConfigDict,
Field,
+ NonNegativeInt,
PositiveInt,
field_validator,
- NonNegativeInt,
)
from typing_extensions import Self
@@ -119,7 +119,7 @@ class GRUser(BaseModel):
claims: Optional["Claims"] = Field(default=None)
def prefetch_claims(
- self, token: str, key: Dict, audience: str, issuer: AnyHttpUrl
+ self, token: str, key: Dict[str, Any], audience: str, issuer: AnyHttpUrl
) -> None:
from jose import jwt
@@ -220,7 +220,7 @@ class GRUser(BaseModel):
LOG.warning("prefetch not run")
return None
- return [b.id for b in self.businesses]
+ return [b.id for b in self.businesses if b.id is not None]
@property
def team_uuids(self) -> Optional[List[UUIDStr]]:
@@ -236,7 +236,7 @@ class GRUser(BaseModel):
LOG.warning("prefetch not run")
return None
- return [t.id for t in self.teams]
+ return [t.id for t in self.teams if t.id is not None]
@property
def product_uuids(self) -> Optional[List[UUIDStr]]:
@@ -294,7 +294,7 @@ class GRUser(BaseModel):
return GRUser.model_validate(d)
@classmethod
- def from_redis(cls, d: Union[str, Dict]) -> Self:
+ def from_redis(cls, d: Union[str, Dict[str, Any]]) -> Self:
if isinstance(d, str):
d = json.loads(d)
assert isinstance(d, dict)
@@ -359,13 +359,13 @@ class GRToken(BaseModel):
# --- Properties ---
@property
- def auth_header(self, key_name="Authorization") -> Dict:
+ def auth_header(self, key_name: str = "Authorization") -> Dict[str, str]:
return {key_name: self.key}
# --- ORM ---
@classmethod
- def from_redis(cls, d: Union[str, Dict]) -> Self:
+ def from_redis(cls, d: Union[str, Dict[str, Any]]) -> Self:
if isinstance(d, str):
d = json.loads(d)
assert isinstance(d, dict)
diff --git a/generalresearch/models/gr/business.py b/generalresearch/models/gr/business.py
index b07c584..8a37f74 100644
--- a/generalresearch/models/gr/business.py
+++ b/generalresearch/models/gr/business.py
@@ -6,7 +6,7 @@ import os
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
-from typing import Optional, List, TYPE_CHECKING, Union
+from typing import TYPE_CHECKING, List, Optional, Union
from uuid import uuid4
import pandas as pd
@@ -24,12 +24,11 @@ from generalresearch.incite.mergers.pop_ledger import PopLedgerMerge
from generalresearch.incite.schemas.mergers.pop_ledger import (
numerical_col_names,
)
-from generalresearch.utils.enum import ReprEnumMeta
from generalresearch.models.admin.request import ReportRequest, ReportType
from generalresearch.models.custom_types import (
+ AwareDatetime,
UUIDStr,
UUIDStrCoerce,
- AwareDatetime,
)
from generalresearch.models.thl.finance import POPFinancial
from generalresearch.models.thl.ledger import LedgerAccount, OrderBy
@@ -37,12 +36,9 @@ from generalresearch.models.thl.payout import BusinessPayoutEvent
from generalresearch.pg_helper import PostgresConfig
from generalresearch.redis_helper import RedisConfig
from generalresearch.utils.aggregation import group_by_year
+from generalresearch.utils.enum import ReprEnumMeta
if TYPE_CHECKING:
- from generalresearch.models.thl.finance import BusinessBalances
-
- from generalresearch.models.thl.product import Product
- from generalresearch.models.gr.team import Team
from generalresearch.incite.base import GRLDatasets
from generalresearch.incite.mergers.foundations.enriched_session import (
EnrichedSessionMerge,
@@ -59,7 +55,9 @@ if TYPE_CHECKING:
from generalresearch.managers.thl.payout import (
BusinessPayoutEventManager,
)
- from generalresearch.managers.thl.product import ProductManager
+ from generalresearch.models.gr.team import Team
+ from generalresearch.models.thl.finance import BusinessBalances
+ from generalresearch.models.thl.product import Product
class TransferMethod(Enum, metaclass=ReprEnumMeta):
@@ -248,7 +246,7 @@ class Business(BaseModel):
with pg_config.make_connection() as conn:
with conn.cursor(row_factory=dict_row) as c:
c.execute(
- query=f"""
+ query="""
SELECT *
FROM common_businessaddress AS ba
WHERE ba.business_id = %s
@@ -271,7 +269,7 @@ class Business(BaseModel):
c: Cursor
c.execute(
- query=f"""
+ query="""
SELECT t.*
FROM common_team AS t
INNER JOIN common_team_businesses AS tb
@@ -359,9 +357,11 @@ class Business(BaseModel):
self.prefetch_products(thl_pg_config=thl_pg_config)
accounts: List[LedgerAccount] = lm.get_accounts_if_exists(
- qualified_names=[
- f"{lm.currency.value}:bp_wallet:{bpid}" for bpid in self.product_uuids
- ]
+ qualified_names=(
+ [f"{lm.currency.value}:bp_wallet:{bpid}" for bpid in self.product_uuids]
+ if self.product_uuids
+ else []
+ )
)
if len(accounts) != len(self.products):
@@ -711,7 +711,8 @@ class Business(BaseModel):
fields: List[str],
gr_redis_config: RedisConfig,
) -> Optional[Self]:
- keys: List = Business.required_fields() + fields
+ keys: List[str] = Business.required_fields() + fields
+
if "pop_financial" in keys:
# We should explicitly pass the pop_financial years we want. By default,
# at least get this year.
diff --git a/generalresearch/models/gr/team.py b/generalresearch/models/gr/team.py
index 38aff56..a3ac9cf 100644
--- a/generalresearch/models/gr/team.py
+++ b/generalresearch/models/gr/team.py
@@ -3,7 +3,7 @@ import os
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
-from typing import Optional, Union, List, TYPE_CHECKING
+from typing import TYPE_CHECKING, List, Optional, Union
from uuid import uuid4
import pandas as pd
@@ -25,7 +25,6 @@ from generalresearch.incite.mergers.foundations.enriched_session import (
from generalresearch.incite.mergers.foundations.enriched_wall import (
EnrichedWallMerge,
)
-from generalresearch.utils.enum import ReprEnumMeta
from generalresearch.models.admin.request import ReportRequest, ReportType
from generalresearch.models.custom_types import (
AwareDatetimeISO,
@@ -34,11 +33,12 @@ from generalresearch.models.custom_types import (
)
from generalresearch.pg_helper import PostgresConfig
from generalresearch.redis_helper import RedisConfig
+from generalresearch.utils.enum import ReprEnumMeta
if TYPE_CHECKING:
from generalresearch.incite.base import GRLDatasets
- from generalresearch.models.gr.business import Business
from generalresearch.models.gr.authentication import GRUser
+ from generalresearch.models.gr.business import Business
from generalresearch.models.thl.product import Product
@@ -77,7 +77,7 @@ class Membership(BaseModel):
"account was created."
)
- user_id: SkipJsonSchema[PositiveInt] = Field(default=None)
+ user_id: SkipJsonSchema[PositiveInt] = Field()
team_id: SkipJsonSchema[PositiveInt] = Field()
@@ -189,7 +189,7 @@ class Team(BaseModel):
)
try:
- test = pd.read_parquet(path, engine="pyarrow")
+ _ = pd.read_parquet(path, engine="pyarrow")
except Exception as e:
raise IOError(f"Parquet verification failed: {e}")
@@ -234,7 +234,7 @@ class Team(BaseModel):
)
try:
- test = pd.read_parquet(path, engine="pyarrow")
+ _ = pd.read_parquet(path, engine="pyarrow")
except Exception as e:
raise IOError(f"Parquet verification failed: {e}")
@@ -342,5 +342,6 @@ class Team(BaseModel):
res: List = rc.hmget(name=f"team:{uuid}", keys=keys)
d = {val: json.loads(res[idx]) for idx, val in enumerate(keys)}
return Team.model_validate(d)
+
except (Exception,) as e:
return None
diff --git a/generalresearch/models/innovate/question.py b/generalresearch/models/innovate/question.py
index 0fd2547..0d47ba7 100644
--- a/generalresearch/models/innovate/question.py
+++ b/generalresearch/models/innovate/question.py
@@ -4,9 +4,9 @@ from __future__ import annotations
import json
import logging
from enum import Enum
-from typing import List, Optional, Literal, Any, Dict
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
-from pydantic import BaseModel, Field, model_validator, field_validator
+from pydantic import BaseModel, Field, field_validator, model_validator
from generalresearch.models import Source
from generalresearch.models.innovate import InnovateQuestionID
@@ -15,6 +15,11 @@ from generalresearch.models.thl.profiling.marketplace import (
MarketplaceUserQuestionAnswer,
)
+if TYPE_CHECKING:
+ from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
+ )
+
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
@@ -47,8 +52,9 @@ class InnovateQuestionOption(BaseModel):
class InnovateQuestionType(str, Enum):
# API response: {'Multipunch', 'Numeric Open Ended', 'Single Punch'}
- # "Numeric Open Ended" must be wrong... It can't be numeric, as UK's postcode question is marked
- # as this, but it wants alphanumeric answers. So this is just text_entry.
+ # "Numeric Open Ended" must be wrong... It can't be numeric, as UK's
+ # postcode question is marked as this, but it wants alphanumeric
+ # answers. So this is just text_entry.
SINGLE_SELECT = "s"
MULTI_SELECT = "m"
@@ -135,7 +141,7 @@ class InnovateQuestion(MarketplaceQuestion):
@classmethod
def from_api(
- cls, d: dict, country_iso: str, language_iso: str
+ cls, d: Dict, country_iso: str, language_iso: str
) -> Optional["InnovateQuestion"]:
"""
:param d: Raw response from API
@@ -179,13 +185,15 @@ class InnovateQuestion(MarketplaceQuestion):
)
@classmethod
- def from_db(cls, d: dict) -> "InnovateQuestion":
+ def from_db(cls, d: Dict[str, Any]) -> "InnovateQuestion":
+
options = None
if d["options"]:
options = [
InnovateQuestionOption(id=r["id"], text=r["text"], order=r["order"])
for r in d["options"]
]
+
return cls(
question_id=d["question_id"],
question_key=d["question_key"],
@@ -204,13 +212,13 @@ class InnovateQuestion(MarketplaceQuestion):
d["options"] = json.dumps(d["options"])
return d
- def to_upk_question(self):
+ def to_upk_question(self) -> "UpkQuestion":
from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
UpkQuestionChoice,
- UpkQuestionType,
UpkQuestionSelectorMC,
UpkQuestionSelectorTE,
- UpkQuestion,
+ UpkQuestionType,
)
upk_type_selector_map = {
diff --git a/generalresearch/models/innovate/survey.py b/generalresearch/models/innovate/survey.py
index e230899..3a3d9e2 100644
--- a/generalresearch/models/innovate/survey.py
+++ b/generalresearch/models/innovate/survey.py
@@ -2,47 +2,47 @@ from __future__ import annotations
import json
import logging
-from datetime import timezone, date
+from datetime import date, timezone
from decimal import Decimal
from functools import cached_property
from typing import (
- Optional,
- Dict,
+ Annotated,
Any,
+ Dict,
List,
Literal,
+ Optional,
Set,
Tuple,
- Annotated,
Type,
)
from more_itertools import flatten
from pydantic import (
- Field,
- ConfigDict,
BaseModel,
- model_validator,
+ ConfigDict,
+ Field,
computed_field,
+ model_validator,
)
from typing_extensions import Self
from generalresearch.locales import Localelator
from generalresearch.models import (
- Source,
LogicalOperator,
+ Source,
TaskCalculationType,
)
from generalresearch.models.custom_types import (
- CoercedStr,
- AwareDatetimeISO,
AlphaNumStrSet,
+ AwareDatetimeISO,
+ CoercedStr,
DeviceTypes,
)
from generalresearch.models.innovate import (
- InnovateStatus,
- InnovateQuotaStatus,
InnovateDuplicateCheckLevel,
+ InnovateQuotaStatus,
+ InnovateStatus,
)
from generalresearch.models.innovate.question import InnovateQuestionID
from generalresearch.models.thl.demographics import Gender
@@ -266,7 +266,7 @@ class InnovateSurvey(MarketplaceTask):
return data
@classmethod
- def from_api(cls, d: Dict) -> Optional["InnovateSurvey"]:
+ def from_api(cls, d: Dict[str, Any]) -> Optional["InnovateSurvey"]:
try:
return cls._from_api(d)
except Exception as e:
@@ -274,7 +274,7 @@ class InnovateSurvey(MarketplaceTask):
return None
@classmethod
- def _from_api(cls, d: Dict):
+ def _from_api(cls, d: Dict[str, Any]) -> "InnovateSurvey":
d["conditions"] = dict()
# If we haven't hit the "detail" endpoint, we won't get this
@@ -334,7 +334,7 @@ class InnovateSurvey(MarketplaceTask):
)
return f"{self.__repr_name__()}({repr_str})"
- def is_unchanged(self, other) -> bool:
+ def is_unchanged(self, other: "InnovateSurvey") -> bool:
# Avoiding overloading __eq__ because it looks kind of complicated? I
# want to be explicit that this is not testing object equivalence,
# just that the objects don't require any db updates. We also exclude
@@ -465,6 +465,7 @@ class InnovateSurvey(MarketplaceTask):
return None, set(
flatten([m[1] for q, m in quota_eval.items() if m[0] is None])
)
+
return False, set()
def determine_eligibility(
diff --git a/generalresearch/models/innovate/task_collection.py b/generalresearch/models/innovate/task_collection.py
index 97e7a7c..4d7fe51 100644
--- a/generalresearch/models/innovate/task_collection.py
+++ b/generalresearch/models/innovate/task_collection.py
@@ -1,7 +1,7 @@
from typing import List, Set
import pandas as pd
-from pandera import Column, DataFrameSchema, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
from generalresearch.locales import Localelator
from generalresearch.models.innovate import InnovateStatus
diff --git a/generalresearch/models/legacy/bucket.py b/generalresearch/models/legacy/bucket.py
index 5afb17b..8ce559b 100644
--- a/generalresearch/models/legacy/bucket.py
+++ b/generalresearch/models/legacy/bucket.py
@@ -4,23 +4,23 @@ import logging
import math
from datetime import timedelta
from decimal import Decimal
-from typing import Optional, Dict, List, Union, Literal, Tuple
-from typing_extensions import Self
+from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import (
BaseModel,
+ ConfigDict,
Field,
+ NonNegativeInt,
field_validator,
model_validator,
- ConfigDict,
- NonNegativeInt,
)
+from typing_extensions import Self
from generalresearch.models import Source
from generalresearch.models.custom_types import (
HttpsUrl,
- UUIDStr,
PropertyCode,
+ UUIDStr,
)
from generalresearch.models.thl.stats import StatisticalSummary
@@ -181,8 +181,10 @@ class Bucket(BaseModel):
loi_q1: Optional[timedelta] = Field(strict=True, default=None)
loi_q2: Optional[timedelta] = Field(strict=True, default=None)
loi_q3: Optional[timedelta] = Field(strict=True, default=None)
- # decimal USD. This should not have more than 2 decimal places.
- # There is no way to make this "strict" and optional, so we have a separate pre-validator
+
+ # Decimal USD. This should not have more than 2 decimal places.
+ # There is no way to make this "strict" and optional, so
+ # we have a separate pre-validator
user_payout_min: Optional[Decimal] = Field(default=None, lt=1000, gt=0)
user_payout_max: Optional[Decimal] = Field(default=None, lt=1000, gt=0)
user_payout_q1: Optional[Decimal] = Field(default=None, lt=1000, gt=0)
@@ -358,7 +360,7 @@ class Bucket(BaseModel):
)
@classmethod
- def parse_from_offerwall_style2(cls, bucket: Dict):
+ def parse_from_offerwall_style2(cls, bucket: Dict[str, Any]):
# {'payout': {'min': 123}}
loi_min_sec = bucket.get("duration", {}).get("min")
loi_max_sec = bucket.get("duration", {}).get("max")
@@ -383,7 +385,7 @@ class Bucket(BaseModel):
)
@classmethod
- def parse_from_offerwall_style3(cls, bucket: Dict):
+ def parse_from_offerwall_style3(cls, bucket: Dict[str, Any]):
# {'payout': 123, 'duration': 123}
return cls(
user_payout_min=cls.usd_cents_to_decimal(bucket["payout"]),
@@ -397,13 +399,13 @@ class Bucket(BaseModel):
)
@staticmethod
- def usd_cents_to_decimal(v: int):
+ def usd_cents_to_decimal(v: Optional[int]) -> Optional[Decimal]:
if v is None:
return None
return Decimal(Decimal(int(v)) / Decimal(100))
@staticmethod
- def decimal_to_usd_cents(d: Decimal):
+ def decimal_to_usd_cents(d: Optional[Decimal]) -> Optional[Decimal]:
if d is None:
return None
return round(d * Decimal(100), 2)
@@ -437,7 +439,7 @@ class DurationSummary(StatisticalSummary):
}
@classmethod
- def from_bucket(cls, bucket: Bucket):
+ def from_bucket(cls, bucket: Bucket) -> "DurationSummary":
return cls(
min=bucket.loi_min.total_seconds(),
max=bucket.loi_max.total_seconds(),
@@ -610,18 +612,21 @@ class TopNPlusBucket(BucketBase):
def eligibility_ranks(cls, criteria):
criteria = list(criteria)
ranks = [c.rank for c in criteria]
+
if all(r is None for r in ranks):
for i, c in enumerate(criteria):
c.rank = i
return tuple(criteria)
+
if any(r is None for r in ranks):
raise ValueError("Set all or no ranks in eligibility_criteria")
if len(ranks) != len(set(ranks)):
raise ValueError("Duplicate ranks")
+
return tuple(sorted(criteria, key=lambda c: c.rank))
@classmethod
- def from_bucket(cls, bucket: Bucket):
+ def from_bucket(cls, bucket: Bucket) -> "TopNPlusBucket":
return cls.model_validate(
{
"id": bucket.id,
diff --git a/generalresearch/models/legacy/offerwall.py b/generalresearch/models/legacy/offerwall.py
index 67213f7..a8efe9a 100644
--- a/generalresearch/models/legacy/offerwall.py
+++ b/generalresearch/models/legacy/offerwall.py
@@ -1,22 +1,22 @@
from __future__ import annotations
-from typing import List, Dict
+from typing import Dict, List
-from pydantic import BaseModel, Field, ConfigDict, NonNegativeInt
+from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt
from generalresearch.models.custom_types import UUIDStr
from generalresearch.models.legacy.bucket import (
BucketBase,
- SoftPairBucket,
- TopNBucket,
- TimeBucksBucket,
MarketplaceBucket,
- TopNPlusBucket,
- SingleEntryBucket,
- WXETOfferwallBucket,
OneShotOfferwallBucket,
OneShotSoftPairOfferwallBucket,
+ SingleEntryBucket,
+ SoftPairBucket,
+ TimeBucksBucket,
+ TopNBucket,
+ TopNPlusBucket,
TopNPlusRecontactBucket,
+ WXETOfferwallBucket,
)
from generalresearch.models.legacy.definitions import OfferwallReason
from generalresearch.models.thl.payout_format import (
diff --git a/generalresearch/models/legacy/questions.py b/generalresearch/models/legacy/questions.py
index 1559f24..1a86121 100644
--- a/generalresearch/models/legacy/questions.py
+++ b/generalresearch/models/legacy/questions.py
@@ -1,21 +1,20 @@
from __future__ import annotations
-from typing import Dict, List, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
from pydantic import (
BaseModel,
+ BeforeValidator,
+ ConfigDict,
Field,
NonNegativeInt,
- model_validator,
StringConstraints,
- ConfigDict,
ValidationError,
- BeforeValidator,
field_validator,
+ model_validator,
)
from sentry_sdk import capture_exception
-from typing_extensions import Annotated
-from typing_extensions import Self
+from typing_extensions import Annotated, Self
from generalresearch.models.custom_types import UUIDStr
from generalresearch.models.legacy.api_status import StatusResponse
@@ -34,10 +33,10 @@ if TYPE_CHECKING:
class UpkQuestionResponse(StatusResponse):
questions: List[UpkQuestionOut] = Field()
- consent_questions: List[Dict] = Field(
+ consent_questions: List[Dict[str, Any]] = Field(
description="For internal use", default_factory=list
)
- special_questions: List[Dict] = Field(
+ special_questions: List[Dict[str, Any]] = Field(
description="For internal use", default_factory=list
)
count: NonNegativeInt = Field(description="The number of questions returned")
@@ -221,7 +220,7 @@ class UserQuestionAnswers(BaseModel):
def prefetch_user(self, um: "UserManager") -> None:
from generalresearch.models.thl.user import User
- res: User = um.get_user_if_exists(
+ res: Optional[User] = um.get_user_if_exists(
product_id=self.product_id, product_user_id=self.product_user_id
)
@@ -229,10 +228,11 @@ class UserQuestionAnswers(BaseModel):
raise ValidationError("Invalid user")
self.user = res
+ return None
def prefetch_wall(self, wm: "WallManager") -> None:
- from generalresearch.models.thl.session import Wall
from generalresearch.models import Source
+ from generalresearch.models.thl.session import Wall
res: Optional[Wall] = wm.get_from_uuid_if_exists(wall_uuid=self.session_id)
@@ -252,3 +252,4 @@ class UserQuestionAnswers(BaseModel):
raise ValueError("Not a valid GRS event status")
self.wall = res
+ return None
diff --git a/generalresearch/models/lucid/question.py b/generalresearch/models/lucid/question.py
index b3d9b27..8542156 100644
--- a/generalresearch/models/lucid/question.py
+++ b/generalresearch/models/lucid/question.py
@@ -2,9 +2,9 @@ from __future__ import annotations
import logging
from enum import Enum
-from typing import List, Optional, Literal, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
-from pydantic import BaseModel, Field, model_validator, field_validator
+from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from generalresearch.models import Source
@@ -94,7 +94,7 @@ class LucidQuestion(MarketplaceQuestion):
return options
@classmethod
- def from_db(cls, d: dict) -> Self:
+ def from_db(cls, d: Dict[str, Any]) -> Self:
options = None
if d["options"]:
options = [
@@ -112,11 +112,11 @@ class LucidQuestion(MarketplaceQuestion):
def to_upk_question(self) -> "UpkQuestion":
from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
UpkQuestionChoice,
- UpkQuestionType,
UpkQuestionSelectorMC,
UpkQuestionSelectorTE,
- UpkQuestion,
+ UpkQuestionType,
)
upk_type_selector_map = {
diff --git a/generalresearch/models/lucid/survey.py b/generalresearch/models/lucid/survey.py
index 6d81254..0e01243 100644
--- a/generalresearch/models/lucid/survey.py
+++ b/generalresearch/models/lucid/survey.py
@@ -1,20 +1,20 @@
from __future__ import annotations
-from typing import Optional, Dict, Set, Tuple, List
+from typing import Any, Dict, List, Optional, Self, Set, Tuple
-from pydantic import NonNegativeInt, Field, ConfigDict, BaseModel
+from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt
from generalresearch.models import Source
from generalresearch.models.custom_types import (
AwareDatetimeISO,
- UUIDStr,
- CoercedStr,
BigAutoInteger,
+ CoercedStr,
+ UUIDStr,
)
from generalresearch.models.thl.locales import CountryISO, LanguageISO
from generalresearch.models.thl.survey.condition import (
- MarketplaceCondition,
ConditionValueType,
+ MarketplaceCondition,
)
@@ -41,7 +41,7 @@ class LucidCondition(MarketplaceCondition):
return hash(self.id)
@classmethod
- def from_mysql(cls, x):
+ def from_mysql(cls, x: Dict[str, Any]) -> Self:
x["value_type"] = ConditionValueType.LIST
x["negate"] = False
x["values"] = x.pop("pre_codes").split("|")
diff --git a/generalresearch/models/marketplace/summary.py b/generalresearch/models/marketplace/summary.py
index 0dd3404..49cf7e9 100644
--- a/generalresearch/models/marketplace/summary.py
+++ b/generalresearch/models/marketplace/summary.py
@@ -1,10 +1,10 @@
from __future__ import annotations
from abc import ABC
-from typing import List, Optional, Dict, Literal, Collection
+from typing import Collection, Dict, List, Literal, Optional
import numpy as np
-from pydantic import BaseModel, Field, computed_field, ConfigDict
+from pydantic import BaseModel, ConfigDict, Field, computed_field
from typing_extensions import Self
from generalresearch.models.thl.stats import StatisticalSummary
diff --git a/generalresearch/models/morning/question.py b/generalresearch/models/morning/question.py
index 77ea209..6dca8c0 100644
--- a/generalresearch/models/morning/question.py
+++ b/generalresearch/models/morning/question.py
@@ -1,17 +1,17 @@
import json
from enum import Enum
-from typing import List, Optional, Dict, Literal, Any
+from typing import Any, Dict, List, Literal, Optional
from uuid import UUID
-from pydantic import BaseModel, Field, model_validator, field_validator
+from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from generalresearch.locales import Localelator
from generalresearch.models import Source
from generalresearch.models.morning import MorningQuestionID
from generalresearch.models.thl.profiling.marketplace import (
- MarketplaceUserQuestionAnswer,
MarketplaceQuestion,
+ MarketplaceUserQuestionAnswer,
)
# todo: we could validate that the country_iso / language_iso exists ...
@@ -119,13 +119,14 @@ class MorningQuestion(MarketplaceQuestion):
return options
@classmethod
- def from_api(cls, d: dict, country_iso: str, language_iso: str):
+ def from_api(cls, d: Dict[str, Any], country_iso: str, language_iso: str):
options = None
if d.get("responses"):
options = [
MorningQuestionOption(id=r["id"], text=r["text"], order=order)
for order, r in enumerate(d["responses"])
]
+
return cls(
id=d["id"],
name=d["name"],
@@ -137,7 +138,7 @@ class MorningQuestion(MarketplaceQuestion):
)
@classmethod
- def from_db(cls, d: dict):
+ def from_db(cls, d: Dict[str, Any]) -> Self:
options = None
if d["options"]:
options = [
@@ -146,6 +147,7 @@ class MorningQuestion(MarketplaceQuestion):
)
for r in d["options"]
]
+
return cls(
id=d["question_id"],
name=d["question_name"],
@@ -167,12 +169,12 @@ class MorningQuestion(MarketplaceQuestion):
def to_upk_question(self):
from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
UpkQuestionChoice,
- UpkQuestionType,
+ UpkQuestionSelectorHIDDEN,
UpkQuestionSelectorMC,
UpkQuestionSelectorTE,
- UpkQuestionSelectorHIDDEN,
- UpkQuestion,
+ UpkQuestionType,
)
upk_type_selector_map = {
diff --git a/generalresearch/models/morning/survey.py b/generalresearch/models/morning/survey.py
index 6f61661..bc71e58 100644
--- a/generalresearch/models/morning/survey.py
+++ b/generalresearch/models/morning/survey.py
@@ -6,26 +6,26 @@ from datetime import timezone
from decimal import Decimal
from functools import cached_property
from typing import (
- Optional,
- Dict,
+ Annotated,
Any,
+ Dict,
List,
+ Literal,
+ Optional,
Set,
- Annotated,
Tuple,
- Literal,
Type,
)
from pydantic import (
- Field,
- ConfigDict,
BaseModel,
- computed_field,
+ ConfigDict,
+ Field,
NonNegativeInt,
- model_validator,
PositiveInt,
PrivateAttr,
+ computed_field,
+ model_validator,
)
from typing_extensions import Self
@@ -35,7 +35,7 @@ from generalresearch.models.custom_types import (
AwareDatetimeISO,
UUIDStrCoerce,
)
-from generalresearch.models.morning import MorningStatus, MorningQuestionID
+from generalresearch.models.morning import MorningQuestionID, MorningStatus
from generalresearch.models.morning.question import MorningQuestion
from generalresearch.models.thl.demographics import Gender
from generalresearch.models.thl.locales import (
@@ -387,7 +387,7 @@ class MorningBid(MorningTaskStatistics):
@model_validator(mode="before")
@classmethod
- def setup_quota_fields(cls, data: dict) -> dict:
+ def setup_quota_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]:
# These fields get "inherited" by each quota from its bid.
quota_fields = [
"country_iso",
@@ -396,10 +396,12 @@ class MorningBid(MorningTaskStatistics):
"bid_loi",
"used_question_ids",
]
+
for quota in data["quotas"]:
for field in quota_fields:
if field not in quota:
quota[field] = data[field]
+
return data
@model_validator(mode="before")
@@ -417,9 +419,10 @@ class MorningBid(MorningTaskStatistics):
@model_validator(mode="before")
@classmethod
- def setup_conditions(cls, data: dict) -> dict:
+ def setup_conditions(cls, data: Dict[str, Any]) -> Dict[str, Any]:
if "conditions" in data:
return data
+
data["conditions"] = dict()
for quota in data["quotas"]:
if "qualifications" in quota:
@@ -445,7 +448,7 @@ class MorningBid(MorningTaskStatistics):
@model_validator(mode="before")
@classmethod
- def clean_alias(cls, data: Dict) -> Dict:
+ def clean_alias(cls, data: Dict[str, Any]) -> Dict[str, Any]:
# Make sure fields are named certain ways, so we don't have to check
# aliases within other validators
if "estimated_length_of_interview" in data:
@@ -500,7 +503,7 @@ class MorningBid(MorningTaskStatistics):
return d
@classmethod
- def from_db(cls, d: Dict[str, Any]):
+ def from_db(cls, d: Dict[str, Any]) -> Self:
d["created"] = d["created"].replace(tzinfo=timezone.utc)
d["updated"] = d["updated"].replace(tzinfo=timezone.utc)
d["expected_end"] = d["expected_end"].replace(tzinfo=timezone.utc)
@@ -522,10 +525,12 @@ class MorningBid(MorningTaskStatistics):
self, criteria_evaluation: Dict[str, Optional[bool]]
) -> Tuple[Optional[bool], Optional[List[str]], Optional[Set[str]]]:
"""
- Quotas are mutually-exclusive. A user can only possibly match 1 quota. As such, all unknown
- questions on any quota will be the same unknowns on all.
- Returns (the eligibility (True/False/None), passing quota ID or None (if eligibility is not True),
- unknown_hashes (or None))
+ Quotas are mutually-exclusive. A user can only possibly match 1
+ quota. As such, all unknown questions on any quota will be
+ the same unknowns on all.
+
+ Returns (the eligibility (True/False/None), passing quota ID or
+ None (if eligibility is not True), unknown_hashes (or None))
"""
unknown_quotas = []
unknown_hashes = set()
diff --git a/generalresearch/models/morning/task_collection.py b/generalresearch/models/morning/task_collection.py
index 174f5e1..8e1eb4b 100644
--- a/generalresearch/models/morning/task_collection.py
+++ b/generalresearch/models/morning/task_collection.py
@@ -1,14 +1,14 @@
from typing import List, Set
import pandas as pd
-from pandera import Column, DataFrameSchema, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
from generalresearch.locales import Localelator
from generalresearch.models.morning import MorningStatus
from generalresearch.models.morning.survey import MorningBid
from generalresearch.models.thl.survey.task_collection import (
- create_empty_df_from_schema,
TaskCollection,
+ create_empty_df_from_schema,
)
COUNTRY_ISOS: Set[str] = Localelator().get_all_countries()
@@ -130,10 +130,11 @@ class MorningTaskCollection(TaskCollection):
rows.append(d)
return rows
- def to_df(self):
+ def to_df(self) -> pd.DataFrame:
rows = []
for s in self.items:
rows.extend(self.to_rows(s))
+
if rows:
return pd.DataFrame.from_records(rows, index="quota_id")
else:
diff --git a/generalresearch/models/pollfish/question.py b/generalresearch/models/pollfish/question.py
index 30f8088..a89a793 100644
--- a/generalresearch/models/pollfish/question.py
+++ b/generalresearch/models/pollfish/question.py
@@ -2,13 +2,18 @@
import json
import logging
from enum import Enum
-from typing import List, Optional, Literal, Any, Dict
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Self
from pydantic import BaseModel, Field, model_validator
from generalresearch.models import Source
from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion
+if TYPE_CHECKING:
+ from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
+ )
+
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
@@ -73,7 +78,7 @@ class PollfishQuestion(MarketplaceQuestion):
return self
@classmethod
- def from_db(cls, d: dict):
+ def from_db(cls, d: Dict[str, Any]) -> Self:
options = None
if d["options"]:
options = [
@@ -97,13 +102,13 @@ class PollfishQuestion(MarketplaceQuestion):
d["options"] = json.dumps(d["options"])
return d
- def to_upk_question(self):
+ def to_upk_question(self) -> "UpkQuestion":
from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
UpkQuestionChoice,
- UpkQuestionType,
UpkQuestionSelectorMC,
UpkQuestionSelectorTE,
- UpkQuestion,
+ UpkQuestionType,
order_exclusive_options,
)
diff --git a/generalresearch/models/precision/question.py b/generalresearch/models/precision/question.py
index 5750bb6..9673f54 100644
--- a/generalresearch/models/precision/question.py
+++ b/generalresearch/models/precision/question.py
@@ -2,9 +2,9 @@
import json
import logging
from enum import Enum
-from typing import List, Optional, Literal, Any, Dict
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Self
-from pydantic import BaseModel, Field, model_validator, field_validator
+from pydantic import BaseModel, Field, field_validator, model_validator
from generalresearch.models import Source, string_utils
from generalresearch.models.precision import PrecisionQuestionID
@@ -13,6 +13,11 @@ from generalresearch.models.thl.profiling.marketplace import (
MarketplaceUserQuestionAnswer,
)
+if TYPE_CHECKING:
+ from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
+ )
+
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
@@ -99,7 +104,7 @@ class PrecisionQuestion(MarketplaceQuestion):
return self
@classmethod
- def from_api(cls, d: dict) -> Optional["PrecisionQuestion"]:
+ def from_api(cls, d: Dict[str, Any]) -> Optional["PrecisionQuestion"]:
"""
:param d: Raw response from API
"""
@@ -110,7 +115,7 @@ class PrecisionQuestion(MarketplaceQuestion):
return None
@classmethod
- def _from_api(cls, d: dict) -> "PrecisionQuestion":
+ def _from_api(cls, d: Dict[str, Any]) -> "PrecisionQuestion":
question_type = PrecisionQuestionType.from_api(d["question_type_name"])
# sometimes an empty option is returned .... ?
options = [
@@ -132,7 +137,7 @@ class PrecisionQuestion(MarketplaceQuestion):
)
@classmethod
- def from_db(cls, d: dict):
+ def from_db(cls, d: Dict[str, Any]) -> "PrecisionQuestion":
options = None
if d["options"]:
options = [
@@ -156,13 +161,13 @@ class PrecisionQuestion(MarketplaceQuestion):
d["options"] = json.dumps(d["options"])
return d
- def to_upk_question(self):
+ def to_upk_question(self) -> "UpkQuestion":
from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
UpkQuestionChoice,
- UpkQuestionType,
UpkQuestionSelectorMC,
UpkQuestionSelectorTE,
- UpkQuestion,
+ UpkQuestionType,
order_exclusive_options,
)
diff --git a/generalresearch/models/precision/survey.py b/generalresearch/models/precision/survey.py
index 847093b..646d60e 100644
--- a/generalresearch/models/precision/survey.py
+++ b/generalresearch/models/precision/survey.py
@@ -3,14 +3,14 @@ from __future__ import annotations
import json
from datetime import timezone
from functools import cached_property
-from typing import Optional, List, Literal, Set, Dict, Any, Tuple, Type
+from typing import Any, Dict, List, Literal, Optional, Self, Set, Tuple, Type
from more_itertools import flatten
from pydantic import (
+ BaseModel,
ConfigDict,
Field,
PrivateAttr,
- BaseModel,
computed_field,
model_validator,
)
@@ -18,18 +18,18 @@ from typing_extensions import Annotated
from generalresearch.models import Source
from generalresearch.models.custom_types import (
- CoercedStr,
- UUIDStrCoerce,
- AwareDatetimeISO,
AlphaNumStrSet,
+ AwareDatetimeISO,
+ CoercedStr,
DeviceTypes,
+ UUIDStrCoerce,
)
from generalresearch.models.precision import PrecisionQuestionID, PrecisionStatus
from generalresearch.models.thl.demographics import Gender
from generalresearch.models.thl.survey import MarketplaceTask
from generalresearch.models.thl.survey.condition import (
- MarketplaceCondition,
ConditionValueType,
+ MarketplaceCondition,
)
@@ -137,7 +137,7 @@ class PrecisionSurvey(MarketplaceTask):
# complete_pct: float = Field(ge=0, le=1, validation_alias="cp")
# Also skipping: ismultiple (allowing multiple entrances). How is that even possible? They are all False anyways.
- bid_loi: int = Field(default=None, ge=59, le=120 * 60, validation_alias="loi")
+ bid_loi: int = Field(ge=59, le=120 * 60, validation_alias="loi")
bid_ir: float = Field(ge=0, le=1, validation_alias="ir")
# Be careful with this, it doesn't make any sense. See survey 452481, has 12 completes with a 100% live_ir,
# but the only quotas have 0 completes and 1052 terms. .... ??
@@ -257,12 +257,12 @@ class PrecisionSurvey(MarketplaceTask):
)
return f"{self.__repr_name__()}({repr_str})"
- def is_unchanged(self, other):
+ def is_unchanged(self, other) -> bool:
return self.model_dump(
exclude={"updated", "conditions", "created"}
) == other.model_dump(exclude={"updated", "conditions", "created"})
- def to_mysql(self):
+ def to_mysql(self) -> Dict[str, Any]:
d = self.model_dump(
mode="json",
exclude={
@@ -283,7 +283,7 @@ class PrecisionSurvey(MarketplaceTask):
return d
@classmethod
- def from_db(cls, d: Dict[str, Any]):
+ def from_db(cls, d: Dict[str, Any]) -> Self:
d["created"] = d["created"].replace(tzinfo=timezone.utc)
d["updated"] = d["updated"].replace(tzinfo=timezone.utc)
d["expected_end_date"] = (
@@ -334,13 +334,13 @@ class PrecisionSurvey(MarketplaceTask):
else:
# we don't match any quotas, so everything is unknown
return None, set(
- flatten([m[1] for q, m in quota_eval.items() if m[0] is None])
+ flatten([m[1] for _, m in quota_eval.items() if m[0] is None])
)
if True in evals:
return True, set()
if None in evals:
return None, set(
- flatten([m[1] for q, m in quota_eval.items() if m[0] is None])
+ flatten([m[1] for _, m in quota_eval.items() if m[0] is None])
)
return False, set()
diff --git a/generalresearch/models/precision/task_collection.py b/generalresearch/models/precision/task_collection.py
index e2942b5..bd7b274 100644
--- a/generalresearch/models/precision/task_collection.py
+++ b/generalresearch/models/precision/task_collection.py
@@ -1,7 +1,7 @@
-from typing import List
+from typing import Any, Dict, List
import pandas as pd
-from pandera import Column, DataFrameSchema, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
from generalresearch.locales import Localelator
from generalresearch.models.precision import PrecisionStatus
@@ -56,7 +56,7 @@ class PrecisionTaskCollection(TaskCollection):
items: List[PrecisionSurvey]
_schema = PrecisionTaskCollectionSchema
- def to_row(self, s: PrecisionSurvey):
+ def to_row(self, s: PrecisionSurvey) -> Dict[str, Any]:
d = s.model_dump(
mode="json",
exclude={
diff --git a/generalresearch/models/prodege/__init__.py b/generalresearch/models/prodege/__init__.py
index c7bc4e7..d419c0c 100644
--- a/generalresearch/models/prodege/__init__.py
+++ b/generalresearch/models/prodege/__init__.py
@@ -30,8 +30,10 @@ class ProdegePastParticipationType(str, Enum):
# This is the value of the 'status' url param in the redirect
# https://developer.prodege.com/surveys-feed/redirects
-# Note: there is no status for ProdegePastParticipationType.CLICK b/c that would be an abandonent
+# Note: there is no status for ProdegePastParticipationType.CLICK b/c
+# that would be an abandonent
# Note: there is no ProdegePastParticipationType for quality (status 4)
ProdgeRedirectStatus = Literal["1", "2", "3", "4"]
-# I'm not using the ProdegePastParticipationType for the values here b/c there is not a 1-to-1 mapping.
+# I'm not using the ProdegePastParticipationType for the values here
+# b/c there is not a 1-to-1 mapping.
ProdgeRedirectStatusNameMap = {"1": "complete", "2": "oq", "3": "dq", "4": "quality"}
diff --git a/generalresearch/models/prodege/question.py b/generalresearch/models/prodege/question.py
index 0ae4548..e3cdf08 100644
--- a/generalresearch/models/prodege/question.py
+++ b/generalresearch/models/prodege/question.py
@@ -6,16 +6,21 @@ import logging
from datetime import datetime, timezone
from enum import Enum
from functools import cached_property
-from typing import List, Optional, Literal, Any, Dict, Set
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Self, Set
-from pydantic import BaseModel, Field, model_validator, ConfigDict, PositiveInt
+from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator
from generalresearch.locales import Localelator
-from generalresearch.models import Source, MAX_INT32
+from generalresearch.models import MAX_INT32, Source
from generalresearch.models.custom_types import AwareDatetimeISO
from generalresearch.models.prodege import ProdegeQuestionIdType
from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion
+if TYPE_CHECKING:
+ from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
+ )
+
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
@@ -24,23 +29,29 @@ locale_helper = Localelator()
class ProdegeUserQuestionAnswer(BaseModel):
- # This is optional b/c this model can be used for eligibility checks for "anonymous" users, which are represented
- # by a list of question answers not associated with an actual user. No default b/c we must explicitly set
- # the field to None.
+ # This is optional b/c this model can be used for eligibility checks
+ # for "anonymous" users, which are represented by a list of question
+ # answers not associated with an actual user. No default b/c we must
+ # explicitly set the field to None.
user_id: Optional[PositiveInt] = Field(lt=MAX_INT32)
question_id: ProdegeQuestionIdType = Field()
- # This is optional b/c we do not need it when writing these to the db. When these are fetched from the db
- # for use in yield-management, we read this field from the prodege_question table.
+
+ # This is optional b/c we do not need it when writing these to the
+ # db. When these are fetched from the db for use in yield-management,
+ # we read this field from the prodege_question table.
question_type: Optional[ProdegeQuestionType] = Field(default=None)
+
# This may be a pipe-separated string if the question_type is multi. regex means any chars except capital letters
option_id: str = Field(pattern=r"^[^A-Z]*$")
created: AwareDatetimeISO = Field(
default_factory=lambda: datetime.now(tz=timezone.utc)
)
+
# ISO 3166-1 alpha-2 (two-letter codes, lowercase)
country_iso: str = Field(
max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True
)
+
# 3-char ISO 639-2/B, lowercase
language_iso: str = Field(
max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True
@@ -72,10 +83,12 @@ class ProdegeQuestionOption(BaseModel):
validation_alias="option_text",
description="The response text shown to respondents",
)
- # Order does not come back explicitly in the API, but the responses seem to be ordered
+ # Order does not come back explicitly in the API, but the
+ # responses seem to be ordered
order: int = Field()
- # Both is_exclusive and is_anchored are returned, but I don't see how they are different.
- # We are merging them both into is_exclusive.
+
+ # Both is_exclusive and is_anchored are returned, but I don't see how
+ # they are different. We are merging them both into is_exclusive.
is_exclusive: bool = Field(default=False)
@@ -126,7 +139,9 @@ class ProdegeQuestion(MarketplaceQuestion):
return self
@classmethod
- def from_api(cls, d: dict, country_iso: str) -> Optional["ProdegeQuestion"]:
+ def from_api(
+ cls, d: Dict[str, Any], country_iso: str
+ ) -> Optional["ProdegeQuestion"]:
"""
:param d: Raw response from API
"""
@@ -137,7 +152,7 @@ class ProdegeQuestion(MarketplaceQuestion):
return None
@classmethod
- def _from_api(cls, d: dict, country_iso: str) -> "ProdegeQuestion":
+ def _from_api(cls, d: Dict[str, Any], country_iso: str) -> "ProdegeQuestion":
# The API has no concept of language at all. Questions for a country
# are returned both in english and other languages. Questions do have
# a field 'country_specific', and if True, that generally means the
@@ -170,7 +185,7 @@ class ProdegeQuestion(MarketplaceQuestion):
return cls.model_validate(d)
@classmethod
- def from_db(cls, d: dict):
+ def from_db(cls, d: Dict[str, Any]) -> "ProdegeQuestion":
options = None
if d["options"]:
options = [
@@ -200,13 +215,13 @@ class ProdegeQuestion(MarketplaceQuestion):
d["options"] = json.dumps(d["options"])
return d
- def to_upk_question(self):
+ def to_upk_question(self) -> "UpkQuestion":
from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
UpkQuestionChoice,
- UpkQuestionType,
UpkQuestionSelectorMC,
UpkQuestionSelectorTE,
- UpkQuestion,
+ UpkQuestionType,
order_exclusive_options,
)
diff --git a/generalresearch/models/prodege/survey.py b/generalresearch/models/prodege/survey.py
index 378f0a7..eaae883 100644
--- a/generalresearch/models/prodege/survey.py
+++ b/generalresearch/models/prodege/survey.py
@@ -4,34 +4,34 @@ from __future__ import annotations
import json
import logging
from collections import defaultdict
-from datetime import timezone, datetime
+from datetime import datetime, timezone
from decimal import Decimal
from functools import cached_property
-from typing import List, Optional, Dict, Any, Set, Literal, Tuple, Type
+from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type
from pydantic import (
BaseModel,
- Field,
ConfigDict,
+ Field,
computed_field,
- model_validator,
field_validator,
+ model_validator,
)
from generalresearch.locales import Localelator
from generalresearch.models import LogicalOperator, Source, TaskCalculationType
from generalresearch.models.custom_types import (
AlphaNumStrSet,
+ AwareDatetimeISO,
+ CoercedStr,
InclExcl,
UUIDStr,
- CoercedStr,
- AwareDatetimeISO,
)
from generalresearch.models.prodege import (
- ProdegeStatus,
+ ProdegePastParticipationType,
ProdegeQuestionIdType,
+ ProdegeStatus,
ProdgeRedirectStatus,
- ProdegePastParticipationType,
)
from generalresearch.models.prodege.definitions import PG_COUNTRY_TO_ISO
from generalresearch.models.thl.demographics import Gender
@@ -151,15 +151,18 @@ class ProdegeQuota(BaseModel):
}
@classmethod
- def from_api(cls, d: Dict):
+ def from_api(cls, d: Dict[str, Any]) -> "ProdegeQuota":
# the API doesn't handle None's correctly? idk
if d["parent_quota_id"] == 0:
d["parent_quota_id"] = None
+
d["calculation_type"] = TaskCalculationType.prodege_from_api(
d["calculation_type"]
)
+
if d.get("country_id"):
d["country_iso"] = PG_COUNTRY_TO_ISO[d["country_id"]]
+
return cls.model_validate(d)
def passes(
@@ -174,12 +177,13 @@ class ProdegeQuota(BaseModel):
self, criteria_evaluation: Dict[str, Optional[bool]], country_iso: str
) -> bool:
# Match means we meet all conditions.
- # We can "match" a quota that is closed. In that case, we would fail the parent quota
+ # We can "match" a quota that is closed. In that case, we would
+ # fail the parent quota
return self.matches_country(country_iso) and all(
criteria_evaluation.get(c) for c in self.condition_hashes
)
- def matches_country(self, country_iso: str):
+ def matches_country(self, country_iso: str) -> bool:
return self.country_iso is None or self.country_iso == country_iso
def passes_verbose(
@@ -224,9 +228,12 @@ class ProdegeMaxClicksSetting(BaseModel):
# The total number of clicks allowed before survey traffic is paused.
cap: int = Field(validation_alias="max_clicks_cap")
+
# The current remaining number of clicks before survey traffic is paused.
allowed_clicks: int = Field(validation_alias="max_clicks_allowed_clicks")
- # The refill rate id for clicks (1: every 30 min, 2: every 1 hour, 3: every 24 hours, 0: one-time setting).
+
+ # The refill rate id for clicks (1: every 30 min, 2: every 1 hour,
+ # 3: every 24 hours, 0: one-time setting).
# (not going to bother structuring this, we can't really use it...)
max_click_rate_id: int = Field(validation_alias="max_clicks_max_click_rate_id")
@@ -243,24 +250,31 @@ class ProdegeUserPastParticipation(BaseModel):
@property
def participation_types(self) -> Set[ProdegePastParticipationType]:
- # If the survey is filtering completes, then only a complete counts. But if the survey is filtering
- # on clicks, then a person who got a complete ALSO did click. And so, the logic here is that
+ # If the survey is filtering completes, then only a complete
+ # counts. But if the survey is filtering on clicks, then a person
+ # who got a complete ALSO did click. And so, the logic here is that
# participation_types should always include "click".
if self.ext_status_code_1 is None:
return {ProdegePastParticipationType.CLICK}
+
elif self.ext_status_code_1 == "1":
return {
ProdegePastParticipationType.CLICK,
ProdegePastParticipationType.COMPLETE,
}
+
elif self.ext_status_code_1 == "2":
return {ProdegePastParticipationType.CLICK, ProdegePastParticipationType.OQ}
+
elif self.ext_status_code_1 == "3":
return {ProdegePastParticipationType.CLICK, ProdegePastParticipationType.DQ}
+
elif self.ext_status_code_1 == "4":
# 4 means "Quality Disqualification". unclear which participation type this is.
return {ProdegePastParticipationType.CLICK, ProdegePastParticipationType.DQ}
+ raise ValueError(f"Unknown ext_status_code_1: {self.ext_status_code_1}")
+
def days_ago(self) -> float:
now = datetime.now(timezone.utc)
return (now - self.started).total_seconds() / (3600 * 24)
@@ -285,15 +299,18 @@ class ProdegePastParticipation(BaseModel):
"""
@classmethod
- def from_api(cls, d: Dict):
+ def from_api(cls, d: Dict[str, Any]) -> "ProdegePastParticipation":
# the API doesn't handle None's correctly? idk
if d["in_past_days"] == 0:
d["in_past_days"] = None
+
d["participation_project_ids"] = list(map(str, d["participation_project_ids"]))
return cls.model_validate(d)
def user_participated(self, user_participation: ProdegeUserPastParticipation):
- # Given this user's participation event (1 single event), is it being filtered by this survey?
+ # Given this user's participation event (1 single event), is it
+ # being filtered by this survey?
+
return (
user_participation.survey_id in self.survey_ids
and (
@@ -493,7 +510,7 @@ class ProdegeSurvey(MarketplaceTask):
return round(float(v), 2)
@classmethod
- def from_api(cls, d: Dict) -> Optional["ProdegeSurvey"]:
+ def from_api(cls, d: Dict[str, Any]) -> Optional["ProdegeSurvey"]:
try:
return cls._from_api(d)
except Exception as e:
@@ -501,16 +518,19 @@ class ProdegeSurvey(MarketplaceTask):
return None
@classmethod
- def _from_api(cls, d: Dict):
- # handle phases. keys in api response are 'loi' and 'actual_ir'
+ def _from_api(cls, d: Dict[str, Any]) -> "ProdegeSurvey":
+
+ # Handle phases. keys in api response are 'loi' and 'actual_ir'
if d["phases"]["loi_phase"] == "actual":
d["actual_loi"] = d.pop("loi") * 60
else:
d["bid_loi"] = d.pop("loi") * 60
+
if d["phases"]["actual_ir_phase"] == "actual":
d["actual_ir"] = d.pop("actual_ir") / 100
else:
d["bid_ir"] = d.pop("actual_ir") / 100
+
d["conversion_rate"] = (
d["conversion_rate"] / 100 if d["conversion_rate"] else None
)
@@ -576,7 +596,7 @@ class ProdegeSurvey(MarketplaceTask):
sub_res.append(q)
return res
- def is_unchanged(self, other):
+ def is_unchanged(self, other) -> bool:
# Avoiding overloading __eq__ because it looks kind of complicated? I
# want to be explicit that this is not testing object equivalence,
# just that the objects don't require any db updates. We also exclude
@@ -593,7 +613,8 @@ class ProdegeSurvey(MarketplaceTask):
exclude={"created", "updated", "conditions", "survey_name"}
)
if o1 == o2:
- # We don't have to check bid/actual, b/c we already know it's not changed
+ # We don't have to check bid/actual, b/c we already know
+ # it's not changed
return True
# Ignore bid fields if either one is NULL
@@ -604,7 +625,7 @@ class ProdegeSurvey(MarketplaceTask):
return o1 == o2
- def to_mysql(self):
+ def to_mysql(self) -> Dict[str, Any]:
d = self.model_dump(
mode="json",
exclude={
@@ -618,6 +639,7 @@ class ProdegeSurvey(MarketplaceTask):
},
)
d["quotas"] = json.dumps(d["quotas"])
+
for k in [
"max_clicks_settings",
"past_participation",
@@ -625,13 +647,15 @@ class ProdegeSurvey(MarketplaceTask):
"exclude_psids",
]:
d[k] = json.dumps(d[k]) if d[k] else None
+
d["used_question_ids"] = json.dumps(d["used_question_ids"])
d["created"] = self.created
d["updated"] = self.updated
+
return d
@classmethod
- def from_db(cls, d: Dict[str, Any]):
+ def from_db(cls, d: Dict[str, Any]) -> "ProdegeSurvey":
d["created"] = d["created"].replace(tzinfo=timezone.utc)
d["updated"] = d["updated"].replace(tzinfo=timezone.utc)
d["quotas"] = json.loads(d["quotas"])
@@ -653,7 +677,7 @@ class ProdegeSurvey(MarketplaceTask):
self,
criteria_evaluation: Dict[str, Optional[bool]],
country_iso: str,
- verbose=False,
+ verbose: bool = False,
) -> bool:
# https://developer.prodege.com/surveys-feed/api-reference/survey-matching/quota-structure
# https://developer.prodege.com/surveys-feed/api-reference/survey-matching/quota-matching-requirements
@@ -700,7 +724,7 @@ class ProdegeSurvey(MarketplaceTask):
criteria_evaluation: Dict[str, Optional[bool]],
child_quotas: List[ProdegeQuota],
country_iso: str,
- verbose=False,
+ verbose: bool = False,
) -> bool:
if len(child_quotas) == 0:
# If the parent has no children, we pass
@@ -745,3 +769,5 @@ class ProdegeSurvey(MarketplaceTask):
criteria_evaluation, country_iso=country_iso, verbose=True
)
)
+
+ return None
diff --git a/generalresearch/models/prodege/task_collection.py b/generalresearch/models/prodege/task_collection.py
index 4765f29..0ea09e1 100644
--- a/generalresearch/models/prodege/task_collection.py
+++ b/generalresearch/models/prodege/task_collection.py
@@ -1,7 +1,7 @@
-from typing import List, Dict, Any
+from typing import Any, Dict, List
import pandas as pd
-from pandera import Column, DataFrameSchema, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
from generalresearch.locales import Localelator
from generalresearch.models.prodege import ProdegeStatus
diff --git a/generalresearch/models/repdata/question.py b/generalresearch/models/repdata/question.py
index 8b4eb4e..d33a4d8 100644
--- a/generalresearch/models/repdata/question.py
+++ b/generalresearch/models/repdata/question.py
@@ -4,22 +4,27 @@ import json
import logging
from enum import Enum
from functools import cached_property
-from typing import List, Optional, Literal, Any, Dict, Set
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
from uuid import UUID
from pydantic import (
BaseModel,
- Field,
- model_validator,
ConfigDict,
- field_validator,
+ Field,
PositiveInt,
+ field_validator,
+ model_validator,
)
-from generalresearch.models import Source, MAX_INT32
-from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO
+from generalresearch.models import MAX_INT32, Source
+from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion
+if TYPE_CHECKING:
+ from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
+ )
+
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
@@ -155,7 +160,7 @@ class RepDataQuestion(MarketplaceQuestion):
@classmethod
def from_api(
- cls, d: dict, country_iso: str, language_iso: str
+ cls, d: Dict[str, Any], country_iso: str, language_iso: str
) -> Optional["RepDataQuestion"]:
"""
:param d: Raw response from API
@@ -168,7 +173,7 @@ class RepDataQuestion(MarketplaceQuestion):
@classmethod
def _from_api(
- cls, d: dict, country_iso: str, language_iso: str
+ cls, d: Dict[str, Any], country_iso: str, language_iso: str
) -> "RepDataQuestion":
d["QualificationType"] = RepDataQuestionType.from_api(d["QualificationType"])
# zip code/age has a placeholder invalid option for some reason
@@ -186,7 +191,7 @@ class RepDataQuestion(MarketplaceQuestion):
)
@classmethod
- def from_db(cls, d: dict):
+ def from_db(cls, d: Dict[str, Any]) -> "RepDataQuestion":
options = None
if d["options"]:
options = [
@@ -212,13 +217,13 @@ class RepDataQuestion(MarketplaceQuestion):
d["options"] = json.dumps(d["options"])
return d
- def to_upk_question(self):
+ def to_upk_question(self) -> "UpkQuestion":
from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
UpkQuestionChoice,
- UpkQuestionType,
UpkQuestionSelectorMC,
UpkQuestionSelectorTE,
- UpkQuestion,
+ UpkQuestionType,
order_exclusive_options,
)
diff --git a/generalresearch/models/repdata/survey.py b/generalresearch/models/repdata/survey.py
index 572e696..b6acd92 100644
--- a/generalresearch/models/repdata/survey.py
+++ b/generalresearch/models/repdata/survey.py
@@ -6,38 +6,38 @@ import logging
from datetime import datetime, timezone
from decimal import Decimal
from functools import cached_property
-from typing import List, Optional, Dict, Any, Set, Literal, Type
-from typing_extensions import Self
+from typing import Any, Dict, List, Literal, Optional, Set, Type
from uuid import UUID
from pydantic import (
BaseModel,
- Field,
- field_validator,
ConfigDict,
+ Field,
computed_field,
+ field_validator,
model_validator,
)
+from typing_extensions import Self
from generalresearch.grpc import timestamp_from_datetime
from generalresearch.locales import Localelator
from generalresearch.models import (
+ DeviceType,
LogicalOperator,
Source,
TaskCalculationType,
- DeviceType,
)
from generalresearch.models.custom_types import (
+ AwareDatetimeISO,
CoercedStr,
UUIDStr,
- AwareDatetimeISO,
)
from generalresearch.models.repdata import RepDataStatus
from generalresearch.models.thl.demographics import Gender
from generalresearch.models.thl.survey import MarketplaceTask
from generalresearch.models.thl.survey.condition import (
- MarketplaceCondition,
ConditionValueType,
+ MarketplaceCondition,
)
logging.basicConfig()
@@ -328,7 +328,7 @@ class RepDataStream(MarketplaceTask):
}
@classmethod
- def from_api(cls, stream_res, country_iso, language_iso):
+ def from_api(cls, stream_res, country_iso: str, language_iso: str):
# qualifications and quotas need to be added to the stream_res manually
d = stream_res.copy()
d["CalculationType"] = TaskCalculationType.from_api(d["CalculationType"])
@@ -361,7 +361,9 @@ class RepDataStreamHashed(RepDataStream):
return d
@classmethod
- def from_db(cls, res, survey: RepDataSurveyHashed):
+ def from_db(
+ cls, res: Dict[str, Any], survey: RepDataSurveyHashed
+ ) -> "RepDataStreamHashed":
# We need certain fields copied over here so that a stream can exist
# independent of the survey
res["country_iso"] = survey.country_iso
@@ -531,7 +533,7 @@ class RepDataSurveyHashed(RepDataSurvey):
streams: None = Field(default=None, exclude=True)
@classmethod
- def from_db(cls, res):
+ def from_db(cls, res: Dict[str, Any]) -> "RepDataSurveyHashed":
res["allowed_devices"] = [
DeviceType(int(x)) for x in res["allowed_devices"].split(",")
]
@@ -553,6 +555,7 @@ class RepDataSurveyHashed(RepDataSurvey):
def to_grpc(self, repdata_pb2):
now = datetime.now(tz=timezone.utc)
timestamp = timestamp_from_datetime(now)
+
return repdata_pb2.RepDataOpportunity(
json_str=self.model_dump_json(),
timestamp=timestamp,
diff --git a/generalresearch/models/repdata/task_collection.py b/generalresearch/models/repdata/task_collection.py
index 7b99638..740670b 100644
--- a/generalresearch/models/repdata/task_collection.py
+++ b/generalresearch/models/repdata/task_collection.py
@@ -1,7 +1,7 @@
-from typing import List
+from typing import Any, Dict, List
import pandas as pd
-from pandera import Column, DataFrameSchema, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
from generalresearch.locales import Localelator
from generalresearch.models import TaskCalculationType
@@ -81,7 +81,7 @@ class RepDataTaskCollection(TaskCollection):
items: List[RepDataSurveyHashed]
_schema = RepDataTaskCollectionSchema
- def to_rows(self, s: RepDataSurveyHashed):
+ def to_rows(self, s: RepDataSurveyHashed) -> List[Dict[str, Any]]:
survey_fields = [
"survey_id",
"survey_uuid",
@@ -122,7 +122,7 @@ class RepDataTaskCollection(TaskCollection):
rows.append(ds)
return rows
- def to_df(self):
+ def to_df(self) -> pd.DataFrame:
rows = []
for s in self.items:
rows.extend(self.to_rows(s))
diff --git a/generalresearch/models/sago/question.py b/generalresearch/models/sago/question.py
index 8137b38..62ff363 100644
--- a/generalresearch/models/sago/question.py
+++ b/generalresearch/models/sago/question.py
@@ -4,21 +4,26 @@ import json
import logging
from enum import Enum
from functools import cached_property
-from typing import List, Optional, Literal, Any, Dict, Set
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
from pydantic import (
BaseModel,
+ ConfigDict,
Field,
- model_validator,
- field_validator,
PositiveInt,
- ConfigDict,
+ field_validator,
+ model_validator,
)
-from generalresearch.models import Source, string_utils, MAX_INT32
+from generalresearch.models import MAX_INT32, Source, string_utils
from generalresearch.models.custom_types import AwareDatetimeISO
from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion
+if TYPE_CHECKING:
+ from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
+ )
+
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
@@ -165,7 +170,7 @@ class SagoQuestion(MarketplaceQuestion):
@classmethod
def from_api(
- cls, d: dict, country_iso: str, language_iso: str
+ cls, d: Dict[str, Any], country_iso: str, language_iso: str
) -> Optional["SagoQuestion"]:
"""
:param d: Raw response from API
@@ -180,7 +185,9 @@ class SagoQuestion(MarketplaceQuestion):
return None
@classmethod
- def _from_api(cls, d: dict, country_iso, language_iso) -> "SagoQuestion":
+ def _from_api(
+ cls, d: Dict[str, Any], country_iso: str, language_iso: str
+ ) -> "SagoQuestion":
sago_category_to_tags = {
1: "Standard",
2: "Custom",
@@ -214,7 +221,7 @@ class SagoQuestion(MarketplaceQuestion):
)
@classmethod
- def from_db(cls, d: dict):
+ def from_db(cls, d: Dict[str, Any]) -> "SagoQuestion":
options = None
if d["options"]:
options = [
@@ -241,13 +248,13 @@ class SagoQuestion(MarketplaceQuestion):
d["options"] = json.dumps(d["options"])
return d
- def to_upk_question(self):
+ def to_upk_question(self) -> "UpkQuestion":
from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
UpkQuestionChoice,
- UpkQuestionType,
UpkQuestionSelectorMC,
UpkQuestionSelectorTE,
- UpkQuestion,
+ UpkQuestionType,
order_exclusive_options,
)
diff --git a/generalresearch/models/sago/survey.py b/generalresearch/models/sago/survey.py
index a2e9a8a..a8188b0 100644
--- a/generalresearch/models/sago/survey.py
+++ b/generalresearch/models/sago/survey.py
@@ -5,19 +5,19 @@ import logging
from datetime import timezone
from decimal import Decimal
from functools import cached_property
-from typing import Optional, Dict, Any, List, Literal, Set, Tuple, Annotated, Type
-from typing_extensions import Self
+from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Tuple, Type
from more_itertools import flatten
-from pydantic import Field, ConfigDict, BaseModel, model_validator, computed_field
+from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator
+from typing_extensions import Self
from generalresearch.locales import Localelator
-from generalresearch.models import Source, LogicalOperator
+from generalresearch.models import LogicalOperator, Source
from generalresearch.models.custom_types import (
- CoercedStr,
- AwareDatetimeISO,
- AlphaNumStrSet,
AlphaNumStr,
+ AlphaNumStrSet,
+ AwareDatetimeISO,
+ CoercedStr,
DeviceTypes,
IPLikeStrSet,
)
@@ -80,7 +80,7 @@ class SagoQuota(BaseModel):
return self.remaining_count >= min_open_spots
@classmethod
- def from_api(cls, d: Dict) -> Self:
+ def from_api(cls, d: Dict[str, Any]) -> Self:
return cls.model_validate(d)
def passes(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool:
@@ -259,7 +259,7 @@ class SagoSurvey(MarketplaceTask):
}
@classmethod
- def from_api(cls, d: Dict) -> Optional["SagoSurvey"]:
+ def from_api(cls, d: Dict[str, Any]) -> Optional["SagoSurvey"]:
try:
return cls._from_api(d)
except Exception as e:
@@ -267,7 +267,7 @@ class SagoSurvey(MarketplaceTask):
return None
@classmethod
- def _from_api(cls, d: Dict):
+ def _from_api(cls, d: Dict[str, Any]) -> "SagoSurvey":
return cls.model_validate(d)
def __repr__(self) -> str:
@@ -285,7 +285,7 @@ class SagoSurvey(MarketplaceTask):
)
return f"{self.__repr_name__()}({repr_str})"
- def is_unchanged(self, other):
+ def is_unchanged(self, other) -> bool:
# Avoiding overloading __eq__ because it looks kind of complicated? I
# want to be explicit that this is not testing object equivalence, just
# that the objects don't require any db updates. We also exclude
@@ -294,7 +294,7 @@ class SagoSurvey(MarketplaceTask):
exclude={"updated", "conditions", "created"}
) == other.model_dump(exclude={"updated", "conditions", "created"})
- def to_mysql(self):
+ def to_mysql(self) -> Dict[str, Any]:
d = self.model_dump(
mode="json",
exclude={
diff --git a/generalresearch/models/sago/task_collection.py b/generalresearch/models/sago/task_collection.py
index c2f168a..41d047e 100644
--- a/generalresearch/models/sago/task_collection.py
+++ b/generalresearch/models/sago/task_collection.py
@@ -1,7 +1,7 @@
-from typing import List, Set
+from typing import Any, Dict, List, Set
import pandas as pd
-from pandera import Column, DataFrameSchema, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
from generalresearch.locales import Localelator
from generalresearch.models.sago import SagoStatus
@@ -51,7 +51,7 @@ class SagoTaskCollection(TaskCollection):
items: List[SagoSurvey]
_schema = SagoTaskCollectionSchema
- def to_row(self, s: SagoSurvey):
+ def to_row(self, s: SagoSurvey) -> Dict[str, Any]:
d = s.model_dump(
mode="json",
exclude={
@@ -71,7 +71,7 @@ class SagoTaskCollection(TaskCollection):
d["cpi"] = float(s.cpi)
return d
- def to_df(self):
+ def to_df(self) -> pd.DataFrame:
rows = []
for s in self.items:
rows.append(self.to_row(s))
diff --git a/generalresearch/models/spectrum/question.py b/generalresearch/models/spectrum/question.py
index 549fc7e..4f8b5e1 100644
--- a/generalresearch/models/spectrum/question.py
+++ b/generalresearch/models/spectrum/question.py
@@ -6,25 +6,30 @@ import logging
from datetime import datetime, timezone
from enum import Enum
from functools import cached_property
-from typing import List, Optional, Literal, Any, Dict, Set
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
from uuid import UUID
from pydantic import (
BaseModel,
Field,
- model_validator,
- field_validator,
PositiveInt,
+ field_validator,
+ model_validator,
)
from typing_extensions import Self
-from generalresearch.models import Source, string_utils, MAX_INT32
+from generalresearch.models import MAX_INT32, Source, string_utils
from generalresearch.models.custom_types import AwareDatetimeISO
from generalresearch.models.spectrum import SpectrumQuestionIdType
from generalresearch.models.thl.profiling.marketplace import (
MarketplaceQuestion,
)
+if TYPE_CHECKING:
+ from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
+ )
+
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
@@ -246,7 +251,7 @@ class SpectrumQuestion(MarketplaceQuestion):
@classmethod
def from_api(
- cls, d: dict, country_iso: str, language_iso: str
+ cls, d: Dict[str, Any], country_iso: str, language_iso: str
) -> Optional["SpectrumQuestion"]:
# To not pollute our logs, we know we are skipping any question that
# meets the following conditions:
@@ -263,7 +268,7 @@ class SpectrumQuestion(MarketplaceQuestion):
return None
@classmethod
- def _from_api(cls, d: dict, country_iso: str, language_iso: str) -> Self:
+ def _from_api(cls, d: Dict[str, Any], country_iso: str, language_iso: str) -> Self:
options = None
if d.get("condition_codes"):
# Sometimes they use the key "name" instead of "text" ... ?
@@ -296,7 +301,7 @@ class SpectrumQuestion(MarketplaceQuestion):
)
@classmethod
- def from_db(cls, d: dict) -> Self:
+ def from_db(cls, d: Dict[str, Any]) -> Self:
options = None
if d["options"]:
options = [
@@ -331,13 +336,13 @@ class SpectrumQuestion(MarketplaceQuestion):
d["created"] = self.created.replace(tzinfo=None)
return d
- def to_upk_question(self):
+ def to_upk_question(self) -> "UpkQuestion":
from generalresearch.models.thl.profiling.upk_question import (
+ UpkQuestion,
UpkQuestionChoice,
- UpkQuestionType,
UpkQuestionSelectorMC,
UpkQuestionSelectorTE,
- UpkQuestion,
+ UpkQuestionType,
)
upk_type_selector_map = {
diff --git a/generalresearch/models/spectrum/survey.py b/generalresearch/models/spectrum/survey.py
index 7bebaa2..72093eb 100644
--- a/generalresearch/models/spectrum/survey.py
+++ b/generalresearch/models/spectrum/survey.py
@@ -4,20 +4,20 @@ import json
import logging
from datetime import timezone
from decimal import Decimal
-from typing import Optional, Dict, Any, List, Literal, Set, Tuple, Type
-from typing_extensions import Self
+from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type
from more_itertools import flatten
-from pydantic import Field, ConfigDict, BaseModel, model_validator, computed_field
+from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator
+from typing_extensions import Self
from generalresearch.locales import Localelator
-from generalresearch.models import TaskCalculationType, Source
+from generalresearch.models import Source, TaskCalculationType
from generalresearch.models.custom_types import (
- CoercedStr,
- AwareDatetimeISO,
+ AlphaNumStr,
AlphaNumStrSet,
+ AwareDatetimeISO,
+ CoercedStr,
UUIDStrSet,
- AlphaNumStr,
)
from generalresearch.models.spectrum import SpectrumStatus
from generalresearch.models.thl.demographics import Gender
@@ -321,7 +321,7 @@ class SpectrumSurvey(MarketplaceTask):
}
@classmethod
- def from_api(cls, d: Dict) -> Optional["SpectrumSurvey"]:
+ def from_api(cls, d: Dict[str, Any]) -> Optional["SpectrumSurvey"]:
try:
return cls._from_api(d)
except Exception as e:
@@ -329,7 +329,7 @@ class SpectrumSurvey(MarketplaceTask):
return None
@classmethod
- def _from_api(cls, d: Dict) -> Self:
+ def _from_api(cls, d: Dict[str, Any]) -> Self:
assert d["click_balancing"] in {0, 1}, "unknown click_balancing value"
d["calculation_type"] = (
TaskCalculationType.STARTS
diff --git a/generalresearch/models/spectrum/task_collection.py b/generalresearch/models/spectrum/task_collection.py
index 8aeac7d..3c2ee85 100644
--- a/generalresearch/models/spectrum/task_collection.py
+++ b/generalresearch/models/spectrum/task_collection.py
@@ -1,15 +1,15 @@
-from typing import List, Set, Dict
+from typing import Dict, List, Set
import pandas as pd
-from pandera import Column, DataFrameSchema, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
from generalresearch.locales import Localelator
from generalresearch.models import TaskCalculationType
from generalresearch.models.spectrum import SpectrumStatus
from generalresearch.models.spectrum.survey import SpectrumSurvey
from generalresearch.models.thl.survey.task_collection import (
- create_empty_df_from_schema,
TaskCollection,
+ create_empty_df_from_schema,
)
COUNTRY_ISOS: Set[str] = Localelator().get_all_countries()
@@ -100,7 +100,7 @@ class SpectrumTaskCollection(TaskCollection):
rows.append(d)
return rows
- def to_df(self):
+ def to_df(self) -> pd.DataFrame:
rows = []
for s in self.items:
rows.extend(self.to_rows(s))
diff --git a/generalresearch/models/thl/__init__.py b/generalresearch/models/thl/__init__.py
index 875b2bb..0356842 100644
--- a/generalresearch/models/thl/__init__.py
+++ b/generalresearch/models/thl/__init__.py
@@ -2,8 +2,8 @@ from decimal import Decimal
from typing import Optional
from generalresearch.models.thl.finance import (
- ProductBalances,
POPFinancial,
+ ProductBalances,
)
from generalresearch.models.thl.payout import (
BrokerageProductPayoutEvent,
diff --git a/generalresearch/models/thl/category.py b/generalresearch/models/thl/category.py
index fe330f1..4e9e2ff 100644
--- a/generalresearch/models/thl/category.py
+++ b/generalresearch/models/thl/category.py
@@ -1,7 +1,7 @@
-from typing import Optional
+from typing import Any, Dict, Optional
from uuid import uuid4
-from pydantic import BaseModel, Field, model_validator, PositiveInt
+from pydantic import BaseModel, Field, PositiveInt, model_validator
from typing_extensions import Self
from generalresearch.models.custom_types import UUIDStr
@@ -53,7 +53,7 @@ class Category(BaseModel, frozen=True):
def is_root(self) -> bool:
return self.parent_path is None
- def to_offerwall_api(self) -> dict:
+ def to_offerwall_api(self) -> Dict[str, Any]:
return {
"id": self.uuid,
"label": self.label,
diff --git a/generalresearch/models/thl/contest/__init__.py b/generalresearch/models/thl/contest/__init__.py
index 7cf1f54..f9693eb 100644
--- a/generalresearch/models/thl/contest/__init__.py
+++ b/generalresearch/models/thl/contest/__init__.py
@@ -1,20 +1,20 @@
from __future__ import annotations
from datetime import datetime, timezone
-from typing import Optional, Dict, Any
+from typing import Any, Dict, Optional
from uuid import uuid4
from pydantic import (
BaseModel,
Field,
- model_validator,
- computed_field,
PositiveInt,
+ computed_field,
+ model_validator,
)
from typing_extensions import Self
from generalresearch.currency import USDCent
-from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO
+from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
from generalresearch.models.thl.contest.definitions import ContestPrizeKind
from generalresearch.models.thl.user import User
@@ -34,17 +34,18 @@ class ContestEntryRule(BaseModel):
default=None,
)
- # TODO: Only allow entries if user meets some criteria: gold-membership status,
- # ID/phone verified, min_completes etc...
- # Maybe these get put in a separate model b/c the could apply if the ContestType is not ENTRY
+ # TODO: Only allow entries if user meets some criteria: gold-membership
+ # status, ID/phone verified, min_completes etc... Maybe these get put
+ # in a separate model b/c the could apply if the ContestType is not ENTRY
min_completes: Optional[int] = None
min_membership_level: Optional[int] = None
id_verified: Optional[bool] = None
class ContestEndCondition(BaseModel):
- """Defines the conditions to evaluate to determine when the contest is over.
- Multiple conditions can be set. The contest is over once ANY conditions are met.
+ """Defines the conditions to evaluate to determine when the contest
+ is over. Multiple conditions can be set. The contest is over
+ once ANY conditions are met.
"""
target_entry_amount: USDCent | PositiveInt | None = Field(
diff --git a/generalresearch/models/thl/contest/contest.py b/generalresearch/models/thl/contest/contest.py
index 232a038..e644ae4 100644
--- a/generalresearch/models/thl/contest/contest.py
+++ b/generalresearch/models/thl/contest/contest.py
@@ -1,31 +1,31 @@
from __future__ import annotations
import json
-from abc import abstractmethod, ABC
-from datetime import timezone, datetime
-from typing import List, Tuple, Optional, Dict
+from abc import ABC, abstractmethod
+from datetime import datetime, timezone
+from typing import Any, Dict, List, Optional, Tuple
from uuid import uuid4
from pydantic import (
BaseModel,
+ ConfigDict,
Field,
HttpUrl,
- ConfigDict,
- model_validator,
NonNegativeInt,
+ model_validator,
)
from typing_extensions import Self
-from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO
+from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
from generalresearch.models.thl.contest import (
ContestEndCondition,
ContestPrize,
ContestWinner,
)
from generalresearch.models.thl.contest.definitions import (
+ ContestEndReason,
ContestStatus,
ContestType,
- ContestEndReason,
)
from generalresearch.models.thl.locales import CountryISOs
@@ -169,7 +169,7 @@ class Contest(ContestBase):
)
return None
- def model_dump_mysql(self, **kwargs) -> Dict:
+ def model_dump_mysql(self, **kwargs) -> Dict[str, Any]:
d = self.model_dump(mode="json", **kwargs)
d["created_at"] = self.created_at
@@ -183,7 +183,7 @@ class Contest(ContestBase):
return d
@classmethod
- def model_validate_mysql(cls, data) -> Self:
+ def model_validate_mysql(cls, data: Dict[str, Any]) -> Self:
data = {k: v for k, v in data.items() if k in cls.model_fields.keys()}
if isinstance(data["end_condition"], dict):
data["end_condition"] = ContestEndCondition.model_validate(
diff --git a/generalresearch/models/thl/contest/contest_entry.py b/generalresearch/models/thl/contest/contest_entry.py
index 146f06f..19586da 100644
--- a/generalresearch/models/thl/contest/contest_entry.py
+++ b/generalresearch/models/thl/contest/contest_entry.py
@@ -1,18 +1,18 @@
from __future__ import annotations
from datetime import datetime, timezone
-from typing import Union, Dict, Any
+from typing import Any, Dict, Union
from uuid import uuid4
from pydantic import (
- Field,
BaseModel,
- model_validator,
+ Field,
computed_field,
+ model_validator,
)
from generalresearch.currency import USDCent
-from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO
+from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
from generalresearch.models.thl.contest.definitions import ContestEntryType
from generalresearch.models.thl.user import User
@@ -23,8 +23,8 @@ class ContestEntryCreate(BaseModel):
amount: Union[USDCent, int] = Field(
description="The amount of the entry in integer counts or USD Cents",
gt=0,
- default=None,
)
+
# This is used in the Create Entry API. We'll look up the user and set
# user_id. When we return this model in the API, user_id is excluded
product_user_id: str = Field(
@@ -54,7 +54,6 @@ class ContestEntry(BaseModel):
amount: Union[USDCent, int] = Field(
description="The amount of the entry in integer counts or USD Cents",
gt=0,
- default=None,
)
# user_id used internally, for DB joins/index
@@ -77,6 +76,7 @@ class ContestEntry(BaseModel):
elif entry_type == ContestEntryType.CASH:
# This may be coming from the DB, in which case it is an int.
data["amount"] = USDCent(data["amount"])
+
return data
@computed_field()
@@ -91,6 +91,8 @@ class ContestEntry(BaseModel):
elif self.entry_type == ContestEntryType.CASH:
return self.amount.to_usd_str()
+ raise ValueError(f"Unknown entry_type: {self.entry_type}")
+
@computed_field()
@property
def censored_product_user_id(self) -> str:
diff --git a/generalresearch/models/thl/contest/examples.py b/generalresearch/models/thl/contest/examples.py
index 9748597..4835c14 100644
--- a/generalresearch/models/thl/contest/examples.py
+++ b/generalresearch/models/thl/contest/examples.py
@@ -1,4 +1,4 @@
-from typing import Dict
+from typing import Any, Dict
from pydantic import HttpUrl
@@ -6,21 +6,21 @@ from generalresearch.config import EXAMPLE_PRODUCT_ID
from generalresearch.currency import USDCent
-def _example_raffle_create(schema: Dict) -> None:
- from generalresearch.models.thl.contest.raffle import (
- RaffleContestCreate,
- )
+def _example_raffle_create(schema: Dict[str, Any]) -> None:
from generalresearch.models.thl.contest import (
ContestEndCondition,
- ContestPrize,
ContestEntryRule,
+ ContestPrize,
+ )
+ from generalresearch.models.thl.contest.contest_entry import (
+ ContestEntryType,
)
from generalresearch.models.thl.contest.definitions import (
- ContestType,
ContestPrizeKind,
+ ContestType,
)
- from generalresearch.models.thl.contest.contest_entry import (
- ContestEntryType,
+ from generalresearch.models.thl.contest.raffle import (
+ RaffleContestCreate,
)
schema["example"] = RaffleContestCreate(
@@ -47,20 +47,20 @@ def _example_raffle_create(schema: Dict) -> None:
def _example_raffle(schema: Dict) -> None:
- from generalresearch.models.thl.contest.raffle import RaffleContest
from generalresearch.models.thl.contest import (
ContestEndCondition,
- ContestPrize,
ContestEntryRule,
+ ContestPrize,
+ )
+ from generalresearch.models.thl.contest.contest_entry import (
+ ContestEntryType,
)
from generalresearch.models.thl.contest.definitions import (
- ContestStatus,
ContestPrizeKind,
+ ContestStatus,
ContestType,
)
- from generalresearch.models.thl.contest.contest_entry import (
- ContestEntryType,
- )
+ from generalresearch.models.thl.contest.raffle import RaffleContest
schema["example"] = RaffleContest(
name="Win an iPhone",
@@ -91,22 +91,24 @@ def _example_raffle(schema: Dict) -> None:
product_id=EXAMPLE_PRODUCT_ID,
).model_dump(mode="json")
+ return None
+
-def _example_raffle_user_view(schema: Dict) -> None:
- from generalresearch.models.thl.contest.raffle import RaffleUserView
+def _example_raffle_user_view(schema: Dict[str, Any]) -> None:
from generalresearch.models.thl.contest import (
ContestEndCondition,
- ContestPrize,
ContestEntryRule,
+ ContestPrize,
+ )
+ from generalresearch.models.thl.contest.contest_entry import (
+ ContestEntryType,
)
from generalresearch.models.thl.contest.definitions import (
- ContestStatus,
ContestPrizeKind,
+ ContestStatus,
ContestType,
)
- from generalresearch.models.thl.contest.contest_entry import (
- ContestEntryType,
- )
+ from generalresearch.models.thl.contest.raffle import RaffleUserView
schema["example"] = RaffleUserView(
name="Win an iPhone",
@@ -140,19 +142,21 @@ def _example_raffle_user_view(schema: Dict) -> None:
product_user_id="test-user",
).model_dump(mode="json")
+ return None
-def _example_milestone_create(schema: Dict) -> None:
- from generalresearch.models.thl.contest.milestone import (
- MilestoneContestCreate,
- MilestoneContestEndCondition,
- ContestEntryTrigger,
- )
+
+def _example_milestone_create(schema: Dict[str, Any]) -> None:
from generalresearch.models.thl.contest import (
ContestPrize,
)
from generalresearch.models.thl.contest.definitions import (
- ContestType,
ContestPrizeKind,
+ ContestType,
+ )
+ from generalresearch.models.thl.contest.milestone import (
+ ContestEntryTrigger,
+ MilestoneContestCreate,
+ MilestoneContestEndCondition,
)
schema["example"] = MilestoneContestCreate(
@@ -179,19 +183,21 @@ def _example_milestone_create(schema: Dict) -> None:
terms_and_conditions=HttpUrl("https://www.example.com"),
).model_dump(mode="json")
+ return None
-def _example_milestone(schema: Dict) -> None:
- from generalresearch.models.thl.contest.milestone import (
- MilestoneContest,
- MilestoneContestEndCondition,
- ContestEntryTrigger,
- )
+
+def _example_milestone(schema: Dict[str, Any]) -> None:
from generalresearch.models.thl.contest import (
ContestPrize,
)
from generalresearch.models.thl.contest.definitions import (
- ContestType,
ContestPrizeKind,
+ ContestType,
+ )
+ from generalresearch.models.thl.contest.milestone import (
+ ContestEntryTrigger,
+ MilestoneContest,
+ MilestoneContestEndCondition,
)
schema["example"] = MilestoneContest(
@@ -223,17 +229,19 @@ def _example_milestone(schema: Dict) -> None:
win_count=12,
).model_dump(mode="json")
+ return None
-def _example_milestone_user_view(schema: Dict) -> None:
- from generalresearch.models.thl.contest.milestone import (
- MilestoneUserView,
- MilestoneContestEndCondition,
- ContestEntryTrigger,
- )
+
+def _example_milestone_user_view(schema: Dict[str, Any]) -> None:
from generalresearch.models.thl.contest import ContestPrize
from generalresearch.models.thl.contest.definitions import (
- ContestType,
ContestPrizeKind,
+ ContestType,
+ )
+ from generalresearch.models.thl.contest.milestone import (
+ ContestEntryTrigger,
+ MilestoneContestEndCondition,
+ MilestoneUserView,
)
schema["example"] = MilestoneUserView(
@@ -267,17 +275,19 @@ def _example_milestone_user_view(schema: Dict) -> None:
product_user_id="test-user",
).model_dump(mode="json")
+ return None
-def _example_leaderboard_contest_create(schema: Dict) -> None:
- from generalresearch.models.thl.contest.leaderboard import (
- LeaderboardContestCreate,
- )
+
+def _example_leaderboard_contest_create(schema: Dict[str, Any]) -> None:
from generalresearch.models.thl.contest import (
ContestPrize,
)
from generalresearch.models.thl.contest.definitions import (
- ContestType,
ContestPrizeKind,
+ ContestType,
+ )
+ from generalresearch.models.thl.contest.leaderboard import (
+ LeaderboardContestCreate,
)
schema["example"] = LeaderboardContestCreate(
@@ -313,16 +323,16 @@ def _example_leaderboard_contest_create(schema: Dict) -> None:
return None
-def _example_leaderboard_contest(schema: Dict) -> None:
- from generalresearch.models.thl.contest.leaderboard import (
- LeaderboardContest,
- )
+def _example_leaderboard_contest(schema: Dict[str, Any]) -> None:
from generalresearch.models.thl.contest import (
ContestPrize,
)
from generalresearch.models.thl.contest.definitions import (
- ContestType,
ContestPrizeKind,
+ ContestType,
+ )
+ from generalresearch.models.thl.contest.leaderboard import (
+ LeaderboardContest,
)
schema["example"] = LeaderboardContest(
@@ -359,10 +369,7 @@ def _example_leaderboard_contest(schema: Dict) -> None:
return None
-def _example_leaderboard_contest_user_view(schema: Dict) -> None:
- from generalresearch.models.thl.contest.leaderboard import (
- LeaderboardContestUserView,
- )
+def _example_leaderboard_contest_user_view(schema: Dict[str, Any]) -> None:
from generalresearch.models.thl.contest import (
ContestPrize,
)
@@ -370,6 +377,9 @@ def _example_leaderboard_contest_user_view(schema: Dict) -> None:
ContestPrizeKind,
ContestType,
)
+ from generalresearch.models.thl.contest.leaderboard import (
+ LeaderboardContestUserView,
+ )
schema["example"] = LeaderboardContestUserView(
name="Prizes for top survey takers this week",
@@ -402,3 +412,5 @@ def _example_leaderboard_contest_user_view(schema: Dict) -> None:
product_id=EXAMPLE_PRODUCT_ID,
product_user_id="test-user",
).model_dump(mode="json")
+
+ return None
diff --git a/generalresearch/models/thl/contest/leaderboard.py b/generalresearch/models/thl/contest/leaderboard.py
index 0b24190..d0681c4 100644
--- a/generalresearch/models/thl/contest/leaderboard.py
+++ b/generalresearch/models/thl/contest/leaderboard.py
@@ -1,12 +1,12 @@
-from datetime import datetime, timezone, timedelta
-from typing import Optional, Literal, List, Tuple, Dict, Any
+from datetime import datetime, timedelta, timezone
+from typing import Any, Dict, List, Literal, Optional, Tuple
from pydantic import (
- Field,
ConfigDict,
+ Field,
+ PrivateAttr,
computed_field,
model_validator,
- PrivateAttr,
)
from redis import Redis
from typing_extensions import Self
@@ -18,8 +18,8 @@ from generalresearch.managers.thl.user_manager.user_manager import (
UserManager,
)
from generalresearch.models.thl.contest import (
- ContestWinner,
ContestEndCondition,
+ ContestWinner,
)
from generalresearch.models.thl.contest.contest import (
Contest,
@@ -27,16 +27,16 @@ from generalresearch.models.thl.contest.contest import (
ContestUserView,
)
from generalresearch.models.thl.contest.definitions import (
+ ContestEndReason,
+ ContestPrizeKind,
ContestStatus,
ContestType,
- ContestPrizeKind,
- ContestEndReason,
LeaderboardTieBreakStrategy,
)
from generalresearch.models.thl.contest.examples import (
- _example_leaderboard_contest_user_view,
_example_leaderboard_contest,
_example_leaderboard_contest_create,
+ _example_leaderboard_contest_user_view,
)
from generalresearch.models.thl.leaderboard import (
Leaderboard,
@@ -96,7 +96,7 @@ class LeaderboardContestCreate(ContestBase):
return self
@property
- def leaderboard_key_parts(self) -> Dict:
+ def leaderboard_key_parts(self) -> Dict[str, Any]:
assert self.leaderboard_key.count(":") == 5, "invalid leaderboard_key"
parts = self.leaderboard_key.split(":")
_, product_id, country_iso, freq_str, date_str, board_code_value = parts
@@ -281,7 +281,7 @@ class LeaderboardContestUserView(LeaderboardContest, ContestUserView):
return False, "Contest hasn't started"
# This would indicate something is wrong, as something else should have done this
- e, reason = self.should_end()
+ e, _ = self.should_end()
if e:
LOG.warning("contest should be over")
return False, "contest is over"
diff --git a/generalresearch/models/thl/contest/milestone.py b/generalresearch/models/thl/contest/milestone.py
index f902401..8c51a88 100644
--- a/generalresearch/models/thl/contest/milestone.py
+++ b/generalresearch/models/thl/contest/milestone.py
@@ -2,33 +2,33 @@ from __future__ import annotations
import logging
from datetime import timedelta
-from typing import Literal, Optional, Tuple, Dict
+from typing import Any, Dict, Literal, Optional, Tuple
from pydantic import (
- Field,
- ConfigDict,
BaseModel,
+ ConfigDict,
+ Field,
PositiveInt,
)
from typing_extensions import Self
from generalresearch.models.custom_types import AwareDatetimeISO
from generalresearch.models.thl.contest.contest import (
- ContestBase,
Contest,
+ ContestBase,
ContestUserView,
)
from generalresearch.models.thl.contest.contest_entry import ContestEntry
from generalresearch.models.thl.contest.definitions import (
+ ContestEndReason,
+ ContestEntryTrigger,
ContestEntryType,
ContestStatus,
ContestType,
- ContestEndReason,
- ContestEntryTrigger,
)
from generalresearch.models.thl.contest.examples import (
- _example_milestone_create,
_example_milestone,
+ _example_milestone_create,
_example_milestone_user_view,
)
@@ -42,9 +42,7 @@ class MilestoneEntry(ContestEntry):
entry_type: Literal[ContestEntryType.COUNT] = Field(default=ContestEntryType.COUNT)
- # TODO: Must fix - how can the default be None if it's not Optional...
amount: int = Field(
- default=None,
description="The amount of the entry in integer counts",
gt=0,
)
@@ -147,7 +145,7 @@ class MilestoneContest(MilestoneContestCreate, Contest):
# just does nothing
return None
- def model_dump_mysql(self):
+ def model_dump_mysql(self) -> Dict[str, Any]:
d = super().model_dump_mysql(
exclude={
"entry_trigger",
@@ -165,7 +163,7 @@ class MilestoneContest(MilestoneContestCreate, Contest):
return d
@classmethod
- def model_validate_mysql(cls, data: Dict) -> Self:
+ def model_validate_mysql(cls, data: Dict[str, Any]) -> Self:
data.update(
MilestoneContestConfig.model_validate(data["milestone_config"]).model_dump()
)
@@ -218,9 +216,10 @@ class MilestoneUserView(MilestoneContest, ContestUserView):
# i.e. it hasn't been >24 hrs since user signed up, or whatever
# This would indicate something is wrong, as something else should have done this
- e, reason = self.should_end()
+ e, _ = self.should_end()
if e:
LOG.warning("contest should be over")
return False, "contest is over"
+
# TODO: others in self.entry_rule ... min_completes, id_verified, etc.
return True, ""
diff --git a/generalresearch/models/thl/contest/raffle.py b/generalresearch/models/thl/contest/raffle.py
index 9a01d0f..1592857 100644
--- a/generalresearch/models/thl/contest/raffle.py
+++ b/generalresearch/models/thl/contest/raffle.py
@@ -4,14 +4,14 @@ import logging
import random
from collections import defaultdict
from datetime import datetime, timezone
-from typing import Literal, List, Dict, Tuple, Optional, Union
+from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import (
+ ConfigDict,
Field,
- model_validator,
computed_field,
field_validator,
- ConfigDict,
+ model_validator,
)
from scipy.stats import hypergeom
from typing_extensions import Self
@@ -28,14 +28,14 @@ from generalresearch.models.thl.contest.contest import (
)
from generalresearch.models.thl.contest.contest_entry import ContestEntry
from generalresearch.models.thl.contest.definitions import (
+ ContestEndReason,
ContestEntryType,
ContestStatus,
ContestType,
- ContestEndReason,
)
from generalresearch.models.thl.contest.examples import (
- _example_raffle_create,
_example_raffle,
+ _example_raffle_create,
_example_raffle_user_view,
)
@@ -186,10 +186,10 @@ class RaffleContest(RaffleContestCreate, Contest):
def get_current_participants(self) -> int:
return len({entry.user.user_id for entry in self.entries})
- def get_current_amount(self) -> Union[int | USDCent]:
+ def get_current_amount(self) -> Union[int, USDCent]:
return sum([x.amount for x in self.entries])
- def get_user_amount(self, product_user_id: str) -> Union[int | USDCent]:
+ def get_user_amount(self, product_user_id: str) -> Union[int, USDCent]:
# Sum of this user's amounts
return sum(
e.amount for e in self.entries if e.user.product_user_id == product_user_id
@@ -206,7 +206,7 @@ class RaffleContest(RaffleContestCreate, Contest):
return True
return False
- def model_dump_mysql(self):
+ def model_dump_mysql(self) -> Dict[str, Any]:
d = super().model_dump_mysql()
d["entry_rule"] = self.entry_rule.model_dump_json()
return d
@@ -309,9 +309,10 @@ class RaffleUserView(RaffleContest, ContestUserView):
return False, "Reached max amount today."
# This would indicate something is wrong, as something else should have done this
- e, reason = self.should_end()
+ e, _ = self.should_end()
if e:
LOG.warning("contest should be over")
return False, "contest is over"
+
# todo: others in self.entry_rule ... min_completes, id_verified, etc.
return True, ""
diff --git a/generalresearch/models/thl/contest/utils.py b/generalresearch/models/thl/contest/utils.py
index 1505f7b..e5043c2 100644
--- a/generalresearch/models/thl/contest/utils.py
+++ b/generalresearch/models/thl/contest/utils.py
@@ -1,9 +1,9 @@
-from typing import TYPE_CHECKING, List, Dict
+from typing import TYPE_CHECKING, Dict, List
if TYPE_CHECKING:
- from generalresearch.models.thl.user import User
from generalresearch.currency import USDCent
from generalresearch.models.thl.leaderboard import LeaderboardRow
+ from generalresearch.models.thl.user import User
def censor_product_user_id(user: "User") -> str:
diff --git a/generalresearch/models/thl/definitions.py b/generalresearch/models/thl/definitions.py
index 259093b..22ca4b9 100644
--- a/generalresearch/models/thl/definitions.py
+++ b/generalresearch/models/thl/definitions.py
@@ -67,7 +67,8 @@ class THLPaths(str, Enum, metaclass=ReprEnumMeta):
class Status(str, Enum, metaclass=ReprEnumMeta):
"""
- The outcome of a session or wall event. If the session is still in progress, the status will be NULL.
+ The outcome of a session or wall event. If the session is still in
+ progress, the status will be NULL.
"""
# User completed the job successfully and should be paid something
diff --git a/generalresearch/models/thl/demographics.py b/generalresearch/models/thl/demographics.py
index fd833c5..a688e49 100644
--- a/generalresearch/models/thl/demographics.py
+++ b/generalresearch/models/thl/demographics.py
@@ -1,10 +1,10 @@
from __future__ import annotations
import copy
-from collections import defaultdict, Counter
+from collections import Counter, defaultdict
from dataclasses import dataclass
from enum import Enum
-from typing import TYPE_CHECKING, Literal, List, Dict
+from typing import TYPE_CHECKING, Any, Dict, List, Literal
import numpy as np
@@ -76,7 +76,7 @@ class AgeGroup(Enum):
return self.label
-def calculate_demographic_metrics(opps: List[MarketplaceTask]):
+def calculate_demographic_metrics(opps: List[MarketplaceTask]) -> List:
"""
Measurement: marketplace_survey_demographics
tags: source (marketplace)
@@ -141,12 +141,13 @@ def calculate_demographic_metrics(opps: List[MarketplaceTask]):
d["tags"].update(k.to_tags())
d["fields"].update(v)
points.append(d)
+
return points
def calculate_used_question_metrics(
opps: List[MarketplaceTask], qid_label: Dict[str, str]
-):
+) -> List[Dict[str, Any]]:
"""
Measurement: marketplace_survey_targeting
tags: source (marketplace), "type", country (all and individual)
diff --git a/generalresearch/models/thl/finance.py b/generalresearch/models/thl/finance.py
index 6a24b5e..8d526ac 100644
--- a/generalresearch/models/thl/finance.py
+++ b/generalresearch/models/thl/finance.py
@@ -1,23 +1,23 @@
import random
from datetime import timezone
-from typing import Optional, TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional
from uuid import uuid4
import pandas as pd
from pydantic import (
BaseModel,
+ ConfigDict,
Field,
NonNegativeInt,
- ConfigDict,
- model_validator,
computed_field,
field_validator,
+ model_validator,
)
from pydantic.json_schema import SkipJsonSchema
from generalresearch.currency import USDCent
from generalresearch.decorators import LOG
-from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO
+from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
from generalresearch.models.thl.definitions import SessionAdjustedStatus
from generalresearch.pg_helper import PostgresConfig
@@ -25,9 +25,8 @@ payout_example = random.randint(150, 750 * 100)
adjustment_example = random.randint(-1_000, 50 * 100)
if TYPE_CHECKING:
- from generalresearch.models.thl.ledger import LedgerAccount
from generalresearch.managers.thl.product import ProductManager
- from generalresearch.models.thl.ledger import AccountType, Direction
+ from generalresearch.models.thl.ledger import AccountType, Direction, LedgerAccount
class AdjustmentType(BaseModel):
@@ -111,12 +110,11 @@ class POPFinancial(BaseModel):
has a single Product.
"""
+ from generalresearch.config import is_debug
from generalresearch.incite.schemas.mergers.pop_ledger import (
numerical_col_names,
)
- from generalresearch.config import is_debug
-
# Validate the input accounts
assert len(accounts) > 0, "Must provide accounts"
from generalresearch.models.thl.ledger import (
@@ -504,7 +502,7 @@ class ProductBalances(BaseModel):
@staticmethod
def from_pandas(
input_data: pd.DataFrame | pd.Series,
- ):
+ ) -> "ProductBalances":
LOG.debug(f"ProductBalances.from_pandas(input_data={input_data.shape})")
if isinstance(input_data, pd.Series):
@@ -832,11 +830,11 @@ class BusinessBalances(BaseModel):
from generalresearch.incite.schemas.mergers.pop_ledger import (
numerical_col_names,
)
- from generalresearch.models.thl.product import Product
from generalresearch.models.thl.ledger import (
AccountType,
Direction,
)
+ from generalresearch.models.thl.product import Product
# Validate the input accounts
assert len(accounts) > 0, "Must provide accounts"
diff --git a/generalresearch/models/thl/grliq.py b/generalresearch/models/thl/grliq.py
index 1e769aa..69b23bd 100644
--- a/generalresearch/models/thl/grliq.py
+++ b/generalresearch/models/thl/grliq.py
@@ -1,6 +1,6 @@
from generalresearch.grliq.models.decider import (
- Decider,
AttemptDecision,
+ Decider,
GrlIqAttemptResult,
)
diff --git a/generalresearch/models/thl/ipinfo.py b/generalresearch/models/thl/ipinfo.py
index d8abfc0..a98689f 100644
--- a/generalresearch/models/thl/ipinfo.py
+++ b/generalresearch/models/thl/ipinfo.py
@@ -1,23 +1,23 @@
import ipaddress
from datetime import datetime, timezone
-from typing import Optional, Dict, Any, Literal, Tuple
+from typing import Any, Dict, Literal, Optional, Tuple
import geoip2.models
from faker import Faker
from pydantic import (
BaseModel,
+ ConfigDict,
Field,
PositiveInt,
- field_validator,
PrivateAttr,
- ConfigDict,
+ field_validator,
)
from typing_extensions import Self
from generalresearch.models.custom_types import (
AwareDatetimeISO,
- IPvAnyAddressStr,
CountryISOLike,
+ IPvAnyAddressStr,
)
from generalresearch.models.thl.maxmind.definitions import UserType
from generalresearch.pg_helper import PostgresConfig
@@ -104,9 +104,10 @@ class IPGeoname(BaseModel):
"subdivision_2_iso",
mode="before",
)
- def make_lower(cls, value: str):
+ def make_lower(cls, value: Optional[str]) -> Optional[str]:
if value is not None:
return value.lower()
+
return value
# --- ORM ---
@@ -116,7 +117,7 @@ class IPGeoname(BaseModel):
return d
@classmethod
- def from_mysql(cls, d: Dict) -> Self:
+ def from_mysql(cls, d: Dict[str, Any]) -> Self:
d["updated"] = d["updated"].replace(tzinfo=timezone.utc)
return cls.model_validate(d)
@@ -250,9 +251,10 @@ class IPInformation(BaseModel):
_geoname: Optional[IPGeoname] = PrivateAttr(default=None)
@field_validator("country_iso", "registered_country_iso", mode="before")
- def make_lower(cls, value: str):
+ def make_lower(cls, value: Optional[str]) -> Optional[str]:
if value is not None:
return value.lower()
+
return value
@property
diff --git a/generalresearch/models/thl/leaderboard.py b/generalresearch/models/thl/leaderboard.py
index a4c2134..7d79091 100644
--- a/generalresearch/models/thl/leaderboard.py
+++ b/generalresearch/models/thl/leaderboard.py
@@ -1,25 +1,25 @@
from __future__ import annotations
import logging
+import math
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import List, Literal
from uuid import UUID, uuid3
-from zoneinfo import ZoneInfo
-import math
import pandas as pd
from pydantic import (
+ AwareDatetime,
BaseModel,
Field,
NonNegativeInt,
- model_validator,
computed_field,
- AwareDatetime,
field_validator,
+ model_validator,
)
+from zoneinfo import ZoneInfo
-from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO
+from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
from generalresearch.models.legacy.api_status import StatusResponse
from generalresearch.models.thl.locales import CountryISO
from generalresearch.utils.enum import ReprEnumMeta
@@ -81,13 +81,11 @@ class Leaderboard(BaseModel):
id: UUIDStr = Field(
description="Unique ID for this leaderboard",
examples=["845b0074ad533df580ebb9c80cc3bce1"],
- default=None,
)
name: str = Field(
description="Descriptive name for the leaderboard based on the board_code",
examples=["Number of Completes"],
- default=None,
)
board_code: LeaderboardCode = Field(
@@ -111,7 +109,6 @@ class Leaderboard(BaseModel):
timezone_name: str = Field(
description="The timezone for the requested country",
examples=["America/New_York"],
- default=None,
)
sort_order: Literal["ascending", "descending"] = Field(default="descending")
@@ -155,7 +152,6 @@ class Leaderboard(BaseModel):
)
],
# exclude=True,
- default=None,
)
@property
diff --git a/generalresearch/models/thl/ledger.py b/generalresearch/models/thl/ledger.py
index 045eeba..31679d2 100644
--- a/generalresearch/models/thl/ledger.py
+++ b/generalresearch/models/thl/ledger.py
@@ -1,31 +1,31 @@
from datetime import datetime, timezone
from enum import Enum
-from typing import Dict, Optional, List, Literal, Annotated, Union
+from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from uuid import uuid4
from pydantic import (
BaseModel,
- Field,
- field_validator,
- model_validator,
ConfigDict,
+ Field,
+ NonNegativeInt,
PositiveInt,
computed_field,
- NonNegativeInt,
+ field_validator,
+ model_validator,
)
from typing_extensions import Self
from generalresearch.models.custom_types import (
- UUIDStr,
AwareDatetimeISO,
- check_valid_uuid,
HttpsUrlStr,
+ UUIDStr,
+ check_valid_uuid,
)
from generalresearch.models.thl.ledger_example import (
- _example_user_tx_payout,
+ _example_user_tx_adjustment,
_example_user_tx_bonus,
_example_user_tx_complete,
- _example_user_tx_adjustment,
+ _example_user_tx_payout,
)
from generalresearch.models.thl.pagination import Page
from generalresearch.models.thl.payout_format import (
@@ -300,7 +300,7 @@ class LedgerTransaction(BaseModel):
), "ledger entries must balance"
return entries
- def model_dump_mysql(self, *args, **kwargs) -> dict:
+ def model_dump_mysql(self, *args, **kwargs) -> Dict[str, Any]:
d = self.model_dump(mode="json", *args, **kwargs)
if "created" in d:
d["created"] = self.created.replace(tzinfo=None)
diff --git a/generalresearch/models/thl/ledger_example.py b/generalresearch/models/thl/ledger_example.py
index 32cd464..bf120cb 100644
--- a/generalresearch/models/thl/ledger_example.py
+++ b/generalresearch/models/thl/ledger_example.py
@@ -1,9 +1,9 @@
from datetime import datetime, timezone
-from typing import Dict
+from typing import Any, Dict
from uuid import uuid4
-def _example_user_tx_payout(schema: Dict) -> None:
+def _example_user_tx_payout(schema: Dict[str, Any]) -> None:
from generalresearch.models.thl.ledger import (
UserLedgerTransactionUserPayout,
)
@@ -18,7 +18,7 @@ def _example_user_tx_payout(schema: Dict) -> None:
).model_dump(mode="json")
-def _example_user_tx_bonus(schema: Dict) -> None:
+def _example_user_tx_bonus(schema: Dict[str, Any]) -> None:
from generalresearch.models.thl.ledger import (
UserLedgerTransactionUserBonus,
)
@@ -32,7 +32,7 @@ def _example_user_tx_bonus(schema: Dict) -> None:
).model_dump(mode="json")
-def _example_user_tx_complete(schema: Dict) -> None:
+def _example_user_tx_complete(schema: Dict[str, Any]) -> None:
from generalresearch.models.thl.ledger import (
UserLedgerTransactionTaskComplete,
)
@@ -47,7 +47,7 @@ def _example_user_tx_complete(schema: Dict) -> None:
).model_dump(mode="json")
-def _example_user_tx_adjustment(schema: Dict) -> None:
+def _example_user_tx_adjustment(schema: Dict[str, Any]) -> None:
from generalresearch.models.thl.ledger import (
UserLedgerTransactionTaskAdjustment,
)
diff --git a/generalresearch/models/thl/locales.py b/generalresearch/models/thl/locales.py
index 210a158..bff774a 100644
--- a/generalresearch/models/thl/locales.py
+++ b/generalresearch/models/thl/locales.py
@@ -4,8 +4,8 @@ from pydantic import AfterValidator
from generalresearch.locales import Localelator
from generalresearch.models.custom_types import (
- to_comma_sep_str,
from_comma_sep_str,
+ to_comma_sep_str,
)
locale_helper = Localelator()
diff --git a/generalresearch/models/thl/offerwall/__init__.py b/generalresearch/models/thl/offerwall/__init__.py
index 2da9d43..bf6b7ea 100644
--- a/generalresearch/models/thl/offerwall/__init__.py
+++ b/generalresearch/models/thl/offerwall/__init__.py
@@ -4,14 +4,14 @@ import hashlib
import json
from decimal import Decimal
from enum import Enum
-from typing import Literal, Optional, Dict, Set, Any
+from typing import Any, Dict, Literal, Optional, Set
from pydantic import (
BaseModel,
Field,
- model_validator,
- computed_field,
PositiveInt,
+ computed_field,
+ model_validator,
)
from typing_extensions import Self
diff --git a/generalresearch/models/thl/offerwall/base.py b/generalresearch/models/thl/offerwall/base.py
index 66149a9..8ea259d 100644
--- a/generalresearch/models/thl/offerwall/base.py
+++ b/generalresearch/models/thl/offerwall/base.py
@@ -2,29 +2,31 @@ import statistics
from datetime import timedelta
from decimal import Decimal
from string import Formatter
-from typing import Optional, List, Any, Set, Dict, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple
from uuid import uuid4
import numpy as np
import pandas as pd
from pydantic import (
BaseModel,
+ ConfigDict,
Field,
NonNegativeFloat,
NonNegativeInt,
- ConfigDict,
field_validator,
model_validator,
)
-from typing_extensions import Self, Annotated
+from typing_extensions import Annotated, Self
from generalresearch.models import Source
-from generalresearch.models.custom_types import UUIDStr, HttpsUrl
+from generalresearch.models.custom_types import HttpsUrl, UUIDStr
from generalresearch.models.legacy.bucket import (
Bucket as LegacyBucket,
- Eligibility,
+)
+from generalresearch.models.legacy.bucket import (
CategoryAssociation,
DurationSummary,
+ Eligibility,
PayoutSummary,
PayoutSummaryDecimal,
SurveyEligibilityCriterion,
@@ -32,9 +34,9 @@ from generalresearch.models.legacy.bucket import (
from generalresearch.models.legacy.definitions import OfferwallReason
from generalresearch.models.thl.locales import CountryISO
from generalresearch.models.thl.offerwall import (
+ OFFERWALL_TYPE_CLASS,
OfferWallType,
OfferWallTypeClass,
- OFFERWALL_TYPE_CLASS,
)
from generalresearch.models.thl.offerwall.bucket import (
generate_offerwall_entry_url,
diff --git a/generalresearch/models/thl/offerwall/behavior.py b/generalresearch/models/thl/offerwall/behavior.py
index e8da334..8e6c089 100644
--- a/generalresearch/models/thl/offerwall/behavior.py
+++ b/generalresearch/models/thl/offerwall/behavior.py
@@ -1,6 +1,6 @@
from typing import Any, Dict, Literal
-from pydantic import Field, BaseModel
+from pydantic import BaseModel, Field
class OfferWallBehavior(BaseModel):
diff --git a/generalresearch/models/thl/offerwall/bucket.py b/generalresearch/models/thl/offerwall/bucket.py
index dd93cd1..1e08e32 100644
--- a/generalresearch/models/thl/offerwall/bucket.py
+++ b/generalresearch/models/thl/offerwall/bucket.py
@@ -8,9 +8,9 @@ def generate_offerwall_entry_url(
bp_user_id: str,
request_id: Optional[str] = None,
nudge_id: Optional[str] = None,
-):
- # for an offerwall entry link, we need the clicked bucket_id and the request hash (so we know
- # which GetOfferwall cache to get
+) -> str:
+ # For an offerwall entry link, we need the clicked bucket_id and the
+ # request hash (so we know which GetOfferwall cache to get
query_dict = {"i": obj_id, "b": bp_user_id}
if request_id:
query_dict["66482fb"] = request_id
diff --git a/generalresearch/models/thl/offerwall/cache.py b/generalresearch/models/thl/offerwall/cache.py
index 6040175..b75a803 100644
--- a/generalresearch/models/thl/offerwall/cache.py
+++ b/generalresearch/models/thl/offerwall/cache.py
@@ -1,5 +1,5 @@
from datetime import datetime, timezone
-from typing import Dict, Any, List, Optional
+from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
@@ -8,8 +8,8 @@ from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
from generalresearch.models.thl.offerwall import OfferWallRequest
from generalresearch.models.thl.offerwall.base import (
OfferwallBase,
- TaskResult,
ScoredTaskResult,
+ TaskResult,
)
diff --git a/generalresearch/models/thl/pagination.py b/generalresearch/models/thl/pagination.py
index 6679cb4..1b31078 100644
--- a/generalresearch/models/thl/pagination.py
+++ b/generalresearch/models/thl/pagination.py
@@ -1,6 +1,6 @@
+from math import ceil
from typing import Optional
-from math import ceil
from pydantic import BaseModel, Field, computed_field
diff --git a/generalresearch/models/thl/payout.py b/generalresearch/models/thl/payout.py
index b6f880f..8ab01ec 100644
--- a/generalresearch/models/thl/payout.py
+++ b/generalresearch/models/thl/payout.py
@@ -1,19 +1,19 @@
import json
from datetime import datetime, timezone
-from typing import Dict, Optional, Collection, List
+from typing import Collection, Dict, List, Optional
from uuid import uuid4
from pydantic import (
BaseModel,
Field,
+ PositiveInt,
computed_field,
field_validator,
- PositiveInt,
)
from typing_extensions import Self
from generalresearch.currency import USDCent
-from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO
+from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
from generalresearch.models.thl.definitions import PayoutStatus
from generalresearch.models.thl.ledger import OrderBy
from generalresearch.models.thl.wallet import PayoutType
@@ -229,6 +229,7 @@ class BrokerageProductPayoutEvent(PayoutEvent):
account_product_mapping: Optional[Dict[UUIDStr, UUIDStr]] = None,
redis_config: Optional[RedisConfig] = None,
) -> Self:
+ # TODO!: prevent re-assignment, rework this...
if account_product_mapping is None:
rc = redis_config.create_redis_client()
@@ -248,6 +249,7 @@ class BrokerageProductPayoutEvent(PayoutEvent):
account_product_mapping: Optional[Dict[UUIDStr, UUIDStr]] = None,
redis_config: Optional[RedisConfig] = None,
) -> List[Self]:
+ # TODO!: prevent re-assignment, rework this...
if account_product_mapping is None:
rc = redis_config.create_redis_client()
diff --git a/generalresearch/models/thl/payout_format.py b/generalresearch/models/thl/payout_format.py
index f108e46..4989bdc 100644
--- a/generalresearch/models/thl/payout_format.py
+++ b/generalresearch/models/thl/payout_format.py
@@ -20,7 +20,9 @@ def validate_payout_format(payout_format: str) -> str:
def format_payout_format(payout_format: str, payout_int: int) -> str:
"""
- Generate a str representation of a payout. Typically, this would be displayed to a user.
+ Generate a str representation of a payout. Typically, this would be
+ displayed to a user.
+
:param payout_format: see BPC_DEFAULTS.payout_format
:param payout_int: The actual value in integer usd cents.
"""
diff --git a/generalresearch/models/thl/product.py b/generalresearch/models/thl/product.py
index 949f1f8..5889dc3 100644
--- a/generalresearch/models/thl/product.py
+++ b/generalresearch/models/thl/product.py
@@ -125,16 +125,17 @@ class ProfilingConfig(BaseModel):
enabled: bool = Field(
default=True,
- description="If False, the harmonizer/profiling system is not used at all. This should "
- "never be False unless special circumstances",
+ description="If False, the harmonizer/profiling system is not used at "
+ "all. This should never be False unless special circumstances",
)
grs_enabled: bool = Field(
default=True,
- description="""If grs_enabled is False, and is_grs is passed in the profiling-questions call,
- then don't actually return any questions. This allows a client to hit the endpoint with no limit
- and still get questions. In effect, this means that we'll redirect the user through the GRS
- system but won't present them any questions.""",
+ description="""If grs_enabled is False, and is_grs is passed in the
+ profiling-questions call, then don't actually return any questions.
+ This allows a client to hit the endpoint with no limit and still get
+ questions. In effect, this means that we'll redirect the user through
+ the GRS system but won't present them any questions.""",
)
n_questions: Optional[PositiveInt] = Field(
@@ -155,14 +156,16 @@ class ProfilingConfig(BaseModel):
# Don't set this to 0, use enabled
task_injection_freq_mult: PositiveFloat = Field(
default=1,
- description="Scale how frequently we inject profiling questions, relative to the default."
- "1 is default, 2 is twice as often. 10 means always. 0.5 half as often",
+ description="""Scale how frequently we inject profiling questions,
+ relative to the default. 1 is default, 2 is twice as often. 10 means
+ always. 0.5 half as often""",
)
non_us_mult: PositiveFloat = Field(
default=2,
- description="Non-us multiplier, used to increase freq and length of profilers in all non-us countries."
- "This value is multiplied by task_injection_freq_mult and avg_question_count.",
+ description="""Non-us multiplier, used to increase freq and length of
+ profilers in all non-us countries. This value is multiplied by
+ task_injection_freq_mult and avg_question_count.""",
)
hidden_questions_expiration_hours: PositiveInt = Field(
@@ -186,13 +189,13 @@ class UserHealthConfig(BaseModel):
# are blocked.
banned_countries: List[CountryISOLike] = Field(default_factory=list)
- # Decide if a user can be blocked for IP-related triggers such as sharing IPs
- # and location history. This should eventually be deprecated and replaced
- # with something with more specificity.
+ # Decide if a user can be blocked for IP-related triggers such as sharing
+ # IPs and location history. This should eventually be deprecated and
+ # replaced with something with more specificity.
allow_ban_iphist: bool = Field(default=True)
- # These are only checked by ym-user-predict, which I'm not sure even works properly.
- # To be deprecated ... don't even use them.
+ # These are only checked by ym-user-predict, which I'm not sure even
+ # works properly. To be deprecated ... don't even use them.
userprofit_cutoff: Optional[Decimal] = Field(default=None, exclude=True)
recon_cutoff: Optional[float] = Field(default=None, exclude=True)
droprate_cutoff: Optional[float] = Field(default=None, exclude=True)
@@ -205,9 +208,11 @@ class UserHealthConfig(BaseModel):
class OfferWallRequestYieldmanParams(BaseModel):
# model_config = ConfigDict(extra='forbid')
- # keys: use_stats, use_harmonizer, allow_pii, add_default_lang_eng, first_n_completes_easier_per_day are
+ # keys: use_stats, use_harmonizer, allow_pii, add_default_lang_eng,
+ # first_n_completes_easier_per_day are
# ignored/deprecated
- # allow_pii: bool = Field(default=True, description="Allow tasks that request PII. This actually does nothing.")
+ # allow_pii: bool = Field(default=True, description="Allow tasks that
+ # request PII. This actually does nothing.")
# see thl-grpc:yield_management.scoring.score() for more info
conversion_factor_adj: float = Field(
diff --git a/generalresearch/models/thl/profiling/marketplace.py b/generalresearch/models/thl/profiling/marketplace.py
index 414a437..2dd7028 100644
--- a/generalresearch/models/thl/profiling/marketplace.py
+++ b/generalresearch/models/thl/profiling/marketplace.py
@@ -1,16 +1,16 @@
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from functools import cached_property
-from typing import Any, Dict, Set, Optional, Tuple
+from typing import Any, Dict, Optional, Set, Tuple
-from pydantic import PositiveInt, BaseModel, Field, computed_field, ConfigDict
+from pydantic import BaseModel, ConfigDict, Field, PositiveInt, computed_field
from generalresearch.models import MAX_INT32, Source
from generalresearch.models.custom_types import (
AwareDatetimeISO,
- UUIDStr,
CountryISOLike,
LanguageISOLike,
+ UUIDStr,
)
from generalresearch.models.thl.locales import CountryISO, LanguageISO
@@ -48,8 +48,9 @@ class MarketplaceQuestion(BaseModel, ABC):
@property
@abstractmethod
def internal_id(self) -> str:
- """This is the value that is used for this question within the marketplace. Typically,
- this is question_id. Innovate uses question_key."""
+ """This is the value that is used for this question within the
+ marketplace. Typically, this is question_id. Innovate uses question_key.
+ """
...
@property
@@ -58,8 +59,9 @@ class MarketplaceQuestion(BaseModel, ABC):
@property
def _key(self) -> Tuple[str, CountryISOLike, LanguageISOLike]:
- """This uniquely identifies a question in a locale. There is a unique index
- on this in the db. e.g. (question_id, country_iso, language_iso)"""
+ """This uniquely identifies a question in a locale. There is a unique
+ index on this in the db. e.g. (question_id, country_iso, language_iso)
+ """
return self.internal_id, self.country_iso, self.language_iso
@abstractmethod
diff --git a/generalresearch/models/thl/profiling/question.py b/generalresearch/models/thl/profiling/question.py
index 6f6b270..72115fe 100644
--- a/generalresearch/models/thl/profiling/question.py
+++ b/generalresearch/models/thl/profiling/question.py
@@ -1,19 +1,19 @@
from __future__ import annotations
-from typing import Optional, Dict, Tuple
+from typing import Any, Dict, Optional, Tuple
from pydantic import (
BaseModel,
- Field,
ConfigDict,
+ Field,
computed_field,
)
from generalresearch.models.custom_types import (
- UUIDStr,
AwareDatetimeISO,
CountryISOLike,
LanguageISOLike,
+ UUIDStr,
)
from generalresearch.models.thl.profiling.upk_question import UpkQuestion
@@ -34,7 +34,7 @@ class Question(BaseModel):
)
data: UpkQuestion = Field()
is_live: bool = Field()
- custom: Dict = Field(default_factory=dict)
+ custom: Dict[str, Any] = Field(default_factory=dict)
last_updated: AwareDatetimeISO = Field()
@computed_field
diff --git a/generalresearch/models/thl/profiling/upk_property.py b/generalresearch/models/thl/profiling/upk_property.py
index 2b63aef..9eede95 100644
--- a/generalresearch/models/thl/profiling/upk_property.py
+++ b/generalresearch/models/thl/profiling/upk_property.py
@@ -1,11 +1,11 @@
from enum import Enum
from functools import cached_property
-from typing import List, Optional, Dict
+from typing import Dict, List, Optional
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
-from generalresearch.models.custom_types import UUIDStr, CountryISOLike
+from generalresearch.models.custom_types import CountryISOLike, UUIDStr
from generalresearch.models.thl.category import Category
from generalresearch.utils.enum import ReprEnumMeta
diff --git a/generalresearch/models/thl/profiling/upk_question.py b/generalresearch/models/thl/profiling/upk_question.py
index 2b952ec..5c908b0 100644
--- a/generalresearch/models/thl/profiling/upk_question.py
+++ b/generalresearch/models/thl/profiling/upk_question.py
@@ -5,16 +5,16 @@ import json
import re
from enum import Enum
from functools import cached_property
-from typing import List, Optional, Union, Literal, Dict, Tuple, Set
+from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
from pydantic import (
BaseModel,
- Field,
- model_validator,
- field_validator,
ConfigDict,
+ Field,
NonNegativeInt,
PositiveInt,
+ field_validator,
+ model_validator,
)
from typing_extensions import Annotated
@@ -302,7 +302,7 @@ class UpkQuestion(BaseModel):
# Don't set a min_length=1 here. We'll allow this to be created, but it
# won't be askable with empty choices.
choices: Optional[List[UpkQuestionChoice]] = Field(default=None)
- selector: SelectorType = Field(default=None)
+ selector: SelectorType = Field()
configuration: Optional[Configuration] = Field(default=None)
validation: Optional[UpkQuestionValidation] = Field(default=None)
importance: Optional[UPKImportance] = Field(default=None)
@@ -348,7 +348,7 @@ class UpkQuestion(BaseModel):
@model_validator(mode="before")
@classmethod
- def check_configuration_type(cls, data: Dict):
+ def check_configuration_type(cls, data: Dict[str, Any]) -> Dict[str, Any]:
# The model knows what the type of Configuration to grab depending on
# the key 'type' which it expects inside the configuration object.
# Here, we grab the type from the top-level model instead.
@@ -433,19 +433,23 @@ class UpkQuestion(BaseModel):
@field_validator("choices")
@classmethod
- def order_choices(cls, choices):
+ def order_choices(cls, choices: List):
if choices:
choices.sort(key=lambda x: x.order)
return choices
@field_validator("choices")
@classmethod
- def validate_choices(cls, choices):
+ def validate_choices(
+ cls, choices: Optional[List[UpkQuestionChoice]]
+ ) -> Optional[List[UpkQuestionChoice]]:
if choices:
ids = {x.id for x in choices}
assert len(ids) == len(choices), "choices.id must be unique"
+
orders = {x.order for x in choices}
assert len(orders) == len(choices), "choices.order must be unique"
+
return choices
@field_validator("explanation_template", "explanation_fragment_template")
diff --git a/generalresearch/models/thl/profiling/upk_question_answer.py b/generalresearch/models/thl/profiling/upk_question_answer.py
index 28b9f27..2eb52e1 100644
--- a/generalresearch/models/thl/profiling/upk_question_answer.py
+++ b/generalresearch/models/thl/profiling/upk_question_answer.py
@@ -1,5 +1,5 @@
from datetime import datetime, timezone
-from typing import Optional, Union, Dict
+from typing import Any, Dict, Optional, Union
from uuid import uuid4
from pydantic import (
@@ -7,25 +7,24 @@ from pydantic import (
ConfigDict,
Field,
PositiveInt,
- model_validator,
computed_field,
+ model_validator,
)
from typing_extensions import Self
from generalresearch.models import MAX_INT32
from generalresearch.models.custom_types import (
- UUIDStr,
AwareDatetimeISO,
CountryISOLike,
+ UUIDStr,
)
from generalresearch.models.thl.profiling.upk_property import (
- PropertyType,
Cardinality,
+ PropertyType,
)
class UpkQuestionAnswer(BaseModel):
- """ """
model_config = ConfigDict(populate_by_name=True)
@@ -110,7 +109,7 @@ class UpkQuestionAnswer(BaseModel):
return self
- def model_dump_mysql(self) -> Dict:
+ def model_dump_mysql(self) -> Dict[str, Any]:
d = self.model_dump(mode="json")
d["created"] = self.created
return d
diff --git a/generalresearch/models/thl/profiling/user_info.py b/generalresearch/models/thl/profiling/user_info.py
index 609121e..15197c9 100644
--- a/generalresearch/models/thl/profiling/user_info.py
+++ b/generalresearch/models/thl/profiling/user_info.py
@@ -1,4 +1,4 @@
-from typing import Optional, List
+from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
from pydantic.json_schema import SkipJsonSchema
diff --git a/generalresearch/models/thl/profiling/user_question_answer.py b/generalresearch/models/thl/profiling/user_question_answer.py
index b42183d..b325583 100644
--- a/generalresearch/models/thl/profiling/user_question_answer.py
+++ b/generalresearch/models/thl/profiling/user_question_answer.py
@@ -1,19 +1,19 @@
import json
-from datetime import datetime, timezone, timedelta
-from typing import Dict, Tuple, Iterator, Optional, Literal, Union, Any
+from datetime import datetime, timedelta, timezone
+from typing import Any, Dict, Iterator, Literal, Optional, Tuple, Union
from pydantic import (
- PositiveInt,
+ BaseModel,
+ ConfigDict,
Field,
+ PositiveInt,
field_validator,
model_validator,
- BaseModel,
- ConfigDict,
)
from typing_extensions import Self
from generalresearch.grpc import timestamp_to_datetime
-from generalresearch.models import Source, MAX_INT32
+from generalresearch.models import MAX_INT32, Source
from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
from generalresearch.models.thl.locales import CountryISO, LanguageISO
from generalresearch.models.thl.profiling.upk_question import UpkQuestion
diff --git a/generalresearch/models/thl/report_task.py b/generalresearch/models/thl/report_task.py
index e945b3b..1abd330 100644
--- a/generalresearch/models/thl/report_task.py
+++ b/generalresearch/models/thl/report_task.py
@@ -1,6 +1,6 @@
import random
from collections import defaultdict
-from typing import List, Collection, Optional
+from typing import Collection, List, Optional
from pydantic import BaseModel, ConfigDict, Field
diff --git a/generalresearch/models/thl/session.py b/generalresearch/models/thl/session.py
index 0066559..74ed5eb 100644
--- a/generalresearch/models/thl/session.py
+++ b/generalresearch/models/thl/session.py
@@ -1,29 +1,28 @@
import json
import logging
-from datetime import datetime, timezone, timedelta
+from datetime import datetime, timedelta, timezone
from decimal import Decimal
-from typing import Optional, Dict, Any, Tuple, Union, List, Annotated
+from typing import TYPE_CHECKING, Annotated, Any, Dict, List, Optional, Tuple, Union
from uuid import uuid4
-from typing import TYPE_CHECKING
from pydantic import (
- BaseModel,
AwareDatetime,
+ BaseModel,
+ ConfigDict,
Field,
- model_validator,
- field_validator,
computed_field,
- ConfigDict,
field_serializer,
+ field_validator,
+ model_validator,
)
from typing_extensions import Self
from generalresearch.models import DeviceType, Source
from generalresearch.models.custom_types import (
- UUIDStr,
AwareDatetimeISO,
- IPvAnyAddressStr,
EnumNameSerializer,
+ IPvAnyAddressStr,
+ UUIDStr,
)
from generalresearch.models.legacy.bucket import Bucket
from generalresearch.models.thl import (
@@ -32,15 +31,15 @@ from generalresearch.models.thl import (
int_cents_to_decimal,
)
from generalresearch.models.thl.definitions import (
- Status,
+ WALL_ALLOWED_STATUS_CODE_1_2,
+ WALL_ALLOWED_STATUS_STATUS_CODE,
+ ReportValue,
SessionAdjustedStatus,
- WallAdjustedStatus,
+ SessionStatusCode2,
+ Status,
StatusCode1,
- ReportValue,
+ WallAdjustedStatus,
WallStatusCode2,
- SessionStatusCode2,
- WALL_ALLOWED_STATUS_CODE_1_2,
- WALL_ALLOWED_STATUS_STATUS_CODE,
)
from generalresearch.models.thl.user import User
diff --git a/generalresearch/models/thl/stats.py b/generalresearch/models/thl/stats.py
index fc8be59..9e9971d 100644
--- a/generalresearch/models/thl/stats.py
+++ b/generalresearch/models/thl/stats.py
@@ -1,6 +1,6 @@
from typing import Optional
-from pydantic import BaseModel, Field, model_validator, computed_field
+from pydantic import BaseModel, Field, computed_field, model_validator
class StatisticalSummary(BaseModel):
diff --git a/generalresearch/models/thl/survey/__init__.py b/generalresearch/models/thl/survey/__init__.py
index fd78091..5feb955 100644
--- a/generalresearch/models/thl/survey/__init__.py
+++ b/generalresearch/models/thl/survey/__init__.py
@@ -1,26 +1,26 @@
from abc import ABC, abstractmethod
from decimal import Decimal
from itertools import product
-from typing import Set, Optional, List, Dict, Type
+from typing import Dict, List, Optional, Set, Type
from more_itertools import flatten
from pydantic import BaseModel, Field
from generalresearch.models import Source
from generalresearch.models.thl.demographics import (
+ AgeGroup,
DemographicTarget,
Gender,
- AgeGroup,
)
from generalresearch.models.thl.locales import (
- CountryISOs,
- LanguageISOs,
CountryISO,
+ CountryISOs,
LanguageISO,
+ LanguageISOs,
)
from generalresearch.models.thl.survey.condition import (
- MarketplaceCondition,
ConditionValueType,
+ MarketplaceCondition,
)
@@ -68,8 +68,9 @@ class MarketplaceTask(BaseModel, ABC):
@property
@abstractmethod
def internal_id(self) -> str:
- """This is the value that is used for this survey within the marketplace. Typically,
- this is survey_id/survey_number. Morning is quota_id, repdata: stream_id.
+ """This is the value that is used for this survey within the
+ marketplace. Typically, this is survey_id/survey_number. Morning
+ is quota_id, repdata: stream_id.
"""
...
diff --git a/generalresearch/models/thl/survey/buyer.py b/generalresearch/models/thl/survey/buyer.py
index 0d235ed..b4d8fb0 100644
--- a/generalresearch/models/thl/survey/buyer.py
+++ b/generalresearch/models/thl/survey/buyer.py
@@ -1,16 +1,16 @@
-from datetime import timezone, datetime
+from datetime import datetime, timezone
from decimal import Decimal
-from typing import Optional, Annotated
-
from math import log
+from typing import Annotated, Optional
+
from pydantic import (
- model_validator,
BaseModel,
ConfigDict,
Field,
- PositiveInt,
NonNegativeInt,
+ PositiveInt,
computed_field,
+ model_validator,
)
from scipy.stats import beta as beta_dist
@@ -24,7 +24,8 @@ from generalresearch.models.custom_types import (
class Buyer(BaseModel):
"""
- The entity that commissions and pays for a task and uses the resulting data or insights.
+ The entity that commissions and pays for a task and uses the
+ resulting data or insights.
"""
model_config = ConfigDict(validate_assignment=True)
@@ -176,7 +177,6 @@ class BuyerCountryStat(BaseModel):
# ---- Scoring ----
score: float = Field(
description="Composite score calculated from all of the individual features",
- default=None,
examples=[-5.329389837486194],
)
diff --git a/generalresearch/models/thl/survey/condition.py b/generalresearch/models/thl/survey/condition.py
index a60b034..3610750 100644
--- a/generalresearch/models/thl/survey/condition.py
+++ b/generalresearch/models/thl/survey/condition.py
@@ -2,19 +2,19 @@ import hashlib
from abc import ABC
from enum import Enum
from functools import cached_property
-from typing import List, Dict, Set, Optional, Any, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple
from pydantic import (
BaseModel,
+ ConfigDict,
Field,
+ PrivateAttr,
+ StringConstraints,
computed_field,
- ConfigDict,
field_validator,
model_validator,
- StringConstraints,
- PrivateAttr,
)
-from typing_extensions import Self, Annotated
+from typing_extensions import Annotated, Self
from generalresearch.models import LogicalOperator
diff --git a/generalresearch/models/thl/survey/model.py b/generalresearch/models/thl/survey/model.py
index 9e7a402..8d37af4 100644
--- a/generalresearch/models/thl/survey/model.py
+++ b/generalresearch/models/thl/survey/model.py
@@ -1,28 +1,28 @@
-from datetime import timezone, datetime
+from datetime import datetime, timezone
from decimal import Decimal
-from typing import Optional, List, Tuple, Dict
-from typing_extensions import Annotated
+from typing import Any, Dict, List, Optional, Tuple
from pydantic import (
BaseModel,
ConfigDict,
Field,
+ NonNegativeFloat,
+ NonNegativeInt,
PositiveInt,
- model_validator,
computed_field,
- NonNegativeInt,
- NonNegativeFloat,
field_validator,
+ model_validator,
)
+from typing_extensions import Annotated
from generalresearch.managers.thl.buyer import Buyer
from generalresearch.models import Source
from generalresearch.models.custom_types import (
AwareDatetimeISO,
CountryISOLike,
- SurveyKey,
EnumNameSerializer,
PropertyCode,
+ SurveyKey,
)
from generalresearch.models.thl.category import Category
from generalresearch.models.thl.definitions import Status, StatusCode1
@@ -273,7 +273,7 @@ class TaskActivityPublic(BaseModel):
last_entrance: Optional[AwareDatetimeISO] = Field(default=None)
@field_validator("status_code_1_percentages", mode="before")
- def transform_enum_name_pct(cls, value: dict) -> dict:
+ def transform_enum_name_pct(cls, value: Dict[str, Any]) -> Dict[str, Any]:
# If we are serializing+deserializing this model (i.e. when we cache
# it), this fails because we've replaced the enum value with the
# name. Put it back here ...
@@ -297,7 +297,7 @@ class TaskActivityPrivate(TaskActivityPublic):
)
@field_validator("status_code_1_counts", mode="before")
- def transform_enum_name_cnt(cls, value: dict) -> dict:
+ def transform_enum_name_cnt(cls, value: Dict[str, Any]) -> Dict[str, Any]:
# If we are serializing+deserializing this model (i.e. when we cache
# it), this fails because we've replaced the enum value with the
# name. Put it back here ...
diff --git a/generalresearch/models/thl/survey/penalty.py b/generalresearch/models/thl/survey/penalty.py
index 915254f..a9c8a56 100644
--- a/generalresearch/models/thl/survey/penalty.py
+++ b/generalresearch/models/thl/survey/penalty.py
@@ -1,5 +1,5 @@
import abc
-from datetime import timezone, datetime
+from datetime import datetime, timezone
from typing import List, Literal, Union
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
@@ -7,8 +7,8 @@ from typing_extensions import Annotated
from generalresearch.models import Source
from generalresearch.models.custom_types import (
- UUIDStr,
AwareDatetimeISO,
+ UUIDStr,
)
diff --git a/generalresearch/models/thl/survey/task_collection.py b/generalresearch/models/thl/survey/task_collection.py
index 68c2ff3..c41d8b9 100644
--- a/generalresearch/models/thl/survey/task_collection.py
+++ b/generalresearch/models/thl/survey/task_collection.py
@@ -6,7 +6,7 @@ from typing import List
import pandas as pd
import pandera
from pandera import DataFrameSchema
-from pydantic import Field, ConfigDict, BaseModel, model_validator
+from pydantic import BaseModel, ConfigDict, Field, model_validator
from generalresearch.models.thl.survey import MarketplaceTask
diff --git a/generalresearch/models/thl/synchronize_global_vars.py b/generalresearch/models/thl/synchronize_global_vars.py
index 4a587e1..72d987d 100644
--- a/generalresearch/models/thl/synchronize_global_vars.py
+++ b/generalresearch/models/thl/synchronize_global_vars.py
@@ -1,6 +1,6 @@
from typing import List
-from pydantic import Field, BaseModel
+from pydantic import BaseModel, Field
class SynchronizeGlobalVarsMsg(BaseModel):
diff --git a/generalresearch/models/thl/task_adjustment.py b/generalresearch/models/thl/task_adjustment.py
index afb2a6a..9d8a04d 100644
--- a/generalresearch/models/thl/task_adjustment.py
+++ b/generalresearch/models/thl/task_adjustment.py
@@ -5,8 +5,8 @@ from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator
-from generalresearch.models import Source, MAX_INT32
-from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO
+from generalresearch.models import MAX_INT32, Source
+from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
from generalresearch.models.thl.definitions import (
WallAdjustedStatus,
)
diff --git a/generalresearch/models/thl/task_status.py b/generalresearch/models/thl/task_status.py
index f476286..a7aec91 100644
--- a/generalresearch/models/thl/task_status.py
+++ b/generalresearch/models/thl/task_status.py
@@ -1,39 +1,39 @@
from datetime import datetime
-from typing import Dict, Optional, Any, Literal, Annotated, List
+from typing import Annotated, Any, Dict, List, Literal, Optional
from pydantic import (
BaseModel,
Field,
- model_validator,
NonNegativeInt,
computed_field,
- field_validator,
field_serializer,
+ field_validator,
+ model_validator,
)
from typing_extensions import Self
from generalresearch.models.custom_types import (
- UUIDStr,
AwareDatetimeISO,
EnumNameSerializer,
+ UUIDStr,
)
from generalresearch.models.thl import decimal_to_int_cents
from generalresearch.models.thl.definitions import (
- StatusCode1,
+ SessionAdjustedStatus,
SessionStatusCode2,
Status,
- SessionAdjustedStatus,
+ StatusCode1,
)
from generalresearch.models.thl.pagination import Page
from generalresearch.models.thl.payout_format import (
- PayoutFormatType,
PayoutFormatOptionalField,
+ PayoutFormatType,
)
from generalresearch.models.thl.product import (
PayoutTransformation,
Product,
)
-from generalresearch.models.thl.session import WallOut, Session
+from generalresearch.models.thl.session import Session, WallOut
# API uses the ints, b/c this is what the grpc returned originally ...
STATUS_MAP = {
diff --git a/generalresearch/models/thl/user.py b/generalresearch/models/thl/user.py
index e393830..852766c 100644
--- a/generalresearch/models/thl/user.py
+++ b/generalresearch/models/thl/user.py
@@ -3,20 +3,20 @@ from __future__ import annotations
import json
import logging
import re
-from datetime import timezone, datetime
-from typing import Optional, Dict, List, TYPE_CHECKING
-from uuid import uuid4, UUID
+from datetime import datetime, timezone
+from typing import TYPE_CHECKING, Dict, List, Optional
+from uuid import UUID, uuid4
from pydantic import (
+ AfterValidator,
AwareDatetime,
- Field,
BaseModel,
- field_validator,
- model_validator,
- PositiveInt,
ConfigDict,
+ Field,
+ PositiveInt,
StringConstraints,
- AfterValidator,
+ field_validator,
+ model_validator,
)
from sentry_sdk import set_tag, set_user
from typing_extensions import Annotated, Self
@@ -30,10 +30,10 @@ from generalresearch.models.thl.userhealth import AuditLog
from generalresearch.pg_helper import PostgresConfig
if TYPE_CHECKING:
- from generalresearch.managers.thl.userhealth import AuditLogManager
from generalresearch.managers.thl.ledger_manager.thl_ledger import (
ThlLedgerManager,
)
+ from generalresearch.managers.thl.userhealth import AuditLogManager
# from generalresearch.managers.thl.userhealth import UserIpHistoryManager
diff --git a/generalresearch/models/thl/user_iphistory.py b/generalresearch/models/thl/user_iphistory.py
index ecdbc7a..159b97b 100644
--- a/generalresearch/models/thl/user_iphistory.py
+++ b/generalresearch/models/thl/user_iphistory.py
@@ -1,12 +1,12 @@
import ipaddress
-from datetime import timezone, datetime, timedelta
-from typing import List, Optional, Dict
+from datetime import datetime, timedelta, timezone
+from typing import Dict, List, Optional
from faker import Faker
from pydantic import (
BaseModel,
- Field,
ConfigDict,
+ Field,
PositiveInt,
field_validator,
)
@@ -14,8 +14,8 @@ from typing_extensions import Self
from generalresearch.models.custom_types import (
AwareDatetimeISO,
- IPvAnyAddressStr,
CountryISOLike,
+ IPvAnyAddressStr,
)
from generalresearch.models.thl.ipinfo import (
GeoIPInformation,
diff --git a/generalresearch/models/thl/user_profile.py b/generalresearch/models/thl/user_profile.py
index 6514c7d..98b3326 100644
--- a/generalresearch/models/thl/user_profile.py
+++ b/generalresearch/models/thl/user_profile.py
@@ -1,16 +1,16 @@
import hashlib
-from typing import Optional, Dict, Any, List
+from typing import Any, Dict, List, Optional
from pydantic import (
- Field,
BaseModel,
ConfigDict,
EmailStr,
+ Field,
PositiveInt,
computed_field,
)
from pydantic.json_schema import SkipJsonSchema
-from typing_extensions import Self, Annotated
+from typing_extensions import Annotated, Self
from generalresearch.models import MAX_INT32, Source
from generalresearch.models.custom_types import UUIDStr
diff --git a/generalresearch/models/thl/user_quality_event.py b/generalresearch/models/thl/user_quality_event.py
index f98e42d..2b4873a 100644
--- a/generalresearch/models/thl/user_quality_event.py
+++ b/generalresearch/models/thl/user_quality_event.py
@@ -7,8 +7,8 @@ from typing import List, Literal, Optional
from pydantic import BaseModel, Field, PositiveInt
-from generalresearch.models import Source, MAX_INT32
-from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO
+from generalresearch.models import MAX_INT32, Source
+from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
from generalresearch.models.thl.definitions import WallAdjustedStatus
from generalresearch.models.thl.user import BPUIDStr
from generalresearch.utils.enum import ReprEnumMeta
diff --git a/generalresearch/models/thl/user_streak.py b/generalresearch/models/thl/user_streak.py
index fa6d3b1..5d13bc5 100644
--- a/generalresearch/models/thl/user_streak.py
+++ b/generalresearch/models/thl/user_streak.py
@@ -1,20 +1,20 @@
from datetime import date, datetime, timedelta
from enum import Enum
from typing import Optional, Tuple
-from zoneinfo import ZoneInfo
import pandas as pd
from pydantic import (
+ AwareDatetime,
BaseModel,
- NonNegativeInt,
+ ConfigDict,
Field,
+ NonNegativeInt,
+ PositiveInt,
computed_field,
- AwareDatetime,
model_validator,
- ConfigDict,
- PositiveInt,
)
from pydantic.json_schema import SkipJsonSchema
+from zoneinfo import ZoneInfo
from generalresearch.managers.leaderboard import country_timezone
from generalresearch.models import MAX_INT32
diff --git a/generalresearch/models/thl/userhealth.py b/generalresearch/models/thl/userhealth.py
index 8ea81dd..fd275de 100644
--- a/generalresearch/models/thl/userhealth.py
+++ b/generalresearch/models/thl/userhealth.py
@@ -1,8 +1,8 @@
from datetime import datetime, timezone
from enum import Enum
-from typing import Optional, Dict
+from typing import Dict, Optional
-from pydantic import Field, BaseModel, PositiveInt, NonNegativeFloat
+from pydantic import BaseModel, Field, NonNegativeFloat, PositiveInt
from typing_extensions import Self
from generalresearch.models.custom_types import AwareDatetimeISO
diff --git a/generalresearch/models/thl/wallet/__init__.py b/generalresearch/models/thl/wallet/__init__.py
index 928d67f..e3f5144 100644
--- a/generalresearch/models/thl/wallet/__init__.py
+++ b/generalresearch/models/thl/wallet/__init__.py
@@ -13,7 +13,7 @@ class PayoutType(str, Enum, metaclass=ReprEnumMeta):
# User is paid out to their personal PayPal email address
PAYPAL = "PAYPAL"
- # User is paid uut via a Tango Gift Card
+ # User is paid out via a Tango Gift Card
TANGO = "TANGO"
# DWOLLA
DWOLLA = "DWOLLA"
diff --git a/generalresearch/models/thl/wallet/cashout_method.py b/generalresearch/models/thl/wallet/cashout_method.py
index def724d..c04bd66 100644
--- a/generalresearch/models/thl/wallet/cashout_method.py
+++ b/generalresearch/models/thl/wallet/cashout_method.py
@@ -4,31 +4,31 @@ import hashlib
import logging
from datetime import datetime, timezone
from enum import Enum
-from typing import List, Dict, Any, Optional, Literal, Union
+from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import (
BaseModel,
- Field,
ConfigDict,
+ EmailStr,
+ Field,
NonNegativeInt,
PositiveInt,
- EmailStr,
- model_validator,
field_validator,
+ model_validator,
)
from typing_extensions import Self
from generalresearch.currency import USDCent
from generalresearch.models.custom_types import (
- UUIDStr,
- HttpsUrlStr,
AwareDatetimeISO,
+ HttpsUrlStr,
+ UUIDStr,
)
from generalresearch.models.legacy.api_status import StatusResponse
from generalresearch.models.thl.definitions import PayoutStatus
from generalresearch.models.thl.locales import CountryISO
from generalresearch.models.thl.user import BPUIDStr, User
-from generalresearch.models.thl.wallet import PayoutType, Currency
+from generalresearch.models.thl.wallet import Currency, PayoutType
from generalresearch.utils.enum import ReprEnumMeta
logger = logging.getLogger()
diff --git a/generalresearch/models/thl/wallet/payout.py b/generalresearch/models/thl/wallet/payout.py
index 97e0c3d..72aea61 100644
--- a/generalresearch/models/thl/wallet/payout.py
+++ b/generalresearch/models/thl/wallet/payout.py
@@ -1,6 +1,6 @@
import json
from datetime import datetime, timezone
-from typing import Dict, Optional, Collection, List
+from typing import Any, Collection, Dict, List, Optional, Union
from uuid import uuid4
from pydantic import (
@@ -12,7 +12,7 @@ from pydantic import (
)
from generalresearch.currency import USDCent
-from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO
+from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
from generalresearch.models.thl.definitions import PayoutStatus
from generalresearch.models.thl.wallet import PayoutType
from generalresearch.models.thl.wallet.cashout_method import (
@@ -74,11 +74,11 @@ class PayoutEvent(BaseModel, validate_assignment=True):
# Stores payout-type-specific information that is used to request this
# payout from the external provider.
- request_data: Dict = Field(default_factory=dict)
+ request_data: Dict[str, Any] = Field(default_factory=dict)
# Stores payout-type-specific order information that is returned from
# the external payout provider.
- order_data: Optional[Dict | CashMailOrderData] = Field(default=None)
+ order_data: Optional[Union[Dict[str, Any], CashMailOrderData]] = Field(default=None)
@field_validator("payout_type", mode="before")
@classmethod
@@ -94,7 +94,7 @@ class PayoutEvent(BaseModel, validate_assignment=True):
self,
status: PayoutStatus,
ext_ref_id: Optional[str] = None,
- order_data: Optional[Dict] = None,
+ order_data: Optional[Dict[str, Any]] = None,
) -> None:
# These 3 things are the only modifiable attributes
self.check_status_change_allowed(status)
@@ -128,7 +128,7 @@ class PayoutEvent(BaseModel, validate_assignment=True):
else:
raise ValueError("this shouldn't happen")
- def model_dump_mysql(self, *args, **kwargs) -> dict:
+ def model_dump_mysql(self, *args, **kwargs) -> Dict[str, Any]:
d = self.model_dump(mode="json", *args, **kwargs)
if "created" in d:
d["created"] = self.created.replace(tzinfo=None)
diff --git a/generalresearch/models/thl/wallet/user_wallet.py b/generalresearch/models/thl/wallet/user_wallet.py
index 917a09d..dbe66fa 100644
--- a/generalresearch/models/thl/wallet/user_wallet.py
+++ b/generalresearch/models/thl/wallet/user_wallet.py
@@ -2,12 +2,12 @@ from __future__ import annotations
import logging
-from pydantic import BaseModel, Field, ConfigDict, NonNegativeInt
+from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt
from generalresearch.models.legacy.api_status import StatusResponse
from generalresearch.models.thl.payout_format import (
- PayoutFormatType,
PayoutFormatField,
+ PayoutFormatType,
)
logger = logging.getLogger()
diff --git a/generalresearch/pg_helper.py b/generalresearch/pg_helper.py
index 064a883..f7c9675 100644
--- a/generalresearch/pg_helper.py
+++ b/generalresearch/pg_helper.py
@@ -1,15 +1,14 @@
+from datetime import timezone
from typing import Optional
+import psycopg
from psycopg.adapt import Buffer
-from psycopg.types.net import InetLoader, Address, Interface
+from psycopg.rows import RowFactory, dict_row
+from psycopg.types.datetime import TimestampLoader
+from psycopg.types.net import Address, InetLoader, Interface
from psycopg.types.string import TextLoader
-from pydantic import PostgresDsn
-
-import psycopg
-from psycopg.rows import dict_row, RowFactory
from psycopg.types.uuid import UUIDLoader
-from psycopg.types.datetime import TimestampLoader
-from datetime import timezone
+from pydantic import PostgresDsn
class UUIDHexLoader(UUIDLoader):
diff --git a/generalresearch/schemas/survey_stats.py b/generalresearch/schemas/survey_stats.py
index 532e0d7..1570554 100644
--- a/generalresearch/schemas/survey_stats.py
+++ b/generalresearch/schemas/survey_stats.py
@@ -1,5 +1,5 @@
import pandas as pd
-from pandera import DataFrameSchema, Column, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
from generalresearch.locales import Localelator
from generalresearch.models import Source
diff --git a/generalresearch/sql_helper.py b/generalresearch/sql_helper.py
index b92813b..175830c 100644
--- a/generalresearch/sql_helper.py
+++ b/generalresearch/sql_helper.py
@@ -1,8 +1,8 @@
import logging
-from typing import Any, Dict, List, Tuple, Union, Optional
+from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import UUID
-from pydantic import MySQLDsn, PostgresDsn, MariaDBDsn
+from pydantic import MariaDBDsn, MySQLDsn, PostgresDsn
from pymysql import Connection
ListOrTupleOfStrings = Union[List[str], Tuple[str, ...]]
@@ -119,7 +119,7 @@ def is_uuid4(s: Any) -> bool:
return False
-def decode_uuids(row: Dict) -> Dict:
+def decode_uuids(row: Dict[str, Any]) -> Dict[str, Any]:
return {
key: (UUID(value, version=4).hex if is_uuid4(value) else value)
for key, value in row.items()
@@ -131,7 +131,9 @@ class SqlHelper(SqlConnector):
def __init__(self, dsn: Optional[DataBaseDsn] = None, **kwargs):
super(SqlHelper, self).__init__(dsn, **kwargs)
- def execute_sql_query(self, query, params=None, commit=False) -> List[Dict]:
+ def execute_sql_query(
+ self, query: str, params: Optional[Dict[str, Any]] = None, commit: bool = False
+ ) -> List[Dict[str, Any]]:
for param in params if params else []:
if isinstance(param, (tuple, list, set)) and len(param) == 0:
logging.warning("param is empty. not executing query")
@@ -157,7 +159,7 @@ class SqlHelper(SqlConnector):
field_names: ListOrTupleOfStrings,
values_to_insert: ListOrTupleOfListOrTuple,
cursor=None,
- ignore_existing=False,
+ ignore_existing: bool = False,
) -> None:
"""
:param table_name: name of table
diff --git a/generalresearch/utils/aggregation.py b/generalresearch/utils/aggregation.py
index b168e4c..4023dc9 100644
--- a/generalresearch/utils/aggregation.py
+++ b/generalresearch/utils/aggregation.py
@@ -1,8 +1,8 @@
from collections import defaultdict
-from typing import Dict, List
+from typing import Any, Dict, List
-def group_by_year(records: List[Dict], datetime_field: str) -> Dict[int, List]:
+def group_by_year(records: List[Dict], datetime_field: str) -> Dict[int, List[Any]]:
"""Memory efficient - processes records one at a time"""
by_year = defaultdict(list)
diff --git a/generalresearch/utils/grpc_logger.py b/generalresearch/utils/grpc_logger.py
index a98a13b..59f7471 100644
--- a/generalresearch/utils/grpc_logger.py
+++ b/generalresearch/utils/grpc_logger.py
@@ -1,7 +1,7 @@
import json
import logging
-from logging.handlers import TimedRotatingFileHandler
import time
+from logging.handlers import TimedRotatingFileHandler
handler = TimedRotatingFileHandler(
"grpc_access.log", when="midnight", backupCount=3, encoding="utf-8"
diff --git a/generalresearch/wall_status_codes/__init__.py b/generalresearch/wall_status_codes/__init__.py
index 102b72a..3a0abb8 100644
--- a/generalresearch/wall_status_codes/__init__.py
+++ b/generalresearch/wall_status_codes/__init__.py
@@ -1,21 +1,21 @@
-from typing import Tuple, Optional
+from typing import Optional, Tuple
from generalresearch.models import Source
from generalresearch.models.thl.definitions import Status, StatusCode1
from generalresearch.models.thl.session import Wall
from generalresearch.wall_status_codes import (
+ cint,
dynata,
fullcircle,
innovate,
+ lucid,
morning,
pollfish,
precision,
- spectrum,
- sago,
- cint,
- lucid,
prodege,
repdata,
+ sago,
+ spectrum,
)
@@ -36,6 +36,7 @@ def annotate_status_code(
return Status.FAIL, StatusCode1.UNKNOWN, None
if source == Source.PULLEY:
return Status.FAIL, StatusCode1.UNKNOWN, None
+
return {
Source.CINT: cint.annotate_status_code,
Source.DYNATA: dynata.annotate_status_code,
diff --git a/generalresearch/wall_status_codes/cint.py b/generalresearch/wall_status_codes/cint.py
index 0fa929c..8042cd2 100644
--- a/generalresearch/wall_status_codes/cint.py
+++ b/generalresearch/wall_status_codes/cint.py
@@ -1,5 +1,6 @@
-from typing import Optional
+from typing import Any, Optional, Tuple
+from generalresearch.models.thl.definitions import Status, StatusCode1
from generalresearch.wall_status_codes import lucid
@@ -7,7 +8,7 @@ def annotate_status_code(
ext_status_code_1: str,
ext_status_code_2: Optional[str] = None,
ext_status_code_3: Optional[str] = None,
-):
+) -> Tuple[Status, StatusCode1, Optional[Any]]:
return lucid.annotate_status_code(
ext_status_code_1=ext_status_code_1,
ext_status_code_2=ext_status_code_2,
diff --git a/generalresearch/wall_status_codes/dynata.py b/generalresearch/wall_status_codes/dynata.py
index f37fca5..958e06f 100644
--- a/generalresearch/wall_status_codes/dynata.py
+++ b/generalresearch/wall_status_codes/dynata.py
@@ -4,11 +4,11 @@ checked by Greg 2023-10-10
"""
from collections import defaultdict
-from typing import Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from generalresearch.models.thl.definitions import Status, StatusCode1
-status_codes_name = {
+status_codes_name: Dict[str, str] = {
"0.0": "Unknown",
"0.1": "Missing Language",
"0.2": "Missing Respondent ID",
@@ -51,10 +51,10 @@ status_codes_name = {
"5.10": "Daily Limit",
}
-status_map = defaultdict(
+status_map: Dict[str, Status] = defaultdict(
lambda: Status.FAIL, **{"1.0": Status.COMPLETE, "1.1": Status.COMPLETE}
)
-status_codes_ext_map = {
+status_codes_ext_map: Dict[StatusCode1, List[str]] = {
StatusCode1.COMPLETE: ["1.0", "1.1"],
StatusCode1.BUYER_FAIL: ["2.2", "3.2"],
StatusCode1.BUYER_QUALITY_FAIL: ["5.1", "5.2"],
@@ -88,9 +88,13 @@ status_codes_ext_map = {
"5.10",
],
}
-ext_status_code_map = dict()
+ext_status_code_map: Dict[str, StatusCode1] = dict()
for k, v in status_codes_ext_map.items():
+ k: StatusCode1
+ v: List[str]
+
for vv in v:
+ vv: str
ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k
@@ -98,7 +102,7 @@ def annotate_status_code(
ext_status_code_1: str,
ext_status_code_2: Optional[str] = None,
ext_status_code_3: Optional[str] = None,
-) -> Tuple:
+) -> Tuple[Status, StatusCode1, Optional[Any]]:
"""
:params ext_status_code_1: this is from the callback url params:
disposition and status, '.'-joined
@@ -113,7 +117,7 @@ def annotate_status_code(
return status, status_code, None
-def stop_marketplace_session(status_code_1, ext_status_code_1) -> bool:
+def stop_marketplace_session(status_code_1: StatusCode1, ext_status_code_1) -> bool:
if ext_status_code_1.startswith("5"):
# '5.10' is the user hit a Daily Limit, so they should not be sent in again today
return True
diff --git a/generalresearch/wall_status_codes/fullcircle.py b/generalresearch/wall_status_codes/fullcircle.py
index 08f2c44..eda4d9c 100644
--- a/generalresearch/wall_status_codes/fullcircle.py
+++ b/generalresearch/wall_status_codes/fullcircle.py
@@ -7,11 +7,11 @@ we'll try to infer based on the time spent in survey.
from collections import defaultdict
from datetime import timedelta
-from typing import Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from generalresearch.models.thl.definitions import Status, StatusCode1
-status_codes_map = {
+status_codes_map: Dict[str, str] = {
"1": "Complete",
"2": "Terminate",
"3": "Over-quota",
@@ -19,7 +19,7 @@ status_codes_map = {
}
status_map = defaultdict(lambda: Status.FAIL, **{"1": Status.COMPLETE})
-status_codes_ext_map = {
+status_codes_ext_map: Dict[StatusCode1, List[str]] = {
StatusCode1.COMPLETE: ["1"],
StatusCode1.BUYER_FAIL: ["2", "3"],
StatusCode1.BUYER_QUALITY_FAIL: ["4"],
@@ -29,9 +29,13 @@ status_codes_ext_map = {
StatusCode1.PS_FAIL: [],
StatusCode1.PS_OVERQUOTA: [],
}
-ext_status_code_map = dict()
+ext_status_code_map: Dict[str, StatusCode1] = dict()
for k, v in status_codes_ext_map.items():
+ k: StatusCode1
+ v: List[str]
+
for vv in v:
+ vv: str
ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k
@@ -39,7 +43,7 @@ def annotate_status_code(
ext_status_code_1: str,
ext_status_code_2: Optional[str] = None,
ext_status_code_3: Optional[str] = None,
-) -> Tuple:
+) -> Tuple[Status, StatusCode1, Optional[Any]]:
"""
:params ext_status_code_1: this is from the callback url param 's'
:params ext_status_code_2: not used
diff --git a/generalresearch/wall_status_codes/innovate.py b/generalresearch/wall_status_codes/innovate.py
index 0d11425..b650ade 100644
--- a/generalresearch/wall_status_codes/innovate.py
+++ b/generalresearch/wall_status_codes/innovate.py
@@ -9,11 +9,11 @@ can map directly, and some we have to look at the category.
"""
from collections import defaultdict
-from typing import Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from generalresearch.models.thl.definitions import Status, StatusCode1
-status_codes_innovate = {
+status_codes_innovate: Dict[str, str] = {
"1": "Complete",
"2": "Buyer Fail",
"3": "Buyer Over Quota",
@@ -29,7 +29,7 @@ status_map = defaultdict(
lambda: Status.FAIL,
**{"1": Status.COMPLETE, "0": Status.ABANDON, "6": Status.ABANDON},
)
-status_codes_ext_map = {
+status_codes_ext_map: Dict[StatusCode1, List[str]] = {
StatusCode1.BUYER_FAIL: ["2", "3"],
StatusCode1.BUYER_QUALITY_FAIL: ["4"],
StatusCode1.PS_BLOCKED: [],
@@ -43,14 +43,14 @@ for k, v in status_codes_ext_map.items():
for vv in v:
ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k
-category_innovate = {
+category_innovate: Dict[str, StatusCode1] = {
"Selected threat potential score at joblevel not allow the survey": StatusCode1.PS_QUALITY,
"OE Validation": StatusCode1.PS_QUALITY,
"Unique IP": StatusCode1.PS_DUPLICATE,
"Unique PID": StatusCode1.PS_DUPLICATE,
# 'Duplicated to token {token} and Group {groupID}': StatusCode1.PS_DUPLICATE,
# 'Duplicate Due to Multi Groups: Token {token} and Group {groupID}': StatusCode1.PS_DUPLICATE,
- # todo: we should not send them into this marketplace for a day?
+ # TODO: we should not send them into this marketplace for a day?
"User has attended {count} survey in 5 range": StatusCode1.PS_FAIL,
"PII_OPT": StatusCode1.PS_QUALITY,
"Recaptcha": StatusCode1.PS_QUALITY,
@@ -80,7 +80,7 @@ def annotate_status_code(
ext_status_code_1: str,
ext_status_code_2: Optional[str] = None,
ext_status_code_3: Optional[str] = None,
-) -> Tuple:
+) -> Tuple[Status, StatusCode1, Optional[Any]]:
"""
Only quality terminate (4 and 8), and PS term (5) return a term_reason (af=).
@@ -93,13 +93,16 @@ def annotate_status_code(
status = status_map[ext_status_code_1]
if status == Status.COMPLETE:
return status, StatusCode1.COMPLETE, None
+
# First use the 1 through 8 status code. Then, using the reason (af=), if available,
# try to maybe reclassify it.
if ext_status_code_1 not in ext_status_code_map:
return status, StatusCode1.UNKNOWN, None
+
status_code = ext_status_code_map.get(ext_status_code_1, StatusCode1.UNKNOWN)
if ext_status_code_2 in category_innovate:
status_code = category_innovate[ext_status_code_2]
+
if ext_status_code_2:
# Some of these have ids in them... so we have to pattern match it
if (
diff --git a/generalresearch/wall_status_codes/lucid.py b/generalresearch/wall_status_codes/lucid.py
index bf49b45..c0098e6 100644
--- a/generalresearch/wall_status_codes/lucid.py
+++ b/generalresearch/wall_status_codes/lucid.py
@@ -5,11 +5,11 @@ https://support.lucidhq.com/s/article/Collecting-Data-From-Redirects
"""
from collections import defaultdict
-from typing import Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from generalresearch.models.thl.definitions import Status, StatusCode1
-mp_codes = {
+mp_codes: Dict[str, str] = {
"-6": "Pre-Client Intermediary Page Drop Off",
"-5": "Failure in the Post Answer Behavior",
"-1": "Failure to Load the Lucid Marketplace",
@@ -54,7 +54,7 @@ mp_codes = {
}
# todo: finish, there's a bunch more
-client_status_map = {
+client_status_map: Dict[str, StatusCode1] = {
"30": StatusCode1.BUYER_QUALITY_FAIL,
"33": StatusCode1.BUYER_QUALITY_FAIL,
"34": StatusCode1.BUYER_QUALITY_FAIL,
@@ -62,7 +62,7 @@ client_status_map = {
}
status_map = defaultdict(lambda: Status.FAIL, **{"s": Status.COMPLETE})
-status_codes_ext_map = {
+status_codes_ext_map: Dict[StatusCode1, List[str]] = {
StatusCode1.COMPLETE: [],
StatusCode1.BUYER_FAIL: ["3"],
StatusCode1.BUYER_QUALITY_FAIL: [],
@@ -101,9 +101,15 @@ status_codes_ext_map = {
],
StatusCode1.PS_OVERQUOTA: ["40", "41", "42"],
}
-ext_status_code_map = dict()
+
+ext_status_code_map: Dict[str, StatusCode1] = dict()
for k, v in status_codes_ext_map.items():
+ k: StatusCode1
+ v: List[str]
+
for vv in v:
+ vv: str
+
ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k
@@ -111,18 +117,27 @@ def annotate_status_code(
ext_status_code_1: str,
ext_status_code_2: Optional[str] = None,
ext_status_code_3: Optional[str] = None,
-) -> Tuple:
+) -> Tuple[Status, StatusCode1, Optional[Any]]:
"""
:params ext_status_code_1: this indicates which callback url was hit. possible values {'s', *anything else*}
:params ext_status_code_2: this is from the callback url params: InitialStatus
:params ext_status_code_3: this is from the callback url params: ClientStatus
+
returns: (status, status_code_1, status_code_2)
"""
- status = status_map[ext_status_code_1]
+ status: Status = status_map[ext_status_code_1]
+
if ext_status_code_2 == "3":
- status_code = client_status_map.get(ext_status_code_3, StatusCode1.BUYER_FAIL)
+ status_code = client_status_map.get(
+ key=ext_status_code_3, default=StatusCode1.BUYER_FAIL
+ )
+
else:
- status_code = ext_status_code_map.get(ext_status_code_2, StatusCode1.UNKNOWN)
+ status_code = ext_status_code_map.get(
+ key=ext_status_code_2, default=StatusCode1.UNKNOWN
+ )
+
if status == Status.COMPLETE:
status_code = StatusCode1.COMPLETE
+
return status, status_code, None
diff --git a/generalresearch/wall_status_codes/morning.py b/generalresearch/wall_status_codes/morning.py
index 2539b55..318d1b2 100644
--- a/generalresearch/wall_status_codes/morning.py
+++ b/generalresearch/wall_status_codes/morning.py
@@ -1,5 +1,5 @@
from collections import defaultdict
-from typing import Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from generalresearch.models.thl.definitions import Status, StatusCode1
@@ -17,7 +17,7 @@ timeout: The respondent completed the survey after the timeout period had expire
in_progress: The respondent interview session is still in progress, such as in the prescreener or survey.
"""
-short_code_to_status_codes_morning = {
+short_code_to_status_codes_morning: Dict[str, str] = {
"att_che": "attention_check",
"banned": "banned",
"bid_clo": "bid_closed",
@@ -53,7 +53,8 @@ short_code_to_status_codes_morning = {
"tem_ban": "temporarily_banned",
}
status_map = defaultdict(lambda: Status.FAIL, **{"complete": Status.COMPLETE})
-status_codes_ext_map = {
+
+status_codes_ext_map: Dict[StatusCode1, List[str]] = {
StatusCode1.COMPLETE: ["complete"],
StatusCode1.BUYER_FAIL: [
"in_survey_failure",
@@ -96,9 +97,13 @@ status_codes_ext_map = {
"quota_invalid_for_bid",
],
}
-ext_status_code_map = dict()
+ext_status_code_map: Dict[str, StatusCode1] = dict()
for k, v in status_codes_ext_map.items():
+ k: StatusCode1
+ v: List[str]
+
for vv in v:
+ vv: str
ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k
@@ -106,7 +111,7 @@ def annotate_status_code(
ext_status_code_1: str,
ext_status_code_2: Optional[str] = None,
ext_status_code_3: Optional[str] = None,
-) -> Tuple:
+) -> Tuple[Status, StatusCode1, Optional[Any]]:
"""
:params ext_status_code_1: from callback url params: &sti={{status_id}}
:params ext_status_code_2: from callback url params: &sdi={{status_detail_id}}
diff --git a/generalresearch/wall_status_codes/pollfish.py b/generalresearch/wall_status_codes/pollfish.py
index fbe1169..361785d 100644
--- a/generalresearch/wall_status_codes/pollfish.py
+++ b/generalresearch/wall_status_codes/pollfish.py
@@ -1,9 +1,9 @@
from collections import defaultdict
-from typing import Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from generalresearch.models.thl.definitions import Status, StatusCode1
-status_codes_map = {
+status_codes_map: Dict[str, str] = {
"quo_ful": "quota_full",
"sur_clo": "survey_closed",
"profilin": "profiling",
@@ -29,7 +29,7 @@ status_codes_map = {
"complete": "complete",
}
status_map = defaultdict(lambda: Status.FAIL, **{"complete": Status.COMPLETE})
-status_codes_ext_map = {
+status_codes_ext_map: Dict[StatusCode1, List[str]] = {
StatusCode1.COMPLETE: ["complete"],
StatusCode1.BUYER_FAIL: ["third_party_termination", "screenout"],
StatusCode1.BUYER_QUALITY_FAIL: [
@@ -60,7 +60,11 @@ status_codes_ext_map = {
}
ext_status_code_map = dict()
for k, v in status_codes_ext_map.items():
+ k: StatusCode1
+ v: List[str]
+
for vv in v:
+ vv: str
ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k
@@ -68,11 +72,12 @@ def annotate_status_code(
ext_status_code_1: str,
ext_status_code_2: Optional[str] = None,
ext_status_code_3: Optional[str] = None,
-) -> Tuple:
+) -> Tuple[Status, StatusCode1, Optional[Any]]:
"""
:params ext_status_code_1: from callback url params: &sti={{status_id}}
:params ext_status_code_2: from callback url params: &sdi={{status_detail_id}}
:params ext_status_code_3: not used
+
returns: (status, status_code_1, status_code_2)
"""
status = status_map[ext_status_code_1]
diff --git a/generalresearch/wall_status_codes/precision.py b/generalresearch/wall_status_codes/precision.py
index 147685a..ffeaeca 100644
--- a/generalresearch/wall_status_codes/precision.py
+++ b/generalresearch/wall_status_codes/precision.py
@@ -7,11 +7,11 @@ f - client approved the Preliminary complete as Final Complete
"""
from collections import defaultdict
-from typing import Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from generalresearch.models.thl.definitions import Status, StatusCode1
-status_codes_precision = {
+status_codes_precision: Dict[str, str] = {
"10": "Complete",
"20": "Client Terminate",
"21": "PS Terminate",
@@ -46,7 +46,7 @@ status_codes_precision = {
"80": "Final Complete",
}
status_map = defaultdict(lambda: Status.FAIL, **{"s": Status.COMPLETE})
-status_codes_ext_map = {
+status_codes_ext_map: Dict[StatusCode1, List[str]] = {
StatusCode1.COMPLETE: ["10"],
StatusCode1.BUYER_FAIL: ["20", "30"],
StatusCode1.BUYER_QUALITY_FAIL: ["60"],
@@ -75,7 +75,11 @@ status_codes_ext_map = {
}
ext_status_code_map = dict()
for k, v in status_codes_ext_map.items():
+ k: StatusCode1
+ v: List[str]
+
for vv in v:
+ vv: str
ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k
@@ -83,7 +87,7 @@ def annotate_status_code(
ext_status_code_1: str,
ext_status_code_2: Optional[str] = None,
ext_status_code_3: Optional[str] = None,
-) -> Tuple:
+) -> Tuple[Status, StatusCode1, Optional[Any]]:
"""
:params ext_status_code_1: from callback url params: status
:params ext_status_code_2: from callback url params: code
diff --git a/generalresearch/wall_status_codes/prodege.py b/generalresearch/wall_status_codes/prodege.py
index 8ce4d22..a1e25ea 100644
--- a/generalresearch/wall_status_codes/prodege.py
+++ b/generalresearch/wall_status_codes/prodege.py
@@ -3,12 +3,12 @@ https://developer.prodege.com/surveys-feed/term-reasons
"""
from collections import defaultdict
-from typing import Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from generalresearch.models.thl.definitions import Status, StatusCode1
status_map = defaultdict(lambda: Status.FAIL, **{"1": Status.COMPLETE})
-status_code_map = {
+status_code_map: Dict[StatusCode1, List[str]] = {
StatusCode1.COMPLETE: [],
StatusCode1.BUYER_FAIL: ["1", "2"],
StatusCode1.BUYER_QUALITY_FAIL: ["10", "12"],
@@ -33,7 +33,11 @@ status_code_map = {
status_class = dict()
for k, v in status_code_map.items():
+ k: StatusCode1
+ v: List[str]
+
for vv in v:
+ vv: str
status_class[status_code_map.get(vv, vv)] = k
@@ -41,7 +45,7 @@ def annotate_status_code(
ext_status_code_1: str,
ext_status_code_2: Optional[str] = None,
ext_status_code_3: Optional[str] = None,
-) -> Tuple:
+) -> Tuple[Status, StatusCode1, Optional[Any]]:
"""
:params ext_status_code_1: status from redirect url
:params ext_status_code_2: termreason from redirect url
diff --git a/generalresearch/wall_status_codes/repdata.py b/generalresearch/wall_status_codes/repdata.py
index 5e575c4..c502532 100644
--- a/generalresearch/wall_status_codes/repdata.py
+++ b/generalresearch/wall_status_codes/repdata.py
@@ -3,11 +3,11 @@ Status codes are in a xlsx file. See thl-repdata readme
"""
from collections import defaultdict
-from typing import Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from generalresearch.models.thl.definitions import Status, StatusCode1
-status_codes_name = {
+status_codes_name: Dict[str, str] = {
"2": "Search Failed",
"3": "Activity Failed",
"4": "Review Failed",
@@ -26,7 +26,7 @@ status_codes_name = {
"6003": "In-Survey maximum exceeded (Research Desk)",
}
# See: 02, and 13 are de-dupes
-rd_threat_name = {
+rd_threat_name: Dict[str, str] = {
"02": "Duplicate entrant into survey",
"03": "Emulator Usage",
"04": "VPN usage detected",
@@ -47,7 +47,7 @@ rd_threat_name = {
}
status_map = defaultdict(lambda: Status.FAIL, **{"complete": Status.COMPLETE})
-status_code_map = {
+status_code_map: Dict[StatusCode1, List[str]] = {
StatusCode1.COMPLETE: ["1000"],
StatusCode1.BUYER_FAIL: ["2000", "4000"],
StatusCode1.BUYER_QUALITY_FAIL: ["3000"],
@@ -60,7 +60,11 @@ status_code_map = {
status_class = dict()
for k, v in status_code_map.items():
+ k: StatusCode1
+ v: List[str]
+
for vv in v:
+ vv: str
status_class[status_code_map.get(vv, vv)] = k
@@ -68,7 +72,7 @@ def annotate_status_code(
ext_status_code_1: str,
ext_status_code_2: Optional[str] = None,
ext_status_code_3: Optional[str] = None,
-) -> Tuple:
+) -> Tuple[Status, StatusCode1, Optional[Any]]:
"""
:params ext_status_code_1: the redirect urls category (as defined in url param 549f3710b)
{'term', 'overquota', 'fraud', 'complete'}
diff --git a/generalresearch/wall_status_codes/sago.py b/generalresearch/wall_status_codes/sago.py
index ac354ce..9c8710f 100644
--- a/generalresearch/wall_status_codes/sago.py
+++ b/generalresearch/wall_status_codes/sago.py
@@ -3,11 +3,11 @@ https://developer-beta.market-cube.com/api-details#api=definition-api&operation=
"""
from collections import defaultdict
-from typing import Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from generalresearch.models.thl.definitions import Status, StatusCode1
-status_codes_schlesinger = {
+status_codes_schlesinger: Dict[str, str] = {
"1": "Complete",
"2": "Buyer Fail",
"3": "Buyer Fail",
@@ -20,7 +20,7 @@ status_codes_schlesinger = {
"11": "Abandon", # really it is "Buyer Abandon"
}
-status_reason_name = {
+status_reason_name: Dict[str, str] = {
"1": "Not a Unique Sample Cube User",
"4": "GeoIP - wrong country",
"7": "Duplicate - not a unique IP",
@@ -121,7 +121,7 @@ status_map = defaultdict(
lambda: Status.FAIL, **{"1": Status.COMPLETE, "0": Status.ABANDON}
)
-status_codes_ext_map = {
+status_codes_ext_map: Dict[StatusCode1, List[str]] = {
StatusCode1.COMPLETE: ["48"],
StatusCode1.BUYER_FAIL: ["16", "29", "49", "50", "78", "114", "110", "114"],
StatusCode1.BUYER_QUALITY_FAIL: ["26", "52", "68", "81", "84"],
@@ -167,9 +167,13 @@ status_codes_ext_map = {
StatusCode1.PS_FAIL: ["7", "29", "36", "47", "56", "58", "64"],
StatusCode1.PS_OVERQUOTA: ["29", "46", "33", "31"],
}
-ext_status_code_map = dict()
+ext_status_code_map: Dict[str, StatusCode1] = dict()
for k, v in status_codes_ext_map.items():
+ k: StatusCode1
+ v: List[str]
+
for vv in v:
+ vv: str
ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k
@@ -177,17 +181,20 @@ def annotate_status_code(
ext_status_code_1: str,
ext_status_code_2: Optional[str] = None,
ext_status_code_3: Optional[str] = None,
-) -> Tuple:
+) -> Tuple[Status, StatusCode1, Optional[Any]]:
"""
:params ext_status_code_1: from callback url params: scstatus
:params ext_status_code_2: from callback url params: scsecuritystatus
:params ext_status_code_3: not used
+
returns: (status, status_code_1, status_code_2)
"""
status = status_map[ext_status_code_1]
status_code = ext_status_code_map.get(ext_status_code_2, StatusCode1.UNKNOWN)
# According to personal communication, scsecuritystatus may not always
# come back for completes. Going to ignore it if the status is complete
+
if status == Status.COMPLETE:
status_code = StatusCode1.COMPLETE
+
return status, status_code, None
diff --git a/generalresearch/wall_status_codes/spectrum.py b/generalresearch/wall_status_codes/spectrum.py
index 0cf5814..9e9e0a2 100644
--- a/generalresearch/wall_status_codes/spectrum.py
+++ b/generalresearch/wall_status_codes/spectrum.py
@@ -3,11 +3,11 @@ https://purespectrum.atlassian.net/wiki/spaces/PA/pages/33613201/Minimizing+Clic
"""
from collections import defaultdict
-from typing import Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from generalresearch.models.thl.definitions import Status, StatusCode1
-status_codes_spectrum = {
+status_codes_spectrum: Dict[str, str] = {
"11": "PS Drop",
"12": "PS Quota Full Core",
"13": "PS Termination Core",
@@ -80,7 +80,7 @@ status_codes_spectrum = {
"88": "PS_Supplier_Allocation_Throttle",
}
status_map = defaultdict(lambda: Status.FAIL, **{"21": Status.COMPLETE})
-status_codes_ext_map = {
+status_codes_ext_map: Dict[StatusCode1, List[str]] = {
StatusCode1.COMPLETE: ["21"],
StatusCode1.BUYER_FAIL: ["16", "17", "18", "19", "30", "59", "84"],
StatusCode1.BUYER_QUALITY_FAIL: ["20", "31"],
@@ -142,7 +142,11 @@ status_codes_ext_map = {
}
ext_status_code_map = dict()
for k, v in status_codes_ext_map.items():
+ k: StatusCode1
+ v: List[str]
+
for vv in v:
+ vv: str
ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k
@@ -150,7 +154,7 @@ def annotate_status_code(
ext_status_code_1: str,
ext_status_code_2: Optional[str] = None,
ext_status_code_3: Optional[str] = None,
-) -> Tuple:
+) -> Tuple[Status, StatusCode1, Optional[Any]]:
"""
:params ext_status_code_1: from url params: ps_rstatus
https://purespectrum.atlassian.net/wiki/spaces/PA/pages/33613201/Minimizing+Clickwaste+with+ps+rstatus
diff --git a/generalresearch/wall_status_codes/wxet.py b/generalresearch/wall_status_codes/wxet.py
index 1cbf568..e7cf67d 100644
--- a/generalresearch/wall_status_codes/wxet.py
+++ b/generalresearch/wall_status_codes/wxet.py
@@ -1,7 +1,7 @@
from collections import defaultdict
-from typing import Optional, Dict, Tuple
+from typing import Dict, List, Optional, Tuple
-from generalresearch.models.thl.definitions import StatusCode1, Status
+from generalresearch.models.thl.definitions import Status, StatusCode1
from generalresearch.wxet.models.definitions import (
WXETStatus,
WXETStatusCode1,
@@ -11,7 +11,7 @@ from generalresearch.wxet.models.definitions import (
status_map: Dict[WXETStatus, Status] = defaultdict(
lambda: Status.FAIL, **{WXETStatus.COMPLETE: Status.COMPLETE}
)
-status_codes_ext_map = {
+status_codes_ext_map: Dict[StatusCode1, List[WXETStatusCode1]] = {
StatusCode1.COMPLETE: [WXETStatusCode1.COMPLETE],
StatusCode1.BUYER_FAIL: [
WXETStatusCode1.BUYER_DUPLICATE,
@@ -32,10 +32,14 @@ status_codes_ext_map = {
}
ext_status_code_map = dict()
for k, v in status_codes_ext_map.items():
+ k: StatusCode1
+ v: List[WXETStatusCode1]
+
for vv in v:
+ vv: WXETStatusCode1
ext_status_code_map[vv] = k
-status_code2_map = {
+status_code2_map: Dict[StatusCode1, List[WXETStatusCode2]] = {
StatusCode1.PS_QUALITY: [],
StatusCode1.PS_DUPLICATE: [
WXETStatusCode2.WORKER_INELIGIBLE,
@@ -65,11 +69,12 @@ def annotate_status_code(
ext_status_code_1: str,
ext_status_code_2: Optional[str] = None,
ext_status_code_3: Optional[str] = None,
-) -> Tuple:
+) -> Tuple[Status, StatusCode1, Optional[WXETStatusCode2]]:
"""
:params ext_status_code_1: WXETStatus
:params ext_status_code_2: WXETStatusCode1
:params ext_status_code_3: WXETStatusCode2
+
returns: (status, status_code_1, status_code_2)
"""
ext_status_code_1 = WXETStatus(ext_status_code_1)
diff --git a/generalresearch/wxet/models/definitions.py b/generalresearch/wxet/models/definitions.py
index 50853dd..801346d 100644
--- a/generalresearch/wxet/models/definitions.py
+++ b/generalresearch/wxet/models/definitions.py
@@ -39,10 +39,13 @@ class WXETStatus(str, Enum, metaclass=ReprEnumMeta):
class WXETAdjustedStatus(str, Enum, metaclass=ReprEnumMeta):
# Task was reconciled to complete
ADJUSTED_TO_COMPLETE = "ac"
+
# Task was reconciled to incomplete
ADJUSTED_TO_FAIL = "af"
+
# The cpi for a task was adjusted
CPI_ADJUSTMENT = "ca"
+
# The user was redirected without a Postback, but the postback was then "immediately"
# recieved. The supplier thinks this was a failure. This is distinct from an
# actual adjustment to complete.
diff --git a/generalresearch/wxet/models/finish_type.py b/generalresearch/wxet/models/finish_type.py
index 2aa4c7f..af60fe6 100644
--- a/generalresearch/wxet/models/finish_type.py
+++ b/generalresearch/wxet/models/finish_type.py
@@ -1,5 +1,5 @@
from enum import Enum
-from typing import Set, Optional
+from typing import Optional, Set
from generalresearch.utils.enum import ReprEnumMeta
from generalresearch.wxet.models.definitions import WXETStatus, WXETStatusCode1
diff --git a/test_utils/conftest.py b/test_utils/conftest.py
index e074301..54fb682 100644
--- a/test_utils/conftest.py
+++ b/test_utils/conftest.py
@@ -247,13 +247,16 @@ def utc_30days_ago() -> "datetime":
@pytest.fixture(scope="function")
-def delete_df_collection(thl_web_rw, create_main_accounts) -> Callable:
+def delete_df_collection(
+ thl_web_rw: PostgresConfig, create_main_accounts: Callable[..., None]
+) -> Callable[..., None]:
+
from generalresearch.incite.collections import (
DFCollection,
DFCollectionType,
)
- def _teardown_events(coll: "DFCollection"):
+ def _inner(coll: "DFCollection"):
match coll.data_type:
case DFCollectionType.LEDGER:
for table in [
@@ -290,7 +293,7 @@ def delete_df_collection(thl_web_rw, create_main_accounts) -> Callable:
query=f"DELETE FROM {coll.data_type.value};",
)
- return _teardown_events
+ return _inner
# === GR Related ===
diff --git a/test_utils/managers/conftest.py b/test_utils/managers/conftest.py
index f1f774e..94dabae 100644
--- a/test_utils/managers/conftest.py
+++ b/test_utils/managers/conftest.py
@@ -1,15 +1,18 @@
from typing import TYPE_CHECKING, Callable
-import pymysql
import pytest
from generalresearch.managers.base import Permission
from generalresearch.models import Source
+from generalresearch.pg_helper import PostgresConfig
+from generalresearch.redis_helper import RedisConfig
+from generalresearch.sql_helper import SqlHelper
from test_utils.managers.cashout_methods import (
EXAMPLE_TANGO_CASHOUT_METHODS,
)
if TYPE_CHECKING:
+ from generalresearch.config import GRLBaseSettings
from generalresearch.grliq.managers.forensic_data import (
GrlIqDataManager,
)
@@ -32,6 +35,7 @@ if TYPE_CHECKING:
MembershipManager,
TeamManager,
)
+ from generalresearch.managers.thl.category import CategoryManager
from generalresearch.managers.thl.contest_manager import ContestManager
from generalresearch.managers.thl.ipinfo import (
GeoIpInfoManager,
@@ -84,7 +88,9 @@ if TYPE_CHECKING:
@pytest.fixture(scope="session")
-def ltxm(thl_web_rw, thl_redis_config) -> "LedgerTransactionManager":
+def ltxm(
+ thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig
+) -> "LedgerTransactionManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.ledger_manager.ledger import (
@@ -100,7 +106,9 @@ def ltxm(thl_web_rw, thl_redis_config) -> "LedgerTransactionManager":
@pytest.fixture(scope="session")
-def lam(thl_web_rw, thl_redis_config) -> "LedgerAccountManager":
+def lam(
+ thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig
+) -> "LedgerAccountManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.ledger_manager.ledger import (
@@ -116,7 +124,7 @@ def lam(thl_web_rw, thl_redis_config) -> "LedgerAccountManager":
@pytest.fixture(scope="session")
-def lm(thl_web_rw, thl_redis_config) -> "LedgerManager":
+def lm(thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig) -> "LedgerManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.ledger_manager.ledger import (
@@ -137,7 +145,9 @@ def lm(thl_web_rw, thl_redis_config) -> "LedgerManager":
@pytest.fixture(scope="session")
-def thl_lm(thl_web_rw, thl_redis_config) -> "ThlLedgerManager":
+def thl_lm(
+ thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig
+) -> "ThlLedgerManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.ledger_manager.thl_ledger import (
@@ -158,7 +168,9 @@ def thl_lm(thl_web_rw, thl_redis_config) -> "ThlLedgerManager":
@pytest.fixture(scope="session")
-def payout_event_manager(thl_web_rw, thl_redis_config) -> "PayoutEventManager":
+def payout_event_manager(
+ thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig
+) -> "PayoutEventManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.payout import PayoutEventManager
@@ -171,7 +183,9 @@ def payout_event_manager(thl_web_rw, thl_redis_config) -> "PayoutEventManager":
@pytest.fixture(scope="session")
-def user_payout_event_manager(thl_web_rw, thl_redis_config) -> "UserPayoutEventManager":
+def user_payout_event_manager(
+ thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig
+) -> "UserPayoutEventManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.payout import UserPayoutEventManager
@@ -185,7 +199,7 @@ def user_payout_event_manager(thl_web_rw, thl_redis_config) -> "UserPayoutEventM
@pytest.fixture(scope="session")
def brokerage_product_payout_event_manager(
- thl_web_rw, thl_redis_config
+ thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig
) -> "BrokerageProductPayoutEventManager":
assert "/unittest-" in thl_web_rw.dsn.path
@@ -202,7 +216,7 @@ def brokerage_product_payout_event_manager(
@pytest.fixture(scope="session")
def business_payout_event_manager(
- thl_web_rw, thl_redis_config
+ thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig
) -> "BusinessPayoutEventManager":
assert "/unittest-" in thl_web_rw.dsn.path
@@ -218,7 +232,7 @@ def business_payout_event_manager(
@pytest.fixture(scope="session")
-def product_manager(thl_web_rw) -> "ProductManager":
+def product_manager(thl_web_rw: PostgresConfig) -> "ProductManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.product import ProductManager
@@ -227,7 +241,9 @@ def product_manager(thl_web_rw) -> "ProductManager":
@pytest.fixture(scope="session")
-def user_manager(settings, thl_web_rw, thl_web_rr) -> "UserManager":
+def user_manager(
+ settings: "GRLBaseSettings", thl_web_rw: PostgresConfig, thl_web_rr: PostgresConfig
+) -> "UserManager":
assert "/unittest-" in thl_web_rw.dsn.path
assert "/unittest-" in thl_web_rr.dsn.path
@@ -243,7 +259,7 @@ def user_manager(settings, thl_web_rw, thl_web_rr) -> "UserManager":
@pytest.fixture(scope="session")
-def user_metadata_manager(thl_web_rw) -> "UserMetadataManager":
+def user_metadata_manager(thl_web_rw: PostgresConfig) -> "UserMetadataManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.user_manager.user_metadata_manager import (
@@ -254,7 +270,7 @@ def user_metadata_manager(thl_web_rw) -> "UserMetadataManager":
@pytest.fixture(scope="session")
-def session_manager(thl_web_rw) -> "SessionManager":
+def session_manager(thl_web_rw: PostgresConfig) -> "SessionManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.session import SessionManager
@@ -263,7 +279,7 @@ def session_manager(thl_web_rw) -> "SessionManager":
@pytest.fixture(scope="session")
-def wall_manager(thl_web_rw) -> "WallManager":
+def wall_manager(thl_web_rw: PostgresConfig) -> "WallManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.wall import WallManager
@@ -272,7 +288,9 @@ def wall_manager(thl_web_rw) -> "WallManager":
@pytest.fixture(scope="session")
-def wall_cache_manager(thl_web_rw, thl_redis_config) -> "WallCacheManager":
+def wall_cache_manager(
+ thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig
+) -> "WallCacheManager":
# assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.wall import WallCacheManager
@@ -281,7 +299,7 @@ def wall_cache_manager(thl_web_rw, thl_redis_config) -> "WallCacheManager":
@pytest.fixture(scope="session")
-def task_adjustment_manager(thl_web_rw) -> "TaskAdjustmentManager":
+def task_adjustment_manager(thl_web_rw: PostgresConfig) -> "TaskAdjustmentManager":
# assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.task_adjustment import (
@@ -292,7 +310,7 @@ def task_adjustment_manager(thl_web_rw) -> "TaskAdjustmentManager":
@pytest.fixture(scope="session")
-def contest_manager(thl_web_rw) -> "ContestManager":
+def contest_manager(thl_web_rw: PostgresConfig) -> "ContestManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.contest_manager import ContestManager
@@ -309,7 +327,7 @@ def contest_manager(thl_web_rw) -> "ContestManager":
@pytest.fixture(scope="session")
-def category_manager(thl_web_rw):
+def category_manager(thl_web_rw: PostgresConfig) -> "CategoryManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.category import CategoryManager
@@ -317,7 +335,7 @@ def category_manager(thl_web_rw):
@pytest.fixture(scope="session")
-def buyer_manager(thl_web_rw):
+def buyer_manager(thl_web_rw: PostgresConfig):
# assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.buyer import BuyerManager
@@ -325,7 +343,7 @@ def buyer_manager(thl_web_rw):
@pytest.fixture(scope="session")
-def survey_manager(thl_web_rw):
+def survey_manager(thl_web_rw: PostgresConfig):
# assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.survey import SurveyManager
@@ -333,7 +351,7 @@ def survey_manager(thl_web_rw):
@pytest.fixture(scope="session")
-def surveystat_manager(thl_web_rw):
+def surveystat_manager(thl_web_rw: PostgresConfig):
# assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.survey import SurveyStatManager
@@ -348,7 +366,7 @@ def surveypenalty_manager(thl_redis_config):
@pytest.fixture(scope="session")
-def upk_schema_manager(thl_web_rw):
+def upk_schema_manager(thl_web_rw: PostgresConfig):
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.profiling.schema import (
UpkSchemaManager,
@@ -358,7 +376,7 @@ def upk_schema_manager(thl_web_rw):
@pytest.fixture(scope="session")
-def user_upk_manager(thl_web_rw, thl_redis_config):
+def user_upk_manager(thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig):
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.profiling.user_upk import (
UserUpkManager,
@@ -368,7 +386,7 @@ def user_upk_manager(thl_web_rw, thl_redis_config):
@pytest.fixture(scope="session")
-def question_manager(thl_web_rw, thl_redis_config):
+def question_manager(thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig):
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.profiling.question import (
QuestionManager,
@@ -378,7 +396,7 @@ def question_manager(thl_web_rw, thl_redis_config):
@pytest.fixture(scope="session")
-def uqa_manager(thl_web_rw, thl_redis_config):
+def uqa_manager(thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig):
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.profiling.uqa import UQAManager
@@ -395,7 +413,7 @@ def uqa_manager_clear_cache(uqa_manager, user):
@pytest.fixture(scope="session")
-def audit_log_manager(thl_web_rw) -> "AuditLogManager":
+def audit_log_manager(thl_web_rw: PostgresConfig) -> "AuditLogManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.userhealth import AuditLogManager
@@ -404,7 +422,7 @@ def audit_log_manager(thl_web_rw) -> "AuditLogManager":
@pytest.fixture(scope="session")
-def ip_geoname_manager(thl_web_rw) -> "IPGeonameManager":
+def ip_geoname_manager(thl_web_rw: PostgresConfig) -> "IPGeonameManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.ipinfo import IPGeonameManager
@@ -413,7 +431,7 @@ def ip_geoname_manager(thl_web_rw) -> "IPGeonameManager":
@pytest.fixture(scope="session")
-def ip_information_manager(thl_web_rw) -> "IPInformationManager":
+def ip_information_manager(thl_web_rw: PostgresConfig) -> "IPInformationManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.ipinfo import IPInformationManager
@@ -422,7 +440,9 @@ def ip_information_manager(thl_web_rw) -> "IPInformationManager":
@pytest.fixture(scope="session")
-def ip_record_manager(thl_web_rw, thl_redis_config) -> "IPRecordManager":
+def ip_record_manager(
+ thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig
+) -> "IPRecordManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.userhealth import IPRecordManager
@@ -431,7 +451,9 @@ def ip_record_manager(thl_web_rw, thl_redis_config) -> "IPRecordManager":
@pytest.fixture(scope="session")
-def user_iphistory_manager(thl_web_rw, thl_redis_config) -> "UserIpHistoryManager":
+def user_iphistory_manager(
+ thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig
+) -> "UserIpHistoryManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.userhealth import (
@@ -451,7 +473,9 @@ def user_iphistory_manager_clear_cache(user_iphistory_manager, user):
@pytest.fixture(scope="session")
-def geoipinfo_manager(thl_web_rw, thl_redis_config) -> "GeoIpInfoManager":
+def geoipinfo_manager(
+ thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig
+) -> "GeoIpInfoManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.ipinfo import GeoIpInfoManager
@@ -469,7 +493,9 @@ def maxmind_basic_manager() -> "MaxmindBasicManager":
@pytest.fixture(scope="session")
-def maxmind_manager(thl_web_rw, thl_redis_config) -> "MaxmindManager":
+def maxmind_manager(
+ thl_web_rw: PostgresConfig, thl_redis_config: RedisConfig
+) -> "MaxmindManager":
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.maxmind import MaxmindManager
@@ -478,7 +504,7 @@ def maxmind_manager(thl_web_rw, thl_redis_config) -> "MaxmindManager":
@pytest.fixture(scope="session")
-def cashout_method_manager(thl_web_rw):
+def cashout_method_manager(thl_web_rw: PostgresConfig):
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.cashout_method import (
CashoutMethodManager,
@@ -488,14 +514,14 @@ def cashout_method_manager(thl_web_rw):
@pytest.fixture(scope="session")
-def event_manager(thl_redis_config):
+def event_manager(thl_redis_config: RedisConfig):
from generalresearch.managers.events import EventManager
return EventManager(redis_config=thl_redis_config)
@pytest.fixture(scope="session")
-def user_streak_manager(thl_web_rw):
+def user_streak_manager(thl_web_rw: PostgresConfig):
assert "/unittest-" in thl_web_rw.dsn.path
from generalresearch.managers.thl.user_streak import (
UserStreakManager,
@@ -505,7 +531,7 @@ def user_streak_manager(thl_web_rw):
@pytest.fixture(scope="session")
-def uqa_db_index(thl_web_rw):
+def uqa_db_index(thl_web_rw: PostgresConfig):
# There were some custom indices created not through django.
# Make sure the index used in the index hint exists
assert "/unittest-" in thl_web_rw.dsn.path
@@ -521,7 +547,7 @@ def uqa_db_index(thl_web_rw):
@pytest.fixture(scope="session")
-def delete_cashoutmethod_db(thl_web_rw) -> Callable:
+def delete_cashoutmethod_db(thl_web_rw: PostgresConfig) -> Callable[..., None]:
def _delete_cashoutmethod_db():
thl_web_rw.execute_write(
query="DELETE FROM accounting_cashoutmethod;",
@@ -531,14 +557,21 @@ def delete_cashoutmethod_db(thl_web_rw) -> Callable:
@pytest.fixture(scope="session")
-def setup_cashoutmethod_db(settings, cashout_method_manager, delete_cashoutmethod_db):
- settings.amt_
-
+def setup_cashoutmethod_db(
+ settings: "GRLBaseSettings", cashout_method_manager, delete_cashoutmethod_db
+):
delete_cashoutmethod_db()
for x in EXAMPLE_TANGO_CASHOUT_METHODS:
cashout_method_manager.create(x)
- cashout_method_manager.create(AMT_ASSIGNMENT_CASHOUT_METHOD)
- cashout_method_manager.create(AMT_BONUS_CASHOUT_METHOD)
+
+ # TODO: convert these ids into instances to use.
+ # settings.amt_bonus_cashout_method_id
+ # settings.amt_assignment_cashout_method_id
+
+ # cashout_method_manager.create(AMT_ASSIGNMENT_CASHOUT_METHOD)
+ # cashout_method_manager.create(AMT_BONUS_CASHOUT_METHOD)
+ raise NotImplementedError("Need to implement setup_cashoutmethod_db")
+
return None
@@ -546,7 +579,7 @@ def setup_cashoutmethod_db(settings, cashout_method_manager, delete_cashoutmetho
@pytest.fixture(scope="session")
-def spectrum_manager(spectrum_rw):
+def spectrum_manager(spectrum_rw: SqlHelper) -> "SpectrumSurveyManager":
from generalresearch.managers.spectrum.survey import (
SpectrumSurveyManager,
)
@@ -556,7 +589,9 @@ def spectrum_manager(spectrum_rw):
# === GR ===
@pytest.fixture(scope="session")
-def business_manager(gr_db, gr_redis_config) -> "BusinessManager":
+def business_manager(
+ gr_db: PostgresConfig, gr_redis_config: RedisConfig
+) -> "BusinessManager":
from generalresearch.redis_helper import RedisConfig
assert "/unittest-" in gr_db.dsn.path
@@ -571,7 +606,7 @@ def business_manager(gr_db, gr_redis_config) -> "BusinessManager":
@pytest.fixture(scope="session")
-def business_address_manager(gr_db) -> "BusinessAddressManager":
+def business_address_manager(gr_db: PostgresConfig) -> "BusinessAddressManager":
assert "/unittest-" in gr_db.dsn.path
from generalresearch.managers.gr.business import BusinessAddressManager
@@ -580,7 +615,9 @@ def business_address_manager(gr_db) -> "BusinessAddressManager":
@pytest.fixture(scope="session")
-def business_bank_account_manager(gr_db) -> "BusinessBankAccountManager":
+def business_bank_account_manager(
+ gr_db: PostgresConfig,
+) -> "BusinessBankAccountManager":
assert "/unittest-" in gr_db.dsn.path
from generalresearch.managers.gr.business import (
@@ -591,7 +628,7 @@ def business_bank_account_manager(gr_db) -> "BusinessBankAccountManager":
@pytest.fixture(scope="session")
-def team_manager(gr_db, gr_redis_config) -> "TeamManager":
+def team_manager(gr_db: PostgresConfig, gr_redis_config: RedisConfig) -> "TeamManager":
assert "/unittest-" in gr_db.dsn.path
from generalresearch.managers.gr.team import TeamManager
@@ -600,7 +637,7 @@ def team_manager(gr_db, gr_redis_config) -> "TeamManager":
@pytest.fixture(scope="session")
-def gr_um(gr_db, gr_redis_config) -> "GRUserManager":
+def gr_um(gr_db: PostgresConfig, gr_redis_config: RedisConfig) -> "GRUserManager":
assert "/unittest-" in gr_db.dsn.path
from generalresearch.managers.gr.authentication import GRUserManager
@@ -609,7 +646,7 @@ def gr_um(gr_db, gr_redis_config) -> "GRUserManager":
@pytest.fixture(scope="session")
-def gr_tm(gr_db) -> "GRTokenManager":
+def gr_tm(gr_db: PostgresConfig) -> "GRTokenManager":
assert "/unittest-" in gr_db.dsn.path
from generalresearch.managers.gr.authentication import GRTokenManager
@@ -618,7 +655,7 @@ def gr_tm(gr_db) -> "GRTokenManager":
@pytest.fixture(scope="session")
-def membership_manager(gr_db) -> "MembershipManager":
+def membership_manager(gr_db: PostgresConfig) -> "MembershipManager":
assert "/unittest-" in gr_db.dsn.path
from generalresearch.managers.gr.team import MembershipManager
@@ -630,7 +667,7 @@ def membership_manager(gr_db) -> "MembershipManager":
@pytest.fixture(scope="session")
-def grliq_dm(grliq_db) -> "GrlIqDataManager":
+def grliq_dm(grliq_db: PostgresConfig) -> "GrlIqDataManager":
assert "/unittest-" in grliq_db.dsn.path
from generalresearch.grliq.managers.forensic_data import (
@@ -641,7 +678,7 @@ def grliq_dm(grliq_db) -> "GrlIqDataManager":
@pytest.fixture(scope="session")
-def grliq_em(grliq_db) -> "GrlIqEventManager":
+def grliq_em(grliq_db: PostgresConfig) -> "GrlIqEventManager":
assert "/unittest-" in grliq_db.dsn.path
from generalresearch.grliq.managers.forensic_events import (
@@ -652,7 +689,7 @@ def grliq_em(grliq_db) -> "GrlIqEventManager":
@pytest.fixture(scope="session")
-def grliq_crr(grliq_db) -> "GrlIqCategoryResultsReader":
+def grliq_crr(grliq_db: PostgresConfig) -> "GrlIqCategoryResultsReader":
assert "/unittest-" in grliq_db.dsn.path
from generalresearch.grliq.managers.forensic_results import (
@@ -663,7 +700,7 @@ def grliq_crr(grliq_db) -> "GrlIqCategoryResultsReader":
@pytest.fixture(scope="session")
-def delete_buyers_surveys(thl_web_rw, buyer_manager):
+def delete_buyers_surveys(thl_web_rw: PostgresConfig, buyer_manager):
# assert "/unittest-" in thl_web_rw.dsn.path
thl_web_rw.execute_write(
"""
diff --git a/test_utils/managers/upk/conftest.py b/test_utils/managers/upk/conftest.py
index 61be924..e28d085 100644
--- a/test_utils/managers/upk/conftest.py
+++ b/test_utils/managers/upk/conftest.py
@@ -1,6 +1,6 @@
import os
import time
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from uuid import UUID
import pandas as pd
@@ -8,6 +8,9 @@ import pytest
from generalresearch.pg_helper import PostgresConfig
+if TYPE_CHECKING:
+ from generalresearch.managers.thl.category import CategoryManager
+
def insert_data_from_csv(
thl_web_rw: PostgresConfig,
@@ -39,7 +42,9 @@ def insert_data_from_csv(
@pytest.fixture(scope="session")
-def category_data(thl_web_rw, category_manager) -> None:
+def category_data(
+ thl_web_rw: PostgresConfig, category_manager: "CategoryManager"
+) -> None:
fp = os.path.join(os.path.dirname(__file__), "marketplace_category.csv.gz")
insert_data_from_csv(
thl_web_rw,
@@ -66,20 +71,23 @@ def category_data(thl_web_rw, category_manager) -> None:
@pytest.fixture(scope="session")
-def property_data(thl_web_rw) -> None:
+def property_data(thl_web_rw: PostgresConfig) -> None:
fp = os.path.join(os.path.dirname(__file__), "marketplace_property.csv.gz")
insert_data_from_csv(thl_web_rw, fp=fp, table_name="marketplace_property")
@pytest.fixture(scope="session")
-def item_data(thl_web_rw) -> None:
+def item_data(thl_web_rw: PostgresConfig) -> None:
fp = os.path.join(os.path.dirname(__file__), "marketplace_item.csv.gz")
insert_data_from_csv(thl_web_rw, fp=fp, table_name="marketplace_item")
@pytest.fixture(scope="session")
def propertycategoryassociation_data(
- thl_web_rw, category_data, property_data, category_manager
+ thl_web_rw: PostgresConfig,
+ category_data,
+ property_data,
+ category_manager: "CategoryManager",
) -> None:
table_name = "marketplace_propertycategoryassociation"
fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz")
@@ -93,27 +101,31 @@ def propertycategoryassociation_data(
@pytest.fixture(scope="session")
-def propertycountry_data(thl_web_rw, property_data) -> None:
+def propertycountry_data(thl_web_rw: PostgresConfig, property_data) -> None:
fp = os.path.join(os.path.dirname(__file__), "marketplace_propertycountry.csv.gz")
insert_data_from_csv(thl_web_rw, fp=fp, table_name="marketplace_propertycountry")
@pytest.fixture(scope="session")
-def propertymarketplaceassociation_data(thl_web_rw, property_data) -> None:
+def propertymarketplaceassociation_data(
+ thl_web_rw: PostgresConfig, property_data
+) -> None:
table_name = "marketplace_propertymarketplaceassociation"
fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz")
insert_data_from_csv(thl_web_rw, fp=fp, table_name=table_name)
@pytest.fixture(scope="session")
-def propertyitemrange_data(thl_web_rw, property_data, item_data) -> None:
+def propertyitemrange_data(
+ thl_web_rw: PostgresConfig, property_data, item_data
+) -> None:
table_name = "marketplace_propertyitemrange"
fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz")
insert_data_from_csv(thl_web_rw, fp=fp, table_name=table_name)
@pytest.fixture(scope="session")
-def question_data(thl_web_rw) -> None:
+def question_data(thl_web_rw: PostgresConfig) -> None:
table_name = "marketplace_question"
fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz")
insert_data_from_csv(
@@ -122,7 +134,7 @@ def question_data(thl_web_rw) -> None:
@pytest.fixture(scope="session")
-def clear_upk_tables(thl_web_rw):
+def clear_upk_tables(thl_web_rw: PostgresConfig):
tables = [
"marketplace_propertyitemrange",
"marketplace_propertymarketplaceassociation",
diff --git a/test_utils/models/conftest.py b/test_utils/models/conftest.py
index dcc3b66..81dc11e 100644
--- a/test_utils/models/conftest.py
+++ b/test_utils/models/conftest.py
@@ -5,7 +5,6 @@ from random import randint
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
from uuid import uuid4
-import pytest
from pydantic import AwareDatetime, PositiveInt
from generalresearch.models import Source
@@ -14,6 +13,7 @@ from generalresearch.models.thl.definitions import (
Status,
)
from generalresearch.models.thl.survey.model import Buyer, Survey
+from generalresearch.pg_helper import PostgresConfig
from test_utils.managers.conftest import (
business_address_manager,
business_manager,
@@ -28,6 +28,7 @@ from test_utils.managers.conftest import (
if TYPE_CHECKING:
from generalresearch.currency import USDCent
+ from generalresearch.managers.thl.session import SessionManager
from generalresearch.models.gr.authentication import GRToken, GRUser
from generalresearch.models.gr.business import (
Business,
@@ -53,7 +54,7 @@ if TYPE_CHECKING:
@pytest.fixture(scope="function")
-def user(request, product_manager, user_manager, thl_web_rr) -> "User":
+def user(request, product_manager, user_manager, thl_web_rr: PostgresConfig) -> "User":
product = getattr(request, "product", None)
if product is None:
@@ -74,25 +75,29 @@ def user_with_wallet(
@pytest.fixture
-def user_with_wallet_amt(request, user_factory, product_amt_true: "Product") -> "User":
+def user_with_wallet_amt(
+ request, user_factory: Callable[..., "User"], product_amt_true: "Product"
+) -> "User":
# A user on a product with user wallet enabled, on AMT, but they have no money
return user_factory(product=product_amt_true)
@pytest.fixture(scope="function")
-def user_factory(user_manager, thl_web_rr) -> Callable:
- def _create_user(product: "Product", created: Optional[datetime] = None):
+def user_factory(user_manager, thl_web_rr: PostgresConfig) -> Callable[..., "User"]:
+
+ def _inner(product: "Product", created: Optional[datetime] = None):
u = user_manager.create_dummy(product=product, created=created)
u.prefetch_product(pg_config=thl_web_rr)
return u
- return _create_user
+ return _inner
@pytest.fixture(scope="function")
-def wall_factory(wall_manager) -> Callable:
- def _create_wall(
+def wall_factory(wall_manager) -> Callable[..., "Wall"]:
+
+ def _inner(
session: "Session", wall_status: "Status", req_cpi: Optional[Decimal] = None
):
@@ -126,10 +131,10 @@ def wall_factory(wall_manager) -> Callable:
return wall
- return _create_wall
+ return _inner
-@pytest.fixture(scope="function")
+@pytest.fixture
def wall(session, user, wall_manager) -> Optional["Wall"]:
from generalresearch.models.thl.task_status import StatusCode1
@@ -143,9 +148,12 @@ def wall(session, user, wall_manager) -> Optional["Wall"]:
return wall
-@pytest.fixture(scope="function")
+@pytest.fixture
def session_factory(
- wall_factory, session_manager, wall_manager, utc_hour_ago
+ wall_factory,
+ session_manager: "SessionManager",
+ wall_manager,
+ utc_hour_ago: datetime,
) -> Callable[..., "Session"]:
from generalresearch.models.thl.session import Source
@@ -208,11 +216,13 @@ def session_factory(
@pytest.fixture(scope="function")
def finished_session_factory(
- session_factory, session_manager, utc_hour_ago
-) -> Callable:
+ session_factory: Callable[..., "Session"],
+ session_manager: "SessionManager",
+ utc_hour_ago: datetime,
+) -> Callable[..., "Session"]:
from generalresearch.models.thl.session import Source
- def _create_finished_session(
+ def _inner(
user: "User",
# Wall details
wall_count: int = 5,
@@ -246,11 +256,11 @@ def finished_session_factory(
)
return s
- return _create_finished_session
+ return _inner
@pytest.fixture(scope="function")
-def session(user, session_manager, wall_manager) -> "Session":
+def session(user, session_manager: "SessionManager", wall_manager) -> "Session":
from generalresearch.models.thl.session import Session, Wall
session: Session = session_manager.create_dummy(user=user, country_iso="us")
@@ -294,7 +304,7 @@ def product_factory(product_manager) -> Callable:
return _create_product
-@pytest.fixture(scope="function")
+@pytest.fixture
def payout_config(request) -> "PayoutConfig":
from generalresearch.models.thl.product import (
PayoutConfig,
@@ -315,7 +325,7 @@ def payout_config(request) -> "PayoutConfig":
)
-@pytest.fixture(scope="function")
+@pytest.fixture
def product_user_wallet_yes(payout_config, product_manager) -> "Product":
from generalresearch.managers.thl.product import ProductManager
from generalresearch.models.thl.product import UserWalletConfig
@@ -326,7 +336,7 @@ def product_user_wallet_yes(payout_config, product_manager) -> "Product":
)
-@pytest.fixture(scope="function")
+@pytest.fixture
def product_user_wallet_no(product_manager) -> "Product":
from generalresearch.managers.thl.product import ProductManager
from generalresearch.models.thl.product import UserWalletConfig
@@ -337,7 +347,7 @@ def product_user_wallet_no(product_manager) -> "Product":
)
-@pytest.fixture(scope="function")
+@pytest.fixture
def product_amt_true(product_manager, payout_config) -> "Product":
from generalresearch.models.thl.product import UserWalletConfig
@@ -347,7 +357,7 @@ def product_amt_true(product_manager, payout_config) -> "Product":
)
-@pytest.fixture(scope="function")
+@pytest.fixture
def bp_payout_factory(
thl_lm, product_manager, business_payout_event_manager
) -> Callable:
@@ -380,7 +390,7 @@ def bp_payout_factory(
# === GR ===
-@pytest.fixture(scope="function")
+@pytest.fixture
def business(request, business_manager) -> "Business":
from generalresearch.managers.gr.business import BusinessManager
@@ -388,7 +398,7 @@ def business(request, business_manager) -> "Business":
return business_manager.create_dummy()
-@pytest.fixture(scope="function")
+@pytest.fixture
def business_address(request, business, business_address_manager) -> "BusinessAddress":
from generalresearch.managers.gr.business import BusinessAddressManager
@@ -396,7 +406,7 @@ def business_address(request, business, business_address_manager) -> "BusinessAd
return business_address_manager.create_dummy(business_id=business.id)
-@pytest.fixture(scope="function")
+@pytest.fixture
def business_bank_account(
request, business, business_bank_account_manager
) -> "BusinessBankAccount":
@@ -406,7 +416,7 @@ def business_bank_account(
return business_bank_account_manager.create_dummy(business_id=business.id)
-@pytest.fixture(scope="function")
+@pytest.fixture
def team(request, team_manager) -> "Team":
from generalresearch.managers.gr.team import TeamManager
@@ -414,7 +424,7 @@ def team(request, team_manager) -> "Team":
return team_manager.create_dummy()
-@pytest.fixture(scope="function")
+@pytest.fixture
def gr_user(gr_um) -> "GRUser":
from generalresearch.managers.gr.authentication import GRUserManager
@@ -422,7 +432,7 @@ def gr_user(gr_um) -> "GRUser":
return gr_um.create_dummy()
-@pytest.fixture(scope="function")
+@pytest.fixture
def gr_user_cache(gr_user, gr_db, thl_web_rr, gr_redis_config):
gr_user.set_cache(
pg_config=gr_db, thl_web_rr=thl_web_rr, redis_config=gr_redis_config
@@ -430,12 +440,12 @@ def gr_user_cache(gr_user, gr_db, thl_web_rr, gr_redis_config):
return gr_user
-@pytest.fixture(scope="function")
-def gr_user_factory(gr_um) -> Callable:
- def _create_gr_user():
+@pytest.fixture
+def gr_user_factory(gr_um) -> Callable[..., "GRUser"]:
+ def _inner():
return gr_um.create_dummy()
- return _create_gr_user
+ return _inner
@pytest.fixture()
@@ -447,7 +457,7 @@ def gr_user_token(gr_user, gr_tm, gr_db) -> "GRToken":
@pytest.fixture()
-def gr_user_token_header(gr_user_token) -> Dict:
+def gr_user_token_header(gr_user_token: "GRToken") -> Dict[str, str]:
return gr_user_token.auth_header
@@ -461,41 +471,41 @@ def membership(request, team, gr_user, team_manager) -> "Membership":
@pytest.fixture(scope="function")
def membership_factory(
team: "Team", gr_user: "GRUser", membership_manager, team_manager, gr_um
-) -> Callable:
+) -> Callable[..., "Membership"]:
from generalresearch.managers.gr.team import MembershipManager
membership_manager: MembershipManager
- def _create_membership(**kwargs):
+ def _inner(**kwargs) -> "Membership":
_team = kwargs.get("team", team_manager.create_dummy())
_gr_user = kwargs.get("gr_user", gr_um.create_dummy())
return membership_manager.create(team=_team, gr_user=_gr_user)
- return _create_membership
+ return _inner
-@pytest.fixture(scope="function")
-def audit_log(audit_log_manager, user) -> "AuditLog":
+@pytest.fixture
+def audit_log(audit_log_manager, user: "User") -> "AuditLog":
from generalresearch.managers.thl.userhealth import AuditLogManager
audit_log_manager: AuditLogManager
return audit_log_manager.create_dummy(user_id=user.user_id)
-@pytest.fixture(scope="function")
-def audit_log_factory(audit_log_manager) -> Callable:
+@pytest.fixture
+def audit_log_factory(audit_log_manager) -> Callable[..., "AuditLog"]:
from generalresearch.managers.thl.userhealth import AuditLogManager
audit_log_manager: AuditLogManager
- def _create_audit_log(
+ def _inner(
user_id: PositiveInt,
level: Optional["AuditLogLevel"] = None,
event_type: Optional[str] = None,
event_msg: Optional[str] = None,
event_value: Optional[float] = None,
- ):
+ ) -> "AuditLog":
return audit_log_manager.create_dummy(
user_id=user_id,
level=level,
@@ -504,10 +514,10 @@ def audit_log_factory(audit_log_manager) -> Callable:
event_value=event_value,
)
- return _create_audit_log
+ return _inner
-@pytest.fixture(scope="function")
+@pytest.fixture
def ip_geoname(ip_geoname_manager) -> "IPGeoname":
from generalresearch.managers.thl.ipinfo import IPGeonameManager
@@ -515,7 +525,7 @@ def ip_geoname(ip_geoname_manager) -> "IPGeoname":
return ip_geoname_manager.create_dummy()
-@pytest.fixture(scope="function")
+@pytest.fixture
def ip_information(ip_information_manager, ip_geoname) -> "IPInformation":
from generalresearch.managers.thl.ipinfo import IPInformationManager
@@ -525,7 +535,7 @@ def ip_information(ip_information_manager, ip_geoname) -> "IPInformation":
)
-@pytest.fixture(scope="function")
+@pytest.fixture
def ip_information_factory(ip_information_manager) -> Callable:
from generalresearch.managers.thl.ipinfo import IPInformationManager
@@ -542,8 +552,8 @@ def ip_information_factory(ip_information_manager) -> Callable:
return _create_ip_info
-@pytest.fixture(scope="function")
-def ip_record(ip_record_manager, ip_geoname, user) -> "IPRecord":
+@pytest.fixture
+def ip_record(ip_record_manager, ip_geoname, user: "User") -> "IPRecord":
from generalresearch.managers.thl.userhealth import IPRecordManager
ip_record_manager: IPRecordManager
@@ -551,8 +561,8 @@ def ip_record(ip_record_manager, ip_geoname, user) -> "IPRecord":
return ip_record_manager.create_dummy(user_id=user.user_id)
-@pytest.fixture(scope="function")
-def ip_record_factory(ip_record_manager, user) -> Callable:
+@pytest.fixture
+def ip_record_factory(ip_record_manager, user: "User") -> Callable:
from generalresearch.managers.thl.userhealth import IPRecordManager
ip_record_manager: IPRecordManager
diff --git a/test_utils/spectrum/conftest.py b/test_utils/spectrum/conftest.py
index d737730..0afc3f5 100644
--- a/test_utils/spectrum/conftest.py
+++ b/test_utils/spectrum/conftest.py
@@ -34,21 +34,21 @@ def spectrum_rw(settings: "GRLBaseSettings") -> SqlHelper:
@pytest.fixture(scope="session")
-def spectrum_criteria_manager(spectrum_rw) -> SpectrumCriteriaManager:
+def spectrum_criteria_manager(spectrum_rw: SqlHelper) -> SpectrumCriteriaManager:
assert "/unittest-" in spectrum_rw.dsn.path
return SpectrumCriteriaManager(spectrum_rw)
@pytest.fixture(scope="session")
-def spectrum_survey_manager(spectrum_rw) -> SpectrumSurveyManager:
+def spectrum_survey_manager(spectrum_rw: SqlHelper) -> SpectrumSurveyManager:
assert "/unittest-" in spectrum_rw.dsn.path
return SpectrumSurveyManager(spectrum_rw)
@pytest.fixture(scope="session")
def setup_spectrum_surveys(
- spectrum_rw, spectrum_survey_manager, spectrum_criteria_manager
-):
+ spectrum_rw: SqlHelper, spectrum_survey_manager, spectrum_criteria_manager
+) -> None:
now = datetime.now(timezone.utc)
# make sure these example surveys exist in db
surveys = [SpectrumSurvey.model_validate_json(x) for x in SURVEYS_JSON]