diff options
| author | Max Nanis | 2026-03-07 21:06:45 -0500 |
|---|---|---|
| committer | Max Nanis | 2026-03-07 21:06:45 -0500 |
| commit | b2e56ec3ebc5eb91eb10cc37cc9e02102f441660 (patch) | |
| tree | bc6fa4d29ed83e2294f79e24977f77e7263997a1 | |
| parent | 9833e57ccd2f9ec2090ab1f7da97500a071664b9 (diff) | |
| download | generalresearch-b2e56ec3ebc5eb91eb10cc37cc9e02102f441660.tar.gz generalresearch-b2e56ec3ebc5eb91eb10cc37cc9e02102f441660.zip | |
Simple typing changes, Ruff import formatter.
96 files changed, 867 insertions, 602 deletions
diff --git a/generalresearch/grliq/managers/event_plotter.py b/generalresearch/grliq/managers/event_plotter.py index b879c5c..54105ce 100644 --- a/generalresearch/grliq/managers/event_plotter.py +++ b/generalresearch/grliq/managers/event_plotter.py @@ -1,12 +1,13 @@ import html -from typing import List import webbrowser +from typing import List + import numpy as np from more_itertools import windowed from scipy.spatial.distance import euclidean from generalresearch.grliq.managers.colormap import turbo_colormap_data -from generalresearch.grliq.models.events import MouseEvent, KeyboardEvent +from generalresearch.grliq.models.events import KeyboardEvent, MouseEvent def make_events_svg( @@ -30,6 +31,10 @@ def make_events_svg( for ee in windowed(move_events, 2): e1 = ee[0] e2 = ee[1] + + assert e1 is not None + assert e2 is not None + ts_idx = (e2.timeStamp - t.min()) / t_diff r, g, b = turbo_colormap_data[round(ts_idx * 255)] color = f"rgb({int(r*255)},{int(g*255)},{int(b*255)})" @@ -115,7 +120,7 @@ def svg_multiline_text( def group_input_events_by_xy( mouse_events: List[MouseEvent], keyboard_events: List[KeyboardEvent] -): +) -> List[tuple[tuple[float, float], List[str]]]: """ Each keypress is its own event. For plotting, we want to group together all keypresses that were made when the mouse was at the same position, diff --git a/generalresearch/grliq/managers/forensic_data.py b/generalresearch/grliq/managers/forensic_data.py index 0f58d54..d7e362d 100644 --- a/generalresearch/grliq/managers/forensic_data.py +++ b/generalresearch/grliq/managers/forensic_data.py @@ -1,16 +1,16 @@ from datetime import datetime, timezone -from typing import Optional, List, Collection, Dict, Tuple, Any +from typing import Any, Collection, Dict, List, Optional, Tuple from uuid import uuid4 from psycopg import sql -from pydantic import PositiveInt, NonNegativeInt +from pydantic import NonNegativeInt, PositiveInt from generalresearch.grliq.managers import DUMMY_GRLIQ_DATA from generalresearch.grliq.models.events import PointerMove, TimingData from generalresearch.grliq.models.forensic_data import GrlIqData from generalresearch.grliq.models.forensic_result import ( - GrlIqForensicCategoryResult, GrlIqCheckerResults, + GrlIqForensicCategoryResult, Phase, ) from generalresearch.models.custom_types import UUIDStr @@ -25,7 +25,7 @@ class GrlIqDataManager: def create_dummy( self, - is_attempt_allowed: True, + is_attempt_allowed: bool = True, product_id: Optional[str] = None, product_user_id: Optional[str] = None, uuid: Optional[str] = None, @@ -118,7 +118,7 @@ class GrlIqDataManager: with self.postgres_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, data) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() iq_data.id = pk @@ -316,16 +316,22 @@ class GrlIqDataManager: def filter_timing_data( self, - created_between: Optional[Tuple[datetime, datetime]] = None, + created_between: Tuple[datetime, datetime], limit: Optional[int] = None, offset: Optional[int] = None, - ) -> List[Dict]: + ) -> List[Dict[str, Any]]: + + # TODO! created_between used to be marked as Optional, but it would + # break the query. Evaluate it's use to determine best behavior. + limit_str = f"LIMIT {limit}" if limit is not None else "" offset_str = f"OFFSET {offset}" if offset is not None else "" + params = { "created_after": created_between[0], "created_before": created_between[1], } + query = f""" SELECT d.id, d.session_uuid, d.client_ip, d.country_iso, @@ -342,10 +348,12 @@ class GrlIqDataManager: WHERE d.created_at BETWEEN %(created_after)s AND %(created_before)s {limit_str} {offset_str}; """ + with self.postgres_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, params) - res: List[Dict] = c.fetchall() + res: List[Dict[str, Any]] = c.fetchall() # type: ignore + for x in res: x["timing_data"] = TimingData.model_validate(x["timing_data"]) @@ -353,12 +361,14 @@ class GrlIqDataManager: def get_unique_user_count_by_fingerprint( self, - product_id: str, + product_id: UUIDStr, fingerprint: str, product_user_id_not: str, ) -> NonNegativeInt: + # This is used for filtering for other forensic posts with a certain # fingerprint, in this product_id, but NOT for this user. + query = sql.SQL( """ SELECT COUNT(DISTINCT product_user_id) as user_count @@ -378,7 +388,8 @@ class GrlIqDataManager: with self.postgres_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, params) - user_count = c.fetchone()["user_count"] + user_count = c.fetchone()["user_count"] # type: ignore + return int(user_count) def filter_data( @@ -695,7 +706,7 @@ class GrlIqDataManager: order_by: str = "created_at DESC", limit: Optional[int] = None, offset: Optional[int] = None, - ) -> List[Dict]: + ) -> List[Dict[str, Any]]: """ Accepts lots of optional filters. """ @@ -738,7 +749,7 @@ class GrlIqDataManager: with self.postgres_config.make_connection() as conn: with conn.cursor() as c: c.execute(query=query, params=params) - res: List = c.fetchall() + res: List[Dict[str, Any]] = c.fetchall() # type: ignore for x in res: @@ -766,7 +777,7 @@ class GrlIqDataManager: return res @staticmethod - def temporary_add_missing_fields(d: Dict): + def temporary_add_missing_fields(d: Dict[str, Any]) -> None: # The following fields were added recently, and so we must give them # a value or old db rows won't be parseable. Once logs are backfilled # then this can be removed @@ -785,8 +796,9 @@ class GrlIqDataManager: if k not in d: d[k] = v - # We made a mistake once and saved the grliq data object with the events fields set. - # Make sure they are not set here. We load them from the events table, not here! + # We made a mistake once and saved the grliq data object with the + # events fields set. Make sure they are not set here. We load them + # from the events table, not here! d.pop("events", None) d.pop("pointer_move_events", None) d.pop("mouse_events", None) diff --git a/generalresearch/grliq/managers/forensic_events.py b/generalresearch/grliq/managers/forensic_events.py index 2633db7..bbc6b6d 100644 --- a/generalresearch/grliq/managers/forensic_events.py +++ b/generalresearch/grliq/managers/forensic_events.py @@ -1,16 +1,17 @@ import json from datetime import datetime -from typing import Optional, List, Collection, Dict +from typing import Any, Collection, Dict, List, Optional from uuid import uuid4 from psycopg import sql +from pydantic import PositiveInt from generalresearch.grliq.models.events import ( - TimingData, - PointerMove, - MouseEvent, - KeyboardEvent, Bounds, + KeyboardEvent, + MouseEvent, + PointerMove, + TimingData, ) from generalresearch.models.custom_types import UUIDStr from generalresearch.pg_helper import PostgresConfig @@ -24,8 +25,8 @@ class GrlIqEventManager: def update_or_create_timing( self, session_uuid: UUIDStr, - timing_data: TimingData, - ): + timing_data: Optional[TimingData] = None, + ) -> PositiveInt: data = { "session_uuid": session_uuid, "timing_data": ( @@ -69,7 +70,7 @@ class GrlIqEventManager: pk = c.fetchone()["id"] conn.commit() - return pk + return int(pk) def update_or_create_events( self, @@ -78,7 +79,7 @@ class GrlIqEventManager: event_end: datetime, events: Optional[List[Dict]] = None, mouse_events: Optional[List[Dict]] = None, - ): + ) -> PositiveInt: data = { "uuid": uuid4().hex, "session_uuid": session_uuid, @@ -130,7 +131,7 @@ class GrlIqEventManager: pk = c.fetchone()["id"] conn.commit() - return pk + return int(pk) def filter( self, @@ -141,14 +142,16 @@ class GrlIqEventManager: started_since: Optional[datetime] = None, limit: Optional[int] = None, order_by: str = "event_start DESC", - ) -> List[Dict]: - """ """ + ) -> List[Dict[str, Any]]: + if not limit: limit = 100 if not select_str: select_str = "*" + filters = [] params = {} + if session_uuid: params["session_uuid"] = session_uuid filters.append("session_uuid = %(session_uuid)s") @@ -164,6 +167,7 @@ class GrlIqEventManager: filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" + query = f""" SELECT {select_str} FROM grliq_forensicevents @@ -189,12 +193,13 @@ class GrlIqEventManager: events=events, pointer_moves=pointer_moves ) x["keyboard_events"] = self.process_keyboard_events(events=events) + return res def filter_distinct_timing( self, session_uuids: Collection[str], - ) -> List[Dict]: + ) -> List[Dict[str, Any]]: params = {"session_uuids": list(session_uuids)} query = sql.SQL( """ @@ -226,9 +231,10 @@ class GrlIqEventManager: @staticmethod def process_mouse_events(pointer_moves: List[PointerMove], events: List[Dict]): """ - In the db column 'mouse_events' we put all 'pointermove' events. Pull those - out, and then any 'pointerdown' and 'pointerup' events from the 'events' column, - and merge them all together into a list of MouseEvent objects + In the db column 'mouse_events' we put all 'pointermove' events. Pull + those out, and then any 'pointerdown' and 'pointerup' events from the + 'events' column, and merge them all together into a list of MouseEvent + objects """ mouse_events = [ # these contain only pointermove events diff --git a/generalresearch/grliq/managers/forensic_results.py b/generalresearch/grliq/managers/forensic_results.py index dd1b039..30db53d 100644 --- a/generalresearch/grliq/managers/forensic_results.py +++ b/generalresearch/grliq/managers/forensic_results.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional, List, Collection, Dict, Tuple +from typing import Any, Collection, Dict, List, Optional, Tuple from generalresearch.grliq.models.forensic_result import ( GrlIqForensicCategoryResult, @@ -25,9 +25,10 @@ class GrlIqCategoryResultsReader: created_between: Optional[Tuple[datetime, datetime]] = None, user: Optional[User] = None, limit: Optional[int] = None, - ) -> List[Dict]: + ) -> List[Dict[str, Any]]: """ For retrieving GrlIqForensicCategoryResult objects from db. + :return: List of Dict. Keys are below in the 'select_str'. """ select_str = ( @@ -37,10 +38,11 @@ class GrlIqCategoryResultsReader: " category_result, is_attempt_allowed, fraud_score" ) if not limit: - limit = 5000 + limit = 5_000 filters = [] params = {} + if session_uuid: params["session_uuid"] = session_uuid filters.append("d.session_uuid = %(session_uuid)s") @@ -77,6 +79,7 @@ class GrlIqCategoryResultsReader: filters.append( "(d.product_id = %(product_id)s AND d.product_user_id = %(product_user_id)s)" ) + filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" query = f""" @@ -85,6 +88,7 @@ class GrlIqCategoryResultsReader: {filter_str} ORDER BY created_at DESC LIMIT {limit} """ + with self.postgres_config.make_connection() as conn: with conn.cursor() as c: c.execute(query, params) diff --git a/generalresearch/grliq/managers/forensic_summary.py b/generalresearch/grliq/managers/forensic_summary.py index c44daf2..21b7e4b 100644 --- a/generalresearch/grliq/managers/forensic_summary.py +++ b/generalresearch/grliq/managers/forensic_summary.py @@ -2,8 +2,8 @@ from __future__ import annotations import statistics from collections import defaultdict -from datetime import datetime, timezone, timedelta -from typing import List, Dict +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List import numpy as np @@ -12,15 +12,15 @@ from generalresearch.grliq.managers.forensic_events import ( GrlIqEventManager, ) from generalresearch.grliq.models.forensic_result import ( - GrlIqForensicCategoryResult, GrlIqCheckerResults, + GrlIqForensicCategoryResult, ) from generalresearch.grliq.models.forensic_summary import ( - GrlIqForensicCategorySummary, - GrlIqCheckerResultsSummary, - UserForensicSummary, CountryRTTDistribution, + GrlIqCheckerResultsSummary, + GrlIqForensicCategorySummary, TimingDataCountrySummary, + UserForensicSummary, ) from generalresearch.models.thl.user import User from generalresearch.redis_helper import RedisConfig @@ -85,8 +85,9 @@ def calculate_checker_summary( def calculate_timing_summary( - redis_config: RedisConfig, timing_res + redis_config: RedisConfig, timing_res: List[Dict[str, Any]] ) -> Dict[str, TimingDataCountrySummary]: + country_median_rtts = defaultdict(list) for x in timing_res: s = x["timing_data"].summarize @@ -135,6 +136,7 @@ def run_user_forensic_summary( redis_config: RedisConfig, user: User, ) -> UserForensicSummary: + now = datetime.now(tz=timezone.utc) created_between = (now - timedelta(days=90), now) select_str = "id, session_uuid, product_id, product_user_id, created_at, result_data, category_result" diff --git a/generalresearch/grliq/models/events.py b/generalresearch/grliq/models/events.py index 2c8fa64..7e9ebee 100644 --- a/generalresearch/grliq/models/events.py +++ b/generalresearch/grliq/models/events.py @@ -3,15 +3,15 @@ from __future__ import annotations from collections import namedtuple from dataclasses import dataclass, fields from functools import cached_property -from typing import List, Optional +from typing import Any, Dict, List, Optional import numpy as np from pydantic import ( BaseModel, ConfigDict, Field, - NonNegativeInt, NonNegativeFloat, + NonNegativeInt, PositiveFloat, ) from typing_extensions import Self @@ -44,7 +44,7 @@ class Event: _elementBounds: Optional[Bounds] = None @classmethod - def from_dict(cls, data: dict) -> Self: + def from_dict(cls, data: Dict[str, Any]) -> Self: data = {k: v for k, v in data.items() if k in cls.__dataclass_fields__} bounds = data.get("_elementBounds") if bounds is not None and not isinstance(bounds, Bounds): @@ -56,16 +56,21 @@ class Event: class PointerMove(Event): # should always be 'pointermove' type: str + # (mouse, touch, pen) pointerType: str - # coordinate relative to the screen + + # Coordinate relative to the screen screenX: float screenY: float - # coordinate relative to the document (unaffected by scrolling) + + # Coordinate relative to the document (unaffected by scrolling) pageX: float pageY: float - # pageX/Y divided by the document Width/Height. This is calculated in JS and sent, which - # it must be b/c we don't know the document width/height at each time otherwise. + + # PageX/Y divided by the document Width/Height. This is calculated in JS + # and sent, which it must be b/c we don't know the document width/height + # at each time otherwise. normalizedX: float normalizedY: float @@ -82,8 +87,10 @@ class MouseEvent(Event): # should be {'pointerdown', 'pointerup', 'pointermove', 'click'} type: str + # Type of input (mouse, touch, pen) pointerType: str + # coordinate relative to the document (unaffected by scrolling) pageX: float pageY: float @@ -91,15 +98,17 @@ class MouseEvent(Event): @dataclass class KeyboardEvent(Event): - """ """ # should be {'keydown', 'input'} type: str + # "insertText", "insertCompositionText", "deleteCompositionText", # "insertFromComposition", "deleteContentBackward" inputType: Optional[str] + # e.g., 'Enter', 'a', 'Backspace' key: Optional[str] = None + # This is the actual text, if applicable data: Optional[str] = None @@ -180,7 +189,7 @@ class TimingData(BaseModel): def has_data(self): return len(self.client_rtts) > 0 and len(self.server_rtts) > 0 - def filter_rtts(self, rtts): + def filter_rtts(self, rtts: List[float]) -> List[float]: # Skip the first 5 pings, unless we have <10 pings, then get the last # 5 instead. # The first couple pings are usually outliers as they are running @@ -189,6 +198,7 @@ class TimingData(BaseModel): rtts = rtts[5:] else: rtts = rtts[-5:] + return rtts @cached_property diff --git a/generalresearch/grliq/models/forensic_data.py b/generalresearch/grliq/models/forensic_data.py index c182f43..1eb3648 100644 --- a/generalresearch/grliq/models/forensic_data.py +++ b/generalresearch/grliq/models/forensic_data.py @@ -1,55 +1,55 @@ import hashlib import re from collections import Counter -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from enum import Enum from functools import cached_property -from typing import Literal, Optional, Dict, List, Set, Any +from typing import Any, Dict, List, Literal, Optional, Set from uuid import uuid4 import pycountry from faker import Faker from pydantic import ( + AfterValidator, + AwareDatetime, BaseModel, ConfigDict, Field, - field_validator, - StringConstraints, - AfterValidator, - AwareDatetime, NonNegativeInt, + StringConstraints, + field_validator, ) from pydantic.json_schema import SkipJsonSchema from pydantic_extra_types.timezone_name import TimeZoneName -from typing_extensions import Self, Annotated +from typing_extensions import Annotated, Self from generalresearch.grliq.models import ( AUDIO_CODEC_NAMES, - VIDEO_CODEC_NAMES, SUPPORTED_FONTS, + VIDEO_CODEC_NAMES, ) from generalresearch.grliq.models.events import ( + KeyboardEvent, + MouseEvent, PointerMove, TimingData, - MouseEvent, - KeyboardEvent, ) from generalresearch.grliq.models.forensic_result import ( - GrlIqForensicCategoryResult, GrlIqCheckerResults, + GrlIqForensicCategoryResult, + Phase, ) -from generalresearch.grliq.models.forensic_result import Phase from generalresearch.grliq.models.useragents import ( GrlUserAgent, OSFamily, UserAgentHints, ) from generalresearch.models.custom_types import ( - UUIDStr, - IPvAnyAddressStr, - BigAutoInteger, AwareDatetimeISO, + BigAutoInteger, CountryISOLike, + IPvAnyAddressStr, + UUIDStr, ) from generalresearch.models.thl.ipinfo import GeoIPInformation from generalresearch.models.thl.session import Session @@ -265,6 +265,7 @@ class GrlIqData(BaseModel): navigator_java_enabled: bool = Field() do_not_track_enabled: str = Field(description="unspecified or '' ? ") mime_types_length: int = Field() + # todo: some report actual values, some are (always) faked by the os/browser # as anti-fingerprint 10737418240 (10gb) typical on firefox windows, # 2147483648 (20gb) chrome, iPhones typically only 8 different values @@ -542,7 +543,7 @@ class GrlIqData(BaseModel): return hashlib.md5(s.encode()).hexdigest() @cached_property - def audio_codecs_named(self) -> Dict: + def audio_codecs_named(self) -> Dict[str, bool]: return dict( zip( AUDIO_CODEC_NAMES, @@ -551,7 +552,7 @@ class GrlIqData(BaseModel): ) @cached_property - def video_codecs_named(self) -> Dict: + def video_codecs_named(self) -> Dict[str, bool]: return dict( zip( VIDEO_CODEC_NAMES, @@ -618,7 +619,7 @@ class GrlIqData(BaseModel): mode="before", ) @classmethod - def str_to_float_or_null(cls, value: str) -> Optional[int]: + def str_to_float_or_null(cls, value: str) -> Optional[float]: return float(value) if value not in {None, ""} else None @field_validator( @@ -787,7 +788,7 @@ class GrlIqData(BaseModel): return d @classmethod - def from_db(cls, d: Dict) -> Self: + def from_db(cls, d: Dict[str, Any]) -> Self: res = GrlIqData.model_validate(d["data"]) if d.get("category_result"): diff --git a/generalresearch/grliq/models/forensic_result.py b/generalresearch/grliq/models/forensic_result.py index 7f906c2..d89f681 100644 --- a/generalresearch/grliq/models/forensic_result.py +++ b/generalresearch/grliq/models/forensic_result.py @@ -1,18 +1,18 @@ from __future__ import annotations from enum import Enum -from typing import Optional, List, Set +from typing import List, Optional, Set from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field, computed_field from generalresearch.grliq.models.custom_types import GrlIqScore from generalresearch.grliq.models.decider import ( - Decider, AttemptDecision, + Decider, GrlIqAttemptResult, ) -from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr class Phase(str, Enum): @@ -21,10 +21,13 @@ class Phase(str, Enum): # Within a custom offerwall. Very optional, as most BPs won't be running our code OFFERWALL = "offerwall" + # When a user clicks on a bucket. Each session should go through this OFFERWALL_ENTER = "offerwall-enter" + # Running in GRS. Not every session will have this. PROFILING = "profiling" + # We could run grl-iq again when a user continues a session SESSION_CONTINUE = "session-continue" @@ -33,8 +36,9 @@ class GrlIqForensicCategoryResult(BaseModel): """ This is for reporting external to GRL. - There is a balance between exposing enough to answer "why did this user get blocked?" without - giving away technical knowledge that could be used to bypass. + There is a balance between exposing enough to answer "why did this user + get blocked?" without giving away technical knowledge that could be + used to bypass. """ model_config = ConfigDict(extra="forbid", validate_assignment=True) diff --git a/generalresearch/grliq/models/forensic_summary.py b/generalresearch/grliq/models/forensic_summary.py index 2d38435..a40e103 100644 --- a/generalresearch/grliq/models/forensic_summary.py +++ b/generalresearch/grliq/models/forensic_summary.py @@ -2,15 +2,15 @@ from __future__ import annotations import random from typing import ( + Dict, List, Literal, Optional, Tuple, - Dict, - get_type_hints, - get_origin, Union, get_args, + get_origin, + get_type_hints, ) import numpy as np @@ -19,17 +19,17 @@ from pydantic import ( ConfigDict, Field, NonNegativeInt, - create_model, computed_field, + create_model, ) from scipy.stats import lognorm from generalresearch.grliq.models.custom_types import GrlIqAvgScore, GrlIqRate from generalresearch.grliq.models.forensic_result import ( - GrlIqCheckerResults, GrlIqCheckerResult, + GrlIqCheckerResults, ) -from generalresearch.models.custom_types import IPvAnyAddressStr, AwareDatetimeISO +from generalresearch.models.custom_types import AwareDatetimeISO, IPvAnyAddressStr from generalresearch.models.thl.locales import CountryISO from generalresearch.models.thl.maxmind.definitions import UserType diff --git a/generalresearch/grliq/utils.py b/generalresearch/grliq/utils.py index 6e563f9..c772e7f 100644 --- a/generalresearch/grliq/utils.py +++ b/generalresearch/grliq/utils.py @@ -1,11 +1,11 @@ import os from datetime import datetime, timezone +from pathlib import Path from typing import Optional, Union from uuid import UUID # from generalresearch.config import from generalresearch.models.custom_types import UUIDStr -from pathlib import Path def get_screenshot_fp( diff --git a/generalresearch/incite/base.py b/generalresearch/incite/base.py index bd346f9..6f1a9be 100644 --- a/generalresearch/incite/base.py +++ b/generalresearch/incite/base.py @@ -8,20 +8,21 @@ import shutil import subprocess import warnings from concurrent.futures import Future -from datetime import datetime, timezone, timedelta -from os import access, R_OK, listdir -from os.path import join as pjoin, isdir +from datetime import datetime, timedelta, timezone +from os import R_OK, access, listdir +from os.path import isdir +from os.path import join as pjoin from pathlib import Path from sys import platform from typing import ( - Optional, - Tuple, + TYPE_CHECKING, + Any, + Callable, List, + Optional, Sequence, - Any, + Tuple, Union, - Callable, - TYPE_CHECKING, ) from uuid import uuid4 @@ -35,12 +36,12 @@ from pydantic import ( BaseModel, ConfigDict, DirectoryPath, - PrivateAttr, Field, - model_validator, FilePath, - field_validator, + PrivateAttr, ValidationInfo, + field_validator, + model_validator, ) from pydantic.json_schema import SkipJsonSchema from sentry_sdk import capture_exception @@ -54,11 +55,11 @@ from generalresearch.incite.schemas import ( from generalresearch.models.custom_types import AwareDatetimeISO if TYPE_CHECKING: - from generalresearch.incite.mergers import MergeType, MergeCollection + from generalresearch.incite.collections import DFCollection from generalresearch.incite.collections.thl_marketplaces import ( DFCollectionType, ) - from generalresearch.incite.collections import DFCollection + from generalresearch.incite.mergers import MergeCollection, MergeType Collection = Union[DFCollection, MergeCollection] @@ -93,10 +94,10 @@ class GRLDatasets(BaseModel): @model_validator(mode="after") def check_data_src_and_et_path(self) -> Self: - from generalresearch.incite.mergers import MergeType from generalresearch.incite.collections.thl_marketplaces import ( DFCollectionType, ) + from generalresearch.incite.mergers import MergeType # Create the base folders and confirm we have read access self.data_src.mkdir(parents=True, exist_ok=True) diff --git a/generalresearch/incite/collections/__init__.py b/generalresearch/incite/collections/__init__.py index 051c5a1..9049e8f 100644 --- a/generalresearch/incite/collections/__init__.py +++ b/generalresearch/incite/collections/__init__.py @@ -5,7 +5,7 @@ import time from datetime import datetime from enum import Enum from sys import platform -from typing import Optional, List +from typing import Any, Dict, List, Optional import dask import dask.dataframe as dd @@ -16,13 +16,13 @@ from distributed import Client, as_completed from more_itertools import chunked from pandera import DataFrameSchema from psycopg import Cursor -from pydantic import Field, FilePath, field_validator, ValidationInfo +from pydantic import Field, FilePath, ValidationInfo, field_validator from sentry_sdk import capture_exception from generalresearch.incite.base import CollectionBase, CollectionItemBase from generalresearch.incite.schemas import ( - ORDER_KEY, ARCHIVE_AFTER, + ORDER_KEY, PARTITION_ON, empty_dataframe_from_schema, ) @@ -33,18 +33,18 @@ from generalresearch.incite.schemas.thl_marketplaces import ( SpectrumSurveyTimeseriesSchema, ) from generalresearch.incite.schemas.thl_web import ( - TxSchema, - TxMetaSchema, - THLUserSchema, + LedgerSchema, + THLIPInfoSchema, + THLSessionSchema, THLTaskAdjustmentSchema, + THLUserSchema, THLWallSchema, - THLSessionSchema, - THLIPInfoSchema, TransactionMetadataColumns, - UserHealthIPHistorySchema, + TxMetaSchema, + TxSchema, UserHealthAuditLogSchema, + UserHealthIPHistorySchema, UserHealthIPHistoryWSSchema, - LedgerSchema, ) from generalresearch.pg_helper import PostgresConfig from generalresearch.sql_helper import SqlHelper @@ -167,7 +167,7 @@ class DFCollectionItem(CollectionItemBase): return self.to_archive(ddf=dd.from_pandas(_df, npartitions=1), is_partial=True) # --- ORM / Data handlers--- - def to_dict(self, *args, **kwargs) -> dict: + def to_dict(self, *args, **kwargs) -> Dict[str, Any]: return self._to_dict() def from_mysql(self, since: Optional[datetime] = None) -> Optional[pd.DataFrame]: @@ -503,7 +503,7 @@ class DFCollectionItem(CollectionItemBase): subprocess.call(["mv", "-T", tmp_path.as_posix(), self.path.as_posix()]) return True - def to_archive_numbered_partial(self, ddf: dd.DataFrame) -> bool: + def to_archive_numbered_partial(self, ddf: Optional[dd.DataFrame] = None) -> bool: """ For partial files/dirs only. Writes the .partial file with a number at the end (.partial.####) and then creates a symlink @@ -513,13 +513,14 @@ class DFCollectionItem(CollectionItemBase): """ if ddf is None: return False + collection = self._collection schema = collection._schema client: Optional[Client] = collection._client next_numbered_path = self.next_numbered_path(self.partial_path) partial_path = self.partial_path - finish = self.finish + # finish = self.finish # Make sure these are in the same dir. b/c the symlink has to be # relative, not an absolute path @@ -571,7 +572,7 @@ class DFCollectionItem(CollectionItemBase): assert self.should_archive(), "not ready to archive!" - df: pd.DataFrame = self.from_mysql() + df: Optional[pd.DataFrame] = self.from_mysql() if df is None: self.set_empty() @@ -648,9 +649,9 @@ class DFCollection(CollectionBase): def initial_load( self, client: Optional[Client] = None, - sync=True, + sync: bool = True, since: Optional[datetime] = None, - client_resources=None, + client_resources: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, ) -> List[Future]: # This can be used to just build all local archive files @@ -733,10 +734,15 @@ class DFCollection(CollectionBase): return sources def force_rr_latest( - self, client: Client, client_resources=None, sync: bool = True + self, + client: Client, + client_resources: Optional[Dict[str, Any]] = None, + sync: bool = True, ) -> List[Future]: + # For forcing update of any partials asynchronously if desired LOG.info(f"{self.data_type.value}.force_rr_latest({client=})") + rr_items = [ i for i in self.items if not i.should_archive() and not i.is_empty() ] diff --git a/generalresearch/incite/defaults.py b/generalresearch/incite/defaults.py index e200ddc..421710e 100644 --- a/generalresearch/incite/defaults.py +++ b/generalresearch/incite/defaults.py @@ -9,11 +9,11 @@ from generalresearch.incite.collections.thl_marketplaces import ( SpectrumSurveyTimeseriesCollection, ) from generalresearch.incite.collections.thl_web import ( + LedgerDFCollection, SessionDFCollection, - WallDFCollection, - UserDFCollection, TaskAdjustmentDFCollection, - LedgerDFCollection, + UserDFCollection, + WallDFCollection, ) from generalresearch.incite.mergers import MergeType from generalresearch.incite.mergers.foundations.enriched_session import ( @@ -33,7 +33,6 @@ from generalresearch.incite.mergers.ym_survey_wall import YMSurveyWallMerge from generalresearch.pg_helper import PostgresConfig from generalresearch.sql_helper import SqlHelper - # --- THL Web --- # diff --git a/generalresearch/incite/mergers/foundations/__init__.py b/generalresearch/incite/mergers/foundations/__init__.py index d3db74c..9ce9d91 100644 --- a/generalresearch/incite/mergers/foundations/__init__.py +++ b/generalresearch/incite/mergers/foundations/__init__.py @@ -1,8 +1,9 @@ import logging -from typing import Collection, List, Dict +from typing import Any, Collection, Dict, List import pandas as pd from more_itertools import chunked +from pydantic import PositiveInt from generalresearch.pg_helper import PostgresConfig @@ -10,7 +11,7 @@ LOG = logging.getLogger("incite") def annotate_product_id( - df: pd.DataFrame, pg_config: PostgresConfig, chunksize=500 + df: pd.DataFrame, pg_config: PostgresConfig, chunksize: PositiveInt = 500 ) -> pd.DataFrame: """ Dask map_partitions is being called on a dask df. However, the function @@ -27,7 +28,7 @@ def annotate_product_id( assert len(user_ids) >= 1, "must have user_ids" LOG.warning(f"annotate_product_id.len(user_ids): {len(user_ids)}") - res: List[Dict] = [] + res: List[Dict[str, Any]] = [] with pg_config.make_connection() as conn: for chunk in chunked(user_ids, chunksize): try: @@ -53,15 +54,17 @@ def annotate_product_id( def lookup_product_and_team_id( user_ids: Collection[int], pg_config: PostgresConfig, -) -> List[Dict]: +) -> List[Dict[str, Any]]: + user_ids = set(user_ids) LOG.info(f"lookup_product_and_team_id: {len(user_ids)}") LOG.info({type(x) for x in user_ids}) + assert all(type(x) is int for x in user_ids), "must pass all integers" assert len(user_ids) >= 1, "must have user_ids" assert len(user_ids) <= 1000, "you should chunk this bro" - res: List[Dict] = [] + res: List[Dict[str, Any]] = [] with pg_config.make_connection() as conn: try: with conn.cursor() as c: @@ -87,7 +90,7 @@ def lookup_product_and_team_id( def annotate_product_and_team_id( - df: pd.DataFrame, pg_config: PostgresConfig, chunksize=500 + df: pd.DataFrame, pg_config: PostgresConfig, chunksize: PositiveInt = 500 ) -> pd.DataFrame: """ Dask map_partitions is being called on a dask df. However, the function @@ -95,8 +98,8 @@ def annotate_product_and_team_id( df AS a pandas df. expects column 'user_id', adds column 'product_id' and team_id - """ + LOG.info(f"annotate_product_and_team_id.chunk: {df.shape}") assert "user_id" in df.columns, "must have a user_id column to join on" @@ -105,7 +108,7 @@ def annotate_product_and_team_id( assert len(user_ids) >= 1, "must have user_ids" LOG.warning(f"annotate_product_and_team_id.len(user_ids): {len(user_ids)}") - res: List[Dict] = [] + res: List[Dict[str, Any]] = [] with pg_config.make_connection() as conn: for chunk in chunked(user_ids, chunksize): try: @@ -133,7 +136,7 @@ def annotate_product_and_team_id( def annotate_product_user( - df: pd.DataFrame, pg_config: PostgresConfig, chunksize=500 + df: pd.DataFrame, pg_config: PostgresConfig, chunksize: PositiveInt = 500 ) -> pd.DataFrame: LOG.info(f"annotate_product_user.chunk: {df.shape}") assert "user_id" in df.columns, "must have a user_id column to join on" @@ -143,7 +146,7 @@ def annotate_product_user( assert len(user_ids) >= 1, "must have user_ids" LOG.warning(f"annotate_product_user.len(user_ids): {len(user_ids)}") - res: List[Dict] = [] + res: List[Dict[str, Any]] = [] with pg_config.make_connection() as conn: for chunk in chunked(user_ids, chunksize): try: diff --git a/generalresearch/incite/mergers/foundations/enriched_session.py b/generalresearch/incite/mergers/foundations/enriched_session.py index 7fdcb50..4b87df7 100644 --- a/generalresearch/incite/mergers/foundations/enriched_session.py +++ b/generalresearch/incite/mergers/foundations/enriched_session.py @@ -1,6 +1,6 @@ import logging from datetime import timedelta -from typing import Literal, Optional, List, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional import dask.dataframe as dd import pandas as pd @@ -45,7 +45,7 @@ class EnrichedSessionMergeItem(MergeCollectionItem): wall_coll: WallDFCollection, pg_config: PostgresConfig, client: Optional[Client] = None, - client_resources=None, + client_resources: Optional[Dict[str, Any]] = None, ) -> None: ir: pd.Interval = self.interval diff --git a/generalresearch/incite/mergers/foundations/enriched_task_adjust.py b/generalresearch/incite/mergers/foundations/enriched_task_adjust.py index 1749f9a..d2a8aa1 100644 --- a/generalresearch/incite/mergers/foundations/enriched_task_adjust.py +++ b/generalresearch/incite/mergers/foundations/enriched_task_adjust.py @@ -1,5 +1,5 @@ import logging -from typing import Literal +from typing import Any, Dict, Literal, Optional import dask.dataframe as dd import pandas as pd @@ -11,8 +11,8 @@ from generalresearch.incite.collections.thl_web import ( ) from generalresearch.incite.mergers import ( MergeCollection, - MergeType, MergeCollectionItem, + MergeType, ) from generalresearch.incite.mergers.foundations import ( annotate_product_and_team_id, @@ -40,7 +40,7 @@ class EnrichedTaskAdjustMergeItem(MergeCollectionItem): enriched_wall: EnrichedWallMerge, pg_config: PostgresConfig, client: Client, - client_resources=None, + client_resources: Optional[Dict[str, Any]] = None, ) -> None: """ TaskAdjustments are always partial because they could be revoked @@ -61,7 +61,7 @@ class EnrichedTaskAdjustMergeItem(MergeCollectionItem): if len(task_adj_coll_items) == 0: raise Exception("TaskAdjColl item collection failed") - ddf: dd.DataFrame = task_adj_coll.ddf( + ddf: Optional[dd.DataFrame] = task_adj_coll.ddf( items=task_adj_coll_items, include_partial=True, force_rr_latest=False, @@ -114,6 +114,7 @@ class EnrichedTaskAdjustMergeItem(MergeCollectionItem): assert str(ddf.wall_uuid.dtype) == "string" assert str(wall_ddf.index.dtype) == "string" + ddf = ddf.merge( wall_ddf, left_on="wall_uuid", diff --git a/generalresearch/incite/mergers/foundations/enriched_wall.py b/generalresearch/incite/mergers/foundations/enriched_wall.py index 241a239..d69293b 100644 --- a/generalresearch/incite/mergers/foundations/enriched_wall.py +++ b/generalresearch/incite/mergers/foundations/enriched_wall.py @@ -1,6 +1,6 @@ import logging from datetime import timedelta -from typing import Literal, Optional, TYPE_CHECKING, List +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional import dask.dataframe as dd import pandas as pd @@ -40,7 +40,7 @@ class EnrichedWallMergeItem(MergeCollectionItem): session_coll: SessionDFCollection, pg_config: PostgresConfig, client: Optional[Client] = None, - client_resources=None, + client_resources: Optional[Dict[str, Any]] = None, ) -> None: ir: pd.Interval = self.interval diff --git a/generalresearch/incite/mergers/foundations/user_id_product.py b/generalresearch/incite/mergers/foundations/user_id_product.py index e467179..73ee36a 100644 --- a/generalresearch/incite/mergers/foundations/user_id_product.py +++ b/generalresearch/incite/mergers/foundations/user_id_product.py @@ -1,12 +1,12 @@ import logging -from typing import Literal +from typing import Any, Dict, Literal, Optional from distributed import Client from generalresearch.incite.collections.thl_web import UserDFCollection from generalresearch.incite.mergers import ( - MergeCollectionItem, MergeCollection, + MergeCollectionItem, MergeType, ) @@ -16,7 +16,10 @@ LOG = logging.getLogger("incite") class UserIdProductMergeItem(MergeCollectionItem): def build( - self, client: Client, user_coll: UserDFCollection, client_resources=None + self, + client: Client, + user_coll: UserDFCollection, + client_resources: Optional[Dict[str, Any]] = None, ) -> None: LOG.warning(f"UserIdProductMergeItem.build({self.interval})") @@ -34,10 +37,13 @@ class UserIdProductMergeItem(MergeCollectionItem): class UserIdProductMerge(MergeCollection): merge_type: Literal[MergeType.USER_ID_PRODUCT] = MergeType.USER_ID_PRODUCT collection_item_class: Literal[UserIdProductMergeItem] = UserIdProductMergeItem - offset: None = None + offset: Optional[str] = None def build( - self, client: Client, user_coll: UserDFCollection, client_resources=None + self, + client: Client, + user_coll: UserDFCollection, + client_resources: Optional[Dict[str, Any]] = None, ) -> None: LOG.info(f"UserIdProductMerge.build(user_coll={user_coll.signature()})") diff --git a/generalresearch/incite/mergers/nginx_core.py b/generalresearch/incite/mergers/nginx_core.py index 1f471a6..343da8f 100644 --- a/generalresearch/incite/mergers/nginx_core.py +++ b/generalresearch/incite/mergers/nginx_core.py @@ -1,7 +1,7 @@ import json import logging from datetime import datetime, timedelta -from typing import Literal, List +from typing import List, Literal from urllib.parse import parse_qs, urlsplit import dask.bag as db diff --git a/generalresearch/incite/mergers/nginx_fsb.py b/generalresearch/incite/mergers/nginx_fsb.py index 1f1039b..9cb71d3 100644 --- a/generalresearch/incite/mergers/nginx_fsb.py +++ b/generalresearch/incite/mergers/nginx_fsb.py @@ -1,14 +1,13 @@ import json import logging - -from sentry_sdk import capture_exception import re from datetime import datetime, timedelta -from typing import Literal, List +from typing import List, Literal from urllib.parse import parse_qs, urlsplit import dask.bag as db import pandas as pd +from sentry_sdk import capture_exception from generalresearch.incite.mergers import ( MergeCollection, diff --git a/generalresearch/incite/mergers/nginx_grs.py b/generalresearch/incite/mergers/nginx_grs.py index 0242b0b..fb22070 100644 --- a/generalresearch/incite/mergers/nginx_grs.py +++ b/generalresearch/incite/mergers/nginx_grs.py @@ -1,7 +1,7 @@ import json import logging from datetime import datetime, timedelta -from typing import Literal, List +from typing import List, Literal from urllib.parse import parse_qs, urlsplit import dask.bag as db diff --git a/generalresearch/incite/mergers/pop_ledger.py b/generalresearch/incite/mergers/pop_ledger.py index 4915abb..0475df2 100644 --- a/generalresearch/incite/mergers/pop_ledger.py +++ b/generalresearch/incite/mergers/pop_ledger.py @@ -1,5 +1,5 @@ import logging -from typing import Literal, Optional +from typing import Any, Dict, Literal, Optional import dask.dataframe as dd import pandas as pd @@ -24,7 +24,7 @@ class PopLedgerMergeItem(MergeCollectionItem): self, ledger_coll: LedgerDFCollection, client: Optional[Client] = None, - client_resources=None, + client_resources: Optional[Dict[str, Any]] = None, ) -> None: ir: pd.Interval = self.interval diff --git a/generalresearch/incite/mergers/ym_survey_wall.py b/generalresearch/incite/mergers/ym_survey_wall.py index 7f6c31c..4c2defb 100644 --- a/generalresearch/incite/mergers/ym_survey_wall.py +++ b/generalresearch/incite/mergers/ym_survey_wall.py @@ -1,6 +1,6 @@ import logging from datetime import timedelta -from typing import Optional, Literal +from typing import Any, Dict, Literal, Optional import dask.dataframe as dd import pandas as pd @@ -10,8 +10,8 @@ from sentry_sdk import capture_exception from generalresearch.incite.collections.thl_web import WallDFCollection from generalresearch.incite.mergers import ( MergeCollection, - MergeType, MergeCollectionItem, + MergeType, ) from generalresearch.incite.mergers.foundations.enriched_session import ( EnrichedSessionMerge, @@ -31,7 +31,7 @@ class YMSurveyWallMergeCollectionItem(MergeCollectionItem): wall_coll: WallDFCollection, enriched_session: EnrichedSessionMerge, client: Optional[Client] = None, - client_resources=None, + client_resources: Optional[Dict[str, Any]] = None, ) -> None: LOG.info(f"YMSurveyWallMerge.build({self.start=}, {self.finish=})") ir: pd.Interval = self.interval diff --git a/generalresearch/incite/mergers/ym_wall_summary.py b/generalresearch/incite/mergers/ym_wall_summary.py index 419994d..c7b871c 100644 --- a/generalresearch/incite/mergers/ym_wall_summary.py +++ b/generalresearch/incite/mergers/ym_wall_summary.py @@ -1,5 +1,5 @@ -from datetime import timedelta, datetime, time -from typing import Literal, List, Optional, Type +from datetime import datetime, time, timedelta +from typing import List, Literal, Optional, Type import dask.dataframe as dd import pandas as pd @@ -12,8 +12,8 @@ from generalresearch.incite.collections.thl_web import ( ) from generalresearch.incite.mergers import ( MergeCollection, - MergeType, MergeCollectionItem, + MergeType, ) from generalresearch.incite.mergers.foundations.user_id_product import ( UserIdProductMerge, diff --git a/generalresearch/incite/schemas/admin_responses.py b/generalresearch/incite/schemas/admin_responses.py index 73c0aaa..ccfe1cf 100644 --- a/generalresearch/incite/schemas/admin_responses.py +++ b/generalresearch/incite/schemas/admin_responses.py @@ -1,12 +1,12 @@ from datetime import datetime from pandera import ( - DataFrameSchema, - Column, Check, - Parser, - MultiIndex, + Column, + DataFrameSchema, Index, + MultiIndex, + Parser, Timestamp, ) diff --git a/generalresearch/incite/schemas/mergers/foundations/enriched_session.py b/generalresearch/incite/schemas/mergers/foundations/enriched_session.py index 4badfac..ee17bda 100644 --- a/generalresearch/incite/schemas/mergers/foundations/enriched_session.py +++ b/generalresearch/incite/schemas/mergers/foundations/enriched_session.py @@ -1,6 +1,6 @@ from datetime import timedelta -from pandera import DataFrameSchema, Column, Check +from pandera import Check, Column, DataFrameSchema from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY, PARTITION_ON from generalresearch.incite.schemas.thl_web import THLSessionSchema diff --git a/generalresearch/incite/schemas/mergers/foundations/enriched_task_adjust.py b/generalresearch/incite/schemas/mergers/foundations/enriched_task_adjust.py index 35b579b..7b43589 100644 --- a/generalresearch/incite/schemas/mergers/foundations/enriched_task_adjust.py +++ b/generalresearch/incite/schemas/mergers/foundations/enriched_task_adjust.py @@ -1,7 +1,8 @@ -import pandas as pd -from pandera import DataFrameSchema, Column, Check, Index from typing import Set +import pandas as pd +from pandera import Check, Column, DataFrameSchema, Index + from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY from generalresearch.incite.schemas.thl_web import THLTaskAdjustmentSchema from generalresearch.locales import Localelator diff --git a/generalresearch/incite/schemas/mergers/foundations/enriched_wall.py b/generalresearch/incite/schemas/mergers/foundations/enriched_wall.py index 1b78fde..8d16270 100644 --- a/generalresearch/incite/schemas/mergers/foundations/enriched_wall.py +++ b/generalresearch/incite/schemas/mergers/foundations/enriched_wall.py @@ -1,15 +1,15 @@ from datetime import timedelta import pandas as pd -from pandera import DataFrameSchema, Column, Check, Index +from pandera import Check, Column, DataFrameSchema, Index -from generalresearch.incite.schemas import PARTITION_ON, ARCHIVE_AFTER +from generalresearch.incite.schemas import ARCHIVE_AFTER, PARTITION_ON from generalresearch.locales import Localelator from generalresearch.models import DeviceType, Source from generalresearch.models.thl.definitions import ( + ReportValue, Status, StatusCode1, - ReportValue, WallStatusCode2, ) diff --git a/generalresearch/incite/schemas/mergers/foundations/user_id_product.py b/generalresearch/incite/schemas/mergers/foundations/user_id_product.py index 780a3f2..160f05b 100644 --- a/generalresearch/incite/schemas/mergers/foundations/user_id_product.py +++ b/generalresearch/incite/schemas/mergers/foundations/user_id_product.py @@ -1,6 +1,6 @@ from datetime import timedelta -from pandera import Column, Check, Index, Category, DataFrameSchema +from pandera import Category, Check, Column, DataFrameSchema, Index from generalresearch.incite.schemas import ARCHIVE_AFTER diff --git a/generalresearch/incite/schemas/mergers/nginx.py b/generalresearch/incite/schemas/mergers/nginx.py index 30e6fec..3d16738 100644 --- a/generalresearch/incite/schemas/mergers/nginx.py +++ b/generalresearch/incite/schemas/mergers/nginx.py @@ -5,9 +5,9 @@ from datetime import timedelta import pandas as pd -from pandera import DataFrameSchema, Column, Check, Index +from pandera import Check, Column, DataFrameSchema, Index -from generalresearch.incite.schemas import PARTITION_ON, ARCHIVE_AFTER +from generalresearch.incite.schemas import ARCHIVE_AFTER, PARTITION_ON NGINXBaseSchema = DataFrameSchema( columns={ diff --git a/generalresearch/incite/schemas/mergers/pop_ledger.py b/generalresearch/incite/schemas/mergers/pop_ledger.py index 25c7e68..af8e0c8 100644 --- a/generalresearch/incite/schemas/mergers/pop_ledger.py +++ b/generalresearch/incite/schemas/mergers/pop_ledger.py @@ -2,11 +2,11 @@ from datetime import timedelta import pandas as pd from more_itertools import flatten -from pandera import DataFrameSchema, Column, Check, Index +from pandera import Check, Column, DataFrameSchema, Index from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY, PARTITION_ON from generalresearch.incite.schemas.thl_web import TxSchema -from generalresearch.models.thl.ledger import TransactionType, Direction +from generalresearch.models.thl.ledger import Direction, TransactionType """ - In reality, a multi-index would be appropriate here, but dask does not support this, so we're keeping it flat. diff --git a/generalresearch/incite/schemas/mergers/ym_survey_wall.py b/generalresearch/incite/schemas/mergers/ym_survey_wall.py index 2b2d266..3882bd1 100644 --- a/generalresearch/incite/schemas/mergers/ym_survey_wall.py +++ b/generalresearch/incite/schemas/mergers/ym_survey_wall.py @@ -1,9 +1,9 @@ from datetime import timedelta -from pandera import DataFrameSchema, Column, Check, Index +from pandera import Check, Column, DataFrameSchema, Index -from generalresearch.incite.schemas import ORDER_KEY, ARCHIVE_AFTER -from generalresearch.incite.schemas.thl_web import THLWallSchema, THLSessionSchema +from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY +from generalresearch.incite.schemas.thl_web import THLSessionSchema, THLWallSchema thl_wall_columns = THLWallSchema.columns.copy() diff --git a/generalresearch/incite/schemas/mergers/ym_wall_summary.py b/generalresearch/incite/schemas/mergers/ym_wall_summary.py index fb34dec..5e97b47 100644 --- a/generalresearch/incite/schemas/mergers/ym_wall_summary.py +++ b/generalresearch/incite/schemas/mergers/ym_wall_summary.py @@ -1,7 +1,7 @@ from datetime import timedelta from typing import Set -from pandera import DataFrameSchema, Column, Check, Index +from pandera import Check, Column, DataFrameSchema, Index from generalresearch.incite.schemas import ARCHIVE_AFTER from generalresearch.locales import Localelator diff --git a/generalresearch/incite/schemas/thl_marketplaces.py b/generalresearch/incite/schemas/thl_marketplaces.py index 286db6a..6d7b83b 100644 --- a/generalresearch/incite/schemas/thl_marketplaces.py +++ b/generalresearch/incite/schemas/thl_marketplaces.py @@ -2,9 +2,9 @@ import copy from datetime import timedelta import pandas as pd -from pandera import Column, Check, Index, DataFrameSchema +from pandera import Check, Column, DataFrameSchema, Index -from generalresearch.incite.schemas import ORDER_KEY, ARCHIVE_AFTER +from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY BIGINT = 9223372036854775807 diff --git a/generalresearch/incite/schemas/thl_web.py b/generalresearch/incite/schemas/thl_web.py index a644720..c57a233 100644 --- a/generalresearch/incite/schemas/thl_web.py +++ b/generalresearch/incite/schemas/thl_web.py @@ -1,19 +1,19 @@ -from datetime import timezone, datetime, timedelta +from datetime import datetime, timedelta, timezone import pandas as pd -from pandera import DataFrameSchema, Column, Check, Index, MultiIndex +from pandera import Check, Column, DataFrameSchema, Index, MultiIndex -from generalresearch.incite.schemas import ORDER_KEY, ARCHIVE_AFTER +from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY from generalresearch.locales import Localelator from generalresearch.models import DeviceType, Source from generalresearch.models.thl.definitions import ( - StatusCode1, - WallStatusCode2, ReportValue, - WallAdjustedStatus, - Status, - SessionStatusCode2, SessionAdjustedStatus, + SessionStatusCode2, + Status, + StatusCode1, + WallAdjustedStatus, + WallStatusCode2, ) from generalresearch.models.thl.ledger import TransactionMetadataColumns from generalresearch.models.thl.maxmind.definitions import UserType diff --git a/generalresearch/managers/__init__.py b/generalresearch/managers/__init__.py index 8af988c..bc745fd 100644 --- a/generalresearch/managers/__init__.py +++ b/generalresearch/managers/__init__.py @@ -1,4 +1,4 @@ -def parse_order_by(order_by_str: str): +def parse_order_by(order_by_str: str) -> str: """ Converts django-rest-framework ordering str to mysql clause :param order_by_str: e.g. 'created,-name' diff --git a/generalresearch/managers/cint/profiling.py b/generalresearch/managers/cint/profiling.py index 4a7fc69..5e6e46a 100644 --- a/generalresearch/managers/cint/profiling.py +++ b/generalresearch/managers/cint/profiling.py @@ -1,5 +1,5 @@ import json -from typing import List, Collection, Optional, Tuple +from typing import Collection, List, Optional, Tuple from generalresearch.models.cint.question import CintQuestion from generalresearch.sql_helper import SqlHelper @@ -24,10 +24,13 @@ def get_profiling_library( :param is_live: filters on is_live field :param pks: The pk is (question_key, country_iso, language_iso). pks accepts a collection of len(3) tuples. e.g. [('CORE_AUTOMOTIVE_0002', 'us', 'eng'), ('AGE', 'us', 'spa')] + :return: """ + filters = [] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`country_iso` = %(country_iso)s") @@ -46,8 +49,10 @@ def get_profiling_library( if pks: params["pks"] = pks filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( f""" SELECT * @@ -56,7 +61,9 @@ def get_profiling_library( """, params, ) + for x in res: x["options"] = json.loads(x["options"]) if x["options"] else None + qs = [CintQuestion.from_db(x) for x in res] return qs diff --git a/generalresearch/managers/cint/survey.py b/generalresearch/managers/cint/survey.py index b7045c9..e8298fd 100644 --- a/generalresearch/managers/cint/survey.py +++ b/generalresearch/managers/cint/survey.py @@ -1,15 +1,15 @@ from __future__ import annotations import logging -from datetime import timezone, datetime -from typing import List, Collection, Optional, Set +from datetime import datetime, timezone +from typing import Collection, List, Optional, Set import pymysql from pymysql import IntegrityError from generalresearch.managers.criteria import CriteriaManager from generalresearch.managers.survey import SurveyManager -from generalresearch.models.cint.survey import CintSurvey, CintCondition +from generalresearch.models.cint.survey import CintCondition, CintSurvey logger = logging.getLogger() diff --git a/generalresearch/managers/dynata/profiling.py b/generalresearch/managers/dynata/profiling.py index 10d6c69..39e6591 100644 --- a/generalresearch/managers/dynata/profiling.py +++ b/generalresearch/managers/dynata/profiling.py @@ -1,5 +1,5 @@ import json -from typing import List, Collection, Optional, Tuple +from typing import Collection, List, Optional, Tuple from generalresearch.models.dynata.question import DynataQuestion from generalresearch.sql_helper import SqlHelper @@ -24,10 +24,13 @@ def get_profiling_library( :param is_live: filters on is_live field :param pks: The pk is (question_id, country_iso, language_iso). pks accepts a collection of len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] + :return: """ + filters = [] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`country_iso` = %(country_iso)s") @@ -46,8 +49,10 @@ def get_profiling_library( if pks: params["pks"] = pks filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( f""" SELECT * diff --git a/generalresearch/managers/dynata/survey.py b/generalresearch/managers/dynata/survey.py index a20345a..d20c1c8 100644 --- a/generalresearch/managers/dynata/survey.py +++ b/generalresearch/managers/dynata/survey.py @@ -1,15 +1,15 @@ from __future__ import annotations import logging -from datetime import timezone, datetime -from typing import List, Collection, Optional +from datetime import datetime, timezone +from typing import Collection, List, Optional import pymysql from pymysql import IntegrityError from generalresearch.managers.criteria import CriteriaManager from generalresearch.managers.survey import SurveyManager -from generalresearch.models.dynata.survey import DynataSurvey, DynataCondition +from generalresearch.models.dynata.survey import DynataCondition, DynataSurvey logger = logging.getLogger() @@ -64,8 +64,10 @@ class DynataSurveyManager(SurveyManager): :param is_live: filters on is_live field :param updated_since: filters on "> last_updated" """ + filters = [] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`country_iso` = %(country_iso)s") @@ -81,6 +83,7 @@ class DynataSurveyManager(SurveyManager): if updated_since is not None: params["updated_since"] = updated_since filters.append("last_updated > %(updated_since)s") + assert filters, "Must set at least 1 filter" filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" diff --git a/generalresearch/managers/events.py b/generalresearch/managers/events.py index 3e87879..c0c0d0f 100644 --- a/generalresearch/managers/events.py +++ b/generalresearch/managers/events.py @@ -1,34 +1,34 @@ import logging +import math import socket import threading import time -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from decimal import Decimal -from typing import Set, Optional, TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union -import math from redis.client import PubSub, Redis from generalresearch.managers.base import RedisManager from generalresearch.models import Source from generalresearch.models.custom_types import UUIDStr from generalresearch.models.events import ( - StatsMessage, - EventMessage, + AggregateBySource, EventEnvelope, + EventMessage, EventType, - TaskEnterPayload, - ServerToClientMessageAdapter, + MaxGaugeBySource, ServerToClientMessage, - TaskFinishPayload, + ServerToClientMessageAdapter, SessionEnterPayload, SessionFinishPayload, - AggregateBySource, - MaxGaugeBySource, + StatsMessage, + TaskEnterPayload, + TaskFinishPayload, TaskStatsSnapshot, ) from generalresearch.models.thl.definitions import Status -from generalresearch.models.thl.session import Wall, Session +from generalresearch.models.thl.session import Session, Wall from generalresearch.models.thl.user import User if TYPE_CHECKING: @@ -309,7 +309,9 @@ class TaskStatsManager(RedisManager): def get_active_sources(self) -> List[Source]: return [Source(x) for x in self.redis_client.hkeys("live_task_count")] - def get_task_stats_raw(self): + def get_task_stats_raw( + self, + ) -> Dict[str, Union[AggregateBySource, MaxGaugeBySource]]: sources = self.get_active_sources() pipe = self.redis_client.pipeline(transaction=False) @@ -365,7 +367,7 @@ class TaskStatsManager(RedisManager): "live_tasks_max_payout": live_tasks_max_payout, } - def clear_task_stats(self): + def clear_task_stats(self) -> None: keys = self.task_stats.copy() keys.extend([f"task_created_count_last_1h:{source.value}" for source in Source]) keys.extend( @@ -373,6 +375,8 @@ class TaskStatsManager(RedisManager): ) self.redis_client.delete(*keys) + return None + class SessionStatsManager(RedisManager): """ @@ -633,9 +637,6 @@ class EventManager(StatsManager): def get_active_subscribers(self) -> Set[UUIDStr]: res = self.redis_client.pubsub_channels(f"{self.cache_prefix}:event-channel:*") product_ids = {x.rsplit(":", 1)[-1] for x in res} - # product_ids.update( - # {"fc14e741b5004581b30e6478363414df", "888dbc589987425fa846d6e2a8daed04"} - # ) return product_ids def stats_worker(self): @@ -647,7 +648,7 @@ class EventManager(StatsManager): finally: time.sleep(60) - def stats_worker_task(self): + def stats_worker_task(self) -> None: """ Only a single worker will be running. It'll be responsible for periodic publication of summary/stats messages. @@ -682,6 +683,8 @@ class EventManager(StatsManager): self.redis_client.delete(lock_key) + return None + def make_influx_point(self, channel: str, numsub: int): return { "measurement": "redis_pubsub_subscribers", diff --git a/generalresearch/managers/gr/authentication.py b/generalresearch/managers/gr/authentication.py index f851cfa..7b8e526 100644 --- a/generalresearch/managers/gr/authentication.py +++ b/generalresearch/managers/gr/authentication.py @@ -2,22 +2,21 @@ import binascii import logging import os from datetime import datetime, timezone -from typing import Optional, List, TYPE_CHECKING, Dict, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from uuid import uuid4 from psycopg import sql -from pydantic import AnyHttpUrl -from pydantic import PositiveInt +from pydantic import AnyHttpUrl, PositiveInt -from generalresearch.managers.base import PostgresManagerWithRedis, PostgresManager +from generalresearch.managers.base import PostgresManager, PostgresManagerWithRedis from generalresearch.models.custom_types import UUIDStr -from generalresearch.redis_helper import RedisConfig from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig LOG = logging.getLogger("gr") if TYPE_CHECKING: - from generalresearch.models.gr.authentication import GRUser, GRToken + from generalresearch.models.gr.authentication import GRToken, GRUser class GRUserManager(PostgresManagerWithRedis): @@ -194,7 +193,7 @@ class GRTokenManager(PostgresManager): def get_by_key( self, api_key: str, - jwks: Optional[Dict] = None, + jwks: Optional[Dict[str, Any]] = None, audience: Optional[str] = None, issuer: Optional[Union[AnyHttpUrl, str]] = None, gr_redis_config: Optional[RedisConfig] = None, @@ -210,7 +209,7 @@ class GRTokenManager(PostgresManager): :return GRToken instance (minified version, no relationships) :raises NotFoundException """ - from generalresearch.models.gr.authentication import GRToken, Claims + from generalresearch.models.gr.authentication import Claims, GRToken # SSO Key if GRToken.is_sso(api_key): @@ -324,7 +323,7 @@ class GRTokenManager(PostgresManager): res = result[0] - for k, v in res.items(): + for k, _ in res.items(): if isinstance(res[k], datetime): res[k] = res[k].replace(tzinfo=timezone.utc) diff --git a/generalresearch/managers/gr/business.py b/generalresearch/managers/gr/business.py index 001a9e8..e9ec580 100644 --- a/generalresearch/managers/gr/business.py +++ b/generalresearch/managers/gr/business.py @@ -1,4 +1,4 @@ -from typing import Optional, List, TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional from uuid import UUID, uuid4 from psycopg import sql @@ -6,20 +6,20 @@ from pydantic import PositiveInt from pydantic_extra_types.phone_numbers import PhoneNumber from generalresearch.managers.base import ( - PostgresManagerWithRedis, PostgresManager, + PostgresManagerWithRedis, ) from generalresearch.models.custom_types import UUIDStr if TYPE_CHECKING: - from generalresearch.models.gr.team import Team from generalresearch.models.gr.business import ( Business, - BusinessType, BusinessAddress, BusinessBankAccount, + BusinessType, TransferMethod, ) + from generalresearch.models.gr.team import Team class BusinessBankAccountManager(PostgresManager): @@ -27,7 +27,7 @@ class BusinessBankAccountManager(PostgresManager): def create_dummy( self, business_id: PositiveInt, - uuid: Optional[UUID] = None, + uuid: Optional[UUIDStr] = None, transfer_method: Optional["TransferMethod"] = None, account_number: Optional[str] = None, routing_number: Optional[str] = None, @@ -36,21 +36,14 @@ class BusinessBankAccountManager(PostgresManager): ): from generalresearch.models.gr.business import TransferMethod - uuid = uuid or uuid4().hex - transfer_method = transfer_method or TransferMethod.ACH - account_number = account_number or uuid4().hex[:6] - routing_number = routing_number or uuid4().hex[:6] - iban = iban or uuid4().hex[:6] - swift = swift or uuid4().hex[:6] - return self.create( business_id=business_id, - uuid=uuid, - transfer_method=transfer_method, - account_number=account_number, - routing_number=routing_number, - iban=iban, - swift=swift, + uuid=uuid or uuid4().hex, + transfer_method=transfer_method or TransferMethod.ACH, + account_number=account_number or uuid4().hex[:6], + routing_number=routing_number or uuid4().hex[:6], + iban=iban or uuid4().hex[:6], + swift=swift or uuid4().hex[:6], ) def create( @@ -95,7 +88,7 @@ class BusinessBankAccountManager(PostgresManager): ), params=data, ) - ba_id = c.fetchone()["id"] + ba_id = c.fetchone()["id"] # type: ignore conn.commit() ba.id = ba_id @@ -194,14 +187,15 @@ class BusinessAddressManager(PostgresManager): (uuid, line_1, line_2, city, country, state, postal_code, phone_number, business_id) VALUES - (%(uuid)s, %(line_1)s, %(line_2)s, %(city)s, %(country)s, %(state)s, - %(postal_code)s, %(phone_number)s, %(business_id)s) + (%(uuid)s, %(line_1)s, %(line_2)s, %(city)s, + %(country)s, %(state)s, %(postal_code)s, + %(phone_number)s, %(business_id)s) RETURNING id """ ), params=data, ) - ba_id = c.fetchone()["id"] + ba_id = c.fetchone()["id"] # type: ignore conn.commit() ba.id = ba_id @@ -304,7 +298,7 @@ class BusinessManager(PostgresManagerWithRedis): ), params=data, ) - business_id = c.fetchone()["id"] + business_id = c.fetchone()["id"] # type: ignore conn.commit() business.id = business_id diff --git a/generalresearch/managers/gr/team.py b/generalresearch/managers/gr/team.py index a57ef5f..c6709d0 100644 --- a/generalresearch/managers/gr/team.py +++ b/generalresearch/managers/gr/team.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import Optional, List, TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional from uuid import uuid4 from psycopg import sql @@ -13,13 +13,13 @@ from generalresearch.models.custom_types import UUIDStr from generalresearch.models.gr.team import Membership, MembershipPrivilege if TYPE_CHECKING: + from generalresearch.models.gr.authentication import GRUser + from generalresearch.models.gr.business import Business from generalresearch.models.gr.team import ( Membership, - Team, MembershipPrivilege, + Team, ) - from generalresearch.models.gr.authentication import GRUser - from generalresearch.models.gr.business import Business class MembershipManager(PostgresManager): @@ -77,7 +77,7 @@ class MembershipManager(PostgresManager): ), params=data, ) - membership_id: int = c.fetchone()["id"] + membership_id: int = c.fetchone()["id"] # type: ignore conn.commit() membership.id = membership_id @@ -205,7 +205,7 @@ class TeamManager(PostgresManagerWithRedis): ), params=[team.uuid, team.name], ) - team_id = c.fetchone()["id"] + team_id = c.fetchone()["id"] # type: ignore conn.commit() team.id = team_id diff --git a/generalresearch/managers/innovate/profiling.py b/generalresearch/managers/innovate/profiling.py index a3939a7..81cb29e 100644 --- a/generalresearch/managers/innovate/profiling.py +++ b/generalresearch/managers/innovate/profiling.py @@ -1,5 +1,5 @@ import json -from typing import List, Collection, Optional, Tuple +from typing import Collection, List, Optional, Tuple from generalresearch.models.innovate.question import InnovateQuestion from generalresearch.sql_helper import SqlHelper @@ -22,12 +22,16 @@ def get_profiling_library( :param question_keys: filters on question_key field, accepts multiple values :param max_options: filters on max_options field :param is_live: filters on is_live field - :param pks: The pk is (question_key, country_iso, language_iso). pks accepts a collection of - len(3) tuples. e.g. [('CORE_AUTOMOTIVE_0002', 'us', 'eng'), ('AGE', 'us', 'spa')] + :param pks: The pk is (question_key, country_iso, language_iso). pks + accepts a collection of len(3) tuples. e.g. [('CORE_AUTOMOTIVE_0002', + 'us', 'eng'), ('AGE', 'us', 'spa')] + :return: """ + filters = [] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`country_iso` = %(country_iso)s") @@ -46,8 +50,10 @@ def get_profiling_library( if pks: params["pks"] = pks filters.append("(question_key, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( f""" SELECT * @@ -56,7 +62,9 @@ def get_profiling_library( """, params, ) + for x in res: x["options"] = json.loads(x["options"]) if x["options"] else None qs = [InnovateQuestion.from_db(x) for x in res] + return qs diff --git a/generalresearch/managers/innovate/survey.py b/generalresearch/managers/innovate/survey.py index 0e19065..bd86e34 100644 --- a/generalresearch/managers/innovate/survey.py +++ b/generalresearch/managers/innovate/survey.py @@ -1,8 +1,8 @@ from __future__ import annotations import logging -from datetime import timezone, datetime -from typing import List, Collection, Optional, Set +from datetime import datetime, timezone +from typing import Collection, List, Optional, Set import pymysql from pymysql import IntegrityError @@ -10,8 +10,8 @@ from pymysql import IntegrityError from generalresearch.managers.criteria import CriteriaManager from generalresearch.managers.survey import SurveyManager from generalresearch.models.innovate.survey import ( - InnovateSurvey, InnovateCondition, + InnovateSurvey, ) logger = logging.getLogger() @@ -77,8 +77,10 @@ class InnovateSurveyManager(SurveyManager): :param is_live: filters on is_live field :param updated_since: filters on "> updated" """ + filters = [] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`country_iso` = %(country_iso)s") @@ -88,18 +90,22 @@ class InnovateSurveyManager(SurveyManager): if survey_ids is not None: params["survey_ids"] = survey_ids filters.append("survey_id IN %(survey_ids)s") + if is_live is not None: if is_live: filters.append("is_live") else: filters.append("NOT is_live") + if updated_since is not None: params["updated_since"] = updated_since filters.append("updated > %(updated_since)s") + assert filters, "Must set at least 1 filter" filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" fields = set(self.SURVEY_FIELDS) | {"created", "updated"} + if exclude_fields: fields -= exclude_fields fields_str = ", ".join([f"`{v}`" for v in fields]) diff --git a/generalresearch/managers/leaderboard/manager.py b/generalresearch/managers/leaderboard/manager.py index 86b3d80..18dec92 100644 --- a/generalresearch/managers/leaderboard/manager.py +++ b/generalresearch/managers/leaderboard/manager.py @@ -1,19 +1,19 @@ -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from decimal import Decimal from functools import cached_property -from typing import Optional, TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional import pandas as pd from pandas import Period -from pydantic import NaiveDatetime, AwareDatetime +from pydantic import AwareDatetime, NaiveDatetime from redis import Redis from generalresearch.managers.leaderboard import country_timezone from generalresearch.models.thl.leaderboard import ( + Leaderboard, LeaderboardCode, LeaderboardFrequency, LeaderboardRow, - Leaderboard, ) if TYPE_CHECKING: @@ -130,7 +130,7 @@ class LeaderboardManager: def get_leaderboard( self, limit: Optional[int] = None, - bp_user_id=None, + bp_user_id: Optional[str] = None, ) -> Leaderboard: if bp_user_id: diff --git a/generalresearch/managers/leaderboard/tasks.py b/generalresearch/managers/leaderboard/tasks.py index e25c4d8..072a1aa 100644 --- a/generalresearch/managers/leaderboard/tasks.py +++ b/generalresearch/managers/leaderboard/tasks.py @@ -3,11 +3,11 @@ import logging from redis import Redis from generalresearch.managers.leaderboard.manager import LeaderboardManager -from generalresearch.models.thl.session import Session from generalresearch.models.thl.leaderboard import ( - LeaderboardFrequency, LeaderboardCode, + LeaderboardFrequency, ) +from generalresearch.models.thl.session import Session logger = logging.getLogger() diff --git a/generalresearch/managers/lucid/profiling.py b/generalresearch/managers/lucid/profiling.py index 5f084b1..ac2556d 100644 --- a/generalresearch/managers/lucid/profiling.py +++ b/generalresearch/managers/lucid/profiling.py @@ -1,8 +1,9 @@ import json -from typing import List, Collection, Optional, Tuple -from generalresearch.decorators import LOG +from typing import Collection, List, Optional, Tuple + from pydantic import ValidationError +from generalresearch.decorators import LOG from generalresearch.models.lucid.question import LucidQuestion, LucidQuestionType from generalresearch.sql_helper import SqlHelper @@ -24,8 +25,10 @@ def get_profiling_library( len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] :return: """ + filters = ["`q`.question_type != 'o'"] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`q`.`country_iso` = %(country_iso)s") @@ -40,8 +43,10 @@ def get_profiling_library( pks = [(int(x[0]), x[1], x[2]) for x in pks] params["pks"] = pks filters.append("(q.question_id, q.country_iso, q.language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" + db_name = sql_helper.db_name res = sql_helper.execute_sql_query( query=f""" @@ -74,6 +79,7 @@ def get_profiling_library( if x["question_id"] in {"116", "120", "121"}: x["question_type"] = LucidQuestionType.TEXT_ENTRY qs = [] + for x in res: try: qs.append(LucidQuestion.from_db(x)) diff --git a/generalresearch/managers/marketplace/user_pid.py b/generalresearch/managers/marketplace/user_pid.py index f731179..cadaea9 100644 --- a/generalresearch/managers/marketplace/user_pid.py +++ b/generalresearch/managers/marketplace/user_pid.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Collection, Optional, List, Dict +from typing import Collection, Dict, List, Optional from uuid import UUID from generalresearch.managers.base import SqlManager @@ -12,7 +12,7 @@ class UserPidManager(SqlManager, ABC): For getting user pids across marketplaces """ - SOURCE: Source = None + SOURCE: Optional[Source] = None TABLE_NAME = None def filter( diff --git a/generalresearch/managers/morning/profiling.py b/generalresearch/managers/morning/profiling.py index c397748..494e406 100644 --- a/generalresearch/managers/morning/profiling.py +++ b/generalresearch/managers/morning/profiling.py @@ -1,5 +1,5 @@ import json -from typing import List, Collection, Optional, Tuple +from typing import Collection, List, Optional, Tuple from generalresearch.models.morning.question import MorningQuestion from generalresearch.sql_helper import SqlHelper @@ -28,8 +28,10 @@ def get_profiling_library( len(3) tuples. e.g. [('employer_size', 'us', 'eng'), ('employer_size', 'us', 'spa')] :return: """ + filters = [] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`country_iso` = %(country_iso)s") @@ -51,8 +53,10 @@ def get_profiling_library( if pks: params["pks"] = pks filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( f""" SELECT * @@ -61,6 +65,7 @@ def get_profiling_library( """, params, ) + for x in res: x["options"] = json.loads(x["options"]) if x["options"] else None qs = [MorningQuestion.from_db(x) for x in res] diff --git a/generalresearch/managers/morning/survey.py b/generalresearch/managers/morning/survey.py index 9b08a65..3478e92 100644 --- a/generalresearch/managers/morning/survey.py +++ b/generalresearch/managers/morning/survey.py @@ -2,8 +2,8 @@ from __future__ import annotations import json import logging -from datetime import timezone, datetime -from typing import List, Collection, Optional +from datetime import datetime, timezone +from typing import Collection, List, Optional import pymysql from pymysql import IntegrityError @@ -184,7 +184,7 @@ class MorningSurveyManager(SurveyManager): for survey in surveys: self.update_one(survey, now=now) - def update_one(self, bid: MorningBid, now=None) -> bool: + def update_one(self, bid: MorningBid, now: Optional[datetime] = None) -> bool: if now is None: now = datetime.now(tz=timezone.utc) d = bid.to_mysql() diff --git a/generalresearch/managers/pollfish/profiling.py b/generalresearch/managers/pollfish/profiling.py index 735e824..4431784 100644 --- a/generalresearch/managers/pollfish/profiling.py +++ b/generalresearch/managers/pollfish/profiling.py @@ -1,5 +1,5 @@ import json -from typing import List, Collection, Optional, Tuple +from typing import Collection, List, Optional, Tuple from generalresearch.models.pollfish.question import PollfishQuestion from generalresearch.sql_helper import SqlHelper @@ -26,8 +26,10 @@ def get_profiling_library( len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] :return: """ + filters = [] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`country_iso` = %(country_iso)s") @@ -43,9 +45,11 @@ def get_profiling_library( if is_live is not None: params["is_live"] = is_live filters.append("is_live = %(is_live)s") + if pks: params["pks"] = pks filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" res = sql_helper.execute_sql_query( diff --git a/generalresearch/managers/precision/profiling.py b/generalresearch/managers/precision/profiling.py index 5f687ce..c4b24d8 100644 --- a/generalresearch/managers/precision/profiling.py +++ b/generalresearch/managers/precision/profiling.py @@ -1,5 +1,5 @@ import json -from typing import List, Collection, Optional, Tuple +from typing import Collection, List, Optional, Tuple from generalresearch.models.precision.question import PrecisionQuestion from generalresearch.sql_helper import SqlHelper @@ -26,8 +26,10 @@ def get_profiling_library( len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] :return: """ + filters = [] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`country_iso` = %(country_iso)s") @@ -46,6 +48,7 @@ def get_profiling_library( if pks: params["pks"] = pks filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" res = sql_helper.execute_sql_query( diff --git a/generalresearch/managers/precision/survey.py b/generalresearch/managers/precision/survey.py index fc4e037..112c938 100644 --- a/generalresearch/managers/precision/survey.py +++ b/generalresearch/managers/precision/survey.py @@ -1,8 +1,8 @@ from __future__ import annotations import logging -from datetime import timezone, datetime -from typing import List, Collection, Optional +from datetime import datetime, timezone +from typing import Collection, List, Optional import pymysql from pymysql import IntegrityError @@ -10,8 +10,8 @@ from pymysql import IntegrityError from generalresearch.managers.criteria import CriteriaManager from generalresearch.managers.survey import SurveyManager from generalresearch.models.precision.survey import ( - PrecisionSurvey, PrecisionCondition, + PrecisionSurvey, ) logger = logging.getLogger() @@ -62,8 +62,10 @@ class PrecisionSurveyManager(SurveyManager): :param is_live: filters on is_live field :param updated_since: filters on "> last_updated" """ + filters = [] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`country_iso` = %(country_iso)s") @@ -186,12 +188,12 @@ class PrecisionSurveyManager(SurveyManager): country_data = [(survey.survey_id, c) for c in survey.country_isos] # Turn ON countries in this survey's list of countries, insert row, if already exists, set active. c.executemany( - f""" + query=f""" INSERT INTO `thl-precision`.`precision_survey_country` (survey_id, country_iso, is_active) VALUES (%s, %s, TRUE) ON DUPLICATE KEY UPDATE is_active = TRUE; """, - country_data, + args=country_data, ) # Same thing with languages @@ -205,12 +207,12 @@ class PrecisionSurveyManager(SurveyManager): ) language_data = [(survey.survey_id, c) for c in survey.language_isos] c.executemany( - f""" + query=f""" INSERT INTO `thl-precision`.`precision_survey_language` (survey_id, language_iso, is_active) VALUES (%s, %s, TRUE) ON DUPLICATE KEY UPDATE is_active = TRUE; """, - language_data, + args=language_data, ) conn.commit() diff --git a/generalresearch/managers/prodege/profiling.py b/generalresearch/managers/prodege/profiling.py index a33364b..bf4e3cf 100644 --- a/generalresearch/managers/prodege/profiling.py +++ b/generalresearch/managers/prodege/profiling.py @@ -1,5 +1,5 @@ import json -from typing import List, Collection, Optional, Tuple +from typing import Collection, List, Optional, Tuple from generalresearch.models.prodege.question import ProdegeQuestion from generalresearch.sql_helper import SqlHelper @@ -26,8 +26,10 @@ def get_profiling_library( len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] :return: """ + filters = [] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`country_iso` = %(country_iso)s") @@ -46,6 +48,7 @@ def get_profiling_library( if pks: params["pks"] = pks filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" res = sql_helper.execute_sql_query( @@ -56,6 +59,7 @@ def get_profiling_library( """, params, ) + for x in res: x["options"] = json.loads(x["options"]) if x["options"] else None qs = [ProdegeQuestion.from_db(x) for x in res] diff --git a/generalresearch/managers/prodege/survey.py b/generalresearch/managers/prodege/survey.py index ec5665e..75fa34e 100644 --- a/generalresearch/managers/prodege/survey.py +++ b/generalresearch/managers/prodege/survey.py @@ -1,13 +1,13 @@ from __future__ import annotations -from datetime import timezone, datetime -from typing import List, Collection, Optional +from datetime import datetime, timezone +from typing import Collection, List, Optional import pymysql from generalresearch.managers.criteria import CriteriaManager from generalresearch.managers.survey import SurveyManager -from generalresearch.models.prodege.survey import ProdegeSurvey, ProdegeCondition +from generalresearch.models.prodege.survey import ProdegeCondition, ProdegeSurvey class ProdegeCriteriaManager(CriteriaManager): @@ -57,8 +57,10 @@ class ProdegeSurveyManager(SurveyManager): :param is_live: filters on is_live field :param updated_since: filters on "> updated" """ + filters = [] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`country_iso` = %(country_iso)s") diff --git a/generalresearch/managers/repdata/profiling.py b/generalresearch/managers/repdata/profiling.py index c508764..6a63c38 100644 --- a/generalresearch/managers/repdata/profiling.py +++ b/generalresearch/managers/repdata/profiling.py @@ -1,5 +1,5 @@ import json -from typing import List, Collection, Optional, Tuple +from typing import Collection, List, Optional, Tuple from generalresearch.models.repdata.question import RepDataQuestion from generalresearch.sql_helper import SqlHelper diff --git a/generalresearch/managers/repdata/survey.py b/generalresearch/managers/repdata/survey.py index 05465e9..eecdc04 100644 --- a/generalresearch/managers/repdata/survey.py +++ b/generalresearch/managers/repdata/survey.py @@ -1,18 +1,18 @@ from __future__ import annotations import json -from datetime import timezone, datetime -from typing import List, Collection, Optional +from datetime import datetime, timezone +from typing import Collection, List, Optional import pymysql from generalresearch.managers.criteria import CriteriaManager from generalresearch.managers.survey import SurveyManager from generalresearch.models.repdata.survey import ( + RepDataCondition, + RepDataStreamHashed, RepDataSurvey, RepDataSurveyHashed, - RepDataStreamHashed, - RepDataCondition, ) diff --git a/generalresearch/managers/sago/profiling.py b/generalresearch/managers/sago/profiling.py index e1ed97f..6d00ec4 100644 --- a/generalresearch/managers/sago/profiling.py +++ b/generalresearch/managers/sago/profiling.py @@ -1,5 +1,5 @@ import json -from typing import List, Collection, Optional, Tuple +from typing import Collection, List, Optional, Tuple from generalresearch.models.sago.question import SagoQuestion from generalresearch.sql_helper import SqlHelper @@ -24,10 +24,13 @@ def get_profiling_library( :param is_live: filters on is_live field :param pks: The pk is (question_id, country_iso, language_iso). pks accepts a collection of len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] + :return: """ + filters = [] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`country_iso` = %(country_iso)s") @@ -46,6 +49,7 @@ def get_profiling_library( if pks: params["pks"] = pks filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" res = sql_helper.execute_sql_query( @@ -56,6 +60,7 @@ def get_profiling_library( """, params, ) + for x in res: x["options"] = json.loads(x["options"]) if x["options"] else None qs = [SagoQuestion.from_db(x) for x in res] diff --git a/generalresearch/managers/sago/survey.py b/generalresearch/managers/sago/survey.py index 535d8bb..12b37e0 100644 --- a/generalresearch/managers/sago/survey.py +++ b/generalresearch/managers/sago/survey.py @@ -1,15 +1,15 @@ from __future__ import annotations import logging -from datetime import timezone, datetime -from typing import List, Collection, Optional, Set +from datetime import datetime, timezone +from typing import Collection, List, Optional, Set import pymysql from pymysql import IntegrityError from generalresearch.managers.criteria import CriteriaManager from generalresearch.managers.survey import SurveyManager -from generalresearch.models.sago.survey import SagoSurvey, SagoCondition +from generalresearch.models.sago.survey import SagoCondition, SagoSurvey logger = logging.getLogger() diff --git a/generalresearch/managers/spectrum/profiling.py b/generalresearch/managers/spectrum/profiling.py index 21575c6..8a9b9a9 100644 --- a/generalresearch/managers/spectrum/profiling.py +++ b/generalresearch/managers/spectrum/profiling.py @@ -1,5 +1,5 @@ import json -from typing import List, Collection, Optional, Tuple +from typing import Collection, List, Optional, Tuple from generalresearch.models.spectrum.question import SpectrumQuestion from generalresearch.sql_helper import SqlHelper @@ -26,8 +26,10 @@ def get_profiling_library( len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] :return: """ + filters = ["is_valid"] params = {} + if country_iso: params["country_iso"] = country_iso filters.append("`country_iso` = %(country_iso)s") @@ -46,8 +48,10 @@ def get_profiling_library( if pks: params["pks"] = pks filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( f""" SELECT * diff --git a/generalresearch/managers/spectrum/survey.py b/generalresearch/managers/spectrum/survey.py index 0dcc232..190450d 100644 --- a/generalresearch/managers/spectrum/survey.py +++ b/generalresearch/managers/spectrum/survey.py @@ -1,8 +1,8 @@ from __future__ import annotations import logging -from datetime import timezone, datetime -from typing import List, Collection, Optional +from datetime import datetime, timezone +from typing import Collection, List, Optional import pymysql from pymysql import IntegrityError @@ -10,8 +10,8 @@ from pymysql import IntegrityError from generalresearch.managers.criteria import CriteriaManager from generalresearch.managers.survey import SurveyManager from generalresearch.models.spectrum.survey import ( - SpectrumSurvey, SpectrumCondition, + SpectrumSurvey, ) logger = logging.getLogger() @@ -61,7 +61,7 @@ class SpectrumSurveyManager(SurveyManager): survey_ids: Optional[Collection[str]] = None, is_live: Optional[bool] = None, updated_since: Optional[datetime] = None, - fields=None, + fields: List[str] = None, ) -> List[SpectrumSurvey]: """ Accepts lots of optional filters. @@ -93,6 +93,7 @@ class SpectrumSurveyManager(SurveyManager): fields_str = "*" if fields: fields_str = ",".join(fields) + filter_str = " AND ".join(filters) filter_str = "WHERE " + filter_str if filter_str else "" diff --git a/generalresearch/managers/thl/buyer.py b/generalresearch/managers/thl/buyer.py index dd0b4f2..b264945 100644 --- a/generalresearch/managers/thl/buyer.py +++ b/generalresearch/managers/thl/buyer.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone -from typing import Collection, Dict, Optional +from typing import Collection, Dict, List, Optional -from generalresearch.managers.base import PostgresManager, Permission +from generalresearch.managers.base import Permission, PostgresManager from generalresearch.models import Source from generalresearch.models.thl.survey.buyer import Buyer from generalresearch.pg_helper import PostgresConfig @@ -42,10 +42,11 @@ class BuyerManager(PostgresManager): except KeyError: return None - def bulk_get_or_create(self, source: Source, codes: Collection[str]): + def bulk_get_or_create(self, source: Source, codes: Collection[str]) -> List[Buyer]: now = datetime.now(tz=timezone.utc) buyers = [] params_seq = [] + for code in codes: source_code = f"{source.value}:{code}" if source_code in self.source_code_buyer: @@ -83,9 +84,10 @@ class BuyerManager(PostgresManager): # Not required, just for ease of testing/deterministic buyers = sorted(buyers, key=lambda x: (x.source, x.code)) assert len(buyers) == len(codes), "something went wrong" + return buyers - def update(self, buyer: Buyer): + def update(self, buyer: Buyer) -> None: # label is the only thing that can be updated query = """ UPDATE marketplace_buyer @@ -103,7 +105,7 @@ class BuyerManager(PostgresManager): with conn.cursor() as c: c.execute(query, params=params) assert c.rowcount == 1 - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore if buyer.id is not None: assert buyer.id == pk else: diff --git a/generalresearch/managers/thl/cashout_method.py b/generalresearch/managers/thl/cashout_method.py index 98e8e66..accbb41 100644 --- a/generalresearch/managers/thl/cashout_method.py +++ b/generalresearch/managers/thl/cashout_method.py @@ -1,14 +1,16 @@ from copy import copy -from datetime import timezone, datetime -from typing import List, Optional, Collection, Dict -from uuid import uuid4, UUID +from datetime import datetime, timezone +from typing import Any, Collection, Dict, List, Optional +from uuid import UUID, uuid4 + +from pydantic import NonNegativeInt from generalresearch.managers.base import PostgresManager from generalresearch.models.thl.user import User from generalresearch.models.thl.wallet import PayoutType from generalresearch.models.thl.wallet.cashout_method import ( - CashoutMethod, CashMailCashoutMethodData, + CashoutMethod, PaypalCashoutMethodData, ) @@ -43,24 +45,25 @@ class CashoutMethodManager(PostgresManager): def delete_cashout_method(self, cm_id: str): db_res = self.pg_config.execute_sql_query( - f""" + query=""" SELECT id::uuid, user_id FROM accounting_cashoutmethod WHERE id = %s AND is_live LIMIT 1;""", - [cm_id], + params=[cm_id], ) res = next(iter(db_res), None) assert res, f"cashout method id {cm_id} not found" # Don't let anyone delete a non-user-scoped cashout method assert ( res["user_id"] is not None - ), f"error trying to delete non user-scoped cashout method" + ), "error trying to delete non user-scoped cashout method" + self.pg_config.execute_write( - f""" - UPDATE accounting_cashoutmethod SET is_live = FALSE - WHERE id = %s;""", - [cm_id], + query=""" + UPDATE accounting_cashoutmethod SET is_live = FALSE + WHERE id = %s;""", + params=[cm_id], ) def create_cash_in_mail_cashout_method( @@ -185,7 +188,7 @@ class CashoutMethodManager(PostgresManager): ext_id: Optional[str] = None, payout_types: Optional[Collection[PayoutType]] = None, is_live: Optional[bool] = True, - ) -> int: + ) -> NonNegativeInt: filter_str, params = self.make_filter_str( uuid=uuid, user=user, @@ -201,7 +204,7 @@ class CashoutMethodManager(PostgresManager): """, params=params, ) - return res[0]["cnt"] + return int(res[0]["cnt"]) # type: ignore def filter( self, @@ -248,7 +251,7 @@ class CashoutMethodManager(PostgresManager): "supported_payout_types": [x.value for x in supported_payout_types], "user_id": user.user_id, } - query = f""" + query = """ SELECT id::uuid, provider, ext_id, data::jsonb as _data_, user_id FROM accounting_cashoutmethod WHERE is_live @@ -276,14 +279,16 @@ class CashoutMethodManager(PostgresManager): return cms @staticmethod - def format_from_db(x: Dict, user: Optional[User] = None) -> CashoutMethod: + def format_from_db(x: Dict[str, Any], user: Optional[User] = None) -> CashoutMethod: x["id"] = UUID(x["id"]).hex + # The data column here is inconsistent. Pulling keys from the mysql 'data' col # and putting them into the base level. Renamed so that we don't overwrite # a col called "data" within the "_data_" field. for k in list(x["_data_"].keys()): if k in CashoutMethod.model_fields: x[k] = x["_data_"].pop(k) + x["type"] = PayoutType(x["provider"].upper()) if "data" not in x: x["data"] = dict() diff --git a/generalresearch/managers/thl/contest_manager.py b/generalresearch/managers/thl/contest_manager.py index 4dc02e3..696a497 100644 --- a/generalresearch/managers/thl/contest_manager.py +++ b/generalresearch/managers/thl/contest_manager.py @@ -1,9 +1,9 @@ -from datetime import timezone, datetime -from typing import List, Optional, Literal, cast, Collection, Tuple, Dict +from datetime import datetime, timezone +from typing import Any, Collection, Dict, List, Literal, Optional, Tuple, cast from uuid import UUID import redis -from pydantic import PositiveInt, NonNegativeInt +from pydantic import NonNegativeInt, PositiveInt from redis import Redis from generalresearch.managers.base import PostgresManager @@ -15,8 +15,8 @@ from generalresearch.managers.thl.user_manager.user_manager import ( ) from generalresearch.models.custom_types import UUIDStr from generalresearch.models.thl.contest import ( - ContestWinner, ContestPrize, + ContestWinner, ) from generalresearch.models.thl.contest.contest import ( Contest, @@ -34,20 +34,20 @@ from generalresearch.models.thl.contest.io import ( user_model_cls, ) from generalresearch.models.thl.contest.leaderboard import ( - LeaderboardContestUserView, LeaderboardContest, + LeaderboardContestUserView, ) from generalresearch.models.thl.contest.milestone import ( - MilestoneUserView, - MilestoneEntry, ContestEntryTrigger, MilestoneContest, + MilestoneEntry, + MilestoneUserView, ) from generalresearch.models.thl.contest.raffle import ( ContestEntry, ContestEntryType, - RaffleUserView, RaffleContest, + RaffleUserView, ) from generalresearch.models.thl.user import User @@ -139,7 +139,7 @@ class ContestBaseManager(PostgresManager): query=query, params=data, ) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() contest.id = pk @@ -182,9 +182,10 @@ class ContestBaseManager(PostgresManager): name_contains: Optional[str] = None, uuids: Optional[Collection[str]] = None, has_participants: Optional[bool] = None, - ) -> Tuple[str, Dict]: + ) -> Tuple[str, Dict[str, Any]]: filters = [] params = dict() + if product_id: params["product_id"] = product_id filters.append("product_id = %(product_id)s") @@ -207,6 +208,7 @@ class ContestBaseManager(PostgresManager): if name_contains is not None: params["name_contains"] = f"%{name_contains}%" filters.append("name ILIKE %(name_contains)s") + if uuids is not None: if len(uuids) == 0: # If we pass an empty list, the sql query will have a syntax error. Make it @@ -214,6 +216,7 @@ class ContestBaseManager(PostgresManager): uuids = ["0" * 32] params["uuids"] = uuids filters.append("uuid = ANY(%(uuids)s)") + if has_participants: filters.append("current_participants > 0") @@ -711,7 +714,7 @@ class RaffleContestManager(ContestBaseManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute( - query=f""" + query=""" INSERT INTO contest_contestentry (uuid, amount, user_id, created_at, updated_at, contest_id) @@ -721,7 +724,7 @@ class RaffleContestManager(ContestBaseManager): params=data, ) c.execute( - query=f""" + query=""" UPDATE contest_contest SET current_amount = %(current_amount)s, current_participants = %(current_participants)s @@ -985,11 +988,14 @@ class LeaderboardContestManager(ContestBaseManager): def end_contest_if_over( self, contest: Contest, ledger_manager: ThlLedgerManager ) -> None: - decision, reason = contest.should_end() + decision, _ = contest.should_end() + if decision: contest.end_contest() return self.end_contest_with_winners(contest, ledger_manager) + return None + class ContestManager( RaffleContestManager, MilestoneContestManager, LeaderboardContestManager diff --git a/generalresearch/managers/thl/ipinfo.py b/generalresearch/managers/thl/ipinfo.py index 563b179..f86573b 100644 --- a/generalresearch/managers/thl/ipinfo.py +++ b/generalresearch/managers/thl/ipinfo.py @@ -1,7 +1,7 @@ import ipaddress from decimal import Decimal from random import randint -from typing import List, Optional, Dict, Collection +from typing import Collection, Dict, List, Optional import faker import pymysql @@ -14,12 +14,12 @@ from generalresearch.managers.base import ( PostgresManagerWithRedis, ) from generalresearch.models.custom_types import ( - IPvAnyAddressStr, CountryISOLike, + IPvAnyAddressStr, ) from generalresearch.models.thl.ipinfo import ( - IPGeoname, GeoIPInformation, + IPGeoname, IPInformation, normalize_ip, ) @@ -83,14 +83,15 @@ class IPGeonameManager(PostgresManager): } ) self.pg_config.execute_write( - query=f""" + query=""" INSERT INTO thl_geoname ( geoname_id, country_iso, is_in_european_union, country_name, continent_code, continent_name, updated ) VALUES ( - %(geoname_id)s, %(country_iso)s, %(is_in_european_union)s, %(country_name)s, - %(continent_code)s, %(continent_name)s, %(updated)s + %(geoname_id)s, %(country_iso)s, %(is_in_european_union)s, + %(country_name)s, %(continent_code)s, %(continent_name)s, + %(updated)s ) ON CONFLICT (geoname_id) DO NOTHING; """, @@ -151,7 +152,7 @@ class IPGeonameManager(PostgresManager): ) self.pg_config.execute_write( - query=f""" + query=""" INSERT INTO thl_geoname ( geoname_id, continent_code, continent_name, country_iso, country_name, @@ -165,8 +166,8 @@ class IPGeonameManager(PostgresManager): %(country_iso)s, %(country_name)s, %(subdivision_1_iso)s, %(subdivision_1_name)s, %(subdivision_2_iso)s, %(subdivision_2_name)s, - %(city_name)s, %(metro_code)s, %(time_zone)s, %(is_in_european_union)s, - %(updated)s + %(city_name)s, %(metro_code)s, %(time_zone)s, + %(is_in_european_union)s, %(updated)s ) ON CONFLICT (geoname_id) DO NOTHING; """, @@ -208,7 +209,7 @@ class IPGeonameManager(PostgresManager): assert len(filter_ids) <= 500, "chunk me" c.execute( - query=f""" + query=""" SELECT g.geoname_id, g.continent_code, g.continent_name, g.country_iso, g.country_name, @@ -298,10 +299,11 @@ class IPInformationManager(PostgresManager): ) instance.normalize_ip() self.pg_config.execute_write( - query=f""" + query=""" INSERT INTO thl_ipinformation - (ip, country_iso, registered_country_iso, geoname_id, updated) - VALUES (%(ip)s, %(country_iso)s, %(registered_country_iso)s, %(geoname_id)s, %(updated)s) + (ip, country_iso, registered_country_iso, geoname_id, updated) + VALUES (%(ip)s, %(country_iso)s, %(registered_country_iso)s, + %(geoname_id)s, %(updated)s) ON CONFLICT (ip) DO NOTHING; """, params=instance.model_dump(mode="json"), @@ -367,7 +369,7 @@ class IPInformationManager(PostgresManager): instance.normalize_ip() self.pg_config.execute_write( - query=f""" + query=""" INSERT INTO thl_ipinformation ( ip, geoname_id, country_iso, registered_country_iso, @@ -466,7 +468,7 @@ class IPInformationManager(PostgresManager): normalized_ips = set(normalized_ip_lookup.values()) c.execute( - query=f""" + query=""" SELECT i.ip, i.geoname_id, i.country_iso, i.registered_country_iso, i.is_anonymous, i.is_anonymous_vpn, i.is_hosting_provider, @@ -500,9 +502,9 @@ class IPInformationManager(PostgresManager): WHERE updated >= NOW() - INTERVAL '12 hours' AND country_iso IS NULL; """ - numerator = list(pg_config.execute_sql_query(query=query))[0]["numerator"] + # numerator = list(pg_config.execute_sql_query(query=query))[0]["numerator"] - query = f""" + query = """ SELECT COUNT(1) AS denominator FROM thl_ipinformation WHERE updated >= NOW() - INTERVAL '12 hours' @@ -510,10 +512,11 @@ class IPInformationManager(PostgresManager): denominator = list(pg_config.execute_sql_query(query=query))[0]["denominator"] if denominator == 0: pass - percent_empty = numerator / (denominator or 1) + + # percent_empty = numerator / (denominator or 1) # TODO: Post to telegraf / grafana - return + return None class GeoIpInfoManager(PostgresManagerWithRedis): @@ -748,7 +751,7 @@ class GeoIpInfoManager(PostgresManagerWithRedis): normalized_ips.update({self.compress_ip(ip) for ip in ips}) c.execute( - query=f""" + query=""" SELECT geo.geoname_id, geo.continent_name, @@ -801,9 +804,11 @@ class GeoIpInfoManager(PostgresManagerWithRedis): raise ValueError( f'mismatch between ipinfo country {d["country_iso"]} and geoname country {d["geo_country_iso"]}' ) + gs = [GeoIPInformation.from_mysql(i) for i in res] gs = {g.ip: g for g in gs} res2 = dict() + for ip, (normalized_ip, lookup_prefix) in ip_norm_lookup.items(): if normalized_ip not in gs: # also can remove 28 days after 2025-11-15 diff --git a/generalresearch/managers/thl/ledger_manager/conditions.py b/generalresearch/managers/thl/ledger_manager/conditions.py index 3c03300..399d28f 100644 --- a/generalresearch/managers/thl/ledger_manager/conditions.py +++ b/generalresearch/managers/thl/ledger_manager/conditions.py @@ -1,12 +1,12 @@ import logging -from datetime import datetime, timezone, timedelta -from typing import Callable, Optional, Tuple, TYPE_CHECKING +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Callable, Optional, Tuple -from generalresearch.config import JAMES_BILLINGS_TX_CUTOFF, JAMES_BILLINGS_BPID +from generalresearch.config import JAMES_BILLINGS_BPID, JAMES_BILLINGS_TX_CUTOFF from generalresearch.currency import USDCent from generalresearch.models.custom_types import UUIDStr from generalresearch.models.thl.product import Product -from generalresearch.models.thl.session import Wall, Session +from generalresearch.models.thl.session import Session, Wall from generalresearch.models.thl.user import User logging.basicConfig() @@ -22,7 +22,7 @@ if TYPE_CHECKING: ) -def generate_condition_mp_payment(wall: "Wall") -> Callable: +def generate_condition_mp_payment(wall: "Wall") -> Callable[..., bool]: """This returns a function that checks if the payment for this wall event exists already. This function gets run after we acquire a lock. It should return True if we want to continue (create a tx). @@ -37,7 +37,7 @@ def generate_condition_mp_payment(wall: "Wall") -> Callable: return _condition -def generate_condition_bp_payment(session: "Session") -> Callable: +def generate_condition_bp_payment(session: "Session") -> Callable[..., bool]: """This returns a function that checks if the payment for this Session exists already. This function gets run after we acquire a lock. It should return True if we want to continue (create a tx). @@ -52,7 +52,7 @@ def generate_condition_bp_payment(session: "Session") -> Callable: return _condition -def generate_condition_tag_exists(tag: str) -> Callable: +def generate_condition_tag_exists(tag: str) -> Callable[..., bool]: """This returns a function that checks if a tx with this tag already exists. It should return True if we want to continue (create a tx). """ @@ -70,7 +70,7 @@ def generate_condition_bp_payout( payoutevent_uuid: UUIDStr, skip_one_per_day_check: bool = False, skip_wallet_balance_check: bool = False, -) -> Callable: +) -> Callable[..., Tuple[bool, str]]: created = datetime.now(tz=timezone.utc) def _condition( @@ -110,7 +110,7 @@ def generate_condition_bp_payout( def generate_condition_user_payout_request( user: User, payoutevent_uuid: UUIDStr, min_balance: Optional[int] = None -) -> Callable: +) -> Callable[..., bool]: """This returns a function that checks if `user` has at least `min_balance` in their wallet and that a payout request hasn't already been issued with this payoutevent_uuid. @@ -150,7 +150,7 @@ def generate_condition_user_payout_request( def generate_condition_enter_contest( user: User, tag: str, min_balance: USDCent -) -> Callable: +) -> Callable[..., Tuple[bool, str]]: """This returns a function that checks if `user` has at least `min_balance` in their wallet and that a tx doesn't already exist with this tag @@ -177,7 +177,7 @@ def generate_condition_enter_contest( def generate_condition_user_payout_action( payoutevent_uuid: UUIDStr, action: str -) -> Callable: +) -> Callable[..., bool]: """The balance has already been taken from the user's wallet, so there is no balance check. We only just check that the ledger transaction doesn't already exist. diff --git a/generalresearch/managers/thl/ledger_manager/ledger.py b/generalresearch/managers/thl/ledger_manager/ledger.py index 3c65cbd..d463b55 100644 --- a/generalresearch/managers/thl/ledger_manager/ledger.py +++ b/generalresearch/managers/thl/ledger_manager/ledger.py @@ -1,12 +1,12 @@ import logging from collections import defaultdict from datetime import datetime, timedelta, timezone -from typing import Callable, Collection, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Union from uuid import UUID import redis from more_itertools import chunked, flatten -from pydantic import AwareDatetime, PositiveInt +from pydantic import AwareDatetime, NonNegativeInt, PositiveInt from redis.exceptions import LockError, LockNotOwnedError from generalresearch.currency import LedgerCurrency @@ -206,10 +206,10 @@ class LedgerTransactionManager(LedgerManagerBasePostgres): def create_tx_protected( self, lock_key: str, - condition: Callable, + condition: Callable[..., Union[bool, Tuple[bool, str]]], create_tx_func: Callable, flag_key: Optional[str] = None, - skip_flag_check=False, + skip_flag_check: bool = False, ) -> LedgerTransaction: """ The complexity here is that even we protect a transaction logic with @@ -287,12 +287,15 @@ class LedgerTransactionManager(LedgerManagerBasePostgres): # taking a long time, this would allow the tx to get retried # over and over every 4 seconds, which is not good. rc.set(name=flag_name, value=1, ex=3600 * 24) + # Condition returns either bool or Tuple[bool, str] condition_res = condition(self) + if isinstance(condition_res, tuple): condition_res, condition_msg = condition_res else: condition_msg = "" + if condition_res is False: rc.delete(flag_name) raise LedgerTransactionConditionFailedError(condition_msg) @@ -404,7 +407,7 @@ class LedgerTransactionManager(LedgerManagerBasePostgres): @staticmethod def process_get_tx_mysql_rows_json( - rows: Collection[Dict], + rows: Collection[Dict[str, Any]], ) -> List[LedgerTransaction]: """Columns: transaction_id, created, ext_description, tag, key_value_pairs, entries_json @@ -493,7 +496,7 @@ class LedgerTransactionManager(LedgerManagerBasePostgres): account_uuid: UUIDStr, time_start: Optional[datetime] = None, time_end: Optional[datetime] = None, - ): + ) -> NonNegativeInt: filter_str, params = self.make_filter_str( time_start=time_start, time_end=time_end, @@ -533,7 +536,7 @@ class LedgerTransactionManager(LedgerManagerBasePostgres): account_uuid: str, oldest_created: datetime, exclude_txs_before: Optional[datetime] = None, - ): + ) -> NonNegativeInt: """ In a paginated list of txs, if I want to calculate a running balance, I need the balance in that account @@ -560,6 +563,7 @@ class LedgerTransactionManager(LedgerManagerBasePostgres): AND lt.created < %(oldest_created)s {exclude_str};""" res = self.pg_config.execute_sql_query(query, params=params) + return res[0]["balance_before_page"] def include_running_balance( @@ -735,7 +739,7 @@ class LedgerMetadataManager(LedgerManagerBasePostgres): def get_tx_metadata_by_txs( self, transactions: List[LedgerTransaction] - ) -> Dict[PositiveInt, Dict]: + ) -> Dict[PositiveInt, Dict[str, Any]]: """ Each transaction can have 1 metadata dictionary. However, each metadata dictionary can have multiple key/value pairs that @@ -847,7 +851,7 @@ class LedgerAccountManager(LedgerManagerBasePostgres): return account def get_account( - self, qualified_name: str, raise_on_error=True + self, qualified_name: str, raise_on_error: bool = True ) -> Optional[LedgerAccount]: res = self.get_account_many( qualified_names=[qualified_name], raise_on_error=raise_on_error @@ -855,8 +859,8 @@ class LedgerAccountManager(LedgerManagerBasePostgres): return res[0] if len(res) == 1 else None def get_account_many_( - self, qualified_names: List[str], raise_on_error=True - ) -> List[Dict]: + self, qualified_names: List[str], raise_on_error: bool = True + ) -> List[Dict[str, Any]]: assert len(qualified_names) <= 500, "chunk me" # qualified_name has a unique index so there can only be 0 or 1 match. @@ -878,7 +882,7 @@ class LedgerAccountManager(LedgerManagerBasePostgres): return list(res) def get_account_many( - self, qualified_names: List[str], raise_on_error=True + self, qualified_names: List[str], raise_on_error: bool = True ) -> List[LedgerAccount]: res = flatten( [ @@ -1103,7 +1107,7 @@ class LedgerManager( self, time_start: Optional[AwareDatetime] = None, time_end: Optional[AwareDatetime] = None, - ) -> Dict: + ) -> Dict[str, Any]: filter_str, params = self.make_filter_str( time_end=time_end, diff --git a/generalresearch/managers/thl/ledger_manager/thl_ledger.py b/generalresearch/managers/thl/ledger_manager/thl_ledger.py index 977b68f..8004320 100644 --- a/generalresearch/managers/thl/ledger_manager/thl_ledger.py +++ b/generalresearch/managers/thl/ledger_manager/thl_ledger.py @@ -1,7 +1,7 @@ import logging -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from decimal import Decimal -from typing import Optional, Callable, Collection, List, TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Collection, List, Optional from uuid import UUID import numpy as np @@ -15,13 +15,13 @@ from generalresearch.config import ( from generalresearch.currency import USDCent from generalresearch.managers.base import Permission from generalresearch.managers.thl.ledger_manager.conditions import ( - generate_condition_mp_payment, generate_condition_bp_payment, generate_condition_bp_payout, - generate_condition_user_payout_request, - generate_condition_user_payout_action, - generate_condition_tag_exists, generate_condition_enter_contest, + generate_condition_mp_payment, + generate_condition_tag_exists, + generate_condition_user_payout_action, + generate_condition_user_payout_request, ) from generalresearch.managers.thl.ledger_manager.ledger import ( LedgerManager, @@ -39,18 +39,20 @@ from generalresearch.models.thl.contest.raffle import ( RaffleContest, ) from generalresearch.models.thl.ledger import ( - LedgerAccount, + AccountType, Direction, - LedgerTransaction, + LedgerAccount, LedgerEntry, - AccountType, + LedgerTransaction, TransactionType, - TransactionMetadataColumns as tmc, UserLedgerTransactions, ) +from generalresearch.models.thl.ledger import ( + TransactionMetadataColumns as tmc, +) from generalresearch.models.thl.payout import UserPayoutEvent from generalresearch.models.thl.product import Product -from generalresearch.models.thl.session import Status, Session, Wall +from generalresearch.models.thl.session import Session, Status, Wall from generalresearch.models.thl.user import User from generalresearch.models.thl.wallet import PayoutType @@ -207,12 +209,18 @@ class ThlLedgerManager(LedgerManager): return self.get_account_or_create(account=account) def get_account_task_complete_revenue(self) -> "LedgerAccount": - return self.get_account( + res = self.get_account( qualified_name=f"{self.currency.value}:revenue:task_complete" ) + assert res is not None, "Revenue account for task complete does not exist" + + return res def get_account_cash(self) -> "LedgerAccount": - return self.get_account(qualified_name=f"{self.currency.value}:cash") + res = self.get_account(qualified_name=f"{self.currency.value}:cash") + assert res is not None, "Cash account does not exist" + + return res def get_accounts_bp_wallet_for_products( self, product_uuids: Collection[UUIDStr] @@ -263,7 +271,7 @@ class ThlLedgerManager(LedgerManager): wall: Wall, user: User, created: Optional[datetime] = None, - force=False, + force: bool = False, ) -> PositiveInt: """ Create a transaction when we complete a task from a marketplace, @@ -302,7 +310,10 @@ class ThlLedgerManager(LedgerManager): } # This tag should uniquely identify this transaction (which should only happen once!) tag = f"{self.currency.value}:mp_payment:{wall.uuid}" + + assert wall.cpi is not None amount = round(wall.cpi * 100) + entries = [ LedgerEntry( direction=Direction.CREDIT, @@ -327,7 +338,7 @@ class ThlLedgerManager(LedgerManager): return t def create_tx_bp_payment( - self, session: Session, created: Optional[datetime] = None, force=False + self, session: Session, created: Optional[datetime] = None, force: bool = False ) -> LedgerTransaction: """ Create a transaction when we decide to report a session as complete @@ -750,9 +761,9 @@ class ThlLedgerManager(LedgerManager): amount: USDCent, payoutevent_uuid: UUIDStr, created: AwareDatetime, - skip_wallet_balance_check=False, - skip_one_per_day_check=False, - skip_flag_check=False, + skip_wallet_balance_check: bool = False, + skip_one_per_day_check: bool = False, + skip_flag_check: bool = False, ) -> LedgerTransaction: """This is when we pay "OUT" a BP their wallet balance. (Not a payment for a task complete) @@ -859,7 +870,7 @@ class ThlLedgerManager(LedgerManager): created: AwareDatetime, direction: Direction = Direction.DEBIT, description: Optional[str] = None, - skip_flag_check=False, + skip_flag_check: bool = False, ) -> LedgerTransaction: """https://en.wikipedia.org/wiki/Plug_(accounting) @@ -1869,6 +1880,7 @@ class ThlLedgerManager(LedgerManager): ), columns=["finished", "user_payout"], ) + if wall.empty: reserve = 0 else: @@ -1878,16 +1890,17 @@ class ThlLedgerManager(LedgerManager): wall["pct_rdm"] = wall["days_since_complete"].apply(self.get_redeemable_pct) wall.loc[wall["pct_rdm"] > 0.95, "pct_rdm"] = 1 wall["redeemable"] = wall["pct_rdm"] * wall["user_payout_int"] - # Calculate money needed to save in reserve to cover the difference between - # money earned from completes and $ redeemable, subtract that from the - # wall balance. + # Calculate money needed to save in reserve to cover the difference + # between money earned from completes and $ redeemable, subtract + # that from the wall balance. reserve = round(wall["user_payout_int"].sum() - wall["redeemable"].sum()) + redeemable_balance = user_wallet_balance - reserve redeemable_balance = 0 if redeemable_balance < 0 else redeemable_balance if redeemable_balance > 0: - # it is possible the user_wallet_balance is negative, in which case the redeemable - # balance is 0. Don't fail assertion if that happens. + # it is possible the user_wallet_balance is negative, in which case + # the redeemable balance is 0. Don't fail assertion if that happens. assert redeemable_balance <= user_wallet_balance return redeemable_balance diff --git a/generalresearch/managers/thl/maxmind/__init__.py b/generalresearch/managers/thl/maxmind/__init__.py index 59e0af8..3bf0d07 100644 --- a/generalresearch/managers/thl/maxmind/__init__.py +++ b/generalresearch/managers/thl/maxmind/__init__.py @@ -7,9 +7,9 @@ from generalresearch.managers.base import ( PostgresManagerWithRedis, ) from generalresearch.managers.thl.ipinfo import ( - IPInformationManager, - IPGeonameManager, GeoIpInfoManager, + IPGeonameManager, + IPInformationManager, ) from generalresearch.managers.thl.maxmind.basic import MaxmindBasicManager from generalresearch.managers.thl.maxmind.insights import ( @@ -18,9 +18,9 @@ from generalresearch.managers.thl.maxmind.insights import ( ) from generalresearch.models.custom_types import IPvAnyAddressStr from generalresearch.models.thl.ipinfo import ( - IPInformation, - IPGeoname, GeoIPInformation, + IPGeoname, + IPInformation, normalize_ip, ) from generalresearch.pg_helper import PostgresConfig diff --git a/generalresearch/managers/thl/maxmind/basic.py b/generalresearch/managers/thl/maxmind/basic.py index d065c13..72479df 100644 --- a/generalresearch/managers/thl/maxmind/basic.py +++ b/generalresearch/managers/thl/maxmind/basic.py @@ -10,16 +10,15 @@ from uuid import uuid4 import geoip2.database import geoip2.models import requests -from cachetools import cached, TTLCache +from cachetools import TTLCache, cached from geoip2.errors import AddressNotFoundError from generalresearch.managers.base import Manager from generalresearch.models.custom_types import ( - IPvAnyAddressStr, CountryISOLike, + IPvAnyAddressStr, ) - logger = logging.getLogger() diff --git a/generalresearch/managers/thl/maxmind/insights.py b/generalresearch/managers/thl/maxmind/insights.py index b83bded..13dc3e8 100644 --- a/generalresearch/managers/thl/maxmind/insights.py +++ b/generalresearch/managers/thl/maxmind/insights.py @@ -1,10 +1,8 @@ import logging from typing import Optional -import geoip2.database import geoip2.models import geoip2.webservice -import slack from geoip2.errors import ( AddressNotFoundError, AuthenticationError, diff --git a/generalresearch/managers/thl/payout.py b/generalresearch/managers/thl/payout.py index e99cc25..0d00edb 100644 --- a/generalresearch/managers/thl/payout.py +++ b/generalresearch/managers/thl/payout.py @@ -1,14 +1,15 @@ from collections import defaultdict -from datetime import timezone, datetime, timedelta -from random import randint, choice as rand_choice +from datetime import datetime, timedelta, timezone +from random import choice as rand_choice +from random import randint from time import sleep -from typing import Collection, Optional, Dict, List, Union +from typing import Any, Collection, Dict, List, Optional, Union from uuid import UUID, uuid4 import numpy as np import pandas as pd from psycopg import sql -from pydantic import AwareDatetime, PositiveInt, NonNegativeInt +from pydantic import AwareDatetime, NonNegativeInt, PositiveInt from generalresearch.currency import USDCent from generalresearch.decorators import LOG @@ -23,21 +24,21 @@ from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr from generalresearch.models.gr.business import Business from generalresearch.models.thl.definitions import PayoutStatus from generalresearch.models.thl.ledger import ( - LedgerAccount, Direction, + LedgerAccount, OrderBy, ) from generalresearch.models.thl.payout import ( - PayoutEvent, - UserPayoutEvent, BrokerageProductPayoutEvent, BusinessPayoutEvent, + PayoutEvent, + UserPayoutEvent, ) from generalresearch.models.thl.product import Product from generalresearch.models.thl.wallet import PayoutType from generalresearch.models.thl.wallet.cashout_method import ( - CashoutRequestInfo, CashMailOrderData, + CashoutRequestInfo, ) @@ -93,7 +94,7 @@ class PayoutEventManager(PostgresManagerWithRedis): payout_event: Union[UserPayoutEvent, BrokerageProductPayoutEvent], status: PayoutStatus, ext_ref_id: Optional[str] = None, - order_data: Optional[Dict] = None, + order_data: Optional[Dict[str, Any]] = None, ) -> None: # These 3 things are the only modifiable attributes ext_ref_id = ext_ref_id if ext_ref_id is not None else payout_event.ext_ref_id @@ -126,7 +127,7 @@ class UserPayoutEventManager(PayoutEventManager): def get_by_uuid(self, pe_uuid: UUIDStr) -> UserPayoutEvent: res = self.pg_config.execute_sql_query( - query=f""" + query=""" SELECT ep.uuid, ep.debit_account_uuid, ep.cashout_method_uuid, @@ -162,7 +163,7 @@ class UserPayoutEventManager(PayoutEventManager): pe = self.get_by_uuid(pe_uuid=pe_uuid) transaction_info = dict() - order: Dict = pe.order_data + order: Dict[str, Any] = pe.order_data if pe.payout_type == PayoutType.TANGO and pe.status == PayoutStatus.COMPLETE: reward = order["reward"] if "credentialList" in reward: @@ -215,10 +216,12 @@ class UserPayoutEventManager(PayoutEventManager): """ args = [] filters = [] + if reference_uuid: # This could be a product_id or a user_uuid filters.append("la.reference_uuid = %s") args.append(reference_uuid) + if debit_account_uuids: # Or we could use the bp_wallet or user_wallet's account uuid # instead of looking up by the product/user @@ -290,13 +293,13 @@ class UserPayoutEventManager(PayoutEventManager): uuid: Optional[UUIDStr] = None, status: Optional[PayoutStatus] = None, created: Optional[AwareDatetimeISO] = None, - request_data: Optional[Dict] = None, + request_data: Optional[Dict[str, Any]] = None, # --- Optional: None --- account_reference_type: Optional[str] = None, account_reference_uuid: Optional[UUIDStr] = None, description: Optional[str] = None, ext_ref_id: Optional[str] = None, - order_data: Optional[Dict | CashMailOrderData] = None, + order_data: Optional[Union[Dict[str, Any], CashMailOrderData]] = None, ) -> UserPayoutEvent: payout_event = UserPayoutEvent( @@ -319,10 +322,12 @@ class UserPayoutEventManager(PayoutEventManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute( - query=f""" + query=""" INSERT INTO event_payout ( - uuid, debit_account_uuid, created, cashout_method_uuid, amount, - status, ext_ref_id, payout_type, order_data, request_data + uuid, debit_account_uuid, created, + cashout_method_uuid, amount, status, + ext_ref_id, payout_type, order_data, + request_data ) VALUES ( %(uuid)s, %(debit_account_uuid)s, %(created)s, %(cashout_method_uuid)s, %(amount)s, %(status)s, @@ -350,9 +355,10 @@ class UserPayoutEventManager(PayoutEventManager): status: Optional[PayoutStatus] = None, ext_ref_id: Optional[str] = None, payout_type: Optional[PayoutType] = None, - request_data: Optional[Dict] = None, - order_data: Optional[Dict | CashMailOrderData] = None, + request_data: Optional[Dict[str, Any]] = None, + order_data: Optional[Union[Dict[str, Any], CashMailOrderData]] = None, ) -> UserPayoutEvent: + debit_account_uuid = debit_account_uuid or uuid4().hex cashout_method_uuid = cashout_method_uuid or uuid4().hex # account_reference_type = account_reference_type or f"acct-ref-{uuid4().hex}" @@ -396,7 +402,7 @@ class BrokerageProductPayoutEventManager(PayoutEventManager): ) -> BrokerageProductPayoutEvent: res = self.pg_config.execute_sql_query( - query=f""" + query=""" SELECT ep.uuid, ep.debit_account_uuid, ep.cashout_method_uuid, @@ -418,6 +424,7 @@ class BrokerageProductPayoutEventManager(PayoutEventManager): rc = self.redis_client account_product_mapping: Dict = rc.hgetall(name="pem:account_to_product") assert isinstance(account_product_mapping, dict) + d["product_id"] = account_product_mapping[d["debit_account_uuid"]] return BrokerageProductPayoutEvent.model_validate(d) @@ -477,11 +484,12 @@ class BrokerageProductPayoutEventManager(PayoutEventManager): status: Optional[PayoutStatus] = None, ext_ref_id: Optional[str] = None, payout_type: PayoutType = None, - request_data: Dict = None, - order_data: Optional[Dict | CashMailOrderData] = None, + request_data: Optional[Dict[str, Any]] = None, + order_data: Optional[Union[Dict[str, Any], CashMailOrderData]] = None, # --- Support resources --- account_product_mapping: Optional[Dict[UUIDStr, UUIDStr]] = None, ) -> BrokerageProductPayoutEvent: + if request_data is None: request_data = dict() @@ -509,7 +517,7 @@ class BrokerageProductPayoutEventManager(PayoutEventManager): d = bp_payout_event.model_dump_mysql() self.pg_config.execute_write( - query=f""" + query=""" INSERT INTO event_payout ( uuid, debit_account_uuid, created, cashout_method_uuid, amount, status, ext_ref_id, payout_type, order_data, request_data @@ -685,7 +693,8 @@ class BrokerageProductPayoutEventManager(PayoutEventManager): skip_wallet_balance_check: bool = False, skip_one_per_day_check: bool = False, ) -> BrokerageProductPayoutEvent: - """If a create_bp_payout_event call fails, this can be called with + """ + If a create_bp_payout_event call fails, this can be called with the associated payoutevent. """ bp_pe: BrokerageProductPayoutEvent = self.get_by_uuid(payout_event_uuid) @@ -1005,7 +1014,7 @@ class BusinessPayoutEventManager(BrokerageProductPayoutEventManager): grouped[bp_pe.ext_ref_id].append(bp_pe) res = [] - for ex_ref_id, members in grouped.items(): + for _, members in grouped.items(): res.append(BusinessPayoutEvent.model_validate({"bp_payouts": members})) return res @@ -1077,8 +1086,8 @@ class BusinessPayoutEventManager(BrokerageProductPayoutEventManager): def distribute_amount( df: pd.DataFrame, amount: USDCent, - weight_col="weight", - balance_col="remaining_balance", + weight_col: str = "weight", + balance_col: str = "remaining_balance", ) -> pd.Series: """ Distributes an integer amount across dataframe rows proportionally, @@ -1129,7 +1138,7 @@ class BusinessPayoutEventManager(BrokerageProductPayoutEventManager): from itertools import islice # Add 1 cent to the top 'shortage' rows - for idx, value in islice(remainders.items(), shortage): + for idx, _ in islice(remainders.items(), shortage): # Only add if it doesn't exceed the balance if allocation.loc[idx] < df[balance_col].loc[idx]: allocation.loc[idx] += 1 diff --git a/generalresearch/managers/thl/product.py b/generalresearch/managers/thl/product.py index c7d38c2..a3a8da6 100644 --- a/generalresearch/managers/thl/product.py +++ b/generalresearch/managers/thl/product.py @@ -1,11 +1,11 @@ import json import logging import operator -from datetime import timezone, datetime +from datetime import datetime, timezone from decimal import Decimal from threading import Lock -from typing import Collection, Optional, List, TYPE_CHECKING, Union -from uuid import uuid4, UUID +from typing import TYPE_CHECKING, Collection, List, Optional, Union +from uuid import UUID, uuid4 from cachetools import TTLCache, cachedmethod, keys from more_itertools import chunked @@ -24,16 +24,16 @@ from generalresearch.pg_helper import PostgresConfig logger = logging.getLogger() if TYPE_CHECKING: - from generalresearch.models.thl.product import Product from generalresearch.models.thl.product import ( - UserCreateConfig, PayoutConfig, + Product, + ProfilingConfig, SessionConfig, - UserWalletConfig, SourcesConfig, - UserHealthConfig, - ProfilingConfig, SupplyConfigs, + UserCreateConfig, + UserHealthConfig, + UserWalletConfig, ) @@ -113,7 +113,7 @@ class ProductManager(PostgresManager): if rand_limit: res = self.pg_config.execute_sql_query( - query=f""" + query=""" SELECT p.id::uuid FROM userprofile_brokerageproduct AS p ORDER BY RANDOM() @@ -124,7 +124,7 @@ class ProductManager(PostgresManager): else: res = self.pg_config.execute_sql_query( - query=f""" + query=""" SELECT p.id::uuid FROM userprofile_brokerageproduct AS p """ @@ -322,14 +322,14 @@ class ProductManager(PostgresManager): ) -> "Product": """Create a Product with all the basic defaults and return the instance""" from generalresearch.models.thl.product import ( - UserCreateConfig, PayoutConfig, + Product, + ProfilingConfig, SessionConfig, - UserWalletConfig, SourcesConfig, + UserCreateConfig, UserHealthConfig, - ProfilingConfig, - Product, + UserWalletConfig, ) now = datetime.now(tz=timezone.utc) @@ -402,7 +402,7 @@ class ProductManager(PostgresManager): try: insert_data["id_int"] = list( self.pg_config.execute_sql_query( - f""" + query=""" SELECT COALESCE(MAX(id_int), 0) + 1 as id_int FROM userprofile_brokerageproduct """ diff --git a/generalresearch/managers/thl/profiling/question.py b/generalresearch/managers/thl/profiling/question.py index 1ad27ac..7b2a7ad 100644 --- a/generalresearch/managers/thl/profiling/question.py +++ b/generalresearch/managers/thl/profiling/question.py @@ -1,15 +1,15 @@ import random import threading -from typing import Collection, List, Tuple +from typing import Any, Collection, Dict, List, Tuple -from cachetools import cached, TTLCache +from cachetools import TTLCache, cached from pydantic import ValidationError from generalresearch.decorators import LOG from generalresearch.managers.base import PostgresManager from generalresearch.models.thl.profiling.upk_question import ( - UpkQuestion, UPKImportance, + UpkQuestion, ) @@ -21,8 +21,8 @@ class QuestionManager(PostgresManager): FROM marketplace_question WHERE id = ANY(%(question_ids)s); """ - res = self.pg_config.execute_sql_query( - query, {"question_ids": list(question_ids)} + res: List[Dict[str, Any]] = self.pg_config.execute_sql_query( + query=query, params={"question_ids": list(question_ids)} ) for x in res: x["data"]["ext_question_id"] = x["property_code"] @@ -31,6 +31,7 @@ class QuestionManager(PostgresManager): "explanation_fragment_template" ] x["data"].pop("categories", None) + return [UpkQuestion.model_validate(x["data"]) for x in res] @cached( @@ -50,7 +51,7 @@ class QuestionManager(PostgresManager): AND property_code NOT LIKE 'g:%%' AND is_live """ - res = self.pg_config.execute_sql_query( + res: List[Dict[str, Any]] = self.pg_config.execute_sql_query( query=query, params={"country_iso": country_iso, "language_iso": language_iso}, ) @@ -94,9 +95,12 @@ class QuestionManager(PostgresManager): "country_iso": country_iso, "language_iso": language_iso, } - res = self.pg_config.execute_sql_query(query=query, params=params) + res: List[Dict[str, Any]] = self.pg_config.execute_sql_query( + query=query, params=params + ) assert len(res) == 1, f"expected 1, got {len(res)} results" x = res[0] + x["data"]["ext_question_id"] = x["property_code"] x["data"]["explanation_template"] = x["explanation_template"] x["data"]["explanation_fragment_template"] = x["explanation_fragment_template"] @@ -119,7 +123,9 @@ class QuestionManager(PostgresManager): WHERE {where_str} """ flat_params = [item for tup in lookup for item in tup] - res = self.pg_config.execute_sql_query(query, params=flat_params) + res: List[Dict[str, Any]] = self.pg_config.execute_sql_query( + query=query, params=flat_params + ) for x in res: x["data"]["ext_question_id"] = x["property_code"] x["data"]["explanation_template"] = x["explanation_template"] diff --git a/generalresearch/managers/thl/profiling/schema.py b/generalresearch/managers/thl/profiling/schema.py index 581270b..e209067 100644 --- a/generalresearch/managers/thl/profiling/schema.py +++ b/generalresearch/managers/thl/profiling/schema.py @@ -2,7 +2,7 @@ from threading import RLock from typing import List from uuid import UUID -from cachetools import cached, TTLCache +from cachetools import TTLCache, cached from generalresearch.managers.base import PostgresManager from generalresearch.models.thl.profiling.upk_property import ( diff --git a/generalresearch/managers/thl/profiling/uqa.py b/generalresearch/managers/thl/profiling/uqa.py index 6800d32..1cab6c2 100644 --- a/generalresearch/managers/thl/profiling/uqa.py +++ b/generalresearch/managers/thl/profiling/uqa.py @@ -132,7 +132,7 @@ class UQAManager(PostgresManagerWithRedis): # 1) the cache expired and the user hasn't sent an answer recently # or 2) The user just sent an answer, so we'll make sure it gets put into the results # after this query runs. - query = f""" + query = """ WITH ranked AS ( SELECT uqa.*, diff --git a/generalresearch/managers/thl/profiling/user_upk.py b/generalresearch/managers/thl/profiling/user_upk.py index 4449afa..b3ea52e 100644 --- a/generalresearch/managers/thl/profiling/user_upk.py +++ b/generalresearch/managers/thl/profiling/user_upk.py @@ -1,10 +1,11 @@ import json from collections import defaultdict -from datetime import timedelta, datetime, timezone -from typing import Dict, Union, Set, List, Collection, Optional, Tuple +from datetime import datetime, timedelta, timezone +from typing import Any, Collection, Dict, List, Optional, Set, Tuple, Union from uuid import UUID from psycopg import Cursor +from pydantic import PositiveInt from generalresearch.managers.base import ( Permission, @@ -117,8 +118,9 @@ class UserUpkManager(PostgresManagerWithRedis): return [UpkQuestionAnswer.model_validate(x) for x in res] def get_user_upk_simple( - self, user_id, country_iso="us" + self, user_id: PositiveInt, country_iso: str = "us" ) -> Dict[str, Union[Set[str], str, float]]: + res = self.get_user_upk(user_id=user_id) res = [x for x in res if x.country_iso == country_iso] d: Dict[str, Union[Set[str], str, float]] = defaultdict(set) @@ -127,16 +129,19 @@ class UserUpkManager(PostgresManagerWithRedis): d[x.property_label] = x.value else: d[x.property_label].add(x.value) + return dict(d) def get_age_gender( - self, user_id, country_iso="us" + self, user_id: PositiveInt, country_iso: str = "us" ) -> Tuple[Optional[int], Optional[str]]: + # Returns an integer year for age, and {'male', 'female', 'other_gender'} d = self.get_user_upk_simple(user_id, country_iso) age = d.get("age_in_years") if age is not None: age = int(age) + gender = d.get("gender") return age, gender @@ -145,7 +150,9 @@ class UserUpkManager(PostgresManagerWithRedis): country_iso=country_iso ) - def populate_user_upk_from_dict(self, upk_ans_dict): + def populate_user_upk_from_dict( + self, upk_ans_dict: List[Dict[str, Any]] + ) -> List[UpkQuestionAnswer]: country_isos = {x["country_iso"] for x in upk_ans_dict} assert len(country_isos) == 1 diff --git a/generalresearch/managers/thl/session.py b/generalresearch/managers/thl/session.py index d328f7d..6e77031 100644 --- a/generalresearch/managers/thl/session.py +++ b/generalresearch/managers/thl/session.py @@ -1,11 +1,11 @@ from datetime import datetime, timedelta, timezone from decimal import Decimal -from typing import Optional, Dict, Tuple, List, Any, Collection -from uuid import uuid4, UUID +from typing import Any, Collection, Dict, List, Optional, Tuple +from uuid import UUID, uuid4 from faker import Faker from psycopg import sql -from pydantic import NonNegativeInt +from pydantic import NonNegativeInt, PositiveInt from generalresearch.managers import parse_order_by from generalresearch.managers.base import ( @@ -17,17 +17,17 @@ from generalresearch.models import DeviceType from generalresearch.models.custom_types import UUIDStr from generalresearch.models.legacy.bucket import Bucket from generalresearch.models.thl.definitions import ( + SessionStatusCode2, Status, StatusCode1, - SessionStatusCode2, ) from generalresearch.models.thl.session import ( Session, Wall, ) from generalresearch.models.thl.task_status import ( - TaskStatusResponse, TasksStatusResponse, + TaskStatusResponse, ) from generalresearch.models.thl.user import User @@ -48,7 +48,7 @@ class SessionManager(PostgresManager): device_type: Optional[DeviceType] = None, ip: Optional[str] = None, bucket: Optional[Bucket] = None, - url_metadata: Optional[Dict] = None, + url_metadata: Optional[Dict[str, str]] = None, uuid_id: Optional[str] = None, ) -> Session: """Creates a Session. Prefer to use this rather than instantiating the @@ -86,7 +86,7 @@ class SessionManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute(query=query, params=d) - session.id = c.fetchone()["id"] + session.id = c.fetchone()["id"] # type: ignore conn.commit() return session @@ -100,7 +100,7 @@ class SessionManager(PostgresManager): device_type: Optional[DeviceType] = None, ip: Optional[str] = None, bucket: Optional[Bucket] = None, - url_metadata: Optional[Dict] = None, + url_metadata: Optional[Dict[str, str]] = None, uuid_id: Optional[str] = None, ) -> Session: """To be used in tests, where we don't care about certain fields""" @@ -125,7 +125,7 @@ class SessionManager(PostgresManager): ) def get_from_uuid(self, session_uuid: UUIDStr) -> Session: - query = f""" + query = """ SELECT s.id AS session_id, s.uuid AS session_uuid, @@ -148,7 +148,7 @@ class SessionManager(PostgresManager): return self.session_from_mysql(res[0]) def get_from_id(self, session_id: int) -> Session: - query = f""" + query = """ SELECT s.id AS session_id, s.uuid AS session_uuid, @@ -234,7 +234,7 @@ class SessionManager(PostgresManager): ) d = session.model_dump_mysql() self.pg_config.execute_write( - query=f""" + query=""" UPDATE thl_session SET status = %(status)s, status_code_1 = %(status_code_1)s, status_code_2 = %(status_code_2)s, finished = %(finished)s, @@ -286,7 +286,7 @@ class SessionManager(PostgresManager): def filter_paginated( self, - user_id: Optional[int] = None, + user_id: Optional[PositiveInt] = None, session_uuids: Optional[List[UUIDStr]] = None, product_uuids: Optional[List[UUIDStr]] = None, started_after: Optional[datetime] = None, diff --git a/generalresearch/managers/thl/survey.py b/generalresearch/managers/thl/survey.py index 871ab83..1ce81d4 100644 --- a/generalresearch/managers/thl/survey.py +++ b/generalresearch/managers/thl/survey.py @@ -1,12 +1,13 @@ from collections import defaultdict from datetime import datetime, timezone -from typing import Collection, List, Tuple, Optional +from typing import Any, Collection, Dict, List, Optional, Tuple import pandas as pd from more_itertools import chunked from psycopg import sql +from pydantic import NonNegativeInt -from generalresearch.managers.base import PostgresManager, Permission +from generalresearch.managers.base import Permission, PostgresManager from generalresearch.managers.thl.buyer import BuyerManager from generalresearch.managers.thl.category import CategoryManager from generalresearch.models import Source @@ -126,7 +127,8 @@ class SurveyManager(PostgresManager): self, survey_keys: Collection[SurveyKey], include_categories: bool = False, - ): + ) -> List[Survey]: + assert len(survey_keys) <= 1000 if len(survey_keys) == 0: return [] @@ -195,15 +197,20 @@ class SurveyManager(PostgresManager): query, params=params, ) + return [Survey.model_validate(x) for x in res] - def filter_by_natural_key(self, source: Source, survey_ids: Collection[str]): + def filter_by_natural_key( + self, source: Source, survey_ids: Collection[str] + ) -> List[Survey]: res = [] for chunk in chunked(survey_ids, 1000): res.extend(self.filter_by_natural_key_chunk(source, chunk)) return res - def filter_by_natural_key_chunk(self, source: Source, survey_ids: Collection[str]): + def filter_by_natural_key_chunk( + self, source: Source, survey_ids: Collection[str] + ) -> List[Survey]: query = """ SELECT id, source, survey_id, created_at, updated_at, is_live, is_recontact, buyer_id, eligibility_criteria @@ -217,7 +224,7 @@ class SurveyManager(PostgresManager): ) return [Survey.model_validate(x) for x in res] - def filter_by_source_live(self, source: Source): + def filter_by_source_live(self, source: Source) -> List[Survey]: """ Return all live surveys for this source """ @@ -230,7 +237,7 @@ class SurveyManager(PostgresManager): res = self.pg_config.execute_sql_query(query, params={"source": source.value}) return [Survey.model_validate(x) for x in res] - def filter_by_live(self, fields: Optional[List[str]] = None): + def filter_by_live(self, fields: Optional[List[str]] = None) -> List[Survey]: """ Return all live surveys """ @@ -245,7 +252,9 @@ class SurveyManager(PostgresManager): res = self.pg_config.execute_sql_query(query) return [Survey.model_validate(x) for x in res] - def turn_off_by_natural_key(self, source: Source, survey_ids: Collection[str]): + def turn_off_by_natural_key( + self, source: Source, survey_ids: Collection[str] + ) -> None: params = {"survey_ids": list(survey_ids), "source": source.value} query = """ UPDATE marketplace_survey @@ -267,7 +276,7 @@ class SurveyManager(PostgresManager): survey_id = ANY(%(survey_pks)s); """ self.pg_config.execute_write( - query, + query=query, params={"survey_pks": survey_pks}, ) return None @@ -447,13 +456,16 @@ class SurveyStatManager(PostgresManager): # info = CompositeInfo.fetch(conn, "surveystat_key") # info.register(conn) - def update_or_create(self, survey_stats: List[SurveyStat]): + def update_or_create( + self, survey_stats: List[SurveyStat] + ) -> Optional[List[SurveyStat]]: """ This manager is NOT responsible for creating surveys or buyers. It will check to make sure they exist """ if len(survey_stats) == 0: return [] + assert all(s.survey_survey_id is not None for s in survey_stats) assert all(s.survey_source is not None for s in survey_stats) assert ( @@ -478,6 +490,7 @@ class SurveyStatManager(PostgresManager): ) # print(f"----aa-----: {datetime.now().isoformat()}") self.upsert_sql(survey_stats=survey_stats) + # print(f"----ab-----: {datetime.now().isoformat()}") return None # keys = [s.unique_key for s in survey_stats] @@ -489,7 +502,7 @@ class SurveyStatManager(PostgresManager): # survey_stats = sorted(survey_stats, key=lambda s: s.natural_key) # return survey_stats - def upsert_sql(self, survey_stats: List[SurveyStat]): + def upsert_sql(self, survey_stats: List[SurveyStat]) -> None: for chunk in chunked(survey_stats, 1000): self.upsert_sql_chunk(survey_stats=chunk) return None @@ -519,7 +532,7 @@ class SurveyStatManager(PostgresManager): # conn.commit() # return None - def upsert_sql_chunk(self, survey_stats: List[SurveyStat]): + def upsert_sql_chunk(self, survey_stats: List[SurveyStat]) -> None: assert len(survey_stats) <= 1000, "chunk me" keys = self.KEYS keys_str = ", ".join(keys) @@ -538,16 +551,19 @@ class SurveyStatManager(PostgresManager): DO UPDATE SET {update_str};""" now = datetime.now(tz=timezone.utc) params = [ss.model_dump_sql() | {"updated_at": now} for ss in survey_stats] + with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.executemany(query=query, params_seq=params) conn.commit() + return None - def filter_by_unique_keys(self, keys: Collection[Tuple]): + def filter_by_unique_keys(self, keys: Collection[Tuple]) -> List[SurveyStat]: res = [] for chunk in chunked(keys, 5000): res.extend(self.filter_by_unique_keys_chunk(chunk)) + return res def filter_by_unique_keys_chunk(self, keys: Collection[Tuple]): @@ -610,7 +626,7 @@ class SurveyStatManager(PostgresManager): return res - def filter_by_updated_since(self, since): + def filter_by_updated_since(self, since: datetime): return self.filter(updated_after=since, is_live=None) def filter_by_live(self): @@ -624,7 +640,7 @@ class SurveyStatManager(PostgresManager): survey_keys: Optional[Collection[SurveyKey]] = None, sources: Optional[Collection[Source]] = None, country_iso: Optional[str] = None, - ): + ) -> Tuple[str, Dict[str, Any]]: filters = [] params = dict() if updated_after is not None: @@ -665,6 +681,7 @@ class SurveyStatManager(PostgresManager): filters.append(f"({' OR '.join(sk_filters)})") filter_str = "WHERE " + " AND ".join(filters) if filters else "" + return filter_str, params def filter_count( @@ -675,7 +692,7 @@ class SurveyStatManager(PostgresManager): survey_keys: Optional[Collection[SurveyKey]] = None, sources: Optional[Collection[Source]] = None, country_iso: Optional[str] = None, - ) -> int: + ) -> NonNegativeInt: filter_str, params = self.make_filter_str( is_live=is_live, updated_after=updated_after, @@ -703,7 +720,7 @@ class SurveyStatManager(PostgresManager): size: Optional[int] = None, order_by: Optional[str] = None, debug: Optional[bool] = False, - ): + ) -> List[SurveyStat]: filter_str, params = self.make_filter_str( is_live=is_live, updated_after=updated_after, @@ -745,15 +762,18 @@ class SurveyStatManager(PostgresManager): {order_by_str} {paginated_filter_str} ; """ + if debug: print(query) print(params) + with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute("SET work_mem = '256MB';") c.execute("SET statement_timeout = '10s';") c.execute(query, params=params) res = c.fetchall() + return [SurveyStat.model_validate(x) for x in res] def filter_to_merge_table( @@ -761,12 +781,14 @@ class SurveyStatManager(PostgresManager): is_live: Optional[bool] = True, updated_after: Optional[datetime] = None, min_score: Optional[float] = 0.0001, - ): + ) -> Optional[pd.DataFrame]: + survey_stats = self.filter( is_live=is_live, updated_after=updated_after, min_score=min_score ) if not survey_stats: return None + extra_cols = { "survey_id", "quota_id", diff --git a/generalresearch/managers/thl/survey_penalty.py b/generalresearch/managers/thl/survey_penalty.py index 4e2c104..3c402ee 100644 --- a/generalresearch/managers/thl/survey_penalty.py +++ b/generalresearch/managers/thl/survey_penalty.py @@ -2,7 +2,9 @@ import json import threading from collections import defaultdict from datetime import timedelta -from typing import Optional, List, Tuple, Dict +from typing import Dict, List, Optional, Tuple + +from cachetools import TTLCache, cachedmethod from generalresearch.decorators import LOG from generalresearch.managers.base import RedisManager @@ -11,12 +13,11 @@ from generalresearch.models.custom_types import ( ) from generalresearch.models.thl.survey.penalty import ( BPSurveyPenalty, - TeamSurveyPenalty, - PenaltyListAdapter, Penalty, + PenaltyListAdapter, + TeamSurveyPenalty, ) from generalresearch.redis_helper import RedisConfig -from cachetools import cachedmethod, TTLCache class SurveyPenaltyManager(RedisManager): diff --git a/generalresearch/managers/thl/task_adjustment.py b/generalresearch/managers/thl/task_adjustment.py index 15802e3..b1b509b 100644 --- a/generalresearch/managers/thl/task_adjustment.py +++ b/generalresearch/managers/thl/task_adjustment.py @@ -2,7 +2,7 @@ import logging from datetime import datetime, timezone from decimal import Decimal from functools import cached_property -from typing import Optional +from typing import List, Optional from generalresearch.managers import parse_order_by from generalresearch.managers.base import ( @@ -13,9 +13,10 @@ from generalresearch.managers.thl.ledger_manager.thl_ledger import ( ) from generalresearch.managers.thl.session import SessionManager from generalresearch.managers.thl.wall import WallManager +from generalresearch.models.custom_types import UUIDStr from generalresearch.models.thl.definitions import ( - WallAdjustedStatus, Status, + WallAdjustedStatus, ) from generalresearch.models.thl.session import ( _check_adjusted_status_wall_consistent, @@ -35,11 +36,11 @@ class TaskAdjustmentManager(PostgresManager): def filter_by_wall_uuid( self, - wall_uuid, + wall_uuid: UUIDStr, page: int = 1, size: int = 100, order_by: Optional[str] = "-created", - ): + ) -> List[TaskAdjustmentEvent]: params = {"wall_uuid": wall_uuid} order_by_str = parse_order_by(order_by) paginated_filter_str = "LIMIT %(limit)s OFFSET %(offset)s" @@ -67,26 +68,34 @@ class TaskAdjustmentManager(PostgresManager): ) return [TaskAdjustmentEvent.model_validate(x) for x in res] - def create_task_adjustment_event(self, event: TaskAdjustmentEvent): - # Only insert a new record into thl_taskadjustment if the status for this wall_uuid - # is different from the last one. Don't need the same thing twice + def create_task_adjustment_event( + self, event: TaskAdjustmentEvent + ) -> TaskAdjustmentEvent: + + # Only insert a new record into thl_taskadjustment if the status + # for this wall_uuid is different from the last one. Don't + # need the same thing twice res = self.filter_by_wall_uuid( wall_uuid=event.wall_uuid, page=1, size=1, order_by="-created" ) if res and event.adjusted_status == res[0].adjusted_status: - # We already have this and it's the same change. Still call the wall_manager.adjust_status - # and ledger code b/c 1) it also won't do the same thing twice, and 2) we could be out of sync - # so check anyway. + # We already have this and it's the same change. Still call + # the wall_manager.adjust_status and ledger code b/c 1) it + # also won't do the same thing twice, and 2) we could be out + # of sync so check anyway. return res[0] self.pg_config.execute_write( - """ + query=""" INSERT INTO thl_taskadjustment - (uuid, adjusted_status, ext_status_code, amount, alerted, - created, user_id, wall_uuid, started, source, survey_id) - VALUES (%(uuid)s, %(adjusted_status)s, %(ext_status_code)s, %(amount)s, %(alerted)s, - %(created)s, %(user_id)s, %(wall_uuid)s, %(started)s, %(source)s, %(survey_id)s) + (uuid, adjusted_status, ext_status_code, + amount, alerted, created, user_id, + wall_uuid, started, source, survey_id) + VALUES ( + %(uuid)s, %(adjusted_status)s, %(ext_status_code)s, + %(amount)s, %(alerted)s, %(created)s, %(user_id)s, + %(wall_uuid)s, %(started)s, %(source)s, %(survey_id)s) """, params=event.model_dump(mode="json"), ) @@ -100,13 +109,15 @@ class TaskAdjustmentManager(PostgresManager): alert_time: Optional[datetime] = None, ext_status_code: Optional[str] = None, adjusted_cpi: Optional[Decimal] = None, - ): + ) -> None: """ We just got an adjustment notification from a marketplace. See note on TaskAdjustmentEvent.adjusted_status. - These fields (specifically adjusted_status and adjusted_cpi) are CHANGES/DELTAS - as just communicated by the marketplace, not what the Wall's final adjusted_* will be. + + These fields (specifically adjusted_status and adjusted_cpi) are + CHANGES/DELTAS as just communicated by the marketplace, not + what the Wall's final adjusted_* will be. """ alert_time = alert_time or datetime.now(tz=timezone.utc) assert alert_time.tzinfo == timezone.utc @@ -129,9 +140,10 @@ class TaskAdjustmentManager(PostgresManager): else: raise ValueError - # If the wall event is a complete -> fail -> complete, we are going to - # receive an adjusted_status.adjust_to_complete, but internally, - # this is going to set the adjusted_status to None (b/c it was already a complete) + # If the wall event is a complete -> fail -> complete, we are going + # to receive an adjusted_status.adjust_to_complete, but internally, + # this is going to set the adjusted_status to None (b/c it was + # already a complete) if ( wall.status == Status.COMPLETE and adjusted_status == WallAdjustedStatus.ADJUSTED_TO_COMPLETE @@ -185,3 +197,5 @@ class TaskAdjustmentManager(PostgresManager): session.wall_events = self.wall_manager.get_wall_events(session.id) self.session_manager.adjust_status(session) ledger_manager.create_tx_bp_adjustment(session, created=alert_time) + + return None diff --git a/generalresearch/managers/thl/user_compensate.py b/generalresearch/managers/thl/user_compensate.py index c5018d9..7cd16fe 100644 --- a/generalresearch/managers/thl/user_compensate.py +++ b/generalresearch/managers/thl/user_compensate.py @@ -3,6 +3,8 @@ from decimal import Decimal from typing import Optional from uuid import uuid4 +from pydantic import NonNegativeInt + from generalresearch.managers.thl.ledger_manager.thl_ledger import ( ThlLedgerManager, ) @@ -13,14 +15,14 @@ from generalresearch.models.thl.user import User def user_compensate( ledger_manager: ThlLedgerManager, user: User, - amount_int: int, - ext_ref=None, - description=None, + amount_int: NonNegativeInt, + ext_ref: Optional[str] = None, + description: Optional[str] = None, skip_flag_check: Optional[bool] = False, ) -> UUIDStr: """ - Compensate a user. aka "bribe". The money is paid out of the BP's wallet balance. - Amount is in USD cents. + Compensate a user. aka "bribe". The money is paid out of the BP's + wallet balance. Amount is in USD cents. """ pg_config = ledger_manager.pg_config redis_client = ledger_manager.redis_client @@ -44,7 +46,7 @@ def user_compensate( # If there is an external reference ID, don't allow it to be used twice if ext_ref: res = pg_config.execute_sql_query( - query=f""" + query=""" SELECT 1 FROM event_bribe WHERE ext_ref_id = %s @@ -59,9 +61,10 @@ def user_compensate( # Create a new bribe instance account = ledger_manager.get_account_or_create_user_wallet(user) pg_config.execute_write( - query=f""" + query=""" INSERT INTO event_bribe - (uuid, credit_account_uuid, created, amount, ext_ref_id, description, data) + (uuid, credit_account_uuid, created, + amount, ext_ref_id, description, data) VALUES (%s, %s, %s, %s, %s, %s, %s) """, params=[ diff --git a/generalresearch/managers/thl/user_manager/__init__.py b/generalresearch/managers/thl/user_manager/__init__.py index b8b1c35..8875de8 100644 --- a/generalresearch/managers/thl/user_manager/__init__.py +++ b/generalresearch/managers/thl/user_manager/__init__.py @@ -5,11 +5,10 @@ import threading import time from pathlib import Path from threading import RLock -from typing import Dict +from typing import Any, Dict, Union -from cachetools import cached, TTLCache +from cachetools import TTLCache, cached -from generalresearch.managers.thl.user_manager import mysql_user_manager from generalresearch.models.thl.product import Product logger = logging.getLogger() @@ -54,7 +53,7 @@ def get_bp_trust_df(): convert_int = lambda x: int(float(x)) -def parse_bp_trust_df(fp) -> Dict: +def parse_bp_trust_df(fp: Union[str, Path]) -> Dict[str, Any]: dtype = { "bp_trust": float, "team_trust": float, @@ -67,6 +66,7 @@ def parse_bp_trust_df(fp) -> Dict: with open(fp, newline="") as csvfile: reader = csv.reader(csvfile) header = next(reader) + for row in reader: d = dict(zip(header, row)) for k, v in dtype.items(): diff --git a/generalresearch/managers/thl/user_manager/mysql_user_manager.py b/generalresearch/managers/thl/user_manager/mysql_user_manager.py index ab2c6c3..299b167 100644 --- a/generalresearch/managers/thl/user_manager/mysql_user_manager.py +++ b/generalresearch/managers/thl/user_manager/mysql_user_manager.py @@ -1,7 +1,7 @@ import logging from datetime import datetime, timezone from functools import lru_cache -from typing import Optional, Collection, List +from typing import Collection, List, Optional from uuid import uuid4 import psycopg @@ -41,7 +41,7 @@ class MysqlUserManager: product_user_id: Optional[str] = None, user_id: Optional[int] = None, user_uuid: Optional[UUIDStr] = None, - can_use_read_replica=True, + can_use_read_replica: bool = True, ) -> Optional[User]: logger.info( @@ -64,7 +64,7 @@ class MysqlUserManager: if product_id: res = self.pg_config.execute_sql_query( - query=f""" + query=""" SELECT id AS user_id, product_id, product_user_id, uuid, blocked, created, last_seen FROM thl_user @@ -77,7 +77,7 @@ class MysqlUserManager: elif user_id: res = self.pg_config.execute_sql_query( - query=f""" + query=""" SELECT id AS user_id, product_id, product_user_id, uuid, blocked, created, last_seen FROM thl_user @@ -89,7 +89,7 @@ class MysqlUserManager: else: res = self.pg_config.execute_sql_query( - query=f""" + query=""" SELECT id AS user_id, product_id, product_user_id, uuid, blocked, created, last_seen FROM thl_user @@ -205,7 +205,7 @@ class MysqlUserManager: def is_whitelisted(self, user: User): res = self.pg_config.execute_sql_query( - f""" + """ SELECT value FROM userprofile_userstat WHERE user_id = %s @@ -260,7 +260,7 @@ class MysqlUserManager: ), "must pass a collection of user_ids" res = self.pg_config.execute_sql_query( - query=f""" + query=""" SELECT id AS user_id, product_id, product_user_id, uuid, blocked, created, last_seen FROM thl_user @@ -275,7 +275,7 @@ class MysqlUserManager: user_uuids, (list, set) ), "must pass a collection of user_uuids" res = self.pg_config.execute_sql_query( - query=f""" + query=""" SELECT id AS user_id, product_id, product_user_id, uuid, blocked, created, last_seen FROM thl_user diff --git a/generalresearch/managers/thl/user_manager/rate_limit.py b/generalresearch/managers/thl/user_manager/rate_limit.py index bd0f9ac..f938664 100644 --- a/generalresearch/managers/thl/user_manager/rate_limit.py +++ b/generalresearch/managers/thl/user_manager/rate_limit.py @@ -1,6 +1,6 @@ import logging -from limits import storage, strategies, RateLimitItemPerHour, RateLimitItem +from limits import RateLimitItem, RateLimitItemPerHour, storage, strategies from limits.limits import TIME_TYPES, safe_string from pydantic import RedisDsn diff --git a/generalresearch/managers/thl/user_manager/redis_user_manager.py b/generalresearch/managers/thl/user_manager/redis_user_manager.py index 8e69740..6c869bc 100644 --- a/generalresearch/managers/thl/user_manager/redis_user_manager.py +++ b/generalresearch/managers/thl/user_manager/redis_user_manager.py @@ -36,7 +36,7 @@ class RedisUserManager: product_user_id: Optional[str] = None, user_id: Optional[int] = None, user_uuid: Optional[str] = None, - ) -> User: + ) -> Optional[User]: # assume we did input validation in user_manager.get_user() function if user_uuid: d = self.client.get(f"{self.cache_prefix}:uuid:{user_uuid}") @@ -52,6 +52,8 @@ class RedisUserManager: if d: return User.model_validate_json(d) + return None + def set_user(self, user: User) -> None: d = user.to_json() with self.client.pipeline(transaction=False) as p: diff --git a/generalresearch/managers/thl/user_manager/user_manager.py b/generalresearch/managers/thl/user_manager/user_manager.py index dd32f8b..ffd8191 100644 --- a/generalresearch/managers/thl/user_manager/user_manager.py +++ b/generalresearch/managers/thl/user_manager/user_manager.py @@ -1,7 +1,7 @@ import logging from datetime import datetime from functools import lru_cache -from typing import Collection, Optional, List +from typing import Collection, List, Optional from uuid import uuid4 from pydantic import RedisDsn @@ -35,7 +35,7 @@ class UserManager: redis: Optional[RedisDsn] = None, pg_config: Optional[PostgresConfig] = None, pg_config_rr: Optional[PostgresConfig] = None, - sql_permissions: Collection[Permission] = None, + sql_permissions: Optional[Collection[Permission]] = None, cache_prefix: Optional[str] = None, redis_timeout: Optional[float] = None, ): diff --git a/generalresearch/managers/thl/user_manager/user_metadata_manager.py b/generalresearch/managers/thl/user_manager/user_metadata_manager.py index 23c9d3c..a97b214 100644 --- a/generalresearch/managers/thl/user_manager/user_metadata_manager.py +++ b/generalresearch/managers/thl/user_manager/user_metadata_manager.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Collection +from typing import Collection, List, Optional from generalresearch.managers.base import PostgresManager from generalresearch.models.thl.user_profile import UserMetadata @@ -23,8 +23,10 @@ class UserMetadataManager(PostgresManager): assert arg is None or isinstance( arg, (set, list) ), "must pass a collection of objects" + filters = [] params = {} + if user_ids: params["user_id"] = list(set(user_ids)) filters.append("user_id = ANY(%(user_id)s)") @@ -50,6 +52,7 @@ class UserMetadataManager(PostgresManager): """, params, ) + return [UserMetadata.from_db(**x) for x in res] def get_if_exists( @@ -104,8 +107,9 @@ class UserMetadataManager(PostgresManager): def update(self, user_metadata: UserMetadata) -> int: """ - The row in the thl_usermetadata might not exist. We'll implicitly create it - if it doesn't yet exist. The caller does not need to know this detail. + The row in the thl_usermetadata might not exist. We'll + implicitly create it if it doesn't yet exist. The caller + does not need to know this detail. """ res = self.get_if_exists(user_id=user_metadata.user_id) @@ -132,10 +136,11 @@ class UserMetadataManager(PostgresManager): def _create(self, user_metadata: UserMetadata) -> int: return self.pg_config.execute_write( - """ + query=""" INSERT INTO thl_usermetadata (user_id, email_address, email_sha256, email_sha1, email_md5) - VALUES (%(user_id)s, %(email_address)s, %(email_sha256)s, %(email_sha1)s, %(email_md5)s); + VALUES (%(user_id)s, %(email_address)s, %(email_sha256)s, + %(email_sha1)s, %(email_md5)s); """, params=user_metadata.to_db(), ) diff --git a/generalresearch/managers/thl/user_streak.py b/generalresearch/managers/thl/user_streak.py index fd3c1ba..4bc7e70 100644 --- a/generalresearch/managers/thl/user_streak.py +++ b/generalresearch/managers/thl/user_streak.py @@ -1,16 +1,16 @@ from datetime import date, datetime -from typing import Optional, List, Tuple +from typing import List, Optional, Tuple import pandas as pd from generalresearch.managers.base import PostgresManager from generalresearch.managers.leaderboard import country_timezone from generalresearch.models.thl.user_streak import ( - UserStreak, - StreakPeriod, + PERIOD_TO_PD_FREQ, StreakFulfillment, + StreakPeriod, StreakState, - PERIOD_TO_PD_FREQ, + UserStreak, ) @@ -30,7 +30,7 @@ class UserStreakManager(PostgresManager): {"user_id": user_id}, ) if res: - return res[0]["country_iso"] + return res[0]["country_iso"] # type: ignore return None diff --git a/generalresearch/managers/thl/userhealth.py b/generalresearch/managers/thl/userhealth.py index 6f66ce3..dab35d1 100644 --- a/generalresearch/managers/thl/userhealth.py +++ b/generalresearch/managers/thl/userhealth.py @@ -1,11 +1,12 @@ import ipaddress -from datetime import timezone, datetime, timedelta +from datetime import datetime, timedelta, timezone from itertools import zip_longest -from random import choice as rchoice, random -from typing import List, Collection, Optional, Dict, Tuple +from random import choice as rchoice +from random import random +from typing import Any, Collection, Dict, List, Optional, Tuple import faker -from pydantic import PositiveInt, NonNegativeInt +from pydantic import NonNegativeInt, PositiveInt from generalresearch.decorators import LOG from generalresearch.managers.base import ( @@ -19,8 +20,8 @@ from generalresearch.models.thl.product import Product from generalresearch.models.thl.user import User from generalresearch.models.thl.user_iphistory import ( IPRecord, - UserIPRecord, UserIPHistory, + UserIPRecord, ) from generalresearch.models.thl.userhealth import AuditLog, AuditLogLevel from generalresearch.pg_helper import PostgresConfig @@ -56,7 +57,7 @@ class UserIpHistoryManager(PostgresManagerWithRedis): # The IP metadata is ONLY for the 'ip', NOT for any forwarded ips. # This might get called immediately after a write, so use the non-rr res = self.pg_config.execute_sql_query( - query=f""" + query=""" SELECT iph.ip, iph.created, iph.user_id, geo.subdivision_1_iso, ipinfo.country_iso, @@ -263,7 +264,7 @@ class IPRecordManager(PostgresManagerWithRedis): data[col] = ipaddress.ip_address(ip).exploded if ip else ip self.pg_config.execute_write( - query=f""" + query=""" INSERT INTO userhealth_iphistory ( user_id, ip, created, forwarded_ip1, forwarded_ip2, forwarded_ip3, @@ -396,16 +397,17 @@ class AuditLogManager(PostgresManager): with self.pg_config.make_connection() as conn: with conn.cursor() as c: c.execute( - query=f""" + query=""" INSERT INTO userhealth_auditlog - (user_id, created, level, event_type, event_msg, event_value) - VALUES ( %(user_id)s , %(created)s, %(level)s, %(event_type)s, - %(event_msg)s, %(event_value)s) + (user_id, created, level, + event_type, event_msg, event_value) + VALUES ( %(user_id)s , %(created)s, %(level)s, + %(event_type)s, %(event_msg)s, %(event_value)s) RETURNING id; """, params=al.model_dump_mysql(), ) - pk = c.fetchone()["id"] + pk = c.fetchone()["id"] # type: ignore conn.commit() al.id = pk @@ -414,7 +416,7 @@ class AuditLogManager(PostgresManager): def get_by_id(self, auditlog_id: PositiveInt) -> AuditLog: res = self.pg_config.execute_sql_query( - query=f""" + query=""" SELECT al.* FROM userhealth_auditlog AS al WHERE al.id = %s @@ -434,7 +436,7 @@ class AuditLogManager(PostgresManager): def filter_by_product(self, product: Product) -> List[AuditLog]: res = self.pg_config.execute_sql_query( - query=f""" + query=""" SELECT al.* FROM userhealth_auditlog AS al INNER JOIN thl_user AS u @@ -450,7 +452,7 @@ class AuditLogManager(PostgresManager): def filter_by_user_id(self, user_id: PositiveInt) -> List[AuditLog]: res = self.pg_config.execute_sql_query( - query=f""" + query=""" SELECT * FROM userhealth_auditlog AS al WHERE al.user_id = %s @@ -527,7 +529,7 @@ class AuditLogManager(PostgresManager): assert len(res) == 1 - return int(res[0]["c"]) + return int(res[0]["c"]) # type: ignore @staticmethod def make_filter_str( @@ -538,7 +540,7 @@ class AuditLogManager(PostgresManager): event_type_like: Optional[str] = None, event_msg: Optional[str] = None, created_after: Optional[datetime] = None, - ) -> Tuple[str, Dict]: + ) -> Tuple[str, Dict[str, Any]]: assert user_ids, "must pass at least 1 user_id" assert all( [isinstance(uid, int) for uid in user_ids] diff --git a/generalresearch/managers/thl/wall.py b/generalresearch/managers/thl/wall.py index cd1bdbf..3ddbf51 100644 --- a/generalresearch/managers/thl/wall.py +++ b/generalresearch/managers/thl/wall.py @@ -1,16 +1,16 @@ import logging from collections import defaultdict -from datetime import datetime, timezone, timedelta -from decimal import Decimal, ROUND_DOWN +from datetime import datetime, timedelta, timezone +from decimal import ROUND_DOWN, Decimal from functools import cached_property from random import choice as rchoice -from typing import Optional, Collection, List +from typing import Collection, List, Optional from uuid import uuid4 from faker import Faker from psycopg import sql from psycopg.rows import dict_row -from pydantic import AwareDatetime, PositiveInt, PostgresDsn, RedisDsn +from pydantic import AwareDatetime, PositiveInt from generalresearch.managers import parse_order_by from generalresearch.managers.base import ( @@ -19,23 +19,22 @@ from generalresearch.managers.base import ( PostgresManagerWithRedis, ) from generalresearch.models import Source -from generalresearch.models.custom_types import UUIDStr, SurveyKey +from generalresearch.models.custom_types import SurveyKey, UUIDStr from generalresearch.models.thl.definitions import ( + ReportValue, Status, StatusCode1, - WallStatusCode2, - ReportValue, WallAdjustedStatus, + WallStatusCode2, ) from generalresearch.models.thl.ledger import OrderBy from generalresearch.models.thl.session import ( - check_adjusted_status_wall_consistent, Wall, WallAttempt, + check_adjusted_status_wall_consistent, ) from generalresearch.models.thl.survey.model import TaskActivity from generalresearch.pg_helper import PostgresConfig -from generalresearch.redis_helper import RedisConfig logger = logging.getLogger("WallManager") fake = Faker() @@ -379,7 +378,7 @@ class WallManager(PostgresManager): a "progress bar" for eligible, live, surveys they've already attempted. """ - query = f""" + query = """ SELECT COUNT(1) as cnt FROM thl_wall w @@ -396,7 +395,7 @@ class WallManager(PostgresManager): query=query, params=params, ) - return res[0]["cnt"] + return res[0]["cnt"] # type: ignore[union-attr] def filter_wall_attempts_paginated( self, @@ -628,6 +627,7 @@ class WallCacheManager(PostgresManagerWithRedis): def update_attempts_redis_(self, attempts: List[WallAttempt], user_id: int) -> None: if not attempts: return None + redis_key = self.get_cache_key_(user_id=user_id) # Make sure attempts is ordered, so the most recent is last # "LPUSH mylist a b c will result into a list containing c as first element, @@ -636,11 +636,12 @@ class WallCacheManager(PostgresManagerWithRedis): json_res = [attempt.model_dump_json() for attempt in attempts] res = self.redis_client.lpush(redis_key, *json_res) self.redis_client.expire(redis_key, time=60 * 60 * 24) + # So this doesn't grow forever, keep only the most recent 5k self.redis_client.ltrim(redis_key, 0, 4999) return None - def get_attempts(self, user_id: int) -> List[WallAttempt]: + def get_attempts(self, user_id: PositiveInt) -> List[WallAttempt]: """ This is used in the GetOpportunityIDs call to get a list of surveys (& surveygroups) which should be excluded for this user. We don't diff --git a/generalresearch/managers/thl/wallet/__init__.py b/generalresearch/managers/thl/wallet/__init__.py index 3b34756..0826263 100644 --- a/generalresearch/managers/thl/wallet/__init__.py +++ b/generalresearch/managers/thl/wallet/__init__.py @@ -1,20 +1,20 @@ from decimal import Decimal -from typing import Optional, Dict +from typing import Any, Dict, Optional, Union from generalresearch.managers.thl.ledger_manager.thl_ledger import ( ThlLedgerManager, ) from generalresearch.managers.thl.payout import ( - UserPayoutEventManager, PayoutEventManager, + UserPayoutEventManager, ) from generalresearch.managers.thl.user_manager.user_manager import ( UserManager, ) from generalresearch.managers.thl.userhealth import UserIpHistoryManager from generalresearch.managers.thl.wallet.approve import ( - approve_paypal_order, approve_amt_cashout, + approve_paypal_order, ) from generalresearch.models.thl.definitions import PayoutStatus from generalresearch.models.thl.payout import UserPayoutEvent @@ -31,7 +31,7 @@ def manage_pending_cashout( user_ip_history_manager: UserIpHistoryManager, user_manager: UserManager, ledger_manager: ThlLedgerManager, - order_data: Optional[Dict | CashMailOrderData] = None, + order_data: Optional[Union[Dict[str, Any], CashMailOrderData]] = None, ) -> UserPayoutEvent: """ Called by a UI actions performed by Todd. This rejects/approves/cancels diff --git a/generalresearch/managers/thl/wallet/tango.py b/generalresearch/managers/thl/wallet/tango.py index 5587435..7665ef0 100644 --- a/generalresearch/managers/thl/wallet/tango.py +++ b/generalresearch/managers/thl/wallet/tango.py @@ -1,7 +1,5 @@ from typing import Any, Dict -import slack - from generalresearch.config import ( is_debug, ) @@ -92,7 +90,7 @@ def get_tango_order(ref_id: str): # return json.loads(APIHelper.json_serialize(orders[0])) -def create_tango_order(request_data: Dict, ref_id: str) -> Dict[str, Any]: +def create_tango_order(request_data: Dict[str, Any], ref_id: str) -> Dict[str, Any]: """ Create a tango gift card order. Throws exception if anything is not right. diff --git a/tests/managers/thl/test_ledger/test_lm_accounts.py b/tests/managers/thl/test_ledger/test_lm_accounts.py index be9cf5b..5cfaac1 100644 --- a/tests/managers/thl/test_ledger/test_lm_accounts.py +++ b/tests/managers/thl/test_ledger/test_lm_accounts.py @@ -19,8 +19,18 @@ from generalresearch.models.thl.ledger import ( ) if TYPE_CHECKING: + from pydantic import PositiveInt + from generalresearch.config import GRLSettings from generalresearch.currency import LedgerCurrency + from generalresearch.managers.thl.ledger_manager.ledger import LedgerManager + from generalresearch.models.custom_types import AccountType, Direction, UUIDStr + from generalresearch.models.thl import Direction + from generalresearch.models.thl.ledger import ( + AccountType, + LedgerAccount, + LedgerTransaction, + ) from generalresearch.models.thl.product import Product from generalresearch.models.thl.session import Session from generalresearch.models.thl.user import User @@ -40,7 +50,11 @@ if TYPE_CHECKING: class TestLedgerAccountManagerNoResults: def test_get_account_no_results( - self, currency: "LedgerCurrency", kind, acct_id, lm + self, + currency: "LedgerCurrency", + kind: str, + acct_id: "UUIDStr", + lm: "LedgerManager", ): """Try to query for accounts that we know don't exist and confirm that we either get the expected None result or it raises the correct @@ -59,7 +73,11 @@ class TestLedgerAccountManagerNoResults: assert lm.get_account(qualified_name=qn, raise_on_error=False) is None def test_get_account_no_results_many( - self, currency: "LedgerCurrency", kind, acct_id, lm + self, + currency: "LedgerCurrency", + kind: str, + acct_id: "UUIDStr", + lm: "LedgerManager", ): qn = ":".join([currency, kind, acct_id]) @@ -95,7 +113,11 @@ class TestLedgerAccountManagerNoResults: class TestLedgerAccountManagerCreate: def test_create_account_error_permission( - self, currency: "LedgerCurrency", account_type, direction, lm + self, + currency: "LedgerCurrency", + account_type: "AccountType", + direction: "Direction", + lm: "LedgerManager", ): """Confirm that the Permission values that are set on the Ledger Manger allow the Creation action to occur. @@ -140,7 +162,13 @@ class TestLedgerAccountManagerCreate: str(excinfo.value) == "LedgerManager does not have sufficient permissions" ) - def test_create(self, currency: "LedgerCurrency", account_type, direction, lm): + def test_create( + self, + currency: "LedgerCurrency", + account_type: "AccountType", + direction: "Direction", + lm: "LedgerManager", + ): """Confirm that the Permission values that are set on the Ledger Manger allow the Creation action to occur. """ @@ -161,10 +189,15 @@ class TestLedgerAccountManagerCreate: # Query for, and make sure the Account was saved in the DB res = lm.get_account(qualified_name=qn, raise_on_error=True) + assert res is not None assert account.uuid == res.uuid def test_get_or_create( - self, currency: "LedgerCurrency", account_type, direction, lm + self, + currency: "LedgerCurrency", + account_type: "AccountType", + direction: "Direction", + lm: "LedgerManager", ): """Confirm that the Permission values that are set on the Ledger Manger allow the Creation action to occur. @@ -186,13 +219,15 @@ class TestLedgerAccountManagerCreate: # Query for, and make sure the Account was saved in the DB res = lm.get_account(qualified_name=qn, raise_on_error=True) + assert res is not None assert account.uuid == res.uuid class TestLedgerAccountManagerGet: - def test_get(self, ledger_account, lm): + def test_get(self, ledger_account: "LedgerAccount", lm: "LedgerManager"): res = lm.get_account(qualified_name=ledger_account.qualified_name) + assert res is not None assert res.uuid == ledger_account.uuid res = lm.get_account_many(qualified_names=[ledger_account.qualified_name]) @@ -207,7 +242,12 @@ class TestLedgerAccountManagerGet: # creation working def test_get_balance_empty( - self, ledger_account, ledger_account_credit, ledger_account_debit, ledger_tx, lm + self, + ledger_account: "LedgerAccount", + ledger_account_credit: "LedgerAccount", + ledger_account_debit: "LedgerAccount", + ledger_tx: "LedgerTransaction", + lm: "LedgerManager", ): res = lm.get_account_balance(account=ledger_account) assert res == 0 @@ -221,12 +261,12 @@ class TestLedgerAccountManagerGet: @pytest.mark.parametrize("n_times", range(5)) def test_get_account_filtered_balance( self, - ledger_account, - ledger_account_credit, - ledger_account_debit, - ledger_tx, - n_times, - lm, + ledger_account: "LedgerAccount", + ledger_account_credit: "LedgerAccount", + ledger_account_debit: "LedgerAccount", + ledger_tx: "LedgerTransaction", + n_times: "PositiveInt", + lm: "LedgerManager", ): """Try searching for random metadata and confirm it's always 0 because Tx can be found. @@ -279,6 +319,8 @@ class TestLedgerAccountManagerGet: == rand_amount ) - def test_get_balance_timerange_empty(self, ledger_account, lm): + def test_get_balance_timerange_empty( + self, ledger_account: "LedgerAccount", lm: "LedgerManager" + ): res = lm.get_account_balance_timerange(account=ledger_account) assert res == 0 |
