aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--generalresearch/grliq/managers/event_plotter.py11
-rw-r--r--generalresearch/grliq/managers/forensic_data.py42
-rw-r--r--generalresearch/grliq/managers/forensic_events.py38
-rw-r--r--generalresearch/grliq/managers/forensic_results.py10
-rw-r--r--generalresearch/grliq/managers/forensic_summary.py16
-rw-r--r--generalresearch/grliq/models/events.py28
-rw-r--r--generalresearch/grliq/models/forensic_data.py39
-rw-r--r--generalresearch/grliq/models/forensic_result.py14
-rw-r--r--generalresearch/grliq/models/forensic_summary.py12
-rw-r--r--generalresearch/grliq/utils.py2
-rw-r--r--generalresearch/incite/base.py29
-rw-r--r--generalresearch/incite/collections/__init__.py40
-rw-r--r--generalresearch/incite/defaults.py7
-rw-r--r--generalresearch/incite/mergers/foundations/__init__.py23
-rw-r--r--generalresearch/incite/mergers/foundations/enriched_session.py4
-rw-r--r--generalresearch/incite/mergers/foundations/enriched_task_adjust.py9
-rw-r--r--generalresearch/incite/mergers/foundations/enriched_wall.py4
-rw-r--r--generalresearch/incite/mergers/foundations/user_id_product.py16
-rw-r--r--generalresearch/incite/mergers/nginx_core.py2
-rw-r--r--generalresearch/incite/mergers/nginx_fsb.py5
-rw-r--r--generalresearch/incite/mergers/nginx_grs.py2
-rw-r--r--generalresearch/incite/mergers/pop_ledger.py4
-rw-r--r--generalresearch/incite/mergers/ym_survey_wall.py6
-rw-r--r--generalresearch/incite/mergers/ym_wall_summary.py6
-rw-r--r--generalresearch/incite/schemas/admin_responses.py8
-rw-r--r--generalresearch/incite/schemas/mergers/foundations/enriched_session.py2
-rw-r--r--generalresearch/incite/schemas/mergers/foundations/enriched_task_adjust.py5
-rw-r--r--generalresearch/incite/schemas/mergers/foundations/enriched_wall.py6
-rw-r--r--generalresearch/incite/schemas/mergers/foundations/user_id_product.py2
-rw-r--r--generalresearch/incite/schemas/mergers/nginx.py4
-rw-r--r--generalresearch/incite/schemas/mergers/pop_ledger.py4
-rw-r--r--generalresearch/incite/schemas/mergers/ym_survey_wall.py6
-rw-r--r--generalresearch/incite/schemas/mergers/ym_wall_summary.py2
-rw-r--r--generalresearch/incite/schemas/thl_marketplaces.py4
-rw-r--r--generalresearch/incite/schemas/thl_web.py16
-rw-r--r--generalresearch/managers/__init__.py2
-rw-r--r--generalresearch/managers/cint/profiling.py9
-rw-r--r--generalresearch/managers/cint/survey.py6
-rw-r--r--generalresearch/managers/dynata/profiling.py7
-rw-r--r--generalresearch/managers/dynata/survey.py9
-rw-r--r--generalresearch/managers/events.py37
-rw-r--r--generalresearch/managers/gr/authentication.py17
-rw-r--r--generalresearch/managers/gr/business.py40
-rw-r--r--generalresearch/managers/gr/team.py12
-rw-r--r--generalresearch/managers/innovate/profiling.py14
-rw-r--r--generalresearch/managers/innovate/survey.py12
-rw-r--r--generalresearch/managers/leaderboard/manager.py10
-rw-r--r--generalresearch/managers/leaderboard/tasks.py4
-rw-r--r--generalresearch/managers/lucid/profiling.py10
-rw-r--r--generalresearch/managers/marketplace/user_pid.py4
-rw-r--r--generalresearch/managers/morning/profiling.py7
-rw-r--r--generalresearch/managers/morning/survey.py6
-rw-r--r--generalresearch/managers/pollfish/profiling.py6
-rw-r--r--generalresearch/managers/precision/profiling.py5
-rw-r--r--generalresearch/managers/precision/survey.py16
-rw-r--r--generalresearch/managers/prodege/profiling.py6
-rw-r--r--generalresearch/managers/prodege/survey.py8
-rw-r--r--generalresearch/managers/repdata/profiling.py2
-rw-r--r--generalresearch/managers/repdata/survey.py8
-rw-r--r--generalresearch/managers/sago/profiling.py7
-rw-r--r--generalresearch/managers/sago/survey.py6
-rw-r--r--generalresearch/managers/spectrum/profiling.py6
-rw-r--r--generalresearch/managers/spectrum/survey.py9
-rw-r--r--generalresearch/managers/thl/buyer.py12
-rw-r--r--generalresearch/managers/thl/cashout_method.py35
-rw-r--r--generalresearch/managers/thl/contest_manager.py32
-rw-r--r--generalresearch/managers/thl/ipinfo.py45
-rw-r--r--generalresearch/managers/thl/ledger_manager/conditions.py22
-rw-r--r--generalresearch/managers/thl/ledger_manager/ledger.py30
-rw-r--r--generalresearch/managers/thl/ledger_manager/thl_ledger.py61
-rw-r--r--generalresearch/managers/thl/maxmind/__init__.py8
-rw-r--r--generalresearch/managers/thl/maxmind/basic.py5
-rw-r--r--generalresearch/managers/thl/maxmind/insights.py2
-rw-r--r--generalresearch/managers/thl/payout.py63
-rw-r--r--generalresearch/managers/thl/product.py30
-rw-r--r--generalresearch/managers/thl/profiling/question.py22
-rw-r--r--generalresearch/managers/thl/profiling/schema.py2
-rw-r--r--generalresearch/managers/thl/profiling/uqa.py2
-rw-r--r--generalresearch/managers/thl/profiling/user_upk.py17
-rw-r--r--generalresearch/managers/thl/session.py24
-rw-r--r--generalresearch/managers/thl/survey.py58
-rw-r--r--generalresearch/managers/thl/survey_penalty.py9
-rw-r--r--generalresearch/managers/thl/task_adjustment.py56
-rw-r--r--generalresearch/managers/thl/user_compensate.py19
-rw-r--r--generalresearch/managers/thl/user_manager/__init__.py8
-rw-r--r--generalresearch/managers/thl/user_manager/mysql_user_manager.py16
-rw-r--r--generalresearch/managers/thl/user_manager/rate_limit.py2
-rw-r--r--generalresearch/managers/thl/user_manager/redis_user_manager.py4
-rw-r--r--generalresearch/managers/thl/user_manager/user_manager.py4
-rw-r--r--generalresearch/managers/thl/user_manager/user_metadata_manager.py15
-rw-r--r--generalresearch/managers/thl/user_streak.py10
-rw-r--r--generalresearch/managers/thl/userhealth.py36
-rw-r--r--generalresearch/managers/thl/wall.py25
-rw-r--r--generalresearch/managers/thl/wallet/__init__.py8
-rw-r--r--generalresearch/managers/thl/wallet/tango.py4
-rw-r--r--tests/managers/thl/test_ledger/test_lm_accounts.py70
96 files changed, 867 insertions, 602 deletions
diff --git a/generalresearch/grliq/managers/event_plotter.py b/generalresearch/grliq/managers/event_plotter.py
index b879c5c..54105ce 100644
--- a/generalresearch/grliq/managers/event_plotter.py
+++ b/generalresearch/grliq/managers/event_plotter.py
@@ -1,12 +1,13 @@
import html
-from typing import List
import webbrowser
+from typing import List
+
import numpy as np
from more_itertools import windowed
from scipy.spatial.distance import euclidean
from generalresearch.grliq.managers.colormap import turbo_colormap_data
-from generalresearch.grliq.models.events import MouseEvent, KeyboardEvent
+from generalresearch.grliq.models.events import KeyboardEvent, MouseEvent
def make_events_svg(
@@ -30,6 +31,10 @@ def make_events_svg(
for ee in windowed(move_events, 2):
e1 = ee[0]
e2 = ee[1]
+
+ assert e1 is not None
+ assert e2 is not None
+
ts_idx = (e2.timeStamp - t.min()) / t_diff
r, g, b = turbo_colormap_data[round(ts_idx * 255)]
color = f"rgb({int(r*255)},{int(g*255)},{int(b*255)})"
@@ -115,7 +120,7 @@ def svg_multiline_text(
def group_input_events_by_xy(
mouse_events: List[MouseEvent], keyboard_events: List[KeyboardEvent]
-):
+) -> List[tuple[tuple[float, float], List[str]]]:
"""
Each keypress is its own event. For plotting, we want to group together
all keypresses that were made when the mouse was at the same position,
diff --git a/generalresearch/grliq/managers/forensic_data.py b/generalresearch/grliq/managers/forensic_data.py
index 0f58d54..d7e362d 100644
--- a/generalresearch/grliq/managers/forensic_data.py
+++ b/generalresearch/grliq/managers/forensic_data.py
@@ -1,16 +1,16 @@
from datetime import datetime, timezone
-from typing import Optional, List, Collection, Dict, Tuple, Any
+from typing import Any, Collection, Dict, List, Optional, Tuple
from uuid import uuid4
from psycopg import sql
-from pydantic import PositiveInt, NonNegativeInt
+from pydantic import NonNegativeInt, PositiveInt
from generalresearch.grliq.managers import DUMMY_GRLIQ_DATA
from generalresearch.grliq.models.events import PointerMove, TimingData
from generalresearch.grliq.models.forensic_data import GrlIqData
from generalresearch.grliq.models.forensic_result import (
- GrlIqForensicCategoryResult,
GrlIqCheckerResults,
+ GrlIqForensicCategoryResult,
Phase,
)
from generalresearch.models.custom_types import UUIDStr
@@ -25,7 +25,7 @@ class GrlIqDataManager:
def create_dummy(
self,
- is_attempt_allowed: True,
+ is_attempt_allowed: bool = True,
product_id: Optional[str] = None,
product_user_id: Optional[str] = None,
uuid: Optional[str] = None,
@@ -118,7 +118,7 @@ class GrlIqDataManager:
with self.postgres_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, data)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
iq_data.id = pk
@@ -316,16 +316,22 @@ class GrlIqDataManager:
def filter_timing_data(
self,
- created_between: Optional[Tuple[datetime, datetime]] = None,
+ created_between: Tuple[datetime, datetime],
limit: Optional[int] = None,
offset: Optional[int] = None,
- ) -> List[Dict]:
+ ) -> List[Dict[str, Any]]:
+
+ # TODO! created_between used to be marked as Optional, but it would
+ # break the query. Evaluate it's use to determine best behavior.
+
limit_str = f"LIMIT {limit}" if limit is not None else ""
offset_str = f"OFFSET {offset}" if offset is not None else ""
+
params = {
"created_after": created_between[0],
"created_before": created_between[1],
}
+
query = f"""
SELECT
d.id, d.session_uuid, d.client_ip, d.country_iso,
@@ -342,10 +348,12 @@ class GrlIqDataManager:
WHERE d.created_at BETWEEN %(created_after)s AND %(created_before)s
{limit_str} {offset_str};
"""
+
with self.postgres_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, params)
- res: List[Dict] = c.fetchall()
+ res: List[Dict[str, Any]] = c.fetchall() # type: ignore
+
for x in res:
x["timing_data"] = TimingData.model_validate(x["timing_data"])
@@ -353,12 +361,14 @@ class GrlIqDataManager:
def get_unique_user_count_by_fingerprint(
self,
- product_id: str,
+ product_id: UUIDStr,
fingerprint: str,
product_user_id_not: str,
) -> NonNegativeInt:
+
# This is used for filtering for other forensic posts with a certain
# fingerprint, in this product_id, but NOT for this user.
+
query = sql.SQL(
"""
SELECT COUNT(DISTINCT product_user_id) as user_count
@@ -378,7 +388,8 @@ class GrlIqDataManager:
with self.postgres_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, params)
- user_count = c.fetchone()["user_count"]
+ user_count = c.fetchone()["user_count"] # type: ignore
+
return int(user_count)
def filter_data(
@@ -695,7 +706,7 @@ class GrlIqDataManager:
order_by: str = "created_at DESC",
limit: Optional[int] = None,
offset: Optional[int] = None,
- ) -> List[Dict]:
+ ) -> List[Dict[str, Any]]:
"""
Accepts lots of optional filters.
"""
@@ -738,7 +749,7 @@ class GrlIqDataManager:
with self.postgres_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query=query, params=params)
- res: List = c.fetchall()
+ res: List[Dict[str, Any]] = c.fetchall() # type: ignore
for x in res:
@@ -766,7 +777,7 @@ class GrlIqDataManager:
return res
@staticmethod
- def temporary_add_missing_fields(d: Dict):
+ def temporary_add_missing_fields(d: Dict[str, Any]) -> None:
# The following fields were added recently, and so we must give them
# a value or old db rows won't be parseable. Once logs are backfilled
# then this can be removed
@@ -785,8 +796,9 @@ class GrlIqDataManager:
if k not in d:
d[k] = v
- # We made a mistake once and saved the grliq data object with the events fields set.
- # Make sure they are not set here. We load them from the events table, not here!
+ # We made a mistake once and saved the grliq data object with the
+ # events fields set. Make sure they are not set here. We load them
+ # from the events table, not here!
d.pop("events", None)
d.pop("pointer_move_events", None)
d.pop("mouse_events", None)
diff --git a/generalresearch/grliq/managers/forensic_events.py b/generalresearch/grliq/managers/forensic_events.py
index 2633db7..bbc6b6d 100644
--- a/generalresearch/grliq/managers/forensic_events.py
+++ b/generalresearch/grliq/managers/forensic_events.py
@@ -1,16 +1,17 @@
import json
from datetime import datetime
-from typing import Optional, List, Collection, Dict
+from typing import Any, Collection, Dict, List, Optional
from uuid import uuid4
from psycopg import sql
+from pydantic import PositiveInt
from generalresearch.grliq.models.events import (
- TimingData,
- PointerMove,
- MouseEvent,
- KeyboardEvent,
Bounds,
+ KeyboardEvent,
+ MouseEvent,
+ PointerMove,
+ TimingData,
)
from generalresearch.models.custom_types import UUIDStr
from generalresearch.pg_helper import PostgresConfig
@@ -24,8 +25,8 @@ class GrlIqEventManager:
def update_or_create_timing(
self,
session_uuid: UUIDStr,
- timing_data: TimingData,
- ):
+ timing_data: Optional[TimingData] = None,
+ ) -> PositiveInt:
data = {
"session_uuid": session_uuid,
"timing_data": (
@@ -69,7 +70,7 @@ class GrlIqEventManager:
pk = c.fetchone()["id"]
conn.commit()
- return pk
+ return int(pk)
def update_or_create_events(
self,
@@ -78,7 +79,7 @@ class GrlIqEventManager:
event_end: datetime,
events: Optional[List[Dict]] = None,
mouse_events: Optional[List[Dict]] = None,
- ):
+ ) -> PositiveInt:
data = {
"uuid": uuid4().hex,
"session_uuid": session_uuid,
@@ -130,7 +131,7 @@ class GrlIqEventManager:
pk = c.fetchone()["id"]
conn.commit()
- return pk
+ return int(pk)
def filter(
self,
@@ -141,14 +142,16 @@ class GrlIqEventManager:
started_since: Optional[datetime] = None,
limit: Optional[int] = None,
order_by: str = "event_start DESC",
- ) -> List[Dict]:
- """ """
+ ) -> List[Dict[str, Any]]:
+
if not limit:
limit = 100
if not select_str:
select_str = "*"
+
filters = []
params = {}
+
if session_uuid:
params["session_uuid"] = session_uuid
filters.append("session_uuid = %(session_uuid)s")
@@ -164,6 +167,7 @@ class GrlIqEventManager:
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
+
query = f"""
SELECT {select_str}
FROM grliq_forensicevents
@@ -189,12 +193,13 @@ class GrlIqEventManager:
events=events, pointer_moves=pointer_moves
)
x["keyboard_events"] = self.process_keyboard_events(events=events)
+
return res
def filter_distinct_timing(
self,
session_uuids: Collection[str],
- ) -> List[Dict]:
+ ) -> List[Dict[str, Any]]:
params = {"session_uuids": list(session_uuids)}
query = sql.SQL(
"""
@@ -226,9 +231,10 @@ class GrlIqEventManager:
@staticmethod
def process_mouse_events(pointer_moves: List[PointerMove], events: List[Dict]):
"""
- In the db column 'mouse_events' we put all 'pointermove' events. Pull those
- out, and then any 'pointerdown' and 'pointerup' events from the 'events' column,
- and merge them all together into a list of MouseEvent objects
+ In the db column 'mouse_events' we put all 'pointermove' events. Pull
+ those out, and then any 'pointerdown' and 'pointerup' events from the
+ 'events' column, and merge them all together into a list of MouseEvent
+ objects
"""
mouse_events = [
# these contain only pointermove events
diff --git a/generalresearch/grliq/managers/forensic_results.py b/generalresearch/grliq/managers/forensic_results.py
index dd1b039..30db53d 100644
--- a/generalresearch/grliq/managers/forensic_results.py
+++ b/generalresearch/grliq/managers/forensic_results.py
@@ -1,5 +1,5 @@
from datetime import datetime
-from typing import Optional, List, Collection, Dict, Tuple
+from typing import Any, Collection, Dict, List, Optional, Tuple
from generalresearch.grliq.models.forensic_result import (
GrlIqForensicCategoryResult,
@@ -25,9 +25,10 @@ class GrlIqCategoryResultsReader:
created_between: Optional[Tuple[datetime, datetime]] = None,
user: Optional[User] = None,
limit: Optional[int] = None,
- ) -> List[Dict]:
+ ) -> List[Dict[str, Any]]:
"""
For retrieving GrlIqForensicCategoryResult objects from db.
+
:return: List of Dict. Keys are below in the 'select_str'.
"""
select_str = (
@@ -37,10 +38,11 @@ class GrlIqCategoryResultsReader:
" category_result, is_attempt_allowed, fraud_score"
)
if not limit:
- limit = 5000
+ limit = 5_000
filters = []
params = {}
+
if session_uuid:
params["session_uuid"] = session_uuid
filters.append("d.session_uuid = %(session_uuid)s")
@@ -77,6 +79,7 @@ class GrlIqCategoryResultsReader:
filters.append(
"(d.product_id = %(product_id)s AND d.product_user_id = %(product_user_id)s)"
)
+
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
query = f"""
@@ -85,6 +88,7 @@ class GrlIqCategoryResultsReader:
{filter_str}
ORDER BY created_at DESC LIMIT {limit}
"""
+
with self.postgres_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query, params)
diff --git a/generalresearch/grliq/managers/forensic_summary.py b/generalresearch/grliq/managers/forensic_summary.py
index c44daf2..21b7e4b 100644
--- a/generalresearch/grliq/managers/forensic_summary.py
+++ b/generalresearch/grliq/managers/forensic_summary.py
@@ -2,8 +2,8 @@ from __future__ import annotations
import statistics
from collections import defaultdict
-from datetime import datetime, timezone, timedelta
-from typing import List, Dict
+from datetime import datetime, timedelta, timezone
+from typing import Any, Dict, List
import numpy as np
@@ -12,15 +12,15 @@ from generalresearch.grliq.managers.forensic_events import (
GrlIqEventManager,
)
from generalresearch.grliq.models.forensic_result import (
- GrlIqForensicCategoryResult,
GrlIqCheckerResults,
+ GrlIqForensicCategoryResult,
)
from generalresearch.grliq.models.forensic_summary import (
- GrlIqForensicCategorySummary,
- GrlIqCheckerResultsSummary,
- UserForensicSummary,
CountryRTTDistribution,
+ GrlIqCheckerResultsSummary,
+ GrlIqForensicCategorySummary,
TimingDataCountrySummary,
+ UserForensicSummary,
)
from generalresearch.models.thl.user import User
from generalresearch.redis_helper import RedisConfig
@@ -85,8 +85,9 @@ def calculate_checker_summary(
def calculate_timing_summary(
- redis_config: RedisConfig, timing_res
+ redis_config: RedisConfig, timing_res: List[Dict[str, Any]]
) -> Dict[str, TimingDataCountrySummary]:
+
country_median_rtts = defaultdict(list)
for x in timing_res:
s = x["timing_data"].summarize
@@ -135,6 +136,7 @@ def run_user_forensic_summary(
redis_config: RedisConfig,
user: User,
) -> UserForensicSummary:
+
now = datetime.now(tz=timezone.utc)
created_between = (now - timedelta(days=90), now)
select_str = "id, session_uuid, product_id, product_user_id, created_at, result_data, category_result"
diff --git a/generalresearch/grliq/models/events.py b/generalresearch/grliq/models/events.py
index 2c8fa64..7e9ebee 100644
--- a/generalresearch/grliq/models/events.py
+++ b/generalresearch/grliq/models/events.py
@@ -3,15 +3,15 @@ from __future__ import annotations
from collections import namedtuple
from dataclasses import dataclass, fields
from functools import cached_property
-from typing import List, Optional
+from typing import Any, Dict, List, Optional
import numpy as np
from pydantic import (
BaseModel,
ConfigDict,
Field,
- NonNegativeInt,
NonNegativeFloat,
+ NonNegativeInt,
PositiveFloat,
)
from typing_extensions import Self
@@ -44,7 +44,7 @@ class Event:
_elementBounds: Optional[Bounds] = None
@classmethod
- def from_dict(cls, data: dict) -> Self:
+ def from_dict(cls, data: Dict[str, Any]) -> Self:
data = {k: v for k, v in data.items() if k in cls.__dataclass_fields__}
bounds = data.get("_elementBounds")
if bounds is not None and not isinstance(bounds, Bounds):
@@ -56,16 +56,21 @@ class Event:
class PointerMove(Event):
# should always be 'pointermove'
type: str
+
# (mouse, touch, pen)
pointerType: str
- # coordinate relative to the screen
+
+ # Coordinate relative to the screen
screenX: float
screenY: float
- # coordinate relative to the document (unaffected by scrolling)
+
+ # Coordinate relative to the document (unaffected by scrolling)
pageX: float
pageY: float
- # pageX/Y divided by the document Width/Height. This is calculated in JS and sent, which
- # it must be b/c we don't know the document width/height at each time otherwise.
+
+ # PageX/Y divided by the document Width/Height. This is calculated in JS
+ # and sent, which it must be b/c we don't know the document width/height
+ # at each time otherwise.
normalizedX: float
normalizedY: float
@@ -82,8 +87,10 @@ class MouseEvent(Event):
# should be {'pointerdown', 'pointerup', 'pointermove', 'click'}
type: str
+
# Type of input (mouse, touch, pen)
pointerType: str
+
# coordinate relative to the document (unaffected by scrolling)
pageX: float
pageY: float
@@ -91,15 +98,17 @@ class MouseEvent(Event):
@dataclass
class KeyboardEvent(Event):
- """ """
# should be {'keydown', 'input'}
type: str
+
# "insertText", "insertCompositionText", "deleteCompositionText",
# "insertFromComposition", "deleteContentBackward"
inputType: Optional[str]
+
# e.g., 'Enter', 'a', 'Backspace'
key: Optional[str] = None
+
# This is the actual text, if applicable
data: Optional[str] = None
@@ -180,7 +189,7 @@ class TimingData(BaseModel):
def has_data(self):
return len(self.client_rtts) > 0 and len(self.server_rtts) > 0
- def filter_rtts(self, rtts):
+ def filter_rtts(self, rtts: List[float]) -> List[float]:
# Skip the first 5 pings, unless we have <10 pings, then get the last
# 5 instead.
# The first couple pings are usually outliers as they are running
@@ -189,6 +198,7 @@ class TimingData(BaseModel):
rtts = rtts[5:]
else:
rtts = rtts[-5:]
+
return rtts
@cached_property
diff --git a/generalresearch/grliq/models/forensic_data.py b/generalresearch/grliq/models/forensic_data.py
index c182f43..1eb3648 100644
--- a/generalresearch/grliq/models/forensic_data.py
+++ b/generalresearch/grliq/models/forensic_data.py
@@ -1,55 +1,55 @@
import hashlib
import re
from collections import Counter
-from datetime import datetime, timezone, timedelta
+from datetime import datetime, timedelta, timezone
from enum import Enum
from functools import cached_property
-from typing import Literal, Optional, Dict, List, Set, Any
+from typing import Any, Dict, List, Literal, Optional, Set
from uuid import uuid4
import pycountry
from faker import Faker
from pydantic import (
+ AfterValidator,
+ AwareDatetime,
BaseModel,
ConfigDict,
Field,
- field_validator,
- StringConstraints,
- AfterValidator,
- AwareDatetime,
NonNegativeInt,
+ StringConstraints,
+ field_validator,
)
from pydantic.json_schema import SkipJsonSchema
from pydantic_extra_types.timezone_name import TimeZoneName
-from typing_extensions import Self, Annotated
+from typing_extensions import Annotated, Self
from generalresearch.grliq.models import (
AUDIO_CODEC_NAMES,
- VIDEO_CODEC_NAMES,
SUPPORTED_FONTS,
+ VIDEO_CODEC_NAMES,
)
from generalresearch.grliq.models.events import (
+ KeyboardEvent,
+ MouseEvent,
PointerMove,
TimingData,
- MouseEvent,
- KeyboardEvent,
)
from generalresearch.grliq.models.forensic_result import (
- GrlIqForensicCategoryResult,
GrlIqCheckerResults,
+ GrlIqForensicCategoryResult,
+ Phase,
)
-from generalresearch.grliq.models.forensic_result import Phase
from generalresearch.grliq.models.useragents import (
GrlUserAgent,
OSFamily,
UserAgentHints,
)
from generalresearch.models.custom_types import (
- UUIDStr,
- IPvAnyAddressStr,
- BigAutoInteger,
AwareDatetimeISO,
+ BigAutoInteger,
CountryISOLike,
+ IPvAnyAddressStr,
+ UUIDStr,
)
from generalresearch.models.thl.ipinfo import GeoIPInformation
from generalresearch.models.thl.session import Session
@@ -265,6 +265,7 @@ class GrlIqData(BaseModel):
navigator_java_enabled: bool = Field()
do_not_track_enabled: str = Field(description="unspecified or '' ? ")
mime_types_length: int = Field()
+
# todo: some report actual values, some are (always) faked by the os/browser
# as anti-fingerprint 10737418240 (10gb) typical on firefox windows,
# 2147483648 (20gb) chrome, iPhones typically only 8 different values
@@ -542,7 +543,7 @@ class GrlIqData(BaseModel):
return hashlib.md5(s.encode()).hexdigest()
@cached_property
- def audio_codecs_named(self) -> Dict:
+ def audio_codecs_named(self) -> Dict[str, bool]:
return dict(
zip(
AUDIO_CODEC_NAMES,
@@ -551,7 +552,7 @@ class GrlIqData(BaseModel):
)
@cached_property
- def video_codecs_named(self) -> Dict:
+ def video_codecs_named(self) -> Dict[str, bool]:
return dict(
zip(
VIDEO_CODEC_NAMES,
@@ -618,7 +619,7 @@ class GrlIqData(BaseModel):
mode="before",
)
@classmethod
- def str_to_float_or_null(cls, value: str) -> Optional[int]:
+ def str_to_float_or_null(cls, value: str) -> Optional[float]:
return float(value) if value not in {None, ""} else None
@field_validator(
@@ -787,7 +788,7 @@ class GrlIqData(BaseModel):
return d
@classmethod
- def from_db(cls, d: Dict) -> Self:
+ def from_db(cls, d: Dict[str, Any]) -> Self:
res = GrlIqData.model_validate(d["data"])
if d.get("category_result"):
diff --git a/generalresearch/grliq/models/forensic_result.py b/generalresearch/grliq/models/forensic_result.py
index 7f906c2..d89f681 100644
--- a/generalresearch/grliq/models/forensic_result.py
+++ b/generalresearch/grliq/models/forensic_result.py
@@ -1,18 +1,18 @@
from __future__ import annotations
from enum import Enum
-from typing import Optional, List, Set
+from typing import List, Optional, Set
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, computed_field
from generalresearch.grliq.models.custom_types import GrlIqScore
from generalresearch.grliq.models.decider import (
- Decider,
AttemptDecision,
+ Decider,
GrlIqAttemptResult,
)
-from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO
+from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
class Phase(str, Enum):
@@ -21,10 +21,13 @@ class Phase(str, Enum):
# Within a custom offerwall. Very optional, as most BPs won't be running our code
OFFERWALL = "offerwall"
+
# When a user clicks on a bucket. Each session should go through this
OFFERWALL_ENTER = "offerwall-enter"
+
# Running in GRS. Not every session will have this.
PROFILING = "profiling"
+
# We could run grl-iq again when a user continues a session
SESSION_CONTINUE = "session-continue"
@@ -33,8 +36,9 @@ class GrlIqForensicCategoryResult(BaseModel):
"""
This is for reporting external to GRL.
- There is a balance between exposing enough to answer "why did this user get blocked?" without
- giving away technical knowledge that could be used to bypass.
+ There is a balance between exposing enough to answer "why did this user
+ get blocked?" without giving away technical knowledge that could be
+ used to bypass.
"""
model_config = ConfigDict(extra="forbid", validate_assignment=True)
diff --git a/generalresearch/grliq/models/forensic_summary.py b/generalresearch/grliq/models/forensic_summary.py
index 2d38435..a40e103 100644
--- a/generalresearch/grliq/models/forensic_summary.py
+++ b/generalresearch/grliq/models/forensic_summary.py
@@ -2,15 +2,15 @@ from __future__ import annotations
import random
from typing import (
+ Dict,
List,
Literal,
Optional,
Tuple,
- Dict,
- get_type_hints,
- get_origin,
Union,
get_args,
+ get_origin,
+ get_type_hints,
)
import numpy as np
@@ -19,17 +19,17 @@ from pydantic import (
ConfigDict,
Field,
NonNegativeInt,
- create_model,
computed_field,
+ create_model,
)
from scipy.stats import lognorm
from generalresearch.grliq.models.custom_types import GrlIqAvgScore, GrlIqRate
from generalresearch.grliq.models.forensic_result import (
- GrlIqCheckerResults,
GrlIqCheckerResult,
+ GrlIqCheckerResults,
)
-from generalresearch.models.custom_types import IPvAnyAddressStr, AwareDatetimeISO
+from generalresearch.models.custom_types import AwareDatetimeISO, IPvAnyAddressStr
from generalresearch.models.thl.locales import CountryISO
from generalresearch.models.thl.maxmind.definitions import UserType
diff --git a/generalresearch/grliq/utils.py b/generalresearch/grliq/utils.py
index 6e563f9..c772e7f 100644
--- a/generalresearch/grliq/utils.py
+++ b/generalresearch/grliq/utils.py
@@ -1,11 +1,11 @@
import os
from datetime import datetime, timezone
+from pathlib import Path
from typing import Optional, Union
from uuid import UUID
# from generalresearch.config import
from generalresearch.models.custom_types import UUIDStr
-from pathlib import Path
def get_screenshot_fp(
diff --git a/generalresearch/incite/base.py b/generalresearch/incite/base.py
index bd346f9..6f1a9be 100644
--- a/generalresearch/incite/base.py
+++ b/generalresearch/incite/base.py
@@ -8,20 +8,21 @@ import shutil
import subprocess
import warnings
from concurrent.futures import Future
-from datetime import datetime, timezone, timedelta
-from os import access, R_OK, listdir
-from os.path import join as pjoin, isdir
+from datetime import datetime, timedelta, timezone
+from os import R_OK, access, listdir
+from os.path import isdir
+from os.path import join as pjoin
from pathlib import Path
from sys import platform
from typing import (
- Optional,
- Tuple,
+ TYPE_CHECKING,
+ Any,
+ Callable,
List,
+ Optional,
Sequence,
- Any,
+ Tuple,
Union,
- Callable,
- TYPE_CHECKING,
)
from uuid import uuid4
@@ -35,12 +36,12 @@ from pydantic import (
BaseModel,
ConfigDict,
DirectoryPath,
- PrivateAttr,
Field,
- model_validator,
FilePath,
- field_validator,
+ PrivateAttr,
ValidationInfo,
+ field_validator,
+ model_validator,
)
from pydantic.json_schema import SkipJsonSchema
from sentry_sdk import capture_exception
@@ -54,11 +55,11 @@ from generalresearch.incite.schemas import (
from generalresearch.models.custom_types import AwareDatetimeISO
if TYPE_CHECKING:
- from generalresearch.incite.mergers import MergeType, MergeCollection
+ from generalresearch.incite.collections import DFCollection
from generalresearch.incite.collections.thl_marketplaces import (
DFCollectionType,
)
- from generalresearch.incite.collections import DFCollection
+ from generalresearch.incite.mergers import MergeCollection, MergeType
Collection = Union[DFCollection, MergeCollection]
@@ -93,10 +94,10 @@ class GRLDatasets(BaseModel):
@model_validator(mode="after")
def check_data_src_and_et_path(self) -> Self:
- from generalresearch.incite.mergers import MergeType
from generalresearch.incite.collections.thl_marketplaces import (
DFCollectionType,
)
+ from generalresearch.incite.mergers import MergeType
# Create the base folders and confirm we have read access
self.data_src.mkdir(parents=True, exist_ok=True)
diff --git a/generalresearch/incite/collections/__init__.py b/generalresearch/incite/collections/__init__.py
index 051c5a1..9049e8f 100644
--- a/generalresearch/incite/collections/__init__.py
+++ b/generalresearch/incite/collections/__init__.py
@@ -5,7 +5,7 @@ import time
from datetime import datetime
from enum import Enum
from sys import platform
-from typing import Optional, List
+from typing import Any, Dict, List, Optional
import dask
import dask.dataframe as dd
@@ -16,13 +16,13 @@ from distributed import Client, as_completed
from more_itertools import chunked
from pandera import DataFrameSchema
from psycopg import Cursor
-from pydantic import Field, FilePath, field_validator, ValidationInfo
+from pydantic import Field, FilePath, ValidationInfo, field_validator
from sentry_sdk import capture_exception
from generalresearch.incite.base import CollectionBase, CollectionItemBase
from generalresearch.incite.schemas import (
- ORDER_KEY,
ARCHIVE_AFTER,
+ ORDER_KEY,
PARTITION_ON,
empty_dataframe_from_schema,
)
@@ -33,18 +33,18 @@ from generalresearch.incite.schemas.thl_marketplaces import (
SpectrumSurveyTimeseriesSchema,
)
from generalresearch.incite.schemas.thl_web import (
- TxSchema,
- TxMetaSchema,
- THLUserSchema,
+ LedgerSchema,
+ THLIPInfoSchema,
+ THLSessionSchema,
THLTaskAdjustmentSchema,
+ THLUserSchema,
THLWallSchema,
- THLSessionSchema,
- THLIPInfoSchema,
TransactionMetadataColumns,
- UserHealthIPHistorySchema,
+ TxMetaSchema,
+ TxSchema,
UserHealthAuditLogSchema,
+ UserHealthIPHistorySchema,
UserHealthIPHistoryWSSchema,
- LedgerSchema,
)
from generalresearch.pg_helper import PostgresConfig
from generalresearch.sql_helper import SqlHelper
@@ -167,7 +167,7 @@ class DFCollectionItem(CollectionItemBase):
return self.to_archive(ddf=dd.from_pandas(_df, npartitions=1), is_partial=True)
# --- ORM / Data handlers---
- def to_dict(self, *args, **kwargs) -> dict:
+ def to_dict(self, *args, **kwargs) -> Dict[str, Any]:
return self._to_dict()
def from_mysql(self, since: Optional[datetime] = None) -> Optional[pd.DataFrame]:
@@ -503,7 +503,7 @@ class DFCollectionItem(CollectionItemBase):
subprocess.call(["mv", "-T", tmp_path.as_posix(), self.path.as_posix()])
return True
- def to_archive_numbered_partial(self, ddf: dd.DataFrame) -> bool:
+ def to_archive_numbered_partial(self, ddf: Optional[dd.DataFrame] = None) -> bool:
"""
For partial files/dirs only. Writes the .partial file with a number
at the end (.partial.####) and then creates a symlink
@@ -513,13 +513,14 @@ class DFCollectionItem(CollectionItemBase):
"""
if ddf is None:
return False
+
collection = self._collection
schema = collection._schema
client: Optional[Client] = collection._client
next_numbered_path = self.next_numbered_path(self.partial_path)
partial_path = self.partial_path
- finish = self.finish
+ # finish = self.finish
# Make sure these are in the same dir. b/c the symlink has to be
# relative, not an absolute path
@@ -571,7 +572,7 @@ class DFCollectionItem(CollectionItemBase):
assert self.should_archive(), "not ready to archive!"
- df: pd.DataFrame = self.from_mysql()
+ df: Optional[pd.DataFrame] = self.from_mysql()
if df is None:
self.set_empty()
@@ -648,9 +649,9 @@ class DFCollection(CollectionBase):
def initial_load(
self,
client: Optional[Client] = None,
- sync=True,
+ sync: bool = True,
since: Optional[datetime] = None,
- client_resources=None,
+ client_resources: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
) -> List[Future]:
# This can be used to just build all local archive files
@@ -733,10 +734,15 @@ class DFCollection(CollectionBase):
return sources
def force_rr_latest(
- self, client: Client, client_resources=None, sync: bool = True
+ self,
+ client: Client,
+ client_resources: Optional[Dict[str, Any]] = None,
+ sync: bool = True,
) -> List[Future]:
+
# For forcing update of any partials asynchronously if desired
LOG.info(f"{self.data_type.value}.force_rr_latest({client=})")
+
rr_items = [
i for i in self.items if not i.should_archive() and not i.is_empty()
]
diff --git a/generalresearch/incite/defaults.py b/generalresearch/incite/defaults.py
index e200ddc..421710e 100644
--- a/generalresearch/incite/defaults.py
+++ b/generalresearch/incite/defaults.py
@@ -9,11 +9,11 @@ from generalresearch.incite.collections.thl_marketplaces import (
SpectrumSurveyTimeseriesCollection,
)
from generalresearch.incite.collections.thl_web import (
+ LedgerDFCollection,
SessionDFCollection,
- WallDFCollection,
- UserDFCollection,
TaskAdjustmentDFCollection,
- LedgerDFCollection,
+ UserDFCollection,
+ WallDFCollection,
)
from generalresearch.incite.mergers import MergeType
from generalresearch.incite.mergers.foundations.enriched_session import (
@@ -33,7 +33,6 @@ from generalresearch.incite.mergers.ym_survey_wall import YMSurveyWallMerge
from generalresearch.pg_helper import PostgresConfig
from generalresearch.sql_helper import SqlHelper
-
# --- THL Web --- #
diff --git a/generalresearch/incite/mergers/foundations/__init__.py b/generalresearch/incite/mergers/foundations/__init__.py
index d3db74c..9ce9d91 100644
--- a/generalresearch/incite/mergers/foundations/__init__.py
+++ b/generalresearch/incite/mergers/foundations/__init__.py
@@ -1,8 +1,9 @@
import logging
-from typing import Collection, List, Dict
+from typing import Any, Collection, Dict, List
import pandas as pd
from more_itertools import chunked
+from pydantic import PositiveInt
from generalresearch.pg_helper import PostgresConfig
@@ -10,7 +11,7 @@ LOG = logging.getLogger("incite")
def annotate_product_id(
- df: pd.DataFrame, pg_config: PostgresConfig, chunksize=500
+ df: pd.DataFrame, pg_config: PostgresConfig, chunksize: PositiveInt = 500
) -> pd.DataFrame:
"""
Dask map_partitions is being called on a dask df. However, the function
@@ -27,7 +28,7 @@ def annotate_product_id(
assert len(user_ids) >= 1, "must have user_ids"
LOG.warning(f"annotate_product_id.len(user_ids): {len(user_ids)}")
- res: List[Dict] = []
+ res: List[Dict[str, Any]] = []
with pg_config.make_connection() as conn:
for chunk in chunked(user_ids, chunksize):
try:
@@ -53,15 +54,17 @@ def annotate_product_id(
def lookup_product_and_team_id(
user_ids: Collection[int],
pg_config: PostgresConfig,
-) -> List[Dict]:
+) -> List[Dict[str, Any]]:
+
user_ids = set(user_ids)
LOG.info(f"lookup_product_and_team_id: {len(user_ids)}")
LOG.info({type(x) for x in user_ids})
+
assert all(type(x) is int for x in user_ids), "must pass all integers"
assert len(user_ids) >= 1, "must have user_ids"
assert len(user_ids) <= 1000, "you should chunk this bro"
- res: List[Dict] = []
+ res: List[Dict[str, Any]] = []
with pg_config.make_connection() as conn:
try:
with conn.cursor() as c:
@@ -87,7 +90,7 @@ def lookup_product_and_team_id(
def annotate_product_and_team_id(
- df: pd.DataFrame, pg_config: PostgresConfig, chunksize=500
+ df: pd.DataFrame, pg_config: PostgresConfig, chunksize: PositiveInt = 500
) -> pd.DataFrame:
"""
Dask map_partitions is being called on a dask df. However, the function
@@ -95,8 +98,8 @@ def annotate_product_and_team_id(
df AS a pandas df.
expects column 'user_id', adds column 'product_id' and team_id
-
"""
+
LOG.info(f"annotate_product_and_team_id.chunk: {df.shape}")
assert "user_id" in df.columns, "must have a user_id column to join on"
@@ -105,7 +108,7 @@ def annotate_product_and_team_id(
assert len(user_ids) >= 1, "must have user_ids"
LOG.warning(f"annotate_product_and_team_id.len(user_ids): {len(user_ids)}")
- res: List[Dict] = []
+ res: List[Dict[str, Any]] = []
with pg_config.make_connection() as conn:
for chunk in chunked(user_ids, chunksize):
try:
@@ -133,7 +136,7 @@ def annotate_product_and_team_id(
def annotate_product_user(
- df: pd.DataFrame, pg_config: PostgresConfig, chunksize=500
+ df: pd.DataFrame, pg_config: PostgresConfig, chunksize: PositiveInt = 500
) -> pd.DataFrame:
LOG.info(f"annotate_product_user.chunk: {df.shape}")
assert "user_id" in df.columns, "must have a user_id column to join on"
@@ -143,7 +146,7 @@ def annotate_product_user(
assert len(user_ids) >= 1, "must have user_ids"
LOG.warning(f"annotate_product_user.len(user_ids): {len(user_ids)}")
- res: List[Dict] = []
+ res: List[Dict[str, Any]] = []
with pg_config.make_connection() as conn:
for chunk in chunked(user_ids, chunksize):
try:
diff --git a/generalresearch/incite/mergers/foundations/enriched_session.py b/generalresearch/incite/mergers/foundations/enriched_session.py
index 7fdcb50..4b87df7 100644
--- a/generalresearch/incite/mergers/foundations/enriched_session.py
+++ b/generalresearch/incite/mergers/foundations/enriched_session.py
@@ -1,6 +1,6 @@
import logging
from datetime import timedelta
-from typing import Literal, Optional, List, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
import dask.dataframe as dd
import pandas as pd
@@ -45,7 +45,7 @@ class EnrichedSessionMergeItem(MergeCollectionItem):
wall_coll: WallDFCollection,
pg_config: PostgresConfig,
client: Optional[Client] = None,
- client_resources=None,
+ client_resources: Optional[Dict[str, Any]] = None,
) -> None:
ir: pd.Interval = self.interval
diff --git a/generalresearch/incite/mergers/foundations/enriched_task_adjust.py b/generalresearch/incite/mergers/foundations/enriched_task_adjust.py
index 1749f9a..d2a8aa1 100644
--- a/generalresearch/incite/mergers/foundations/enriched_task_adjust.py
+++ b/generalresearch/incite/mergers/foundations/enriched_task_adjust.py
@@ -1,5 +1,5 @@
import logging
-from typing import Literal
+from typing import Any, Dict, Literal, Optional
import dask.dataframe as dd
import pandas as pd
@@ -11,8 +11,8 @@ from generalresearch.incite.collections.thl_web import (
)
from generalresearch.incite.mergers import (
MergeCollection,
- MergeType,
MergeCollectionItem,
+ MergeType,
)
from generalresearch.incite.mergers.foundations import (
annotate_product_and_team_id,
@@ -40,7 +40,7 @@ class EnrichedTaskAdjustMergeItem(MergeCollectionItem):
enriched_wall: EnrichedWallMerge,
pg_config: PostgresConfig,
client: Client,
- client_resources=None,
+ client_resources: Optional[Dict[str, Any]] = None,
) -> None:
"""
TaskAdjustments are always partial because they could be revoked
@@ -61,7 +61,7 @@ class EnrichedTaskAdjustMergeItem(MergeCollectionItem):
if len(task_adj_coll_items) == 0:
raise Exception("TaskAdjColl item collection failed")
- ddf: dd.DataFrame = task_adj_coll.ddf(
+ ddf: Optional[dd.DataFrame] = task_adj_coll.ddf(
items=task_adj_coll_items,
include_partial=True,
force_rr_latest=False,
@@ -114,6 +114,7 @@ class EnrichedTaskAdjustMergeItem(MergeCollectionItem):
assert str(ddf.wall_uuid.dtype) == "string"
assert str(wall_ddf.index.dtype) == "string"
+
ddf = ddf.merge(
wall_ddf,
left_on="wall_uuid",
diff --git a/generalresearch/incite/mergers/foundations/enriched_wall.py b/generalresearch/incite/mergers/foundations/enriched_wall.py
index 241a239..d69293b 100644
--- a/generalresearch/incite/mergers/foundations/enriched_wall.py
+++ b/generalresearch/incite/mergers/foundations/enriched_wall.py
@@ -1,6 +1,6 @@
import logging
from datetime import timedelta
-from typing import Literal, Optional, TYPE_CHECKING, List
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
import dask.dataframe as dd
import pandas as pd
@@ -40,7 +40,7 @@ class EnrichedWallMergeItem(MergeCollectionItem):
session_coll: SessionDFCollection,
pg_config: PostgresConfig,
client: Optional[Client] = None,
- client_resources=None,
+ client_resources: Optional[Dict[str, Any]] = None,
) -> None:
ir: pd.Interval = self.interval
diff --git a/generalresearch/incite/mergers/foundations/user_id_product.py b/generalresearch/incite/mergers/foundations/user_id_product.py
index e467179..73ee36a 100644
--- a/generalresearch/incite/mergers/foundations/user_id_product.py
+++ b/generalresearch/incite/mergers/foundations/user_id_product.py
@@ -1,12 +1,12 @@
import logging
-from typing import Literal
+from typing import Any, Dict, Literal, Optional
from distributed import Client
from generalresearch.incite.collections.thl_web import UserDFCollection
from generalresearch.incite.mergers import (
- MergeCollectionItem,
MergeCollection,
+ MergeCollectionItem,
MergeType,
)
@@ -16,7 +16,10 @@ LOG = logging.getLogger("incite")
class UserIdProductMergeItem(MergeCollectionItem):
def build(
- self, client: Client, user_coll: UserDFCollection, client_resources=None
+ self,
+ client: Client,
+ user_coll: UserDFCollection,
+ client_resources: Optional[Dict[str, Any]] = None,
) -> None:
LOG.warning(f"UserIdProductMergeItem.build({self.interval})")
@@ -34,10 +37,13 @@ class UserIdProductMergeItem(MergeCollectionItem):
class UserIdProductMerge(MergeCollection):
merge_type: Literal[MergeType.USER_ID_PRODUCT] = MergeType.USER_ID_PRODUCT
collection_item_class: Literal[UserIdProductMergeItem] = UserIdProductMergeItem
- offset: None = None
+ offset: Optional[str] = None
def build(
- self, client: Client, user_coll: UserDFCollection, client_resources=None
+ self,
+ client: Client,
+ user_coll: UserDFCollection,
+ client_resources: Optional[Dict[str, Any]] = None,
) -> None:
LOG.info(f"UserIdProductMerge.build(user_coll={user_coll.signature()})")
diff --git a/generalresearch/incite/mergers/nginx_core.py b/generalresearch/incite/mergers/nginx_core.py
index 1f471a6..343da8f 100644
--- a/generalresearch/incite/mergers/nginx_core.py
+++ b/generalresearch/incite/mergers/nginx_core.py
@@ -1,7 +1,7 @@
import json
import logging
from datetime import datetime, timedelta
-from typing import Literal, List
+from typing import List, Literal
from urllib.parse import parse_qs, urlsplit
import dask.bag as db
diff --git a/generalresearch/incite/mergers/nginx_fsb.py b/generalresearch/incite/mergers/nginx_fsb.py
index 1f1039b..9cb71d3 100644
--- a/generalresearch/incite/mergers/nginx_fsb.py
+++ b/generalresearch/incite/mergers/nginx_fsb.py
@@ -1,14 +1,13 @@
import json
import logging
-
-from sentry_sdk import capture_exception
import re
from datetime import datetime, timedelta
-from typing import Literal, List
+from typing import List, Literal
from urllib.parse import parse_qs, urlsplit
import dask.bag as db
import pandas as pd
+from sentry_sdk import capture_exception
from generalresearch.incite.mergers import (
MergeCollection,
diff --git a/generalresearch/incite/mergers/nginx_grs.py b/generalresearch/incite/mergers/nginx_grs.py
index 0242b0b..fb22070 100644
--- a/generalresearch/incite/mergers/nginx_grs.py
+++ b/generalresearch/incite/mergers/nginx_grs.py
@@ -1,7 +1,7 @@
import json
import logging
from datetime import datetime, timedelta
-from typing import Literal, List
+from typing import List, Literal
from urllib.parse import parse_qs, urlsplit
import dask.bag as db
diff --git a/generalresearch/incite/mergers/pop_ledger.py b/generalresearch/incite/mergers/pop_ledger.py
index 4915abb..0475df2 100644
--- a/generalresearch/incite/mergers/pop_ledger.py
+++ b/generalresearch/incite/mergers/pop_ledger.py
@@ -1,5 +1,5 @@
import logging
-from typing import Literal, Optional
+from typing import Any, Dict, Literal, Optional
import dask.dataframe as dd
import pandas as pd
@@ -24,7 +24,7 @@ class PopLedgerMergeItem(MergeCollectionItem):
self,
ledger_coll: LedgerDFCollection,
client: Optional[Client] = None,
- client_resources=None,
+ client_resources: Optional[Dict[str, Any]] = None,
) -> None:
ir: pd.Interval = self.interval
diff --git a/generalresearch/incite/mergers/ym_survey_wall.py b/generalresearch/incite/mergers/ym_survey_wall.py
index 7f6c31c..4c2defb 100644
--- a/generalresearch/incite/mergers/ym_survey_wall.py
+++ b/generalresearch/incite/mergers/ym_survey_wall.py
@@ -1,6 +1,6 @@
import logging
from datetime import timedelta
-from typing import Optional, Literal
+from typing import Any, Dict, Literal, Optional
import dask.dataframe as dd
import pandas as pd
@@ -10,8 +10,8 @@ from sentry_sdk import capture_exception
from generalresearch.incite.collections.thl_web import WallDFCollection
from generalresearch.incite.mergers import (
MergeCollection,
- MergeType,
MergeCollectionItem,
+ MergeType,
)
from generalresearch.incite.mergers.foundations.enriched_session import (
EnrichedSessionMerge,
@@ -31,7 +31,7 @@ class YMSurveyWallMergeCollectionItem(MergeCollectionItem):
wall_coll: WallDFCollection,
enriched_session: EnrichedSessionMerge,
client: Optional[Client] = None,
- client_resources=None,
+ client_resources: Optional[Dict[str, Any]] = None,
) -> None:
LOG.info(f"YMSurveyWallMerge.build({self.start=}, {self.finish=})")
ir: pd.Interval = self.interval
diff --git a/generalresearch/incite/mergers/ym_wall_summary.py b/generalresearch/incite/mergers/ym_wall_summary.py
index 419994d..c7b871c 100644
--- a/generalresearch/incite/mergers/ym_wall_summary.py
+++ b/generalresearch/incite/mergers/ym_wall_summary.py
@@ -1,5 +1,5 @@
-from datetime import timedelta, datetime, time
-from typing import Literal, List, Optional, Type
+from datetime import datetime, time, timedelta
+from typing import List, Literal, Optional, Type
import dask.dataframe as dd
import pandas as pd
@@ -12,8 +12,8 @@ from generalresearch.incite.collections.thl_web import (
)
from generalresearch.incite.mergers import (
MergeCollection,
- MergeType,
MergeCollectionItem,
+ MergeType,
)
from generalresearch.incite.mergers.foundations.user_id_product import (
UserIdProductMerge,
diff --git a/generalresearch/incite/schemas/admin_responses.py b/generalresearch/incite/schemas/admin_responses.py
index 73c0aaa..ccfe1cf 100644
--- a/generalresearch/incite/schemas/admin_responses.py
+++ b/generalresearch/incite/schemas/admin_responses.py
@@ -1,12 +1,12 @@
from datetime import datetime
from pandera import (
- DataFrameSchema,
- Column,
Check,
- Parser,
- MultiIndex,
+ Column,
+ DataFrameSchema,
Index,
+ MultiIndex,
+ Parser,
Timestamp,
)
diff --git a/generalresearch/incite/schemas/mergers/foundations/enriched_session.py b/generalresearch/incite/schemas/mergers/foundations/enriched_session.py
index 4badfac..ee17bda 100644
--- a/generalresearch/incite/schemas/mergers/foundations/enriched_session.py
+++ b/generalresearch/incite/schemas/mergers/foundations/enriched_session.py
@@ -1,6 +1,6 @@
from datetime import timedelta
-from pandera import DataFrameSchema, Column, Check
+from pandera import Check, Column, DataFrameSchema
from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY, PARTITION_ON
from generalresearch.incite.schemas.thl_web import THLSessionSchema
diff --git a/generalresearch/incite/schemas/mergers/foundations/enriched_task_adjust.py b/generalresearch/incite/schemas/mergers/foundations/enriched_task_adjust.py
index 35b579b..7b43589 100644
--- a/generalresearch/incite/schemas/mergers/foundations/enriched_task_adjust.py
+++ b/generalresearch/incite/schemas/mergers/foundations/enriched_task_adjust.py
@@ -1,7 +1,8 @@
-import pandas as pd
-from pandera import DataFrameSchema, Column, Check, Index
from typing import Set
+import pandas as pd
+from pandera import Check, Column, DataFrameSchema, Index
+
from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY
from generalresearch.incite.schemas.thl_web import THLTaskAdjustmentSchema
from generalresearch.locales import Localelator
diff --git a/generalresearch/incite/schemas/mergers/foundations/enriched_wall.py b/generalresearch/incite/schemas/mergers/foundations/enriched_wall.py
index 1b78fde..8d16270 100644
--- a/generalresearch/incite/schemas/mergers/foundations/enriched_wall.py
+++ b/generalresearch/incite/schemas/mergers/foundations/enriched_wall.py
@@ -1,15 +1,15 @@
from datetime import timedelta
import pandas as pd
-from pandera import DataFrameSchema, Column, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
-from generalresearch.incite.schemas import PARTITION_ON, ARCHIVE_AFTER
+from generalresearch.incite.schemas import ARCHIVE_AFTER, PARTITION_ON
from generalresearch.locales import Localelator
from generalresearch.models import DeviceType, Source
from generalresearch.models.thl.definitions import (
+ ReportValue,
Status,
StatusCode1,
- ReportValue,
WallStatusCode2,
)
diff --git a/generalresearch/incite/schemas/mergers/foundations/user_id_product.py b/generalresearch/incite/schemas/mergers/foundations/user_id_product.py
index 780a3f2..160f05b 100644
--- a/generalresearch/incite/schemas/mergers/foundations/user_id_product.py
+++ b/generalresearch/incite/schemas/mergers/foundations/user_id_product.py
@@ -1,6 +1,6 @@
from datetime import timedelta
-from pandera import Column, Check, Index, Category, DataFrameSchema
+from pandera import Category, Check, Column, DataFrameSchema, Index
from generalresearch.incite.schemas import ARCHIVE_AFTER
diff --git a/generalresearch/incite/schemas/mergers/nginx.py b/generalresearch/incite/schemas/mergers/nginx.py
index 30e6fec..3d16738 100644
--- a/generalresearch/incite/schemas/mergers/nginx.py
+++ b/generalresearch/incite/schemas/mergers/nginx.py
@@ -5,9 +5,9 @@
from datetime import timedelta
import pandas as pd
-from pandera import DataFrameSchema, Column, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
-from generalresearch.incite.schemas import PARTITION_ON, ARCHIVE_AFTER
+from generalresearch.incite.schemas import ARCHIVE_AFTER, PARTITION_ON
NGINXBaseSchema = DataFrameSchema(
columns={
diff --git a/generalresearch/incite/schemas/mergers/pop_ledger.py b/generalresearch/incite/schemas/mergers/pop_ledger.py
index 25c7e68..af8e0c8 100644
--- a/generalresearch/incite/schemas/mergers/pop_ledger.py
+++ b/generalresearch/incite/schemas/mergers/pop_ledger.py
@@ -2,11 +2,11 @@ from datetime import timedelta
import pandas as pd
from more_itertools import flatten
-from pandera import DataFrameSchema, Column, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY, PARTITION_ON
from generalresearch.incite.schemas.thl_web import TxSchema
-from generalresearch.models.thl.ledger import TransactionType, Direction
+from generalresearch.models.thl.ledger import Direction, TransactionType
"""
- In reality, a multi-index would be appropriate here, but dask does not support this, so we're keeping it flat.
diff --git a/generalresearch/incite/schemas/mergers/ym_survey_wall.py b/generalresearch/incite/schemas/mergers/ym_survey_wall.py
index 2b2d266..3882bd1 100644
--- a/generalresearch/incite/schemas/mergers/ym_survey_wall.py
+++ b/generalresearch/incite/schemas/mergers/ym_survey_wall.py
@@ -1,9 +1,9 @@
from datetime import timedelta
-from pandera import DataFrameSchema, Column, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
-from generalresearch.incite.schemas import ORDER_KEY, ARCHIVE_AFTER
-from generalresearch.incite.schemas.thl_web import THLWallSchema, THLSessionSchema
+from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY
+from generalresearch.incite.schemas.thl_web import THLSessionSchema, THLWallSchema
thl_wall_columns = THLWallSchema.columns.copy()
diff --git a/generalresearch/incite/schemas/mergers/ym_wall_summary.py b/generalresearch/incite/schemas/mergers/ym_wall_summary.py
index fb34dec..5e97b47 100644
--- a/generalresearch/incite/schemas/mergers/ym_wall_summary.py
+++ b/generalresearch/incite/schemas/mergers/ym_wall_summary.py
@@ -1,7 +1,7 @@
from datetime import timedelta
from typing import Set
-from pandera import DataFrameSchema, Column, Check, Index
+from pandera import Check, Column, DataFrameSchema, Index
from generalresearch.incite.schemas import ARCHIVE_AFTER
from generalresearch.locales import Localelator
diff --git a/generalresearch/incite/schemas/thl_marketplaces.py b/generalresearch/incite/schemas/thl_marketplaces.py
index 286db6a..6d7b83b 100644
--- a/generalresearch/incite/schemas/thl_marketplaces.py
+++ b/generalresearch/incite/schemas/thl_marketplaces.py
@@ -2,9 +2,9 @@ import copy
from datetime import timedelta
import pandas as pd
-from pandera import Column, Check, Index, DataFrameSchema
+from pandera import Check, Column, DataFrameSchema, Index
-from generalresearch.incite.schemas import ORDER_KEY, ARCHIVE_AFTER
+from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY
BIGINT = 9223372036854775807
diff --git a/generalresearch/incite/schemas/thl_web.py b/generalresearch/incite/schemas/thl_web.py
index a644720..c57a233 100644
--- a/generalresearch/incite/schemas/thl_web.py
+++ b/generalresearch/incite/schemas/thl_web.py
@@ -1,19 +1,19 @@
-from datetime import timezone, datetime, timedelta
+from datetime import datetime, timedelta, timezone
import pandas as pd
-from pandera import DataFrameSchema, Column, Check, Index, MultiIndex
+from pandera import Check, Column, DataFrameSchema, Index, MultiIndex
-from generalresearch.incite.schemas import ORDER_KEY, ARCHIVE_AFTER
+from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY
from generalresearch.locales import Localelator
from generalresearch.models import DeviceType, Source
from generalresearch.models.thl.definitions import (
- StatusCode1,
- WallStatusCode2,
ReportValue,
- WallAdjustedStatus,
- Status,
- SessionStatusCode2,
SessionAdjustedStatus,
+ SessionStatusCode2,
+ Status,
+ StatusCode1,
+ WallAdjustedStatus,
+ WallStatusCode2,
)
from generalresearch.models.thl.ledger import TransactionMetadataColumns
from generalresearch.models.thl.maxmind.definitions import UserType
diff --git a/generalresearch/managers/__init__.py b/generalresearch/managers/__init__.py
index 8af988c..bc745fd 100644
--- a/generalresearch/managers/__init__.py
+++ b/generalresearch/managers/__init__.py
@@ -1,4 +1,4 @@
-def parse_order_by(order_by_str: str):
+def parse_order_by(order_by_str: str) -> str:
"""
Converts django-rest-framework ordering str to mysql clause
:param order_by_str: e.g. 'created,-name'
diff --git a/generalresearch/managers/cint/profiling.py b/generalresearch/managers/cint/profiling.py
index 4a7fc69..5e6e46a 100644
--- a/generalresearch/managers/cint/profiling.py
+++ b/generalresearch/managers/cint/profiling.py
@@ -1,5 +1,5 @@
import json
-from typing import List, Collection, Optional, Tuple
+from typing import Collection, List, Optional, Tuple
from generalresearch.models.cint.question import CintQuestion
from generalresearch.sql_helper import SqlHelper
@@ -24,10 +24,13 @@ def get_profiling_library(
:param is_live: filters on is_live field
:param pks: The pk is (question_key, country_iso, language_iso). pks accepts a collection of
len(3) tuples. e.g. [('CORE_AUTOMOTIVE_0002', 'us', 'eng'), ('AGE', 'us', 'spa')]
+
:return:
"""
+
filters = []
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`country_iso` = %(country_iso)s")
@@ -46,8 +49,10 @@ def get_profiling_library(
if pks:
params["pks"] = pks
filters.append("(question_id, country_iso, language_iso) IN %(pks)s")
+
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
+
res = sql_helper.execute_sql_query(
f"""
SELECT *
@@ -56,7 +61,9 @@ def get_profiling_library(
""",
params,
)
+
for x in res:
x["options"] = json.loads(x["options"]) if x["options"] else None
+
qs = [CintQuestion.from_db(x) for x in res]
return qs
diff --git a/generalresearch/managers/cint/survey.py b/generalresearch/managers/cint/survey.py
index b7045c9..e8298fd 100644
--- a/generalresearch/managers/cint/survey.py
+++ b/generalresearch/managers/cint/survey.py
@@ -1,15 +1,15 @@
from __future__ import annotations
import logging
-from datetime import timezone, datetime
-from typing import List, Collection, Optional, Set
+from datetime import datetime, timezone
+from typing import Collection, List, Optional, Set
import pymysql
from pymysql import IntegrityError
from generalresearch.managers.criteria import CriteriaManager
from generalresearch.managers.survey import SurveyManager
-from generalresearch.models.cint.survey import CintSurvey, CintCondition
+from generalresearch.models.cint.survey import CintCondition, CintSurvey
logger = logging.getLogger()
diff --git a/generalresearch/managers/dynata/profiling.py b/generalresearch/managers/dynata/profiling.py
index 10d6c69..39e6591 100644
--- a/generalresearch/managers/dynata/profiling.py
+++ b/generalresearch/managers/dynata/profiling.py
@@ -1,5 +1,5 @@
import json
-from typing import List, Collection, Optional, Tuple
+from typing import Collection, List, Optional, Tuple
from generalresearch.models.dynata.question import DynataQuestion
from generalresearch.sql_helper import SqlHelper
@@ -24,10 +24,13 @@ def get_profiling_library(
:param is_live: filters on is_live field
:param pks: The pk is (question_id, country_iso, language_iso). pks accepts a collection of
len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')]
+
:return:
"""
+
filters = []
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`country_iso` = %(country_iso)s")
@@ -46,8 +49,10 @@ def get_profiling_library(
if pks:
params["pks"] = pks
filters.append("(question_id, country_iso, language_iso) IN %(pks)s")
+
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
+
res = sql_helper.execute_sql_query(
f"""
SELECT *
diff --git a/generalresearch/managers/dynata/survey.py b/generalresearch/managers/dynata/survey.py
index a20345a..d20c1c8 100644
--- a/generalresearch/managers/dynata/survey.py
+++ b/generalresearch/managers/dynata/survey.py
@@ -1,15 +1,15 @@
from __future__ import annotations
import logging
-from datetime import timezone, datetime
-from typing import List, Collection, Optional
+from datetime import datetime, timezone
+from typing import Collection, List, Optional
import pymysql
from pymysql import IntegrityError
from generalresearch.managers.criteria import CriteriaManager
from generalresearch.managers.survey import SurveyManager
-from generalresearch.models.dynata.survey import DynataSurvey, DynataCondition
+from generalresearch.models.dynata.survey import DynataCondition, DynataSurvey
logger = logging.getLogger()
@@ -64,8 +64,10 @@ class DynataSurveyManager(SurveyManager):
:param is_live: filters on is_live field
:param updated_since: filters on "> last_updated"
"""
+
filters = []
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`country_iso` = %(country_iso)s")
@@ -81,6 +83,7 @@ class DynataSurveyManager(SurveyManager):
if updated_since is not None:
params["updated_since"] = updated_since
filters.append("last_updated > %(updated_since)s")
+
assert filters, "Must set at least 1 filter"
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
diff --git a/generalresearch/managers/events.py b/generalresearch/managers/events.py
index 3e87879..c0c0d0f 100644
--- a/generalresearch/managers/events.py
+++ b/generalresearch/managers/events.py
@@ -1,34 +1,34 @@
import logging
+import math
import socket
import threading
import time
-from datetime import datetime, timezone, timedelta
+from datetime import datetime, timedelta, timezone
from decimal import Decimal
-from typing import Set, Optional, TYPE_CHECKING, Dict, List
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
-import math
from redis.client import PubSub, Redis
from generalresearch.managers.base import RedisManager
from generalresearch.models import Source
from generalresearch.models.custom_types import UUIDStr
from generalresearch.models.events import (
- StatsMessage,
- EventMessage,
+ AggregateBySource,
EventEnvelope,
+ EventMessage,
EventType,
- TaskEnterPayload,
- ServerToClientMessageAdapter,
+ MaxGaugeBySource,
ServerToClientMessage,
- TaskFinishPayload,
+ ServerToClientMessageAdapter,
SessionEnterPayload,
SessionFinishPayload,
- AggregateBySource,
- MaxGaugeBySource,
+ StatsMessage,
+ TaskEnterPayload,
+ TaskFinishPayload,
TaskStatsSnapshot,
)
from generalresearch.models.thl.definitions import Status
-from generalresearch.models.thl.session import Wall, Session
+from generalresearch.models.thl.session import Session, Wall
from generalresearch.models.thl.user import User
if TYPE_CHECKING:
@@ -309,7 +309,9 @@ class TaskStatsManager(RedisManager):
def get_active_sources(self) -> List[Source]:
return [Source(x) for x in self.redis_client.hkeys("live_task_count")]
- def get_task_stats_raw(self):
+ def get_task_stats_raw(
+ self,
+ ) -> Dict[str, Union[AggregateBySource, MaxGaugeBySource]]:
sources = self.get_active_sources()
pipe = self.redis_client.pipeline(transaction=False)
@@ -365,7 +367,7 @@ class TaskStatsManager(RedisManager):
"live_tasks_max_payout": live_tasks_max_payout,
}
- def clear_task_stats(self):
+ def clear_task_stats(self) -> None:
keys = self.task_stats.copy()
keys.extend([f"task_created_count_last_1h:{source.value}" for source in Source])
keys.extend(
@@ -373,6 +375,8 @@ class TaskStatsManager(RedisManager):
)
self.redis_client.delete(*keys)
+ return None
+
class SessionStatsManager(RedisManager):
"""
@@ -633,9 +637,6 @@ class EventManager(StatsManager):
def get_active_subscribers(self) -> Set[UUIDStr]:
res = self.redis_client.pubsub_channels(f"{self.cache_prefix}:event-channel:*")
product_ids = {x.rsplit(":", 1)[-1] for x in res}
- # product_ids.update(
- # {"fc14e741b5004581b30e6478363414df", "888dbc589987425fa846d6e2a8daed04"}
- # )
return product_ids
def stats_worker(self):
@@ -647,7 +648,7 @@ class EventManager(StatsManager):
finally:
time.sleep(60)
- def stats_worker_task(self):
+ def stats_worker_task(self) -> None:
"""
Only a single worker will be running. It'll be responsible
for periodic publication of summary/stats messages.
@@ -682,6 +683,8 @@ class EventManager(StatsManager):
self.redis_client.delete(lock_key)
+ return None
+
def make_influx_point(self, channel: str, numsub: int):
return {
"measurement": "redis_pubsub_subscribers",
diff --git a/generalresearch/managers/gr/authentication.py b/generalresearch/managers/gr/authentication.py
index f851cfa..7b8e526 100644
--- a/generalresearch/managers/gr/authentication.py
+++ b/generalresearch/managers/gr/authentication.py
@@ -2,22 +2,21 @@ import binascii
import logging
import os
from datetime import datetime, timezone
-from typing import Optional, List, TYPE_CHECKING, Dict, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from uuid import uuid4
from psycopg import sql
-from pydantic import AnyHttpUrl
-from pydantic import PositiveInt
+from pydantic import AnyHttpUrl, PositiveInt
-from generalresearch.managers.base import PostgresManagerWithRedis, PostgresManager
+from generalresearch.managers.base import PostgresManager, PostgresManagerWithRedis
from generalresearch.models.custom_types import UUIDStr
-from generalresearch.redis_helper import RedisConfig
from generalresearch.pg_helper import PostgresConfig
+from generalresearch.redis_helper import RedisConfig
LOG = logging.getLogger("gr")
if TYPE_CHECKING:
- from generalresearch.models.gr.authentication import GRUser, GRToken
+ from generalresearch.models.gr.authentication import GRToken, GRUser
class GRUserManager(PostgresManagerWithRedis):
@@ -194,7 +193,7 @@ class GRTokenManager(PostgresManager):
def get_by_key(
self,
api_key: str,
- jwks: Optional[Dict] = None,
+ jwks: Optional[Dict[str, Any]] = None,
audience: Optional[str] = None,
issuer: Optional[Union[AnyHttpUrl, str]] = None,
gr_redis_config: Optional[RedisConfig] = None,
@@ -210,7 +209,7 @@ class GRTokenManager(PostgresManager):
:return GRToken instance (minified version, no relationships)
:raises NotFoundException
"""
- from generalresearch.models.gr.authentication import GRToken, Claims
+ from generalresearch.models.gr.authentication import Claims, GRToken
# SSO Key
if GRToken.is_sso(api_key):
@@ -324,7 +323,7 @@ class GRTokenManager(PostgresManager):
res = result[0]
- for k, v in res.items():
+ for k, _ in res.items():
if isinstance(res[k], datetime):
res[k] = res[k].replace(tzinfo=timezone.utc)
diff --git a/generalresearch/managers/gr/business.py b/generalresearch/managers/gr/business.py
index 001a9e8..e9ec580 100644
--- a/generalresearch/managers/gr/business.py
+++ b/generalresearch/managers/gr/business.py
@@ -1,4 +1,4 @@
-from typing import Optional, List, TYPE_CHECKING
+from typing import TYPE_CHECKING, List, Optional
from uuid import UUID, uuid4
from psycopg import sql
@@ -6,20 +6,20 @@ from pydantic import PositiveInt
from pydantic_extra_types.phone_numbers import PhoneNumber
from generalresearch.managers.base import (
- PostgresManagerWithRedis,
PostgresManager,
+ PostgresManagerWithRedis,
)
from generalresearch.models.custom_types import UUIDStr
if TYPE_CHECKING:
- from generalresearch.models.gr.team import Team
from generalresearch.models.gr.business import (
Business,
- BusinessType,
BusinessAddress,
BusinessBankAccount,
+ BusinessType,
TransferMethod,
)
+ from generalresearch.models.gr.team import Team
class BusinessBankAccountManager(PostgresManager):
@@ -27,7 +27,7 @@ class BusinessBankAccountManager(PostgresManager):
def create_dummy(
self,
business_id: PositiveInt,
- uuid: Optional[UUID] = None,
+ uuid: Optional[UUIDStr] = None,
transfer_method: Optional["TransferMethod"] = None,
account_number: Optional[str] = None,
routing_number: Optional[str] = None,
@@ -36,21 +36,14 @@ class BusinessBankAccountManager(PostgresManager):
):
from generalresearch.models.gr.business import TransferMethod
- uuid = uuid or uuid4().hex
- transfer_method = transfer_method or TransferMethod.ACH
- account_number = account_number or uuid4().hex[:6]
- routing_number = routing_number or uuid4().hex[:6]
- iban = iban or uuid4().hex[:6]
- swift = swift or uuid4().hex[:6]
-
return self.create(
business_id=business_id,
- uuid=uuid,
- transfer_method=transfer_method,
- account_number=account_number,
- routing_number=routing_number,
- iban=iban,
- swift=swift,
+ uuid=uuid or uuid4().hex,
+ transfer_method=transfer_method or TransferMethod.ACH,
+ account_number=account_number or uuid4().hex[:6],
+ routing_number=routing_number or uuid4().hex[:6],
+ iban=iban or uuid4().hex[:6],
+ swift=swift or uuid4().hex[:6],
)
def create(
@@ -95,7 +88,7 @@ class BusinessBankAccountManager(PostgresManager):
),
params=data,
)
- ba_id = c.fetchone()["id"]
+ ba_id = c.fetchone()["id"] # type: ignore
conn.commit()
ba.id = ba_id
@@ -194,14 +187,15 @@ class BusinessAddressManager(PostgresManager):
(uuid, line_1, line_2, city, country, state,
postal_code, phone_number, business_id)
VALUES
- (%(uuid)s, %(line_1)s, %(line_2)s, %(city)s, %(country)s, %(state)s,
- %(postal_code)s, %(phone_number)s, %(business_id)s)
+ (%(uuid)s, %(line_1)s, %(line_2)s, %(city)s,
+ %(country)s, %(state)s, %(postal_code)s,
+ %(phone_number)s, %(business_id)s)
RETURNING id
"""
),
params=data,
)
- ba_id = c.fetchone()["id"]
+ ba_id = c.fetchone()["id"] # type: ignore
conn.commit()
ba.id = ba_id
@@ -304,7 +298,7 @@ class BusinessManager(PostgresManagerWithRedis):
),
params=data,
)
- business_id = c.fetchone()["id"]
+ business_id = c.fetchone()["id"] # type: ignore
conn.commit()
business.id = business_id
diff --git a/generalresearch/managers/gr/team.py b/generalresearch/managers/gr/team.py
index a57ef5f..c6709d0 100644
--- a/generalresearch/managers/gr/team.py
+++ b/generalresearch/managers/gr/team.py
@@ -1,5 +1,5 @@
from datetime import datetime, timezone
-from typing import Optional, List, TYPE_CHECKING
+from typing import TYPE_CHECKING, List, Optional
from uuid import uuid4
from psycopg import sql
@@ -13,13 +13,13 @@ from generalresearch.models.custom_types import UUIDStr
from generalresearch.models.gr.team import Membership, MembershipPrivilege
if TYPE_CHECKING:
+ from generalresearch.models.gr.authentication import GRUser
+ from generalresearch.models.gr.business import Business
from generalresearch.models.gr.team import (
Membership,
- Team,
MembershipPrivilege,
+ Team,
)
- from generalresearch.models.gr.authentication import GRUser
- from generalresearch.models.gr.business import Business
class MembershipManager(PostgresManager):
@@ -77,7 +77,7 @@ class MembershipManager(PostgresManager):
),
params=data,
)
- membership_id: int = c.fetchone()["id"]
+ membership_id: int = c.fetchone()["id"] # type: ignore
conn.commit()
membership.id = membership_id
@@ -205,7 +205,7 @@ class TeamManager(PostgresManagerWithRedis):
),
params=[team.uuid, team.name],
)
- team_id = c.fetchone()["id"]
+ team_id = c.fetchone()["id"] # type: ignore
conn.commit()
team.id = team_id
diff --git a/generalresearch/managers/innovate/profiling.py b/generalresearch/managers/innovate/profiling.py
index a3939a7..81cb29e 100644
--- a/generalresearch/managers/innovate/profiling.py
+++ b/generalresearch/managers/innovate/profiling.py
@@ -1,5 +1,5 @@
import json
-from typing import List, Collection, Optional, Tuple
+from typing import Collection, List, Optional, Tuple
from generalresearch.models.innovate.question import InnovateQuestion
from generalresearch.sql_helper import SqlHelper
@@ -22,12 +22,16 @@ def get_profiling_library(
:param question_keys: filters on question_key field, accepts multiple values
:param max_options: filters on max_options field
:param is_live: filters on is_live field
- :param pks: The pk is (question_key, country_iso, language_iso). pks accepts a collection of
- len(3) tuples. e.g. [('CORE_AUTOMOTIVE_0002', 'us', 'eng'), ('AGE', 'us', 'spa')]
+ :param pks: The pk is (question_key, country_iso, language_iso). pks
+ accepts a collection of len(3) tuples. e.g. [('CORE_AUTOMOTIVE_0002',
+ 'us', 'eng'), ('AGE', 'us', 'spa')]
+
:return:
"""
+
filters = []
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`country_iso` = %(country_iso)s")
@@ -46,8 +50,10 @@ def get_profiling_library(
if pks:
params["pks"] = pks
filters.append("(question_key, country_iso, language_iso) IN %(pks)s")
+
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
+
res = sql_helper.execute_sql_query(
f"""
SELECT *
@@ -56,7 +62,9 @@ def get_profiling_library(
""",
params,
)
+
for x in res:
x["options"] = json.loads(x["options"]) if x["options"] else None
qs = [InnovateQuestion.from_db(x) for x in res]
+
return qs
diff --git a/generalresearch/managers/innovate/survey.py b/generalresearch/managers/innovate/survey.py
index 0e19065..bd86e34 100644
--- a/generalresearch/managers/innovate/survey.py
+++ b/generalresearch/managers/innovate/survey.py
@@ -1,8 +1,8 @@
from __future__ import annotations
import logging
-from datetime import timezone, datetime
-from typing import List, Collection, Optional, Set
+from datetime import datetime, timezone
+from typing import Collection, List, Optional, Set
import pymysql
from pymysql import IntegrityError
@@ -10,8 +10,8 @@ from pymysql import IntegrityError
from generalresearch.managers.criteria import CriteriaManager
from generalresearch.managers.survey import SurveyManager
from generalresearch.models.innovate.survey import (
- InnovateSurvey,
InnovateCondition,
+ InnovateSurvey,
)
logger = logging.getLogger()
@@ -77,8 +77,10 @@ class InnovateSurveyManager(SurveyManager):
:param is_live: filters on is_live field
:param updated_since: filters on "> updated"
"""
+
filters = []
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`country_iso` = %(country_iso)s")
@@ -88,18 +90,22 @@ class InnovateSurveyManager(SurveyManager):
if survey_ids is not None:
params["survey_ids"] = survey_ids
filters.append("survey_id IN %(survey_ids)s")
+
if is_live is not None:
if is_live:
filters.append("is_live")
else:
filters.append("NOT is_live")
+
if updated_since is not None:
params["updated_since"] = updated_since
filters.append("updated > %(updated_since)s")
+
assert filters, "Must set at least 1 filter"
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
fields = set(self.SURVEY_FIELDS) | {"created", "updated"}
+
if exclude_fields:
fields -= exclude_fields
fields_str = ", ".join([f"`{v}`" for v in fields])
diff --git a/generalresearch/managers/leaderboard/manager.py b/generalresearch/managers/leaderboard/manager.py
index 86b3d80..18dec92 100644
--- a/generalresearch/managers/leaderboard/manager.py
+++ b/generalresearch/managers/leaderboard/manager.py
@@ -1,19 +1,19 @@
-from datetime import datetime, timezone, timedelta
+from datetime import datetime, timedelta, timezone
from decimal import Decimal
from functools import cached_property
-from typing import Optional, TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional
import pandas as pd
from pandas import Period
-from pydantic import NaiveDatetime, AwareDatetime
+from pydantic import AwareDatetime, NaiveDatetime
from redis import Redis
from generalresearch.managers.leaderboard import country_timezone
from generalresearch.models.thl.leaderboard import (
+ Leaderboard,
LeaderboardCode,
LeaderboardFrequency,
LeaderboardRow,
- Leaderboard,
)
if TYPE_CHECKING:
@@ -130,7 +130,7 @@ class LeaderboardManager:
def get_leaderboard(
self,
limit: Optional[int] = None,
- bp_user_id=None,
+ bp_user_id: Optional[str] = None,
) -> Leaderboard:
if bp_user_id:
diff --git a/generalresearch/managers/leaderboard/tasks.py b/generalresearch/managers/leaderboard/tasks.py
index e25c4d8..072a1aa 100644
--- a/generalresearch/managers/leaderboard/tasks.py
+++ b/generalresearch/managers/leaderboard/tasks.py
@@ -3,11 +3,11 @@ import logging
from redis import Redis
from generalresearch.managers.leaderboard.manager import LeaderboardManager
-from generalresearch.models.thl.session import Session
from generalresearch.models.thl.leaderboard import (
- LeaderboardFrequency,
LeaderboardCode,
+ LeaderboardFrequency,
)
+from generalresearch.models.thl.session import Session
logger = logging.getLogger()
diff --git a/generalresearch/managers/lucid/profiling.py b/generalresearch/managers/lucid/profiling.py
index 5f084b1..ac2556d 100644
--- a/generalresearch/managers/lucid/profiling.py
+++ b/generalresearch/managers/lucid/profiling.py
@@ -1,8 +1,9 @@
import json
-from typing import List, Collection, Optional, Tuple
-from generalresearch.decorators import LOG
+from typing import Collection, List, Optional, Tuple
+
from pydantic import ValidationError
+from generalresearch.decorators import LOG
from generalresearch.models.lucid.question import LucidQuestion, LucidQuestionType
from generalresearch.sql_helper import SqlHelper
@@ -24,8 +25,10 @@ def get_profiling_library(
len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')]
:return:
"""
+
filters = ["`q`.question_type != 'o'"]
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`q`.`country_iso` = %(country_iso)s")
@@ -40,8 +43,10 @@ def get_profiling_library(
pks = [(int(x[0]), x[1], x[2]) for x in pks]
params["pks"] = pks
filters.append("(q.question_id, q.country_iso, q.language_iso) IN %(pks)s")
+
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
+
db_name = sql_helper.db_name
res = sql_helper.execute_sql_query(
query=f"""
@@ -74,6 +79,7 @@ def get_profiling_library(
if x["question_id"] in {"116", "120", "121"}:
x["question_type"] = LucidQuestionType.TEXT_ENTRY
qs = []
+
for x in res:
try:
qs.append(LucidQuestion.from_db(x))
diff --git a/generalresearch/managers/marketplace/user_pid.py b/generalresearch/managers/marketplace/user_pid.py
index f731179..cadaea9 100644
--- a/generalresearch/managers/marketplace/user_pid.py
+++ b/generalresearch/managers/marketplace/user_pid.py
@@ -1,5 +1,5 @@
from abc import ABC
-from typing import Collection, Optional, List, Dict
+from typing import Collection, Dict, List, Optional
from uuid import UUID
from generalresearch.managers.base import SqlManager
@@ -12,7 +12,7 @@ class UserPidManager(SqlManager, ABC):
For getting user pids across marketplaces
"""
- SOURCE: Source = None
+ SOURCE: Optional[Source] = None
TABLE_NAME = None
def filter(
diff --git a/generalresearch/managers/morning/profiling.py b/generalresearch/managers/morning/profiling.py
index c397748..494e406 100644
--- a/generalresearch/managers/morning/profiling.py
+++ b/generalresearch/managers/morning/profiling.py
@@ -1,5 +1,5 @@
import json
-from typing import List, Collection, Optional, Tuple
+from typing import Collection, List, Optional, Tuple
from generalresearch.models.morning.question import MorningQuestion
from generalresearch.sql_helper import SqlHelper
@@ -28,8 +28,10 @@ def get_profiling_library(
len(3) tuples. e.g. [('employer_size', 'us', 'eng'), ('employer_size', 'us', 'spa')]
:return:
"""
+
filters = []
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`country_iso` = %(country_iso)s")
@@ -51,8 +53,10 @@ def get_profiling_library(
if pks:
params["pks"] = pks
filters.append("(question_id, country_iso, language_iso) IN %(pks)s")
+
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
+
res = sql_helper.execute_sql_query(
f"""
SELECT *
@@ -61,6 +65,7 @@ def get_profiling_library(
""",
params,
)
+
for x in res:
x["options"] = json.loads(x["options"]) if x["options"] else None
qs = [MorningQuestion.from_db(x) for x in res]
diff --git a/generalresearch/managers/morning/survey.py b/generalresearch/managers/morning/survey.py
index 9b08a65..3478e92 100644
--- a/generalresearch/managers/morning/survey.py
+++ b/generalresearch/managers/morning/survey.py
@@ -2,8 +2,8 @@ from __future__ import annotations
import json
import logging
-from datetime import timezone, datetime
-from typing import List, Collection, Optional
+from datetime import datetime, timezone
+from typing import Collection, List, Optional
import pymysql
from pymysql import IntegrityError
@@ -184,7 +184,7 @@ class MorningSurveyManager(SurveyManager):
for survey in surveys:
self.update_one(survey, now=now)
- def update_one(self, bid: MorningBid, now=None) -> bool:
+ def update_one(self, bid: MorningBid, now: Optional[datetime] = None) -> bool:
if now is None:
now = datetime.now(tz=timezone.utc)
d = bid.to_mysql()
diff --git a/generalresearch/managers/pollfish/profiling.py b/generalresearch/managers/pollfish/profiling.py
index 735e824..4431784 100644
--- a/generalresearch/managers/pollfish/profiling.py
+++ b/generalresearch/managers/pollfish/profiling.py
@@ -1,5 +1,5 @@
import json
-from typing import List, Collection, Optional, Tuple
+from typing import Collection, List, Optional, Tuple
from generalresearch.models.pollfish.question import PollfishQuestion
from generalresearch.sql_helper import SqlHelper
@@ -26,8 +26,10 @@ def get_profiling_library(
len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')]
:return:
"""
+
filters = []
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`country_iso` = %(country_iso)s")
@@ -43,9 +45,11 @@ def get_profiling_library(
if is_live is not None:
params["is_live"] = is_live
filters.append("is_live = %(is_live)s")
+
if pks:
params["pks"] = pks
filters.append("(question_id, country_iso, language_iso) IN %(pks)s")
+
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
res = sql_helper.execute_sql_query(
diff --git a/generalresearch/managers/precision/profiling.py b/generalresearch/managers/precision/profiling.py
index 5f687ce..c4b24d8 100644
--- a/generalresearch/managers/precision/profiling.py
+++ b/generalresearch/managers/precision/profiling.py
@@ -1,5 +1,5 @@
import json
-from typing import List, Collection, Optional, Tuple
+from typing import Collection, List, Optional, Tuple
from generalresearch.models.precision.question import PrecisionQuestion
from generalresearch.sql_helper import SqlHelper
@@ -26,8 +26,10 @@ def get_profiling_library(
len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')]
:return:
"""
+
filters = []
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`country_iso` = %(country_iso)s")
@@ -46,6 +48,7 @@ def get_profiling_library(
if pks:
params["pks"] = pks
filters.append("(question_id, country_iso, language_iso) IN %(pks)s")
+
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
res = sql_helper.execute_sql_query(
diff --git a/generalresearch/managers/precision/survey.py b/generalresearch/managers/precision/survey.py
index fc4e037..112c938 100644
--- a/generalresearch/managers/precision/survey.py
+++ b/generalresearch/managers/precision/survey.py
@@ -1,8 +1,8 @@
from __future__ import annotations
import logging
-from datetime import timezone, datetime
-from typing import List, Collection, Optional
+from datetime import datetime, timezone
+from typing import Collection, List, Optional
import pymysql
from pymysql import IntegrityError
@@ -10,8 +10,8 @@ from pymysql import IntegrityError
from generalresearch.managers.criteria import CriteriaManager
from generalresearch.managers.survey import SurveyManager
from generalresearch.models.precision.survey import (
- PrecisionSurvey,
PrecisionCondition,
+ PrecisionSurvey,
)
logger = logging.getLogger()
@@ -62,8 +62,10 @@ class PrecisionSurveyManager(SurveyManager):
:param is_live: filters on is_live field
:param updated_since: filters on "> last_updated"
"""
+
filters = []
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`country_iso` = %(country_iso)s")
@@ -186,12 +188,12 @@ class PrecisionSurveyManager(SurveyManager):
country_data = [(survey.survey_id, c) for c in survey.country_isos]
# Turn ON countries in this survey's list of countries, insert row, if already exists, set active.
c.executemany(
- f"""
+ query=f"""
INSERT INTO `thl-precision`.`precision_survey_country`
(survey_id, country_iso, is_active) VALUES
(%s, %s, TRUE) ON DUPLICATE KEY UPDATE is_active = TRUE;
""",
- country_data,
+ args=country_data,
)
# Same thing with languages
@@ -205,12 +207,12 @@ class PrecisionSurveyManager(SurveyManager):
)
language_data = [(survey.survey_id, c) for c in survey.language_isos]
c.executemany(
- f"""
+ query=f"""
INSERT INTO `thl-precision`.`precision_survey_language`
(survey_id, language_iso, is_active) VALUES
(%s, %s, TRUE) ON DUPLICATE KEY UPDATE is_active = TRUE;
""",
- language_data,
+ args=language_data,
)
conn.commit()
diff --git a/generalresearch/managers/prodege/profiling.py b/generalresearch/managers/prodege/profiling.py
index a33364b..bf4e3cf 100644
--- a/generalresearch/managers/prodege/profiling.py
+++ b/generalresearch/managers/prodege/profiling.py
@@ -1,5 +1,5 @@
import json
-from typing import List, Collection, Optional, Tuple
+from typing import Collection, List, Optional, Tuple
from generalresearch.models.prodege.question import ProdegeQuestion
from generalresearch.sql_helper import SqlHelper
@@ -26,8 +26,10 @@ def get_profiling_library(
len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')]
:return:
"""
+
filters = []
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`country_iso` = %(country_iso)s")
@@ -46,6 +48,7 @@ def get_profiling_library(
if pks:
params["pks"] = pks
filters.append("(question_id, country_iso, language_iso) IN %(pks)s")
+
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
res = sql_helper.execute_sql_query(
@@ -56,6 +59,7 @@ def get_profiling_library(
""",
params,
)
+
for x in res:
x["options"] = json.loads(x["options"]) if x["options"] else None
qs = [ProdegeQuestion.from_db(x) for x in res]
diff --git a/generalresearch/managers/prodege/survey.py b/generalresearch/managers/prodege/survey.py
index ec5665e..75fa34e 100644
--- a/generalresearch/managers/prodege/survey.py
+++ b/generalresearch/managers/prodege/survey.py
@@ -1,13 +1,13 @@
from __future__ import annotations
-from datetime import timezone, datetime
-from typing import List, Collection, Optional
+from datetime import datetime, timezone
+from typing import Collection, List, Optional
import pymysql
from generalresearch.managers.criteria import CriteriaManager
from generalresearch.managers.survey import SurveyManager
-from generalresearch.models.prodege.survey import ProdegeSurvey, ProdegeCondition
+from generalresearch.models.prodege.survey import ProdegeCondition, ProdegeSurvey
class ProdegeCriteriaManager(CriteriaManager):
@@ -57,8 +57,10 @@ class ProdegeSurveyManager(SurveyManager):
:param is_live: filters on is_live field
:param updated_since: filters on "> updated"
"""
+
filters = []
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`country_iso` = %(country_iso)s")
diff --git a/generalresearch/managers/repdata/profiling.py b/generalresearch/managers/repdata/profiling.py
index c508764..6a63c38 100644
--- a/generalresearch/managers/repdata/profiling.py
+++ b/generalresearch/managers/repdata/profiling.py
@@ -1,5 +1,5 @@
import json
-from typing import List, Collection, Optional, Tuple
+from typing import Collection, List, Optional, Tuple
from generalresearch.models.repdata.question import RepDataQuestion
from generalresearch.sql_helper import SqlHelper
diff --git a/generalresearch/managers/repdata/survey.py b/generalresearch/managers/repdata/survey.py
index 05465e9..eecdc04 100644
--- a/generalresearch/managers/repdata/survey.py
+++ b/generalresearch/managers/repdata/survey.py
@@ -1,18 +1,18 @@
from __future__ import annotations
import json
-from datetime import timezone, datetime
-from typing import List, Collection, Optional
+from datetime import datetime, timezone
+from typing import Collection, List, Optional
import pymysql
from generalresearch.managers.criteria import CriteriaManager
from generalresearch.managers.survey import SurveyManager
from generalresearch.models.repdata.survey import (
+ RepDataCondition,
+ RepDataStreamHashed,
RepDataSurvey,
RepDataSurveyHashed,
- RepDataStreamHashed,
- RepDataCondition,
)
diff --git a/generalresearch/managers/sago/profiling.py b/generalresearch/managers/sago/profiling.py
index e1ed97f..6d00ec4 100644
--- a/generalresearch/managers/sago/profiling.py
+++ b/generalresearch/managers/sago/profiling.py
@@ -1,5 +1,5 @@
import json
-from typing import List, Collection, Optional, Tuple
+from typing import Collection, List, Optional, Tuple
from generalresearch.models.sago.question import SagoQuestion
from generalresearch.sql_helper import SqlHelper
@@ -24,10 +24,13 @@ def get_profiling_library(
:param is_live: filters on is_live field
:param pks: The pk is (question_id, country_iso, language_iso). pks accepts a collection of
len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')]
+
:return:
"""
+
filters = []
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`country_iso` = %(country_iso)s")
@@ -46,6 +49,7 @@ def get_profiling_library(
if pks:
params["pks"] = pks
filters.append("(question_id, country_iso, language_iso) IN %(pks)s")
+
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
res = sql_helper.execute_sql_query(
@@ -56,6 +60,7 @@ def get_profiling_library(
""",
params,
)
+
for x in res:
x["options"] = json.loads(x["options"]) if x["options"] else None
qs = [SagoQuestion.from_db(x) for x in res]
diff --git a/generalresearch/managers/sago/survey.py b/generalresearch/managers/sago/survey.py
index 535d8bb..12b37e0 100644
--- a/generalresearch/managers/sago/survey.py
+++ b/generalresearch/managers/sago/survey.py
@@ -1,15 +1,15 @@
from __future__ import annotations
import logging
-from datetime import timezone, datetime
-from typing import List, Collection, Optional, Set
+from datetime import datetime, timezone
+from typing import Collection, List, Optional, Set
import pymysql
from pymysql import IntegrityError
from generalresearch.managers.criteria import CriteriaManager
from generalresearch.managers.survey import SurveyManager
-from generalresearch.models.sago.survey import SagoSurvey, SagoCondition
+from generalresearch.models.sago.survey import SagoCondition, SagoSurvey
logger = logging.getLogger()
diff --git a/generalresearch/managers/spectrum/profiling.py b/generalresearch/managers/spectrum/profiling.py
index 21575c6..8a9b9a9 100644
--- a/generalresearch/managers/spectrum/profiling.py
+++ b/generalresearch/managers/spectrum/profiling.py
@@ -1,5 +1,5 @@
import json
-from typing import List, Collection, Optional, Tuple
+from typing import Collection, List, Optional, Tuple
from generalresearch.models.spectrum.question import SpectrumQuestion
from generalresearch.sql_helper import SqlHelper
@@ -26,8 +26,10 @@ def get_profiling_library(
len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')]
:return:
"""
+
filters = ["is_valid"]
params = {}
+
if country_iso:
params["country_iso"] = country_iso
filters.append("`country_iso` = %(country_iso)s")
@@ -46,8 +48,10 @@ def get_profiling_library(
if pks:
params["pks"] = pks
filters.append("(question_id, country_iso, language_iso) IN %(pks)s")
+
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
+
res = sql_helper.execute_sql_query(
f"""
SELECT *
diff --git a/generalresearch/managers/spectrum/survey.py b/generalresearch/managers/spectrum/survey.py
index 0dcc232..190450d 100644
--- a/generalresearch/managers/spectrum/survey.py
+++ b/generalresearch/managers/spectrum/survey.py
@@ -1,8 +1,8 @@
from __future__ import annotations
import logging
-from datetime import timezone, datetime
-from typing import List, Collection, Optional
+from datetime import datetime, timezone
+from typing import Collection, List, Optional
import pymysql
from pymysql import IntegrityError
@@ -10,8 +10,8 @@ from pymysql import IntegrityError
from generalresearch.managers.criteria import CriteriaManager
from generalresearch.managers.survey import SurveyManager
from generalresearch.models.spectrum.survey import (
- SpectrumSurvey,
SpectrumCondition,
+ SpectrumSurvey,
)
logger = logging.getLogger()
@@ -61,7 +61,7 @@ class SpectrumSurveyManager(SurveyManager):
survey_ids: Optional[Collection[str]] = None,
is_live: Optional[bool] = None,
updated_since: Optional[datetime] = None,
- fields=None,
+ fields: List[str] = None,
) -> List[SpectrumSurvey]:
"""
Accepts lots of optional filters.
@@ -93,6 +93,7 @@ class SpectrumSurveyManager(SurveyManager):
fields_str = "*"
if fields:
fields_str = ",".join(fields)
+
filter_str = " AND ".join(filters)
filter_str = "WHERE " + filter_str if filter_str else ""
diff --git a/generalresearch/managers/thl/buyer.py b/generalresearch/managers/thl/buyer.py
index dd0b4f2..b264945 100644
--- a/generalresearch/managers/thl/buyer.py
+++ b/generalresearch/managers/thl/buyer.py
@@ -1,7 +1,7 @@
from datetime import datetime, timezone
-from typing import Collection, Dict, Optional
+from typing import Collection, Dict, List, Optional
-from generalresearch.managers.base import PostgresManager, Permission
+from generalresearch.managers.base import Permission, PostgresManager
from generalresearch.models import Source
from generalresearch.models.thl.survey.buyer import Buyer
from generalresearch.pg_helper import PostgresConfig
@@ -42,10 +42,11 @@ class BuyerManager(PostgresManager):
except KeyError:
return None
- def bulk_get_or_create(self, source: Source, codes: Collection[str]):
+ def bulk_get_or_create(self, source: Source, codes: Collection[str]) -> List[Buyer]:
now = datetime.now(tz=timezone.utc)
buyers = []
params_seq = []
+
for code in codes:
source_code = f"{source.value}:{code}"
if source_code in self.source_code_buyer:
@@ -83,9 +84,10 @@ class BuyerManager(PostgresManager):
# Not required, just for ease of testing/deterministic
buyers = sorted(buyers, key=lambda x: (x.source, x.code))
assert len(buyers) == len(codes), "something went wrong"
+
return buyers
- def update(self, buyer: Buyer):
+ def update(self, buyer: Buyer) -> None:
# label is the only thing that can be updated
query = """
UPDATE marketplace_buyer
@@ -103,7 +105,7 @@ class BuyerManager(PostgresManager):
with conn.cursor() as c:
c.execute(query, params=params)
assert c.rowcount == 1
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
if buyer.id is not None:
assert buyer.id == pk
else:
diff --git a/generalresearch/managers/thl/cashout_method.py b/generalresearch/managers/thl/cashout_method.py
index 98e8e66..accbb41 100644
--- a/generalresearch/managers/thl/cashout_method.py
+++ b/generalresearch/managers/thl/cashout_method.py
@@ -1,14 +1,16 @@
from copy import copy
-from datetime import timezone, datetime
-from typing import List, Optional, Collection, Dict
-from uuid import uuid4, UUID
+from datetime import datetime, timezone
+from typing import Any, Collection, Dict, List, Optional
+from uuid import UUID, uuid4
+
+from pydantic import NonNegativeInt
from generalresearch.managers.base import PostgresManager
from generalresearch.models.thl.user import User
from generalresearch.models.thl.wallet import PayoutType
from generalresearch.models.thl.wallet.cashout_method import (
- CashoutMethod,
CashMailCashoutMethodData,
+ CashoutMethod,
PaypalCashoutMethodData,
)
@@ -43,24 +45,25 @@ class CashoutMethodManager(PostgresManager):
def delete_cashout_method(self, cm_id: str):
db_res = self.pg_config.execute_sql_query(
- f"""
+ query="""
SELECT id::uuid, user_id
FROM accounting_cashoutmethod
WHERE id = %s AND is_live
LIMIT 1;""",
- [cm_id],
+ params=[cm_id],
)
res = next(iter(db_res), None)
assert res, f"cashout method id {cm_id} not found"
# Don't let anyone delete a non-user-scoped cashout method
assert (
res["user_id"] is not None
- ), f"error trying to delete non user-scoped cashout method"
+ ), "error trying to delete non user-scoped cashout method"
+
self.pg_config.execute_write(
- f"""
- UPDATE accounting_cashoutmethod SET is_live = FALSE
- WHERE id = %s;""",
- [cm_id],
+ query="""
+ UPDATE accounting_cashoutmethod SET is_live = FALSE
+ WHERE id = %s;""",
+ params=[cm_id],
)
def create_cash_in_mail_cashout_method(
@@ -185,7 +188,7 @@ class CashoutMethodManager(PostgresManager):
ext_id: Optional[str] = None,
payout_types: Optional[Collection[PayoutType]] = None,
is_live: Optional[bool] = True,
- ) -> int:
+ ) -> NonNegativeInt:
filter_str, params = self.make_filter_str(
uuid=uuid,
user=user,
@@ -201,7 +204,7 @@ class CashoutMethodManager(PostgresManager):
""",
params=params,
)
- return res[0]["cnt"]
+ return int(res[0]["cnt"]) # type: ignore
def filter(
self,
@@ -248,7 +251,7 @@ class CashoutMethodManager(PostgresManager):
"supported_payout_types": [x.value for x in supported_payout_types],
"user_id": user.user_id,
}
- query = f"""
+ query = """
SELECT id::uuid, provider, ext_id, data::jsonb as _data_, user_id
FROM accounting_cashoutmethod
WHERE is_live
@@ -276,14 +279,16 @@ class CashoutMethodManager(PostgresManager):
return cms
@staticmethod
- def format_from_db(x: Dict, user: Optional[User] = None) -> CashoutMethod:
+ def format_from_db(x: Dict[str, Any], user: Optional[User] = None) -> CashoutMethod:
x["id"] = UUID(x["id"]).hex
+
# The data column here is inconsistent. Pulling keys from the mysql 'data' col
# and putting them into the base level. Renamed so that we don't overwrite
# a col called "data" within the "_data_" field.
for k in list(x["_data_"].keys()):
if k in CashoutMethod.model_fields:
x[k] = x["_data_"].pop(k)
+
x["type"] = PayoutType(x["provider"].upper())
if "data" not in x:
x["data"] = dict()
diff --git a/generalresearch/managers/thl/contest_manager.py b/generalresearch/managers/thl/contest_manager.py
index 4dc02e3..696a497 100644
--- a/generalresearch/managers/thl/contest_manager.py
+++ b/generalresearch/managers/thl/contest_manager.py
@@ -1,9 +1,9 @@
-from datetime import timezone, datetime
-from typing import List, Optional, Literal, cast, Collection, Tuple, Dict
+from datetime import datetime, timezone
+from typing import Any, Collection, Dict, List, Literal, Optional, Tuple, cast
from uuid import UUID
import redis
-from pydantic import PositiveInt, NonNegativeInt
+from pydantic import NonNegativeInt, PositiveInt
from redis import Redis
from generalresearch.managers.base import PostgresManager
@@ -15,8 +15,8 @@ from generalresearch.managers.thl.user_manager.user_manager import (
)
from generalresearch.models.custom_types import UUIDStr
from generalresearch.models.thl.contest import (
- ContestWinner,
ContestPrize,
+ ContestWinner,
)
from generalresearch.models.thl.contest.contest import (
Contest,
@@ -34,20 +34,20 @@ from generalresearch.models.thl.contest.io import (
user_model_cls,
)
from generalresearch.models.thl.contest.leaderboard import (
- LeaderboardContestUserView,
LeaderboardContest,
+ LeaderboardContestUserView,
)
from generalresearch.models.thl.contest.milestone import (
- MilestoneUserView,
- MilestoneEntry,
ContestEntryTrigger,
MilestoneContest,
+ MilestoneEntry,
+ MilestoneUserView,
)
from generalresearch.models.thl.contest.raffle import (
ContestEntry,
ContestEntryType,
- RaffleUserView,
RaffleContest,
+ RaffleUserView,
)
from generalresearch.models.thl.user import User
@@ -139,7 +139,7 @@ class ContestBaseManager(PostgresManager):
query=query,
params=data,
)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
contest.id = pk
@@ -182,9 +182,10 @@ class ContestBaseManager(PostgresManager):
name_contains: Optional[str] = None,
uuids: Optional[Collection[str]] = None,
has_participants: Optional[bool] = None,
- ) -> Tuple[str, Dict]:
+ ) -> Tuple[str, Dict[str, Any]]:
filters = []
params = dict()
+
if product_id:
params["product_id"] = product_id
filters.append("product_id = %(product_id)s")
@@ -207,6 +208,7 @@ class ContestBaseManager(PostgresManager):
if name_contains is not None:
params["name_contains"] = f"%{name_contains}%"
filters.append("name ILIKE %(name_contains)s")
+
if uuids is not None:
if len(uuids) == 0:
# If we pass an empty list, the sql query will have a syntax error. Make it
@@ -214,6 +216,7 @@ class ContestBaseManager(PostgresManager):
uuids = ["0" * 32]
params["uuids"] = uuids
filters.append("uuid = ANY(%(uuids)s)")
+
if has_participants:
filters.append("current_participants > 0")
@@ -711,7 +714,7 @@ class RaffleContestManager(ContestBaseManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(
- query=f"""
+ query="""
INSERT INTO contest_contestentry
(uuid, amount, user_id,
created_at, updated_at, contest_id)
@@ -721,7 +724,7 @@ class RaffleContestManager(ContestBaseManager):
params=data,
)
c.execute(
- query=f"""
+ query="""
UPDATE contest_contest
SET current_amount = %(current_amount)s,
current_participants = %(current_participants)s
@@ -985,11 +988,14 @@ class LeaderboardContestManager(ContestBaseManager):
def end_contest_if_over(
self, contest: Contest, ledger_manager: ThlLedgerManager
) -> None:
- decision, reason = contest.should_end()
+ decision, _ = contest.should_end()
+
if decision:
contest.end_contest()
return self.end_contest_with_winners(contest, ledger_manager)
+ return None
+
class ContestManager(
RaffleContestManager, MilestoneContestManager, LeaderboardContestManager
diff --git a/generalresearch/managers/thl/ipinfo.py b/generalresearch/managers/thl/ipinfo.py
index 563b179..f86573b 100644
--- a/generalresearch/managers/thl/ipinfo.py
+++ b/generalresearch/managers/thl/ipinfo.py
@@ -1,7 +1,7 @@
import ipaddress
from decimal import Decimal
from random import randint
-from typing import List, Optional, Dict, Collection
+from typing import Collection, Dict, List, Optional
import faker
import pymysql
@@ -14,12 +14,12 @@ from generalresearch.managers.base import (
PostgresManagerWithRedis,
)
from generalresearch.models.custom_types import (
- IPvAnyAddressStr,
CountryISOLike,
+ IPvAnyAddressStr,
)
from generalresearch.models.thl.ipinfo import (
- IPGeoname,
GeoIPInformation,
+ IPGeoname,
IPInformation,
normalize_ip,
)
@@ -83,14 +83,15 @@ class IPGeonameManager(PostgresManager):
}
)
self.pg_config.execute_write(
- query=f"""
+ query="""
INSERT INTO thl_geoname (
geoname_id, country_iso, is_in_european_union, country_name,
continent_code, continent_name, updated
)
VALUES (
- %(geoname_id)s, %(country_iso)s, %(is_in_european_union)s, %(country_name)s,
- %(continent_code)s, %(continent_name)s, %(updated)s
+ %(geoname_id)s, %(country_iso)s, %(is_in_european_union)s,
+ %(country_name)s, %(continent_code)s, %(continent_name)s,
+ %(updated)s
)
ON CONFLICT (geoname_id) DO NOTHING;
""",
@@ -151,7 +152,7 @@ class IPGeonameManager(PostgresManager):
)
self.pg_config.execute_write(
- query=f"""
+ query="""
INSERT INTO thl_geoname
( geoname_id, continent_code, continent_name,
country_iso, country_name,
@@ -165,8 +166,8 @@ class IPGeonameManager(PostgresManager):
%(country_iso)s, %(country_name)s,
%(subdivision_1_iso)s, %(subdivision_1_name)s,
%(subdivision_2_iso)s, %(subdivision_2_name)s,
- %(city_name)s, %(metro_code)s, %(time_zone)s, %(is_in_european_union)s,
- %(updated)s
+ %(city_name)s, %(metro_code)s, %(time_zone)s,
+ %(is_in_european_union)s, %(updated)s
)
ON CONFLICT (geoname_id) DO NOTHING;
""",
@@ -208,7 +209,7 @@ class IPGeonameManager(PostgresManager):
assert len(filter_ids) <= 500, "chunk me"
c.execute(
- query=f"""
+ query="""
SELECT g.geoname_id,
g.continent_code, g.continent_name,
g.country_iso, g.country_name,
@@ -298,10 +299,11 @@ class IPInformationManager(PostgresManager):
)
instance.normalize_ip()
self.pg_config.execute_write(
- query=f"""
+ query="""
INSERT INTO thl_ipinformation
- (ip, country_iso, registered_country_iso, geoname_id, updated)
- VALUES (%(ip)s, %(country_iso)s, %(registered_country_iso)s, %(geoname_id)s, %(updated)s)
+ (ip, country_iso, registered_country_iso, geoname_id, updated)
+ VALUES (%(ip)s, %(country_iso)s, %(registered_country_iso)s,
+ %(geoname_id)s, %(updated)s)
ON CONFLICT (ip) DO NOTHING;
""",
params=instance.model_dump(mode="json"),
@@ -367,7 +369,7 @@ class IPInformationManager(PostgresManager):
instance.normalize_ip()
self.pg_config.execute_write(
- query=f"""
+ query="""
INSERT INTO thl_ipinformation
( ip, geoname_id,
country_iso, registered_country_iso,
@@ -466,7 +468,7 @@ class IPInformationManager(PostgresManager):
normalized_ips = set(normalized_ip_lookup.values())
c.execute(
- query=f"""
+ query="""
SELECT i.ip, i.geoname_id,
i.country_iso, i.registered_country_iso,
i.is_anonymous, i.is_anonymous_vpn, i.is_hosting_provider,
@@ -500,9 +502,9 @@ class IPInformationManager(PostgresManager):
WHERE updated >= NOW() - INTERVAL '12 hours'
AND country_iso IS NULL;
"""
- numerator = list(pg_config.execute_sql_query(query=query))[0]["numerator"]
+ # numerator = list(pg_config.execute_sql_query(query=query))[0]["numerator"]
- query = f"""
+ query = """
SELECT COUNT(1) AS denominator
FROM thl_ipinformation
WHERE updated >= NOW() - INTERVAL '12 hours'
@@ -510,10 +512,11 @@ class IPInformationManager(PostgresManager):
denominator = list(pg_config.execute_sql_query(query=query))[0]["denominator"]
if denominator == 0:
pass
- percent_empty = numerator / (denominator or 1)
+
+ # percent_empty = numerator / (denominator or 1)
# TODO: Post to telegraf / grafana
- return
+ return None
class GeoIpInfoManager(PostgresManagerWithRedis):
@@ -748,7 +751,7 @@ class GeoIpInfoManager(PostgresManagerWithRedis):
normalized_ips.update({self.compress_ip(ip) for ip in ips})
c.execute(
- query=f"""
+ query="""
SELECT
geo.geoname_id,
geo.continent_name,
@@ -801,9 +804,11 @@ class GeoIpInfoManager(PostgresManagerWithRedis):
raise ValueError(
f'mismatch between ipinfo country {d["country_iso"]} and geoname country {d["geo_country_iso"]}'
)
+
gs = [GeoIPInformation.from_mysql(i) for i in res]
gs = {g.ip: g for g in gs}
res2 = dict()
+
for ip, (normalized_ip, lookup_prefix) in ip_norm_lookup.items():
if normalized_ip not in gs:
# also can remove 28 days after 2025-11-15
diff --git a/generalresearch/managers/thl/ledger_manager/conditions.py b/generalresearch/managers/thl/ledger_manager/conditions.py
index 3c03300..399d28f 100644
--- a/generalresearch/managers/thl/ledger_manager/conditions.py
+++ b/generalresearch/managers/thl/ledger_manager/conditions.py
@@ -1,12 +1,12 @@
import logging
-from datetime import datetime, timezone, timedelta
-from typing import Callable, Optional, Tuple, TYPE_CHECKING
+from datetime import datetime, timedelta, timezone
+from typing import TYPE_CHECKING, Callable, Optional, Tuple
-from generalresearch.config import JAMES_BILLINGS_TX_CUTOFF, JAMES_BILLINGS_BPID
+from generalresearch.config import JAMES_BILLINGS_BPID, JAMES_BILLINGS_TX_CUTOFF
from generalresearch.currency import USDCent
from generalresearch.models.custom_types import UUIDStr
from generalresearch.models.thl.product import Product
-from generalresearch.models.thl.session import Wall, Session
+from generalresearch.models.thl.session import Session, Wall
from generalresearch.models.thl.user import User
logging.basicConfig()
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
)
-def generate_condition_mp_payment(wall: "Wall") -> Callable:
+def generate_condition_mp_payment(wall: "Wall") -> Callable[..., bool]:
"""This returns a function that checks if the payment for this wall event
exists already. This function gets run after we acquire a lock. It
should return True if we want to continue (create a tx).
@@ -37,7 +37,7 @@ def generate_condition_mp_payment(wall: "Wall") -> Callable:
return _condition
-def generate_condition_bp_payment(session: "Session") -> Callable:
+def generate_condition_bp_payment(session: "Session") -> Callable[..., bool]:
"""This returns a function that checks if the payment for this Session
exists already. This function gets run after we acquire a lock. It
should return True if we want to continue (create a tx).
@@ -52,7 +52,7 @@ def generate_condition_bp_payment(session: "Session") -> Callable:
return _condition
-def generate_condition_tag_exists(tag: str) -> Callable:
+def generate_condition_tag_exists(tag: str) -> Callable[..., bool]:
"""This returns a function that checks if a tx with this tag already
exists. It should return True if we want to continue (create a tx).
"""
@@ -70,7 +70,7 @@ def generate_condition_bp_payout(
payoutevent_uuid: UUIDStr,
skip_one_per_day_check: bool = False,
skip_wallet_balance_check: bool = False,
-) -> Callable:
+) -> Callable[..., Tuple[bool, str]]:
created = datetime.now(tz=timezone.utc)
def _condition(
@@ -110,7 +110,7 @@ def generate_condition_bp_payout(
def generate_condition_user_payout_request(
user: User, payoutevent_uuid: UUIDStr, min_balance: Optional[int] = None
-) -> Callable:
+) -> Callable[..., bool]:
"""This returns a function that checks if `user` has at least
`min_balance` in their wallet and that a payout request hasn't already
been issued with this payoutevent_uuid.
@@ -150,7 +150,7 @@ def generate_condition_user_payout_request(
def generate_condition_enter_contest(
user: User, tag: str, min_balance: USDCent
-) -> Callable:
+) -> Callable[..., Tuple[bool, str]]:
"""This returns a function that checks if `user` has at least
`min_balance` in their wallet and that a tx doesn't already exist
with this tag
@@ -177,7 +177,7 @@ def generate_condition_enter_contest(
def generate_condition_user_payout_action(
payoutevent_uuid: UUIDStr, action: str
-) -> Callable:
+) -> Callable[..., bool]:
"""The balance has already been taken from the user's wallet, so there
is no balance check. We only just check that the ledger transaction
doesn't already exist.
diff --git a/generalresearch/managers/thl/ledger_manager/ledger.py b/generalresearch/managers/thl/ledger_manager/ledger.py
index 3c65cbd..d463b55 100644
--- a/generalresearch/managers/thl/ledger_manager/ledger.py
+++ b/generalresearch/managers/thl/ledger_manager/ledger.py
@@ -1,12 +1,12 @@
import logging
from collections import defaultdict
from datetime import datetime, timedelta, timezone
-from typing import Callable, Collection, Dict, List, Optional, Set, Tuple
+from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Union
from uuid import UUID
import redis
from more_itertools import chunked, flatten
-from pydantic import AwareDatetime, PositiveInt
+from pydantic import AwareDatetime, NonNegativeInt, PositiveInt
from redis.exceptions import LockError, LockNotOwnedError
from generalresearch.currency import LedgerCurrency
@@ -206,10 +206,10 @@ class LedgerTransactionManager(LedgerManagerBasePostgres):
def create_tx_protected(
self,
lock_key: str,
- condition: Callable,
+ condition: Callable[..., Union[bool, Tuple[bool, str]]],
create_tx_func: Callable,
flag_key: Optional[str] = None,
- skip_flag_check=False,
+ skip_flag_check: bool = False,
) -> LedgerTransaction:
"""
The complexity here is that even we protect a transaction logic with
@@ -287,12 +287,15 @@ class LedgerTransactionManager(LedgerManagerBasePostgres):
# taking a long time, this would allow the tx to get retried
# over and over every 4 seconds, which is not good.
rc.set(name=flag_name, value=1, ex=3600 * 24)
+
# Condition returns either bool or Tuple[bool, str]
condition_res = condition(self)
+
if isinstance(condition_res, tuple):
condition_res, condition_msg = condition_res
else:
condition_msg = ""
+
if condition_res is False:
rc.delete(flag_name)
raise LedgerTransactionConditionFailedError(condition_msg)
@@ -404,7 +407,7 @@ class LedgerTransactionManager(LedgerManagerBasePostgres):
@staticmethod
def process_get_tx_mysql_rows_json(
- rows: Collection[Dict],
+ rows: Collection[Dict[str, Any]],
) -> List[LedgerTransaction]:
"""Columns: transaction_id, created, ext_description, tag,
key_value_pairs, entries_json
@@ -493,7 +496,7 @@ class LedgerTransactionManager(LedgerManagerBasePostgres):
account_uuid: UUIDStr,
time_start: Optional[datetime] = None,
time_end: Optional[datetime] = None,
- ):
+ ) -> NonNegativeInt:
filter_str, params = self.make_filter_str(
time_start=time_start,
time_end=time_end,
@@ -533,7 +536,7 @@ class LedgerTransactionManager(LedgerManagerBasePostgres):
account_uuid: str,
oldest_created: datetime,
exclude_txs_before: Optional[datetime] = None,
- ):
+ ) -> NonNegativeInt:
"""
In a paginated list of txs, if I want to calculate
a running balance, I need the balance in that account
@@ -560,6 +563,7 @@ class LedgerTransactionManager(LedgerManagerBasePostgres):
AND lt.created < %(oldest_created)s
{exclude_str};"""
res = self.pg_config.execute_sql_query(query, params=params)
+
return res[0]["balance_before_page"]
def include_running_balance(
@@ -735,7 +739,7 @@ class LedgerMetadataManager(LedgerManagerBasePostgres):
def get_tx_metadata_by_txs(
self, transactions: List[LedgerTransaction]
- ) -> Dict[PositiveInt, Dict]:
+ ) -> Dict[PositiveInt, Dict[str, Any]]:
"""
Each transaction can have 1 metadata dictionary. However, each
metadata dictionary can have multiple key/value pairs that
@@ -847,7 +851,7 @@ class LedgerAccountManager(LedgerManagerBasePostgres):
return account
def get_account(
- self, qualified_name: str, raise_on_error=True
+ self, qualified_name: str, raise_on_error: bool = True
) -> Optional[LedgerAccount]:
res = self.get_account_many(
qualified_names=[qualified_name], raise_on_error=raise_on_error
@@ -855,8 +859,8 @@ class LedgerAccountManager(LedgerManagerBasePostgres):
return res[0] if len(res) == 1 else None
def get_account_many_(
- self, qualified_names: List[str], raise_on_error=True
- ) -> List[Dict]:
+ self, qualified_names: List[str], raise_on_error: bool = True
+ ) -> List[Dict[str, Any]]:
assert len(qualified_names) <= 500, "chunk me"
# qualified_name has a unique index so there can only be 0 or 1 match.
@@ -878,7 +882,7 @@ class LedgerAccountManager(LedgerManagerBasePostgres):
return list(res)
def get_account_many(
- self, qualified_names: List[str], raise_on_error=True
+ self, qualified_names: List[str], raise_on_error: bool = True
) -> List[LedgerAccount]:
res = flatten(
[
@@ -1103,7 +1107,7 @@ class LedgerManager(
self,
time_start: Optional[AwareDatetime] = None,
time_end: Optional[AwareDatetime] = None,
- ) -> Dict:
+ ) -> Dict[str, Any]:
filter_str, params = self.make_filter_str(
time_end=time_end,
diff --git a/generalresearch/managers/thl/ledger_manager/thl_ledger.py b/generalresearch/managers/thl/ledger_manager/thl_ledger.py
index 977b68f..8004320 100644
--- a/generalresearch/managers/thl/ledger_manager/thl_ledger.py
+++ b/generalresearch/managers/thl/ledger_manager/thl_ledger.py
@@ -1,7 +1,7 @@
import logging
-from datetime import datetime, timezone, timedelta
+from datetime import datetime, timedelta, timezone
from decimal import Decimal
-from typing import Optional, Callable, Collection, List, TYPE_CHECKING
+from typing import TYPE_CHECKING, Callable, Collection, List, Optional
from uuid import UUID
import numpy as np
@@ -15,13 +15,13 @@ from generalresearch.config import (
from generalresearch.currency import USDCent
from generalresearch.managers.base import Permission
from generalresearch.managers.thl.ledger_manager.conditions import (
- generate_condition_mp_payment,
generate_condition_bp_payment,
generate_condition_bp_payout,
- generate_condition_user_payout_request,
- generate_condition_user_payout_action,
- generate_condition_tag_exists,
generate_condition_enter_contest,
+ generate_condition_mp_payment,
+ generate_condition_tag_exists,
+ generate_condition_user_payout_action,
+ generate_condition_user_payout_request,
)
from generalresearch.managers.thl.ledger_manager.ledger import (
LedgerManager,
@@ -39,18 +39,20 @@ from generalresearch.models.thl.contest.raffle import (
RaffleContest,
)
from generalresearch.models.thl.ledger import (
- LedgerAccount,
+ AccountType,
Direction,
- LedgerTransaction,
+ LedgerAccount,
LedgerEntry,
- AccountType,
+ LedgerTransaction,
TransactionType,
- TransactionMetadataColumns as tmc,
UserLedgerTransactions,
)
+from generalresearch.models.thl.ledger import (
+ TransactionMetadataColumns as tmc,
+)
from generalresearch.models.thl.payout import UserPayoutEvent
from generalresearch.models.thl.product import Product
-from generalresearch.models.thl.session import Status, Session, Wall
+from generalresearch.models.thl.session import Session, Status, Wall
from generalresearch.models.thl.user import User
from generalresearch.models.thl.wallet import PayoutType
@@ -207,12 +209,18 @@ class ThlLedgerManager(LedgerManager):
return self.get_account_or_create(account=account)
def get_account_task_complete_revenue(self) -> "LedgerAccount":
- return self.get_account(
+ res = self.get_account(
qualified_name=f"{self.currency.value}:revenue:task_complete"
)
+ assert res is not None, "Revenue account for task complete does not exist"
+
+ return res
def get_account_cash(self) -> "LedgerAccount":
- return self.get_account(qualified_name=f"{self.currency.value}:cash")
+ res = self.get_account(qualified_name=f"{self.currency.value}:cash")
+ assert res is not None, "Cash account does not exist"
+
+ return res
def get_accounts_bp_wallet_for_products(
self, product_uuids: Collection[UUIDStr]
@@ -263,7 +271,7 @@ class ThlLedgerManager(LedgerManager):
wall: Wall,
user: User,
created: Optional[datetime] = None,
- force=False,
+ force: bool = False,
) -> PositiveInt:
"""
Create a transaction when we complete a task from a marketplace,
@@ -302,7 +310,10 @@ class ThlLedgerManager(LedgerManager):
}
# This tag should uniquely identify this transaction (which should only happen once!)
tag = f"{self.currency.value}:mp_payment:{wall.uuid}"
+
+ assert wall.cpi is not None
amount = round(wall.cpi * 100)
+
entries = [
LedgerEntry(
direction=Direction.CREDIT,
@@ -327,7 +338,7 @@ class ThlLedgerManager(LedgerManager):
return t
def create_tx_bp_payment(
- self, session: Session, created: Optional[datetime] = None, force=False
+ self, session: Session, created: Optional[datetime] = None, force: bool = False
) -> LedgerTransaction:
"""
Create a transaction when we decide to report a session as complete
@@ -750,9 +761,9 @@ class ThlLedgerManager(LedgerManager):
amount: USDCent,
payoutevent_uuid: UUIDStr,
created: AwareDatetime,
- skip_wallet_balance_check=False,
- skip_one_per_day_check=False,
- skip_flag_check=False,
+ skip_wallet_balance_check: bool = False,
+ skip_one_per_day_check: bool = False,
+ skip_flag_check: bool = False,
) -> LedgerTransaction:
"""This is when we pay "OUT" a BP their wallet balance. (Not a
payment for a task complete)
@@ -859,7 +870,7 @@ class ThlLedgerManager(LedgerManager):
created: AwareDatetime,
direction: Direction = Direction.DEBIT,
description: Optional[str] = None,
- skip_flag_check=False,
+ skip_flag_check: bool = False,
) -> LedgerTransaction:
"""https://en.wikipedia.org/wiki/Plug_(accounting)
@@ -1869,6 +1880,7 @@ class ThlLedgerManager(LedgerManager):
),
columns=["finished", "user_payout"],
)
+
if wall.empty:
reserve = 0
else:
@@ -1878,16 +1890,17 @@ class ThlLedgerManager(LedgerManager):
wall["pct_rdm"] = wall["days_since_complete"].apply(self.get_redeemable_pct)
wall.loc[wall["pct_rdm"] > 0.95, "pct_rdm"] = 1
wall["redeemable"] = wall["pct_rdm"] * wall["user_payout_int"]
- # Calculate money needed to save in reserve to cover the difference between
- # money earned from completes and $ redeemable, subtract that from the
- # wall balance.
+ # Calculate money needed to save in reserve to cover the difference
+ # between money earned from completes and $ redeemable, subtract
+ # that from the wall balance.
reserve = round(wall["user_payout_int"].sum() - wall["redeemable"].sum())
+
redeemable_balance = user_wallet_balance - reserve
redeemable_balance = 0 if redeemable_balance < 0 else redeemable_balance
if redeemable_balance > 0:
- # it is possible the user_wallet_balance is negative, in which case the redeemable
- # balance is 0. Don't fail assertion if that happens.
+ # it is possible the user_wallet_balance is negative, in which case
+ # the redeemable balance is 0. Don't fail assertion if that happens.
assert redeemable_balance <= user_wallet_balance
return redeemable_balance
diff --git a/generalresearch/managers/thl/maxmind/__init__.py b/generalresearch/managers/thl/maxmind/__init__.py
index 59e0af8..3bf0d07 100644
--- a/generalresearch/managers/thl/maxmind/__init__.py
+++ b/generalresearch/managers/thl/maxmind/__init__.py
@@ -7,9 +7,9 @@ from generalresearch.managers.base import (
PostgresManagerWithRedis,
)
from generalresearch.managers.thl.ipinfo import (
- IPInformationManager,
- IPGeonameManager,
GeoIpInfoManager,
+ IPGeonameManager,
+ IPInformationManager,
)
from generalresearch.managers.thl.maxmind.basic import MaxmindBasicManager
from generalresearch.managers.thl.maxmind.insights import (
@@ -18,9 +18,9 @@ from generalresearch.managers.thl.maxmind.insights import (
)
from generalresearch.models.custom_types import IPvAnyAddressStr
from generalresearch.models.thl.ipinfo import (
- IPInformation,
- IPGeoname,
GeoIPInformation,
+ IPGeoname,
+ IPInformation,
normalize_ip,
)
from generalresearch.pg_helper import PostgresConfig
diff --git a/generalresearch/managers/thl/maxmind/basic.py b/generalresearch/managers/thl/maxmind/basic.py
index d065c13..72479df 100644
--- a/generalresearch/managers/thl/maxmind/basic.py
+++ b/generalresearch/managers/thl/maxmind/basic.py
@@ -10,16 +10,15 @@ from uuid import uuid4
import geoip2.database
import geoip2.models
import requests
-from cachetools import cached, TTLCache
+from cachetools import TTLCache, cached
from geoip2.errors import AddressNotFoundError
from generalresearch.managers.base import Manager
from generalresearch.models.custom_types import (
- IPvAnyAddressStr,
CountryISOLike,
+ IPvAnyAddressStr,
)
-
logger = logging.getLogger()
diff --git a/generalresearch/managers/thl/maxmind/insights.py b/generalresearch/managers/thl/maxmind/insights.py
index b83bded..13dc3e8 100644
--- a/generalresearch/managers/thl/maxmind/insights.py
+++ b/generalresearch/managers/thl/maxmind/insights.py
@@ -1,10 +1,8 @@
import logging
from typing import Optional
-import geoip2.database
import geoip2.models
import geoip2.webservice
-import slack
from geoip2.errors import (
AddressNotFoundError,
AuthenticationError,
diff --git a/generalresearch/managers/thl/payout.py b/generalresearch/managers/thl/payout.py
index e99cc25..0d00edb 100644
--- a/generalresearch/managers/thl/payout.py
+++ b/generalresearch/managers/thl/payout.py
@@ -1,14 +1,15 @@
from collections import defaultdict
-from datetime import timezone, datetime, timedelta
-from random import randint, choice as rand_choice
+from datetime import datetime, timedelta, timezone
+from random import choice as rand_choice
+from random import randint
from time import sleep
-from typing import Collection, Optional, Dict, List, Union
+from typing import Any, Collection, Dict, List, Optional, Union
from uuid import UUID, uuid4
import numpy as np
import pandas as pd
from psycopg import sql
-from pydantic import AwareDatetime, PositiveInt, NonNegativeInt
+from pydantic import AwareDatetime, NonNegativeInt, PositiveInt
from generalresearch.currency import USDCent
from generalresearch.decorators import LOG
@@ -23,21 +24,21 @@ from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr
from generalresearch.models.gr.business import Business
from generalresearch.models.thl.definitions import PayoutStatus
from generalresearch.models.thl.ledger import (
- LedgerAccount,
Direction,
+ LedgerAccount,
OrderBy,
)
from generalresearch.models.thl.payout import (
- PayoutEvent,
- UserPayoutEvent,
BrokerageProductPayoutEvent,
BusinessPayoutEvent,
+ PayoutEvent,
+ UserPayoutEvent,
)
from generalresearch.models.thl.product import Product
from generalresearch.models.thl.wallet import PayoutType
from generalresearch.models.thl.wallet.cashout_method import (
- CashoutRequestInfo,
CashMailOrderData,
+ CashoutRequestInfo,
)
@@ -93,7 +94,7 @@ class PayoutEventManager(PostgresManagerWithRedis):
payout_event: Union[UserPayoutEvent, BrokerageProductPayoutEvent],
status: PayoutStatus,
ext_ref_id: Optional[str] = None,
- order_data: Optional[Dict] = None,
+ order_data: Optional[Dict[str, Any]] = None,
) -> None:
# These 3 things are the only modifiable attributes
ext_ref_id = ext_ref_id if ext_ref_id is not None else payout_event.ext_ref_id
@@ -126,7 +127,7 @@ class UserPayoutEventManager(PayoutEventManager):
def get_by_uuid(self, pe_uuid: UUIDStr) -> UserPayoutEvent:
res = self.pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT ep.uuid,
ep.debit_account_uuid,
ep.cashout_method_uuid,
@@ -162,7 +163,7 @@ class UserPayoutEventManager(PayoutEventManager):
pe = self.get_by_uuid(pe_uuid=pe_uuid)
transaction_info = dict()
- order: Dict = pe.order_data
+ order: Dict[str, Any] = pe.order_data
if pe.payout_type == PayoutType.TANGO and pe.status == PayoutStatus.COMPLETE:
reward = order["reward"]
if "credentialList" in reward:
@@ -215,10 +216,12 @@ class UserPayoutEventManager(PayoutEventManager):
"""
args = []
filters = []
+
if reference_uuid:
# This could be a product_id or a user_uuid
filters.append("la.reference_uuid = %s")
args.append(reference_uuid)
+
if debit_account_uuids:
# Or we could use the bp_wallet or user_wallet's account uuid
# instead of looking up by the product/user
@@ -290,13 +293,13 @@ class UserPayoutEventManager(PayoutEventManager):
uuid: Optional[UUIDStr] = None,
status: Optional[PayoutStatus] = None,
created: Optional[AwareDatetimeISO] = None,
- request_data: Optional[Dict] = None,
+ request_data: Optional[Dict[str, Any]] = None,
# --- Optional: None ---
account_reference_type: Optional[str] = None,
account_reference_uuid: Optional[UUIDStr] = None,
description: Optional[str] = None,
ext_ref_id: Optional[str] = None,
- order_data: Optional[Dict | CashMailOrderData] = None,
+ order_data: Optional[Union[Dict[str, Any], CashMailOrderData]] = None,
) -> UserPayoutEvent:
payout_event = UserPayoutEvent(
@@ -319,10 +322,12 @@ class UserPayoutEventManager(PayoutEventManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(
- query=f"""
+ query="""
INSERT INTO event_payout (
- uuid, debit_account_uuid, created, cashout_method_uuid, amount,
- status, ext_ref_id, payout_type, order_data, request_data
+ uuid, debit_account_uuid, created,
+ cashout_method_uuid, amount, status,
+ ext_ref_id, payout_type, order_data,
+ request_data
) VALUES (
%(uuid)s, %(debit_account_uuid)s, %(created)s,
%(cashout_method_uuid)s, %(amount)s, %(status)s,
@@ -350,9 +355,10 @@ class UserPayoutEventManager(PayoutEventManager):
status: Optional[PayoutStatus] = None,
ext_ref_id: Optional[str] = None,
payout_type: Optional[PayoutType] = None,
- request_data: Optional[Dict] = None,
- order_data: Optional[Dict | CashMailOrderData] = None,
+ request_data: Optional[Dict[str, Any]] = None,
+ order_data: Optional[Union[Dict[str, Any], CashMailOrderData]] = None,
) -> UserPayoutEvent:
+
debit_account_uuid = debit_account_uuid or uuid4().hex
cashout_method_uuid = cashout_method_uuid or uuid4().hex
# account_reference_type = account_reference_type or f"acct-ref-{uuid4().hex}"
@@ -396,7 +402,7 @@ class BrokerageProductPayoutEventManager(PayoutEventManager):
) -> BrokerageProductPayoutEvent:
res = self.pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT ep.uuid,
ep.debit_account_uuid,
ep.cashout_method_uuid,
@@ -418,6 +424,7 @@ class BrokerageProductPayoutEventManager(PayoutEventManager):
rc = self.redis_client
account_product_mapping: Dict = rc.hgetall(name="pem:account_to_product")
assert isinstance(account_product_mapping, dict)
+
d["product_id"] = account_product_mapping[d["debit_account_uuid"]]
return BrokerageProductPayoutEvent.model_validate(d)
@@ -477,11 +484,12 @@ class BrokerageProductPayoutEventManager(PayoutEventManager):
status: Optional[PayoutStatus] = None,
ext_ref_id: Optional[str] = None,
payout_type: PayoutType = None,
- request_data: Dict = None,
- order_data: Optional[Dict | CashMailOrderData] = None,
+ request_data: Optional[Dict[str, Any]] = None,
+ order_data: Optional[Union[Dict[str, Any], CashMailOrderData]] = None,
# --- Support resources ---
account_product_mapping: Optional[Dict[UUIDStr, UUIDStr]] = None,
) -> BrokerageProductPayoutEvent:
+
if request_data is None:
request_data = dict()
@@ -509,7 +517,7 @@ class BrokerageProductPayoutEventManager(PayoutEventManager):
d = bp_payout_event.model_dump_mysql()
self.pg_config.execute_write(
- query=f"""
+ query="""
INSERT INTO event_payout (
uuid, debit_account_uuid, created, cashout_method_uuid, amount,
status, ext_ref_id, payout_type, order_data, request_data
@@ -685,7 +693,8 @@ class BrokerageProductPayoutEventManager(PayoutEventManager):
skip_wallet_balance_check: bool = False,
skip_one_per_day_check: bool = False,
) -> BrokerageProductPayoutEvent:
- """If a create_bp_payout_event call fails, this can be called with
+ """
+ If a create_bp_payout_event call fails, this can be called with
the associated payoutevent.
"""
bp_pe: BrokerageProductPayoutEvent = self.get_by_uuid(payout_event_uuid)
@@ -1005,7 +1014,7 @@ class BusinessPayoutEventManager(BrokerageProductPayoutEventManager):
grouped[bp_pe.ext_ref_id].append(bp_pe)
res = []
- for ex_ref_id, members in grouped.items():
+ for _, members in grouped.items():
res.append(BusinessPayoutEvent.model_validate({"bp_payouts": members}))
return res
@@ -1077,8 +1086,8 @@ class BusinessPayoutEventManager(BrokerageProductPayoutEventManager):
def distribute_amount(
df: pd.DataFrame,
amount: USDCent,
- weight_col="weight",
- balance_col="remaining_balance",
+ weight_col: str = "weight",
+ balance_col: str = "remaining_balance",
) -> pd.Series:
"""
Distributes an integer amount across dataframe rows proportionally,
@@ -1129,7 +1138,7 @@ class BusinessPayoutEventManager(BrokerageProductPayoutEventManager):
from itertools import islice
# Add 1 cent to the top 'shortage' rows
- for idx, value in islice(remainders.items(), shortage):
+ for idx, _ in islice(remainders.items(), shortage):
# Only add if it doesn't exceed the balance
if allocation.loc[idx] < df[balance_col].loc[idx]:
allocation.loc[idx] += 1
diff --git a/generalresearch/managers/thl/product.py b/generalresearch/managers/thl/product.py
index c7d38c2..a3a8da6 100644
--- a/generalresearch/managers/thl/product.py
+++ b/generalresearch/managers/thl/product.py
@@ -1,11 +1,11 @@
import json
import logging
import operator
-from datetime import timezone, datetime
+from datetime import datetime, timezone
from decimal import Decimal
from threading import Lock
-from typing import Collection, Optional, List, TYPE_CHECKING, Union
-from uuid import uuid4, UUID
+from typing import TYPE_CHECKING, Collection, List, Optional, Union
+from uuid import UUID, uuid4
from cachetools import TTLCache, cachedmethod, keys
from more_itertools import chunked
@@ -24,16 +24,16 @@ from generalresearch.pg_helper import PostgresConfig
logger = logging.getLogger()
if TYPE_CHECKING:
- from generalresearch.models.thl.product import Product
from generalresearch.models.thl.product import (
- UserCreateConfig,
PayoutConfig,
+ Product,
+ ProfilingConfig,
SessionConfig,
- UserWalletConfig,
SourcesConfig,
- UserHealthConfig,
- ProfilingConfig,
SupplyConfigs,
+ UserCreateConfig,
+ UserHealthConfig,
+ UserWalletConfig,
)
@@ -113,7 +113,7 @@ class ProductManager(PostgresManager):
if rand_limit:
res = self.pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT p.id::uuid
FROM userprofile_brokerageproduct AS p
ORDER BY RANDOM()
@@ -124,7 +124,7 @@ class ProductManager(PostgresManager):
else:
res = self.pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT p.id::uuid
FROM userprofile_brokerageproduct AS p
"""
@@ -322,14 +322,14 @@ class ProductManager(PostgresManager):
) -> "Product":
"""Create a Product with all the basic defaults and return the instance"""
from generalresearch.models.thl.product import (
- UserCreateConfig,
PayoutConfig,
+ Product,
+ ProfilingConfig,
SessionConfig,
- UserWalletConfig,
SourcesConfig,
+ UserCreateConfig,
UserHealthConfig,
- ProfilingConfig,
- Product,
+ UserWalletConfig,
)
now = datetime.now(tz=timezone.utc)
@@ -402,7 +402,7 @@ class ProductManager(PostgresManager):
try:
insert_data["id_int"] = list(
self.pg_config.execute_sql_query(
- f"""
+ query="""
SELECT COALESCE(MAX(id_int), 0) + 1 as id_int
FROM userprofile_brokerageproduct
"""
diff --git a/generalresearch/managers/thl/profiling/question.py b/generalresearch/managers/thl/profiling/question.py
index 1ad27ac..7b2a7ad 100644
--- a/generalresearch/managers/thl/profiling/question.py
+++ b/generalresearch/managers/thl/profiling/question.py
@@ -1,15 +1,15 @@
import random
import threading
-from typing import Collection, List, Tuple
+from typing import Any, Collection, Dict, List, Tuple
-from cachetools import cached, TTLCache
+from cachetools import TTLCache, cached
from pydantic import ValidationError
from generalresearch.decorators import LOG
from generalresearch.managers.base import PostgresManager
from generalresearch.models.thl.profiling.upk_question import (
- UpkQuestion,
UPKImportance,
+ UpkQuestion,
)
@@ -21,8 +21,8 @@ class QuestionManager(PostgresManager):
FROM marketplace_question
WHERE id = ANY(%(question_ids)s);
"""
- res = self.pg_config.execute_sql_query(
- query, {"question_ids": list(question_ids)}
+ res: List[Dict[str, Any]] = self.pg_config.execute_sql_query(
+ query=query, params={"question_ids": list(question_ids)}
)
for x in res:
x["data"]["ext_question_id"] = x["property_code"]
@@ -31,6 +31,7 @@ class QuestionManager(PostgresManager):
"explanation_fragment_template"
]
x["data"].pop("categories", None)
+
return [UpkQuestion.model_validate(x["data"]) for x in res]
@cached(
@@ -50,7 +51,7 @@ class QuestionManager(PostgresManager):
AND property_code NOT LIKE 'g:%%'
AND is_live
"""
- res = self.pg_config.execute_sql_query(
+ res: List[Dict[str, Any]] = self.pg_config.execute_sql_query(
query=query,
params={"country_iso": country_iso, "language_iso": language_iso},
)
@@ -94,9 +95,12 @@ class QuestionManager(PostgresManager):
"country_iso": country_iso,
"language_iso": language_iso,
}
- res = self.pg_config.execute_sql_query(query=query, params=params)
+ res: List[Dict[str, Any]] = self.pg_config.execute_sql_query(
+ query=query, params=params
+ )
assert len(res) == 1, f"expected 1, got {len(res)} results"
x = res[0]
+
x["data"]["ext_question_id"] = x["property_code"]
x["data"]["explanation_template"] = x["explanation_template"]
x["data"]["explanation_fragment_template"] = x["explanation_fragment_template"]
@@ -119,7 +123,9 @@ class QuestionManager(PostgresManager):
WHERE {where_str}
"""
flat_params = [item for tup in lookup for item in tup]
- res = self.pg_config.execute_sql_query(query, params=flat_params)
+ res: List[Dict[str, Any]] = self.pg_config.execute_sql_query(
+ query=query, params=flat_params
+ )
for x in res:
x["data"]["ext_question_id"] = x["property_code"]
x["data"]["explanation_template"] = x["explanation_template"]
diff --git a/generalresearch/managers/thl/profiling/schema.py b/generalresearch/managers/thl/profiling/schema.py
index 581270b..e209067 100644
--- a/generalresearch/managers/thl/profiling/schema.py
+++ b/generalresearch/managers/thl/profiling/schema.py
@@ -2,7 +2,7 @@ from threading import RLock
from typing import List
from uuid import UUID
-from cachetools import cached, TTLCache
+from cachetools import TTLCache, cached
from generalresearch.managers.base import PostgresManager
from generalresearch.models.thl.profiling.upk_property import (
diff --git a/generalresearch/managers/thl/profiling/uqa.py b/generalresearch/managers/thl/profiling/uqa.py
index 6800d32..1cab6c2 100644
--- a/generalresearch/managers/thl/profiling/uqa.py
+++ b/generalresearch/managers/thl/profiling/uqa.py
@@ -132,7 +132,7 @@ class UQAManager(PostgresManagerWithRedis):
# 1) the cache expired and the user hasn't sent an answer recently
# or 2) The user just sent an answer, so we'll make sure it gets put into the results
# after this query runs.
- query = f"""
+ query = """
WITH ranked AS (
SELECT
uqa.*,
diff --git a/generalresearch/managers/thl/profiling/user_upk.py b/generalresearch/managers/thl/profiling/user_upk.py
index 4449afa..b3ea52e 100644
--- a/generalresearch/managers/thl/profiling/user_upk.py
+++ b/generalresearch/managers/thl/profiling/user_upk.py
@@ -1,10 +1,11 @@
import json
from collections import defaultdict
-from datetime import timedelta, datetime, timezone
-from typing import Dict, Union, Set, List, Collection, Optional, Tuple
+from datetime import datetime, timedelta, timezone
+from typing import Any, Collection, Dict, List, Optional, Set, Tuple, Union
from uuid import UUID
from psycopg import Cursor
+from pydantic import PositiveInt
from generalresearch.managers.base import (
Permission,
@@ -117,8 +118,9 @@ class UserUpkManager(PostgresManagerWithRedis):
return [UpkQuestionAnswer.model_validate(x) for x in res]
def get_user_upk_simple(
- self, user_id, country_iso="us"
+ self, user_id: PositiveInt, country_iso: str = "us"
) -> Dict[str, Union[Set[str], str, float]]:
+
res = self.get_user_upk(user_id=user_id)
res = [x for x in res if x.country_iso == country_iso]
d: Dict[str, Union[Set[str], str, float]] = defaultdict(set)
@@ -127,16 +129,19 @@ class UserUpkManager(PostgresManagerWithRedis):
d[x.property_label] = x.value
else:
d[x.property_label].add(x.value)
+
return dict(d)
def get_age_gender(
- self, user_id, country_iso="us"
+ self, user_id: PositiveInt, country_iso: str = "us"
) -> Tuple[Optional[int], Optional[str]]:
+
# Returns an integer year for age, and {'male', 'female', 'other_gender'}
d = self.get_user_upk_simple(user_id, country_iso)
age = d.get("age_in_years")
if age is not None:
age = int(age)
+
gender = d.get("gender")
return age, gender
@@ -145,7 +150,9 @@ class UserUpkManager(PostgresManagerWithRedis):
country_iso=country_iso
)
- def populate_user_upk_from_dict(self, upk_ans_dict):
+ def populate_user_upk_from_dict(
+ self, upk_ans_dict: List[Dict[str, Any]]
+ ) -> List[UpkQuestionAnswer]:
country_isos = {x["country_iso"] for x in upk_ans_dict}
assert len(country_isos) == 1
diff --git a/generalresearch/managers/thl/session.py b/generalresearch/managers/thl/session.py
index d328f7d..6e77031 100644
--- a/generalresearch/managers/thl/session.py
+++ b/generalresearch/managers/thl/session.py
@@ -1,11 +1,11 @@
from datetime import datetime, timedelta, timezone
from decimal import Decimal
-from typing import Optional, Dict, Tuple, List, Any, Collection
-from uuid import uuid4, UUID
+from typing import Any, Collection, Dict, List, Optional, Tuple
+from uuid import UUID, uuid4
from faker import Faker
from psycopg import sql
-from pydantic import NonNegativeInt
+from pydantic import NonNegativeInt, PositiveInt
from generalresearch.managers import parse_order_by
from generalresearch.managers.base import (
@@ -17,17 +17,17 @@ from generalresearch.models import DeviceType
from generalresearch.models.custom_types import UUIDStr
from generalresearch.models.legacy.bucket import Bucket
from generalresearch.models.thl.definitions import (
+ SessionStatusCode2,
Status,
StatusCode1,
- SessionStatusCode2,
)
from generalresearch.models.thl.session import (
Session,
Wall,
)
from generalresearch.models.thl.task_status import (
- TaskStatusResponse,
TasksStatusResponse,
+ TaskStatusResponse,
)
from generalresearch.models.thl.user import User
@@ -48,7 +48,7 @@ class SessionManager(PostgresManager):
device_type: Optional[DeviceType] = None,
ip: Optional[str] = None,
bucket: Optional[Bucket] = None,
- url_metadata: Optional[Dict] = None,
+ url_metadata: Optional[Dict[str, str]] = None,
uuid_id: Optional[str] = None,
) -> Session:
"""Creates a Session. Prefer to use this rather than instantiating the
@@ -86,7 +86,7 @@ class SessionManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(query=query, params=d)
- session.id = c.fetchone()["id"]
+ session.id = c.fetchone()["id"] # type: ignore
conn.commit()
return session
@@ -100,7 +100,7 @@ class SessionManager(PostgresManager):
device_type: Optional[DeviceType] = None,
ip: Optional[str] = None,
bucket: Optional[Bucket] = None,
- url_metadata: Optional[Dict] = None,
+ url_metadata: Optional[Dict[str, str]] = None,
uuid_id: Optional[str] = None,
) -> Session:
"""To be used in tests, where we don't care about certain fields"""
@@ -125,7 +125,7 @@ class SessionManager(PostgresManager):
)
def get_from_uuid(self, session_uuid: UUIDStr) -> Session:
- query = f"""
+ query = """
SELECT
s.id AS session_id,
s.uuid AS session_uuid,
@@ -148,7 +148,7 @@ class SessionManager(PostgresManager):
return self.session_from_mysql(res[0])
def get_from_id(self, session_id: int) -> Session:
- query = f"""
+ query = """
SELECT
s.id AS session_id,
s.uuid AS session_uuid,
@@ -234,7 +234,7 @@ class SessionManager(PostgresManager):
)
d = session.model_dump_mysql()
self.pg_config.execute_write(
- query=f"""
+ query="""
UPDATE thl_session
SET status = %(status)s, status_code_1 = %(status_code_1)s,
status_code_2 = %(status_code_2)s, finished = %(finished)s,
@@ -286,7 +286,7 @@ class SessionManager(PostgresManager):
def filter_paginated(
self,
- user_id: Optional[int] = None,
+ user_id: Optional[PositiveInt] = None,
session_uuids: Optional[List[UUIDStr]] = None,
product_uuids: Optional[List[UUIDStr]] = None,
started_after: Optional[datetime] = None,
diff --git a/generalresearch/managers/thl/survey.py b/generalresearch/managers/thl/survey.py
index 871ab83..1ce81d4 100644
--- a/generalresearch/managers/thl/survey.py
+++ b/generalresearch/managers/thl/survey.py
@@ -1,12 +1,13 @@
from collections import defaultdict
from datetime import datetime, timezone
-from typing import Collection, List, Tuple, Optional
+from typing import Any, Collection, Dict, List, Optional, Tuple
import pandas as pd
from more_itertools import chunked
from psycopg import sql
+from pydantic import NonNegativeInt
-from generalresearch.managers.base import PostgresManager, Permission
+from generalresearch.managers.base import Permission, PostgresManager
from generalresearch.managers.thl.buyer import BuyerManager
from generalresearch.managers.thl.category import CategoryManager
from generalresearch.models import Source
@@ -126,7 +127,8 @@ class SurveyManager(PostgresManager):
self,
survey_keys: Collection[SurveyKey],
include_categories: bool = False,
- ):
+ ) -> List[Survey]:
+
assert len(survey_keys) <= 1000
if len(survey_keys) == 0:
return []
@@ -195,15 +197,20 @@ class SurveyManager(PostgresManager):
query,
params=params,
)
+
return [Survey.model_validate(x) for x in res]
- def filter_by_natural_key(self, source: Source, survey_ids: Collection[str]):
+ def filter_by_natural_key(
+ self, source: Source, survey_ids: Collection[str]
+ ) -> List[Survey]:
res = []
for chunk in chunked(survey_ids, 1000):
res.extend(self.filter_by_natural_key_chunk(source, chunk))
return res
- def filter_by_natural_key_chunk(self, source: Source, survey_ids: Collection[str]):
+ def filter_by_natural_key_chunk(
+ self, source: Source, survey_ids: Collection[str]
+ ) -> List[Survey]:
query = """
SELECT id, source, survey_id, created_at, updated_at,
is_live, is_recontact, buyer_id, eligibility_criteria
@@ -217,7 +224,7 @@ class SurveyManager(PostgresManager):
)
return [Survey.model_validate(x) for x in res]
- def filter_by_source_live(self, source: Source):
+ def filter_by_source_live(self, source: Source) -> List[Survey]:
"""
Return all live surveys for this source
"""
@@ -230,7 +237,7 @@ class SurveyManager(PostgresManager):
res = self.pg_config.execute_sql_query(query, params={"source": source.value})
return [Survey.model_validate(x) for x in res]
- def filter_by_live(self, fields: Optional[List[str]] = None):
+ def filter_by_live(self, fields: Optional[List[str]] = None) -> List[Survey]:
"""
Return all live surveys
"""
@@ -245,7 +252,9 @@ class SurveyManager(PostgresManager):
res = self.pg_config.execute_sql_query(query)
return [Survey.model_validate(x) for x in res]
- def turn_off_by_natural_key(self, source: Source, survey_ids: Collection[str]):
+ def turn_off_by_natural_key(
+ self, source: Source, survey_ids: Collection[str]
+ ) -> None:
params = {"survey_ids": list(survey_ids), "source": source.value}
query = """
UPDATE marketplace_survey
@@ -267,7 +276,7 @@ class SurveyManager(PostgresManager):
survey_id = ANY(%(survey_pks)s);
"""
self.pg_config.execute_write(
- query,
+ query=query,
params={"survey_pks": survey_pks},
)
return None
@@ -447,13 +456,16 @@ class SurveyStatManager(PostgresManager):
# info = CompositeInfo.fetch(conn, "surveystat_key")
# info.register(conn)
- def update_or_create(self, survey_stats: List[SurveyStat]):
+ def update_or_create(
+ self, survey_stats: List[SurveyStat]
+ ) -> Optional[List[SurveyStat]]:
"""
This manager is NOT responsible for creating surveys or buyers.
It will check to make sure they exist
"""
if len(survey_stats) == 0:
return []
+
assert all(s.survey_survey_id is not None for s in survey_stats)
assert all(s.survey_source is not None for s in survey_stats)
assert (
@@ -478,6 +490,7 @@ class SurveyStatManager(PostgresManager):
)
# print(f"----aa-----: {datetime.now().isoformat()}")
self.upsert_sql(survey_stats=survey_stats)
+
# print(f"----ab-----: {datetime.now().isoformat()}")
return None
# keys = [s.unique_key for s in survey_stats]
@@ -489,7 +502,7 @@ class SurveyStatManager(PostgresManager):
# survey_stats = sorted(survey_stats, key=lambda s: s.natural_key)
# return survey_stats
- def upsert_sql(self, survey_stats: List[SurveyStat]):
+ def upsert_sql(self, survey_stats: List[SurveyStat]) -> None:
for chunk in chunked(survey_stats, 1000):
self.upsert_sql_chunk(survey_stats=chunk)
return None
@@ -519,7 +532,7 @@ class SurveyStatManager(PostgresManager):
# conn.commit()
# return None
- def upsert_sql_chunk(self, survey_stats: List[SurveyStat]):
+ def upsert_sql_chunk(self, survey_stats: List[SurveyStat]) -> None:
assert len(survey_stats) <= 1000, "chunk me"
keys = self.KEYS
keys_str = ", ".join(keys)
@@ -538,16 +551,19 @@ class SurveyStatManager(PostgresManager):
DO UPDATE SET {update_str};"""
now = datetime.now(tz=timezone.utc)
params = [ss.model_dump_sql() | {"updated_at": now} for ss in survey_stats]
+
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.executemany(query=query, params_seq=params)
conn.commit()
+
return None
- def filter_by_unique_keys(self, keys: Collection[Tuple]):
+ def filter_by_unique_keys(self, keys: Collection[Tuple]) -> List[SurveyStat]:
res = []
for chunk in chunked(keys, 5000):
res.extend(self.filter_by_unique_keys_chunk(chunk))
+
return res
def filter_by_unique_keys_chunk(self, keys: Collection[Tuple]):
@@ -610,7 +626,7 @@ class SurveyStatManager(PostgresManager):
return res
- def filter_by_updated_since(self, since):
+ def filter_by_updated_since(self, since: datetime):
return self.filter(updated_after=since, is_live=None)
def filter_by_live(self):
@@ -624,7 +640,7 @@ class SurveyStatManager(PostgresManager):
survey_keys: Optional[Collection[SurveyKey]] = None,
sources: Optional[Collection[Source]] = None,
country_iso: Optional[str] = None,
- ):
+ ) -> Tuple[str, Dict[str, Any]]:
filters = []
params = dict()
if updated_after is not None:
@@ -665,6 +681,7 @@ class SurveyStatManager(PostgresManager):
filters.append(f"({' OR '.join(sk_filters)})")
filter_str = "WHERE " + " AND ".join(filters) if filters else ""
+
return filter_str, params
def filter_count(
@@ -675,7 +692,7 @@ class SurveyStatManager(PostgresManager):
survey_keys: Optional[Collection[SurveyKey]] = None,
sources: Optional[Collection[Source]] = None,
country_iso: Optional[str] = None,
- ) -> int:
+ ) -> NonNegativeInt:
filter_str, params = self.make_filter_str(
is_live=is_live,
updated_after=updated_after,
@@ -703,7 +720,7 @@ class SurveyStatManager(PostgresManager):
size: Optional[int] = None,
order_by: Optional[str] = None,
debug: Optional[bool] = False,
- ):
+ ) -> List[SurveyStat]:
filter_str, params = self.make_filter_str(
is_live=is_live,
updated_after=updated_after,
@@ -745,15 +762,18 @@ class SurveyStatManager(PostgresManager):
{order_by_str}
{paginated_filter_str} ;
"""
+
if debug:
print(query)
print(params)
+
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute("SET work_mem = '256MB';")
c.execute("SET statement_timeout = '10s';")
c.execute(query, params=params)
res = c.fetchall()
+
return [SurveyStat.model_validate(x) for x in res]
def filter_to_merge_table(
@@ -761,12 +781,14 @@ class SurveyStatManager(PostgresManager):
is_live: Optional[bool] = True,
updated_after: Optional[datetime] = None,
min_score: Optional[float] = 0.0001,
- ):
+ ) -> Optional[pd.DataFrame]:
+
survey_stats = self.filter(
is_live=is_live, updated_after=updated_after, min_score=min_score
)
if not survey_stats:
return None
+
extra_cols = {
"survey_id",
"quota_id",
diff --git a/generalresearch/managers/thl/survey_penalty.py b/generalresearch/managers/thl/survey_penalty.py
index 4e2c104..3c402ee 100644
--- a/generalresearch/managers/thl/survey_penalty.py
+++ b/generalresearch/managers/thl/survey_penalty.py
@@ -2,7 +2,9 @@ import json
import threading
from collections import defaultdict
from datetime import timedelta
-from typing import Optional, List, Tuple, Dict
+from typing import Dict, List, Optional, Tuple
+
+from cachetools import TTLCache, cachedmethod
from generalresearch.decorators import LOG
from generalresearch.managers.base import RedisManager
@@ -11,12 +13,11 @@ from generalresearch.models.custom_types import (
)
from generalresearch.models.thl.survey.penalty import (
BPSurveyPenalty,
- TeamSurveyPenalty,
- PenaltyListAdapter,
Penalty,
+ PenaltyListAdapter,
+ TeamSurveyPenalty,
)
from generalresearch.redis_helper import RedisConfig
-from cachetools import cachedmethod, TTLCache
class SurveyPenaltyManager(RedisManager):
diff --git a/generalresearch/managers/thl/task_adjustment.py b/generalresearch/managers/thl/task_adjustment.py
index 15802e3..b1b509b 100644
--- a/generalresearch/managers/thl/task_adjustment.py
+++ b/generalresearch/managers/thl/task_adjustment.py
@@ -2,7 +2,7 @@ import logging
from datetime import datetime, timezone
from decimal import Decimal
from functools import cached_property
-from typing import Optional
+from typing import List, Optional
from generalresearch.managers import parse_order_by
from generalresearch.managers.base import (
@@ -13,9 +13,10 @@ from generalresearch.managers.thl.ledger_manager.thl_ledger import (
)
from generalresearch.managers.thl.session import SessionManager
from generalresearch.managers.thl.wall import WallManager
+from generalresearch.models.custom_types import UUIDStr
from generalresearch.models.thl.definitions import (
- WallAdjustedStatus,
Status,
+ WallAdjustedStatus,
)
from generalresearch.models.thl.session import (
_check_adjusted_status_wall_consistent,
@@ -35,11 +36,11 @@ class TaskAdjustmentManager(PostgresManager):
def filter_by_wall_uuid(
self,
- wall_uuid,
+ wall_uuid: UUIDStr,
page: int = 1,
size: int = 100,
order_by: Optional[str] = "-created",
- ):
+ ) -> List[TaskAdjustmentEvent]:
params = {"wall_uuid": wall_uuid}
order_by_str = parse_order_by(order_by)
paginated_filter_str = "LIMIT %(limit)s OFFSET %(offset)s"
@@ -67,26 +68,34 @@ class TaskAdjustmentManager(PostgresManager):
)
return [TaskAdjustmentEvent.model_validate(x) for x in res]
- def create_task_adjustment_event(self, event: TaskAdjustmentEvent):
- # Only insert a new record into thl_taskadjustment if the status for this wall_uuid
- # is different from the last one. Don't need the same thing twice
+ def create_task_adjustment_event(
+ self, event: TaskAdjustmentEvent
+ ) -> TaskAdjustmentEvent:
+
+ # Only insert a new record into thl_taskadjustment if the status
+ # for this wall_uuid is different from the last one. Don't
+ # need the same thing twice
res = self.filter_by_wall_uuid(
wall_uuid=event.wall_uuid, page=1, size=1, order_by="-created"
)
if res and event.adjusted_status == res[0].adjusted_status:
- # We already have this and it's the same change. Still call the wall_manager.adjust_status
- # and ledger code b/c 1) it also won't do the same thing twice, and 2) we could be out of sync
- # so check anyway.
+ # We already have this and it's the same change. Still call
+ # the wall_manager.adjust_status and ledger code b/c 1) it
+ # also won't do the same thing twice, and 2) we could be out
+ # of sync so check anyway.
return res[0]
self.pg_config.execute_write(
- """
+ query="""
INSERT INTO thl_taskadjustment
- (uuid, adjusted_status, ext_status_code, amount, alerted,
- created, user_id, wall_uuid, started, source, survey_id)
- VALUES (%(uuid)s, %(adjusted_status)s, %(ext_status_code)s, %(amount)s, %(alerted)s,
- %(created)s, %(user_id)s, %(wall_uuid)s, %(started)s, %(source)s, %(survey_id)s)
+ (uuid, adjusted_status, ext_status_code,
+ amount, alerted, created, user_id,
+ wall_uuid, started, source, survey_id)
+ VALUES (
+ %(uuid)s, %(adjusted_status)s, %(ext_status_code)s,
+ %(amount)s, %(alerted)s, %(created)s, %(user_id)s,
+ %(wall_uuid)s, %(started)s, %(source)s, %(survey_id)s)
""",
params=event.model_dump(mode="json"),
)
@@ -100,13 +109,15 @@ class TaskAdjustmentManager(PostgresManager):
alert_time: Optional[datetime] = None,
ext_status_code: Optional[str] = None,
adjusted_cpi: Optional[Decimal] = None,
- ):
+ ) -> None:
"""
We just got an adjustment notification from a marketplace.
See note on TaskAdjustmentEvent.adjusted_status.
- These fields (specifically adjusted_status and adjusted_cpi) are CHANGES/DELTAS
- as just communicated by the marketplace, not what the Wall's final adjusted_* will be.
+
+ These fields (specifically adjusted_status and adjusted_cpi) are
+ CHANGES/DELTAS as just communicated by the marketplace, not
+ what the Wall's final adjusted_* will be.
"""
alert_time = alert_time or datetime.now(tz=timezone.utc)
assert alert_time.tzinfo == timezone.utc
@@ -129,9 +140,10 @@ class TaskAdjustmentManager(PostgresManager):
else:
raise ValueError
- # If the wall event is a complete -> fail -> complete, we are going to
- # receive an adjusted_status.adjust_to_complete, but internally,
- # this is going to set the adjusted_status to None (b/c it was already a complete)
+ # If the wall event is a complete -> fail -> complete, we are going
+ # to receive an adjusted_status.adjust_to_complete, but internally,
+ # this is going to set the adjusted_status to None (b/c it was
+ # already a complete)
if (
wall.status == Status.COMPLETE
and adjusted_status == WallAdjustedStatus.ADJUSTED_TO_COMPLETE
@@ -185,3 +197,5 @@ class TaskAdjustmentManager(PostgresManager):
session.wall_events = self.wall_manager.get_wall_events(session.id)
self.session_manager.adjust_status(session)
ledger_manager.create_tx_bp_adjustment(session, created=alert_time)
+
+ return None
diff --git a/generalresearch/managers/thl/user_compensate.py b/generalresearch/managers/thl/user_compensate.py
index c5018d9..7cd16fe 100644
--- a/generalresearch/managers/thl/user_compensate.py
+++ b/generalresearch/managers/thl/user_compensate.py
@@ -3,6 +3,8 @@ from decimal import Decimal
from typing import Optional
from uuid import uuid4
+from pydantic import NonNegativeInt
+
from generalresearch.managers.thl.ledger_manager.thl_ledger import (
ThlLedgerManager,
)
@@ -13,14 +15,14 @@ from generalresearch.models.thl.user import User
def user_compensate(
ledger_manager: ThlLedgerManager,
user: User,
- amount_int: int,
- ext_ref=None,
- description=None,
+ amount_int: NonNegativeInt,
+ ext_ref: Optional[str] = None,
+ description: Optional[str] = None,
skip_flag_check: Optional[bool] = False,
) -> UUIDStr:
"""
- Compensate a user. aka "bribe". The money is paid out of the BP's wallet balance.
- Amount is in USD cents.
+ Compensate a user. aka "bribe". The money is paid out of the BP's
+ wallet balance. Amount is in USD cents.
"""
pg_config = ledger_manager.pg_config
redis_client = ledger_manager.redis_client
@@ -44,7 +46,7 @@ def user_compensate(
# If there is an external reference ID, don't allow it to be used twice
if ext_ref:
res = pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT 1
FROM event_bribe
WHERE ext_ref_id = %s
@@ -59,9 +61,10 @@ def user_compensate(
# Create a new bribe instance
account = ledger_manager.get_account_or_create_user_wallet(user)
pg_config.execute_write(
- query=f"""
+ query="""
INSERT INTO event_bribe
- (uuid, credit_account_uuid, created, amount, ext_ref_id, description, data)
+ (uuid, credit_account_uuid, created,
+ amount, ext_ref_id, description, data)
VALUES (%s, %s, %s, %s, %s, %s, %s)
""",
params=[
diff --git a/generalresearch/managers/thl/user_manager/__init__.py b/generalresearch/managers/thl/user_manager/__init__.py
index b8b1c35..8875de8 100644
--- a/generalresearch/managers/thl/user_manager/__init__.py
+++ b/generalresearch/managers/thl/user_manager/__init__.py
@@ -5,11 +5,10 @@ import threading
import time
from pathlib import Path
from threading import RLock
-from typing import Dict
+from typing import Any, Dict, Union
-from cachetools import cached, TTLCache
+from cachetools import TTLCache, cached
-from generalresearch.managers.thl.user_manager import mysql_user_manager
from generalresearch.models.thl.product import Product
logger = logging.getLogger()
@@ -54,7 +53,7 @@ def get_bp_trust_df():
convert_int = lambda x: int(float(x))
-def parse_bp_trust_df(fp) -> Dict:
+def parse_bp_trust_df(fp: Union[str, Path]) -> Dict[str, Any]:
dtype = {
"bp_trust": float,
"team_trust": float,
@@ -67,6 +66,7 @@ def parse_bp_trust_df(fp) -> Dict:
with open(fp, newline="") as csvfile:
reader = csv.reader(csvfile)
header = next(reader)
+
for row in reader:
d = dict(zip(header, row))
for k, v in dtype.items():
diff --git a/generalresearch/managers/thl/user_manager/mysql_user_manager.py b/generalresearch/managers/thl/user_manager/mysql_user_manager.py
index ab2c6c3..299b167 100644
--- a/generalresearch/managers/thl/user_manager/mysql_user_manager.py
+++ b/generalresearch/managers/thl/user_manager/mysql_user_manager.py
@@ -1,7 +1,7 @@
import logging
from datetime import datetime, timezone
from functools import lru_cache
-from typing import Optional, Collection, List
+from typing import Collection, List, Optional
from uuid import uuid4
import psycopg
@@ -41,7 +41,7 @@ class MysqlUserManager:
product_user_id: Optional[str] = None,
user_id: Optional[int] = None,
user_uuid: Optional[UUIDStr] = None,
- can_use_read_replica=True,
+ can_use_read_replica: bool = True,
) -> Optional[User]:
logger.info(
@@ -64,7 +64,7 @@ class MysqlUserManager:
if product_id:
res = self.pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT id AS user_id, product_id, product_user_id,
uuid, blocked, created, last_seen
FROM thl_user
@@ -77,7 +77,7 @@ class MysqlUserManager:
elif user_id:
res = self.pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT id AS user_id, product_id, product_user_id,
uuid, blocked, created, last_seen
FROM thl_user
@@ -89,7 +89,7 @@ class MysqlUserManager:
else:
res = self.pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT id AS user_id, product_id, product_user_id,
uuid, blocked, created, last_seen
FROM thl_user
@@ -205,7 +205,7 @@ class MysqlUserManager:
def is_whitelisted(self, user: User):
res = self.pg_config.execute_sql_query(
- f"""
+ """
SELECT value
FROM userprofile_userstat
WHERE user_id = %s
@@ -260,7 +260,7 @@ class MysqlUserManager:
), "must pass a collection of user_ids"
res = self.pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT id AS user_id, product_id, product_user_id,
uuid, blocked, created, last_seen
FROM thl_user
@@ -275,7 +275,7 @@ class MysqlUserManager:
user_uuids, (list, set)
), "must pass a collection of user_uuids"
res = self.pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT id AS user_id, product_id, product_user_id,
uuid, blocked, created, last_seen
FROM thl_user
diff --git a/generalresearch/managers/thl/user_manager/rate_limit.py b/generalresearch/managers/thl/user_manager/rate_limit.py
index bd0f9ac..f938664 100644
--- a/generalresearch/managers/thl/user_manager/rate_limit.py
+++ b/generalresearch/managers/thl/user_manager/rate_limit.py
@@ -1,6 +1,6 @@
import logging
-from limits import storage, strategies, RateLimitItemPerHour, RateLimitItem
+from limits import RateLimitItem, RateLimitItemPerHour, storage, strategies
from limits.limits import TIME_TYPES, safe_string
from pydantic import RedisDsn
diff --git a/generalresearch/managers/thl/user_manager/redis_user_manager.py b/generalresearch/managers/thl/user_manager/redis_user_manager.py
index 8e69740..6c869bc 100644
--- a/generalresearch/managers/thl/user_manager/redis_user_manager.py
+++ b/generalresearch/managers/thl/user_manager/redis_user_manager.py
@@ -36,7 +36,7 @@ class RedisUserManager:
product_user_id: Optional[str] = None,
user_id: Optional[int] = None,
user_uuid: Optional[str] = None,
- ) -> User:
+ ) -> Optional[User]:
# assume we did input validation in user_manager.get_user() function
if user_uuid:
d = self.client.get(f"{self.cache_prefix}:uuid:{user_uuid}")
@@ -52,6 +52,8 @@ class RedisUserManager:
if d:
return User.model_validate_json(d)
+ return None
+
def set_user(self, user: User) -> None:
d = user.to_json()
with self.client.pipeline(transaction=False) as p:
diff --git a/generalresearch/managers/thl/user_manager/user_manager.py b/generalresearch/managers/thl/user_manager/user_manager.py
index dd32f8b..ffd8191 100644
--- a/generalresearch/managers/thl/user_manager/user_manager.py
+++ b/generalresearch/managers/thl/user_manager/user_manager.py
@@ -1,7 +1,7 @@
import logging
from datetime import datetime
from functools import lru_cache
-from typing import Collection, Optional, List
+from typing import Collection, List, Optional
from uuid import uuid4
from pydantic import RedisDsn
@@ -35,7 +35,7 @@ class UserManager:
redis: Optional[RedisDsn] = None,
pg_config: Optional[PostgresConfig] = None,
pg_config_rr: Optional[PostgresConfig] = None,
- sql_permissions: Collection[Permission] = None,
+ sql_permissions: Optional[Collection[Permission]] = None,
cache_prefix: Optional[str] = None,
redis_timeout: Optional[float] = None,
):
diff --git a/generalresearch/managers/thl/user_manager/user_metadata_manager.py b/generalresearch/managers/thl/user_manager/user_metadata_manager.py
index 23c9d3c..a97b214 100644
--- a/generalresearch/managers/thl/user_manager/user_metadata_manager.py
+++ b/generalresearch/managers/thl/user_manager/user_metadata_manager.py
@@ -1,4 +1,4 @@
-from typing import List, Optional, Collection
+from typing import Collection, List, Optional
from generalresearch.managers.base import PostgresManager
from generalresearch.models.thl.user_profile import UserMetadata
@@ -23,8 +23,10 @@ class UserMetadataManager(PostgresManager):
assert arg is None or isinstance(
arg, (set, list)
), "must pass a collection of objects"
+
filters = []
params = {}
+
if user_ids:
params["user_id"] = list(set(user_ids))
filters.append("user_id = ANY(%(user_id)s)")
@@ -50,6 +52,7 @@ class UserMetadataManager(PostgresManager):
""",
params,
)
+
return [UserMetadata.from_db(**x) for x in res]
def get_if_exists(
@@ -104,8 +107,9 @@ class UserMetadataManager(PostgresManager):
def update(self, user_metadata: UserMetadata) -> int:
"""
- The row in the thl_usermetadata might not exist. We'll implicitly create it
- if it doesn't yet exist. The caller does not need to know this detail.
+ The row in the thl_usermetadata might not exist. We'll
+ implicitly create it if it doesn't yet exist. The caller
+ does not need to know this detail.
"""
res = self.get_if_exists(user_id=user_metadata.user_id)
@@ -132,10 +136,11 @@ class UserMetadataManager(PostgresManager):
def _create(self, user_metadata: UserMetadata) -> int:
return self.pg_config.execute_write(
- """
+ query="""
INSERT INTO thl_usermetadata
(user_id, email_address, email_sha256, email_sha1, email_md5)
- VALUES (%(user_id)s, %(email_address)s, %(email_sha256)s, %(email_sha1)s, %(email_md5)s);
+ VALUES (%(user_id)s, %(email_address)s, %(email_sha256)s,
+ %(email_sha1)s, %(email_md5)s);
""",
params=user_metadata.to_db(),
)
diff --git a/generalresearch/managers/thl/user_streak.py b/generalresearch/managers/thl/user_streak.py
index fd3c1ba..4bc7e70 100644
--- a/generalresearch/managers/thl/user_streak.py
+++ b/generalresearch/managers/thl/user_streak.py
@@ -1,16 +1,16 @@
from datetime import date, datetime
-from typing import Optional, List, Tuple
+from typing import List, Optional, Tuple
import pandas as pd
from generalresearch.managers.base import PostgresManager
from generalresearch.managers.leaderboard import country_timezone
from generalresearch.models.thl.user_streak import (
- UserStreak,
- StreakPeriod,
+ PERIOD_TO_PD_FREQ,
StreakFulfillment,
+ StreakPeriod,
StreakState,
- PERIOD_TO_PD_FREQ,
+ UserStreak,
)
@@ -30,7 +30,7 @@ class UserStreakManager(PostgresManager):
{"user_id": user_id},
)
if res:
- return res[0]["country_iso"]
+ return res[0]["country_iso"] # type: ignore
return None
diff --git a/generalresearch/managers/thl/userhealth.py b/generalresearch/managers/thl/userhealth.py
index 6f66ce3..dab35d1 100644
--- a/generalresearch/managers/thl/userhealth.py
+++ b/generalresearch/managers/thl/userhealth.py
@@ -1,11 +1,12 @@
import ipaddress
-from datetime import timezone, datetime, timedelta
+from datetime import datetime, timedelta, timezone
from itertools import zip_longest
-from random import choice as rchoice, random
-from typing import List, Collection, Optional, Dict, Tuple
+from random import choice as rchoice
+from random import random
+from typing import Any, Collection, Dict, List, Optional, Tuple
import faker
-from pydantic import PositiveInt, NonNegativeInt
+from pydantic import NonNegativeInt, PositiveInt
from generalresearch.decorators import LOG
from generalresearch.managers.base import (
@@ -19,8 +20,8 @@ from generalresearch.models.thl.product import Product
from generalresearch.models.thl.user import User
from generalresearch.models.thl.user_iphistory import (
IPRecord,
- UserIPRecord,
UserIPHistory,
+ UserIPRecord,
)
from generalresearch.models.thl.userhealth import AuditLog, AuditLogLevel
from generalresearch.pg_helper import PostgresConfig
@@ -56,7 +57,7 @@ class UserIpHistoryManager(PostgresManagerWithRedis):
# The IP metadata is ONLY for the 'ip', NOT for any forwarded ips.
# This might get called immediately after a write, so use the non-rr
res = self.pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT iph.ip, iph.created, iph.user_id,
geo.subdivision_1_iso,
ipinfo.country_iso,
@@ -263,7 +264,7 @@ class IPRecordManager(PostgresManagerWithRedis):
data[col] = ipaddress.ip_address(ip).exploded if ip else ip
self.pg_config.execute_write(
- query=f"""
+ query="""
INSERT INTO userhealth_iphistory (
user_id, ip, created,
forwarded_ip1, forwarded_ip2, forwarded_ip3,
@@ -396,16 +397,17 @@ class AuditLogManager(PostgresManager):
with self.pg_config.make_connection() as conn:
with conn.cursor() as c:
c.execute(
- query=f"""
+ query="""
INSERT INTO userhealth_auditlog
- (user_id, created, level, event_type, event_msg, event_value)
- VALUES ( %(user_id)s , %(created)s, %(level)s, %(event_type)s,
- %(event_msg)s, %(event_value)s)
+ (user_id, created, level,
+ event_type, event_msg, event_value)
+ VALUES ( %(user_id)s , %(created)s, %(level)s,
+ %(event_type)s, %(event_msg)s, %(event_value)s)
RETURNING id;
""",
params=al.model_dump_mysql(),
)
- pk = c.fetchone()["id"]
+ pk = c.fetchone()["id"] # type: ignore
conn.commit()
al.id = pk
@@ -414,7 +416,7 @@ class AuditLogManager(PostgresManager):
def get_by_id(self, auditlog_id: PositiveInt) -> AuditLog:
res = self.pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT al.*
FROM userhealth_auditlog AS al
WHERE al.id = %s
@@ -434,7 +436,7 @@ class AuditLogManager(PostgresManager):
def filter_by_product(self, product: Product) -> List[AuditLog]:
res = self.pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT al.*
FROM userhealth_auditlog AS al
INNER JOIN thl_user AS u
@@ -450,7 +452,7 @@ class AuditLogManager(PostgresManager):
def filter_by_user_id(self, user_id: PositiveInt) -> List[AuditLog]:
res = self.pg_config.execute_sql_query(
- query=f"""
+ query="""
SELECT *
FROM userhealth_auditlog AS al
WHERE al.user_id = %s
@@ -527,7 +529,7 @@ class AuditLogManager(PostgresManager):
assert len(res) == 1
- return int(res[0]["c"])
+ return int(res[0]["c"]) # type: ignore
@staticmethod
def make_filter_str(
@@ -538,7 +540,7 @@ class AuditLogManager(PostgresManager):
event_type_like: Optional[str] = None,
event_msg: Optional[str] = None,
created_after: Optional[datetime] = None,
- ) -> Tuple[str, Dict]:
+ ) -> Tuple[str, Dict[str, Any]]:
assert user_ids, "must pass at least 1 user_id"
assert all(
[isinstance(uid, int) for uid in user_ids]
diff --git a/generalresearch/managers/thl/wall.py b/generalresearch/managers/thl/wall.py
index cd1bdbf..3ddbf51 100644
--- a/generalresearch/managers/thl/wall.py
+++ b/generalresearch/managers/thl/wall.py
@@ -1,16 +1,16 @@
import logging
from collections import defaultdict
-from datetime import datetime, timezone, timedelta
-from decimal import Decimal, ROUND_DOWN
+from datetime import datetime, timedelta, timezone
+from decimal import ROUND_DOWN, Decimal
from functools import cached_property
from random import choice as rchoice
-from typing import Optional, Collection, List
+from typing import Collection, List, Optional
from uuid import uuid4
from faker import Faker
from psycopg import sql
from psycopg.rows import dict_row
-from pydantic import AwareDatetime, PositiveInt, PostgresDsn, RedisDsn
+from pydantic import AwareDatetime, PositiveInt
from generalresearch.managers import parse_order_by
from generalresearch.managers.base import (
@@ -19,23 +19,22 @@ from generalresearch.managers.base import (
PostgresManagerWithRedis,
)
from generalresearch.models import Source
-from generalresearch.models.custom_types import UUIDStr, SurveyKey
+from generalresearch.models.custom_types import SurveyKey, UUIDStr
from generalresearch.models.thl.definitions import (
+ ReportValue,
Status,
StatusCode1,
- WallStatusCode2,
- ReportValue,
WallAdjustedStatus,
+ WallStatusCode2,
)
from generalresearch.models.thl.ledger import OrderBy
from generalresearch.models.thl.session import (
- check_adjusted_status_wall_consistent,
Wall,
WallAttempt,
+ check_adjusted_status_wall_consistent,
)
from generalresearch.models.thl.survey.model import TaskActivity
from generalresearch.pg_helper import PostgresConfig
-from generalresearch.redis_helper import RedisConfig
logger = logging.getLogger("WallManager")
fake = Faker()
@@ -379,7 +378,7 @@ class WallManager(PostgresManager):
a "progress bar" for eligible, live, surveys they've
already attempted.
"""
- query = f"""
+ query = """
SELECT
COUNT(1) as cnt
FROM thl_wall w
@@ -396,7 +395,7 @@ class WallManager(PostgresManager):
query=query,
params=params,
)
- return res[0]["cnt"]
+ return res[0]["cnt"] # type: ignore[union-attr]
def filter_wall_attempts_paginated(
self,
@@ -628,6 +627,7 @@ class WallCacheManager(PostgresManagerWithRedis):
def update_attempts_redis_(self, attempts: List[WallAttempt], user_id: int) -> None:
if not attempts:
return None
+
redis_key = self.get_cache_key_(user_id=user_id)
# Make sure attempts is ordered, so the most recent is last
# "LPUSH mylist a b c will result into a list containing c as first element,
@@ -636,11 +636,12 @@ class WallCacheManager(PostgresManagerWithRedis):
json_res = [attempt.model_dump_json() for attempt in attempts]
res = self.redis_client.lpush(redis_key, *json_res)
self.redis_client.expire(redis_key, time=60 * 60 * 24)
+
# So this doesn't grow forever, keep only the most recent 5k
self.redis_client.ltrim(redis_key, 0, 4999)
return None
- def get_attempts(self, user_id: int) -> List[WallAttempt]:
+ def get_attempts(self, user_id: PositiveInt) -> List[WallAttempt]:
"""
This is used in the GetOpportunityIDs call to get a list of surveys
(& surveygroups) which should be excluded for this user. We don't
diff --git a/generalresearch/managers/thl/wallet/__init__.py b/generalresearch/managers/thl/wallet/__init__.py
index 3b34756..0826263 100644
--- a/generalresearch/managers/thl/wallet/__init__.py
+++ b/generalresearch/managers/thl/wallet/__init__.py
@@ -1,20 +1,20 @@
from decimal import Decimal
-from typing import Optional, Dict
+from typing import Any, Dict, Optional, Union
from generalresearch.managers.thl.ledger_manager.thl_ledger import (
ThlLedgerManager,
)
from generalresearch.managers.thl.payout import (
- UserPayoutEventManager,
PayoutEventManager,
+ UserPayoutEventManager,
)
from generalresearch.managers.thl.user_manager.user_manager import (
UserManager,
)
from generalresearch.managers.thl.userhealth import UserIpHistoryManager
from generalresearch.managers.thl.wallet.approve import (
- approve_paypal_order,
approve_amt_cashout,
+ approve_paypal_order,
)
from generalresearch.models.thl.definitions import PayoutStatus
from generalresearch.models.thl.payout import UserPayoutEvent
@@ -31,7 +31,7 @@ def manage_pending_cashout(
user_ip_history_manager: UserIpHistoryManager,
user_manager: UserManager,
ledger_manager: ThlLedgerManager,
- order_data: Optional[Dict | CashMailOrderData] = None,
+ order_data: Optional[Union[Dict[str, Any], CashMailOrderData]] = None,
) -> UserPayoutEvent:
"""
Called by a UI actions performed by Todd. This rejects/approves/cancels
diff --git a/generalresearch/managers/thl/wallet/tango.py b/generalresearch/managers/thl/wallet/tango.py
index 5587435..7665ef0 100644
--- a/generalresearch/managers/thl/wallet/tango.py
+++ b/generalresearch/managers/thl/wallet/tango.py
@@ -1,7 +1,5 @@
from typing import Any, Dict
-import slack
-
from generalresearch.config import (
is_debug,
)
@@ -92,7 +90,7 @@ def get_tango_order(ref_id: str):
# return json.loads(APIHelper.json_serialize(orders[0]))
-def create_tango_order(request_data: Dict, ref_id: str) -> Dict[str, Any]:
+def create_tango_order(request_data: Dict[str, Any], ref_id: str) -> Dict[str, Any]:
"""
Create a tango gift card order.
Throws exception if anything is not right.
diff --git a/tests/managers/thl/test_ledger/test_lm_accounts.py b/tests/managers/thl/test_ledger/test_lm_accounts.py
index be9cf5b..5cfaac1 100644
--- a/tests/managers/thl/test_ledger/test_lm_accounts.py
+++ b/tests/managers/thl/test_ledger/test_lm_accounts.py
@@ -19,8 +19,18 @@ from generalresearch.models.thl.ledger import (
)
if TYPE_CHECKING:
+ from pydantic import PositiveInt
+
from generalresearch.config import GRLSettings
from generalresearch.currency import LedgerCurrency
+ from generalresearch.managers.thl.ledger_manager.ledger import LedgerManager
+ from generalresearch.models.custom_types import AccountType, Direction, UUIDStr
+ from generalresearch.models.thl import Direction
+ from generalresearch.models.thl.ledger import (
+ AccountType,
+ LedgerAccount,
+ LedgerTransaction,
+ )
from generalresearch.models.thl.product import Product
from generalresearch.models.thl.session import Session
from generalresearch.models.thl.user import User
@@ -40,7 +50,11 @@ if TYPE_CHECKING:
class TestLedgerAccountManagerNoResults:
def test_get_account_no_results(
- self, currency: "LedgerCurrency", kind, acct_id, lm
+ self,
+ currency: "LedgerCurrency",
+ kind: str,
+ acct_id: "UUIDStr",
+ lm: "LedgerManager",
):
"""Try to query for accounts that we know don't exist and confirm that
we either get the expected None result or it raises the correct
@@ -59,7 +73,11 @@ class TestLedgerAccountManagerNoResults:
assert lm.get_account(qualified_name=qn, raise_on_error=False) is None
def test_get_account_no_results_many(
- self, currency: "LedgerCurrency", kind, acct_id, lm
+ self,
+ currency: "LedgerCurrency",
+ kind: str,
+ acct_id: "UUIDStr",
+ lm: "LedgerManager",
):
qn = ":".join([currency, kind, acct_id])
@@ -95,7 +113,11 @@ class TestLedgerAccountManagerNoResults:
class TestLedgerAccountManagerCreate:
def test_create_account_error_permission(
- self, currency: "LedgerCurrency", account_type, direction, lm
+ self,
+ currency: "LedgerCurrency",
+ account_type: "AccountType",
+ direction: "Direction",
+ lm: "LedgerManager",
):
"""Confirm that the Permission values that are set on the Ledger Manger
allow the Creation action to occur.
@@ -140,7 +162,13 @@ class TestLedgerAccountManagerCreate:
str(excinfo.value) == "LedgerManager does not have sufficient permissions"
)
- def test_create(self, currency: "LedgerCurrency", account_type, direction, lm):
+ def test_create(
+ self,
+ currency: "LedgerCurrency",
+ account_type: "AccountType",
+ direction: "Direction",
+ lm: "LedgerManager",
+ ):
"""Confirm that the Permission values that are set on the Ledger Manger
allow the Creation action to occur.
"""
@@ -161,10 +189,15 @@ class TestLedgerAccountManagerCreate:
# Query for, and make sure the Account was saved in the DB
res = lm.get_account(qualified_name=qn, raise_on_error=True)
+ assert res is not None
assert account.uuid == res.uuid
def test_get_or_create(
- self, currency: "LedgerCurrency", account_type, direction, lm
+ self,
+ currency: "LedgerCurrency",
+ account_type: "AccountType",
+ direction: "Direction",
+ lm: "LedgerManager",
):
"""Confirm that the Permission values that are set on the Ledger Manger
allow the Creation action to occur.
@@ -186,13 +219,15 @@ class TestLedgerAccountManagerCreate:
# Query for, and make sure the Account was saved in the DB
res = lm.get_account(qualified_name=qn, raise_on_error=True)
+ assert res is not None
assert account.uuid == res.uuid
class TestLedgerAccountManagerGet:
- def test_get(self, ledger_account, lm):
+ def test_get(self, ledger_account: "LedgerAccount", lm: "LedgerManager"):
res = lm.get_account(qualified_name=ledger_account.qualified_name)
+ assert res is not None
assert res.uuid == ledger_account.uuid
res = lm.get_account_many(qualified_names=[ledger_account.qualified_name])
@@ -207,7 +242,12 @@ class TestLedgerAccountManagerGet:
# creation working
def test_get_balance_empty(
- self, ledger_account, ledger_account_credit, ledger_account_debit, ledger_tx, lm
+ self,
+ ledger_account: "LedgerAccount",
+ ledger_account_credit: "LedgerAccount",
+ ledger_account_debit: "LedgerAccount",
+ ledger_tx: "LedgerTransaction",
+ lm: "LedgerManager",
):
res = lm.get_account_balance(account=ledger_account)
assert res == 0
@@ -221,12 +261,12 @@ class TestLedgerAccountManagerGet:
@pytest.mark.parametrize("n_times", range(5))
def test_get_account_filtered_balance(
self,
- ledger_account,
- ledger_account_credit,
- ledger_account_debit,
- ledger_tx,
- n_times,
- lm,
+ ledger_account: "LedgerAccount",
+ ledger_account_credit: "LedgerAccount",
+ ledger_account_debit: "LedgerAccount",
+ ledger_tx: "LedgerTransaction",
+ n_times: "PositiveInt",
+ lm: "LedgerManager",
):
"""Try searching for random metadata and confirm it's always 0 because
Tx can be found.
@@ -279,6 +319,8 @@ class TestLedgerAccountManagerGet:
== rand_amount
)
- def test_get_balance_timerange_empty(self, ledger_account, lm):
+ def test_get_balance_timerange_empty(
+ self, ledger_account: "LedgerAccount", lm: "LedgerManager"
+ ):
res = lm.get_account_balance_timerange(account=ledger_account)
assert res == 0