aboutsummaryrefslogtreecommitdiff
path: root/tests/managers
diff options
context:
space:
mode:
authorMax Nanis2026-03-06 16:49:46 -0500
committerMax Nanis2026-03-06 16:49:46 -0500
commit91d040211a4ed6e4157896256a762d3854777b5e (patch)
treecd95922ea4257dc8d3f4e4cbe8534474709a20dc /tests/managers
downloadgeneralresearch-91d040211a4ed6e4157896256a762d3854777b5e.tar.gz
generalresearch-91d040211a4ed6e4157896256a762d3854777b5e.zip
Initial commitv3.3.4
Diffstat (limited to 'tests/managers')
-rw-r--r--tests/managers/__init__.py0
-rw-r--r--tests/managers/gr/__init__.py0
-rw-r--r--tests/managers/gr/test_authentication.py125
-rw-r--r--tests/managers/gr/test_business.py150
-rw-r--r--tests/managers/gr/test_team.py125
-rw-r--r--tests/managers/leaderboard.py274
-rw-r--r--tests/managers/test_events.py530
-rw-r--r--tests/managers/test_lucid.py23
-rw-r--r--tests/managers/test_userpid.py68
-rw-r--r--tests/managers/thl/__init__.py0
-rw-r--r--tests/managers/thl/test_buyer.py25
-rw-r--r--tests/managers/thl/test_cashout_method.py139
-rw-r--r--tests/managers/thl/test_category.py100
-rw-r--r--tests/managers/thl/test_contest/__init__.py0
-rw-r--r--tests/managers/thl/test_contest/test_leaderboard.py138
-rw-r--r--tests/managers/thl/test_contest/test_milestone.py296
-rw-r--r--tests/managers/thl/test_contest/test_raffle.py474
-rw-r--r--tests/managers/thl/test_harmonized_uqa.py116
-rw-r--r--tests/managers/thl/test_ipinfo.py117
-rw-r--r--tests/managers/thl/test_ledger/__init__.py0
-rw-r--r--tests/managers/thl/test_ledger/test_lm_accounts.py268
-rw-r--r--tests/managers/thl/test_ledger/test_lm_tx.py235
-rw-r--r--tests/managers/thl/test_ledger/test_lm_tx_entries.py26
-rw-r--r--tests/managers/thl/test_ledger/test_lm_tx_locks.py371
-rw-r--r--tests/managers/thl/test_ledger/test_lm_tx_metadata.py34
-rw-r--r--tests/managers/thl/test_ledger/test_thl_lm_accounts.py411
-rw-r--r--tests/managers/thl/test_ledger/test_thl_lm_bp_payout.py516
-rw-r--r--tests/managers/thl/test_ledger/test_thl_lm_tx.py1762
-rw-r--r--tests/managers/thl/test_ledger/test_thl_lm_tx__user_payouts.py505
-rw-r--r--tests/managers/thl/test_ledger/test_thl_pem.py251
-rw-r--r--tests/managers/thl/test_ledger/test_user_txs.py288
-rw-r--r--tests/managers/thl/test_ledger/test_wallet.py78
-rw-r--r--tests/managers/thl/test_maxmind.py273
-rw-r--r--tests/managers/thl/test_payout.py1269
-rw-r--r--tests/managers/thl/test_product.py362
-rw-r--r--tests/managers/thl/test_product_prod.py82
-rw-r--r--tests/managers/thl/test_profiling/__init__.py0
-rw-r--r--tests/managers/thl/test_profiling/test_question.py49
-rw-r--r--tests/managers/thl/test_profiling/test_schema.py44
-rw-r--r--tests/managers/thl/test_profiling/test_uqa.py1
-rw-r--r--tests/managers/thl/test_profiling/test_user_upk.py59
-rw-r--r--tests/managers/thl/test_session_manager.py137
-rw-r--r--tests/managers/thl/test_survey.py376
-rw-r--r--tests/managers/thl/test_survey_penalty.py101
-rw-r--r--tests/managers/thl/test_task_adjustment.py346
-rw-r--r--tests/managers/thl/test_task_status.py696
-rw-r--r--tests/managers/thl/test_user_manager/__init__.py0
-rw-r--r--tests/managers/thl/test_user_manager/test_base.py274
-rw-r--r--tests/managers/thl/test_user_manager/test_mysql.py25
-rw-r--r--tests/managers/thl/test_user_manager/test_redis.py80
-rw-r--r--tests/managers/thl/test_user_manager/test_user_fetch.py48
-rw-r--r--tests/managers/thl/test_user_manager/test_user_metadata.py88
-rw-r--r--tests/managers/thl/test_user_streak.py225
-rw-r--r--tests/managers/thl/test_userhealth.py367
-rw-r--r--tests/managers/thl/test_wall_manager.py283
55 files changed, 12630 insertions, 0 deletions
diff --git a/tests/managers/__init__.py b/tests/managers/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/managers/__init__.py
diff --git a/tests/managers/gr/__init__.py b/tests/managers/gr/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/managers/gr/__init__.py
diff --git a/tests/managers/gr/test_authentication.py b/tests/managers/gr/test_authentication.py
new file mode 100644
index 0000000..53b6931
--- /dev/null
+++ b/tests/managers/gr/test_authentication.py
@@ -0,0 +1,125 @@
+import logging
+from random import randint
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.models.gr.authentication import GRUser
+from test_utils.models.conftest import gr_user
+
+SSO_ISSUER = ""
+
+
+class TestGRUserManager:
+
+ def test_create(self, gr_um):
+ from generalresearch.models.gr.authentication import GRUser
+
+ user: GRUser = gr_um.create_dummy()
+ instance = gr_um.get_by_id(user.id)
+ assert user.id == instance.id
+
+ instance2 = gr_um.get_by_id(user.id)
+ assert user.model_dump_json() == instance2.model_dump_json()
+
+ def test_get_by_id(self, gr_user, gr_um):
+ with pytest.raises(expected_exception=ValueError) as cm:
+ gr_um.get_by_id(gr_user_id=999_999_999)
+ assert "GRUser not found" in str(cm.value)
+
+ instance = gr_um.get_by_id(gr_user_id=gr_user.id)
+ assert instance.sub == gr_user.sub
+
+ def test_get_by_sub(self, gr_user, gr_um):
+ with pytest.raises(expected_exception=ValueError) as cm:
+ gr_um.get_by_sub(sub=uuid4().hex)
+ assert "GRUser not found" in str(cm.value)
+
+ instance = gr_um.get_by_sub(sub=gr_user.sub)
+ assert instance.id == gr_user.id
+
+ def test_get_by_sub_or_create(self, gr_user, gr_um):
+ sub = f"{uuid4().hex}-{uuid4().hex}"
+
+ with pytest.raises(expected_exception=ValueError) as cm:
+ gr_um.get_by_sub(sub=sub)
+ assert "GRUser not found" in str(cm.value)
+
+ instance = gr_um.get_by_sub_or_create(sub=sub)
+ assert isinstance(instance, GRUser)
+ assert instance.sub == sub
+
+ def test_get_all(self, gr_um):
+ res1 = gr_um.get_all()
+ assert isinstance(res1, list)
+
+ gr_um.create_dummy()
+ res2 = gr_um.get_all()
+ assert len(res1) == len(res2) - 1
+
+ def test_get_by_team(self, gr_um):
+ res = gr_um.get_by_team(team_id=999_999_999)
+ assert isinstance(res, list)
+ assert res == []
+
+ def test_list_product_uuids(self, caplog, gr_user, gr_um, thl_web_rr):
+ with caplog.at_level(logging.WARNING):
+ gr_um.list_product_uuids(user=gr_user, thl_pg_config=thl_web_rr)
+ assert "prefetch not run" in caplog.text
+
+
+class TestGRTokenManager:
+
+ def test_create(self, gr_user, gr_tm):
+ assert gr_tm.create(user_id=gr_user.id) is None
+
+ token = gr_tm.get_by_user_id(user_id=gr_user.id)
+ assert gr_user.id == token.user_id
+
+ def test_get_by_user_id(self, gr_user, gr_tm):
+ assert gr_tm.create(user_id=gr_user.id) is None
+
+ token = gr_tm.get_by_user_id(user_id=gr_user.id)
+ assert gr_user.id == token.user_id
+
+ def test_prefetch_user(self, gr_user, gr_tm, gr_db, gr_redis_config):
+ from generalresearch.models.gr.authentication import GRToken
+
+ gr_tm.create(user_id=gr_user.id)
+
+ token: GRToken = gr_tm.get_by_user_id(user_id=gr_user.id)
+ assert token.user is None
+
+ token.prefetch_user(pg_config=gr_db, redis_config=gr_redis_config)
+ assert token.user.id == gr_user.id
+
+ def test_get_by_key(self, gr_user, gr_um, gr_tm):
+ gr_tm.create(user_id=gr_user.id)
+ token = gr_tm.get_by_user_id(user_id=gr_user.id)
+
+ instance = gr_tm.get_by_key(api_key=token.key)
+ assert token.created == instance.created
+
+ # Search for non-existent key
+ with pytest.raises(expected_exception=Exception) as cm:
+ gr_tm.get_by_key(api_key=uuid4().hex)
+ assert "No GRUser with token of " in str(cm.value)
+
+ @pytest.mark.skip(reason="no idea how to actually test this...")
+ def test_get_by_sso_key(self, gr_user, gr_um, gr_tm, gr_redis_config):
+ from generalresearch.models.gr.authentication import GRToken
+
+ api_key = "..."
+ jwks = {
+ # ...
+ }
+
+ instance = gr_tm.get_by_key(
+ api_key=api_key,
+ jwks=jwks,
+ audience="...",
+ issuer=SSO_ISSUER,
+ gr_redis_config=gr_redis_config,
+ )
+
+ assert isinstance(instance, GRToken)
diff --git a/tests/managers/gr/test_business.py b/tests/managers/gr/test_business.py
new file mode 100644
index 0000000..7eb77f8
--- /dev/null
+++ b/tests/managers/gr/test_business.py
@@ -0,0 +1,150 @@
+from uuid import uuid4
+
+import pytest
+
+from test_utils.models.conftest import business
+
+
+class TestBusinessBankAccountManager:
+
+ def test_init(self, business_bank_account_manager, gr_db):
+ assert business_bank_account_manager.pg_config == gr_db
+
+ def test_create(self, business, business_bank_account_manager):
+ from generalresearch.models.gr.business import (
+ TransferMethod,
+ BusinessBankAccount,
+ )
+
+ instance = business_bank_account_manager.create(
+ business_id=business.id,
+ uuid=uuid4().hex,
+ transfer_method=TransferMethod.ACH,
+ )
+ assert isinstance(instance, BusinessBankAccount)
+ assert isinstance(instance.id, int)
+
+ res = business_bank_account_manager.get_by_business_id(
+ business_id=instance.business_id
+ )
+ assert isinstance(res, list)
+ assert len(res) == 1
+ assert isinstance(res[0], BusinessBankAccount)
+ assert res[0].business_id == instance.business_id
+
+
+class TestBusinessAddressManager:
+
+ def test_create(self, business, business_address_manager):
+ from generalresearch.models.gr.business import BusinessAddress
+
+ res = business_address_manager.create(uuid=uuid4().hex, business_id=business.id)
+ assert isinstance(res, BusinessAddress)
+ assert isinstance(res.id, int)
+
+
+class TestBusinessManager:
+
+ def test_create(self, business_manager):
+ from generalresearch.models.gr.business import Business
+
+ instance = business_manager.create_dummy()
+ assert isinstance(instance, Business)
+ assert isinstance(instance.id, int)
+
+ def test_get_or_create(self, business_manager):
+ uuid_key = uuid4().hex
+
+ assert business_manager.get_by_uuid(business_uuid=uuid_key) is None
+
+ instance = business_manager.get_or_create(
+ uuid=uuid_key,
+ name=f"name-{uuid4().hex[:6]}",
+ )
+
+ res = business_manager.get_by_uuid(business_uuid=uuid_key)
+ assert res.id == instance.id
+
+ def test_get_all(self, business_manager):
+ res1 = business_manager.get_all()
+ assert isinstance(res1, list)
+
+ business_manager.create_dummy()
+ res2 = business_manager.get_all()
+ assert len(res1) == len(res2) - 1
+
+ @pytest.mark.skip(reason="TODO")
+ def test_get_by_team(self):
+ pass
+
+ def test_get_by_user_id(
+ self, business_manager, gr_user, team_manager, membership_manager
+ ):
+ res = business_manager.get_by_user_id(user_id=gr_user.id)
+ assert len(res) == 0
+
+ # Create a Business, but don't add it to anything
+ b1 = business_manager.create_dummy()
+ res = business_manager.get_by_user_id(user_id=gr_user.id)
+ assert len(res) == 0
+
+ # Create a Team, but don't create any Memberships
+ t1 = team_manager.create_dummy()
+ res = business_manager.get_by_user_id(user_id=gr_user.id)
+ assert len(res) == 0
+
+ # Create a Membership for the gr_user to the Team... but it doesn't
+ # matter because the Team doesn't have any Business yet
+ m1 = membership_manager.create(team=t1, gr_user=gr_user)
+ res = business_manager.get_by_user_id(user_id=gr_user.id)
+ assert len(res) == 0
+
+ # Add the Business to the Team... now the Business should be available
+ # to the gr_user
+ team_manager.add_business(team=t1, business=b1)
+ res = business_manager.get_by_user_id(user_id=gr_user.id)
+ assert len(res) == 1
+
+ # Add another Business to the Team!
+ b2 = business_manager.create_dummy()
+ team_manager.add_business(team=t1, business=b2)
+ res = business_manager.get_by_user_id(user_id=gr_user.id)
+ assert len(res) == 2
+
+ @pytest.mark.skip(reason="TODO")
+ def test_get_uuids_by_user_id(self):
+ pass
+
+ def test_get_by_uuid(self, business, business_manager):
+ instance = business_manager.get_by_uuid(business_uuid=business.uuid)
+ assert business.id == instance.id
+
+ def test_get_by_id(self, business, business_manager):
+ instance = business_manager.get_by_id(business_id=business.id)
+ assert business.uuid == instance.uuid
+
+ def test_cache_key(self, business):
+ assert "business:" in business.cache_key
+
+ # def test_create_raise_on_duplicate(self):
+ # b_uuid = uuid4().hex
+ #
+ # # Make the first one
+ # business = BusinessManager.create(
+ # uuid=b_uuid,
+ # name=f"test-{b_uuid[:6]}")
+ # assert isinstance(business, Business)
+ #
+ # # Try to make it again
+ # with pytest.raises(expected_exception=psycopg.errors.UniqueViolation):
+ # business = BusinessManager.create(
+ # uuid=b_uuid,
+ # name=f"test-{b_uuid[:6]}")
+ #
+ # def test_get_by_team(self, team):
+ # for idx in range(5):
+ # BusinessManager.create(name=f"Business Name #{uuid4().hex[:6]}", team=team)
+ #
+ # res = BusinessManager.get_by_team(team_id=team.id)
+ # assert isinstance(res, list)
+ # assert 5 == len(res)
diff --git a/tests/managers/gr/test_team.py b/tests/managers/gr/test_team.py
new file mode 100644
index 0000000..9215da4
--- /dev/null
+++ b/tests/managers/gr/test_team.py
@@ -0,0 +1,125 @@
+from uuid import uuid4
+
+from test_utils.models.conftest import team
+
+
+class TestMembershipManager:
+
+ def test_init(self, membership_manager, gr_db):
+ assert membership_manager.pg_config == gr_db
+
+
+class TestTeamManager:
+
+ def test_init(self, team_manager, gr_db):
+ assert team_manager.pg_config == gr_db
+
+ def test_get_or_create(self, team_manager):
+ from generalresearch.models.gr.team import Team
+
+ new_uuid = uuid4().hex
+
+ team: Team = team_manager.get_or_create(uuid=new_uuid)
+
+ assert isinstance(team, Team)
+ assert isinstance(team.id, int)
+ assert team.uuid == new_uuid
+ assert team.name == "< Unknown >"
+
+ def test_get_all(self, team_manager):
+ res1 = team_manager.get_all()
+ assert isinstance(res1, list)
+
+ team_manager.create_dummy()
+ res2 = team_manager.get_all()
+ assert len(res1) == len(res2) - 1
+
+ def test_create(self, team_manager):
+ from generalresearch.models.gr.team import Team
+
+ team: Team = team_manager.create_dummy()
+ assert isinstance(team, Team)
+ assert isinstance(team.id, int)
+
+ def test_add_user(self, team, team_manager, gr_um, gr_db, gr_redis_config):
+ from generalresearch.models.gr.authentication import GRUser
+ from generalresearch.models.gr.team import Membership
+
+ user: GRUser = gr_um.create_dummy()
+
+ instance = team_manager.add_user(team=team, gr_user=user)
+ assert isinstance(instance, Membership)
+
+ # assert team.gr_users is None
+ team.prefetch_gr_users(pg_config=gr_db, redis_config=gr_redis_config)
+ assert isinstance(team.gr_users, list)
+ assert len(team.gr_users)
+ assert team.gr_users == [user]
+
+ def test_get_by_uuid(self, team_manager):
+ from generalresearch.models.gr.team import Team
+
+ team: Team = team_manager.create_dummy()
+
+ instance = team_manager.get_by_uuid(team_uuid=team.uuid)
+ assert team.id == instance.id
+
+ def test_get_by_id(self, team_manager):
+ from generalresearch.models.gr.team import Team
+
+ team: Team = team_manager.create_dummy()
+
+ instance = team_manager.get_by_id(team_id=team.id)
+ assert team.uuid == instance.uuid
+
+ def test_get_by_user(self, team, team_manager, gr_um):
+ from generalresearch.models.gr.authentication import GRUser
+ from generalresearch.models.gr.team import Team
+
+ user: GRUser = gr_um.create_dummy()
+ team_manager.add_user(team=team, gr_user=user)
+
+ res = team_manager.get_by_user(gr_user=user)
+ assert isinstance(res, list)
+ assert len(res) == 1
+ instance = res[0]
+ assert isinstance(instance, Team)
+ assert instance.uuid == team.uuid
+
+ def test_get_by_user_duplicates(
+ self,
+ gr_user_token,
+ gr_user,
+ membership,
+ product_factory,
+ membership_factory,
+ team,
+ thl_web_rr,
+ gr_redis_config,
+ gr_db,
+ ):
+ product_factory(team=team)
+ membership_factory(team=team, gr_user=gr_user)
+
+ gr_user.prefetch_teams(
+ pg_config=gr_db,
+ redis_config=gr_redis_config,
+ )
+
+ assert len(gr_user.teams) == 1
+
+ # def test_create_raise_on_duplicate(self):
+ # t_uuid = uuid4().hex
+ #
+ # # Make the first one
+ # team = TeamManager.create(
+ # uuid=t_uuid,
+ # name=f"test-{t_uuid[:6]}")
+ # assert isinstance(team, Team)
+ #
+ # # Try to make it again
+ # with pytest.raises(expected_exception=psycopg.errors.UniqueViolation):
+ # TeamManager.create(
+ # uuid=t_uuid,
+ # name=f"test-{t_uuid[:6]}")
+ #
diff --git a/tests/managers/leaderboard.py b/tests/managers/leaderboard.py
new file mode 100644
index 0000000..4d32dd0
--- /dev/null
+++ b/tests/managers/leaderboard.py
@@ -0,0 +1,274 @@
+import os
+import time
+import zoneinfo
+from datetime import datetime, timezone
+from decimal import Decimal
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.managers.leaderboard.manager import LeaderboardManager
+from generalresearch.managers.leaderboard.tasks import hit_leaderboards
+from generalresearch.models.thl.definitions import Status
+from generalresearch.models.thl.user import User
+from generalresearch.models.thl.product import Product
+from generalresearch.models.thl.session import Session
+from generalresearch.models.thl.leaderboard import (
+ LeaderboardCode,
+ LeaderboardFrequency,
+ LeaderboardRow,
+)
+from generalresearch.models.thl.product import (
+ PayoutConfig,
+ PayoutTransformation,
+ PayoutTransformationPercentArgs,
+)
+
+# random uuid for leaderboard tests
+product_id = uuid4().hex
+
+
+@pytest.fixture(autouse=True)
+def set_timezone():
+ os.environ["TZ"] = "UTC"
+ time.tzset()
+ yield
+ # Optionally reset to default
+ os.environ.pop("TZ", None)
+ time.tzset()
+
+
+@pytest.fixture
+def session_factory():
+ return _create_session
+
+
+def _create_session(
+ product_user_id="aaa", country_iso="us", user_payout=Decimal("1.00")
+):
+ user = User(
+ product_id=product_id,
+ product_user_id=product_user_id,
+ )
+ user.product = Product(
+ id=product_id,
+ name="test",
+ redirect_url="https://www.example.com",
+ payout_config=PayoutConfig(
+ payout_transformation=PayoutTransformation(
+ f="payout_transformation_percent",
+ kwargs=PayoutTransformationPercentArgs(pct=0.5),
+ )
+ ),
+ )
+ session = Session(
+ user=user,
+ started=datetime(2025, 2, 5, 6, tzinfo=timezone.utc),
+ id=1,
+ country_iso=country_iso,
+ status=Status.COMPLETE,
+ payout=Decimal("2.00"),
+ user_payout=user_payout,
+ )
+ return session
+
+
+@pytest.fixture(scope="function")
+def setup_leaderboards(thl_redis):
+ complete_count = {
+ "aaa": 10,
+ "bbb": 6,
+ "ccc": 6,
+ "ddd": 6,
+ "eee": 2,
+ "fff": 1,
+ "ggg": 1,
+ }
+ sum_payout = {"aaa": 345, "bbb": 100, "ccc": 100}
+ max_payout = sum_payout
+ country_iso = "us"
+ for freq in [
+ LeaderboardFrequency.DAILY,
+ LeaderboardFrequency.WEEKLY,
+ LeaderboardFrequency.MONTHLY,
+ ]:
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.COMPLETE_COUNT,
+ freq=freq,
+ product_id=product_id,
+ country_iso=country_iso,
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ thl_redis.delete(m.key)
+ thl_redis.zadd(m.key, complete_count)
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.SUM_PAYOUTS,
+ freq=freq,
+ product_id=product_id,
+ country_iso=country_iso,
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ thl_redis.delete(m.key)
+ thl_redis.zadd(m.key, sum_payout)
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.LARGEST_PAYOUT,
+ freq=freq,
+ product_id=product_id,
+ country_iso=country_iso,
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ thl_redis.delete(m.key)
+ thl_redis.zadd(m.key, max_payout)
+
+
+class TestLeaderboards:
+
+ def test_leaderboard_manager(self, setup_leaderboards, thl_redis):
+ country_iso = "us"
+ board_code = LeaderboardCode.COMPLETE_COUNT
+ freq = LeaderboardFrequency.DAILY
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=board_code,
+ freq=freq,
+ product_id=product_id,
+ country_iso=country_iso,
+ within_time=datetime(2025, 2, 5, 0, 0, 0),
+ )
+ lb = m.get_leaderboard()
+ assert lb.period_start_local == datetime(
+ 2025, 2, 5, 0, 0, 0, tzinfo=zoneinfo.ZoneInfo(key="America/New_York")
+ )
+ assert lb.period_end_local == datetime(
+ 2025,
+ 2,
+ 5,
+ 23,
+ 59,
+ 59,
+ 999999,
+ tzinfo=zoneinfo.ZoneInfo(key="America/New_York"),
+ )
+ assert lb.period_start_utc == datetime(2025, 2, 5, 5, tzinfo=timezone.utc)
+ assert lb.row_count == 7
+ assert lb.rows == [
+ LeaderboardRow(bpuid="aaa", rank=1, value=10),
+ LeaderboardRow(bpuid="bbb", rank=2, value=6),
+ LeaderboardRow(bpuid="ccc", rank=2, value=6),
+ LeaderboardRow(bpuid="ddd", rank=2, value=6),
+ LeaderboardRow(bpuid="eee", rank=5, value=2),
+ LeaderboardRow(bpuid="fff", rank=6, value=1),
+ LeaderboardRow(bpuid="ggg", rank=6, value=1),
+ ]
+
+ def test_leaderboard_manager_bpuid(self, setup_leaderboards, thl_redis):
+ country_iso = "us"
+ board_code = LeaderboardCode.COMPLETE_COUNT
+ freq = LeaderboardFrequency.DAILY
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=board_code,
+ freq=freq,
+ product_id=product_id,
+ country_iso=country_iso,
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ lb = m.get_leaderboard(bp_user_id="fff", limit=1)
+
+ # TODO: this won't work correctly if I request bpuid 'ggg', because it
+ # is ordered at the end even though it is tied, so it won't get a
+ # row after ('fff')
+
+ assert lb.rows == [
+ LeaderboardRow(bpuid="eee", rank=5, value=2),
+ LeaderboardRow(bpuid="fff", rank=6, value=1),
+ LeaderboardRow(bpuid="ggg", rank=6, value=1),
+ ]
+
+ lb.censor()
+ assert lb.rows[0].bpuid == "ee*"
+
+ def test_leaderboard_hit(self, setup_leaderboards, session_factory, thl_redis):
+ hit_leaderboards(redis_client=thl_redis, session=session_factory())
+
+ for freq in [
+ LeaderboardFrequency.DAILY,
+ LeaderboardFrequency.WEEKLY,
+ LeaderboardFrequency.MONTHLY,
+ ]:
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.COMPLETE_COUNT,
+ freq=freq,
+ product_id=product_id,
+ country_iso="us",
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ lb = m.get_leaderboard(limit=1)
+ assert lb.row_count == 7
+ assert lb.rows == [LeaderboardRow(bpuid="aaa", rank=1, value=11)]
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.LARGEST_PAYOUT,
+ freq=freq,
+ product_id=product_id,
+ country_iso="us",
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ lb = m.get_leaderboard(limit=1)
+ assert lb.row_count == 3
+ assert lb.rows == [LeaderboardRow(bpuid="aaa", rank=1, value=345)]
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.SUM_PAYOUTS,
+ freq=freq,
+ product_id=product_id,
+ country_iso="us",
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ lb = m.get_leaderboard(limit=1)
+ assert lb.row_count == 3
+ assert lb.rows == [LeaderboardRow(bpuid="aaa", rank=1, value=345 + 100)]
+
+ def test_leaderboard_hit_new_row(
+ self, setup_leaderboards, session_factory, thl_redis
+ ):
+ session = session_factory(product_user_id="zzz")
+ hit_leaderboards(redis_client=thl_redis, session=session)
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.COMPLETE_COUNT,
+ freq=LeaderboardFrequency.DAILY,
+ product_id=product_id,
+ country_iso="us",
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ lb = m.get_leaderboard()
+ assert lb.row_count == 8
+ assert LeaderboardRow(bpuid="zzz", value=1, rank=6) in lb.rows
+
+ def test_leaderboard_country(self, thl_redis):
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.COMPLETE_COUNT,
+ freq=LeaderboardFrequency.DAILY,
+ product_id=product_id,
+ country_iso="jp",
+ within_time=datetime(
+ 2025,
+ 2,
+ 1,
+ ),
+ )
+ lb = m.get_leaderboard()
+ assert lb.row_count == 0
+ assert lb.period_start_local == datetime(
+ 2025, 2, 1, 0, 0, 0, tzinfo=zoneinfo.ZoneInfo(key="Asia/Tokyo")
+ )
+ assert lb.local_start_time == "2025-02-01T00:00:00+09:00"
+ assert lb.local_end_time == "2025-02-01T23:59:59.999999+09:00"
+ assert lb.period_start_utc == datetime(2025, 1, 31, 15, tzinfo=timezone.utc)
+ print(lb.model_dump(mode="json"))
diff --git a/tests/managers/test_events.py b/tests/managers/test_events.py
new file mode 100644
index 0000000..a0fab38
--- /dev/null
+++ b/tests/managers/test_events.py
@@ -0,0 +1,530 @@
+import random
+import time
+from datetime import timedelta, datetime, timezone
+from decimal import Decimal
+from functools import partial
+from typing import Optional
+from uuid import uuid4
+
+import math
+import pytest
+from math import floor
+
+from generalresearch.managers.events import EventSubscriber
+from generalresearch.models import Source
+from generalresearch.models.events import (
+ MessageKind,
+ EventType,
+ AggregateBySource,
+ MaxGaugeBySource,
+)
+from generalresearch.models.legacy.bucket import Bucket
+from generalresearch.models.thl.definitions import Status, StatusCode1
+from generalresearch.models.thl.session import Session, Wall
+from generalresearch.models.thl.user import User
+
+
+# We don't need anything in the db, so not using the db fixtures
+@pytest.fixture(scope="function")
+def product_id(product_manager):
+ return uuid4().hex
+
+
+@pytest.fixture(scope="function")
+def user_factory(product_id):
+ return partial(create_dummy, product_id=product_id)
+
+
+@pytest.fixture(scope="function")
+def event_subscriber(thl_redis_config, product_id):
+ return EventSubscriber(redis_config=thl_redis_config, product_id=product_id)
+
+
+def create_dummy(
+ product_id: Optional[str] = None, product_user_id: Optional[str] = None
+) -> User:
+ return User(
+ product_id=product_id,
+ product_user_id=product_user_id or uuid4().hex,
+ uuid=uuid4().hex,
+ created=datetime.now(tz=timezone.utc),
+ user_id=random.randint(0, floor(2**32 / 2)),
+ )
+
+
+class TestActiveUsers:
+
+ def test_run_empty(self, event_manager, product_id):
+ res = event_manager.get_user_stats(product_id)
+ assert res == {
+ "active_users_last_1h": 0,
+ "active_users_last_24h": 0,
+ "signups_last_24h": 0,
+ "in_progress_users": 0,
+ }
+
+ def test_run(self, event_manager, product_id, user_factory):
+ event_manager.clear_global_user_stats()
+ user1: User = user_factory()
+
+ # No matter how many times we do this, they're only active once
+ event_manager.handle_user(user1)
+ event_manager.handle_user(user1)
+ event_manager.handle_user(user1)
+
+ res = event_manager.get_user_stats(product_id)
+ assert res == {
+ "active_users_last_1h": 1,
+ "active_users_last_24h": 1,
+ "signups_last_24h": 1,
+ "in_progress_users": 0,
+ }
+
+ assert event_manager.get_global_user_stats() == {
+ "active_users_last_1h": 1,
+ "active_users_last_24h": 1,
+ "signups_last_24h": 1,
+ "in_progress_users": 0,
+ }
+
+ # Create a 2nd user in another product
+ product_id2 = uuid4().hex
+ user2: User = user_factory(product_id=product_id2)
+ # Change to say user was created >24 hrs ago
+ user2.created = user2.created - timedelta(hours=25)
+ event_manager.handle_user(user2)
+
+ # And now each have 1 active user
+ assert event_manager.get_user_stats(product_id) == {
+ "active_users_last_1h": 1,
+ "active_users_last_24h": 1,
+ "signups_last_24h": 1,
+ "in_progress_users": 0,
+ }
+ # user2 was created older than 24hrs ago
+ assert event_manager.get_user_stats(product_id2) == {
+ "active_users_last_1h": 1,
+ "active_users_last_24h": 1,
+ "signups_last_24h": 0,
+ "in_progress_users": 0,
+ }
+ # 2 globally active
+ assert event_manager.get_global_user_stats() == {
+ "active_users_last_1h": 2,
+ "active_users_last_24h": 2,
+ "signups_last_24h": 1,
+ "in_progress_users": 0,
+ }
+
+ def test_inprogress(self, event_manager, product_id, user_factory):
+ event_manager.clear_global_user_stats()
+ user1: User = user_factory()
+ user2: User = user_factory()
+
+ # No matter how many times we do this, they're only active once
+ event_manager.mark_user_inprogress(user1)
+ event_manager.mark_user_inprogress(user1)
+ event_manager.mark_user_inprogress(user1)
+ event_manager.mark_user_inprogress(user2)
+
+ res = event_manager.get_user_stats(product_id)
+ assert res["in_progress_users"] == 2
+
+ event_manager.unmark_user_inprogress(user1)
+ res = event_manager.get_user_stats(product_id)
+ assert res["in_progress_users"] == 1
+
+ # Shouldn't do anything
+ event_manager.unmark_user_inprogress(user1)
+ res = event_manager.get_user_stats(product_id)
+ assert res["in_progress_users"] == 1
+
+ def test_expiry(self, event_manager, product_id, user_factory):
+ event_manager.clear_global_user_stats()
+ user1: User = user_factory()
+ event_manager.handle_user(user1)
+ event_manager.mark_user_inprogress(user1)
+ sec_24hr = timedelta(hours=24).total_seconds()
+
+ # We don't want to wait an hour to test this, so we're going to
+ # just confirm that the keys will expire
+ time.sleep(1.1)
+ ttl = event_manager.redis_client.httl(
+ f"active_users_last_1h:{product_id}", user1.product_user_id
+ )[0]
+ assert 3600 - 60 <= ttl <= 3600
+ ttl = event_manager.redis_client.httl(
+ f"active_users_last_24h:{product_id}", user1.product_user_id
+ )[0]
+ assert sec_24hr - 60 <= ttl <= sec_24hr
+
+ ttl = event_manager.redis_client.httl("signups_last_24h", user1.uuid)[0]
+ assert sec_24hr - 60 <= ttl <= sec_24hr
+
+ ttl = event_manager.redis_client.httl("in_progress_users", user1.uuid)[0]
+ assert 3600 - 60 <= ttl <= 3600
+
+
+class TestSessionStats:
+
+ def test_run_empty(self, event_manager, product_id):
+ res = event_manager.get_session_stats(product_id)
+ assert res == {
+ "session_enters_last_1h": 0,
+ "session_enters_last_24h": 0,
+ "session_fails_last_1h": 0,
+ "session_fails_last_24h": 0,
+ "session_completes_last_1h": 0,
+ "session_completes_last_24h": 0,
+ "sum_payouts_last_1h": 0,
+ "sum_payouts_last_24h": 0,
+ "sum_user_payouts_last_1h": 0,
+ "sum_user_payouts_last_24h": 0,
+ "session_avg_payout_last_24h": None,
+ "session_avg_user_payout_last_24h": None,
+ "session_complete_avg_loi_last_24h": None,
+ "session_fail_avg_loi_last_24h": None,
+ }
+
+ def test_run(self, event_manager, product_id, user_factory, utc_now, utc_hour_ago):
+ event_manager.clear_global_session_stats()
+
+ user: User = user_factory()
+ session = Session(
+ country_iso="us",
+ started=utc_hour_ago + timedelta(minutes=10),
+ user=user,
+ )
+ event_manager.session_on_enter(session=session, user=user)
+ session.update(
+ finished=utc_now,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ payout=Decimal("1.00"),
+ user_payout=Decimal("0.95"),
+ )
+ event_manager.session_on_finish(session=session, user=user)
+ assert event_manager.get_session_stats(product_id) == {
+ "session_enters_last_1h": 1,
+ "session_enters_last_24h": 1,
+ "session_fails_last_1h": 0,
+ "session_fails_last_24h": 0,
+ "session_completes_last_1h": 1,
+ "session_completes_last_24h": 1,
+ "sum_payouts_last_1h": 100,
+ "sum_payouts_last_24h": 100,
+ "sum_user_payouts_last_1h": 95,
+ "sum_user_payouts_last_24h": 95,
+ "session_avg_payout_last_24h": 100,
+ "session_avg_user_payout_last_24h": 95,
+ "session_complete_avg_loi_last_24h": round(session.elapsed.total_seconds()),
+ "session_fail_avg_loi_last_24h": None,
+ }
+
+ # The session gets inserted into redis using the session.finished
+ # timestamp, no matter what time it is right now. So we can
+ # kind of test the expiry by setting it to 61 min ago
+ session2 = Session(
+ country_iso="us",
+ started=utc_hour_ago - timedelta(minutes=10),
+ user=user,
+ )
+ event_manager.session_on_enter(session=session2, user=user)
+ session2.update(
+ finished=utc_hour_ago - timedelta(minutes=1),
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ payout=Decimal("2.00"),
+ user_payout=Decimal("1.50"),
+ )
+ event_manager.session_on_finish(session=session2, user=user)
+ avg_loi = (
+ round(session.elapsed.total_seconds())
+ + round(session2.elapsed.total_seconds())
+ ) / 2
+ assert event_manager.get_session_stats(product_id) == {
+ "session_enters_last_1h": 1,
+ "session_enters_last_24h": 2,
+ "session_fails_last_1h": 0,
+ "session_fails_last_24h": 0,
+ "session_completes_last_1h": 1,
+ "session_completes_last_24h": 2,
+ "sum_payouts_last_1h": 100,
+ "sum_payouts_last_24h": 300,
+ "sum_user_payouts_last_1h": 95,
+ "sum_user_payouts_last_24h": 95 + 150,
+ "session_avg_payout_last_24h": math.ceil((100 + 200) / 2),
+ "session_avg_user_payout_last_24h": math.ceil((95 + 150) / 2),
+ "session_complete_avg_loi_last_24h": avg_loi,
+ "session_fail_avg_loi_last_24h": None,
+ }
+
+ # Don't want to wait an hour, so confirm the keys will expire
+ name = "session_completes_last_1h:" + product_id
+ res = event_manager.redis_client.hgetall(name)
+ field = (int(utc_now.timestamp()) // 60) * 60
+ field_name = str(field)
+ assert res == {field_name: "1"}
+ assert (
+ 3600 - 60 < event_manager.redis_client.httl(name, field_name)[0] < 3600 + 60
+ )
+
+ # Second BP, fail
+ product_id2 = uuid4().hex
+ user2: User = user_factory(product_id=product_id2)
+ session3 = Session(
+ country_iso="us",
+ started=utc_now - timedelta(minutes=1),
+ user=user2,
+ )
+ event_manager.session_on_enter(session=session3, user=user)
+ session3.update(
+ finished=utc_now,
+ status=Status.FAIL,
+ status_code_1=StatusCode1.BUYER_FAIL,
+ )
+ event_manager.session_on_finish(session=session3, user=user)
+ avg_loi_complete = (
+ round(session.elapsed.total_seconds())
+ + round(session2.elapsed.total_seconds())
+ ) / 2
+ assert event_manager.get_session_stats(product_id) == {
+ "session_enters_last_1h": 2,
+ "session_enters_last_24h": 3,
+ "session_fails_last_1h": 1,
+ "session_fails_last_24h": 1,
+ "session_completes_last_1h": 1,
+ "session_completes_last_24h": 2,
+ "sum_payouts_last_1h": 100,
+ "sum_payouts_last_24h": 300,
+ "sum_user_payouts_last_1h": 95,
+ "sum_user_payouts_last_24h": 95 + 150,
+ "session_avg_payout_last_24h": math.ceil((100 + 200) / 2),
+ "session_avg_user_payout_last_24h": math.ceil((95 + 150) / 2),
+ "session_complete_avg_loi_last_24h": avg_loi_complete,
+ "session_fail_avg_loi_last_24h": round(session3.elapsed.total_seconds()),
+ }
+
+
+class TestTaskStatsManager:
+ def test_empty(self, event_manager):
+ event_manager.clear_task_stats()
+ assert event_manager.get_task_stats_raw() == {
+ "live_task_count": AggregateBySource(total=0),
+ "live_tasks_max_payout": MaxGaugeBySource(value=None),
+ "task_created_count_last_1h": AggregateBySource(total=0),
+ "task_created_count_last_24h": AggregateBySource(total=0),
+ }
+ assert event_manager.get_latest_task_stats() is None
+
+ sm = event_manager.get_stats_message(product_id=uuid4().hex)
+ assert sm.data.task_created_count_last_24h.total == 0
+ assert sm.data.live_tasks_max_payout.value is None
+
+ def test(self, event_manager):
+ event_manager.clear_task_stats()
+ event_manager.set_source_task_stats(
+ source=Source.TESTING,
+ live_task_count=100,
+ live_tasks_max_payout=Decimal("1.00"),
+ )
+ assert event_manager.get_task_stats_raw() == {
+ "live_task_count": AggregateBySource(
+ total=100, by_source={Source.TESTING: 100}
+ ),
+ "live_tasks_max_payout": MaxGaugeBySource(
+ value=100, by_source={Source.TESTING: 100}
+ ),
+ "task_created_count_last_1h": AggregateBySource(total=0),
+ "task_created_count_last_24h": AggregateBySource(total=0),
+ }
+ event_manager.set_source_task_stats(
+ source=Source.TESTING2,
+ live_task_count=50,
+ live_tasks_max_payout=Decimal("2.00"),
+ )
+ assert event_manager.get_task_stats_raw() == {
+ "live_task_count": AggregateBySource(
+ total=150, by_source={Source.TESTING: 100, Source.TESTING2: 50}
+ ),
+ "live_tasks_max_payout": MaxGaugeBySource(
+ value=200, by_source={Source.TESTING: 100, Source.TESTING2: 200}
+ ),
+ "task_created_count_last_1h": AggregateBySource(total=0),
+ "task_created_count_last_24h": AggregateBySource(total=0),
+ }
+ event_manager.set_source_task_stats(
+ source=Source.TESTING,
+ live_task_count=101,
+ live_tasks_max_payout=Decimal("1.50"),
+ )
+ assert event_manager.get_task_stats_raw() == {
+ "live_task_count": AggregateBySource(
+ total=151, by_source={Source.TESTING: 101, Source.TESTING2: 50}
+ ),
+ "live_tasks_max_payout": MaxGaugeBySource(
+ value=200, by_source={Source.TESTING: 150, Source.TESTING2: 200}
+ ),
+ "task_created_count_last_1h": AggregateBySource(total=0),
+ "task_created_count_last_24h": AggregateBySource(total=0),
+ }
+ event_manager.set_source_task_stats(
+ source=Source.TESTING,
+ live_task_count=99,
+ live_tasks_max_payout=Decimal("2.50"),
+ )
+ assert event_manager.get_task_stats_raw() == {
+ "live_task_count": AggregateBySource(
+ total=149, by_source={Source.TESTING: 99, Source.TESTING2: 50}
+ ),
+ "live_tasks_max_payout": MaxGaugeBySource(
+ value=250, by_source={Source.TESTING: 250, Source.TESTING2: 200}
+ ),
+ "task_created_count_last_1h": AggregateBySource(total=0),
+ "task_created_count_last_24h": AggregateBySource(total=0),
+ }
+ event_manager.set_source_task_stats(
+ source=Source.TESTING, live_task_count=0, live_tasks_max_payout=Decimal("0")
+ )
+ assert event_manager.get_task_stats_raw() == {
+ "live_task_count": AggregateBySource(
+ total=50, by_source={Source.TESTING2: 50}
+ ),
+ "live_tasks_max_payout": MaxGaugeBySource(
+ value=200, by_source={Source.TESTING2: 200}
+ ),
+ "task_created_count_last_1h": AggregateBySource(total=0),
+ "task_created_count_last_24h": AggregateBySource(total=0),
+ }
+
+ event_manager.set_source_task_stats(
+ source=Source.TESTING,
+ live_task_count=0,
+ live_tasks_max_payout=Decimal("0"),
+ created_count=10,
+ )
+ res = event_manager.get_task_stats_raw()
+ assert res["task_created_count_last_1h"] == AggregateBySource(
+ total=10, by_source={Source.TESTING: 10}
+ )
+ assert res["task_created_count_last_24h"] == AggregateBySource(
+ total=10, by_source={Source.TESTING: 10}
+ )
+
+ event_manager.set_source_task_stats(
+ source=Source.TESTING,
+ live_task_count=0,
+ live_tasks_max_payout=Decimal("0"),
+ created_count=10,
+ )
+ res = event_manager.get_task_stats_raw()
+ assert res["task_created_count_last_1h"] == AggregateBySource(
+ total=20, by_source={Source.TESTING: 20}
+ )
+ assert res["task_created_count_last_24h"] == AggregateBySource(
+ total=20, by_source={Source.TESTING: 20}
+ )
+
+ event_manager.set_source_task_stats(
+ source=Source.TESTING2,
+ live_task_count=0,
+ live_tasks_max_payout=Decimal("0"),
+ created_count=1,
+ )
+ res = event_manager.get_task_stats_raw()
+ assert res["task_created_count_last_1h"] == AggregateBySource(
+ total=21, by_source={Source.TESTING: 20, Source.TESTING2: 1}
+ )
+ assert res["task_created_count_last_24h"] == AggregateBySource(
+ total=21, by_source={Source.TESTING: 20, Source.TESTING2: 1}
+ )
+
+ sm = event_manager.get_stats_message(product_id=uuid4().hex)
+ assert sm.data.task_created_count_last_24h.total == 21
+
+
+class TestChannelsSubscriptions:
+ def test_stats_worker(
+ self,
+ event_manager,
+ event_subscriber,
+ product_id,
+ user_factory,
+ utc_hour_ago,
+ utc_now,
+ ):
+ event_manager.clear_stats()
+ assert event_subscriber.pubsub
+ # We subscribed, so manually trigger the stats worker
+ # to get all product_ids subscribed and publish a
+ # stats message into that channel
+ event_manager.stats_worker_task()
+ assert product_id in event_manager.get_active_subscribers()
+ msg = event_subscriber.get_next_message()
+ print(msg)
+ assert msg.kind == MessageKind.STATS
+
+ user = user_factory()
+ session = Session(
+ country_iso="us",
+ started=utc_hour_ago,
+ user=user,
+ clicked_bucket=Bucket(user_payout_min=Decimal("0.50")),
+ id=1,
+ )
+
+ event_manager.handle_session_enter(session, user)
+ msg = event_subscriber.get_next_message()
+ print(msg)
+ assert msg.kind == MessageKind.EVENT
+ assert msg.data.event_type == EventType.SESSION_ENTER
+
+ wall = Wall(
+ req_survey_id="a",
+ req_cpi=Decimal("1"),
+ source=Source.TESTING,
+ session_id=session.id,
+ user_id=user.user_id,
+ )
+
+ event_manager.handle_task_enter(wall, session, user)
+ msg = event_subscriber.get_next_message()
+ print(msg)
+ assert msg.kind == MessageKind.EVENT
+ assert msg.data.event_type == EventType.TASK_ENTER
+
+ wall.update(
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ finished=datetime.now(tz=timezone.utc),
+ cpi=Decimal("1"),
+ )
+ event_manager.handle_task_finish(wall, session, user)
+ msg = event_subscriber.get_next_message()
+ print(msg)
+ assert msg.kind == MessageKind.EVENT
+ assert msg.data.event_type == EventType.TASK_FINISH
+
+ session.update(
+ finished=utc_now,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ payout=Decimal("1.00"),
+ user_payout=Decimal("1.00"),
+ )
+
+ event_manager.handle_session_finish(session, user)
+ msg = event_subscriber.get_next_message()
+ print(msg)
+ assert msg.kind == MessageKind.EVENT
+ assert msg.data.event_type == EventType.SESSION_FINISH
+
+ event_manager.stats_worker_task()
+ assert product_id in event_manager.get_active_subscribers()
+ msg = event_subscriber.get_next_message()
+ print(msg)
+ assert msg.kind == MessageKind.STATS
+ assert msg.data.active_users_last_1h == 1
+ assert msg.data.session_enters_last_24h == 1
+ assert msg.data.session_completes_last_1h == 1
+ assert msg.data.signups_last_24h == 1
diff --git a/tests/managers/test_lucid.py b/tests/managers/test_lucid.py
new file mode 100644
index 0000000..1a1bae7
--- /dev/null
+++ b/tests/managers/test_lucid.py
@@ -0,0 +1,23 @@
+import pytest
+
+from generalresearch.managers.lucid.profiling import get_profiling_library
+
+qids = ["42", "43", "45", "97", "120", "639", "15297"]
+
+
+class TestLucidProfiling:
+
+ @pytest.mark.skip
+ def test_get_library(self, thl_web_rr):
+ pks = [(qid, "us", "eng") for qid in qids]
+ qs = get_profiling_library(thl_web_rr, pks=pks)
+ assert len(qids) == len(qs)
+
+ # just making sure this doesn't raise errors
+ for q in qs:
+ q.to_upk_question()
+
+ # a lot will fail parsing because they have no options or the options are blank
+ # just asserting that we get some back
+ qs = get_profiling_library(thl_web_rr, country_iso="mx", language_iso="spa")
+ assert len(qs) > 100
diff --git a/tests/managers/test_userpid.py b/tests/managers/test_userpid.py
new file mode 100644
index 0000000..4a3f699
--- /dev/null
+++ b/tests/managers/test_userpid.py
@@ -0,0 +1,68 @@
+import pytest
+from pydantic import MySQLDsn
+
+from generalresearch.managers.marketplace.user_pid import UserPidMultiManager
+from generalresearch.sql_helper import SqlHelper
+from generalresearch.managers.cint.user_pid import CintUserPidManager
+from generalresearch.managers.dynata.user_pid import DynataUserPidManager
+from generalresearch.managers.innovate.user_pid import InnovateUserPidManager
+from generalresearch.managers.morning.user_pid import MorningUserPidManager
+
+# from generalresearch.managers.precision import PrecisionUserPidManager
+from generalresearch.managers.prodege.user_pid import ProdegeUserPidManager
+from generalresearch.managers.repdata.user_pid import RepdataUserPidManager
+from generalresearch.managers.sago.user_pid import SagoUserPidManager
+from generalresearch.managers.spectrum.user_pid import SpectrumUserPidManager
+
+dsn = ""
+
+
+class TestCintUserPidManager:
+
+ def test_filter(self):
+ m = CintUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-cint")))
+
+ with pytest.raises(expected_exception=AssertionError) as excinfo:
+ m.filter()
+ assert str(excinfo.value) == "Must pass ONE of user_ids, pids"
+
+ with pytest.raises(expected_exception=AssertionError) as excinfo:
+ m.filter(user_ids=[1, 2, 3], pids=["ed5b47c8551d453d985501391f190d3f"])
+ assert str(excinfo.value) == "Must pass ONE of user_ids, pids"
+
+ # pids get .hex before and after
+ assert m.filter(pids=["ed5b47c8551d453d985501391f190d3f"]) == m.filter(
+ pids=["ed5b47c8-551d-453d-9855-01391f190d3f"]
+ )
+
+ # user_ids and pids are for the same 3 users
+ res1 = m.filter(user_ids=[61586871, 61458915, 61390116])
+ res2 = m.filter(
+ pids=[
+ "ed5b47c8551d453d985501391f190d3f",
+ "7e640732c59f43d1b7c00137ab66600c",
+ "5160aeec9c3b4dbb85420128e6da6b5a",
+ ]
+ )
+ assert res1 == res2
+
+
+class TestUserPidMultiManager:
+
+ def test_filter(self):
+ managers = [
+ CintUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-cint"))),
+ DynataUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-dynata"))),
+ InnovateUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-innovate"))),
+ MorningUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-morning"))),
+ ProdegeUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-prodege"))),
+ RepdataUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-repdata"))),
+ SagoUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-sago"))),
+ SpectrumUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-spectrum"))),
+ ]
+ m = UserPidMultiManager(sql_helper=SqlHelper(MySQLDsn(dsn)), managers=managers)
+ res = m.filter(user_ids=[1])
+ assert len(res) == len(managers)
+
+ res = m.filter(user_ids=[1, 2, 3])
+ assert len(res) > len(managers)
diff --git a/tests/managers/thl/__init__.py b/tests/managers/thl/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/managers/thl/__init__.py
diff --git a/tests/managers/thl/test_buyer.py b/tests/managers/thl/test_buyer.py
new file mode 100644
index 0000000..69ea105
--- /dev/null
+++ b/tests/managers/thl/test_buyer.py
@@ -0,0 +1,25 @@
+from generalresearch.models import Source
+
+
+class TestBuyer:
+
+ def test(
+ self,
+ delete_buyers_surveys,
+ buyer_manager,
+ ):
+
+ bs = buyer_manager.bulk_get_or_create(source=Source.TESTING, codes=["a", "b"])
+ assert len(bs) == 2
+ buyer_a = bs[0]
+ assert buyer_a.id is not None
+ bs2 = buyer_manager.bulk_get_or_create(source=Source.TESTING, codes=["a", "c"])
+ assert len(bs2) == 2
+ buyer_a2 = bs2[0]
+ buyer_c = bs2[1]
+ # a isn't created again
+ assert buyer_a == buyer_a2
+ assert bs2[0].id is not None
+
+ # and its cached
+ assert buyer_c.id == buyer_manager.source_code_pk[f"{Source.TESTING.value}:c"]
diff --git a/tests/managers/thl/test_cashout_method.py b/tests/managers/thl/test_cashout_method.py
new file mode 100644
index 0000000..fe561d3
--- /dev/null
+++ b/tests/managers/thl/test_cashout_method.py
@@ -0,0 +1,139 @@
+import pytest
+
+from generalresearch.models.thl.wallet import PayoutType
+from generalresearch.models.thl.wallet.cashout_method import (
+ CashMailCashoutMethodData,
+ USDeliveryAddress,
+ PaypalCashoutMethodData,
+)
+from test_utils.managers.cashout_methods import (
+ EXAMPLE_TANGO_CASHOUT_METHODS,
+ AMT_ASSIGNMENT_CASHOUT_METHOD,
+ AMT_BONUS_CASHOUT_METHOD,
+)
+
+
+class TestTangoCashoutMethods:
+
+ def test_create_and_get(self, cashout_method_manager, setup_cashoutmethod_db):
+ res = cashout_method_manager.filter(payout_types=[PayoutType.TANGO])
+ assert len(res) == 2
+ cm = [x for x in res if x.ext_id == "U025035"][0]
+ assert EXAMPLE_TANGO_CASHOUT_METHODS[0] == cm
+
+ def test_user(
+ self, cashout_method_manager, user_with_wallet, setup_cashoutmethod_db
+ ):
+ res = cashout_method_manager.get_cashout_methods(user_with_wallet)
+ # This user ONLY has the two tango cashout methods, no AMT
+ assert len(res) == 2
+
+
+class TestAMTCashoutMethods:
+
+ def test_create_and_get(self, cashout_method_manager, setup_cashoutmethod_db):
+ res = cashout_method_manager.filter(payout_types=[PayoutType.AMT])
+ assert len(res) == 2
+ cm = [x for x in res if x.name == "AMT Assignment"][0]
+ assert AMT_ASSIGNMENT_CASHOUT_METHOD == cm
+ cm = [x for x in res if x.name == "AMT Bonus"][0]
+ assert AMT_BONUS_CASHOUT_METHOD == cm
+
+ def test_user(
+ self, cashout_method_manager, user_with_wallet_amt, setup_cashoutmethod_db
+ ):
+ res = cashout_method_manager.get_cashout_methods(user_with_wallet_amt)
+ # This user has the 2 tango, plus amt bonus & assignment
+ assert len(res) == 4
+
+
+class TestUserCashoutMethods:
+
+ def test(self, cashout_method_manager, user_with_wallet, delete_cashoutmethod_db):
+ delete_cashoutmethod_db()
+
+ res = cashout_method_manager.get_cashout_methods(user_with_wallet)
+ assert len(res) == 0
+
+ def test_cash_in_mail(
+ self, cashout_method_manager, user_with_wallet, delete_cashoutmethod_db
+ ):
+ delete_cashoutmethod_db()
+
+ data = CashMailCashoutMethodData(
+ delivery_address=USDeliveryAddress.model_validate(
+ {
+ "name_or_attn": "Josh Ackerman",
+ "address": "123 Fake St",
+ "city": "San Francisco",
+ "state": "CA",
+ "postal_code": "12345",
+ }
+ )
+ )
+ cashout_method_manager.create_cash_in_mail_cashout_method(
+ data=data, user=user_with_wallet
+ )
+
+ res = cashout_method_manager.get_cashout_methods(user_with_wallet)
+ assert len(res) == 1
+ assert res[0].data.delivery_address.postal_code == "12345"
+
+ # try to create the same one again. should just do nothing
+ cashout_method_manager.create_cash_in_mail_cashout_method(
+ data=data, user=user_with_wallet
+ )
+ res = cashout_method_manager.get_cashout_methods(user_with_wallet)
+ assert len(res) == 1
+
+ # Create with a new address, will create a new one
+ data.delivery_address.postal_code = "99999"
+ cashout_method_manager.create_cash_in_mail_cashout_method(
+ data=data, user=user_with_wallet
+ )
+ res = cashout_method_manager.get_cashout_methods(user_with_wallet)
+ assert len(res) == 2
+
+ def test_paypal(
+ self, cashout_method_manager, user_with_wallet, delete_cashoutmethod_db
+ ):
+ delete_cashoutmethod_db()
+
+ data = PaypalCashoutMethodData(email="test@example.com")
+ cashout_method_manager.create_paypal_cashout_method(
+ data=data, user=user_with_wallet
+ )
+ res = cashout_method_manager.get_cashout_methods(user_with_wallet)
+ assert len(res) == 1
+ assert res[0].data.email == "test@example.com"
+
+ # try to create the same one again. should just do nothing
+ cashout_method_manager.create_paypal_cashout_method(
+ data=data, user=user_with_wallet
+ )
+ res = cashout_method_manager.get_cashout_methods(user_with_wallet)
+ assert len(res) == 1
+
+ # Create with a new email, will error! must delete the old one first.
+ # We can only have one paypal active
+ data.email = "test2@example.com"
+ with pytest.raises(
+ ValueError,
+ match="User already has a cashout method of this type. Delete the existing one and try again.",
+ ):
+ cashout_method_manager.create_paypal_cashout_method(
+ data=data, user=user_with_wallet
+ )
+ res = cashout_method_manager.get_cashout_methods(user_with_wallet)
+ assert len(res) == 1
+
+ cashout_method_manager.delete_cashout_method(res[0].id)
+ res = cashout_method_manager.get_cashout_methods(user_with_wallet)
+ assert len(res) == 0
+
+ cashout_method_manager.create_paypal_cashout_method(
+ data=data, user=user_with_wallet
+ )
+ res = cashout_method_manager.get_cashout_methods(user_with_wallet)
+ assert len(res) == 1
+ assert res[0].data.email == "test2@example.com"
diff --git a/tests/managers/thl/test_category.py b/tests/managers/thl/test_category.py
new file mode 100644
index 0000000..ad0f07b
--- /dev/null
+++ b/tests/managers/thl/test_category.py
@@ -0,0 +1,100 @@
+import pytest
+
+from generalresearch.models.thl.category import Category
+
+
+class TestCategory:
+
+ @pytest.fixture
+ def beauty_fitness(self, thl_web_rw):
+
+ return Category(
+ uuid="12c1e96be82c4642a07a12a90ce6f59e",
+ adwords_vertical_id="44",
+ label="Beauty & Fitness",
+ path="/Beauty & Fitness",
+ )
+
+ @pytest.fixture
+ def hair_care(self, beauty_fitness):
+
+ return Category(
+ uuid="dd76c4b565d34f198dad3687326503d6",
+ adwords_vertical_id="146",
+ label="Hair Care",
+ path="/Beauty & Fitness/Hair Care",
+ )
+
+ @pytest.fixture
+ def hair_loss(self, hair_care):
+
+ return Category(
+ uuid="aacff523c8e246888215611ec3b823c0",
+ adwords_vertical_id="235",
+ label="Hair Loss",
+ path="/Beauty & Fitness/Hair Care/Hair Loss",
+ )
+
+ @pytest.fixture
+ def category_data(
+ self, category_manager, thl_web_rw, beauty_fitness, hair_care, hair_loss
+ ):
+ cats = [beauty_fitness, hair_care, hair_loss]
+ data = [x.model_dump(mode="json") for x in cats]
+ # We need the parent pk's to set the parent_id. So insert all without a parent,
+ # then pull back all pks and map to the parents as parsed by the parent_path
+ query = """
+ INSERT INTO marketplace_category
+ (uuid, adwords_vertical_id, label, path)
+ VALUES
+ (%(uuid)s, %(adwords_vertical_id)s, %(label)s, %(path)s)
+ ON CONFLICT (uuid) DO NOTHING;
+ """
+ with thl_web_rw.make_connection() as conn:
+ with conn.cursor() as c:
+ c.executemany(query=query, params_seq=data)
+ conn.commit()
+
+ res = thl_web_rw.execute_sql_query("SELECT id, path FROM marketplace_category")
+ path_id = {x["path"]: x["id"] for x in res}
+ data = [
+ {"id": path_id[c.path], "parent_id": path_id[c.parent_path]}
+ for c in cats
+ if c.parent_path
+ ]
+ query = """
+ UPDATE marketplace_category
+ SET parent_id = %(parent_id)s
+ WHERE id = %(id)s;
+ """
+ with thl_web_rw.make_connection() as conn:
+ with conn.cursor() as c:
+ c.executemany(query=query, params_seq=data)
+ conn.commit()
+
+ category_manager.populate_caches()
+
+ def test(
+ self,
+ category_data,
+ category_manager,
+ beauty_fitness,
+ hair_care,
+ hair_loss,
+ ):
+ # category_manager on init caches the category info. This rarely/never changes so this is fine,
+ # but now that tests get run on a new db each time, the category_manager is inited before
+ # the fixtures run. so category_manager's cache needs to be rerun
+
+ # path='/Beauty & Fitness/Hair Care/Hair Loss'
+ c: Category = category_manager.get_by_label("Hair Loss")
+ # Beauty & Fitness
+ assert beauty_fitness.uuid == category_manager.get_top_level(c).uuid
+
+ c: Category = category_manager.categories[beauty_fitness.uuid]
+ # The root is itself
+ assert c == category_manager.get_category_root(c)
+
+ # The root is Beauty & Fitness
+ c: Category = category_manager.get_by_label("Hair Loss")
+ assert beauty_fitness.uuid == category_manager.get_category_root(c).uuid
diff --git a/tests/managers/thl/test_contest/__init__.py b/tests/managers/thl/test_contest/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/managers/thl/test_contest/__init__.py
diff --git a/tests/managers/thl/test_contest/test_leaderboard.py b/tests/managers/thl/test_contest/test_leaderboard.py
new file mode 100644
index 0000000..80a88a5
--- /dev/null
+++ b/tests/managers/thl/test_contest/test_leaderboard.py
@@ -0,0 +1,138 @@
+from datetime import datetime, timezone, timedelta
+from zoneinfo import ZoneInfo
+
+from generalresearch.currency import USDCent
+from generalresearch.models.thl.contest.definitions import (
+ ContestStatus,
+ ContestEndReason,
+)
+from generalresearch.models.thl.contest.leaderboard import (
+ LeaderboardContest,
+ LeaderboardContestCreate,
+)
+from generalresearch.models.thl.product import Product
+from generalresearch.models.thl.user import User
+from test_utils.managers.contest.conftest import (
+ leaderboard_contest_in_db as contest_in_db,
+ leaderboard_contest_create as contest_create,
+)
+
+
+class TestLeaderboardContestCRUD:
+
+ def test_create(
+ self,
+ contest_create: LeaderboardContestCreate,
+ product_user_wallet_yes: Product,
+ thl_lm,
+ contest_manager,
+ ):
+ c = contest_manager.create(
+ product_id=product_user_wallet_yes.uuid, contest_create=contest_create
+ )
+ c_out = contest_manager.get(c.uuid)
+ assert c == c_out
+
+ assert isinstance(c, LeaderboardContest)
+ assert c.prize_count == 2
+ assert c.status == ContestStatus.ACTIVE
+ # We have it set in the fixture as the daily contest for 2025-01-01
+ assert c.end_condition.ends_at == datetime(
+ 2025, 1, 1, 23, 59, 59, 999999, tzinfo=ZoneInfo("America/New_York")
+ ).astimezone(tz=timezone.utc) + timedelta(minutes=90)
+
+ def test_enter(
+ self,
+ user_with_wallet: User,
+ contest_in_db: LeaderboardContest,
+ thl_lm,
+ contest_manager,
+ user_manager,
+ thl_redis,
+ ):
+ contest = contest_in_db
+ user = user_with_wallet
+
+ c: LeaderboardContest = contest_manager.get(contest_uuid=contest.uuid)
+
+ c = contest_manager.get_leaderboard_user_view(
+ contest_uuid=contest.uuid,
+ user=user,
+ redis_client=thl_redis,
+ user_manager=user_manager,
+ )
+ assert c.user_rank is None
+
+ lbm = c.get_leaderboard_manager()
+ lbm.hit_complete_count(user.product_user_id)
+
+ c = contest_manager.get_leaderboard_user_view(
+ contest_uuid=contest.uuid,
+ user=user,
+ redis_client=thl_redis,
+ user_manager=user_manager,
+ )
+ assert c.user_rank == 1
+
+ def test_contest_ends(
+ self,
+ user_with_wallet: User,
+ contest_in_db: LeaderboardContest,
+ thl_lm,
+ contest_manager,
+ user_manager,
+ thl_redis,
+ ):
+ # The contest should be over. We need to trigger it.
+ contest = contest_in_db
+ contest._redis_client = thl_redis
+ contest._user_manager = user_manager
+ user = user_with_wallet
+
+ lbm = contest.get_leaderboard_manager()
+ lbm.hit_complete_count(user.product_user_id)
+
+ c = contest_manager.get_leaderboard_user_view(
+ contest_uuid=contest.uuid,
+ user=user,
+ redis_client=thl_redis,
+ user_manager=user_manager,
+ )
+ assert c.user_rank == 1
+
+ bp_wallet = thl_lm.get_account_or_create_bp_wallet_by_uuid(user.product_id)
+ bp_wallet_balance = thl_lm.get_account_balance(account=bp_wallet)
+ assert bp_wallet_balance == 0
+ user_wallet = thl_lm.get_account_or_create_user_wallet(user=user)
+ user_balance = thl_lm.get_account_balance(user_wallet)
+ assert user_balance == 0
+
+ decision, reason = contest.should_end()
+ assert decision
+ assert reason == ContestEndReason.ENDS_AT
+
+ contest_manager.end_contest_if_over(contest=contest, ledger_manager=thl_lm)
+
+ c: LeaderboardContest = contest_manager.get(contest_uuid=contest.uuid)
+ assert c.status == ContestStatus.COMPLETED
+ print(c)
+
+ user_contest = contest_manager.get_leaderboard_user_view(
+ contest_uuid=contest.uuid,
+ user=user,
+ redis_client=thl_redis,
+ user_manager=user_manager,
+ )
+ assert len(user_contest.user_winnings) == 1
+ w = user_contest.user_winnings[0]
+ assert w.product_user_id == user.product_user_id
+ assert w.prize.cash_amount == USDCent(15_00)
+
+ # The prize is $15.00, so the user should get $15, paid by the bp
+ assert thl_lm.get_account_balance(account=user_wallet) == 15_00
+ # contest wallet is 0, and the BP gets 20c
+ contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid(
+ contest_uuid=c.uuid
+ )
+ assert thl_lm.get_account_balance(account=contest_wallet) == 0
+ assert thl_lm.get_account_balance(account=bp_wallet) == -15_00
diff --git a/tests/managers/thl/test_contest/test_milestone.py b/tests/managers/thl/test_contest/test_milestone.py
new file mode 100644
index 0000000..7312a64
--- /dev/null
+++ b/tests/managers/thl/test_contest/test_milestone.py
@@ -0,0 +1,296 @@
+from datetime import datetime, timezone
+
+from generalresearch.models.thl.contest.definitions import (
+ ContestStatus,
+ ContestEndReason,
+)
+from generalresearch.models.thl.contest.milestone import (
+ MilestoneContest,
+ MilestoneContestCreate,
+ MilestoneUserView,
+ ContestEntryTrigger,
+)
+from generalresearch.models.thl.product import Product
+from generalresearch.models.thl.user import User
+from test_utils.managers.contest.conftest import (
+ milestone_contest as contest,
+ milestone_contest_in_db as contest_in_db,
+ milestone_contest_create as contest_create,
+ milestone_contest_factory as contest_factory,
+)
+
+
+class TestMilestoneContest:
+
+ def test_should_end(self, contest: MilestoneContest, thl_lm, contest_manager):
+ # contest is active and has no entries
+ should, msg = contest.should_end()
+ assert not should, msg
+
+ # Change so that the contest ends now
+ contest.end_condition.ends_at = datetime.now(tz=timezone.utc)
+ should, msg = contest.should_end()
+ assert should
+ assert msg == ContestEndReason.ENDS_AT
+
+ # Change the win amount it thinks it past over the target
+ contest.end_condition.ends_at = None
+ contest.end_condition.max_winners = 10
+ contest.win_count = 10
+ should, msg = contest.should_end()
+ assert should
+ assert msg == ContestEndReason.MAX_WINNERS
+
+
+class TestMilestoneContestCRUD:
+
+ def test_create(
+ self,
+ contest_create: MilestoneContestCreate,
+ product_user_wallet_yes: Product,
+ thl_lm,
+ contest_manager,
+ ):
+ c = contest_manager.create(
+ product_id=product_user_wallet_yes.uuid, contest_create=contest_create
+ )
+ c_out = contest_manager.get(c.uuid)
+ assert c == c_out
+
+ assert isinstance(c, MilestoneContest)
+ assert c.prize_count == 2
+ assert c.status == ContestStatus.ACTIVE
+ assert c.end_condition.max_winners == 5
+ assert c.entry_trigger == ContestEntryTrigger.TASK_COMPLETE
+ assert c.target_amount == 3
+ assert c.win_count == 0
+
+ def test_enter(
+ self,
+ user_with_wallet: User,
+ contest_in_db: MilestoneContest,
+ thl_lm,
+ contest_manager,
+ ):
+ # Users CANNOT directly enter a milestone contest through the api,
+ # but we'll call this manager method when a trigger is hit.
+ contest = contest_in_db
+ user = user_with_wallet
+
+ contest_manager.enter_milestone_contest(
+ contest_uuid=contest.uuid,
+ user=user,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ incr=1,
+ )
+
+ c: MilestoneContest = contest_manager.get(contest_uuid=contest.uuid)
+ assert c.status == ContestStatus.ACTIVE
+ assert not hasattr(c, "current_amount")
+ assert not hasattr(c, "current_participants")
+
+ c: MilestoneUserView = contest_manager.get_milestone_user_view(
+ contest_uuid=contest.uuid, user=user_with_wallet
+ )
+ assert c.user_amount == 1
+
+ # Contest wallet should have 0 bc there is no ledger
+ contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid(
+ contest_uuid=contest.uuid
+ )
+ assert thl_lm.get_account_balance(contest_wallet) == 0
+
+ # Enter again!
+ contest_manager.enter_milestone_contest(
+ contest_uuid=contest.uuid,
+ user=user,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ incr=1,
+ )
+ c: MilestoneUserView = contest_manager.get_milestone_user_view(
+ contest_uuid=contest.uuid, user=user_with_wallet
+ )
+ assert c.user_amount == 2
+
+ # We should have ONE entry with a value of 2
+ e = contest_manager.get_entries_by_contest_id(c.id)
+ assert len(e) == 1
+ assert e[0].amount == 2
+
+ def test_enter_win(
+ self,
+ user_with_wallet: User,
+ contest_in_db: MilestoneContest,
+ thl_lm,
+ contest_manager,
+ ):
+ # User enters contest, which brings the USER'S total amount above the limit,
+ # and the user reaches the milestone
+ contest = contest_in_db
+ user = user_with_wallet
+
+ user_wallet = thl_lm.get_account_or_create_user_wallet(user=user)
+ user_balance = thl_lm.get_account_balance(account=user_wallet)
+ bp_wallet = thl_lm.get_account_or_create_bp_wallet_by_uuid(
+ product_uuid=user.product_id
+ )
+ bp_wallet_balance = thl_lm.get_account_balance(account=bp_wallet)
+
+ c: MilestoneUserView = contest_manager.get_milestone_user_view(
+ contest_uuid=contest.uuid, user=user_with_wallet
+ )
+ assert c.user_amount == 0
+ res, msg = c.is_user_eligible(country_iso="us")
+ assert res, msg
+
+ # User reaches the milestone after 3 completes/whatevers.
+ for _ in range(3):
+ contest_manager.enter_milestone_contest(
+ contest_uuid=contest.uuid,
+ user=user,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ incr=1,
+ )
+
+ # to be clear, the contest itself doesn't end!
+ c: MilestoneContest = contest_manager.get(contest_uuid=contest.uuid)
+ assert c.status == ContestStatus.ACTIVE
+
+ c: MilestoneUserView = contest_manager.get_milestone_user_view(
+ contest_uuid=contest.uuid, user=user_with_wallet
+ )
+ assert c.user_amount == 3
+ res, msg = c.is_user_eligible(country_iso="us")
+ assert not res
+ assert msg == "User should have won already"
+
+ assert len(c.user_winnings) == 2
+ assert c.win_count == 1
+
+ # The prize was awarded! User should have won $1.00
+ assert thl_lm.get_account_balance(user_wallet) - user_balance == 100
+ # Which was paid from the BP's balance
+ assert thl_lm.get_account_balance(bp_wallet) - bp_wallet_balance == -100
+
+ # winnings = cm.get_winnings_by_user(user=user)
+ # assert len(winnings) == 1
+ # win = winnings[0]
+ # assert win.product_user_id == user.product_user_id
+
+ def test_enter_ends(
+ self,
+ user_factory,
+ product_user_wallet_yes: Product,
+ contest_in_db: MilestoneContest,
+ thl_lm,
+ contest_manager,
+ ):
+ # Multiple users reach the milestone. Contest ends after 5 wins.
+ users = [user_factory(product=product_user_wallet_yes) for _ in range(5)]
+ contest = contest_in_db
+
+ for u in users:
+ contest_manager.enter_milestone_contest(
+ contest_uuid=contest.uuid,
+ user=u,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ incr=3,
+ )
+
+ c: MilestoneContest = contest_manager.get(contest_uuid=contest.uuid)
+ assert c.status == ContestStatus.COMPLETED
+ assert c.end_reason == ContestEndReason.MAX_WINNERS
+
+ def test_trigger(
+ self,
+ user_with_wallet: User,
+ contest_in_db: MilestoneContest,
+ thl_lm,
+ contest_manager,
+ ):
+ # Pretend user just got a complete
+ cnt = contest_manager.hit_milestone_triggers(
+ country_iso="us",
+ user=user_with_wallet,
+ event=ContestEntryTrigger.TASK_COMPLETE,
+ ledger_manager=thl_lm,
+ )
+ assert cnt == 1
+
+ # Assert this contest got entered
+ c: MilestoneUserView = contest_manager.get_milestone_user_view(
+ contest_uuid=contest_in_db.uuid, user=user_with_wallet
+ )
+ assert c.user_amount == 1
+
+
+class TestMilestoneContestUserViews:
+ def test_list_user_eligible_country(
+ self, user_with_wallet: User, contest_factory, thl_lm, contest_manager
+ ):
+ # No contests exists
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_wallet, country_iso="us"
+ )
+ assert len(cs) == 0
+
+ # Create a contest. It'll be in the US/CA
+ contest_factory(country_isos={"us", "ca"})
+
+ # Not eligible in mexico
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_wallet, country_iso="mx"
+ )
+ assert len(cs) == 0
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_wallet, country_iso="us"
+ )
+ assert len(cs) == 1
+
+ # Create another, any country
+ contest_factory(country_isos=None)
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_wallet, country_iso="mx"
+ )
+ assert len(cs) == 1
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_wallet, country_iso="us"
+ )
+ assert len(cs) == 2
+
+ def test_list_user_eligible(
+ self, user_with_money: User, contest_factory, thl_lm, contest_manager
+ ):
+ # User reaches milestone after 1 complete
+ c = contest_factory(target_amount=1)
+ user = user_with_money
+
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_money, country_iso="us"
+ )
+ assert len(cs) == 1
+
+ contest_manager.enter_milestone_contest(
+ contest_uuid=c.uuid, user=user, country_iso="us", ledger_manager=thl_lm
+ )
+
+ # User isn't eligible anymore
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_money, country_iso="us"
+ )
+ assert len(cs) == 0
+
+ # But it comes back in the list entered
+ cs = contest_manager.get_many_by_user_entered(user=user_with_money)
+ assert len(cs) == 1
+ c = cs[0]
+ assert c.user_amount == 1
+ assert isinstance(c, MilestoneUserView)
+ assert not hasattr(c, "current_win_probability")
+
+ # They won one contest with 2 prizes
+ assert len(contest_manager.get_winnings_by_user(user_with_money)) == 2
diff --git a/tests/managers/thl/test_contest/test_raffle.py b/tests/managers/thl/test_contest/test_raffle.py
new file mode 100644
index 0000000..060055a
--- /dev/null
+++ b/tests/managers/thl/test_contest/test_raffle.py
@@ -0,0 +1,474 @@
+from datetime import datetime, timezone
+
+import pytest
+from pydantic import ValidationError
+from pytest import approx
+
+from generalresearch.currency import USDCent
+from generalresearch.managers.thl.ledger_manager.exceptions import (
+ LedgerTransactionConditionFailedError,
+)
+from generalresearch.models.thl.contest import (
+ ContestPrize,
+ ContestEntryRule,
+ ContestEndCondition,
+)
+from generalresearch.models.thl.contest.definitions import (
+ ContestStatus,
+ ContestPrizeKind,
+ ContestEndReason,
+)
+from generalresearch.models.thl.contest.exceptions import ContestError
+from generalresearch.models.thl.contest.raffle import (
+ ContestEntry,
+ ContestEntryType,
+)
+from generalresearch.models.thl.contest.raffle import (
+ RaffleContest,
+ RaffleContestCreate,
+ RaffleUserView,
+)
+from generalresearch.models.thl.product import Product
+from generalresearch.models.thl.user import User
+from test_utils.managers.contest.conftest import (
+ raffle_contest as contest,
+ raffle_contest_in_db as contest_in_db,
+ raffle_contest_create as contest_create,
+ raffle_contest_factory as contest_factory,
+)
+
+
+class TestRaffleContest:
+
+ def test_should_end(self, contest: RaffleContest, thl_lm, contest_manager):
+ # contest is active and has no entries
+ should, msg = contest.should_end()
+ assert not should, msg
+
+ # Change so that the contest ends now
+ contest.end_condition.ends_at = datetime.now(tz=timezone.utc)
+ should, msg = contest.should_end()
+ assert should
+ assert msg == ContestEndReason.ENDS_AT
+
+ # Change the entry amount it thinks it has to over the target
+ contest.end_condition.ends_at = None
+ contest.current_amount = USDCent(100)
+ should, msg = contest.should_end()
+ assert should
+ assert msg == ContestEndReason.TARGET_ENTRY_AMOUNT
+
+
+class TestRaffleContestCRUD:
+
+ def test_create(
+ self,
+ contest_create: RaffleContestCreate,
+ product_user_wallet_yes: Product,
+ thl_lm,
+ contest_manager,
+ ):
+ c = contest_manager.create(
+ product_id=product_user_wallet_yes.uuid, contest_create=contest_create
+ )
+ c_out = contest_manager.get(c.uuid)
+ assert c == c_out
+
+ assert isinstance(c, RaffleContest)
+ assert c.prize_count == 1
+ assert c.status == ContestStatus.ACTIVE
+ assert c.end_condition.target_entry_amount == USDCent(100)
+ assert c.current_amount == 0
+ assert c.current_participants == 0
+
+ @pytest.mark.parametrize("user_with_money", [{"min_balance": 60}], indirect=True)
+ def test_enter(
+ self,
+ user_with_money: User,
+ contest_in_db: RaffleContest,
+ thl_lm,
+ contest_manager,
+ ):
+ # Raffle ends at $1.00. User enters for $0.60
+ print(user_with_money.product_id)
+ print(contest_in_db.product_id)
+ print(contest_in_db.uuid)
+ contest = contest_in_db
+
+ user_wallet = thl_lm.get_account_or_create_user_wallet(user=user_with_money)
+ user_balance = thl_lm.get_account_balance(account=user_wallet)
+
+ entry = ContestEntry(
+ entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(60)
+ )
+ entry = contest_manager.enter_contest(
+ contest_uuid=contest.uuid,
+ entry=entry,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ )
+ c: RaffleContest = contest_manager.get(contest_uuid=contest.uuid)
+ assert c.current_amount == USDCent(60)
+ assert c.current_participants == 1
+ assert c.status == ContestStatus.ACTIVE
+
+ c: RaffleUserView = contest_manager.get_raffle_user_view(
+ contest_uuid=contest.uuid, user=user_with_money
+ )
+ assert c.user_amount == USDCent(60)
+ assert c.user_amount_today == USDCent(60)
+ assert c.projected_win_probability == approx(60 / 100, rel=0.01)
+
+ # Contest wallet should have $0.60
+ contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid(
+ contest_uuid=contest.uuid
+ )
+ assert thl_lm.get_account_balance(account=contest_wallet) == 60
+ # User spent 60c
+ assert user_balance - thl_lm.get_account_balance(account=user_wallet) == 60
+
+ @pytest.mark.parametrize("user_with_money", [{"min_balance": 120}], indirect=True)
+ def test_enter_ends(
+ self,
+ user_with_money: User,
+ contest_in_db: RaffleContest,
+ thl_lm,
+ contest_manager,
+ ):
+ # User enters contest, which brings the total amount above the limit,
+ # and the contest should end, with a winner selected
+ contest = contest_in_db
+
+ bp_wallet = thl_lm.get_account_or_create_bp_wallet_by_uuid(
+ user_with_money.product_id
+ )
+ # I bribed the user, so the balance is not 0
+ bp_wallet_balance = thl_lm.get_account_balance(account=bp_wallet)
+
+ for _ in range(2):
+ entry = ContestEntry(
+ entry_type=ContestEntryType.CASH,
+ user=user_with_money,
+ amount=USDCent(60),
+ )
+ contest_manager.enter_contest(
+ contest_uuid=contest.uuid,
+ entry=entry,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ )
+ c: RaffleContest = contest_manager.get(contest_uuid=contest.uuid)
+ assert c.status == ContestStatus.COMPLETED
+ print(c)
+
+ user_contest = contest_manager.get_raffle_user_view(
+ contest_uuid=contest.uuid, user=user_with_money
+ )
+ assert user_contest.current_win_probability == 1
+ assert user_contest.projected_win_probability == 1
+ assert len(user_contest.user_winnings) == 1
+
+ # todo: make a all winning method
+ winnings = contest_manager.get_winnings_by_user(user=user_with_money)
+ assert len(winnings) == 1
+ win = winnings[0]
+ assert win.product_user_id == user_with_money.product_user_id
+
+ # Contest wallet should have gotten zeroed out
+ contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid(
+ contest_uuid=contest.uuid
+ )
+ assert thl_lm.get_account_balance(contest_wallet) == 0
+ # Expense wallet gets the $1.00 expense
+ expense_wallet = thl_lm.get_account_or_create_bp_expense_by_uuid(
+ product_uuid=user_with_money.product_id, expense_name="Prize"
+ )
+ assert thl_lm.get_account_balance(expense_wallet) == -100
+ # And the BP gets 20c
+ assert thl_lm.get_account_balance(bp_wallet) - bp_wallet_balance == 20
+
+ @pytest.mark.parametrize("user_with_money", [{"min_balance": 120}], indirect=True)
+ def test_enter_ends_cash_prize(
+ self, user_with_money: User, contest_factory, thl_lm, contest_manager
+ ):
+ # Same as test_enter_ends, but the prize is cash. Just
+ # testing the ledger methods
+ c = contest_factory(
+ prizes=[
+ ContestPrize(
+ name="$1.00 bonus",
+ kind=ContestPrizeKind.CASH,
+ estimated_cash_value=USDCent(100),
+ cash_amount=USDCent(100),
+ )
+ ]
+ )
+ assert c.prizes[0].kind == ContestPrizeKind.CASH
+
+ user_wallet = thl_lm.get_account_or_create_user_wallet(user=user_with_money)
+ user_balance = thl_lm.get_account_balance(user_wallet)
+ bp_wallet = thl_lm.get_account_or_create_bp_wallet_by_uuid(
+ user_with_money.product_id
+ )
+ bp_wallet_balance = thl_lm.get_account_balance(bp_wallet)
+
+ ## Enter Contest
+ entry = ContestEntry(
+ entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(120)
+ )
+ entry = contest_manager.enter_contest(
+ contest_uuid=c.uuid,
+ entry=entry,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ )
+
+ # The prize is $1.00, so the user spent $1.20 entering, won, then got $1.00 back
+ assert (
+ thl_lm.get_account_balance(account=user_wallet) == user_balance + 100 - 120
+ )
+ # contest wallet is 0, and the BP gets 20c
+ contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid(
+ contest_uuid=c.uuid
+ )
+ assert thl_lm.get_account_balance(account=contest_wallet) == 0
+ assert thl_lm.get_account_balance(account=bp_wallet) - bp_wallet_balance == 20
+
+ def test_enter_failure(
+ self,
+ user_with_wallet: User,
+ contest_in_db: RaffleContest,
+ thl_lm,
+ contest_manager,
+ ):
+ c = contest_in_db
+ user = user_with_wallet
+
+ # Tries to enter $0
+ with pytest.raises(ValidationError) as e:
+ entry = ContestEntry(
+ entry_type=ContestEntryType.CASH, user=user, amount=USDCent(0)
+ )
+ assert "Input should be greater than 0" in str(e.value)
+
+ # User has no money
+ entry = ContestEntry(
+ entry_type=ContestEntryType.CASH, user=user, amount=USDCent(20)
+ )
+ with pytest.raises(LedgerTransactionConditionFailedError) as e:
+ entry = contest_manager.enter_contest(
+ contest_uuid=c.uuid,
+ entry=entry,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ )
+ assert e.value.args[0] == "insufficient balance"
+
+ # Tries to enter with the wrong entry type (count, on a cash contest)
+ entry = ContestEntry(entry_type=ContestEntryType.COUNT, user=user, amount=1)
+ with pytest.raises(AssertionError) as e:
+ entry = contest_manager.enter_contest(
+ contest_uuid=c.uuid,
+ entry=entry,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ )
+ assert "incompatible entry type" in str(e.value)
+
+ @pytest.mark.parametrize("user_with_money", [{"min_balance": 100}], indirect=True)
+ def test_enter_not_eligible(
+ self, user_with_money: User, contest_factory, thl_lm, contest_manager
+ ):
+ # Max entry amount per user $0.10. Contest still ends at $1.00
+ c = contest_factory(
+ entry_rule=ContestEntryRule(
+ max_entry_amount_per_user=USDCent(10),
+ max_daily_entries_per_user=USDCent(8),
+ )
+ )
+ c: RaffleContest = contest_manager.get(c.uuid)
+ assert c.entry_rule.max_entry_amount_per_user == USDCent(10)
+ assert c.entry_rule.max_daily_entries_per_user == USDCent(8)
+
+ # User tries to enter $0.20
+ entry = ContestEntry(
+ entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(20)
+ )
+ with pytest.raises(ContestError) as e:
+ entry = contest_manager.enter_contest(
+ contest_uuid=c.uuid,
+ entry=entry,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ )
+ assert "Entry would exceed max amount per user." in str(e.value)
+
+ # User tries to enter $0.10
+ entry = ContestEntry(
+ entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(10)
+ )
+ with pytest.raises(ContestError) as e:
+ entry = contest_manager.enter_contest(
+ contest_uuid=c.uuid,
+ entry=entry,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ )
+ assert "Entry would exceed max amount per user per day." in str(e.value)
+
+ # User enters $0.08 successfully
+ entry = ContestEntry(
+ entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(8)
+ )
+ entry = contest_manager.enter_contest(
+ contest_uuid=c.uuid,
+ entry=entry,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ )
+
+ # Then can't anymore
+ entry = ContestEntry(
+ entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(1)
+ )
+ with pytest.raises(ContestError) as e:
+ entry = contest_manager.enter_contest(
+ contest_uuid=c.uuid,
+ entry=entry,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ )
+ assert "Entry would exceed max amount per user per day." in str(e.value)
+
+
+class TestRaffleContestUserViews:
+ def test_list_user_eligible_country(
+ self, user_with_wallet: User, contest_factory, thl_lm, contest_manager
+ ):
+ # No contests exists
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_wallet, country_iso="us"
+ )
+ assert len(cs) == 0
+
+ # Create a contest. It'll be in the US/CA
+ contest_factory(country_isos={"us", "ca"})
+
+ # Not eligible in mexico
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_wallet, country_iso="mx"
+ )
+ assert len(cs) == 0
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_wallet, country_iso="us"
+ )
+ assert len(cs) == 1
+
+ # Create another, any country
+ contest_factory(country_isos=None)
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_wallet, country_iso="mx"
+ )
+ assert len(cs) == 1
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_wallet, country_iso="us"
+ )
+ assert len(cs) == 2
+
+ def test_list_user_eligible(
+ self, user_with_money: User, contest_factory, thl_lm, contest_manager
+ ):
+ c = contest_factory(
+ end_condition=ContestEndCondition(target_entry_amount=USDCent(10)),
+ entry_rule=ContestEntryRule(
+ max_entry_amount_per_user=USDCent(1),
+ ),
+ )
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_money, country_iso="us"
+ )
+ assert len(cs) == 1
+
+ entry = ContestEntry(
+ entry_type=ContestEntryType.CASH,
+ user=user_with_money,
+ amount=USDCent(1),
+ )
+ contest_manager.enter_contest(
+ contest_uuid=c.uuid,
+ entry=entry,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ )
+
+ # User isn't eligible anymore
+ cs = contest_manager.get_many_by_user_eligible(
+ user=user_with_money, country_iso="us"
+ )
+ assert len(cs) == 0
+
+ # But it comes back in the list entered
+ cs = contest_manager.get_many_by_user_entered(user=user_with_money)
+ assert len(cs) == 1
+ c = cs[0]
+ assert c.user_amount == USDCent(1)
+ assert c.user_amount_today == USDCent(1)
+ assert c.current_win_probability == 1
+ assert c.projected_win_probability == approx(1 / 10, rel=0.01)
+
+ # And nothing won yet #todo
+ # cs = cm.get_many_by_user_won(user=user_with_money)
+
+ assert len(contest_manager.get_winnings_by_user(user_with_money)) == 0
+
+ def test_list_user_winnings(
+ self, user_with_money: User, contest_factory, thl_lm, contest_manager
+ ):
+ c = contest_factory(
+ end_condition=ContestEndCondition(target_entry_amount=USDCent(100)),
+ )
+ entry = ContestEntry(
+ entry_type=ContestEntryType.CASH,
+ user=user_with_money,
+ amount=USDCent(100),
+ )
+ contest_manager.enter_contest(
+ contest_uuid=c.uuid,
+ entry=entry,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ )
+ # Contest ends after 100 entry, user enters 100 entry, user wins!
+ ws = contest_manager.get_winnings_by_user(user_with_money)
+ assert len(ws) == 1
+ w = ws[0]
+ assert w.user.user_id == user_with_money.user_id
+ assert w.prize == c.prizes[0]
+ assert w.awarded_cash_amount is None
+
+ cs = contest_manager.get_many_by_user_won(user_with_money)
+ assert len(cs) == 1
+ c = cs[0]
+ w = c.user_winnings[0]
+ assert w.prize == c.prizes[0]
+ assert w.user.user_id == user_with_money.user_id
+
+
+class TestRaffleContestCRUDCount:
+ # This is a COUNT contest. No cash moves. Not really fleshed out what we'd do with this.
+ @pytest.mark.skip
+ def test_enter(
+ self, user_with_wallet: User, contest_factory, thl_lm, contest_manager
+ ):
+ c = contest_factory(entry_type=ContestEntryType.COUNT)
+ entry = ContestEntry(
+ entry_type=ContestEntryType.COUNT,
+ user=user_with_wallet,
+ amount=1,
+ )
+ contest_manager.enter_contest(
+ contest_uuid=c.uuid,
+ entry=entry,
+ country_iso="us",
+ ledger_manager=thl_lm,
+ )
diff --git a/tests/managers/thl/test_harmonized_uqa.py b/tests/managers/thl/test_harmonized_uqa.py
new file mode 100644
index 0000000..6bbbbe1
--- /dev/null
+++ b/tests/managers/thl/test_harmonized_uqa.py
@@ -0,0 +1,116 @@
+from datetime import datetime, timezone
+
+import pytest
+
+from generalresearch.managers.thl.profiling.uqa import UQAManager
+from generalresearch.models.thl.profiling.user_question_answer import (
+ UserQuestionAnswer,
+ DUMMY_UQA,
+)
+from generalresearch.models.thl.user import User
+
+
+@pytest.mark.usefixtures("uqa_db_index", "upk_data", "uqa_manager_clear_cache")
+class TestUQAManager:
+
+ def test_init(self, uqa_manager: UQAManager, user: User):
+ uqas = uqa_manager.get(user)
+ assert len(uqas) == 0
+
+ def test_create(self, uqa_manager: UQAManager, user: User):
+ now = datetime.now(tz=timezone.utc)
+ uqas = [
+ UserQuestionAnswer(
+ user_id=user.user_id,
+ question_id="fd5bd491b75a491aa7251159680bf1f1",
+ country_iso="us",
+ language_iso="eng",
+ answer=("2",),
+ timestamp=now,
+ property_code="m:job_role",
+ calc_answers={"m:job_role": ("2",)},
+ )
+ ]
+ uqa_manager.create(user, uqas)
+
+ res = uqa_manager.get(user=user)
+ assert len(res) == 1
+ assert res[0] == uqas[0]
+
+ # Same question, so this gets updated
+ now = datetime.now(tz=timezone.utc)
+ uqas_update = [
+ UserQuestionAnswer(
+ user_id=user.user_id,
+ question_id="fd5bd491b75a491aa7251159680bf1f1",
+ country_iso="us",
+ language_iso="eng",
+ answer=("3",),
+ timestamp=now,
+ property_code="m:job_role",
+ calc_answers={"m:job_role": ("3",)},
+ )
+ ]
+ uqa_manager.create(user, uqas_update)
+ res = uqa_manager.get(user=user)
+ assert len(res) == 1
+ assert res[0] == uqas_update[0]
+
+ # Add a new answer
+ now = datetime.now(tz=timezone.utc)
+ uqas_new = [
+ UserQuestionAnswer(
+ user_id=user.user_id,
+ question_id="3b65220db85f442ca16bb0f1c0e3a456",
+ country_iso="us",
+ language_iso="eng",
+ answer=("3",),
+ timestamp=now,
+ property_code="gr:children_age_gender",
+ calc_answers={"gr:children_age_gender": ("3",)},
+ )
+ ]
+ uqa_manager.create(user, uqas_new)
+ res = uqa_manager.get(user=user)
+ assert len(res) == 2
+ assert res[1] == uqas_update[0]
+ assert res[0] == uqas_new[0]
+
+
+@pytest.mark.usefixtures("uqa_db_index", "upk_data", "uqa_manager_clear_cache")
+class TestUQAManagerCache:
+
+ def test_get_uqa_empty(self, uqa_manager: UQAManager, user: User, caplog):
+ res = uqa_manager.get(user=user)
+ assert len(res) == 0
+
+ res = uqa_manager.get_from_db(user=user)
+ assert len(res) == 0
+
+ # Checking that the cache has only the dummy_uqa in it
+ res = uqa_manager.get_from_cache(user=user)
+ assert res == [DUMMY_UQA]
+
+ with caplog.at_level("INFO"):
+ res = uqa_manager.get(user=user)
+ assert f"thl-grpc:uqa-cache-v2:{user.user_id} exists" in caplog.text
+ assert len(res) == 0
+
+ def test_get_uqa(self, uqa_manager: UQAManager, user: User, caplog):
+
+ # Now the user sends an answer
+ uqas = [
+ UserQuestionAnswer(
+ question_id="5d6d9f3c03bb40bf9d0a24f306387d7c",
+ answer=("1",),
+ timestamp=datetime.now(tz=timezone.utc),
+ country_iso="us",
+ language_iso="eng",
+ property_code="gr:gender",
+ user_id=user.user_id,
+ calc_answers={},
+ )
+ ]
+ uqa_manager.update_cache(user=user, uqas=uqas)
+ res = uqa_manager.get_from_cache(user=user)
+ assert res == uqas
diff --git a/tests/managers/thl/test_ipinfo.py b/tests/managers/thl/test_ipinfo.py
new file mode 100644
index 0000000..847b00c
--- /dev/null
+++ b/tests/managers/thl/test_ipinfo.py
@@ -0,0 +1,117 @@
+import faker
+
+from generalresearch.managers.thl.ipinfo import (
+ IPGeonameManager,
+ IPInformationManager,
+ GeoIpInfoManager,
+)
+from generalresearch.models.thl.ipinfo import IPGeoname, IPInformation
+
+fake = faker.Faker()
+
+
+class TestIPGeonameManager:
+
+ def test_init(self, thl_web_rr, ip_geoname_manager: IPGeonameManager):
+
+ instance = IPGeonameManager(pg_config=thl_web_rr)
+ assert isinstance(instance, IPGeonameManager)
+ assert isinstance(ip_geoname_manager, IPGeonameManager)
+
+ def test_create(self, ip_geoname_manager: IPGeonameManager):
+
+ instance = ip_geoname_manager.create_dummy()
+
+ assert isinstance(instance, IPGeoname)
+
+ res = ip_geoname_manager.fetch_geoname_ids(filter_ids=[instance.geoname_id])
+
+ assert res[0].model_dump_json() == instance.model_dump_json()
+
+
+class TestIPInformationManager:
+
+ def test_init(self, thl_web_rr, ip_information_manager: IPInformationManager):
+ instance = IPInformationManager(pg_config=thl_web_rr)
+ assert isinstance(instance, IPInformationManager)
+ assert isinstance(ip_information_manager, IPInformationManager)
+
+ def test_create(self, ip_information_manager: IPInformationManager):
+ instance = ip_information_manager.create_dummy()
+
+ assert isinstance(instance, IPInformation)
+
+ res = ip_information_manager.fetch_ip_information(filter_ips=[instance.ip])
+
+ assert res[0].model_dump_json() == instance.model_dump_json()
+
+ def test_prefetch_geoname(self, ip_information, ip_geoname, thl_web_rr):
+ assert isinstance(ip_information, IPInformation)
+
+ assert ip_information.geoname_id == ip_geoname.geoname_id
+ assert ip_information.geoname is None
+
+ ip_information.prefetch_geoname(pg_config=thl_web_rr)
+ assert isinstance(ip_information.geoname, IPGeoname)
+
+
+class TestGeoIpInfoManager:
+ def test_init(
+ self, thl_web_rr, thl_redis_config, geoipinfo_manager: GeoIpInfoManager
+ ):
+ instance = GeoIpInfoManager(pg_config=thl_web_rr, redis_config=thl_redis_config)
+ assert isinstance(instance, GeoIpInfoManager)
+ assert isinstance(geoipinfo_manager, GeoIpInfoManager)
+
+ def test_multi(self, ip_information_factory, ip_geoname, geoipinfo_manager):
+ ip = fake.ipv4_public()
+ ip_information_factory(ip=ip, geoname=ip_geoname)
+ ips = [ip]
+
+ # This only looks up in redis. They don't exist yet
+ res = geoipinfo_manager.get_cache_multi(ip_addresses=ips)
+ assert res == {ip: None}
+
+ # Looks up in redis, if not exists, looks in mysql, then sets
+ # the caches that didn't exist.
+ res = geoipinfo_manager.get_multi(ip_addresses=ips)
+ assert res[ip] is not None
+
+ ip2 = fake.ipv4_public()
+ ip_information_factory(ip=ip2, geoname=ip_geoname)
+ ips = [ip, ip2]
+ res = geoipinfo_manager.get_cache_multi(ip_addresses=ips)
+ assert res[ip] is not None
+ assert res[ip2] is None
+ res = geoipinfo_manager.get_multi(ip_addresses=ips)
+ assert res[ip] is not None
+ assert res[ip2] is not None
+ res = geoipinfo_manager.get_cache_multi(ip_addresses=ips)
+ assert res[ip] is not None
+ assert res[ip2] is not None
+
+ def test_multi_ipv6(self, ip_information_factory, ip_geoname, geoipinfo_manager):
+ ip = fake.ipv6()
+ # Make another IP that will be in the same /64 block.
+ ip2 = ip[:-1] + "a" if ip[-1] != "a" else ip[:-1] + "b"
+ ip_information_factory(ip=ip, geoname=ip_geoname)
+ ips = [ip, ip2]
+ print(f"{ips=}")
+
+ # This only looks up in redis. They don't exist yet
+ res = geoipinfo_manager.get_cache_multi(ip_addresses=ips)
+ assert res == {ip: None, ip2: None}
+
+ # Looks up in redis, if not exists, looks in mysql, then sets
+ # the caches that didn't exist.
+ res = geoipinfo_manager.get_multi(ip_addresses=ips)
+ assert res[ip].ip == ip
+ assert res[ip].lookup_prefix == "/64"
+ assert res[ip2].ip == ip2
+ assert res[ip2].lookup_prefix == "/64"
+ # they should be the same basically, except for the ip
+
+ def test_doesnt_exist(self, geoipinfo_manager):
+ ip = fake.ipv4_public()
+ res = geoipinfo_manager.get_multi(ip_addresses=[ip])
+ assert res == {ip: None}
diff --git a/tests/managers/thl/test_ledger/__init__.py b/tests/managers/thl/test_ledger/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/managers/thl/test_ledger/__init__.py
diff --git a/tests/managers/thl/test_ledger/test_lm_accounts.py b/tests/managers/thl/test_ledger/test_lm_accounts.py
new file mode 100644
index 0000000..e0d9b0b
--- /dev/null
+++ b/tests/managers/thl/test_ledger/test_lm_accounts.py
@@ -0,0 +1,268 @@
+from itertools import product as iproduct
+from random import randint
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.currency import LedgerCurrency
+from generalresearch.managers.base import Permission
+from generalresearch.managers.thl.ledger_manager.exceptions import (
+ LedgerAccountDoesntExistError,
+)
+from generalresearch.managers.thl.ledger_manager.ledger import LedgerManager
+from generalresearch.models.thl.ledger import LedgerAccount, AccountType, Direction
+from generalresearch.models.thl.ledger import (
+ LedgerEntry,
+)
+from test_utils.managers.ledger.conftest import ledger_account
+
+
+@pytest.mark.parametrize(
+ argnames="currency, kind, acct_id",
+ argvalues=list(
+ iproduct(
+ ["USD", "test", "EUR"],
+ ["expense", "wallet", "revenue", "cash"],
+ [uuid4().hex for i in range(3)],
+ )
+ ),
+)
+class TestLedgerAccountManagerNoResults:
+
+ def test_get_account_no_results(self, currency, kind, acct_id, lm):
+ """Try to query for accounts that we know don't exist and confirm that
+ we either get the expected None result or it raises the correct
+ exception
+ """
+ qn = ":".join([currency, kind, acct_id])
+
+ # (1) .get_account is just a wrapper for .get_account_many_ but
+ # call it either way
+ assert lm.get_account(qualified_name=qn, raise_on_error=False) is None
+
+ with pytest.raises(expected_exception=LedgerAccountDoesntExistError):
+ lm.get_account(qualified_name=qn, raise_on_error=True)
+
+ # (2) .get_account_if_exists is another wrapper
+ assert lm.get_account(qualified_name=qn, raise_on_error=False) is None
+
+ def test_get_account_no_results_many(self, currency, kind, acct_id, lm):
+ qn = ":".join([currency, kind, acct_id])
+
+ # (1) .get_many_
+ assert lm.get_account_many_(qualified_names=[qn], raise_on_error=False) == []
+
+ with pytest.raises(expected_exception=LedgerAccountDoesntExistError):
+ lm.get_account_many_(qualified_names=[qn], raise_on_error=True)
+
+ # (2) .get_many
+ assert lm.get_account_many(qualified_names=[qn], raise_on_error=False) == []
+
+ with pytest.raises(expected_exception=LedgerAccountDoesntExistError):
+ lm.get_account_many(qualified_names=[qn], raise_on_error=True)
+
+ # (3) .get_accounts(..)
+ assert lm.get_accounts_if_exists(qualified_names=[qn]) == []
+
+ with pytest.raises(expected_exception=LedgerAccountDoesntExistError):
+ lm.get_accounts(qualified_names=[qn])
+
+
+@pytest.mark.parametrize(
+ argnames="currency, account_type, direction",
+ argvalues=list(
+ iproduct(
+ list(LedgerCurrency),
+ list(AccountType),
+ list(Direction),
+ )
+ ),
+)
+class TestLedgerAccountManagerCreate:
+
+ def test_create_account_error_permission(
+ self, currency, account_type, direction, lm
+ ):
+ """Confirm that the Permission values that are set on the Ledger Manger
+ allow the Creation action to occur.
+ """
+ acct_uuid = uuid4().hex
+
+ account = LedgerAccount(
+ display_name=f"test-{uuid4().hex}",
+ currency=currency,
+ qualified_name=f"{currency.value}:{account_type.value}:{acct_uuid}",
+ account_type=account_type,
+ normal_balance=direction,
+ )
+
+ # (1) With no Permissions defined
+ test_lm = LedgerManager(
+ pg_config=lm.pg_config,
+ permissions=[],
+ redis_config=lm.redis_config,
+ cache_prefix=lm.cache_prefix,
+ testing=lm.testing,
+ )
+
+ with pytest.raises(expected_exception=AssertionError) as excinfo:
+ test_lm.create_account(account=account)
+ assert (
+ str(excinfo.value) == "LedgerManager does not have sufficient permissions"
+ )
+
+ # (2) With Permissions defined, but not CREATE
+ test_lm = LedgerManager(
+ pg_config=lm.pg_config,
+ permissions=[Permission.READ, Permission.UPDATE, Permission.DELETE],
+ redis_config=lm.redis_config,
+ cache_prefix=lm.cache_prefix,
+ testing=lm.testing,
+ )
+
+ with pytest.raises(expected_exception=AssertionError) as excinfo:
+ test_lm.create_account(account=account)
+ assert (
+ str(excinfo.value) == "LedgerManager does not have sufficient permissions"
+ )
+
+ def test_create(self, currency, account_type, direction, lm):
+ """Confirm that the Permission values that are set on the Ledger Manger
+ allow the Creation action to occur.
+ """
+
+ acct_uuid = uuid4().hex
+ qn = f"{currency.value}:{account_type.value}:{acct_uuid}"
+
+ acct_model = LedgerAccount(
+ uuid=acct_uuid,
+ display_name=f"test-{uuid4().hex}",
+ currency=currency,
+ qualified_name=qn,
+ account_type=account_type,
+ normal_balance=direction,
+ )
+ account = lm.create_account(account=acct_model)
+ assert isinstance(account, LedgerAccount)
+
+ # Query for, and make sure the Account was saved in the DB
+ res = lm.get_account(qualified_name=qn, raise_on_error=True)
+ assert account.uuid == res.uuid
+
+ def test_get_or_create(self, currency, account_type, direction, lm):
+ """Confirm that the Permission values that are set on the Ledger Manger
+ allow the Creation action to occur.
+ """
+
+ acct_uuid = uuid4().hex
+ qn = f"{currency.value}:{account_type.value}:{acct_uuid}"
+
+ acct_model = LedgerAccount(
+ uuid=acct_uuid,
+ display_name=f"test-{uuid4().hex}",
+ currency=currency,
+ qualified_name=qn,
+ account_type=account_type,
+ normal_balance=direction,
+ )
+ account = lm.get_account_or_create(account=acct_model)
+ assert isinstance(account, LedgerAccount)
+
+ # Query for, and make sure the Account was saved in the DB
+ res = lm.get_account(qualified_name=qn, raise_on_error=True)
+ assert account.uuid == res.uuid
+
+
+class TestLedgerAccountManagerGet:
+
+ def test_get(self, ledger_account, lm):
+ res = lm.get_account(qualified_name=ledger_account.qualified_name)
+ assert res.uuid == ledger_account.uuid
+
+ res = lm.get_account_many(qualified_names=[ledger_account.qualified_name])
+ assert len(res) == 1
+ assert res[0].uuid == ledger_account.uuid
+
+ res = lm.get_accounts(qualified_names=[ledger_account.qualified_name])
+ assert len(res) == 1
+ assert res[0].uuid == ledger_account.uuid
+
+ # TODO: I can't test the get_balance without first having Transaction
+ # creation working
+
+ def test_get_balance_empty(
+ self, ledger_account, ledger_account_credit, ledger_account_debit, ledger_tx, lm
+ ):
+ res = lm.get_account_balance(account=ledger_account)
+ assert res == 0
+
+ res = lm.get_account_balance(account=ledger_account_credit)
+ assert res == 100
+
+ res = lm.get_account_balance(account=ledger_account_debit)
+ assert res == 100
+
+ @pytest.mark.parametrize("n_times", range(5))
+ def test_get_account_filtered_balance(
+ self,
+ ledger_account,
+ ledger_account_credit,
+ ledger_account_debit,
+ ledger_tx,
+ n_times,
+ lm,
+ ):
+ """Try searching for random metadata and confirm it's always 0 because
+ Tx can be found.
+ """
+ rand_key = f"key-{uuid4().hex[:10]}"
+ rand_value = uuid4().hex
+
+ assert (
+ lm.get_account_filtered_balance(
+ account=ledger_account, metadata_key=rand_key, metadata_value=rand_value
+ )
+ == 0
+ )
+
+ # Let's create a transaction with this metadata to confirm it saves
+ # and that we can filter it back
+ rand_amount = randint(10, 1_000)
+
+ lm.create_tx(
+ entries=[
+ LedgerEntry(
+ direction=Direction.CREDIT,
+ account_uuid=ledger_account_credit.uuid,
+ amount=rand_amount,
+ ),
+ LedgerEntry(
+ direction=Direction.DEBIT,
+ account_uuid=ledger_account_debit.uuid,
+ amount=rand_amount,
+ ),
+ ],
+ metadata={rand_key: rand_value},
+ )
+
+ assert (
+ lm.get_account_filtered_balance(
+ account=ledger_account_credit,
+ metadata_key=rand_key,
+ metadata_value=rand_value,
+ )
+ == rand_amount
+ )
+
+ assert (
+ lm.get_account_filtered_balance(
+ account=ledger_account_debit,
+ metadata_key=rand_key,
+ metadata_value=rand_value,
+ )
+ == rand_amount
+ )
+
+ def test_get_balance_timerange_empty(self, ledger_account, lm):
+ res = lm.get_account_balance_timerange(account=ledger_account)
+ assert res == 0
diff --git a/tests/managers/thl/test_ledger/test_lm_tx.py b/tests/managers/thl/test_ledger/test_lm_tx.py
new file mode 100644
index 0000000..37b7ba3
--- /dev/null
+++ b/tests/managers/thl/test_ledger/test_lm_tx.py
@@ -0,0 +1,235 @@
+from decimal import Decimal
+from random import randint
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.currency import LedgerCurrency
+from generalresearch.managers.thl.ledger_manager.ledger import LedgerManager
+from generalresearch.models.thl.ledger import (
+ Direction,
+ LedgerEntry,
+ LedgerTransaction,
+)
+
+
+class TestLedgerManagerCreateTx:
+
+ def test_create_account_error_permission(self, lm):
+ """Confirm that the Permission values that are set on the Ledger Manger
+ allow the Creation action to occur.
+ """
+ acct_uuid = uuid4().hex
+
+ # (1) With no Permissions defined
+ test_lm = LedgerManager(
+ pg_config=lm.pg_config,
+ permissions=[],
+ redis_config=lm.redis_config,
+ cache_prefix=lm.cache_prefix,
+ testing=lm.testing,
+ )
+
+ with pytest.raises(expected_exception=AssertionError) as excinfo:
+ test_lm.create_tx(entries=[])
+ assert (
+ str(excinfo.value)
+ == "LedgerTransactionManager has insufficient Permissions"
+ )
+
+ def test_create_assertions(self, ledger_account_debit, ledger_account_credit, lm):
+ with pytest.raises(expected_exception=ValueError) as excinfo:
+ lm.create_tx(
+ entries=[
+ {
+ "direction": Direction.CREDIT,
+ "account_uuid": uuid4().hex,
+ "amount": randint(a=1, b=100),
+ }
+ ]
+ )
+ assert (
+ "Assertion failed, ledger transaction must have 2 or more entries"
+ in str(excinfo.value)
+ )
+
+ def test_create(self, ledger_account_credit, ledger_account_debit, lm):
+ amount = int(Decimal("1.00") * 100)
+
+ entries = [
+ LedgerEntry(
+ direction=Direction.CREDIT,
+ account_uuid=ledger_account_credit.uuid,
+ amount=amount,
+ ),
+ LedgerEntry(
+ direction=Direction.DEBIT,
+ account_uuid=ledger_account_debit.uuid,
+ amount=amount,
+ ),
+ ]
+
+ # Create a Transaction and validate the operation was successful
+ tx = lm.create_tx(entries=entries)
+ assert isinstance(tx, LedgerTransaction)
+
+ res = lm.get_tx_by_id(transaction_id=tx.id)
+ assert isinstance(res, LedgerTransaction)
+ assert len(res.entries) == 2
+ assert tx.id == res.id
+
+ def test_create_and_reverse(self, ledger_account_credit, ledger_account_debit, lm):
+ amount = int(Decimal("1.00") * 100)
+
+ entries = [
+ LedgerEntry(
+ direction=Direction.CREDIT,
+ account_uuid=ledger_account_credit.uuid,
+ amount=amount,
+ ),
+ LedgerEntry(
+ direction=Direction.DEBIT,
+ account_uuid=ledger_account_debit.uuid,
+ amount=amount,
+ ),
+ ]
+
+ tx = lm.create_tx(entries=entries)
+ res = lm.get_tx_by_id(transaction_id=tx.id)
+ assert res.id == tx.id
+
+ assert lm.get_account_balance(account=ledger_account_credit) == 100
+ assert lm.get_account_balance(account=ledger_account_debit) == 100
+ assert lm.check_ledger_balanced() is True
+
+ # Reverse it
+ entries = [
+ LedgerEntry(
+ direction=Direction.DEBIT,
+ account_uuid=ledger_account_credit.uuid,
+ amount=amount,
+ ),
+ LedgerEntry(
+ direction=Direction.CREDIT,
+ account_uuid=ledger_account_debit.uuid,
+ amount=amount,
+ ),
+ ]
+
+ tx = lm.create_tx(entries=entries)
+ res = lm.get_tx_by_id(transaction_id=tx.id)
+ assert res.id == tx.id
+
+ assert lm.get_account_balance(ledger_account_credit) == 0
+ assert lm.get_account_balance(ledger_account_debit) == 0
+ assert lm.check_ledger_balanced()
+
+ # subtract again
+ entries = [
+ LedgerEntry(
+ direction=Direction.DEBIT,
+ account_uuid=ledger_account_credit.uuid,
+ amount=amount,
+ ),
+ LedgerEntry(
+ direction=Direction.CREDIT,
+ account_uuid=ledger_account_debit.uuid,
+ amount=amount,
+ ),
+ ]
+ tx = lm.create_tx(entries=entries)
+ res = lm.get_tx_by_id(transaction_id=tx.id)
+ assert res.id == tx.id
+
+ assert lm.get_account_balance(ledger_account_credit) == -100
+ assert lm.get_account_balance(ledger_account_debit) == -100
+ assert lm.check_ledger_balanced()
+
+
+class TestLedgerManagerGetTx:
+
+ # @pytest.mark.parametrize("currency", [LedgerCurrency.TEST], indirect=True)
+ def test_get_tx_by_id(self, ledger_tx, lm):
+ with pytest.raises(expected_exception=AssertionError):
+ lm.get_tx_by_id(transaction_id=ledger_tx)
+
+ res = lm.get_tx_by_id(transaction_id=ledger_tx.id)
+ assert res.id == ledger_tx.id
+
+ # @pytest.mark.parametrize("currency", [LedgerCurrency.TEST], indirect=True)
+ def test_get_tx_by_ids(self, ledger_tx, lm):
+ res = lm.get_tx_by_id(transaction_id=ledger_tx.id)
+ assert res.id == ledger_tx.id
+
+ @pytest.mark.parametrize(
+ "tag", [f"{LedgerCurrency.TEST}:{uuid4().hex}"], indirect=True
+ )
+ def test_get_tx_ids_by_tag(self, ledger_tx, tag, lm):
+ # (1) search for a random tag
+ res = lm.get_tx_ids_by_tag(tag="aaa:bbb")
+ assert isinstance(res, set)
+ assert len(res) == 0
+
+ # (2) search for the tag that was used during ledger_transaction creation
+ res = lm.get_tx_ids_by_tag(tag=tag)
+ assert isinstance(res, set)
+ assert len(res) == 1
+
+ def test_get_tx_by_tag(self, ledger_tx, tag, lm):
+ # (1) search for a random tag
+ res = lm.get_tx_by_tag(tag="aaa:bbb")
+ assert isinstance(res, list)
+ assert len(res) == 0
+
+ # (2) search for the tag that was used during ledger_transaction creation
+ res = lm.get_tx_by_tag(tag=tag)
+ assert isinstance(res, list)
+ assert len(res) == 1
+
+ assert isinstance(res[0], LedgerTransaction)
+ assert ledger_tx.id == res[0].id
+
+ def test_get_tx_filtered_by_account(
+ self, ledger_tx, ledger_account, ledger_account_debit, ledger_account_credit, lm
+ ):
+ # (1) Do basic assertion checks first
+ with pytest.raises(expected_exception=AssertionError) as excinfo:
+ lm.get_tx_filtered_by_account(account_uuid=ledger_account)
+ assert str(excinfo.value) == "account_uuid must be a str"
+
+ # (2) This search doesn't return anything because this ledger account
+ # wasn't actually used in the entries for the ledger_transaction
+ res = lm.get_tx_filtered_by_account(account_uuid=ledger_account.uuid)
+ assert len(res) == 0
+
+ # (3) Either the credit or the debit example ledger_accounts wll work
+ # to find this transaction because they're both used in the entries
+ res = lm.get_tx_filtered_by_account(account_uuid=ledger_account_debit.uuid)
+ assert len(res) == 1
+ assert res[0].id == ledger_tx.id
+
+ res = lm.get_tx_filtered_by_account(account_uuid=ledger_account_credit.uuid)
+ assert len(res) == 1
+ assert ledger_tx.id == res[0].id
+
+ res2 = lm.get_tx_by_id(transaction_id=ledger_tx.id)
+ assert res2.model_dump_json() == res[0].model_dump_json()
+
+ def test_filter_metadata(self, ledger_tx, tx_metadata, lm):
+ key, value = next(iter(tx_metadata.items()))
+
+ # (1) Confirm a random key,value pair returns nothing
+ res = lm.get_tx_filtered_by_metadata(
+ metadata_key=f"key-{uuid4().hex[:10]}", metadata_value=uuid4().hex[:12]
+ )
+ assert len(res) == 0
+
+ # (2) confirm a key,value pair return the correct results
+ res = lm.get_tx_filtered_by_metadata(metadata_key=key, metadata_value=value)
+ assert len(res) == 1
+
+ # assert 0 == THL_lm.get_filtered_account_balance(account2, "thl_wall", "ccc")
+ # assert 300 == THL_lm.get_filtered_account_balance(account1, "thl_wall", "aaa")
+ # assert 300 == THL_lm.get_filtered_account_balance(account2, "thl_wall", "aaa")
+ # assert 0 == THL_lm.get_filtered_account_balance(account3, "thl_wall", "ccc")
+ # assert THL_lm.check_ledger_balanced()
diff --git a/tests/managers/thl/test_ledger/test_lm_tx_entries.py b/tests/managers/thl/test_ledger/test_lm_tx_entries.py
new file mode 100644
index 0000000..5bf1c48
--- /dev/null
+++ b/tests/managers/thl/test_ledger/test_lm_tx_entries.py
@@ -0,0 +1,26 @@
+from generalresearch.models.thl.ledger import LedgerEntry
+
+
+class TestLedgerEntryManager:
+
+ def test_get_tx_entries_by_tx(self, ledger_tx, lm):
+ # First confirm the Ledger TX exists with 2 Entries
+ res = lm.get_tx_by_id(transaction_id=ledger_tx.id)
+ assert len(res.entries) == 2
+
+ tx_entries = lm.get_tx_entries_by_tx(transaction=ledger_tx)
+ assert len(tx_entries) == 2
+
+ assert res.entries == tx_entries
+ assert isinstance(tx_entries[0], LedgerEntry)
+
+ def test_get_tx_entries_by_txs(self, ledger_tx, lm):
+ # First confirm the Ledger TX exists with 2 Entries
+ res = lm.get_tx_by_id(transaction_id=ledger_tx.id)
+ assert len(res.entries) == 2
+
+ tx_entries = lm.get_tx_entries_by_txs(transactions=[ledger_tx])
+ assert len(tx_entries) == 2
+
+ assert res.entries == tx_entries
+ assert isinstance(tx_entries[0], LedgerEntry)
diff --git a/tests/managers/thl/test_ledger/test_lm_tx_locks.py b/tests/managers/thl/test_ledger/test_lm_tx_locks.py
new file mode 100644
index 0000000..df2611b
--- /dev/null
+++ b/tests/managers/thl/test_ledger/test_lm_tx_locks.py
@@ -0,0 +1,371 @@
+import logging
+from datetime import datetime, timezone, timedelta
+from decimal import Decimal
+from typing import Callable
+
+import pytest
+
+from generalresearch.managers.thl.ledger_manager.conditions import (
+ generate_condition_mp_payment,
+)
+from generalresearch.managers.thl.ledger_manager.exceptions import (
+ LedgerTransactionCreateLockError,
+ LedgerTransactionFlagAlreadyExistsError,
+ LedgerTransactionCreateError,
+)
+from generalresearch.models import Source
+from generalresearch.models.thl.ledger import LedgerTransaction
+from generalresearch.models.thl.session import (
+ Wall,
+ Status,
+ StatusCode1,
+ Session,
+ WallAdjustedStatus,
+)
+from generalresearch.models.thl.user import User
+from test_utils.models.conftest import user_factory, session, product_user_wallet_no
+
+logger = logging.getLogger("LedgerManager")
+
+
+class TestLedgerLocks:
+
+ def test_a(
+ self,
+ user_factory,
+ session_factory,
+ product_user_wallet_no,
+ create_main_accounts,
+ caplog,
+ thl_lm,
+ lm,
+ utc_hour_ago,
+ currency,
+ wall_factory,
+ delete_ledger_db,
+ ):
+ """
+ TODO: This whole test is confusing a I don't really understand.
+ It needs to be better documented and explained what we want
+ it to do and evaluate...
+ """
+ delete_ledger_db()
+ create_main_accounts()
+
+ user: User = user_factory(product=product_user_wallet_no)
+ s1 = session_factory(
+ user=user,
+ wall_count=3,
+ wall_req_cpis=[Decimal("1.23"), Decimal("3.21"), Decimal("4")],
+ wall_statuses=[Status.COMPLETE, Status.COMPLETE, Status.COMPLETE],
+ )
+
+ # A User does a Wall Completion in Session=1
+ w1 = s1.wall_events[0]
+ tx = thl_lm.create_tx_task_complete(wall=w1, user=user, created=w1.started)
+ assert isinstance(tx, LedgerTransaction)
+
+ # A User does another Wall Completion in Session=1
+ w2 = s1.wall_events[1]
+ tx = thl_lm.create_tx_task_complete(wall=w2, user=user, created=w2.started)
+ assert isinstance(tx, LedgerTransaction)
+
+ # That first Wall Complete was "adjusted" to instead be marked
+ # as a Failure
+ w1.update(
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ adjusted_cpi=0,
+ adjusted_timestamp=utc_hour_ago + timedelta(hours=1),
+ )
+ tx = thl_lm.create_tx_task_adjustment(wall=w1, user=user)
+ assert isinstance(tx, LedgerTransaction)
+
+ # A User does another! Wall Completion in Session=1; however, we
+ # don't create a transaction for it
+ w3 = s1.wall_events[2]
+
+ # Make sure we clear any flags/locks first
+ lock_key = f"{currency.value}:thl_wall:{w3.uuid}"
+ lock_name = f"{lm.cache_prefix}:transaction_lock:{lock_key}"
+ flag_name = f"{lm.cache_prefix}:transaction_flag:{lock_key}"
+ lm.redis_client.delete(lock_name)
+ lm.redis_client.delete(flag_name)
+
+ # Despite the
+ f1 = generate_condition_mp_payment(wall=w1)
+ f2 = generate_condition_mp_payment(wall=w2)
+ f3 = generate_condition_mp_payment(wall=w3)
+ assert f1(lm=lm) is False
+ assert f2(lm=lm) is False
+ assert f3(lm=lm) is True
+
+ condition = f3
+ create_tx_func = lambda: thl_lm.create_tx_task_complete_(wall=w3, user=user)
+ assert isinstance(create_tx_func, Callable)
+ assert f3(lm) is True
+
+ lm.redis_client.delete(flag_name)
+ lm.redis_client.delete(lock_name)
+
+ tx = thl_lm.create_tx_protected(
+ lock_key=lock_key, condition=condition, create_tx_func=create_tx_func
+ )
+ assert f3(lm) is False
+
+ # purposely hold the lock open
+ tx = None
+ lm.redis_client.set(lock_name, "1")
+ with caplog.at_level(logging.ERROR):
+ with pytest.raises(expected_exception=LedgerTransactionCreateLockError):
+ tx = thl_lm.create_tx_protected(
+ lock_key=lock_key,
+ condition=condition,
+ create_tx_func=create_tx_func,
+ )
+ assert tx is None
+ assert "Unable to acquire lock within the time specified" in caplog.text
+ lm.redis_client.delete(lock_name)
+
+ def test_locking(
+ self,
+ user_factory,
+ product_user_wallet_no,
+ create_main_accounts,
+ delete_ledger_db,
+ caplog,
+ thl_lm,
+ lm,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ now = datetime.now(timezone.utc) - timedelta(hours=1)
+ user: User = user_factory(product=product_user_wallet_no)
+
+ # A User does a Wall complete on Session.id=1 and the transaction is
+ # logged to the ledger
+ wall1 = Wall(
+ user_id=user.user_id,
+ source=Source.DYNATA,
+ req_survey_id="xxx",
+ req_cpi=Decimal("1.23"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=now,
+ finished=now + timedelta(seconds=1),
+ )
+ thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started)
+
+ # A User does a Wall complete on Session.id=1 and the transaction is
+ # logged to the ledger
+ wall2 = Wall(
+ user_id=user.user_id,
+ source=Source.FULL_CIRCLE,
+ req_survey_id="yyy",
+ req_cpi=Decimal("3.21"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=now,
+ finished=now + timedelta(seconds=1),
+ )
+ thl_lm.create_tx_task_complete(wall=wall2, user=user, created=wall2.started)
+
+ # An hour later, the first wall complete is adjusted to a Failure and
+ # it's tracked in the ledger
+ wall1.update(
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ adjusted_cpi=0,
+ adjusted_timestamp=now + timedelta(hours=1),
+ )
+ thl_lm.create_tx_task_adjustment(wall=wall1, user=user)
+
+ # A User does a Wall complete on Session.id=1 and the transaction
+ # IS NOT logged to the ledger
+ wall3 = Wall(
+ user_id=user.user_id,
+ source=Source.DYNATA,
+ req_survey_id="xxx",
+ req_cpi=Decimal("4"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=now,
+ finished=now + timedelta(seconds=1),
+ uuid="867a282d8b4d40d2a2093d75b802b629",
+ )
+
+ revenue_account = thl_lm.get_account_task_complete_revenue()
+ assert 0 == thl_lm.get_account_filtered_balance(
+ account=revenue_account,
+ metadata_key="thl_wall",
+ metadata_value=wall3.uuid,
+ )
+ # Make sure we clear any flags/locks first
+ lock_key = f"test:thl_wall:{wall3.uuid}"
+ lock_name = f"{lm.cache_prefix}:transaction_lock:{lock_key}"
+ flag_name = f"{lm.cache_prefix}:transaction_flag:{lock_key}"
+ lm.redis_client.delete(lock_name)
+ lm.redis_client.delete(flag_name)
+
+ # Purposely hold the lock open
+ lm.redis_client.set(name=lock_name, value="1")
+ with caplog.at_level(logging.DEBUG):
+ with pytest.raises(expected_exception=LedgerTransactionCreateLockError):
+ tx = thl_lm.create_tx_task_complete(
+ wall=wall3, user=user, created=wall3.started
+ )
+ assert isinstance(tx, LedgerTransaction)
+ assert "Unable to acquire lock within the time specified" in caplog.text
+
+ # Release the lock
+ lm.redis_client.delete(lock_name)
+
+ # Set the redis flag to indicate it has been run
+ lm.redis_client.set(flag_name, "1")
+ # with self.assertLogs(logger=logger, level=logging.DEBUG) as cm2:
+ with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError):
+ tx = thl_lm.create_tx_task_complete(
+ wall=wall3, user=user, created=wall3.started
+ )
+ # self.assertIn("entered_lock: True, flag_set: True", cm2.output[0])
+
+ # Unset the flag
+ lm.redis_client.delete(flag_name)
+
+ assert 0 == lm.get_account_filtered_balance(
+ account=revenue_account,
+ metadata_key="thl_wall",
+ metadata_value=wall3.uuid,
+ )
+
+ # Now actually run it
+ tx = thl_lm.create_tx_task_complete(
+ wall=wall3, user=user, created=wall3.started
+ )
+ assert tx is not None
+
+ # Run it again, should return None
+ # Confirm the Exception inheritance works
+ tx = None
+ with pytest.raises(expected_exception=LedgerTransactionCreateError):
+ tx = thl_lm.create_tx_task_complete(
+ wall=wall3, user=user, created=wall3.started
+ )
+ assert tx is None
+
+ # clear the redis flag, it should query the db
+ assert lm.redis_client.get(flag_name) is not None
+ lm.redis_client.delete(flag_name)
+ assert lm.redis_client.get(flag_name) is None
+
+ with pytest.raises(expected_exception=LedgerTransactionCreateError):
+ tx = thl_lm.create_tx_task_complete(
+ wall=wall3, user=user, created=wall3.started
+ )
+
+ assert 400 == thl_lm.get_account_filtered_balance(
+ account=revenue_account,
+ metadata_key="thl_wall",
+ metadata_value=wall3.uuid,
+ )
+
+ def test_bp_payment_without_locks(
+ self, user_factory, product_user_wallet_no, create_main_accounts, thl_lm, lm
+ ):
+ user: User = user_factory(product=product_user_wallet_no)
+ wall1 = Wall(
+ user_id=user.user_id,
+ source=Source.SAGO,
+ req_survey_id="xxx",
+ req_cpi=Decimal("0.50"),
+ session_id=3,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=datetime.now(timezone.utc),
+ finished=datetime.now(timezone.utc) + timedelta(seconds=1),
+ )
+
+ thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started)
+ session = Session(started=wall1.started, user=user, wall_events=[wall1])
+ status, status_code_1 = session.determine_session_status()
+ thl_net, commission_amount, bp_pay, user_pay = session.determine_payments()
+ session.update(
+ **{
+ "status": status,
+ "status_code_1": status_code_1,
+ "finished": session.started + timedelta(minutes=10),
+ "payout": bp_pay,
+ "user_payout": user_pay,
+ }
+ )
+ print(thl_net, commission_amount, bp_pay, user_pay)
+
+ # Run it 3 times without any checks, and it gets made three times!
+ thl_lm.create_tx_bp_payment(session=session, created=wall1.started)
+ thl_lm.create_tx_bp_payment_(session=session, created=wall1.started)
+ thl_lm.create_tx_bp_payment_(session=session, created=wall1.started)
+
+ bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=user.product)
+ assert 48 * 3 == lm.get_account_balance(account=bp_wallet)
+ assert 48 * 3 == thl_lm.get_account_filtered_balance(
+ account=bp_wallet, metadata_key="thl_session", metadata_value=session.uuid
+ )
+ assert lm.check_ledger_balanced()
+
+ def test_bp_payment_with_locks(
+ self, user_factory, product_user_wallet_no, create_main_accounts, thl_lm, lm
+ ):
+ user: User = user_factory(product=product_user_wallet_no)
+
+ wall1 = Wall(
+ user_id=user.user_id,
+ source=Source.SAGO,
+ req_survey_id="xxx",
+ req_cpi=Decimal("0.50"),
+ session_id=3,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=datetime.now(timezone.utc),
+ finished=datetime.now(timezone.utc) + timedelta(seconds=1),
+ )
+
+ thl_lm.create_tx_task_complete(wall1, user, created=wall1.started)
+ session = Session(started=wall1.started, user=user, wall_events=[wall1])
+ status, status_code_1 = session.determine_session_status()
+ thl_net, commission_amount, bp_pay, user_pay = session.determine_payments()
+ session.update(
+ **{
+ "status": status,
+ "status_code_1": status_code_1,
+ "finished": session.started + timedelta(minutes=10),
+ "payout": bp_pay,
+ "user_payout": user_pay,
+ }
+ )
+ print(thl_net, commission_amount, bp_pay, user_pay)
+
+ # Make sure we clear any flags/locks first
+ lock_key = f"test:thl_wall:{wall1.uuid}"
+ lock_name = f"{lm.cache_prefix}:transaction_lock:{lock_key}"
+ flag_name = f"{lm.cache_prefix}:transaction_flag:{lock_key}"
+ lm.redis_client.delete(lock_name)
+ lm.redis_client.delete(flag_name)
+
+ # Run it 3 times with check, and it gets made once!
+ thl_lm.create_tx_bp_payment(session=session, created=wall1.started)
+ with pytest.raises(expected_exception=LedgerTransactionCreateError):
+ thl_lm.create_tx_bp_payment(session=session, created=wall1.started)
+
+ with pytest.raises(expected_exception=LedgerTransactionCreateError):
+ thl_lm.create_tx_bp_payment(session=session, created=wall1.started)
+
+ bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=user.product)
+ assert 48 == thl_lm.get_account_balance(bp_wallet)
+ assert 48 == thl_lm.get_account_filtered_balance(
+ account=bp_wallet,
+ metadata_key="thl_session",
+ metadata_value=session.uuid,
+ )
+ assert lm.check_ledger_balanced()
diff --git a/tests/managers/thl/test_ledger/test_lm_tx_metadata.py b/tests/managers/thl/test_ledger/test_lm_tx_metadata.py
new file mode 100644
index 0000000..5d12633
--- /dev/null
+++ b/tests/managers/thl/test_ledger/test_lm_tx_metadata.py
@@ -0,0 +1,34 @@
+class TestLedgerMetadataManager:
+
+ def test_get_tx_metadata_by_txs(self, ledger_tx, lm):
+ # First confirm the Ledger TX exists with 2 Entries
+ res = lm.get_tx_by_id(transaction_id=ledger_tx.id)
+ assert isinstance(res.metadata, dict)
+
+ tx_metadatas = lm.get_tx_metadata_by_txs(transactions=[ledger_tx])
+ assert isinstance(tx_metadatas, dict)
+ assert isinstance(tx_metadatas[ledger_tx.id], dict)
+
+ assert res.metadata == tx_metadatas[ledger_tx.id]
+
+ def test_get_tx_metadata_ids_by_tx(self, ledger_tx, lm):
+ # First confirm the Ledger TX exists with 2 Entries
+ res = lm.get_tx_by_id(transaction_id=ledger_tx.id)
+ tx_metadata_cnt = len(res.metadata.keys())
+
+ tx_metadata_ids = lm.get_tx_metadata_ids_by_tx(transaction=ledger_tx)
+ assert isinstance(tx_metadata_ids, set)
+ assert isinstance(list(tx_metadata_ids)[0], int)
+
+ assert tx_metadata_cnt == len(tx_metadata_ids)
+
+ def test_get_tx_metadata_ids_by_txs(self, ledger_tx, lm):
+ # First confirm the Ledger TX exists with 2 Entries
+ res = lm.get_tx_by_id(transaction_id=ledger_tx.id)
+ tx_metadata_cnt = len(res.metadata.keys())
+
+ tx_metadata_ids = lm.get_tx_metadata_ids_by_txs(transactions=[ledger_tx])
+ assert isinstance(tx_metadata_ids, set)
+ assert isinstance(list(tx_metadata_ids)[0], int)
+
+ assert tx_metadata_cnt == len(tx_metadata_ids)
diff --git a/tests/managers/thl/test_ledger/test_thl_lm_accounts.py b/tests/managers/thl/test_ledger/test_thl_lm_accounts.py
new file mode 100644
index 0000000..01d5fe1
--- /dev/null
+++ b/tests/managers/thl/test_ledger/test_thl_lm_accounts.py
@@ -0,0 +1,411 @@
+from uuid import uuid4
+
+import pytest
+
+
+class TestThlLedgerManagerAccounts:
+
+ def test_get_account_or_create_user_wallet(self, user, thl_lm, lm):
+ from generalresearch.currency import LedgerCurrency
+ from generalresearch.models.thl.ledger import (
+ LedgerAccount,
+ Direction,
+ AccountType,
+ )
+
+ account = thl_lm.get_account_or_create_user_wallet(user=user)
+ assert isinstance(account, LedgerAccount)
+
+ assert user.uuid in account.qualified_name
+ assert account.display_name == f"User Wallet {user.uuid}"
+ assert account.account_type == AccountType.USER_WALLET
+ assert account.normal_balance == Direction.CREDIT
+ assert account.reference_type == "user"
+ assert account.reference_uuid == user.uuid
+ assert account.currency == LedgerCurrency.TEST
+
+ # Actually query for it to confirm
+ res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True)
+ assert res.model_dump_json() == account.model_dump_json()
+
+ def test_get_account_or_create_bp_wallet(self, product, thl_lm, lm):
+ from generalresearch.currency import LedgerCurrency
+ from generalresearch.models.thl.ledger import (
+ LedgerAccount,
+ Direction,
+ AccountType,
+ )
+
+ account = thl_lm.get_account_or_create_bp_wallet(product=product)
+ assert isinstance(account, LedgerAccount)
+
+ assert product.uuid in account.qualified_name
+ assert account.display_name == f"BP Wallet {product.uuid}"
+ assert account.account_type == AccountType.BP_WALLET
+ assert account.normal_balance == Direction.CREDIT
+ assert account.reference_type == "bp"
+ assert account.reference_uuid == product.uuid
+ assert account.currency == LedgerCurrency.TEST
+
+ # Actually query for it to confirm
+ res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True)
+ assert res.model_dump_json() == account.model_dump_json()
+
+ def test_get_account_or_create_bp_commission(self, product, thl_lm, lm):
+ from generalresearch.currency import LedgerCurrency
+ from generalresearch.models.thl.ledger import (
+ Direction,
+ AccountType,
+ )
+
+ account = thl_lm.get_account_or_create_bp_commission(product=product)
+
+ assert product.uuid in account.qualified_name
+ assert account.display_name == f"Revenue from commission {product.uuid}"
+ assert account.account_type == AccountType.REVENUE
+ assert account.normal_balance == Direction.CREDIT
+ assert account.reference_type == "bp"
+ assert account.reference_uuid == product.uuid
+ assert account.currency == LedgerCurrency.TEST
+
+ # Actually query for it to confirm
+ res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True)
+ assert res.model_dump_json() == account.model_dump_json()
+
+ @pytest.mark.parametrize("expense", ["tango", "paypal", "gift", "tremendous"])
+ def test_get_account_or_create_bp_expense(self, product, expense, thl_lm, lm):
+ from generalresearch.currency import LedgerCurrency
+ from generalresearch.models.thl.ledger import (
+ Direction,
+ AccountType,
+ )
+
+ account = thl_lm.get_account_or_create_bp_expense(
+ product=product, expense_name=expense
+ )
+ assert product.uuid in account.qualified_name
+ assert account.display_name == f"Expense {expense} {product.uuid}"
+ assert account.account_type == AccountType.EXPENSE
+ assert account.normal_balance == Direction.DEBIT
+ assert account.reference_type == "bp"
+ assert account.reference_uuid == product.uuid
+ assert account.currency == LedgerCurrency.TEST
+
+ # Actually query for it to confirm
+ res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True)
+ assert res.model_dump_json() == account.model_dump_json()
+
+ def test_get_or_create_bp_pending_payout_account(self, product, thl_lm, lm):
+ from generalresearch.currency import LedgerCurrency
+ from generalresearch.models.thl.ledger import (
+ Direction,
+ AccountType,
+ )
+
+ account = thl_lm.get_or_create_bp_pending_payout_account(product=product)
+
+ assert product.uuid in account.qualified_name
+ assert account.display_name == f"BP Wallet Pending {product.uuid}"
+ assert account.account_type == AccountType.BP_WALLET
+ assert account.normal_balance == Direction.CREDIT
+ assert account.reference_type == "bp"
+ assert account.reference_uuid == product.uuid
+ assert account.currency == LedgerCurrency.TEST
+
+ # Actually query for it to confirm
+ res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True)
+ assert res.model_dump_json() == account.model_dump_json()
+
+ def test_get_account_task_complete_revenue_raises(
+ self, delete_ledger_db, thl_lm, lm
+ ):
+ from generalresearch.managers.thl.ledger_manager.exceptions import (
+ LedgerAccountDoesntExistError,
+ )
+
+ delete_ledger_db()
+
+ with pytest.raises(expected_exception=LedgerAccountDoesntExistError):
+ thl_lm.get_account_task_complete_revenue()
+
+ def test_get_account_task_complete_revenue(
+ self, account_cash, account_revenue_task_complete, thl_lm, lm
+ ):
+ from generalresearch.models.thl.ledger import (
+ LedgerAccount,
+ AccountType,
+ )
+
+ res = thl_lm.get_account_task_complete_revenue()
+ assert isinstance(res, LedgerAccount)
+ assert res.reference_type is None
+ assert res.reference_uuid is None
+ assert res.account_type == AccountType.REVENUE
+ assert res.display_name == "Cash flow task complete"
+
+ def test_get_account_cash_raises(self, delete_ledger_db, thl_lm, lm):
+ from generalresearch.managers.thl.ledger_manager.exceptions import (
+ LedgerAccountDoesntExistError,
+ )
+
+ delete_ledger_db()
+
+ with pytest.raises(expected_exception=LedgerAccountDoesntExistError):
+ thl_lm.get_account_cash()
+
+ def test_get_account_cash(self, account_cash, thl_lm, lm):
+ from generalresearch.models.thl.ledger import (
+ LedgerAccount,
+ AccountType,
+ )
+
+ res = thl_lm.get_account_cash()
+ assert isinstance(res, LedgerAccount)
+ assert res.reference_type is None
+ assert res.reference_uuid is None
+ assert res.account_type == AccountType.CASH
+ assert res.display_name == "Operating Cash Account"
+
+ def test_get_accounts(self, setup_accounts, product, user_factory, thl_lm, lm, lam):
+ from generalresearch.models.thl.user import User
+ from generalresearch.managers.thl.ledger_manager.exceptions import (
+ LedgerAccountDoesntExistError,
+ )
+
+ user1: User = user_factory(product=product)
+ user2: User = user_factory(product=product)
+
+ account1 = thl_lm.get_account_or_create_bp_wallet(product=product)
+
+ # (1) known account and confirm it comes back
+ res = lm.get_account(qualified_name=account1.qualified_name)
+ assert account1.model_dump_json() == res.model_dump_json()
+
+ # (2) known accounts and confirm they both come back
+ res = lam.get_accounts(qualified_names=[account1.qualified_name])
+ assert isinstance(res, list)
+ assert len(res) == 1
+ assert account1 in res
+
+ # Get 2 known and 1 made up qualified names, and confirm it raises
+ # an error
+ with pytest.raises(LedgerAccountDoesntExistError):
+ lam.get_accounts(
+ qualified_names=[
+ account1.qualified_name,
+ f"test:bp_wall:{uuid4().hex}",
+ ]
+ )
+
+ def test_get_accounts_if_exists(self, product_factory, currency, thl_lm, lm):
+ from generalresearch.models.thl.product import Product
+
+ p1: Product = product_factory()
+ p2: Product = product_factory()
+
+ account1 = thl_lm.get_account_or_create_bp_wallet(product=p1)
+ account2 = thl_lm.get_account_or_create_bp_wallet(product=p2)
+
+ # (1) known account and confirm it comes back
+ res = lm.get_account(qualified_name=account1.qualified_name)
+ assert account1.model_dump_json() == res.model_dump_json()
+
+ # (2) known accounts and confirm they both come back
+ res = lm.get_accounts(
+ qualified_names=[account1.qualified_name, account2.qualified_name]
+ )
+ assert isinstance(res, list)
+ assert len(res) == 2
+ assert account1 in res
+ assert account2 in res
+
+ # Get 2 known and 1 made up qualified names, and confirm only 2
+ # come back
+ lm.get_accounts_if_exists(
+ qualified_names=[
+ account1.qualified_name,
+ account2.qualified_name,
+ f"{currency.value}:bp_wall:{uuid4().hex}",
+ ]
+ )
+
+ assert isinstance(res, list)
+ assert len(res) == 2
+
+ # Confirm an empty array comes back for all unknown qualified names
+ res = lm.get_accounts_if_exists(
+ qualified_names=[
+ f"{lm.currency.value}:bp_wall:{uuid4().hex}" for i in range(5)
+ ]
+ )
+ assert isinstance(res, list)
+ assert len(res) == 0
+
+ def test_get_accounts_for_products(self, product_factory, thl_lm, lm):
+ from generalresearch.managers.thl.ledger_manager.exceptions import (
+ LedgerAccountDoesntExistError,
+ )
+ from generalresearch.models.thl.ledger import (
+ LedgerAccount,
+ )
+
+ # Create 5 Products
+ product_uuids = []
+ for i in range(5):
+ _p = product_factory()
+ product_uuids.append(_p.uuid)
+
+ # Confirm that this fails.. because none of those accounts have been
+ # created yet
+ with pytest.raises(expected_exception=LedgerAccountDoesntExistError):
+ thl_lm.get_accounts_bp_wallet_for_products(product_uuids=product_uuids)
+
+ # Create the bp_wallet accounts and then try again
+ for p_uuid in product_uuids:
+ thl_lm.get_account_or_create_bp_wallet_by_uuid(product_uuid=p_uuid)
+
+ res = thl_lm.get_accounts_bp_wallet_for_products(product_uuids=product_uuids)
+ assert len(res) == len(product_uuids)
+ assert all([isinstance(i, LedgerAccount) for i in res])
+
+
+class TestLedgerAccountManager:
+
+ def test_get_or_create(self, thl_lm, lm, lam):
+ from generalresearch.managers.thl.ledger_manager.exceptions import (
+ LedgerAccountDoesntExistError,
+ )
+ from generalresearch.models.thl.ledger import (
+ LedgerAccount,
+ Direction,
+ AccountType,
+ )
+
+ u = uuid4().hex
+ name = f"test-{u[:8]}"
+
+ account = LedgerAccount(
+ display_name=name,
+ qualified_name=f"test:bp_wallet:{u}",
+ normal_balance=Direction.DEBIT,
+ account_type=AccountType.BP_WALLET,
+ currency="test",
+ reference_type="bp",
+ reference_uuid=u,
+ )
+
+ # First we want to validate that using the get_account method raises
+ # an error for a random LedgerAccount which we know does not exist.
+ with pytest.raises(LedgerAccountDoesntExistError):
+ lam.get_account(qualified_name=account.qualified_name)
+
+ # Now that we know it doesn't exist, get_or_create for it
+ instance = lam.get_account_or_create(account=account)
+
+ # It should always return
+ assert isinstance(instance, LedgerAccount)
+ assert instance.reference_uuid == u
+
+ def test_get(self, user, thl_lm, lm, lam):
+ from generalresearch.managers.thl.ledger_manager.exceptions import (
+ LedgerAccountDoesntExistError,
+ )
+ from generalresearch.models.thl.ledger import (
+ LedgerAccount,
+ AccountType,
+ )
+
+ with pytest.raises(LedgerAccountDoesntExistError):
+ lam.get_account(qualified_name=f"test:bp_wallet:{user.product.id}")
+
+ thl_lm.get_account_or_create_bp_wallet(product=user.product)
+ account = lam.get_account(qualified_name=f"test:bp_wallet:{user.product.id}")
+
+ assert isinstance(account, LedgerAccount)
+ assert AccountType.BP_WALLET == account.account_type
+ assert user.product.uuid == account.reference_uuid
+
+ def test_get_many(self, product_factory, thl_lm, lm, lam, currency):
+ from generalresearch.models.thl.product import Product
+ from generalresearch.managers.thl.ledger_manager.exceptions import (
+ LedgerAccountDoesntExistError,
+ )
+
+ p1: Product = product_factory()
+ p2: Product = product_factory()
+
+ account1 = thl_lm.get_account_or_create_bp_wallet(product=p1)
+ account2 = thl_lm.get_account_or_create_bp_wallet(product=p2)
+
+ # Get 1 known account and confirm it comes back
+ res = lam.get_account_many(
+ qualified_names=[account1.qualified_name, account2.qualified_name]
+ )
+ assert isinstance(res, list)
+ assert len(res) == 2
+ assert account1 in res
+
+ # Get 2 known accounts and confirm they both come back
+ res = lam.get_account_many(
+ qualified_names=[account1.qualified_name, account2.qualified_name]
+ )
+ assert isinstance(res, list)
+ assert len(res) == 2
+ assert account1 in res
+ assert account2 in res
+
+ # Get 2 known and 1 made up qualified names, and confirm only 2 come
+ # back. Don't raise on error, so we can confirm the array is "short"
+ res = lam.get_account_many(
+ qualified_names=[
+ account1.qualified_name,
+ account2.qualified_name,
+ f"test:bp_wall:{uuid4().hex}",
+ ],
+ raise_on_error=False,
+ )
+ assert isinstance(res, list)
+ assert len(res) == 2
+
+ # Same as above, but confirm the raise works on checking res length
+ with pytest.raises(LedgerAccountDoesntExistError):
+ lam.get_account_many(
+ qualified_names=[
+ account1.qualified_name,
+ account2.qualified_name,
+ f"test:bp_wall:{uuid4().hex}",
+ ],
+ raise_on_error=True,
+ )
+
+ # Confirm an empty array comes back for all unknown qualified names
+ res = lam.get_account_many(
+ qualified_names=[f"test:bp_wall:{uuid4().hex}" for i in range(5)],
+ raise_on_error=False,
+ )
+ assert isinstance(res, list)
+ assert len(res) == 0
+
+ def test_create_account(self, thl_lm, lm, lam):
+ from generalresearch.models.thl.ledger import (
+ LedgerAccount,
+ Direction,
+ AccountType,
+ )
+
+ u = uuid4().hex
+ name = f"test-{u[:8]}"
+
+ account = LedgerAccount(
+ display_name=name,
+ qualified_name=f"test:bp_wallet:{u}",
+ normal_balance=Direction.DEBIT,
+ account_type=AccountType.BP_WALLET,
+ currency="test",
+ reference_type="bp",
+ reference_uuid=u,
+ )
+
+ lam.create_account(account=account)
+ assert lam.get_account(f"test:bp_wallet:{u}") == account
+ assert lam.get_account_or_create(account) == account
diff --git a/tests/managers/thl/test_ledger/test_thl_lm_bp_payout.py b/tests/managers/thl/test_ledger/test_thl_lm_bp_payout.py
new file mode 100644
index 0000000..294d092
--- /dev/null
+++ b/tests/managers/thl/test_ledger/test_thl_lm_bp_payout.py
@@ -0,0 +1,516 @@
+import logging
+from datetime import datetime, timezone, timedelta
+from decimal import Decimal
+from random import randint
+from uuid import uuid4
+
+import pytest
+import redis
+from pydantic import RedisDsn
+from redis.lock import Lock
+
+from generalresearch.currency import USDCent
+from generalresearch.managers.base import Permission
+from generalresearch.managers.thl.ledger_manager.thl_ledger import ThlLedgerManager
+from generalresearch.managers.thl.ledger_manager.exceptions import (
+ LedgerTransactionFlagAlreadyExistsError,
+ LedgerTransactionConditionFailedError,
+ LedgerTransactionReleaseLockError,
+ LedgerTransactionCreateError,
+)
+from generalresearch.managers.thl.ledger_manager.ledger import LedgerTransaction
+from generalresearch.models import Source
+from generalresearch.models.thl.definitions import PayoutStatus
+from generalresearch.models.thl.ledger import Direction, TransactionType
+from generalresearch.models.thl.session import (
+ Wall,
+ Status,
+ StatusCode1,
+ Session,
+)
+from generalresearch.models.thl.user import User
+from generalresearch.models.thl.wallet import PayoutType
+from generalresearch.redis_helper import RedisConfig
+
+
+def broken_acquire(self, *args, **kwargs):
+ raise redis.exceptions.TimeoutError("Simulated timeout during acquire")
+
+
+def broken_release(self, *args, **kwargs):
+ raise redis.exceptions.TimeoutError("Simulated timeout during release")
+
+
+class TestThlLedgerManagerBPPayout:
+
+ def test_create_tx_with_bp_payment(
+ self,
+ user_factory,
+ product_user_wallet_no,
+ create_main_accounts,
+ caplog,
+ thl_lm,
+ delete_ledger_db,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ now = datetime.now(timezone.utc) - timedelta(hours=1)
+ user: User = user_factory(product=product_user_wallet_no)
+
+ wall1 = Wall(
+ user_id=user.user_id,
+ source=Source.DYNATA,
+ req_survey_id="xxx",
+ req_cpi=Decimal("6.00"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=now,
+ finished=now + timedelta(seconds=1),
+ )
+ tx = thl_lm.create_tx_task_complete(
+ wall=wall1, user=user, created=wall1.started
+ )
+ assert isinstance(tx, LedgerTransaction)
+
+ session = Session(started=wall1.started, user=user, wall_events=[wall1])
+ status, status_code_1 = session.determine_session_status()
+ thl_net, commission_amount, bp_pay, user_pay = session.determine_payments()
+ session.update(
+ **{
+ "status": status,
+ "status_code_1": status_code_1,
+ "finished": now + timedelta(minutes=10),
+ "payout": bp_pay,
+ "user_payout": user_pay,
+ }
+ )
+ thl_lm.create_tx_bp_payment(session=session, created=wall1.started)
+
+ lock_key = f"test:bp_payout:{user.product.id}"
+ flag_name = f"{thl_lm.cache_prefix}:transaction_flag:{lock_key}"
+ thl_lm.redis_client.delete(flag_name)
+
+ payoutevent_uuid = uuid4().hex
+ thl_lm.create_tx_bp_payout(
+ product=user.product,
+ amount=USDCent(200),
+ created=now,
+ payoutevent_uuid=payoutevent_uuid,
+ )
+
+ payoutevent_uuid = uuid4().hex
+ thl_lm.create_tx_bp_payout(
+ product=user.product,
+ amount=USDCent(200),
+ created=now + timedelta(minutes=2),
+ skip_one_per_day_check=True,
+ payoutevent_uuid=payoutevent_uuid,
+ )
+
+ cash = thl_lm.get_account_cash()
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product)
+ assert 170 == thl_lm.get_account_balance(bp_wallet_account)
+ assert 200 == thl_lm.get_account_balance(cash)
+
+ with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError):
+ thl_lm.create_tx_bp_payout(
+ user.product,
+ amount=USDCent(200),
+ created=now + timedelta(minutes=2),
+ skip_one_per_day_check=False,
+ skip_wallet_balance_check=False,
+ payoutevent_uuid=payoutevent_uuid,
+ )
+
+ payoutevent_uuid = uuid4().hex
+ with caplog.at_level(logging.INFO):
+ with pytest.raises(LedgerTransactionConditionFailedError):
+ thl_lm.create_tx_bp_payout(
+ user.product,
+ amount=USDCent(10_000),
+ created=now + timedelta(minutes=2),
+ skip_one_per_day_check=True,
+ skip_wallet_balance_check=False,
+ payoutevent_uuid=payoutevent_uuid,
+ )
+ assert "failed condition check balance:" in caplog.text
+
+ thl_lm.create_tx_bp_payout(
+ product=user.product,
+ amount=USDCent(10_00),
+ created=now + timedelta(minutes=2),
+ skip_one_per_day_check=True,
+ skip_wallet_balance_check=True,
+ payoutevent_uuid=payoutevent_uuid,
+ )
+ assert 170 - 1000 == thl_lm.get_account_balance(bp_wallet_account)
+
+ def test_create_tx(self, product, caplog, thl_lm, currency):
+ rand_amount: USDCent = USDCent(randint(100, 1_000))
+ payoutevent_uuid = uuid4().hex
+
+ # Create a BP Payout for a Product without any activity. By issuing,
+ # the skip_* checks, we should be able to force it to work, and will
+ # then ultimately result in a negative balance
+ tx = thl_lm.create_tx_bp_payout(
+ product=product,
+ amount=rand_amount,
+ payoutevent_uuid=payoutevent_uuid,
+ created=datetime.now(tz=timezone.utc),
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ skip_flag_check=True,
+ )
+
+ # Check the basic attributes
+ assert isinstance(tx, LedgerTransaction)
+ assert tx.ext_description == "BP Payout"
+ assert (
+ tx.tag
+ == f"{currency.value}:{TransactionType.BP_PAYOUT.value}:{payoutevent_uuid}"
+ )
+ assert tx.entries[0].amount == rand_amount
+ assert tx.entries[1].amount == rand_amount
+
+ # Check the Product's balance, it should be negative the amount that was
+ # paid out. That's because the Product earned nothing.. and then was
+ # sent something.
+ balance = thl_lm.get_account_balance(
+ account=thl_lm.get_account_or_create_bp_wallet(product=product)
+ )
+ assert balance == int(rand_amount) * -1
+
+ # Test some basic assertions
+ with caplog.at_level(logging.INFO):
+ with pytest.raises(expected_exception=Exception):
+ thl_lm.create_tx_bp_payout(
+ product=product,
+ amount=rand_amount,
+ payoutevent_uuid=uuid4().hex,
+ created=datetime.now(tz=timezone.utc),
+ skip_wallet_balance_check=False,
+ skip_one_per_day_check=False,
+ skip_flag_check=False,
+ )
+ assert "failed condition check >1 tx per day" in caplog.text
+
+ def test_create_tx_redis_failure(self, product, thl_web_rw, thl_lm):
+ rand_amount: USDCent = USDCent(randint(100, 1_000))
+ payoutevent_uuid = uuid4().hex
+ now = datetime.now(tz=timezone.utc)
+
+ thl_lm.create_tx_plug_bp_wallet(
+ product, rand_amount, now, direction=Direction.CREDIT
+ )
+
+ # Non routable IP address. Redis will fail
+ thl_lm_redis_0 = ThlLedgerManager(
+ pg_config=thl_web_rw,
+ permissions=[
+ Permission.CREATE,
+ Permission.READ,
+ Permission.UPDATE,
+ Permission.DELETE,
+ ],
+ testing=True,
+ redis_config=RedisConfig(
+ dsn=RedisDsn("redis://10.255.255.1:6379"),
+ socket_connect_timeout=0.1,
+ ),
+ )
+
+ with pytest.raises(expected_exception=Exception) as e:
+ tx = thl_lm_redis_0.create_tx_bp_payout(
+ product=product,
+ amount=rand_amount,
+ payoutevent_uuid=payoutevent_uuid,
+ created=datetime.now(tz=timezone.utc),
+ )
+ assert e.type is redis.exceptions.TimeoutError
+ # No txs were created
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product)
+ txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid)
+ txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"]
+ assert len(txs) == 0
+
+ def test_create_tx_multiple_per_day(self, product, thl_lm):
+ rand_amount: USDCent = USDCent(randint(100, 1_000))
+ payoutevent_uuid = uuid4().hex
+ now = datetime.now(tz=timezone.utc)
+
+ thl_lm.create_tx_plug_bp_wallet(
+ product, rand_amount * USDCent(2), now, direction=Direction.CREDIT
+ )
+
+ tx = thl_lm.create_tx_bp_payout(
+ product=product,
+ amount=rand_amount,
+ payoutevent_uuid=payoutevent_uuid,
+ created=datetime.now(tz=timezone.utc),
+ )
+
+ # Try to create another
+ # Will fail b/c it has the same payout event uuid
+ with pytest.raises(expected_exception=Exception) as e:
+ tx = thl_lm.create_tx_bp_payout(
+ product=product,
+ amount=rand_amount,
+ payoutevent_uuid=payoutevent_uuid,
+ created=datetime.now(tz=timezone.utc),
+ )
+ assert e.type is LedgerTransactionFlagAlreadyExistsError
+
+ # Try to create another
+ # Will fail due to multiple per day
+ payoutevent_uuid2 = uuid4().hex
+ with pytest.raises(expected_exception=Exception) as e:
+ tx = thl_lm.create_tx_bp_payout(
+ product=product,
+ amount=rand_amount,
+ payoutevent_uuid=payoutevent_uuid2,
+ created=datetime.now(tz=timezone.utc),
+ )
+ assert e.type is LedgerTransactionConditionFailedError
+ assert str(e.value) == ">1 tx per day"
+
+ # Make it run by skipping one per day check
+ tx = thl_lm.create_tx_bp_payout(
+ product=product,
+ amount=rand_amount,
+ payoutevent_uuid=payoutevent_uuid2,
+ created=datetime.now(tz=timezone.utc),
+ skip_one_per_day_check=True,
+ )
+
+ def test_create_tx_redis_lock_release_error(self, product, thl_lm):
+ rand_amount: USDCent = USDCent(randint(100, 1_000))
+ payoutevent_uuid = uuid4().hex
+ now = datetime.now(tz=timezone.utc)
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product)
+
+ thl_lm.create_tx_plug_bp_wallet(
+ product, rand_amount * USDCent(2), now, direction=Direction.CREDIT
+ )
+
+ original_acquire = Lock.acquire
+ original_release = Lock.release
+ Lock.acquire = broken_acquire
+
+ # Create TX will fail on lock enter, no tx will actually get created
+ with pytest.raises(expected_exception=Exception) as e:
+ tx = thl_lm.create_tx_bp_payout(
+ product=product,
+ amount=rand_amount,
+ payoutevent_uuid=payoutevent_uuid,
+ created=datetime.now(tz=timezone.utc),
+ )
+ assert e.type is LedgerTransactionCreateError
+ assert str(e.value) == "Redis error: Simulated timeout during acquire"
+ txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid)
+ txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"]
+ assert len(txs) == 0
+
+ Lock.acquire = original_acquire
+ Lock.release = broken_release
+
+ # Create TX will fail on lock exit, after the tx was created!
+ with pytest.raises(expected_exception=Exception) as e:
+ tx = thl_lm.create_tx_bp_payout(
+ product=product,
+ amount=rand_amount,
+ payoutevent_uuid=payoutevent_uuid,
+ created=datetime.now(tz=timezone.utc),
+ )
+ assert e.type is LedgerTransactionReleaseLockError
+ assert str(e.value) == "Redis error: Simulated timeout during release"
+
+ # Transaction was still created!
+ txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid)
+ txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"]
+ assert len(txs) == 1
+ Lock.release = original_release
+
+
+class TestPayoutEventManagerBPPayout:
+
+ def test_create(self, product, thl_lm, brokerage_product_payout_event_manager):
+ rand_amount: USDCent = USDCent(randint(100, 1_000))
+ now = datetime.now(tz=timezone.utc)
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product)
+ assert thl_lm.get_account_balance(bp_wallet_account) == 0
+ thl_lm.create_tx_plug_bp_wallet(
+ product, rand_amount, now, direction=Direction.CREDIT
+ )
+ assert thl_lm.get_account_balance(bp_wallet_account) == rand_amount
+ brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+
+ pe = brokerage_product_payout_event_manager.create_bp_payout_event(
+ thl_ledger_manager=thl_lm,
+ product=product,
+ created=now,
+ amount=rand_amount,
+ payout_type=PayoutType.ACH,
+ )
+ assert brokerage_product_payout_event_manager.check_for_ledger_tx(
+ thl_ledger_manager=thl_lm,
+ product_id=product.id,
+ amount=rand_amount,
+ payout_event=pe,
+ )
+ assert thl_lm.get_account_balance(bp_wallet_account) == 0
+
+ def test_create_with_redis_error(
+ self, product, caplog, thl_lm, brokerage_product_payout_event_manager
+ ):
+ caplog.set_level("WARNING")
+ original_acquire = Lock.acquire
+ original_release = Lock.release
+
+ rand_amount: USDCent = USDCent(randint(100, 1_000))
+ now = datetime.now(tz=timezone.utc)
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product)
+ assert thl_lm.get_account_balance(bp_wallet_account) == 0
+ thl_lm.create_tx_plug_bp_wallet(
+ product, rand_amount, now, direction=Direction.CREDIT
+ )
+ assert thl_lm.get_account_balance(bp_wallet_account) == rand_amount
+ brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+
+ # Will fail on lock enter, no tx will actually get created
+ Lock.acquire = broken_acquire
+ with pytest.raises(expected_exception=Exception) as e:
+ pe = brokerage_product_payout_event_manager.create_bp_payout_event(
+ thl_ledger_manager=thl_lm,
+ product=product,
+ created=now,
+ amount=rand_amount,
+ payout_type=PayoutType.ACH,
+ )
+ assert e.type is LedgerTransactionCreateError
+ assert str(e.value) == "Redis error: Simulated timeout during acquire"
+ assert any(
+ "Simulated timeout during acquire. No ledger tx was created" in m
+ for m in caplog.messages
+ )
+
+ txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid)
+ txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"]
+ # One payout event is created, status is failed, and no ledger txs exist
+ assert len(txs) == 0
+ pes = (
+ brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products(
+ thl_ledger_manager=thl_lm, product_uuids=[product.id]
+ )
+ )
+ assert len(pes) == 1
+ assert pes[0].status == PayoutStatus.FAILED
+ pe = pes[0]
+
+ # Fix the redis method
+ Lock.acquire = original_acquire
+
+ # Try to fix the failed payout, by trying ledger tx again
+ brokerage_product_payout_event_manager.retry_create_bp_payout_event_tx(
+ product=product, thl_ledger_manager=thl_lm, payout_event_uuid=pe.uuid
+ )
+ txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid)
+ txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"]
+ assert len(txs) == 1
+ assert thl_lm.get_account_balance(bp_wallet_account) == 0
+
+ # And then try to run it again, it'll fail because a payout event with the same info exists
+ with pytest.raises(expected_exception=Exception) as e:
+ pe = brokerage_product_payout_event_manager.create_bp_payout_event(
+ thl_ledger_manager=thl_lm,
+ product=product,
+ created=now,
+ amount=rand_amount,
+ payout_type=PayoutType.ACH,
+ )
+ assert e.type is ValueError
+ assert "Payout event already exists!" in str(e.value)
+
+ # We wouldn't do this in practice, because this is paying out the BP again, but
+ # we can if want to.
+ # Change the timestamp so it'll create a new payout event
+ now = datetime.now(tz=timezone.utc)
+ with pytest.raises(LedgerTransactionConditionFailedError) as e:
+ pe = brokerage_product_payout_event_manager.create_bp_payout_event(
+ thl_ledger_manager=thl_lm,
+ product=product,
+ created=now,
+ amount=rand_amount,
+ payout_type=PayoutType.ACH,
+ )
+ # But it will fail due to 1 per day check
+ assert str(e.value) == ">1 tx per day"
+ pe = brokerage_product_payout_event_manager.get_by_uuid(e.value.pe_uuid)
+ assert pe.status == PayoutStatus.FAILED
+
+ # And if we really want to, we can make it again
+ now = datetime.now(tz=timezone.utc)
+ pe = brokerage_product_payout_event_manager.create_bp_payout_event(
+ thl_ledger_manager=thl_lm,
+ product=product,
+ created=now,
+ amount=rand_amount,
+ payout_type=PayoutType.ACH,
+ skip_one_per_day_check=True,
+ skip_wallet_balance_check=True,
+ )
+
+ txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid)
+ txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"]
+ assert len(txs) == 2
+ # since they were paid twice
+ assert thl_lm.get_account_balance(bp_wallet_account) == 0 - rand_amount
+
+ Lock.release = original_release
+ Lock.acquire = original_acquire
+
+ def test_create_with_redis_error_release(
+ self, product, caplog, thl_lm, brokerage_product_payout_event_manager
+ ):
+ caplog.set_level("WARNING")
+
+ original_release = Lock.release
+
+ rand_amount: USDCent = USDCent(randint(100, 1_000))
+ now = datetime.now(tz=timezone.utc)
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product)
+ brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+
+ assert thl_lm.get_account_balance(bp_wallet_account) == 0
+ thl_lm.create_tx_plug_bp_wallet(
+ product, rand_amount, now, direction=Direction.CREDIT
+ )
+ assert thl_lm.get_account_balance(bp_wallet_account) == rand_amount
+
+ # Will fail on lock exit, after the tx was created!
+ # But it'll see that the tx was created and so everything will be fine
+ Lock.release = broken_release
+ pe = brokerage_product_payout_event_manager.create_bp_payout_event(
+ thl_ledger_manager=thl_lm,
+ product=product,
+ created=now,
+ amount=rand_amount,
+ payout_type=PayoutType.ACH,
+ )
+ assert any(
+ "Simulated timeout during release but ledger tx exists" in m
+ for m in caplog.messages
+ )
+
+ txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid)
+ txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"]
+ assert len(txs) == 1
+ pes = (
+ brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products(
+ thl_ledger_manager=thl_lm, product_uuids=[product.uuid]
+ )
+ )
+ assert len(pes) == 1
+ assert pes[0].status == PayoutStatus.COMPLETE
+ Lock.release = original_release
diff --git a/tests/managers/thl/test_ledger/test_thl_lm_tx.py b/tests/managers/thl/test_ledger/test_thl_lm_tx.py
new file mode 100644
index 0000000..31c7107
--- /dev/null
+++ b/tests/managers/thl/test_ledger/test_thl_lm_tx.py
@@ -0,0 +1,1762 @@
+import logging
+from datetime import datetime, timezone, timedelta
+from decimal import Decimal
+from random import randint
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.currency import USDCent
+from generalresearch.managers.thl.ledger_manager.ledger import (
+ LedgerTransaction,
+)
+from generalresearch.models import Source
+from generalresearch.models.thl.definitions import (
+ WALL_ALLOWED_STATUS_STATUS_CODE,
+)
+from generalresearch.models.thl.ledger import Direction
+from generalresearch.models.thl.ledger import TransactionType
+from generalresearch.models.thl.product import (
+ PayoutConfig,
+ PayoutTransformation,
+ UserWalletConfig,
+)
+from generalresearch.models.thl.session import (
+ Wall,
+ Status,
+ StatusCode1,
+ Session,
+ WallAdjustedStatus,
+)
+from generalresearch.models.thl.user import User
+from generalresearch.models.thl.wallet import PayoutType
+from generalresearch.models.thl.payout import UserPayoutEvent
+
+logger = logging.getLogger("LedgerManager")
+
+
+class TestThlLedgerTxManager:
+
+ def test_create_tx_task_complete(
+ self,
+ wall,
+ user,
+ account_revenue_task_complete,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ ):
+ create_main_accounts()
+ tx = thl_lm.create_tx_task_complete(wall=wall, user=user)
+ assert isinstance(tx, LedgerTransaction)
+
+ res = lm.get_tx_by_id(transaction_id=tx.id)
+ assert res.created == tx.created
+
+ def test_create_tx_task_complete_(
+ self, wall, user, account_revenue_task_complete, thl_lm, lm
+ ):
+ tx = thl_lm.create_tx_task_complete_(wall=wall, user=user)
+ assert isinstance(tx, LedgerTransaction)
+
+ res = lm.get_tx_by_id(transaction_id=tx.id)
+ assert res.created == tx.created
+
+ def test_create_tx_bp_payment(
+ self,
+ session_factory,
+ user,
+ create_main_accounts,
+ delete_ledger_db,
+ thl_lm,
+ lm,
+ session_manager,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+ s1 = session_factory(user=user)
+
+ status, status_code_1 = s1.determine_session_status()
+ thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments()
+ session_manager.finish_with_status(
+ session=s1,
+ status=Status.COMPLETE,
+ status_code_1=status_code_1,
+ finished=datetime.now(tz=timezone.utc) + timedelta(minutes=10),
+ payout=bp_pay,
+ user_payout=user_pay,
+ )
+
+ tx = thl_lm.create_tx_bp_payment(session=s1)
+ assert isinstance(tx, LedgerTransaction)
+
+ res = lm.get_tx_by_id(transaction_id=tx.id)
+ assert res.created == tx.created
+
+ def test_create_tx_bp_payment_amt(
+ self,
+ session_factory,
+ user_factory,
+ product_manager,
+ create_main_accounts,
+ delete_ledger_db,
+ thl_lm,
+ lm,
+ session_manager,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+ product = product_manager.create_dummy(
+ payout_config=PayoutConfig(
+ payout_transformation=PayoutTransformation(
+ f="payout_transformation_amt"
+ )
+ ),
+ user_wallet_config=UserWalletConfig(amt=True, enabled=True),
+ )
+ user = user_factory(product=product)
+ s1 = session_factory(user=user, wall_req_cpi=Decimal("1"))
+
+ status, status_code_1 = s1.determine_session_status()
+ assert status == Status.COMPLETE
+ thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments(
+ thl_ledger_manager=thl_lm
+ )
+ print(thl_net, commission_amount, bp_pay, user_pay)
+ session_manager.finish_with_status(
+ session=s1,
+ status=Status.COMPLETE,
+ status_code_1=status_code_1,
+ finished=datetime.now(tz=timezone.utc) + timedelta(minutes=10),
+ payout=bp_pay,
+ user_payout=user_pay,
+ )
+
+ tx = thl_lm.create_tx_bp_payment(session=s1)
+ assert isinstance(tx, LedgerTransaction)
+
+ res = lm.get_tx_by_id(transaction_id=tx.id)
+ assert res.created == tx.created
+
+ def test_create_tx_bp_payment_(
+ self,
+ session_factory,
+ user,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ session_manager,
+ utc_hour_ago,
+ ):
+ s1 = session_factory(user=user)
+ status, status_code_1 = s1.determine_session_status()
+ thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments()
+ session_manager.finish_with_status(
+ session=s1,
+ status=status,
+ status_code_1=status_code_1,
+ finished=utc_hour_ago + timedelta(minutes=10),
+ payout=bp_pay,
+ user_payout=user_pay,
+ )
+
+ s1.determine_payments()
+ tx = thl_lm.create_tx_bp_payment_(session=s1)
+ assert isinstance(tx, LedgerTransaction)
+
+ res = lm.get_tx_by_id(transaction_id=tx.id)
+ assert res.created == tx.created
+
+ def test_create_tx_task_adjustment(
+ self, wall_factory, session, user, create_main_accounts, thl_lm, lm
+ ):
+ """Create Wall event Complete, and Create a Tx Task Adjustment
+
+ - I don't know what this does exactly... but we can confirm
+ the transaction comes back with balanced amounts, and that
+ the name of the Source is in the Tx description
+ """
+
+ wall_status = Status.COMPLETE
+ wall: Wall = wall_factory(session=session, wall_status=wall_status)
+
+ tx = thl_lm.create_tx_task_adjustment(wall=wall, user=user)
+ assert isinstance(tx, LedgerTransaction)
+ res = lm.get_tx_by_id(transaction_id=tx.id)
+
+ assert res.entries[0].amount == int(wall.cpi * 100)
+ assert res.entries[1].amount == int(wall.cpi * 100)
+ assert wall.source.name in res.ext_description
+ assert res.created == tx.created
+
+ def test_create_tx_bp_adjustment(self, session, user, caplog, thl_lm, lm):
+ status, status_code_1 = session.determine_session_status()
+ thl_net, commission_amount, bp_pay, user_pay = session.determine_payments()
+
+ # The default session fixture is just an unfinished wall event
+ assert len(session.wall_events) == 1
+ assert session.finished is None
+ assert status == Status.TIMEOUT
+ assert status_code_1 in list(
+ WALL_ALLOWED_STATUS_STATUS_CODE.get(Status.TIMEOUT, {})
+ )
+ assert thl_net == Decimal(0)
+ assert commission_amount == Decimal(0)
+ assert bp_pay == Decimal(0)
+ assert user_pay is None
+
+ # Update the finished timestamp, but nothing else. This means that
+ # there is no financial changes needed
+ session.update(
+ **{
+ "finished": datetime.now(tz=timezone.utc) + timedelta(minutes=10),
+ }
+ )
+ assert session.finished
+ with caplog.at_level(logging.INFO):
+ tx = thl_lm.create_tx_bp_adjustment(session=session)
+ assert tx is None
+ assert "No transactions needed." in caplog.text
+
+ def test_create_tx_bp_payout(self, product, caplog, thl_lm, currency):
+ rand_amount: USDCent = USDCent(randint(100, 1_000))
+ payoutevent_uuid = uuid4().hex
+
+ # Create a BP Payout for a Product without any activity. By issuing,
+ # the skip_* checks, we should be able to force it to work, and will
+ # then ultimately result in a negative balance
+ tx = thl_lm.create_tx_bp_payout(
+ product=product,
+ amount=rand_amount,
+ payoutevent_uuid=payoutevent_uuid,
+ created=datetime.now(tz=timezone.utc),
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ skip_flag_check=True,
+ )
+
+ # Check the basic attributes
+ assert isinstance(tx, LedgerTransaction)
+ assert tx.ext_description == "BP Payout"
+ assert (
+ tx.tag
+ == f"{thl_lm.currency.value}:{TransactionType.BP_PAYOUT.value}:{payoutevent_uuid}"
+ )
+ assert tx.entries[0].amount == rand_amount
+ assert tx.entries[1].amount == rand_amount
+
+ # Check the Product's balance, it should be negative the amount that was
+ # paid out. That's because the Product earned nothing.. and then was
+ # sent something.
+ balance = thl_lm.get_account_balance(
+ account=thl_lm.get_account_or_create_bp_wallet(product=product)
+ )
+ assert balance == int(rand_amount) * -1
+
+ # Test some basic assertions
+ with caplog.at_level(logging.INFO):
+ with pytest.raises(expected_exception=Exception):
+ thl_lm.create_tx_bp_payout(
+ product=product,
+ amount=rand_amount,
+ payoutevent_uuid=uuid4().hex,
+ created=datetime.now(tz=timezone.utc),
+ skip_wallet_balance_check=False,
+ skip_one_per_day_check=False,
+ skip_flag_check=False,
+ )
+ assert "failed condition check >1 tx per day" in caplog.text
+
+ def test_create_tx_bp_payout_(self, product, thl_lm, lm, currency):
+ rand_amount: USDCent = USDCent(randint(100, 1_000))
+ payoutevent_uuid = uuid4().hex
+
+ # Create a BP Payout for a Product without any activity.
+ tx = thl_lm.create_tx_bp_payout_(
+ product=product,
+ amount=rand_amount,
+ payoutevent_uuid=payoutevent_uuid,
+ created=datetime.now(tz=timezone.utc),
+ )
+
+ # Check the basic attributes
+ assert isinstance(tx, LedgerTransaction)
+ assert tx.ext_description == "BP Payout"
+ assert (
+ tx.tag
+ == f"{currency.value}:{TransactionType.BP_PAYOUT.value}:{payoutevent_uuid}"
+ )
+ assert tx.entries[0].amount == rand_amount
+ assert tx.entries[1].amount == rand_amount
+
+ def test_create_tx_plug_bp_wallet(
+ self, product, create_main_accounts, thl_lm, lm, currency
+ ):
+ """A BP Wallet "plug" is a way to makeup discrepancies and simply
+ add or remove money
+ """
+ rand_amount: USDCent = USDCent(randint(100, 1_000))
+
+ tx = thl_lm.create_tx_plug_bp_wallet(
+ product=product,
+ amount=rand_amount,
+ created=datetime.now(tz=timezone.utc),
+ direction=Direction.DEBIT,
+ skip_flag_check=False,
+ )
+
+ assert isinstance(tx, LedgerTransaction)
+
+ # We issued the BP money they didn't earn, so now they have a
+ # negative balance
+ balance = thl_lm.get_account_balance(
+ account=thl_lm.get_account_or_create_bp_wallet(product=product)
+ )
+ assert balance == int(rand_amount) * -1
+
+ def test_create_tx_plug_bp_wallet_(
+ self, product, create_main_accounts, thl_lm, lm, currency
+ ):
+ """A BP Wallet "plug" is a way to fix discrepancies and simply
+ add or remove money.
+
+ Similar to above, but because it's unprotected, we can immediately
+ issue another to see if the balance changes
+ """
+ rand_amount: USDCent = USDCent(randint(100, 1_000))
+
+ tx = thl_lm.create_tx_plug_bp_wallet_(
+ product=product,
+ amount=rand_amount,
+ created=datetime.now(tz=timezone.utc),
+ direction=Direction.DEBIT,
+ )
+
+ assert isinstance(tx, LedgerTransaction)
+
+ # We issued the BP money they didn't earn, so now they have a
+ # negative balance
+ balance = thl_lm.get_account_balance(
+ account=thl_lm.get_account_or_create_bp_wallet(product=product)
+ )
+ assert balance == int(rand_amount) * -1
+
+ # Issue a positive one now, and confirm the balance goes positive
+ thl_lm.create_tx_plug_bp_wallet_(
+ product=product,
+ amount=rand_amount + rand_amount,
+ created=datetime.now(tz=timezone.utc),
+ direction=Direction.CREDIT,
+ )
+ balance = thl_lm.get_account_balance(
+ account=thl_lm.get_account_or_create_bp_wallet(product=product)
+ )
+ assert balance == int(rand_amount)
+
+ def test_create_tx_user_payout_request(
+ self,
+ user,
+ product_user_wallet_yes,
+ user_factory,
+ delete_df_collection,
+ thl_lm,
+ lm,
+ currency,
+ ):
+ pe = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.PAYPAL,
+ amount=500,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+
+ # The default user fixture uses a product that doesn't have wallet
+ # mode enabled
+ with pytest.raises(expected_exception=AssertionError):
+ thl_lm.create_tx_user_payout_request(
+ user=user,
+ payout_event=pe,
+ skip_flag_check=True,
+ skip_wallet_balance_check=True,
+ )
+
+ # Now try it for a user on a product with wallet mode
+ u2 = user_factory(product=product_user_wallet_yes)
+
+ # User's pre-balance is 0 because no activity has occurred yet
+ pre_balance = lm.get_account_balance(
+ account=thl_lm.get_account_or_create_user_wallet(user=u2)
+ )
+ assert pre_balance == 0
+
+ tx = thl_lm.create_tx_user_payout_request(
+ user=u2,
+ payout_event=pe,
+ skip_flag_check=True,
+ skip_wallet_balance_check=True,
+ )
+ assert isinstance(tx, LedgerTransaction)
+ assert tx.entries[0].amount == pe.amount
+ assert tx.entries[1].amount == pe.amount
+ assert tx.ext_description == "User Payout Paypal Request $5.00"
+
+ #
+ # (TODO): This key ":user_payout:" is NOT part of the TransactionType
+ # enum and was manually set. It should be based off the
+ # TransactionType names.
+ #
+
+ assert tx.tag == f"{currency.value}:user_payout:{pe.uuid}:request"
+
+ # Post balance is -$5.00 because it comes out of the wallet before
+ # it's Approved or Completed
+ post_balance = lm.get_account_balance(
+ account=thl_lm.get_account_or_create_user_wallet(user=u2)
+ )
+ assert post_balance == -500
+
+ def test_create_tx_user_payout_request_(
+ self,
+ user,
+ product_user_wallet_yes,
+ user_factory,
+ delete_ledger_db,
+ thl_lm,
+ lm,
+ ):
+ delete_ledger_db()
+
+ pe = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.PAYPAL,
+ amount=500,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+
+ rand_description = uuid4().hex
+ tx = thl_lm.create_tx_user_payout_request_(
+ user=user, payout_event=pe, description=rand_description
+ )
+
+ assert tx.ext_description == rand_description
+
+ post_balance = lm.get_account_balance(
+ account=thl_lm.get_account_or_create_user_wallet(user=user)
+ )
+ assert post_balance == -500
+
+ def test_create_tx_user_payout_complete(
+ self,
+ user_factory,
+ product_user_wallet_yes,
+ create_main_accounts,
+ delete_ledger_db,
+ thl_lm,
+ lm,
+ currency,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ user: User = user_factory(product=product_user_wallet_yes)
+ user_account = thl_lm.get_account_or_create_user_wallet(user=user)
+ rand_amount = randint(100, 1_000)
+
+ # Ensure the user starts out with nothing...
+ assert lm.get_account_balance(account=user_account) == 0
+
+ pe = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.PAYPAL,
+ amount=rand_amount,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+
+ # Confirm it's not possible unless a request occurred happen
+ with pytest.raises(expected_exception=ValueError):
+ thl_lm.create_tx_user_payout_complete(
+ user=user,
+ payout_event=pe,
+ fee_amount=None,
+ skip_flag_check=False,
+ )
+
+ # (1) Make a request first
+ thl_lm.create_tx_user_payout_request(
+ user=user,
+ payout_event=pe,
+ skip_flag_check=True,
+ skip_wallet_balance_check=True,
+ )
+ # Assert the balance came out of their user wallet
+ assert lm.get_account_balance(account=user_account) == rand_amount * -1
+
+ # (2) Complete the request
+ tx = thl_lm.create_tx_user_payout_complete(
+ user=user,
+ payout_event=pe,
+ fee_amount=Decimal(0),
+ skip_flag_check=False,
+ )
+ assert tx.entries[0].amount == rand_amount
+ assert tx.entries[1].amount == rand_amount
+ assert tx.tag == f"{currency.value}:user_payout:{pe.uuid}:complete"
+ assert isinstance(tx, LedgerTransaction)
+
+ # The amount that comes out of the user wallet doesn't change after
+ # it's approved becuase it's already been withdrawn
+ assert lm.get_account_balance(account=user_account) == rand_amount * -1
+
+ def test_create_tx_user_payout_complete_(
+ self,
+ user_factory,
+ product_user_wallet_yes,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ ):
+ user: User = user_factory(product=product_user_wallet_yes)
+ user_account = thl_lm.get_account_or_create_user_wallet(user=user)
+ rand_amount = randint(100, 1_000)
+
+ pe = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.PAYPAL,
+ amount=rand_amount,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+
+ # (1) Make a request first
+ thl_lm.create_tx_user_payout_request(
+ user=user,
+ payout_event=pe,
+ skip_flag_check=True,
+ skip_wallet_balance_check=True,
+ )
+
+ # (2) Complete the request
+ rand_desc = uuid4().hex
+
+ bp_expense_account = thl_lm.get_account_or_create_bp_expense(
+ product=user.product, expense_name="paypal"
+ )
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=user.product)
+
+ tx = thl_lm.create_tx_user_payout_complete_(
+ user=user,
+ payout_event=pe,
+ fee_amount=Decimal("0.00"),
+ fee_expense_account=bp_expense_account,
+ fee_payer_account=bp_wallet_account,
+ description=rand_desc,
+ )
+ assert tx.ext_description == rand_desc
+ assert lm.get_account_balance(account=user_account) == rand_amount * -1
+
+ def test_create_tx_user_payout_cancelled(
+ self,
+ user_factory,
+ product_user_wallet_yes,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ currency,
+ ):
+ user: User = user_factory(product=product_user_wallet_yes)
+ user_account = thl_lm.get_account_or_create_user_wallet(user=user)
+ rand_amount = randint(100, 1_000)
+
+ pe = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.PAYPAL,
+ amount=rand_amount,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+
+ # (1) Make a request first
+ thl_lm.create_tx_user_payout_request(
+ user=user,
+ payout_event=pe,
+ skip_flag_check=True,
+ skip_wallet_balance_check=True,
+ )
+ # Assert the balance came out of their user wallet
+ assert lm.get_account_balance(account=user_account) == rand_amount * -1
+
+ # (2) Cancel the request
+ tx = thl_lm.create_tx_user_payout_cancelled(
+ user=user,
+ payout_event=pe,
+ skip_flag_check=False,
+ )
+ assert tx.entries[0].amount == rand_amount
+ assert tx.entries[1].amount == rand_amount
+ assert tx.tag == f"{currency.value}:user_payout:{pe.uuid}:cancel"
+ assert isinstance(tx, LedgerTransaction)
+
+ # Assert the balance comes back to 0 after it was cancelled
+ assert lm.get_account_balance(account=user_account) == 0
+
+ def test_create_tx_user_payout_cancelled_(
+ self,
+ user_factory,
+ product_user_wallet_yes,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ currency,
+ ):
+ user: User = user_factory(product=product_user_wallet_yes)
+ user_account = thl_lm.get_account_or_create_user_wallet(user=user)
+ rand_amount = randint(100, 1_000)
+
+ pe = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.PAYPAL,
+ amount=rand_amount,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+
+ # (1) Make a request first
+ thl_lm.create_tx_user_payout_request(
+ user=user,
+ payout_event=pe,
+ skip_flag_check=True,
+ skip_wallet_balance_check=True,
+ )
+ # Assert the balance came out of their user wallet
+ assert lm.get_account_balance(account=user_account) == rand_amount * -1
+
+ # (2) Cancel the request
+ rand_desc = uuid4().hex
+ tx = thl_lm.create_tx_user_payout_cancelled_(
+ user=user, payout_event=pe, description=rand_desc
+ )
+ assert isinstance(tx, LedgerTransaction)
+ assert tx.ext_description == rand_desc
+ assert lm.get_account_balance(account=user_account) == 0
+
+ def test_create_tx_user_bonus(
+ self,
+ user_factory,
+ product_user_wallet_yes,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ currency,
+ ):
+ user: User = user_factory(product=product_user_wallet_yes)
+ user_account = thl_lm.get_account_or_create_user_wallet(user=user)
+ rand_amount = randint(100, 1_000)
+ rand_ref_uuid = uuid4().hex
+ rand_desc = uuid4().hex
+
+ # Assert the balance came out of their user wallet
+ assert lm.get_account_balance(account=user_account) == 0
+
+ tx = thl_lm.create_tx_user_bonus(
+ user=user,
+ amount=Decimal(rand_amount / 100),
+ ref_uuid=rand_ref_uuid,
+ description=rand_desc,
+ skip_flag_check=True,
+ )
+ assert tx.ext_description == rand_desc
+ assert tx.tag == f"{thl_lm.currency.value}:user_bonus:{rand_ref_uuid}"
+ assert tx.entries[0].amount == rand_amount
+ assert tx.entries[1].amount == rand_amount
+
+ # Assert the balance came out of their user wallet
+ assert lm.get_account_balance(account=user_account) == rand_amount
+
+ def test_create_tx_user_bonus_(
+ self,
+ user_factory,
+ product_user_wallet_yes,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ currency,
+ ):
+ user: User = user_factory(product=product_user_wallet_yes)
+ user_account = thl_lm.get_account_or_create_user_wallet(user=user)
+ rand_amount = randint(100, 1_000)
+ rand_ref_uuid = uuid4().hex
+ rand_desc = uuid4().hex
+
+ # Assert the balance came out of their user wallet
+ assert lm.get_account_balance(account=user_account) == 0
+
+ tx = thl_lm.create_tx_user_bonus_(
+ user=user,
+ amount=Decimal(rand_amount / 100),
+ ref_uuid=rand_ref_uuid,
+ description=rand_desc,
+ )
+ assert tx.ext_description == rand_desc
+ assert tx.tag == f"{thl_lm.currency.value}:user_bonus:{rand_ref_uuid}"
+ assert tx.entries[0].amount == rand_amount
+ assert tx.entries[1].amount == rand_amount
+
+ # Assert the balance came out of their user wallet
+ assert lm.get_account_balance(account=user_account) == rand_amount
+
+
+class TestThlLedgerTxManagerFlows:
+ """Combine the various THL_LM methods to create actual "real world"
+ examples
+ """
+
+ def test_create_tx_task_complete(
+ self, user, create_main_accounts, thl_lm, lm, currency, delete_ledger_db
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ wall1 = Wall(
+ user_id=1,
+ source=Source.DYNATA,
+ req_survey_id="xxx",
+ req_cpi=Decimal("1.23"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=datetime.now(timezone.utc),
+ finished=datetime.now(timezone.utc) + timedelta(seconds=1),
+ )
+ thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started)
+
+ wall2 = Wall(
+ user_id=1,
+ source=Source.FULL_CIRCLE,
+ req_survey_id="yyy",
+ req_cpi=Decimal("3.21"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=datetime.now(timezone.utc),
+ finished=datetime.now(timezone.utc) + timedelta(seconds=1),
+ )
+ thl_lm.create_tx_task_complete(wall=wall2, user=user, created=wall2.started)
+
+ cash = thl_lm.get_account_cash()
+ revenue = thl_lm.get_account_task_complete_revenue()
+
+ assert lm.get_account_balance(cash) == 123 + 321
+ assert lm.get_account_balance(revenue) == 123 + 321
+ assert lm.check_ledger_balanced()
+
+ assert (
+ lm.get_account_filtered_balance(
+ account=revenue, metadata_key="source", metadata_value="d"
+ )
+ == 123
+ )
+
+ assert (
+ lm.get_account_filtered_balance(
+ account=revenue, metadata_key="source", metadata_value="f"
+ )
+ == 321
+ )
+
+ assert (
+ lm.get_account_filtered_balance(
+ account=revenue, metadata_key="source", metadata_value="x"
+ )
+ == 0
+ )
+
+ assert (
+ thl_lm.get_account_filtered_balance(
+ account=revenue,
+ metadata_key="thl_wall",
+ metadata_value=wall1.uuid,
+ )
+ == 123
+ )
+
+ def test_create_transaction_task_complete_1_cent(
+ self, user, create_main_accounts, thl_lm, lm, currency
+ ):
+ wall1 = Wall(
+ user_id=1,
+ source=Source.DYNATA,
+ req_survey_id="xxx",
+ req_cpi=Decimal("0.007"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=datetime.now(timezone.utc),
+ finished=datetime.now(timezone.utc) + timedelta(seconds=1),
+ )
+ tx = thl_lm.create_tx_task_complete(
+ wall=wall1, user=user, created=wall1.started
+ )
+
+ assert isinstance(tx, LedgerTransaction)
+
+ def test_create_transaction_bp_payment(
+ self,
+ user,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ currency,
+ delete_ledger_db,
+ session_factory,
+ utc_hour_ago,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ s1: Session = session_factory(
+ user=user,
+ wall_count=1,
+ started=utc_hour_ago,
+ wall_source=Source.TESTING,
+ )
+ w1: Wall = s1.wall_events[0]
+
+ tx = thl_lm.create_tx_task_complete(wall=w1, user=user, created=w1.started)
+ assert isinstance(tx, LedgerTransaction)
+
+ status, status_code_1 = s1.determine_session_status()
+ thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments()
+ s1.update(
+ **{
+ "status": status,
+ "status_code_1": status_code_1,
+ "finished": s1.started + timedelta(minutes=10),
+ "payout": bp_pay,
+ "user_payout": user_pay,
+ }
+ )
+ print(thl_net, commission_amount, bp_pay, user_pay)
+ thl_lm.create_tx_bp_payment(session=s1, created=w1.started)
+
+ revenue = thl_lm.get_account_task_complete_revenue()
+ bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=user.product)
+ bp_commission = thl_lm.get_account_or_create_bp_commission(product=user.product)
+
+ assert 0 == lm.get_account_balance(account=revenue)
+ assert 50 == lm.get_account_filtered_balance(
+ account=revenue,
+ metadata_key="source",
+ metadata_value=Source.TESTING,
+ )
+ assert 48 == lm.get_account_balance(account=bp_wallet)
+ assert 48 == lm.get_account_filtered_balance(
+ account=bp_wallet,
+ metadata_key="thl_session",
+ metadata_value=s1.uuid,
+ )
+ assert 2 == thl_lm.get_account_balance(account=bp_commission)
+ assert thl_lm.check_ledger_balanced()
+
+ def test_create_transaction_bp_payment_round(
+ self,
+ user_factory,
+ product_user_wallet_no,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ currency,
+ ):
+ product_user_wallet_no.commission_pct = Decimal("0.085")
+ user: User = user_factory(product=product_user_wallet_no)
+
+ wall1 = Wall(
+ user_id=user.user_id,
+ source=Source.SAGO,
+ req_survey_id="xxx",
+ req_cpi=Decimal("0.287"),
+ session_id=3,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=datetime.now(timezone.utc),
+ finished=datetime.now(timezone.utc) + timedelta(seconds=1),
+ )
+
+ tx = thl_lm.create_tx_task_complete(
+ wall=wall1, user=user, created=wall1.started
+ )
+ assert isinstance(tx, LedgerTransaction)
+
+ session = Session(started=wall1.started, user=user, wall_events=[wall1])
+ status, status_code_1 = session.determine_session_status()
+ thl_net, commission_amount, bp_pay, user_pay = session.determine_payments()
+ session.update(
+ **{
+ "status": status,
+ "status_code_1": status_code_1,
+ "finished": session.started + timedelta(minutes=10),
+ "payout": bp_pay,
+ "user_payout": user_pay,
+ }
+ )
+
+ print(thl_net, commission_amount, bp_pay, user_pay)
+ tx = thl_lm.create_tx_bp_payment(session=session, created=wall1.started)
+ assert isinstance(tx, LedgerTransaction)
+
+ def test_create_transaction_bp_payment_round2(
+ self, delete_ledger_db, user, create_main_accounts, thl_lm, lm, currency
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+ # user must be no user wallet
+ # e.g. session 869b5bfa47f44b4f81cd095ed01df2ff this fails if you dont round properly
+
+ wall1 = Wall(
+ user_id=user.user_id,
+ source=Source.SAGO,
+ req_survey_id="xxx",
+ req_cpi=Decimal("1.64500"),
+ session_id=3,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=datetime.now(timezone.utc),
+ finished=datetime.now(timezone.utc) + timedelta(seconds=1),
+ )
+
+ thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started)
+ session = Session(started=wall1.started, user=user, wall_events=[wall1])
+ status, status_code_1 = session.determine_session_status()
+ # thl_net, commission_amount, bp_pay, user_pay = session.determine_payments()
+ session.update(
+ **{
+ "status": status,
+ "status_code_1": status_code_1,
+ "finished": session.started + timedelta(minutes=10),
+ "payout": Decimal("1.53"),
+ "user_payout": Decimal("1.53"),
+ }
+ )
+
+ thl_lm.create_tx_bp_payment(session=session, created=wall1.started)
+
+ def test_create_transaction_bp_payment_round3(
+ self,
+ user_factory,
+ product_user_wallet_yes,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ currency,
+ ):
+ # e.g. session ___ fails b/c we rounded incorrectly
+ # before, and now we are off by a penny...
+ user: User = user_factory(product=product_user_wallet_yes)
+
+ wall1 = Wall(
+ user_id=user.user_id,
+ source=Source.SAGO,
+ req_survey_id="xxx",
+ req_cpi=Decimal("0.385"),
+ session_id=3,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=datetime.now(timezone.utc),
+ finished=datetime.now(timezone.utc) + timedelta(seconds=1),
+ )
+ thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started)
+
+ session = Session(started=wall1.started, user=user, wall_events=[wall1])
+ status, status_code_1 = session.determine_session_status()
+ # thl_net, commission_amount, bp_pay, user_pay = session.determine_payments()
+ session.update(
+ **{
+ "status": status,
+ "status_code_1": status_code_1,
+ "finished": session.started + timedelta(minutes=10),
+ "payout": Decimal("0.39"),
+ "user_payout": Decimal("0.26"),
+ }
+ )
+ # with pytest.logs(logger, level=logging.WARNING) as cm:
+ # tx = thl_lm.create_transaction_bp_payment(session, created=wall1.started)
+ # assert "Capping bp_pay to thl_net" in cm.output[0]
+
+ def test_create_transaction_bp_payment_user_wallet(
+ self,
+ user_factory,
+ product_user_wallet_yes,
+ create_main_accounts,
+ delete_ledger_db,
+ thl_lm,
+ session_manager,
+ wall_manager,
+ lm,
+ session_factory,
+ currency,
+ utc_hour_ago,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ user: User = user_factory(product=product_user_wallet_yes)
+ assert user.product.user_wallet_enabled
+
+ s1: Session = session_factory(
+ user=user,
+ wall_count=1,
+ started=utc_hour_ago,
+ wall_req_cpi=Decimal(".50"),
+ wall_source=Source.TESTING,
+ )
+ w1: Wall = s1.wall_events[0]
+
+ thl_lm.create_tx_task_complete(wall=w1, user=user, created=w1.started)
+
+ status, status_code_1 = s1.determine_session_status()
+ thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments()
+ session_manager.finish_with_status(
+ session=s1,
+ status=status,
+ status_code_1=status_code_1,
+ finished=s1.started + timedelta(minutes=10),
+ payout=bp_pay,
+ user_payout=user_pay,
+ )
+ thl_lm.create_tx_bp_payment(session=s1, created=w1.started)
+
+ revenue = thl_lm.get_account_task_complete_revenue()
+ bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=user.product)
+ bp_commission = thl_lm.get_account_or_create_bp_commission(product=user.product)
+ user_wallet = thl_lm.get_account_or_create_user_wallet(user=user)
+
+ assert 0 == thl_lm.get_account_balance(account=revenue)
+ assert 50 == thl_lm.get_account_filtered_balance(
+ account=revenue,
+ metadata_key="source",
+ metadata_value=Source.TESTING,
+ )
+
+ assert 48 - 19 == thl_lm.get_account_balance(account=bp_wallet)
+ assert 48 - 19 == thl_lm.get_account_filtered_balance(
+ account=bp_wallet,
+ metadata_key="thl_session",
+ metadata_value=s1.uuid,
+ )
+ assert 2 == thl_lm.get_account_balance(bp_commission)
+ assert 19 == thl_lm.get_account_balance(user_wallet)
+ assert 19 == thl_lm.get_account_filtered_balance(
+ account=user_wallet,
+ metadata_key="thl_session",
+ metadata_value=s1.uuid,
+ )
+
+ assert 0 == thl_lm.get_account_filtered_balance(
+ account=user_wallet, metadata_key="thl_session", metadata_value="x"
+ )
+ assert thl_lm.check_ledger_balanced()
+
+
+class TestThlLedgerManagerAdj:
+
+ def test_create_tx_task_adjustment(
+ self,
+ user_factory,
+ product_user_wallet_no,
+ create_main_accounts,
+ delete_ledger_db,
+ thl_lm,
+ lm,
+ utc_hour_ago,
+ currency,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ user: User = user_factory(product=product_user_wallet_no)
+
+ wall1 = Wall(
+ user_id=user.user_id,
+ source=Source.DYNATA,
+ req_survey_id="xxx",
+ req_cpi=Decimal("1.23"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=utc_hour_ago,
+ finished=utc_hour_ago + timedelta(seconds=1),
+ )
+
+ thl_lm.create_tx_task_complete(wall1, user, created=wall1.started)
+
+ wall2 = Wall(
+ user_id=1,
+ source=Source.FULL_CIRCLE,
+ req_survey_id="yyy",
+ req_cpi=Decimal("3.21"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=utc_hour_ago,
+ finished=utc_hour_ago + timedelta(seconds=1),
+ )
+ thl_lm.create_tx_task_complete(wall2, user, created=wall2.started)
+
+ wall1.update(
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ adjusted_cpi=0,
+ adjusted_timestamp=utc_hour_ago + timedelta(hours=1),
+ )
+ print(wall1.get_cpi_after_adjustment())
+ thl_lm.create_tx_task_adjustment(wall1, user)
+
+ cash = thl_lm.get_account_cash()
+ revenue = thl_lm.get_account_task_complete_revenue()
+
+ assert 123 + 321 - 123 == thl_lm.get_account_balance(account=cash)
+ assert 123 + 321 - 123 == thl_lm.get_account_balance(account=revenue)
+ assert thl_lm.check_ledger_balanced()
+ assert 0 == thl_lm.get_account_filtered_balance(
+ revenue, metadata_key="source", metadata_value="d"
+ )
+ assert 321 == thl_lm.get_account_filtered_balance(
+ revenue, metadata_key="source", metadata_value="f"
+ )
+ assert 0 == thl_lm.get_account_filtered_balance(
+ revenue, metadata_key="source", metadata_value="x"
+ )
+ assert 123 - 123 == thl_lm.get_account_filtered_balance(
+ account=revenue, metadata_key="thl_wall", metadata_value=wall1.uuid
+ )
+
+ # un-reconcile it
+ wall1.update(
+ adjusted_status=None,
+ adjusted_cpi=None,
+ adjusted_timestamp=utc_hour_ago + timedelta(minutes=45),
+ )
+ print(wall1.get_cpi_after_adjustment())
+ thl_lm.create_tx_task_adjustment(wall1, user)
+ # and then run it again to make sure it does nothing
+ thl_lm.create_tx_task_adjustment(wall1, user)
+
+ cash = thl_lm.get_account_cash()
+ revenue = thl_lm.get_account_task_complete_revenue()
+
+ assert 123 + 321 - 123 + 123 == thl_lm.get_account_balance(cash)
+ assert 123 + 321 - 123 + 123 == thl_lm.get_account_balance(revenue)
+ assert thl_lm.check_ledger_balanced()
+ assert 123 == thl_lm.get_account_filtered_balance(
+ account=revenue, metadata_key="source", metadata_value="d"
+ )
+ assert 321 == thl_lm.get_account_filtered_balance(
+ account=revenue, metadata_key="source", metadata_value="f"
+ )
+ assert 0 == thl_lm.get_account_filtered_balance(
+ account=revenue, metadata_key="source", metadata_value="x"
+ )
+ assert 123 - 123 + 123 == thl_lm.get_account_filtered_balance(
+ account=revenue, metadata_key="thl_wall", metadata_value=wall1.uuid
+ )
+
+ def test_create_tx_bp_adjustment(
+ self,
+ user,
+ product_user_wallet_no,
+ create_main_accounts,
+ caplog,
+ thl_lm,
+ lm,
+ currency,
+ session_manager,
+ wall_manager,
+ session_factory,
+ utc_hour_ago,
+ delete_ledger_db,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ s1 = session_factory(
+ user=user,
+ wall_count=2,
+ wall_req_cpis=[Decimal(1), Decimal(3)],
+ wall_statuses=[Status.COMPLETE, Status.COMPLETE],
+ started=utc_hour_ago,
+ )
+
+ w1: Wall = s1.wall_events[0]
+ w2: Wall = s1.wall_events[1]
+
+ thl_lm.create_tx_task_complete(wall=w1, user=user, created=w1.started)
+ thl_lm.create_tx_task_complete(wall=w2, user=user, created=w2.started)
+
+ status, status_code_1 = s1.determine_session_status()
+ thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments()
+ session_manager.finish_with_status(
+ session=s1,
+ status=status,
+ status_code_1=status_code_1,
+ finished=utc_hour_ago + timedelta(minutes=10),
+ payout=bp_pay,
+ user_payout=user_pay,
+ )
+ thl_lm.create_tx_bp_payment(session=s1, created=w1.started)
+ revenue = thl_lm.get_account_task_complete_revenue()
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=user.product)
+ bp_commission_account = thl_lm.get_account_or_create_bp_commission(
+ product=user.product
+ )
+ assert 380 == thl_lm.get_account_balance(account=bp_wallet_account)
+ assert 0 == thl_lm.get_account_balance(account=revenue)
+ assert 20 == thl_lm.get_account_balance(account=bp_commission_account)
+ thl_lm.check_ledger_balanced()
+
+ # This should do nothing (since we haven't adjusted any wall events)
+ s1.adjust_status()
+ with caplog.at_level(logging.INFO):
+ thl_lm.create_tx_bp_adjustment(session=s1)
+
+ assert (
+ "create_transaction_bp_adjustment. No transactions needed." in caplog.text
+ )
+
+ # self.assertEqual(380, ledger_manager.get_account_balance(bp_wallet_account))
+ # self.assertEqual(0, ledger_manager.get_account_balance(revenue))
+ # self.assertEqual(20, ledger_manager.get_account_balance(bp_commission_account))
+ # self.assertTrue(ledger_manager.check_ledger_balanced())
+
+ # recon $1 survey.
+ wall_manager.adjust_status(
+ wall=w1,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ adjusted_cpi=Decimal(0),
+ adjusted_timestamp=utc_hour_ago + timedelta(hours=1),
+ )
+ thl_lm.create_tx_task_adjustment(wall=w1, user=user)
+ # -$1.00 b/c the MP took the $1 back, but we haven't yet taken the BP payment back
+ assert -100 == thl_lm.get_account_balance(revenue)
+ s1.adjust_status()
+ thl_lm.create_tx_bp_adjustment(session=s1)
+
+ with caplog.at_level(logging.INFO):
+ thl_lm.create_tx_bp_adjustment(session=s1)
+ assert (
+ "create_transaction_bp_adjustment. No transactions needed." in caplog.text
+ )
+
+ assert 380 - 95 == thl_lm.get_account_balance(bp_wallet_account)
+ assert 0 == thl_lm.get_account_balance(revenue)
+ assert 20 - 5 == thl_lm.get_account_balance(bp_commission_account)
+ assert thl_lm.check_ledger_balanced()
+
+ # unrecon the $1 survey
+ wall_manager.adjust_status(
+ wall=w1,
+ adjusted_status=None,
+ adjusted_cpi=None,
+ adjusted_timestamp=utc_hour_ago + timedelta(minutes=45),
+ )
+ thl_lm.create_tx_task_adjustment(
+ wall=w1,
+ user=user,
+ created=utc_hour_ago + timedelta(minutes=45),
+ )
+ new_status, new_payout, new_user_payout = s1.determine_new_status_and_payouts()
+ s1.adjust_status()
+ thl_lm.create_tx_bp_adjustment(session=s1)
+ assert 380 == thl_lm.get_account_balance(bp_wallet_account)
+ assert 0 == thl_lm.get_account_balance(revenue)
+ assert 20, thl_lm.get_account_balance(bp_commission_account)
+ assert thl_lm.check_ledger_balanced()
+
+ def test_create_tx_bp_adjustment_small(
+ self,
+ user_factory,
+ product_user_wallet_no,
+ create_main_accounts,
+ delete_ledger_db,
+ thl_lm,
+ lm,
+ utc_hour_ago,
+ currency,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ # This failed when I didn't check that `change_commission` > 0 in
+ # create_transaction_bp_adjustment
+ user: User = user_factory(product=product_user_wallet_no)
+
+ wall1 = Wall(
+ user_id=user.user_id,
+ source=Source.DYNATA,
+ req_survey_id="xxx",
+ req_cpi=Decimal("0.10"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=utc_hour_ago,
+ finished=utc_hour_ago + timedelta(seconds=1),
+ )
+
+ tx = thl_lm.create_tx_task_complete(
+ wall=wall1, user=user, created=wall1.started
+ )
+ assert isinstance(tx, LedgerTransaction)
+
+ session = Session(started=wall1.started, user=user, wall_events=[wall1])
+ status, status_code_1 = session.determine_session_status()
+ thl_net, commission_amount, bp_pay, user_pay = session.determine_payments()
+ session.update(
+ **{
+ "status": status,
+ "status_code_1": status_code_1,
+ "finished": utc_hour_ago + timedelta(minutes=10),
+ "payout": bp_pay,
+ "user_payout": user_pay,
+ }
+ )
+ thl_lm.create_tx_bp_payment(session, created=wall1.started)
+
+ wall1.update(
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ adjusted_cpi=0,
+ adjusted_timestamp=utc_hour_ago + timedelta(hours=1),
+ )
+ thl_lm.create_tx_task_adjustment(wall1, user)
+ session.adjust_status()
+ thl_lm.create_tx_bp_adjustment(session)
+
+ def test_create_tx_bp_adjustment_abandon(
+ self,
+ user_factory,
+ product_user_wallet_no,
+ delete_ledger_db,
+ session_factory,
+ create_main_accounts,
+ caplog,
+ thl_lm,
+ lm,
+ currency,
+ utc_hour_ago,
+ session_manager,
+ wall_manager,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+ user: User = user_factory(product=product_user_wallet_no)
+ s1: Session = session_factory(
+ user=user, final_status=Status.ABANDON, wall_req_cpi=Decimal(1)
+ )
+ w1 = s1.wall_events[-1]
+
+ # Adjust to complete.
+ wall_manager.adjust_status(
+ wall=w1,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE,
+ adjusted_cpi=w1.cpi,
+ adjusted_timestamp=utc_hour_ago + timedelta(hours=1),
+ )
+ thl_lm.create_tx_task_adjustment(wall=w1, user=user)
+ s1.adjust_status()
+ thl_lm.create_tx_bp_adjustment(session=s1)
+ # And then adjust it back (it was abandon before, but now it should be
+ # fail (?) or back to abandon?)
+ wall_manager.adjust_status(
+ wall=w1,
+ adjusted_status=None,
+ adjusted_cpi=None,
+ adjusted_timestamp=utc_hour_ago + timedelta(hours=1),
+ )
+ thl_lm.create_tx_task_adjustment(wall=w1, user=user)
+ s1.adjust_status()
+ thl_lm.create_tx_bp_adjustment(session=s1)
+
+ revenue = thl_lm.get_account_task_complete_revenue()
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=user.product)
+ bp_commission_account = thl_lm.get_account_or_create_bp_commission(
+ product=user.product
+ )
+ assert 0 == thl_lm.get_account_balance(bp_wallet_account)
+ assert 0 == thl_lm.get_account_balance(revenue)
+ assert 0 == thl_lm.get_account_balance(bp_commission_account)
+ assert thl_lm.check_ledger_balanced()
+
+ # This should do nothing
+ s1.adjust_status()
+ with caplog.at_level(logging.INFO):
+ thl_lm.create_tx_bp_adjustment(session=s1)
+ assert "No transactions needed" in caplog.text
+
+ # Now back to complete again
+ wall_manager.adjust_status(
+ wall=w1,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE,
+ adjusted_cpi=w1.cpi,
+ adjusted_timestamp=utc_hour_ago + timedelta(hours=1),
+ )
+ s1.adjust_status()
+ thl_lm.create_tx_bp_adjustment(session=s1)
+ assert 95 == thl_lm.get_account_balance(bp_wallet_account)
+
+ def test_create_tx_bp_adjustment_user_wallet(
+ self,
+ user_factory,
+ product_user_wallet_yes,
+ create_main_accounts,
+ delete_ledger_db,
+ caplog,
+ thl_lm,
+ lm,
+ currency,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ now = datetime.now(timezone.utc) - timedelta(days=1)
+ user: User = user_factory(product=product_user_wallet_yes)
+
+ # Create 2 Wall completes and create the respective transaction for
+ # them. We then create a 3rd wall event which is a failure but we
+ # do NOT create a transaction for it
+
+ wall3 = Wall(
+ user_id=8,
+ source=Source.CINT,
+ req_survey_id="zzz",
+ req_cpi=Decimal("2.00"),
+ session_id=1,
+ status=Status.FAIL,
+ status_code_1=StatusCode1.BUYER_FAIL,
+ started=now,
+ finished=now + timedelta(minutes=1),
+ )
+
+ now_w1 = now + timedelta(minutes=1, milliseconds=1)
+ wall1 = Wall(
+ user_id=user.user_id,
+ source=Source.DYNATA,
+ req_survey_id="xxx",
+ req_cpi=Decimal("1.00"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=now_w1,
+ finished=now_w1 + timedelta(minutes=1),
+ )
+ tx = thl_lm.create_tx_task_complete(
+ wall=wall1, user=user, created=wall1.started
+ )
+ assert isinstance(tx, LedgerTransaction)
+
+ now_w2 = now + timedelta(minutes=2, milliseconds=1)
+ wall2 = Wall(
+ user_id=user.user_id,
+ source=Source.FULL_CIRCLE,
+ req_survey_id="yyy",
+ req_cpi=Decimal("3.00"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=now_w2,
+ finished=now_w2 + timedelta(minutes=1),
+ )
+ tx = thl_lm.create_tx_task_complete(
+ wall=wall2, user=user, created=wall2.started
+ )
+ assert isinstance(tx, LedgerTransaction)
+
+ # It doesn't matter what order these wall events go in as because
+ # we have a pydantic valiator that sorts them
+ wall_events = [wall3, wall1, wall2]
+ # shuffle(wall_events)
+ session = Session(started=wall1.started, user=user, wall_events=wall_events)
+ status, status_code_1 = session.determine_session_status()
+ assert status == Status.COMPLETE
+ assert status_code_1 == StatusCode1.COMPLETE
+
+ thl_net, commission_amount, bp_pay, user_pay = session.determine_payments()
+ assert thl_net == Decimal("4.00")
+ assert commission_amount == Decimal("0.20")
+ assert bp_pay == Decimal("3.80")
+ assert user_pay == Decimal("1.52")
+
+ session.update(
+ **{
+ "status": status,
+ "status_code_1": status_code_1,
+ "finished": now + timedelta(minutes=10),
+ "payout": bp_pay,
+ "user_payout": user_pay,
+ }
+ )
+
+ tx = thl_lm.create_tx_bp_adjustment(session=session, created=wall1.started)
+ assert isinstance(tx, LedgerTransaction)
+
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=user.product)
+ assert 228 == thl_lm.get_account_balance(account=bp_wallet_account)
+
+ user_account = thl_lm.get_account_or_create_user_wallet(user=user)
+ assert 152 == thl_lm.get_account_balance(account=user_account)
+
+ revenue = thl_lm.get_account_task_complete_revenue()
+ assert 0 == thl_lm.get_account_balance(account=revenue)
+
+ bp_commission_account = thl_lm.get_account_or_create_bp_commission(
+ product=user.product
+ )
+ assert 20 == thl_lm.get_account_balance(account=bp_commission_account)
+
+ # the total (4.00) = 2.28 + 1.52 + .20
+ assert thl_lm.check_ledger_balanced()
+
+ # This should do nothing (since we haven't adjusted any wall events)
+ session.adjust_status()
+ print(
+ session.get_status_after_adjustment(),
+ session.get_payout_after_adjustment(),
+ session.get_user_payout_after_adjustment(),
+ )
+ with caplog.at_level(logging.INFO):
+ thl_lm.create_tx_bp_adjustment(session)
+ assert (
+ "create_transaction_bp_adjustment. No transactions needed." in caplog.text
+ )
+
+ # recon $1 survey.
+ wall1.update(
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ adjusted_cpi=0,
+ adjusted_timestamp=now + timedelta(hours=1),
+ )
+ thl_lm.create_tx_task_adjustment(wall1, user)
+ # -$1.00 b/c the MP took the $1 back, but we haven't yet taken the BP payment back
+ assert -100 == thl_lm.get_account_balance(revenue)
+ session.adjust_status()
+ print(
+ session.get_status_after_adjustment(),
+ session.get_payout_after_adjustment(),
+ session.get_user_payout_after_adjustment(),
+ )
+ thl_lm.create_tx_bp_adjustment(session)
+
+ # running this twice b/c it should do nothing the 2nd time
+ print(
+ session.get_status_after_adjustment(),
+ session.get_payout_after_adjustment(),
+ session.get_user_payout_after_adjustment(),
+ )
+ with caplog.at_level(logging.INFO):
+ thl_lm.create_tx_bp_adjustment(session)
+ assert (
+ "create_transaction_bp_adjustment. No transactions needed." in caplog.text
+ )
+
+ assert 228 - 57 == thl_lm.get_account_balance(bp_wallet_account)
+ assert 152 - 38 == thl_lm.get_account_balance(user_account)
+ assert 0 == thl_lm.get_account_balance(revenue)
+ assert 20 - 5 == thl_lm.get_account_balance(bp_commission_account)
+ assert thl_lm.check_ledger_balanced()
+
+ # unrecon the $1 survey
+ wall1.update(
+ adjusted_status=None,
+ adjusted_cpi=None,
+ adjusted_timestamp=now + timedelta(hours=2),
+ )
+ tx = thl_lm.create_tx_task_adjustment(wall=wall1, user=user)
+ assert isinstance(tx, LedgerTransaction)
+
+ new_status, new_payout, new_user_payout = (
+ session.determine_new_status_and_payouts()
+ )
+ print(new_status, new_payout, new_user_payout, session.adjusted_payout)
+ session.adjust_status()
+ print(
+ session.get_status_after_adjustment(),
+ session.get_payout_after_adjustment(),
+ session.get_user_payout_after_adjustment(),
+ )
+ thl_lm.create_tx_bp_adjustment(session)
+
+ assert 228 - 57 + 57 == thl_lm.get_account_balance(bp_wallet_account)
+ assert 152 - 38 + 38 == thl_lm.get_account_balance(user_account)
+ assert 0 == thl_lm.get_account_balance(revenue)
+ assert 20 - 5 + 5 == thl_lm.get_account_balance(bp_commission_account)
+ assert thl_lm.check_ledger_balanced()
+
+ # make the $2 failure into a complete also
+ wall3.update(
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE,
+ adjusted_cpi=wall3.cpi,
+ adjusted_timestamp=now + timedelta(hours=2),
+ )
+ thl_lm.create_tx_task_adjustment(wall3, user)
+ new_status, new_payout, new_user_payout = (
+ session.determine_new_status_and_payouts()
+ )
+ print(new_status, new_payout, new_user_payout, session.adjusted_payout)
+ session.adjust_status()
+ print(
+ session.get_status_after_adjustment(),
+ session.get_payout_after_adjustment(),
+ session.get_user_payout_after_adjustment(),
+ )
+ thl_lm.create_tx_bp_adjustment(session)
+ assert 228 - 57 + 57 + 114 == thl_lm.get_account_balance(bp_wallet_account)
+ assert 152 - 38 + 38 + 76 == thl_lm.get_account_balance(user_account)
+ assert 0 == thl_lm.get_account_balance(revenue)
+ assert 20 - 5 + 5 + 10 == thl_lm.get_account_balance(bp_commission_account)
+ assert thl_lm.check_ledger_balanced()
+
+ def test_create_transaction_bp_adjustment_cpi_adjustment(
+ self,
+ user_factory,
+ product_user_wallet_no,
+ create_main_accounts,
+ delete_ledger_db,
+ caplog,
+ thl_lm,
+ lm,
+ utc_hour_ago,
+ currency,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+ user: User = user_factory(product=product_user_wallet_no)
+
+ wall1 = Wall(
+ user_id=user.user_id,
+ source=Source.DYNATA,
+ req_survey_id="xxx",
+ req_cpi=Decimal("1.00"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=utc_hour_ago,
+ finished=utc_hour_ago + timedelta(seconds=1),
+ )
+ tx = thl_lm.create_tx_task_complete(
+ wall=wall1, user=user, created=wall1.started
+ )
+ assert isinstance(tx, LedgerTransaction)
+
+ wall2 = Wall(
+ user_id=user.user_id,
+ source=Source.FULL_CIRCLE,
+ req_survey_id="yyy",
+ req_cpi=Decimal("3.00"),
+ session_id=1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ started=utc_hour_ago,
+ finished=utc_hour_ago + timedelta(seconds=1),
+ )
+ tx = thl_lm.create_tx_task_complete(
+ wall=wall2, user=user, created=wall2.started
+ )
+ assert isinstance(tx, LedgerTransaction)
+
+ session = Session(started=wall1.started, user=user, wall_events=[wall1, wall2])
+ status, status_code_1 = session.determine_session_status()
+ thl_net, commission_amount, bp_pay, user_pay = session.determine_payments()
+ session.update(
+ **{
+ "status": status,
+ "status_code_1": status_code_1,
+ "finished": utc_hour_ago + timedelta(minutes=10),
+ "payout": bp_pay,
+ "user_payout": user_pay,
+ }
+ )
+ thl_lm.create_tx_bp_payment(session, created=wall1.started)
+
+ revenue = thl_lm.get_account_task_complete_revenue()
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product)
+ bp_commission_account = thl_lm.get_account_or_create_bp_commission(user.product)
+ assert 380 == thl_lm.get_account_balance(bp_wallet_account)
+ assert 0 == thl_lm.get_account_balance(revenue)
+ assert 20 == thl_lm.get_account_balance(bp_commission_account)
+ assert thl_lm.check_ledger_balanced()
+
+ # cpi adjustment $1 -> $.60.
+ wall1.update(
+ adjusted_status=WallAdjustedStatus.CPI_ADJUSTMENT,
+ adjusted_cpi=Decimal("0.60"),
+ adjusted_timestamp=utc_hour_ago + timedelta(minutes=30),
+ )
+ thl_lm.create_tx_task_adjustment(wall1, user)
+
+ # -$0.40 b/c the MP took $0.40 back, but we haven't yet taken the BP payment back
+ assert -40 == thl_lm.get_account_balance(revenue)
+ session.adjust_status()
+ print(
+ session.get_status_after_adjustment(),
+ session.get_payout_after_adjustment(),
+ session.get_user_payout_after_adjustment(),
+ )
+ thl_lm.create_tx_bp_adjustment(session)
+
+ # running this twice b/c it should do nothing the 2nd time
+ print(
+ session.get_status_after_adjustment(),
+ session.get_payout_after_adjustment(),
+ session.get_user_payout_after_adjustment(),
+ )
+ with caplog.at_level(logging.INFO):
+ thl_lm.create_tx_bp_adjustment(session)
+ assert "create_transaction_bp_adjustment." in caplog.text
+ assert "No transactions needed." in caplog.text
+
+ assert 380 - 38 == thl_lm.get_account_balance(bp_wallet_account)
+ assert 0 == thl_lm.get_account_balance(revenue)
+ assert 20 - 2 == thl_lm.get_account_balance(bp_commission_account)
+ assert thl_lm.check_ledger_balanced()
+
+ # adjust it to failure
+ wall1.update(
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ adjusted_cpi=0,
+ adjusted_timestamp=utc_hour_ago + timedelta(minutes=45),
+ )
+ thl_lm.create_tx_task_adjustment(wall1, user)
+ session.adjust_status()
+ thl_lm.create_tx_bp_adjustment(session)
+ assert 300 - (300 * 0.05) == thl_lm.get_account_balance(bp_wallet_account)
+ assert 0 == thl_lm.get_account_balance(revenue)
+ assert 300 * 0.05 == thl_lm.get_account_balance(bp_commission_account)
+ assert thl_lm.check_ledger_balanced()
+
+ # and then back to cpi adj again, but this time for more than the orig amount
+ wall1.update(
+ adjusted_status=WallAdjustedStatus.CPI_ADJUSTMENT,
+ adjusted_cpi=Decimal("2.00"),
+ adjusted_timestamp=utc_hour_ago + timedelta(minutes=45),
+ )
+ thl_lm.create_tx_task_adjustment(wall1, user)
+ session.adjust_status()
+ thl_lm.create_tx_bp_adjustment(session)
+ assert 500 - (500 * 0.05) == thl_lm.get_account_balance(bp_wallet_account)
+ assert 0 == thl_lm.get_account_balance(revenue)
+ assert 500 * 0.05 == thl_lm.get_account_balance(bp_commission_account)
+ assert thl_lm.check_ledger_balanced()
+
+ # And adjust again
+ wall1.update(
+ adjusted_status=WallAdjustedStatus.CPI_ADJUSTMENT,
+ adjusted_cpi=Decimal("3.00"),
+ adjusted_timestamp=utc_hour_ago + timedelta(minutes=45),
+ )
+ thl_lm.create_tx_task_adjustment(wall=wall1, user=user)
+ session.adjust_status()
+ thl_lm.create_tx_bp_adjustment(session=session)
+ assert 600 - (600 * 0.05) == thl_lm.get_account_balance(
+ account=bp_wallet_account
+ )
+ assert 0 == thl_lm.get_account_balance(account=revenue)
+ assert 600 * 0.05 == thl_lm.get_account_balance(account=bp_commission_account)
+ assert thl_lm.check_ledger_balanced()
diff --git a/tests/managers/thl/test_ledger/test_thl_lm_tx__user_payouts.py b/tests/managers/thl/test_ledger/test_thl_lm_tx__user_payouts.py
new file mode 100644
index 0000000..1e7146a
--- /dev/null
+++ b/tests/managers/thl/test_ledger/test_thl_lm_tx__user_payouts.py
@@ -0,0 +1,505 @@
+import logging
+from datetime import datetime, timezone, timedelta
+from decimal import Decimal
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.managers.thl.ledger_manager.exceptions import (
+ LedgerTransactionFlagAlreadyExistsError,
+ LedgerTransactionConditionFailedError,
+)
+from generalresearch.models.thl.user import User
+from generalresearch.models.thl.wallet import PayoutType
+from generalresearch.models.thl.payout import UserPayoutEvent
+from test_utils.managers.ledger.conftest import create_main_accounts
+
+
+class TestLedgerManagerAMT:
+
+ def test_create_transaction_amt_ass_request(
+ self,
+ user_factory,
+ product_amt_true,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ delete_ledger_db,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+ user: User = user_factory(product=product_amt_true)
+
+ # debit_account_uuid nothing checks they match the ledger ... todo?
+ pe = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.AMT_HIT,
+ amount=5,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+ flag_key = f"test:user_payout:{pe.uuid}:request"
+ flag_name = f"ledger-manager:transaction_flag:{flag_key}"
+ lm.redis_client.delete(flag_name)
+
+ # User has $0 in their wallet. They are allowed amt_assignment payouts until -$1.00
+ thl_lm.create_tx_user_payout_request(user=user, payout_event=pe)
+ with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError):
+ thl_lm.create_tx_user_payout_request(
+ user=user, payout_event=pe, skip_flag_check=False
+ )
+ with pytest.raises(expected_exception=LedgerTransactionConditionFailedError):
+ thl_lm.create_tx_user_payout_request(
+ user=user, payout_event=pe, skip_flag_check=True
+ )
+ pe2 = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.AMT_HIT,
+ amount=96,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+
+ flag_key = f"test:user_payout:{pe2.uuid}:request"
+ flag_name = f"ledger-manager:transaction_flag:{flag_key}"
+ lm.redis_client.delete(flag_name)
+ # 96 cents would put them over the -$1.00 limit
+ with pytest.raises(expected_exception=LedgerTransactionConditionFailedError):
+ thl_lm.create_tx_user_payout_request(user, payout_event=pe2)
+
+ # But they could do 0.95 cents
+ pe2.amount = 95
+ thl_lm.create_tx_user_payout_request(
+ user, payout_event=pe2, skip_flag_check=True
+ )
+
+ cash = thl_lm.get_account_cash()
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product)
+ bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account(
+ product=user.product
+ )
+ user_wallet_account = thl_lm.get_account_or_create_user_wallet(user=user)
+
+ assert 0 == lm.get_account_balance(account=bp_wallet_account)
+ assert 0 == lm.get_account_balance(account=cash)
+ assert 100 == lm.get_account_balance(account=bp_pending_account)
+ assert -100 == lm.get_account_balance(account=user_wallet_account)
+ assert thl_lm.check_ledger_balanced()
+ assert -5 == thl_lm.get_account_filtered_balance(
+ account=user_wallet_account,
+ metadata_key="payoutevent",
+ metadata_value=pe.uuid,
+ )
+
+ assert -95 == thl_lm.get_account_filtered_balance(
+ account=user_wallet_account,
+ metadata_key="payoutevent",
+ metadata_value=pe2.uuid,
+ )
+
+ def test_create_transaction_amt_ass_complete(
+ self,
+ user_factory,
+ product_amt_true,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ delete_ledger_db,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+ user: User = user_factory(product=product_amt_true)
+
+ pe = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.AMT_HIT,
+ amount=5,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+ flag = f"ledger-manager:transaction_flag:test:user_payout:{pe.uuid}:request"
+ lm.redis_client.delete(flag)
+ flag = f"ledger-manager:transaction_flag:test:user_payout:{pe.uuid}:complete"
+ lm.redis_client.delete(flag)
+
+ # User has $0 in their wallet. They are allowed amt_assignment payouts until -$1.00
+ thl_lm.create_tx_user_payout_request(user, payout_event=pe)
+ thl_lm.create_tx_user_payout_complete(user, payout_event=pe)
+
+ cash = thl_lm.get_account_cash()
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product)
+ bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account(
+ user.product
+ )
+ bp_amt_expense_account = thl_lm.get_account_or_create_bp_expense(
+ user.product, expense_name="amt"
+ )
+ user_wallet_account = thl_lm.get_account_or_create_user_wallet(user)
+
+ # BP wallet pays the 1cent fee
+ assert -1 == thl_lm.get_account_balance(bp_wallet_account)
+ assert -5 == thl_lm.get_account_balance(cash)
+ assert -1 == thl_lm.get_account_balance(bp_amt_expense_account)
+ assert 0 == thl_lm.get_account_balance(bp_pending_account)
+ assert -5 == lm.get_account_balance(user_wallet_account)
+ assert thl_lm.check_ledger_balanced()
+
+ def test_create_transaction_amt_bonus(
+ self,
+ user_factory,
+ product_amt_true,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ delete_ledger_db,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ user: User = user_factory(product=product_amt_true)
+
+ pe = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.AMT_BONUS,
+ amount=34,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+ flag = f"ledger-manager:transaction_flag:test:user_payout:{pe.uuid}:request"
+ lm.redis_client.delete(flag)
+ flag = f"ledger-manager:transaction_flag:test:user_payout:{pe.uuid}:complete"
+ lm.redis_client.delete(flag)
+
+ with pytest.raises(expected_exception=LedgerTransactionConditionFailedError):
+ # User has $0 in their wallet. No amt bonus allowed
+ thl_lm.create_tx_user_payout_request(user, payout_event=pe)
+
+ thl_lm.create_tx_user_bonus(
+ user,
+ amount=Decimal(5),
+ ref_uuid="e703830dec124f17abed2d697d8d7701",
+ description="Bribe",
+ skip_flag_check=True,
+ )
+ pe.amount = 101
+ thl_lm.create_tx_user_payout_request(
+ user, payout_event=pe, skip_flag_check=False
+ )
+ thl_lm.create_tx_user_payout_complete(
+ user, payout_event=pe, skip_flag_check=False
+ )
+ with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError):
+ # duplicate, even if amount changed
+ pe.amount = 200
+ thl_lm.create_tx_user_payout_complete(
+ user, payout_event=pe, skip_flag_check=False
+ )
+ with pytest.raises(expected_exception=LedgerTransactionConditionFailedError):
+ # duplicate
+ thl_lm.create_tx_user_payout_complete(
+ user, payout_event=pe, skip_flag_check=True
+ )
+ pe.uuid = "533364150de4451198e5774e221a2acb"
+ pe.amount = 9900
+ with pytest.raises(expected_exception=ValueError):
+ # Trying to complete payout with no pending tx
+ thl_lm.create_tx_user_payout_complete(
+ user, payout_event=pe, skip_flag_check=True
+ )
+ with pytest.raises(expected_exception=LedgerTransactionConditionFailedError):
+ # trying to payout $99 with only a $5 balance
+ thl_lm.create_tx_user_payout_request(
+ user, payout_event=pe, skip_flag_check=True
+ )
+
+ cash = thl_lm.get_account_cash()
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product)
+ bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account(
+ user.product
+ )
+ bp_amt_expense_account = thl_lm.get_account_or_create_bp_expense(
+ user.product, expense_name="amt"
+ )
+ user_wallet_account = thl_lm.get_account_or_create_user_wallet(user)
+ assert -500 + round(-101 * 0.20) == thl_lm.get_account_balance(
+ bp_wallet_account
+ )
+ assert -101 == lm.get_account_balance(cash)
+ assert -20 == lm.get_account_balance(bp_amt_expense_account)
+ assert 0 == lm.get_account_balance(bp_pending_account)
+ assert 500 - 101 == lm.get_account_balance(user_wallet_account)
+ assert lm.check_ledger_balanced() is True
+
+ def test_create_transaction_amt_bonus_cancel(
+ self,
+ user_factory,
+ product_amt_true,
+ create_main_accounts,
+ caplog,
+ thl_lm,
+ lm,
+ delete_ledger_db,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ now = datetime.now(timezone.utc) - timedelta(hours=1)
+ user: User = user_factory(product=product_amt_true)
+
+ pe = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.AMT_BONUS,
+ amount=101,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+
+ thl_lm.create_tx_user_bonus(
+ user,
+ amount=Decimal(5),
+ ref_uuid="c44f4da2db1d421ebc6a5e5241ca4ce6",
+ description="Bribe",
+ skip_flag_check=True,
+ )
+ thl_lm.create_tx_user_payout_request(
+ user, payout_event=pe, skip_flag_check=True
+ )
+ thl_lm.create_tx_user_payout_cancelled(
+ user, payout_event=pe, skip_flag_check=True
+ )
+ with pytest.raises(expected_exception=LedgerTransactionConditionFailedError):
+ with caplog.at_level(logging.WARNING):
+ thl_lm.create_tx_user_payout_complete(
+ user, payout_event=pe, skip_flag_check=True
+ )
+ assert "trying to complete payout that was already cancelled" in caplog.text
+
+ cash = thl_lm.get_account_cash()
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product)
+ bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account(
+ user.product
+ )
+ bp_amt_expense_account = thl_lm.get_account_or_create_bp_expense(
+ user.product, expense_name="amt"
+ )
+ user_wallet_account = thl_lm.get_account_or_create_user_wallet(user)
+ assert -500 == thl_lm.get_account_balance(account=bp_wallet_account)
+ assert 0 == thl_lm.get_account_balance(account=cash)
+ assert 0 == thl_lm.get_account_balance(account=bp_amt_expense_account)
+ assert 0 == thl_lm.get_account_balance(account=bp_pending_account)
+ assert 500 == thl_lm.get_account_balance(account=user_wallet_account)
+ assert thl_lm.check_ledger_balanced()
+
+ pe2 = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.AMT_BONUS,
+ amount=200,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+ thl_lm.create_tx_user_payout_request(
+ user, payout_event=pe2, skip_flag_check=True
+ )
+ thl_lm.create_tx_user_payout_complete(
+ user, payout_event=pe2, skip_flag_check=True
+ )
+ with pytest.raises(expected_exception=LedgerTransactionConditionFailedError):
+ with caplog.at_level(logging.WARNING):
+ thl_lm.create_tx_user_payout_cancelled(
+ user, payout_event=pe2, skip_flag_check=True
+ )
+ assert "trying to cancel payout that was already completed" in caplog.text
+
+
+class TestLedgerManagerTango:
+
+ def test_create_transaction_tango_request(
+ self,
+ user_factory,
+ product_amt_true,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ delete_ledger_db,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ user: User = user_factory(product=product_amt_true)
+
+ # debit_account_uuid nothing checks they match the ledger ... todo?
+ pe = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.TANGO,
+ amount=500,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+ flag_key = f"test:user_payout:{pe.uuid}:request"
+ flag_name = f"ledger-manager:transaction_flag:{flag_key}"
+ lm.redis_client.delete(flag_name)
+ thl_lm.create_tx_user_bonus(
+ user,
+ amount=Decimal(6),
+ ref_uuid="e703830dec124f17abed2d697d8d7701",
+ description="Bribe",
+ skip_flag_check=True,
+ )
+ thl_lm.create_tx_user_payout_request(
+ user, payout_event=pe, skip_flag_check=True
+ )
+
+ cash = thl_lm.get_account_cash()
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product)
+ bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account(
+ user.product
+ )
+ bp_tango_expense_account = thl_lm.get_account_or_create_bp_expense(
+ user.product, expense_name="tango"
+ )
+ user_wallet_account = thl_lm.get_account_or_create_user_wallet(user)
+ assert -600 == thl_lm.get_account_balance(bp_wallet_account)
+ assert 0 == thl_lm.get_account_balance(cash)
+ assert 0 == thl_lm.get_account_balance(bp_tango_expense_account)
+ assert 500 == thl_lm.get_account_balance(bp_pending_account)
+ assert 600 - 500 == thl_lm.get_account_balance(user_wallet_account)
+ assert thl_lm.check_ledger_balanced()
+
+ thl_lm.create_tx_user_payout_complete(
+ user, payout_event=pe, skip_flag_check=True
+ )
+ assert -600 - round(500 * 0.035) == thl_lm.get_account_balance(
+ bp_wallet_account
+ )
+ assert -500, thl_lm.get_account_balance(cash)
+ assert round(-500 * 0.035) == thl_lm.get_account_balance(
+ bp_tango_expense_account
+ )
+ assert 0 == lm.get_account_balance(bp_pending_account)
+ assert 100 == lm.get_account_balance(user_wallet_account)
+ assert lm.check_ledger_balanced()
+
+
+class TestLedgerManagerPaypal:
+
+ def test_create_transaction_paypal_request(
+ self,
+ user_factory,
+ product_amt_true,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ delete_ledger_db,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ now = datetime.now(tz=timezone.utc) - timedelta(hours=1)
+ user: User = user_factory(product=product_amt_true)
+
+ # debit_account_uuid nothing checks they match the ledger ... todo?
+ pe = UserPayoutEvent(
+ uuid=uuid4().hex,
+ payout_type=PayoutType.PAYPAL,
+ amount=500,
+ cashout_method_uuid=uuid4().hex,
+ debit_account_uuid=uuid4().hex,
+ )
+ flag_key = f"test:user_payout:{pe.uuid}:request"
+ flag_name = f"ledger-manager:transaction_flag:{flag_key}"
+ lm.redis_client.delete(flag_name)
+ thl_lm.create_tx_user_bonus(
+ user=user,
+ amount=Decimal(6),
+ ref_uuid="e703830dec124f17abed2d697d8d7701",
+ description="Bribe",
+ skip_flag_check=True,
+ )
+
+ thl_lm.create_tx_user_payout_request(
+ user, payout_event=pe, skip_flag_check=True
+ )
+
+ cash = thl_lm.get_account_cash()
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product)
+ bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account(
+ product=user.product
+ )
+ bp_paypal_expense_account = thl_lm.get_account_or_create_bp_expense(
+ product=user.product, expense_name="paypal"
+ )
+ user_wallet_account = thl_lm.get_account_or_create_user_wallet(user=user)
+ assert -600 == lm.get_account_balance(account=bp_wallet_account)
+ assert 0 == lm.get_account_balance(account=cash)
+ assert 0 == lm.get_account_balance(account=bp_paypal_expense_account)
+ assert 500 == lm.get_account_balance(account=bp_pending_account)
+ assert 600 - 500 == lm.get_account_balance(account=user_wallet_account)
+ assert thl_lm.check_ledger_balanced()
+
+ thl_lm.create_tx_user_payout_complete(
+ user=user, payout_event=pe, skip_flag_check=True, fee_amount=Decimal("0.50")
+ )
+ assert -600 - 50 == thl_lm.get_account_balance(bp_wallet_account)
+ assert -500 == thl_lm.get_account_balance(cash)
+ assert -50 == thl_lm.get_account_balance(bp_paypal_expense_account)
+ assert 0 == thl_lm.get_account_balance(bp_pending_account)
+ assert 100 == thl_lm.get_account_balance(user_wallet_account)
+ assert thl_lm.check_ledger_balanced()
+
+
+class TestLedgerManagerBonus:
+
+ def test_create_transaction_bonus(
+ self,
+ user_factory,
+ product_user_wallet_yes,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ delete_ledger_db,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ user: User = user_factory(product=product_user_wallet_yes)
+
+ thl_lm.create_tx_user_bonus(
+ user=user,
+ amount=Decimal(5),
+ ref_uuid="8d0aaf612462448a9ebdd57fab0fc660",
+ description="Bribe",
+ skip_flag_check=True,
+ )
+ cash = thl_lm.get_account_cash()
+ bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product)
+ bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account(
+ product=user.product
+ )
+ bp_amt_expense_account = thl_lm.get_account_or_create_bp_expense(
+ user.product, expense_name="amt"
+ )
+ user_wallet_account = thl_lm.get_account_or_create_user_wallet(user=user)
+
+ assert -500 == lm.get_account_balance(account=bp_wallet_account)
+ assert 0 == lm.get_account_balance(account=cash)
+ assert 0 == lm.get_account_balance(account=bp_amt_expense_account)
+ assert 0 == lm.get_account_balance(account=bp_pending_account)
+ assert 500 == lm.get_account_balance(account=user_wallet_account)
+ assert thl_lm.check_ledger_balanced()
+
+ with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError):
+ thl_lm.create_tx_user_bonus(
+ user=user,
+ amount=Decimal(5),
+ ref_uuid="8d0aaf612462448a9ebdd57fab0fc660",
+ description="Bribe",
+ skip_flag_check=False,
+ )
+ with pytest.raises(expected_exception=LedgerTransactionConditionFailedError):
+ thl_lm.create_tx_user_bonus(
+ user=user,
+ amount=Decimal(5),
+ ref_uuid="8d0aaf612462448a9ebdd57fab0fc660",
+ description="Bribe",
+ skip_flag_check=True,
+ )
diff --git a/tests/managers/thl/test_ledger/test_thl_pem.py b/tests/managers/thl/test_ledger/test_thl_pem.py
new file mode 100644
index 0000000..5fb9e7d
--- /dev/null
+++ b/tests/managers/thl/test_ledger/test_thl_pem.py
@@ -0,0 +1,251 @@
+import uuid
+from random import randint
+from uuid import uuid4, UUID
+
+import pytest
+
+from generalresearch.currency import USDCent
+from generalresearch.models.thl.definitions import PayoutStatus
+from generalresearch.models.thl.payout import BrokerageProductPayoutEvent
+from generalresearch.models.thl.product import Product
+from generalresearch.models.thl.wallet.cashout_method import (
+ CashoutRequestInfo,
+)
+
+
+class TestThlPayoutEventManager:
+
+ def test_get_by_uuid(self, brokerage_product_payout_event_manager, thl_lm):
+ """This validates that the method raises an exception if it
+ fails. There are plenty of other tests that use this method so
+ it seems silly to duplicate it here again
+ """
+
+ with pytest.raises(expected_exception=AssertionError) as excinfo:
+ brokerage_product_payout_event_manager.get_by_uuid(pe_uuid=uuid4().hex)
+ assert "expected 1 result, got 0" in str(excinfo.value)
+
+ def test_filter_by(
+ self,
+ product_factory,
+ usd_cent,
+ bp_payout_event_factory,
+ thl_lm,
+ brokerage_product_payout_event_manager,
+ ):
+ from generalresearch.models.thl.payout import UserPayoutEvent
+
+ N_PRODUCTS = randint(3, 10)
+ N_PAYOUT_EVENTS = randint(3, 10)
+ amounts = []
+ products = []
+
+ for x_idx in range(N_PRODUCTS):
+ product: Product = product_factory()
+ thl_lm.get_account_or_create_bp_wallet(product=product)
+ products.append(product)
+ brokerage_product_payout_event_manager.set_account_lookup_table(
+ thl_lm=thl_lm
+ )
+
+ for y_idx in range(N_PAYOUT_EVENTS):
+ pe = bp_payout_event_factory(product=product, usd_cent=usd_cent)
+ amounts.append(int(usd_cent))
+ assert isinstance(pe, BrokerageProductPayoutEvent)
+
+ # We just added Payout Events for Products, now go ahead and
+ # query for them
+ accounts = thl_lm.get_accounts_bp_wallet_for_products(
+ product_uuids=[i.uuid for i in products]
+ )
+ res = brokerage_product_payout_event_manager.filter_by(
+ debit_account_uuids=[i.uuid for i in accounts]
+ )
+
+ assert len(res) == (N_PRODUCTS * N_PAYOUT_EVENTS)
+ assert sum([i.amount for i in res]) == sum(amounts)
+
+ def test_get_bp_payout_events_for_product(
+ self,
+ product_factory,
+ usd_cent,
+ bp_payout_event_factory,
+ brokerage_product_payout_event_manager,
+ thl_lm,
+ ):
+ from generalresearch.models.thl.payout import UserPayoutEvent
+
+ N_PRODUCTS = randint(3, 10)
+ N_PAYOUT_EVENTS = randint(3, 10)
+ amounts = []
+ products = []
+
+ for x_idx in range(N_PRODUCTS):
+ product: Product = product_factory()
+ products.append(product)
+ thl_lm.get_account_or_create_bp_wallet(product=product)
+ brokerage_product_payout_event_manager.set_account_lookup_table(
+ thl_lm=thl_lm
+ )
+
+ for y_idx in range(N_PAYOUT_EVENTS):
+ pe = bp_payout_event_factory(product=product, usd_cent=usd_cent)
+ amounts.append(usd_cent)
+ assert isinstance(pe, BrokerageProductPayoutEvent)
+
+ # We just added 5 Payouts for a specific Product, now go
+ # ahead and query for them
+ res = brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products(
+ thl_ledger_manager=thl_lm, product_uuids=[product.id]
+ )
+
+ assert len(res) == N_PAYOUT_EVENTS
+
+ # Now that all the Payouts for all the Products have been added, go
+ # ahead and query for them
+ res = (
+ brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products(
+ thl_ledger_manager=thl_lm, product_uuids=[i.uuid for i in products]
+ )
+ )
+
+ assert len(res) == (N_PRODUCTS * N_PAYOUT_EVENTS)
+ assert sum([i.amount for i in res]) == sum(amounts)
+
+ @pytest.mark.skip
+ def test_get_payout_detail(self, user_payout_event_manager):
+ """This fails because the description coming back is None, but then
+ it tries to return a PayoutEvent which validates that the
+ description can't be None
+ """
+ from generalresearch.models.thl.payout import (
+ UserPayoutEvent,
+ PayoutType,
+ )
+
+ rand_amount = randint(a=99, b=999)
+
+ pe = user_payout_event_manager.create(
+ debit_account_uuid=uuid4().hex,
+ account_reference_type="str-type-random",
+ account_reference_uuid=uuid4().hex,
+ cashout_method_uuid=uuid4().hex,
+ description="Best payout !",
+ amount=rand_amount,
+ status=PayoutStatus.PENDING,
+ ext_ref_id="123",
+ payout_type=PayoutType.CASH_IN_MAIL,
+ request_data={"foo": 123},
+ order_data={},
+ )
+
+ res = user_payout_event_manager.get_payout_detail(pe_uuid=pe.uuid)
+ assert isinstance(res, CashoutRequestInfo)
+
+ # def test_filter_by(self):
+ # raise NotImplementedError
+
+ def test_create(self, user_payout_event_manager):
+ from generalresearch.models.thl.payout import UserPayoutEvent
+
+ # Confirm the creation method returns back an instance.
+ pe = user_payout_event_manager.create_dummy()
+ assert isinstance(pe, UserPayoutEvent)
+
+ # Now query the DB for that PayoutEvent to confirm it was actually
+ # saved.
+ res = user_payout_event_manager.get_by_uuid(pe_uuid=pe.uuid)
+ assert isinstance(res, UserPayoutEvent)
+ assert UUID(res.uuid)
+
+ # Confirm they're the same
+ # assert pe.model_dump_json() == res2.model_dump_json()
+ assert res.description is None
+
+ # def test_update(self):
+ # raise NotImplementedError
+
+ def test_create_bp_payout(
+ self,
+ product,
+ delete_ledger_db,
+ create_main_accounts,
+ thl_lm,
+ brokerage_product_payout_event_manager,
+ lm,
+ ):
+ from generalresearch.models.thl.payout import UserPayoutEvent
+
+ delete_ledger_db()
+ create_main_accounts()
+
+ account_bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=product)
+ brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+
+ rand_amount = randint(a=99, b=999)
+
+ # Save a Brokerage Product Payout, so we have something in the
+ # Payout Event table and the respective ledger TX and Entry rows for it
+ pe = brokerage_product_payout_event_manager.create_bp_payout_event(
+ thl_ledger_manager=thl_lm,
+ product=product,
+ amount=USDCent(rand_amount),
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ )
+ assert isinstance(pe, BrokerageProductPayoutEvent)
+
+ # Now try to query for it!
+ res = thl_lm.get_tx_bp_payouts(account_uuids=[account_bp_wallet.uuid])
+ assert len(res) == 1
+ res = thl_lm.get_tx_bp_payouts(account_uuids=[uuid4().hex])
+ assert len(res) == 0
+
+ # Confirm it added to the users balance. The amount is negative because
+ # money was sent to the Brokerage Product, but they didn't have
+ # any activity that earned them money
+ bal = lm.get_account_balance(account=account_bp_wallet)
+ assert rand_amount == bal * -1
+
+
+class TestBPPayoutEvent:
+
+ def test_get_bp_bp_payout_events_for_products(
+ self,
+ product_factory,
+ bp_payout_event_factory,
+ usd_cent,
+ delete_ledger_db,
+ create_main_accounts,
+ brokerage_product_payout_event_manager,
+ thl_lm,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ N_PAYOUT_EVENTS = randint(3, 10)
+ amounts = []
+
+ product: Product = product_factory()
+ thl_lm.get_account_or_create_bp_wallet(product=product)
+ brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+
+ for y_idx in range(N_PAYOUT_EVENTS):
+ bp_payout_event_factory(product=product, usd_cent=usd_cent)
+ amounts.append(usd_cent)
+
+ # Fetch using the _bp_bp_ approach, so we have an
+ # array of BPPayoutEvents
+ bp_bp_res = (
+ brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products(
+ thl_ledger_manager=thl_lm, product_uuids=[product.uuid]
+ )
+ )
+ assert isinstance(bp_bp_res, list)
+ assert sum(amounts) == sum([i.amount for i in bp_bp_res])
+ for i in bp_bp_res:
+ assert isinstance(i, BrokerageProductPayoutEvent)
+ assert isinstance(i.amount, int)
+ assert isinstance(i.amount_usd, USDCent)
+ assert isinstance(i.amount_usd_str, str)
+ assert i.amount_usd_str[0] == "$"
diff --git a/tests/managers/thl/test_ledger/test_user_txs.py b/tests/managers/thl/test_ledger/test_user_txs.py
new file mode 100644
index 0000000..d81b244
--- /dev/null
+++ b/tests/managers/thl/test_ledger/test_user_txs.py
@@ -0,0 +1,288 @@
+from datetime import timedelta, datetime, timezone
+from decimal import Decimal
+from uuid import uuid4
+
+from generalresearch.managers.thl.payout import (
+ AMT_ASSIGNMENT_CASHOUT_METHOD,
+ AMT_BONUS_CASHOUT_METHOD,
+)
+from generalresearch.managers.thl.user_compensate import user_compensate
+from generalresearch.models.thl.definitions import (
+ Status,
+ WallAdjustedStatus,
+)
+from generalresearch.models.thl.ledger import (
+ UserLedgerTransactionTypesSummary,
+ UserLedgerTransactionTypeSummary,
+ TransactionType,
+)
+from generalresearch.models.thl.session import Session
+from generalresearch.models.thl.user import User
+from generalresearch.models.thl.wallet import PayoutType
+
+
+def test_user_txs(
+ user_factory,
+ product_amt_true,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ delete_ledger_db,
+ session_with_tx_factory,
+ adj_to_fail_with_tx_factory,
+ adj_to_complete_with_tx_factory,
+ session_factory,
+ user_payout_event_manager,
+ utc_now,
+):
+ delete_ledger_db()
+ create_main_accounts()
+
+ user: User = user_factory(product=product_amt_true)
+ account = thl_lm.get_account_or_create_user_wallet(user)
+ print(f"{account.uuid=}")
+
+ s: Session = session_with_tx_factory(user=user, wall_req_cpi=Decimal("1.00"))
+
+ bribe_uuid = user_compensate(
+ ledger_manager=thl_lm,
+ user=user,
+ amount_int=100,
+ )
+
+ pe = user_payout_event_manager.create(
+ uuid=uuid4().hex,
+ debit_account_uuid=account.uuid,
+ cashout_method_uuid=AMT_ASSIGNMENT_CASHOUT_METHOD,
+ amount=5,
+ created=utc_now,
+ payout_type=PayoutType.AMT_HIT,
+ request_data=dict(),
+ )
+ thl_lm.create_tx_user_payout_request(
+ user=user,
+ payout_event=pe,
+ )
+ pe = user_payout_event_manager.create(
+ uuid=uuid4().hex,
+ debit_account_uuid=account.uuid,
+ cashout_method_uuid=AMT_BONUS_CASHOUT_METHOD,
+ amount=127,
+ created=utc_now,
+ payout_type=PayoutType.AMT_BONUS,
+ request_data=dict(),
+ )
+ thl_lm.create_tx_user_payout_request(
+ user=user,
+ payout_event=pe,
+ )
+
+ wall = s.wall_events[-1]
+ adj_to_fail_with_tx_factory(session=s, created=wall.finished)
+
+ # And a fail -> complete adjustment
+ s_fail: Session = session_factory(
+ user=user,
+ wall_count=1,
+ final_status=Status.FAIL,
+ wall_req_cpi=Decimal("2.00"),
+ )
+ adj_to_complete_with_tx_factory(session=s_fail, created=utc_now)
+
+ # txs = thl_lm.get_tx_filtered_by_account(account.uuid)
+ # print(len(txs), txs)
+ txs = thl_lm.get_user_txs(user)
+ assert len(txs.transactions) == 6
+ assert txs.total == 6
+ assert txs.page == 1
+ assert txs.size == 50
+
+ # print(len(txs.transactions), txs)
+ d = txs.model_dump_json()
+ # print(d)
+
+ descriptions = {x.description for x in txs.transactions}
+ assert descriptions == {
+ "Compensation Bonus",
+ "HIT Bonus",
+ "HIT Reward",
+ "Task Adjustment",
+ "Task Complete",
+ }
+ amounts = {x.amount for x in txs.transactions}
+ assert amounts == {-127, 100, 38, -38, -5, 76}
+
+ assert txs.summary == UserLedgerTransactionTypesSummary(
+ bp_adjustment=UserLedgerTransactionTypeSummary(
+ entry_count=2, min_amount=-38, max_amount=76, total_amount=76 - 38
+ ),
+ bp_payment=UserLedgerTransactionTypeSummary(
+ entry_count=1, min_amount=38, max_amount=38, total_amount=38
+ ),
+ user_bonus=UserLedgerTransactionTypeSummary(
+ entry_count=1, min_amount=100, max_amount=100, total_amount=100
+ ),
+ user_payout_request=UserLedgerTransactionTypeSummary(
+ entry_count=2, min_amount=-127, max_amount=-5, total_amount=-132
+ ),
+ )
+ tx_adj_c = [
+ tx for tx in txs.transactions if tx.tx_type == TransactionType.BP_ADJUSTMENT
+ ]
+ assert sorted([tx.amount for tx in tx_adj_c]) == [-38, 76]
+
+
+def test_user_txs_pagination(
+ user_factory,
+ product_amt_true,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ delete_ledger_db,
+ session_with_tx_factory,
+ adj_to_fail_with_tx_factory,
+ user_payout_event_manager,
+ utc_now,
+):
+ delete_ledger_db()
+ create_main_accounts()
+
+ user: User = user_factory(product=product_amt_true)
+ account = thl_lm.get_account_or_create_user_wallet(user)
+ print(f"{account.uuid=}")
+
+ for _ in range(12):
+ user_compensate(
+ ledger_manager=thl_lm,
+ user=user,
+ amount_int=100,
+ skip_flag_check=True,
+ )
+
+ txs = thl_lm.get_user_txs(user, page=1, size=5)
+ assert len(txs.transactions) == 5
+ assert txs.total == 12
+ assert txs.page == 1
+ assert txs.size == 5
+ assert txs.summary.user_bonus.total_amount == 1200
+ assert txs.summary.user_bonus.entry_count == 12
+
+ # Skip to the 3rd page. We made 12, so there are 2 left
+ txs = thl_lm.get_user_txs(user, page=3, size=5)
+ assert len(txs.transactions) == 2
+ assert txs.total == 12
+ assert txs.page == 3
+ assert txs.summary.user_bonus.total_amount == 1200
+ assert txs.summary.user_bonus.entry_count == 12
+
+ # Should be empty, not fail
+ txs = thl_lm.get_user_txs(user, page=4, size=5)
+ assert len(txs.transactions) == 0
+ assert txs.total == 12
+ assert txs.page == 4
+ assert txs.summary.user_bonus.total_amount == 1200
+ assert txs.summary.user_bonus.entry_count == 12
+
+ # Test filtering. We should pull back only this one
+ now = datetime.now(tz=timezone.utc)
+ user_compensate(
+ ledger_manager=thl_lm,
+ user=user,
+ amount_int=100,
+ skip_flag_check=True,
+ )
+ txs = thl_lm.get_user_txs(user, page=1, size=5, time_start=now)
+ assert len(txs.transactions) == 1
+ assert txs.total == 1
+ assert txs.page == 1
+ # And the summary is restricted to this time range also!
+ assert txs.summary.user_bonus.total_amount == 100
+ assert txs.summary.user_bonus.entry_count == 1
+
+ # And filtering with 0 results
+ now = datetime.now(tz=timezone.utc)
+ txs = thl_lm.get_user_txs(user, page=1, size=5, time_start=now)
+ assert len(txs.transactions) == 0
+ assert txs.total == 0
+ assert txs.page == 1
+ assert txs.pages == 0
+ # And the summary is restricted to this time range also!
+ assert txs.summary.user_bonus.total_amount == None
+ assert txs.summary.user_bonus.entry_count == 0
+
+
+def test_user_txs_rolling_balance(
+ user_factory,
+ product_amt_true,
+ create_main_accounts,
+ thl_lm,
+ lm,
+ delete_ledger_db,
+ session_with_tx_factory,
+ adj_to_fail_with_tx_factory,
+ user_payout_event_manager,
+):
+ """
+ Creates 3 $1.00 bonuses (postive),
+ then 1 cashout (negative), $1.50
+ then 3 more $1.00 bonuses.
+ Note: pagination + rolling balance will BREAK if txs have
+ identical timestamps. In practice, they do not.
+ """
+ delete_ledger_db()
+ create_main_accounts()
+
+ user: User = user_factory(product=product_amt_true)
+ account = thl_lm.get_account_or_create_user_wallet(user)
+
+ for _ in range(3):
+ user_compensate(
+ ledger_manager=thl_lm,
+ user=user,
+ amount_int=100,
+ skip_flag_check=True,
+ )
+ pe = user_payout_event_manager.create(
+ uuid=uuid4().hex,
+ debit_account_uuid=account.uuid,
+ cashout_method_uuid=AMT_BONUS_CASHOUT_METHOD,
+ amount=150,
+ payout_type=PayoutType.AMT_BONUS,
+ request_data=dict(),
+ )
+ thl_lm.create_tx_user_payout_request(
+ user=user,
+ payout_event=pe,
+ )
+ for _ in range(3):
+ user_compensate(
+ ledger_manager=thl_lm,
+ user=user,
+ amount_int=100,
+ skip_flag_check=True,
+ )
+
+ txs = thl_lm.get_user_txs(user, page=1, size=10)
+ assert txs.transactions[0].balance_after == 100
+ assert txs.transactions[1].balance_after == 200
+ assert txs.transactions[2].balance_after == 300
+ assert txs.transactions[3].balance_after == 150
+ assert txs.transactions[4].balance_after == 250
+ assert txs.transactions[5].balance_after == 350
+ assert txs.transactions[6].balance_after == 450
+
+ # Ascending order, get 2nd page, make sure the balances include
+ # the previous txs. (will return last 3 txs)
+ txs = thl_lm.get_user_txs(user, page=2, size=4)
+ assert len(txs.transactions) == 3
+ assert txs.transactions[0].balance_after == 250
+ assert txs.transactions[1].balance_after == 350
+ assert txs.transactions[2].balance_after == 450
+
+ # Descending order, get 1st page. Will
+ # return most recent 3 txs in desc order
+ txs = thl_lm.get_user_txs(user, page=1, size=3, order_by="-created")
+ assert len(txs.transactions) == 3
+ assert txs.transactions[0].balance_after == 450
+ assert txs.transactions[1].balance_after == 350
+ assert txs.transactions[2].balance_after == 250
diff --git a/tests/managers/thl/test_ledger/test_wallet.py b/tests/managers/thl/test_ledger/test_wallet.py
new file mode 100644
index 0000000..a0abd7c
--- /dev/null
+++ b/tests/managers/thl/test_ledger/test_wallet.py
@@ -0,0 +1,78 @@
+from decimal import Decimal
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.models.thl.product import (
+ UserWalletConfig,
+ PayoutConfig,
+ PayoutTransformation,
+ PayoutTransformationPercentArgs,
+)
+from generalresearch.models.thl.user import User
+
+
+@pytest.fixture()
+def schrute_product(product_manager):
+ return product_manager.create_dummy(
+ user_wallet_config=UserWalletConfig(enabled=True, amt=False),
+ payout_config=PayoutConfig(
+ payout_transformation=PayoutTransformation(
+ f="payout_transformation_percent",
+ kwargs=PayoutTransformationPercentArgs(pct=0.4),
+ ),
+ payout_format="{payout:,.0f} Schrute Bucks",
+ ),
+ )
+
+
+class TestGetUserWalletBalance:
+ def test_get_user_wallet_balance_non_managed(self, user, thl_lm):
+ with pytest.raises(
+ AssertionError,
+ match="Can't get wallet balance on non-managed account.",
+ ):
+ thl_lm.get_user_wallet_balance(user=user)
+
+ def test_get_user_wallet_balance_managed_0(
+ self, schrute_product, user_factory, thl_lm
+ ):
+ assert (
+ schrute_product.payout_config.payout_format == "{payout:,.0f} Schrute Bucks"
+ )
+ user: User = user_factory(schrute_product)
+ balance = thl_lm.get_user_wallet_balance(user=user)
+ assert balance == 0
+ balance_string = user.product.format_payout_format(Decimal(balance) / 100)
+ assert balance_string == "0 Schrute Bucks"
+ redeemable_balance = thl_lm.get_user_redeemable_wallet_balance(
+ user=user, user_wallet_balance=balance
+ )
+ assert redeemable_balance == 0
+ redeemable_balance_string = user.product.format_payout_format(
+ Decimal(redeemable_balance) / 100
+ )
+ assert redeemable_balance_string == "0 Schrute Bucks"
+
+ def test_get_user_wallet_balance_managed(
+ self, schrute_product, user_factory, thl_lm, session_with_tx_factory
+ ):
+ user: User = user_factory(schrute_product)
+ thl_lm.create_tx_user_bonus(
+ user=user,
+ amount=Decimal(1),
+ ref_uuid=uuid4().hex,
+ description="cheese",
+ )
+ session_with_tx_factory(user=user, wall_req_cpi=Decimal("1.23"))
+
+ # This product has a payout xform of 40% and commission of 5%
+ # 1.23 * 0.05 = 0.06 of commission
+ # 1.17 of payout * 0.40 = 0.47 of user pay and (1.17-0.47) 0.70 bp pay
+ balance = thl_lm.get_user_wallet_balance(user=user)
+ assert balance == 47 + 100 # plus the $1 bribe
+
+ redeemable_balance = thl_lm.get_user_redeemable_wallet_balance(
+ user=user, user_wallet_balance=balance
+ )
+ assert redeemable_balance == 20 + 100
diff --git a/tests/managers/thl/test_maxmind.py b/tests/managers/thl/test_maxmind.py
new file mode 100644
index 0000000..c588c58
--- /dev/null
+++ b/tests/managers/thl/test_maxmind.py
@@ -0,0 +1,273 @@
+import json
+import logging
+from typing import Callable
+
+import geoip2.models
+import pytest
+from faker import Faker
+from faker.providers.address.en_US import Provider as USAddressProvider
+
+from generalresearch.managers.thl.ipinfo import GeoIpInfoManager
+from generalresearch.managers.thl.maxmind import MaxmindManager
+from generalresearch.managers.thl.maxmind.basic import (
+ MaxmindBasicManager,
+)
+from generalresearch.models.thl.ipinfo import (
+ GeoIPInformation,
+ normalize_ip,
+)
+from generalresearch.models.thl.maxmind.definitions import UserType
+
+fake = Faker()
+
+US_STATES = {x.lower() for x in USAddressProvider.states}
+
+IP_v4_INDIA = "106.203.146.157"
+IP_v6_INDIA = "2402:3a80:4649:de3f:0:24:74aa:b601"
+IP_v4_US = "174.218.60.101"
+IP_v6_US = "2600:1700:ece0:9410:55d:faf3:c15d:6e4"
+IP_v6_US_SAME_64 = "2600:1700:ece0:9410:55d:faf3:c15d:aaaa"
+
+
+@pytest.fixture(scope="session")
+def delete_ipinfo(thl_web_rw) -> Callable:
+ def _delete_ipinfo(ip):
+ thl_web_rw.execute_write(
+ query="DELETE FROM thl_geoname WHERE geoname_id IN (SELECT geoname_id FROM thl_ipinformation WHERE ip = %s);",
+ params=[ip],
+ )
+ thl_web_rw.execute_write(
+ query="DELETE FROM thl_ipinformation WHERE ip = %s;",
+ params=[ip],
+ )
+
+ return _delete_ipinfo
+
+
+class TestMaxmindBasicManager:
+
+ def test_init(self, maxmind_basic_manager):
+
+ assert isinstance(maxmind_basic_manager, MaxmindBasicManager)
+
+ def test_get_basic_ip_information(self, maxmind_basic_manager):
+ ip = IP_v4_INDIA
+ maxmind_basic_manager.run_update_geoip_db()
+
+ res1 = maxmind_basic_manager.get_basic_ip_information(ip_address=ip)
+ assert isinstance(res1, geoip2.models.Country)
+ assert res1.country.iso_code == "IN"
+ assert res1.country.name == "India"
+
+ res2 = maxmind_basic_manager.get_basic_ip_information(
+ ip_address=fake.ipv4_private()
+ )
+ assert res2 is None
+
+ def test_get_country_iso_from_ip_geoip2db(self, maxmind_basic_manager):
+ ip = IP_v4_INDIA
+ maxmind_basic_manager.run_update_geoip_db()
+
+ res1 = maxmind_basic_manager.get_country_iso_from_ip_geoip2db(ip=ip)
+ assert res1 == "in"
+
+ res2 = maxmind_basic_manager.get_country_iso_from_ip_geoip2db(
+ ip=fake.ipv4_private()
+ )
+ assert res2 is None
+
+ def test_get_basic_ip_information_ipv6(self, maxmind_basic_manager):
+ ip = IP_v6_INDIA
+ maxmind_basic_manager.run_update_geoip_db()
+
+ res1 = maxmind_basic_manager.get_basic_ip_information(ip_address=ip)
+ assert isinstance(res1, geoip2.models.Country)
+ assert res1.country.iso_code == "IN"
+ assert res1.country.name == "India"
+
+
+class TestMaxmindManager:
+
+ def test_init(self, thl_web_rr, thl_redis_config, maxmind_manager: MaxmindManager):
+ instance = MaxmindManager(pg_config=thl_web_rr, redis_config=thl_redis_config)
+ assert isinstance(instance, MaxmindManager)
+ assert isinstance(maxmind_manager, MaxmindManager)
+
+ def test_create_basic(
+ self,
+ maxmind_manager: MaxmindManager,
+ geoipinfo_manager: GeoIpInfoManager,
+ delete_ipinfo,
+ ):
+ # This is (currently) an IP in India, and so it should only do the basic lookup
+ ip = IP_v4_INDIA
+ delete_ipinfo(ip)
+ geoipinfo_manager.clear_cache(ip)
+ assert geoipinfo_manager.get_cache(ip) is None
+ assert geoipinfo_manager.get_mysql_if_exists(ip) is None
+
+ maxmind_manager.run_ip_information(ip, force_insights=False)
+ # Check that it is in the cache and in mysql
+ res = geoipinfo_manager.get_cache(ip)
+ assert res.ip == ip
+ assert res.basic
+ res = geoipinfo_manager.get_mysql(ip)
+ assert res.ip == ip
+ assert res.basic
+
+ def test_create_basic_ipv6(
+ self,
+ maxmind_manager: MaxmindManager,
+ geoipinfo_manager: GeoIpInfoManager,
+ delete_ipinfo,
+ ):
+ # This is (currently) an IP in India, and so it should only do the basic lookup
+ ip = IP_v6_INDIA
+ normalized_ip, lookup_prefix = normalize_ip(ip)
+ delete_ipinfo(ip)
+ geoipinfo_manager.clear_cache(ip)
+ delete_ipinfo(normalized_ip)
+ geoipinfo_manager.clear_cache(normalized_ip)
+ assert geoipinfo_manager.get_cache(ip) is None
+ assert geoipinfo_manager.get_cache(normalized_ip) is None
+ assert geoipinfo_manager.get_mysql_if_exists(ip) is None
+ assert geoipinfo_manager.get_mysql_if_exists(normalized_ip) is None
+
+ maxmind_manager.run_ip_information(ip, force_insights=False)
+
+ # Check that it is in the cache
+ res = geoipinfo_manager.get_cache(ip)
+ # The looked up IP (/128) is returned,
+ assert res.ip == ip
+ assert res.lookup_prefix == "/64"
+ assert res.basic
+
+ # ... but the normalized version was stored (/64)
+ assert geoipinfo_manager.get_cache_raw(ip) is None
+ res = json.loads(geoipinfo_manager.get_cache_raw(normalized_ip))
+ assert res["ip"] == normalized_ip
+
+ # Check mysql
+ res = geoipinfo_manager.get_mysql(ip)
+ assert res.ip == ip
+ assert res.lookup_prefix == "/64"
+ assert res.basic
+ with pytest.raises(AssertionError):
+ geoipinfo_manager.get_mysql_raw(ip)
+ res = geoipinfo_manager.get_mysql_raw(normalized_ip)
+ assert res["ip"] == normalized_ip
+
+ def test_create_insights(
+ self,
+ maxmind_manager: MaxmindManager,
+ geoipinfo_manager: GeoIpInfoManager,
+ delete_ipinfo,
+ ):
+ # This is (currently) an IP in the US, so it should do insights
+ ip = IP_v4_US
+ delete_ipinfo(ip)
+ geoipinfo_manager.clear_cache(ip)
+ assert geoipinfo_manager.get_cache(ip) is None
+ assert geoipinfo_manager.get_mysql_if_exists(ip) is None
+
+ res1 = maxmind_manager.run_ip_information(ip, force_insights=False)
+ assert isinstance(res1, GeoIPInformation)
+
+ # Check that it is in the cache and in mysql
+ res2 = geoipinfo_manager.get_cache(ip)
+ assert isinstance(res2, GeoIPInformation)
+ assert res2.ip == ip
+ assert not res2.basic
+
+ res3 = geoipinfo_manager.get_mysql(ip)
+ assert isinstance(res3, GeoIPInformation)
+ assert res3.ip == ip
+ assert not res3.basic
+ assert res3.is_anonymous is False
+ assert res3.subdivision_1_name.lower() in US_STATES
+ # this might change ...
+ assert res3.user_type == UserType.CELLULAR
+
+ assert res1 == res2 == res3, "runner, cache, mysql all return same instance"
+
+ def test_create_insights_ipv6(
+ self,
+ maxmind_manager: MaxmindManager,
+ geoipinfo_manager: GeoIpInfoManager,
+ delete_ipinfo,
+ ):
+ # This is (currently) an IP in the US, so it should do insights
+ ip = IP_v6_US
+ normalized_ip, lookup_prefix = normalize_ip(ip)
+ delete_ipinfo(ip)
+ geoipinfo_manager.clear_cache(ip)
+ delete_ipinfo(normalized_ip)
+ geoipinfo_manager.clear_cache(normalized_ip)
+ assert geoipinfo_manager.get_cache(ip) is None
+ assert geoipinfo_manager.get_cache(normalized_ip) is None
+ assert geoipinfo_manager.get_mysql_if_exists(ip) is None
+ assert geoipinfo_manager.get_mysql_if_exists(normalized_ip) is None
+
+ res1 = maxmind_manager.run_ip_information(ip, force_insights=False)
+ assert isinstance(res1, GeoIPInformation)
+ assert res1.lookup_prefix == "/64"
+
+ # Check that it is in the cache and in mysql
+ res2 = geoipinfo_manager.get_cache(ip)
+ assert isinstance(res2, GeoIPInformation)
+ assert res2.ip == ip
+ assert not res2.basic
+
+ res3 = geoipinfo_manager.get_mysql(ip)
+ assert isinstance(res3, GeoIPInformation)
+ assert res3.ip == ip
+ assert not res3.basic
+ assert res3.is_anonymous is False
+ assert res3.subdivision_1_name.lower() in US_STATES
+ # this might change ...
+ assert res3.user_type == UserType.RESIDENTIAL
+
+ assert res1 == res2 == res3, "runner, cache, mysql all return same instance"
+
+ def test_get_or_create_ip_information(self, maxmind_manager):
+ ip = IP_v4_US
+
+ res1 = maxmind_manager.get_or_create_ip_information(ip_address=ip)
+ assert isinstance(res1, GeoIPInformation)
+
+ res2 = maxmind_manager.get_or_create_ip_information(
+ ip_address=fake.ipv4_private()
+ )
+ assert res2 is None
+
+ def test_get_or_create_ip_information_ipv6(
+ self, maxmind_manager, delete_ipinfo, geoipinfo_manager, caplog
+ ):
+ ip = IP_v6_US
+ normalized_ip, lookup_prefix = normalize_ip(ip)
+ delete_ipinfo(normalized_ip)
+ geoipinfo_manager.clear_cache(normalized_ip)
+
+ with caplog.at_level(logging.INFO):
+ res1 = maxmind_manager.get_or_create_ip_information(ip_address=ip)
+ assert isinstance(res1, GeoIPInformation)
+ assert res1.ip == ip
+ # It looks up in insight using the normalize IP!
+ assert f"get_insights_ip_information: {normalized_ip}" in caplog.text
+
+ # And it should NOT do the lookup again with an ipv6 in the same /64 block!
+ ip = IP_v6_US_SAME_64
+ caplog.clear()
+ with caplog.at_level(logging.INFO):
+ res2 = maxmind_manager.get_or_create_ip_information(ip_address=ip)
+ assert isinstance(res2, GeoIPInformation)
+ assert res2.ip == ip
+ assert "get_insights_ip_information" not in caplog.text
+
+ def test_run_ip_information(self, maxmind_manager):
+ ip = IP_v4_US
+
+ res = maxmind_manager.run_ip_information(ip_address=ip)
+ assert isinstance(res, GeoIPInformation)
+ assert res.country_name == "United States"
+ assert res.country_iso == "us"
diff --git a/tests/managers/thl/test_payout.py b/tests/managers/thl/test_payout.py
new file mode 100644
index 0000000..31087b8
--- /dev/null
+++ b/tests/managers/thl/test_payout.py
@@ -0,0 +1,1269 @@
+import logging
+import os
+from datetime import datetime, timezone, timedelta
+from decimal import Decimal
+from random import choice as rand_choice, randint
+from typing import Optional
+from uuid import uuid4
+
+import pandas as pd
+import pytest
+
+from generalresearch.currency import USDCent
+from generalresearch.managers.thl.ledger_manager.exceptions import (
+ LedgerTransactionConditionFailedError,
+)
+from generalresearch.managers.thl.payout import UserPayoutEventManager
+from generalresearch.models.thl.definitions import PayoutStatus
+from generalresearch.models.thl.ledger import LedgerEntry, Direction
+from generalresearch.models.thl.payout import BusinessPayoutEvent
+from generalresearch.models.thl.payout import UserPayoutEvent
+from generalresearch.models.thl.wallet import PayoutType
+from generalresearch.models.thl.ledger import LedgerAccount
+
+logger = logging.getLogger()
+
+cashout_method_uuid = uuid4().hex
+
+
+class TestPayout:
+
+ def test_get_by_uuid_and_create(
+ self,
+ user,
+ user_payout_event_manager: UserPayoutEventManager,
+ thl_lm,
+ utc_now,
+ ):
+
+ user_account: LedgerAccount = thl_lm.get_account_or_create_user_wallet(
+ user=user
+ )
+
+ pe1: UserPayoutEvent = user_payout_event_manager.create(
+ debit_account_uuid=user_account.uuid,
+ payout_type=PayoutType.PAYPAL,
+ cashout_method_uuid=cashout_method_uuid,
+ amount=100,
+ created=utc_now,
+ )
+
+ # these get added by the query
+ pe1.account_reference_type = "user"
+ pe1.account_reference_uuid = user.uuid
+ # pe1.description = "PayPal"
+
+ pe2 = user_payout_event_manager.get_by_uuid(pe_uuid=pe1.uuid)
+
+ assert pe1 == pe2
+
+ def test_update(self, user, user_payout_event_manager, lm, thl_lm, utc_now):
+ from generalresearch.models.thl.definitions import PayoutStatus
+ from generalresearch.models.thl.wallet import PayoutType
+
+ user_account = thl_lm.get_account_or_create_user_wallet(user=user)
+
+ pe1 = user_payout_event_manager.create(
+ status=PayoutStatus.PENDING,
+ debit_account_uuid=user_account.uuid,
+ payout_type=PayoutType.PAYPAL,
+ cashout_method_uuid=cashout_method_uuid,
+ amount=100,
+ created=utc_now,
+ )
+ user_payout_event_manager.update(
+ payout_event=pe1,
+ status=PayoutStatus.APPROVED,
+ order_data={"foo": "bar"},
+ ext_ref_id="abc",
+ )
+ pe = user_payout_event_manager.get_by_uuid(pe_uuid=pe1.uuid)
+ assert pe.status == PayoutStatus.APPROVED
+ assert pe.order_data == {"foo": "bar"}
+
+ with pytest.raises(expected_exception=AssertionError) as cm:
+ user_payout_event_manager.update(
+ payout_event=pe, status=PayoutStatus.PENDING
+ )
+ assert "status APPROVED can only be" in str(cm.value)
+
+ def test_create_bp_payout(
+ self,
+ user,
+ thl_web_rr,
+ user_payout_event_manager,
+ lm,
+ thl_lm,
+ product,
+ brokerage_product_payout_event_manager,
+ delete_ledger_db,
+ create_main_accounts,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+ from generalresearch.models.thl.ledger import LedgerAccount
+
+ thl_lm.get_account_or_create_bp_wallet(product=product)
+ brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+
+ with pytest.raises(expected_exception=LedgerTransactionConditionFailedError):
+ # wallet balance failure
+ brokerage_product_payout_event_manager.create_bp_payout_event(
+ thl_ledger_manager=thl_lm,
+ product=product,
+ amount=USDCent(100),
+ skip_wallet_balance_check=False,
+ skip_one_per_day_check=False,
+ )
+
+ # (we don't have a special method for this) Put money in the BP's account
+ amount_cents = 100
+ cash_account: LedgerAccount = thl_lm.get_account_cash()
+ bp_wallet: LedgerAccount = thl_lm.get_account_or_create_bp_wallet(
+ product=product
+ )
+
+ entries = [
+ LedgerEntry(
+ direction=Direction.DEBIT,
+ account_uuid=cash_account.uuid,
+ amount=amount_cents,
+ ),
+ LedgerEntry(
+ direction=Direction.CREDIT,
+ account_uuid=bp_wallet.uuid,
+ amount=amount_cents,
+ ),
+ ]
+
+ lm.create_tx(entries=entries)
+ assert 100 == lm.get_account_balance(account=bp_wallet)
+
+ # Then run it again for $1.00
+ brokerage_product_payout_event_manager.create_bp_payout_event(
+ thl_ledger_manager=thl_lm,
+ product=product,
+ amount=USDCent(100),
+ skip_wallet_balance_check=False,
+ skip_one_per_day_check=False,
+ )
+ assert 0 == lm.get_account_balance(account=bp_wallet)
+
+ # Run again should without balance check, should still fail due to day check
+ with pytest.raises(LedgerTransactionConditionFailedError):
+ brokerage_product_payout_event_manager.create_bp_payout_event(
+ thl_ledger_manager=thl_lm,
+ product=product,
+ amount=USDCent(100),
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=False,
+ )
+
+ # And then we can run again skip both checks
+ pe = brokerage_product_payout_event_manager.create_bp_payout_event(
+ thl_ledger_manager=thl_lm,
+ product=product,
+ amount=USDCent(100),
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ )
+ assert -100 == lm.get_account_balance(account=bp_wallet)
+
+ pe = brokerage_product_payout_event_manager.get_by_uuid(pe.uuid)
+ txs = lm.get_tx_filtered_by_metadata(
+ metadata_key="event_payout", metadata_value=pe.uuid
+ )
+
+ assert 1 == len(txs)
+
+ def test_create_bp_payout_quick_dupe(
+ self,
+ user,
+ product,
+ thl_web_rw,
+ brokerage_product_payout_event_manager,
+ thl_lm,
+ lm,
+ utc_now,
+ create_main_accounts,
+ ):
+ thl_lm.get_account_or_create_bp_wallet(product=product)
+ brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+
+ brokerage_product_payout_event_manager.create_bp_payout_event(
+ thl_ledger_manager=thl_lm,
+ product=product,
+ amount=USDCent(100),
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ created=utc_now,
+ )
+
+ with pytest.raises(ValueError) as cm:
+ brokerage_product_payout_event_manager.create_bp_payout_event(
+ thl_ledger_manager=thl_lm,
+ product=product,
+ amount=USDCent(100),
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ created=utc_now,
+ )
+ assert "Payout event already exists!" in str(cm.value)
+
+ def test_filter(
+ self,
+ thl_web_rw,
+ thl_lm,
+ lm,
+ product,
+ user,
+ user_payout_event_manager,
+ utc_now,
+ ):
+ from generalresearch.models.thl.definitions import PayoutStatus
+ from generalresearch.models.thl.wallet import PayoutType
+
+ user_account = thl_lm.get_account_or_create_user_wallet(user=user)
+ bp_account = thl_lm.get_account_or_create_bp_wallet(product=product)
+
+ user_payout_event_manager.create(
+ status=PayoutStatus.PENDING,
+ debit_account_uuid=user_account.uuid,
+ payout_type=PayoutType.PAYPAL,
+ cashout_method_uuid=cashout_method_uuid,
+ amount=100,
+ created=utc_now,
+ )
+
+ user_payout_event_manager.create(
+ status=PayoutStatus.PENDING,
+ debit_account_uuid=bp_account.uuid,
+ payout_type=PayoutType.PAYPAL,
+ cashout_method_uuid=cashout_method_uuid,
+ amount=200,
+ created=utc_now,
+ )
+
+ pes = user_payout_event_manager.filter_by(
+ reference_uuid=user.uuid, created=utc_now
+ )
+ assert 1 == len(pes)
+
+ pes = user_payout_event_manager.filter_by(
+ debit_account_uuids=[bp_account.uuid], created=utc_now
+ )
+ assert 1 == len(pes)
+
+ pes = user_payout_event_manager.filter_by(
+ debit_account_uuids=[bp_account.uuid], amount=123
+ )
+ assert 0 == len(pes)
+
+ pes = user_payout_event_manager.filter_by(
+ product_ids=[user.product_id], bp_user_ids=["x"]
+ )
+ assert 0 == len(pes)
+
+ pes = user_payout_event_manager.filter_by(product_ids=[user.product_id])
+ assert 1 == len(pes)
+
+ pes = user_payout_event_manager.filter_by(
+ cashout_types=[PayoutType.PAYPAL],
+ bp_user_ids=[user.product_user_id],
+ )
+ assert 1 == len(pes)
+
+ pes = user_payout_event_manager.filter_by(
+ statuses=[PayoutStatus.FAILED],
+ cashout_types=[PayoutType.PAYPAL],
+ bp_user_ids=[user.product_user_id],
+ )
+ assert 0 == len(pes)
+
+ pes = user_payout_event_manager.filter_by(
+ statuses=[PayoutStatus.PENDING],
+ cashout_types=[PayoutType.PAYPAL],
+ bp_user_ids=[user.product_user_id],
+ )
+ assert 1 == len(pes)
+
+
+class TestPayoutEventManager:
+
+ def test_set_account_lookup_table(
+ self, payout_event_manager, thl_redis_config, thl_lm, delete_ledger_db
+ ):
+ delete_ledger_db()
+ rc = thl_redis_config.create_redis_client()
+ rc.delete("pem:account_to_product")
+ rc.delete("pem:product_to_account")
+ N = 5
+
+ for idx in range(N):
+ thl_lm.get_account_or_create_bp_wallet_by_uuid(product_uuid=uuid4().hex)
+
+ res = rc.hgetall(name="pem:account_to_product")
+ assert len(res.items()) == 0
+
+ res = rc.hgetall(name="pem:product_to_account")
+ assert len(res.items()) == 0
+
+ payout_event_manager.set_account_lookup_table(
+ thl_lm=thl_lm,
+ )
+
+ res = rc.hgetall(name="pem:account_to_product")
+ assert len(res.items()) == N
+
+ res = rc.hgetall(name="pem:product_to_account")
+ assert len(res.items()) == N
+
+ thl_lm.get_account_or_create_bp_wallet_by_uuid(product_uuid=uuid4().hex)
+ payout_event_manager.set_account_lookup_table(
+ thl_lm=thl_lm,
+ )
+
+ res = rc.hgetall(name="pem:account_to_product")
+ assert len(res.items()) == N + 1
+
+ res = rc.hgetall(name="pem:product_to_account")
+ assert len(res.items()) == N + 1
+
+
+class TestBusinessPayoutEventManager:
+
+ @pytest.fixture
+ def start(self) -> "datetime":
+ return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc)
+
+ @pytest.fixture
+ def offset(self) -> str:
+ return "5d"
+
+ @pytest.fixture
+ def duration(self) -> Optional["timedelta"]:
+ return timedelta(days=10)
+
+ def test_base(
+ self,
+ brokerage_product_payout_event_manager,
+ business_payout_event_manager,
+ delete_ledger_db,
+ create_main_accounts,
+ thl_lm,
+ thl_web_rr,
+ product_factory,
+ bp_payout_factory,
+ business,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+
+ from generalresearch.models.thl.product import Product
+
+ p1: Product = product_factory(business=business)
+ thl_lm.get_account_or_create_bp_wallet(product=p1)
+ business_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+
+ ach_id1 = uuid4().hex
+ ach_id2 = uuid4().hex
+
+ bp_payout_factory(
+ product=p1,
+ amount=USDCent(1),
+ ext_ref_id=None,
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ )
+
+ bp_payout_factory(
+ product=p1,
+ amount=USDCent(1),
+ ext_ref_id=ach_id1,
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ )
+
+ bp_payout_factory(
+ product=p1,
+ amount=USDCent(25),
+ ext_ref_id=ach_id1,
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ )
+
+ bp_payout_factory(
+ product=p1,
+ amount=USDCent(50),
+ ext_ref_id=ach_id2,
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ )
+
+ business.prebuild_payouts(
+ thl_pg_config=thl_web_rr,
+ thl_lm=thl_lm,
+ bpem=business_payout_event_manager,
+ )
+
+ assert len(business.payouts) == 3
+ assert business.payouts_total == sum([pe.amount for pe in business.payouts])
+ assert business.payouts[0].created > business.payouts[1].created
+ assert len(business.payouts[0].bp_payouts) == 1
+ assert len(business.payouts[1].bp_payouts) == 2
+
+ assert business.payouts[0].ext_ref_id == ach_id2
+ assert business.payouts[1].ext_ref_id == ach_id1
+ assert business.payouts[2].ext_ref_id is None
+
+ def test_update_ext_reference_ids(
+ self,
+ brokerage_product_payout_event_manager,
+ business_payout_event_manager,
+ delete_ledger_db,
+ create_main_accounts,
+ thl_lm,
+ thl_web_rr,
+ product_factory,
+ bp_payout_factory,
+ delete_df_collection,
+ user_factory,
+ ledger_collection,
+ session_with_tx_factory,
+ pop_ledger_merge,
+ client_no_amm,
+ mnt_filepath,
+ lm,
+ product_manager,
+ start,
+ business,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+ delete_df_collection(coll=ledger_collection)
+
+ from generalresearch.models.thl.product import Product
+ from generalresearch.models.thl.user import User
+
+ p1: Product = product_factory(business=business)
+ u1: User = user_factory(product=p1)
+ thl_lm.get_account_or_create_bp_wallet(product=p1)
+ business_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+
+ # $250.00 to work with
+ for idx in range(1, 10):
+ session_with_tx_factory(
+ user=u1,
+ wall_req_cpi=Decimal("25.00"),
+ started=start + timedelta(days=1, minutes=idx),
+ )
+
+ ach_id1 = uuid4().hex
+ ach_id2 = uuid4().hex
+
+ with pytest.raises(expected_exception=Warning) as cm:
+ business_payout_event_manager.update_ext_reference_ids(
+ new_value=ach_id2,
+ current_value=ach_id1,
+ )
+ assert "No event_payouts found to UPDATE" in str(cm)
+
+ # We must build the balance to issue ACH/Wire
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+ business.prebuild_balance(
+ thl_pg_config=thl_web_rr,
+ lm=lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ pop_ledger=pop_ledger_merge,
+ )
+
+ res = business_payout_event_manager.create_from_ach_or_wire(
+ business=business,
+ amount=USDCent(100_01),
+ pm=product_manager,
+ thl_lm=thl_lm,
+ transaction_id=ach_id1,
+ )
+ assert isinstance(res, BusinessPayoutEvent)
+
+ # Okay, now that there is a payout_event, let's try to update the
+ # ext_reference_id
+ business_payout_event_manager.update_ext_reference_ids(
+ new_value=ach_id2,
+ current_value=ach_id1,
+ )
+
+ res = business_payout_event_manager.filter_by(ext_ref_id=ach_id1)
+ assert len(res) == 0
+
+ res = business_payout_event_manager.filter_by(ext_ref_id=ach_id2)
+ assert len(res) == 1
+
+ def test_delete_failed_business_payout(
+ self,
+ brokerage_product_payout_event_manager,
+ business_payout_event_manager,
+ delete_ledger_db,
+ create_main_accounts,
+ thl_lm,
+ thl_web_rr,
+ product_factory,
+ bp_payout_factory,
+ currency,
+ delete_df_collection,
+ user_factory,
+ ledger_collection,
+ session_with_tx_factory,
+ pop_ledger_merge,
+ client_no_amm,
+ mnt_filepath,
+ lm,
+ product_manager,
+ start,
+ business,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+ delete_df_collection(coll=ledger_collection)
+
+ from generalresearch.models.thl.product import Product
+ from generalresearch.models.thl.user import User
+
+ p1: Product = product_factory(business=business)
+ u1: User = user_factory(product=p1)
+ thl_lm.get_account_or_create_bp_wallet(product=p1)
+ business_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+
+ # $250.00 to work with
+ for idx in range(1, 10):
+ session_with_tx_factory(
+ user=u1,
+ wall_req_cpi=Decimal("25.00"),
+ started=start + timedelta(days=1, minutes=idx),
+ )
+
+ # We must build the balance to issue ACH/Wire
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+ business.prebuild_balance(
+ thl_pg_config=thl_web_rr,
+ lm=lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ pop_ledger=pop_ledger_merge,
+ )
+
+ ach_id1 = uuid4().hex
+
+ res = business_payout_event_manager.create_from_ach_or_wire(
+ business=business,
+ amount=USDCent(100_01),
+ pm=product_manager,
+ thl_lm=thl_lm,
+ transaction_id=ach_id1,
+ )
+ assert isinstance(res, BusinessPayoutEvent)
+
+ # (1) Confirm the initial Event Payout, Tx, TxMeta, TxEntry all exist
+ event_payouts = business_payout_event_manager.filter_by(ext_ref_id=ach_id1)
+ event_payout_uuids = [i.uuid for i in event_payouts]
+ assert len(event_payout_uuids) == 1
+ tags = [f"{currency.value}:bp_payout:{x}" for x in event_payout_uuids]
+ transactions = thl_lm.get_txs_by_tags(tags=tags)
+ assert len(transactions) == 1
+ tx_metadata_ids = thl_lm.get_tx_metadata_ids_by_txs(transactions=transactions)
+ assert len(tx_metadata_ids) == 2
+ tx_entries = thl_lm.get_tx_entries_by_txs(transactions=transactions)
+ assert len(tx_entries) == 2
+
+ # (2) Delete!
+ business_payout_event_manager.delete_failed_business_payout(
+ ext_ref_id=ach_id1, thl_lm=thl_lm
+ )
+
+ # (3) Confirm the initial Event Payout, Tx, TxMeta, TxEntry have
+ # all been deleted
+ res = business_payout_event_manager.filter_by(ext_ref_id=ach_id1)
+ assert len(res) == 0
+
+ # Note: b/c the event_payout shouldn't exist anymore, we are taking
+ # the tag strings and transactions from when they did..
+ res = thl_lm.get_txs_by_tags(tags=tags)
+ assert len(res) == 0
+
+ tx_metadata_ids = thl_lm.get_tx_metadata_ids_by_txs(transactions=transactions)
+ assert len(tx_metadata_ids) == 0
+ tx_entries = thl_lm.get_tx_entries_by_txs(transactions=transactions)
+ assert len(tx_entries) == 0
+
+ def test_recoup_empty(self, business_payout_event_manager):
+ res = {uuid4().hex: USDCent(0) for i in range(100)}
+ df = pd.DataFrame.from_dict(res, orient="index").reset_index()
+ df.columns = ["product_id", "available_balance"]
+
+ with pytest.raises(expected_exception=ValueError) as cm:
+ business_payout_event_manager.recoup_proportional(
+ df=df, target_amount=USDCent(1)
+ )
+ assert "Total available amount is empty, cannot recoup" in str(cm)
+
+ def test_recoup_exceeds(self, business_payout_event_manager):
+ from random import randint
+
+ res = {uuid4().hex: USDCent(randint(a=0, b=1_000_00)) for i in range(100)}
+ df = pd.DataFrame.from_dict(res, orient="index").reset_index()
+ df.columns = ["product_id", "available_balance"]
+
+ avail_balance = USDCent(int(df.available_balance.sum()))
+
+ with pytest.raises(expected_exception=ValueError) as cm:
+ business_payout_event_manager.recoup_proportional(
+ df=df, target_amount=avail_balance + USDCent(1)
+ )
+ assert " exceeds total available " in str(cm)
+
+ def test_recoup(self, business_payout_event_manager):
+ from random import randint, random
+
+ res = {uuid4().hex: USDCent(randint(a=0, b=1_000_00)) for i in range(100)}
+ df = pd.DataFrame.from_dict(res, orient="index").reset_index()
+ df.columns = ["product_id", "available_balance"]
+
+ avail_balance = USDCent(int(df.available_balance.sum()))
+ random_recoup_amount = USDCent(1 + int(int(avail_balance) * random() * 0.5))
+
+ res = business_payout_event_manager.recoup_proportional(
+ df=df, target_amount=random_recoup_amount
+ )
+
+ assert isinstance(res, pd.DataFrame)
+ assert res.weight.sum() == pytest.approx(1)
+ assert res.deduction.sum() == random_recoup_amount
+ assert res.remaining_balance.sum() == avail_balance - random_recoup_amount
+
+ def test_recoup_loop(self, business_payout_event_manager, request):
+ # TODO: Generate this file at random
+ fp = os.path.join(
+ request.config.rootpath, "data/pytest_recoup_proportional.csv"
+ )
+ df = pd.read_csv(fp, index_col=0)
+ res = business_payout_event_manager.recoup_proportional(
+ df=df, target_amount=USDCent(1416089)
+ )
+
+ res = res[res["remaining_balance"] > 0]
+
+ assert int(res.deduction.sum()) == 1416089
+
+ def test_recoup_loop_single_profitable_account(self, business_payout_event_manager):
+ res = [{"product_id": uuid4().hex, "available_balance": 0} for i in range(1000)]
+ for x in range(100):
+ item = rand_choice(res)
+ item["available_balance"] = randint(8, 12)
+
+ df = pd.DataFrame(res)
+ res = business_payout_event_manager.recoup_proportional(
+ df=df, target_amount=USDCent(500)
+ )
+ # res = res[res["remaining_balance"] > 0]
+ assert int(res.deduction.sum()) == 500
+
+ def test_recoup_loop_assertions(self, business_payout_event_manager):
+ df = pd.DataFrame(
+ [
+ {
+ "product_id": uuid4().hex,
+ "available_balance": randint(0, 999_999),
+ }
+ for i in range(10_000)
+ ]
+ )
+ available_balance = int(df.available_balance.sum())
+
+ # Exact amount
+ res = business_payout_event_manager.recoup_proportional(
+ df=df, target_amount=available_balance
+ )
+ assert res.remaining_balance.sum() == 0
+ assert int(res.deduction.sum()) == available_balance
+
+ # Slightly less
+ res = business_payout_event_manager.recoup_proportional(
+ df=df, target_amount=available_balance - 1
+ )
+ assert res.remaining_balance.sum() == 1
+ assert int(res.deduction.sum()) == available_balance - 1
+
+ # Slightly less
+ with pytest.raises(expected_exception=Exception) as cm:
+ res = business_payout_event_manager.recoup_proportional(
+ df=df, target_amount=available_balance + 1
+ )
+
+ # Don't pull anything
+ res = business_payout_event_manager.recoup_proportional(df=df, target_amount=0)
+ assert res.remaining_balance.sum() == available_balance
+ assert int(res.deduction.sum()) == 0
+
+ def test_distribute_amount(self, business_payout_event_manager):
+ import io
+
+ df = pd.read_csv(
+ io.StringIO(
+ "product_id,available_balance,weight,raw_deduction,deduction,remainder,remaining_balance\n0,15faf,11768,0.0019298788807489663,222.3374860933269,223,0.33748609332690194,11545\n5,793e3,202,3.312674489388946e-05,3.8164660257352168,3,0.8164660257352168,199\n6,c703b,22257,0.0036500097084321667,420.5103184890531,421,0.510318489053077,21836\n13,72a70,1424,0.00023352715212326036,26.90419614181658,27,0.9041961418165805,1397\n14,86173,45634,0.007483692457860156,862.1812406851528,863,0.18124068515282943,44771\n17,4f230,143676,0.02356197128403199,2714.5275876907576,2715,0.5275876907576276,140961\n18,e1ee6,129,2.1155198471840298e-05,2.437248105543777,2,0.43724810554377713,127\n22,4524a,85613,0.014040000052478012,1617.5203260458868,1618,0.5203260458868044,83995\n25,4f5e2,30282,0.004966059845924558,572.1298227292765,573,0.12982272927649774,29709\n28,0f3b3,135,2.213916119146078e-05,2.5506084825458135,2,0.5506084825458135,133\n29,15c6f,1226,0.00020105638237578454,23.163303700749385,23,0.16330370074938472,1203\n31,c0c04,376,6.166166376288335e-05,7.103916958794265,7,0.10391695879426521,369\n33,2934c,37649,0.006174202071831903,711.3174722916099,712,0.31747229160987445,36937\n38,5585d,16471,0.0027011416591448184,311.1931282667562,312,0.19312826675621864,16159\n42,0a749,663,0.00010872788051806293,12.526321658724994,12,0.5263216587249939,651\n43,9e336,322,5.280599928629904e-05,6.08367356577594,6,0.08367356577593998,316\n46,043a5,11656,0.001911511576649384,220.22142572262223,221,0.2214257226222287,11435\n47,d6f4e,39,6.3957576775331136e-06,0.7368424505132349,0,0.7368424505132349,39\n48,3e123,3012,0.0004939492852494804,56.90690925502214,57,0.9069092550221427,2955\n49,76ccc,76,1.24635277818594e-05,1.4358981086924578,1,0.4358981086924578,75\n51,8acb9,7710,0.0012643920947123155,145.66808444761645,146,0.6680844476164509,7564\n56,fef3e,212,3.476668275992359e-05,4.005399987405277,4,0.005399987405277251,208\n57,6d7c2,455709,0.07473344449925481,8609.890673870148,8610,0.8906738701480208,447099\n58,f51f6,257,4.2146403157077186e-05,4.855602814920548,4,0.8556028149205481,253\n61,06acf,84310,0.013826316148533765,1592.902230840278,1593,0.90223084027798,82717\n62,6eca7,40,6.559751464136527e-06,0.755735846680241,0,0.755735846680241,40\n68,4a415,1955,0.00032060785280967275,36.93658950649678,37,0.9365895064967802,1918\n69,57a16,409,6.707345872079598e-05,7.727399032305463,7,0.7273990323054633,402\n70,c3ef6,593,9.724831545582401e-05,11.203783927034571,11,0.20378392703457138,582\n71,385da,825,0.00013529487394781586,15.587051837779969,15,0.5870518377799687,810\n72,a8435,748,0.00012266735237935304,14.132260332920506,14,0.13226033292050587,734\n75,c9374,263383,0.043193175496966774,4976.199362654548,4977,0.19936265454816748,258406\n76,7fcc7,136,2.230315497806419e-05,2.569501878712819,2,0.5695018787128192,134\n77,26aec,356,5.838178803081509e-05,6.726049035454145,6,0.7260490354541451,350\n78,76cc5,413,6.772943386720964e-05,7.802972616973489,7,0.8029726169734888,406\n82,9476a,13973,0.002291485180209492,263.99742464157515,264,0.9974246415751509,13709\n85,ee099,2397,0.0003930931064883814,45.28747061231344,46,0.2874706123134416,2351\n87,24bd5,122295,0.020055620132664414,2310.567884244002,2311,0.5678842440020162,119984\n91,1fa8f,4,6.559751464136527e-07,0.0755735846680241,0,0.0755735846680241,4\n92,53b5a,1,1.6399378660341317e-07,0.018893396167006023,0,0.018893396167006023,1\n93,f4f9e,201,3.296275110728605e-05,3.797572629568211,3,0.7975726295682111,198\n95,ff5d7,21317,0.0034958555490249587,402.75052609206745,403,0.7505260920674459,20914\n96,6290d,80,1.3119502928273054e-05,1.511471693360482,1,0.5114716933604819,79\n100,9c34b,1870,0.0003066683809483826,35.33065083230127,36,0.33065083230126646,1834\n101,d32a9,11577,0.0018985560675077143,218.72884742542874,219,0.7288474254287394,11358\n102,a6001,8981,0.0014728281974852537,169.6815909758811,170,0.6815909758811074,8811\n106,0a1ee,9,1.4759440794307186e-06,0.17004056550305424,0,0.17004056550305424,9\n108,8ad36,51,8.363683116774072e-06,0.9635632045173074,0,0.9635632045173074,51\n111,389ab,75,1.2299533995255988e-05,1.417004712525452,1,0.4170047125254519,74\n114,be86b,4831,0.000792253983081089,91.2739968828061,92,0.27399688280610235,4739\n118,99e96,271,4.444231616952497e-05,5.120110361258632,5,0.12011036125863228,266\n121,3b729,12417,0.0020363108482545815,234.59930020571383,235,0.5993002057138312,12182\n122,f16e2,2697,0.0004422912424694053,50.95548946241525,51,0.955489462415251,2646\n125,241c5,2,3.2798757320682633e-07,0.03778679233401205,0,0.03778679233401205,2\n127,523af,569,9.33124645773421e-05,10.750342419026428,10,0.7503424190264276,559\n130,dce2e,31217,0.005119394036398749,589.7951481454271,590,0.7951481454271061,30627\n133,214b8,37360,0.006126807867503516,705.8572807993451,706,0.8572807993450624,36654\n136,03c88,35,5.739782531119461e-06,0.6612688658452109,0,0.6612688658452109,35\n137,08021,44828,0.007351513465857806,846.9531633745461,847,0.9531633745460795,43981\n144,bc3d9,1174,0.00019252870547240705,22.180847100065073,22,0.18084710006507265,1152\n148,bac2f,3745,0.0006141567308297824,70.75576864543757,71,0.7557686454375698,3674\n150,d9e69,8755,0.0014357656017128823,165.41168344213776,166,0.41168344213775754,8589\n151,36b18,1,1.6399378660341317e-07,0.018893396167006023,0,0.018893396167006023,1\n152,e15d7,5259,0.0008624433237473498,99.36037044228468,100,0.3603704422846761,5159\n155,bad23,16504,0.002706553454102731,311.81661034026746,312,0.8166103402674594,16192\n159,9f77f,2220,0.00036406620625957726,41.943339490753374,42,0.9433394907533739,2178\n160,945ab,131,2.1483186045047126e-05,2.4750348978777894,2,0.4750348978777894,129\n161,3e2dc,354,5.8053800457608265e-05,6.688262243120133,6,0.6882622431201328,348\n162,62811,1608,0.0002637020088582884,30.38058103654569,31,0.38058103654568853,1577\n164,59a3d,2541,0.0004167082117592729,48.00811966036231,49,0.008119660362311265,2492\n165,871ca,8,1.3119502928273053e-06,0.1511471693360482,0,0.1511471693360482,8\n167,55d38,683,0.00011200775625013119,12.904189582065115,12,0.9041895820651149,671\n169,1e5af,6,9.83962719620479e-07,0.11336037700203613,0,0.11336037700203613,6\n170,b2901,15994,0.0026229166229349904,302.18097829509435,303,0.1809782950943486,15691\n174,dc880,2435,0.00039932487037931105,46.00541966665967,47,0.005419666659669531,2388\n176,d189e,34579,0.005670741146959424,653.3147460589013,654,0.31474605890127805,33925\n177,d35f5,41070,0.006735224815802179,775.9517805789375,776,0.9517805789374734,40294\n178,bd7a6,12563,0.0020602539410986796,237.35773604609668,238,0.35773604609667586,12325\n180,662cd,9675,0.0015866398853880224,182.79360791578327,183,0.7936079157832694,9492\n181,e995f,1011,0.00016579771825605072,19.101223524843093,19,0.1012235248430926,992\n189,0faf8,626,0.00010266011041373665,11.827266000545771,11,0.8272660005457713,615\n190,1fd84,213,3.4930676546527005e-05,4.024293383572283,4,0.02429338357228339,209\n193,c9b44,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n194,bbd8d,16686,0.0027364003232645522,315.25520844266254,316,0.2552084426625356,16370\n196,d4945,2654,0.00043523950964545856,50.14307342723399,51,0.1430734272339933,2603\n197,bbde8,3043,0.0004990330926341863,57.49260453619934,58,0.4926045361993374,2985\n200,579ad,20833,0.0034164825563089067,393.6061223472365,394,0.6061223472365214,20439\n202,2f932,15237,0.0024987733264762065,287.8786773966708,288,0.8786773966708097,14949\n208,b649d,103551,0.01698172059657004,1956.430066489641,1957,0.43006648964092165,101594\n209,0f939,26025,0.0042679382963538275,491.7006352463318,492,0.7006352463317853,25533\n211,638d8,6218,0.0010197133651000231,117.47913736644347,118,0.47913736644346727,6100\n215,d2a81,8301,0.0013613124225949327,156.834081582317,157,0.8340815823169976,8144\n216,62293,4,6.559751464136527e-07,0.0755735846680241,0,0.0755735846680241,4\n218,c8ae9,2829,0.00046393842230105583,53.44941775646004,54,0.44941775646004345,2775\n219,83a9f,6556,0.0010751432649719768,123.8651052708915,124,0.8651052708915046,6432\n221,d256a,72,1.1807552635445749e-05,1.360324524024434,1,0.36032452402443393,71\n222,6fdc2,7,1.1479565062238922e-06,0.1322537731690422,0,0.1322537731690422,7\n224,56b88,146928,0.02409527907806629,2775.9689120258613,2776,0.9689120258613002,144152\n230,f50f8,5798,0.0009508359747265895,109.54391097630092,110,0.5439109763009213,5688\n231,fa3be,2,3.2798757320682633e-07,0.03778679233401205,0,0.03778679233401205,2\n232,94934,537381,0.08812714503872877,10152.952125621865,10153,0.9521256218649796,527228\n234,4e20c,5,8.199689330170659e-07,0.09446698083503012,0,0.09446698083503012,5\n235,a5d68,31101,0.005100370757152753,587.6035141900543,588,0.603514190054284,30513\n236,e5a29,3208,0.0005260920674237494,60.61001490375533,61,0.610014903755328,3147\n237,0ce0f,294,4.821417326140347e-05,5.554658473099771,5,0.5546584730997708,289\n240,66d2b,18633,0.0030556962257813976,352.04065077982324,353,0.04065077982323828,18280\n244,1bd17,1815,0.0002976487226851949,34.29151404311594,35,0.29151404311593865,1780\n245,32aca,224,3.673460819916455e-05,4.23212074140935,4,0.23212074140935002,220\n247,cbf8e,4747,0.0007784785050064023,89.6869516047776,90,0.6869516047776045,4657\n249,8f24b,807633,0.13244679385587438,15258.930226547576,15259,0.9302265475762397,792374\n251,c97c3,20526,0.0033661364638216586,387.80584972396565,388,0.8058497239656504,20138\n252,88fee,13821,0.0022665581246457734,261.12562842419027,262,0.12562842419026765,13559\n253,a9ad3,178,2.9190894015407545e-05,3.3630245177270726,3,0.36302451772707256,175\n254,83738,104,1.705535380675497e-05,1.9649132013686266,1,0.9649132013686266,103\n255,21f6c,6288,0.001031192930162262,118.8016750981339,119,0.8016750981338987,6169\n256,97e28,6,9.83962719620479e-07,0.11336037700203613,0,0.11336037700203613,6\n257,7f689,39,6.3957576775331136e-06,0.7368424505132349,0,0.7368424505132349,39\n258,e7a50,28031,0.004596909832280275,529.6007879573459,530,0.600787957345915,27501\n259,2eb98,1,1.6399378660341317e-07,0.018893396167006023,0,0.018893396167006023,1\n260,349ba,19518,0.0032008307269254183,368.76130638762356,369,0.7613063876235628,19149\n261,a3d04,235,3.853853985180209e-05,4.439948099246416,4,0.43994809924641576,231\n262,40971,2249,0.0003688220260710762,42.49124797959655,43,0.491247979596551,2206\n264,4f588,6105,0.0010011820672138373,115.34418359957178,116,0.34418359957177813,5989\n269,f182e,1020,0.00016727366233548145,19.271264090346147,19,0.27126409034614696,1001\n270,3798e,5168,0.0008475198891664393,97.64107139108714,98,0.641071391087138,5070\n274,81dc3,274,4.493429752933521e-05,5.176790549759651,5,0.1767905497596507,269\n285,8520b,2,3.2798757320682633e-07,0.03778679233401205,0,0.03778679233401205,2\n287,37b89,3742,0.0006136647494699721,70.69908845693654,71,0.6990884569365363,3671\n291,63706,4740,0.0007773305485001784,89.55469783160856,90,0.5546978316085642,4650\n293,21c9d,241,3.9522502571422574e-05,4.553308476248452,4,0.5533084762484517,237\n296,3a42a,610,0.00010003620982808204,11.524971661873675,11,0.5249716618736748,599\n302,31a8f,148533,0.02435848910556477,2806.292812873906,2807,0.2928128739058593,145726\n303,38467,2399,0.0003934210940615882,45.32525740464745,46,0.3252574046474521,2353\n304,a49a6,26573,0.004357806891412498,502.0542163458511,503,0.05421634585110269,26070\n305,3c4e7,2286,0.0003748897961754025,43.19030363777577,44,0.19030363777577008,2242\n306,63986,34132,0.005597435924347699,644.8693979722497,645,0.8693979722496579,33487\n307,640f4,50,8.199689330170658e-06,0.9446698083503012,0,0.9446698083503012,50\n309,ba81a,3015,0.0004944412666092907,56.96358944352316,57,0.963589443523162,2958\n310,b7e8e,11409,0.0018710051113583408,215.55475686937172,216,0.5547568693717153,11193\n311,98301,5694,0.0009337806209198346,107.5789977749323,108,0.5789977749323043,5586\n312,e70bd,19,3.11588194546485e-06,0.35897452717311445,0,0.35897452717311445,19\n314,adf9c,5023,0.0008237407901089443,94.90152894687125,95,0.9015289468712524,4928\n316,ab5f4,55,9.019658263187724e-06,1.0391367891853314,1,0.039136789185331367,54\n318,d4a4c,2242,0.0003676740695648523,42.358994206427504,43,0.35899420642750357,2199\n319,b4062,222,3.640662062595773e-05,4.194333949075338,4,0.19433394907533774,218\n321,cb691,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n322,b83ed,3716,0.0006094009110182833,70.20786015659439,71,0.20786015659439272,3645\n323,8759e,5964,0.0009780589433027562,112.68021474002393,113,0.6802147400239278,5851\n325,2f217,3625,0.0005944774764373728,68.48856110539684,69,0.4885611053968404,3556\n326,d683f,28156,0.004617409055605701,531.9624624782216,532,0.962462478221596,27624\n327,97cf7,928,0.00015218623396796743,17.533071642981593,17,0.5330716429815929,911\n328,75135,2841,0.0004659063477402968,53.67613851046411,54,0.6761385104641136,2787\n329,d7c10,29913,0.0049055461386678986,565.1581595436512,566,0.15815954365120888,29347\n330,8598e,66,1.082358991582527e-05,1.2469641470223976,1,0.24696414702239755,65\n331,5e72b,10825,0.0017752327399819475,204.52101350784022,205,0.5210135078402232,10620\n332,30bff,975,0.00015989394193832783,18.42106126283087,18,0.42106126283087164,957\n333,f79d7,696,0.00011413967547597557,13.149803732236194,13,0.1498037322361938,683\n334,d0a0d,2121,0.0003478308213858393,40.07289327021978,41,0.07289327021977954,2080\n335,22ec5,3942,0.0006464635067906547,74.47776769033774,75,0.4777676903377426,3867\n336,efeac,515,8.445680010075779e-05,9.730099026008103,9,0.7300990260081033,506\n337,c3854,105882,0.017363990113142592,2000.4705729549316,2001,0.4705729549316402,103881\n338,cd2a3,42924,0.007039269296164907,810.9801370725665,811,0.9801370725665493,42113\n339,d7333,5089,0.0008345643800247697,96.14849309389366,97,0.14849309389366283,4992\n340,7d48c,33,5.411794957912635e-06,0.6234820735111988,0,0.6234820735111988,33\n341,6a148,100,1.6399378660341316e-05,1.8893396167006025,1,0.8893396167006025,99\n342,ba8ff,400,6.559751464136527e-05,7.55735846680241,7,0.5573584668024099,393\n343,38810,3194,0.0005237961544113017,60.34550735741724,61,0.34550735741724026,3133\n344,04d6e,43,7.051732823946766e-06,0.812416035181259,0,0.812416035181259,43\n345,d4f05,1394,0.00022860733852515797,26.337394256806398,27,0.3373942568063981,1367\n346,53f71,1325,0.00021729176724952245,25.033749921282983,26,0.0337499212829826,1299\n347,e4d06,60,9.83962719620479e-06,1.1336037700203614,1,0.1336037700203614,59\n348,8ab28,5,8.199689330170659e-07,0.09446698083503012,0,0.09446698083503012,5\n349,f21e0,65,1.0659596129221857e-05,1.2280707508553916,1,0.22807075085539164,64\n350,02fc2,2202,0.0003611143181007158,41.603258359747265,42,0.6032583597472652,2160\n351,310c0,18370,0.0030125658599047,347.07168758790067,348,0.07168758790066931,18022\n352,fceae,169,2.7714949935976826e-05,3.192983952224018,3,0.1929839522240182,166\n353,25c52,256,4.198240937047377e-05,4.836709418753542,4,0.836709418753542,252\n354,2e0cc,7029,0.0011527123260353911,132.80168165788535,133,0.8016816578853536,6896\n355,4baa9,405,6.641748357438233e-05,7.65182544763744,7,0.6518254476374397,398\n356,f2cb1,2236,0.00036669010684523185,42.24563382942547,43,0.24563382942547207,2193\n357,c4f9f,69,1.1315571275635508e-05,1.3036443355234157,1,0.30364433552341574,68\n358,0fe2b,33,5.411794957912635e-06,0.6234820735111988,0,0.6234820735111988,33\n359,ba2cc,911,0.0001493983395957094,17.21188390814249,17,0.21188390814248947,894\n360,22c5c,84,1.3775478074686707e-05,1.587045278028506,1,0.587045278028506,83\n361,b0369,42,6.8877390373433535e-06,0.793522639014253,0,0.793522639014253,42\n362,a7d03,12344,0.002024339301832532,233.22008228552235,234,0.2200822855223521,12110\n363,f6d7e,11703,0.0019192192846197442,221.1094153424715,222,0.1094153424714932,11481\n364,c281d,50376,0.008261350993933542,951.7737253090955,952,0.7737253090955392,49424\n365,dfbd6,3932,0.0006448235689246206,74.2888337286677,75,0.2888337286676972,3857\n366,a34ba,3909,0.0006410517118327421,73.85428561682654,74,0.8542856168265445,3835\n367,157ee,253,4.149042801066353e-05,4.780029230252524,4,0.7800292302525236,249\n369,33a29,6954,0.0011404127920401352,131.3846769453599,132,0.3846769453598995,6822\n370,2bd95,11279,0.001849685919099897,213.09861536766095,214,0.09861536766095469,11065\n371,138b7,2094,0.00034340298914754716,39.56277157371061,40,0.5627715737106129,2054\n372,eea6f,3884,0.0006369518671676568,73.3819507126514,74,0.38195071265140257,3810\n373,8f39b,1046,0.00017153750078717018,19.7624923906883,19,0.7624923906883012,1027\n374,bacfb,773,0.00012676719704443837,14.604595237095657,14,0.6045952370956567,759\n375,23403,855,0.00014021468754591825,16.15385372279015,16,0.1538537227901493,839\n376,0ad32,1630,0.00026730987216356345,30.79623575221982,31,0.7962357522198182,1599\n377,6be35,21,3.4438695186716767e-06,0.3967613195071265,0,0.3967613195071265,21\n378,12402,2994,0.000490997397090619,56.566828124016034,57,0.566828124016034,2937\n379,bae61,40980,0.006720465375007872,774.251374923907,775,0.25137492390695115,40205\n380,384f7,23545,0.003861233705577363,444.84501275215683,445,0.8450127521568334,23100\n381,03ed6,115,1.8859285459392513e-05,2.1727405592056925,2,0.17274055920569253,113\n382,cc4c8,3529,0.0005787340729234451,66.67479507336427,67,0.6747950733642654,3462\n383,263fa,4019,0.0006590910283591176,75.93255919519721,76,0.9325591951972143,3943\n384,704a1,11504,0.001886584521085665,217.3496295052373,218,0.34962950523728864,11286\n385,a0efb,7960,0.0013053905413631687,150.39143348936796,151,0.3914334893679552,7809\n386,b3197,981,0.00016087790465794832,18.53442163983291,18,0.5344216398329102,963\n387,1c91d,174,2.8534918868993892e-05,3.2874509330590485,3,0.28745093305904845,171\n388,1afff,797,0.0001307030479229203,15.058036745103802,15,0.05803674510380219,782\n389,20304,10605,0.0017391541069291968,200.3644663510989,201,0.3644663510989119,10404\n390,638fd,79,1.295550914166964e-05,1.4925782971934758,1,0.4925782971934758,78\n391,8c258,11,1.8039316526375448e-06,0.20782735783706627,0,0.20782735783706627,11\n392,c1847,20,3.2798757320682634e-06,0.3778679233401205,0,0.3778679233401205,20\n393,d72ec,2445,0.0004009648082453452,46.19435362832973,47,0.19435362832972913,2398\n394,d83e3,60,9.83962719620479e-06,1.1336037700203614,1,0.1336037700203614,59\n395,04a4f,60,9.83962719620479e-06,1.1336037700203614,1,0.1336037700203614,59\n396,94c2e,52,8.527676903377485e-06,0.9824566006843133,0,0.9824566006843133,52\n397,2fc45,8,1.3119502928273053e-06,0.1511471693360482,0,0.1511471693360482,8\n398,e7986,30,4.919813598102395e-06,0.5668018850101807,0,0.5668018850101807,30\n399,612a5,20340,0.003335633619513424,384.2916780369025,385,0.2916780369025105,19955\n400,f9e3a,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n401,75d14,667,0.00010938385566447659,12.601895243393018,12,0.6018952433930185,655\n402,e02b6,86,1.4103465647893532e-05,1.624832070362518,1,0.6248320703625181,85\n403,b4c46,479,7.855302378303491e-05,9.049936763995886,9,0.04993676399588587,470\n404,f9fb4,299,4.9034142194420536e-05,5.6491254539348015,5,0.6491254539348015,294\n405,08461,34,5.575788744516048e-06,0.6423754696782048,0,0.6423754696782048,34\n406,42032,124,2.0335229538823232e-05,2.342781124708747,2,0.3427811247087469,122\n407,b7ea6,2861,0.0004691862234723651,54.05400643380424,55,0.05400643380423986,2806\n408,cf91f,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n409,64f3d,11,1.8039316526375448e-06,0.20782735783706627,0,0.20782735783706627,11\n410,49d78,408,6.690946493419258e-05,7.708505636138459,7,0.708505636138459,401\n411,aa802,3907,0.0006407237242595352,73.81649882449253,74,0.8164988244925269,3833\n412,7feff,450,7.379720397153592e-05,8.50202827515271,8,0.5020282751527105,442\n413,08bc7,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n414,c7cad,16,2.6239005856546107e-06,0.3022943386720964,0,0.3022943386720964,16\n415,ec21a,84,1.3775478074686707e-05,1.587045278028506,1,0.587045278028506,83\n416,a1ba3,5,8.199689330170659e-07,0.09446698083503012,0,0.09446698083503012,5\n417,7f276,240,3.935850878481916e-05,4.534415080081446,4,0.5344150800814456,236\n419,e86cf,188,3.0830831881441676e-05,3.5519584793971326,3,0.5519584793971326,185\n420,4f7b0,14,2.2959130124477845e-06,0.2645075463380844,0,0.2645075463380844,14\n421,c61f3,363,5.952974453703898e-05,6.858302808623187,6,0.8583028086231872,357\n422,0c672,32,5.247801171309221e-06,0.6045886773441927,0,0.6045886773441927,32\n423,9f766,14028,0.00230050483847268,265.0365614307605,266,0.03656143076051421,13762\n424,2f990,4092,0.0006710625747811667,77.31177711538865,78,0.3117771153886508,4014\n425,660cb,1043,0.00017104551942735994,19.705812202187285,19,0.7058122021872855,1024\n426,0045c,81,1.3283496714876466e-05,1.5303650895274878,1,0.5303650895274878,80\n427,934e9,818,0.00013414691744159196,15.454798064610927,15,0.45479806461092664,803\n428,820eb,21,3.4438695186716767e-06,0.3967613195071265,0,0.3967613195071265,21\n429,8736b,14,2.2959130124477845e-06,0.2645075463380844,0,0.2645075463380844,14\n430,3de8b,18,2.9518881588614373e-06,0.3400811310061085,0,0.3400811310061085,18\n431,16140,21927,0.0035958917588530407,414.2754977539411,415,0.2754977539411243,21512\n432,44e80,17490,0.0028682513276936964,330.44549896093537,331,0.44549896093536745,17159\n434,c10c0,58575,0.009605936050294927,1106.680680482378,1107,0.6806804823779657,57468\n435,05214,129705,0.021270814091395706,2450.5679498415166,2451,0.5679498415165654,127254\n436,d70bb,130766,0.021444811498981926,2470.61384317471,2471,0.6138431747099276,128295\n437,a7eba,155,2.5419036923529042e-05,2.928476405885934,2,0.928476405885934,153\n438,2afa7,603,9.888825332185814e-05,11.392717888704633,11,0.39271788870463276,592\n439,ff2e5,1855,0.00030420847414933144,35.04724988979618,36,0.04724988979617706,1819\n440,b784f,148,2.4271080417305148e-05,2.7962226327168915,2,0.7962226327168915,146\n441,c2ce4,16469,0.0027008136715716115,311.15534147442224,312,0.15534147442224366,16157\n442,02ab5,491,8.052094922227586e-05,9.276657517999958,9,0.27665751799995775,482\n443,56013,65781,0.010787675276559121,1242.8264932618233,1243,0.8264932618233161,64538\n444,b4b2c,2429,0.0003983409076596906,45.89205928965763,46,0.8920592896576309,2383\n445,298a3,231,3.788256470538844e-05,4.364374514578392,4,0.3643745145783921,227\n446,3df28,3918,0.0006425276559121728,74.0243261823296,75,0.024326182329602375,3843\n447,99b56,594,9.741230924242742e-05,11.222677323201578,11,0.2226773232015784,583\n448,7c7e3,285,4.6738229181972755e-05,5.384617907596717,5,0.3846179075967173,280\n449,230a9,29334,0.004810593736224522,554.2188831629547,555,0.21888316295473942,28779\n450,11677,332788,0.05457516425617666,6287.4955236256,6288,0.4955236256000717,326500\n451,2e832,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n452,b961f,105,1.7219347593358382e-05,1.9838065975356325,1,0.9838065975356325,104\n453,caff0,2624,0.00043031969604735617,49.57627154222381,50,0.5762715422238074,2574\n454,c0a9c,76675,0.012574223587816704,1448.6511511051867,1449,0.6511511051867274,75226\n455,91d53,2145,0.00035176667226432124,40.52633477822792,41,0.5263347782279197,2104\n456,77494,1404,0.0002302472763911921,26.526328218476458,27,0.5263282184764577,1377\n457,9346a,160758,0.026363313146791495,3037.2645810155545,3038,0.2645810155545405,157720\n458,48d36,690,0.00011315571275635509,13.036443355234157,13,0.03644335523415698,677\n459,de96c,20110,0.003297915048594639,379.9461969184912,380,0.9461969184911823,19730\n460,1eca8,1577,0.00025861820147358256,29.7948857553685,30,0.7948857553685009,1547\n461,29179,1443,0.0002366430340687252,27.26317066898969,28,0.2631706689896909,1415\n462,f6416,12243,0.0020077759293855874,231.31184927265477,232,0.3118492726547686,12011\n463,eb5ab,24654,0.004043102814920548,465.79778910136656,466,0.7977891013665612,24188\n464,bb24e,275,4.5098291315938624e-05,5.195683945926657,5,0.19568394592665683,270\n465,089a1,1595,0.000261570089632444,30.13496688637461,31,0.13496688637460963,1564\n466,8f9ad,8142,0.00133523741052499,153.83003159176306,154,0.8300315917630599,7988\n467,01aa2,67287,0.011034649919183862,1271.2799478893344,1272,0.2799478893343803,66015\n468,b8427,76545,0.012552904395558262,1446.1950096034761,1447,0.19500960347613727,75098\n469,4ca48,20684,0.003392047482104998,390.7910063183526,391,0.7910063183525722,20293\n470,a6fe0,673,0.00011036781838409706,12.715255620395054,12,0.7152556203950535,661\n471,8131b,524,8.59327441801885e-05,9.900139591511158,9,0.9001395915111576,515\n472,9ed2f,2084,0.000341763051281513,39.37383761204055,40,0.3738376120405533,2044\n473,62851,47540,0.007796264615126262,898.1920537794664,899,0.19205377946639146,46641\n474,c8364,85303,0.013989161978630954,1611.663373234115,1612,0.6633732341149425,83691\n475,63dcc,5712,0.000936732509078696,107.9190789059384,108,0.9190789059384059,5604\n476,53490,53415,0.008759728111421314,1009.1907562606268,1010,0.19075626062681295,52405\n477,ea1ca,9513,0.0015600728919582694,179.7328777367283,180,0.7328777367283124,9333\n478,b98cd,86047,0.014111173355863893,1625.7200599823675,1626,0.720059982367502,84421\n479,5cc45,458,7.510915426436323e-05,8.65317544448876,8,0.6531754444887596,450\n480,935e4,1124,0.0001843290161422364,21.23617729171477,21,0.2361772917147711,1103\n481,25830,51,8.363683116774072e-06,0.9635632045173074,0,0.9635632045173074,51\n482,e1bce,401,6.576150842796868e-05,7.576251862969416,7,0.576251862969416,394\n483,522f3,25806,0.00423202365708768,487.5629814857574,488,0.5629814857574047,25318\n484,a091d,297,4.870615462121371e-05,5.611338661600789,5,0.6113386616007892,292\n485,5d15a,3101,0.0005085447322571842,58.588421513885685,59,0.5884215138856845,3042\n486,3807a,606,9.938023468166839e-05,11.449398077205652,11,0.44939807720565206,595\n487,6fc74,14406,0.00236249448980877,272.1782651818888,273,0.17826518188877571,14133\n488,c4b61,186,3.0502844308234848e-05,3.5141716870631203,3,0.5141716870631203,183\n489,1d557,935,0.0001533341904741913,17.665325416150633,17,0.6653254161506332,918\n490,4810f,267,4.378634102311132e-05,5.044536776590609,5,0.04453677659060862,262\n492,20171,1303,0.00021368390394424736,24.61809520560885,24,0.6180952056088493,1279\n493,645a8,507,8.314484980793048e-05,9.578951856672054,9,0.5789518566720542,498\n494,32855,5841,0.0009578877075505363,110.3563270114822,111,0.35632701148219326,5730\n495,9f0b8,1978,0.00032437970990155125,37.37113761833792,38,0.3711376183379187,1940\n496,bfa97,171,2.804293750918365e-05,3.23077074455803,3,0.23077074455803004,168\n497,c7ae8,39,6.3957576775331136e-06,0.7368424505132349,0,0.7368424505132349,39\n498,0443b,107,1.754733516656521e-05,2.0215933898696448,2,0.021593389869644763,105\n499,8b6d1,36,5.9037763177228745e-06,0.680162262012217,0,0.680162262012217,36\n500,0df28,803,0.00013168701064254076,15.171397122105837,15,0.17139712210583724,788\n501,b8da7,1772,0.00029059698986124815,33.479098007934674,34,0.47909800793467383,1738\n502,be111,30,4.919813598102395e-06,0.5668018850101807,0,0.5668018850101807,30\n503,fff89,160,2.6239005856546107e-05,3.022943386720964,3,0.022943386720963854,157\n504,69f64,6756,0.0011079420222926595,127.64378450429271,128,0.6437845042927108,6628\n505,cd02f,1582,0.00025943817040659966,29.889352736203534,30,0.8893527362035343,1552\n506,49604,1205,0.00019761251285711287,22.76654238124226,22,0.7665423812422603,1183\n507,f5741,2674,0.0004385193853775268,50.52094135057411,51,0.5209413505741125,2623\n508,35c93,984,0.00016136988601775856,18.59110182833393,18,0.5911018283339295,966\n509,7af3c,15579,0.0025548592014945737,294.34021888578684,295,0.34021888578683956,15284\n510,b061f,733,0.00012020744558030186,13.848859390415416,13,0.8488593904154165,720\n511,10963,255,4.181841558387036e-05,4.817816022586537,4,0.8178160225865367,251\n512,7429f,596,9.774029681563425e-05,11.26046411553559,11,0.2604641155355907,585\n513,8e06f,114,1.86952916727891e-05,2.153847163038687,2,0.15384716303868684,112\n514,adfd4,10604,0.0017389901131425933,200.3455729549319,201,0.345572954931896,10403\n515,f5fd5,4215,0.0006912338105333865,79.63566484393039,80,0.6356648439303854,4135\n516,565d2,285,4.6738229181972755e-05,5.384617907596717,5,0.3846179075967173,280\n517,8f6cf,119,1.9515260605806166e-05,2.2483141438737166,2,0.24831414387371664,117\n518,4f72e,3024,0.0004959172106887215,57.13363000902623,58,0.13363000902622701,2966\n519,d9867,1147,0.00018810087323411492,21.670725403555913,21,0.6707254035559131,1126\n520,2c9f1,82,1.344749050147988e-05,1.549258485694494,1,0.549258485694494,81\n521,dccb0,2004,0.00032864354835324,37.86236591868007,38,0.8623659186800694,1966\n522,6e283,2322,0.00038079357249312537,43.87046589978799,44,0.8704658997879875,2278\n523,43593,923,0.00015136626503495037,17.438604662146563,17,0.43860466214656313,906\n524,17b6e,903,0.0001480863893028821,17.06073673880644,17,0.06073673880644037,886\n525,72ae9,357,5.85457818174185e-05,6.74494243162115,6,0.7449424316211504,351\n526,d902d,873,0.0001431665757047797,16.49393485379626,16,0.49393485379626156,857\n527,64690,318,5.215002413988539e-05,6.008099981107916,6,0.00809998110791632,312\n528,f476f,25,4.099844665085329e-06,0.4723349041751506,0,0.4723349041751506,25\n529,82153,127,2.0827210898633473e-05,2.3994613132097653,2,0.3994613132097653,125\n530,de876,128,2.0991204685236885e-05,2.418354709376771,2,0.418354709376771,126\n531,0f108,641,0.00010512001721278785,12.110666943050862,12,0.11066694305086244,629\n532,c4fd2,1020,0.00016727366233548145,19.271264090346147,19,0.27126409034614696,1001\n533,5a3a4,147,2.4107086630701735e-05,2.7773292365498854,2,0.7773292365498854,145\n534,59b0c,157,2.5747024496735867e-05,2.966263198219946,2,0.9662631982199459,155\n535,1b5c6,24,3.935850878481916e-06,0.45344150800814453,0,0.45344150800814453,24\n536,433d0,129,2.1155198471840298e-05,2.437248105543777,2,0.43724810554377713,127\n537,69d01,1266,0.00020761613383992107,23.919039547429627,23,0.9190395474296267,1243\n538,d05da,277,4.542627888914545e-05,5.233470738260669,5,0.2334707382606691,272\n539,c552a,924,0.00015153025882155377,17.45749805831357,17,0.4574980583135684,907\n540,ede77,196,3.214278217426898e-05,3.703105648733181,3,0.7031056487331808,193\n541,5fb34,336,5.510191229874683e-05,6.348181112114024,6,0.34818111211402414,330\n542,3c8b0,303,4.969011734083419e-05,5.724699038602826,5,0.724699038602826,298\n543,69faf,48,7.871701756963832e-06,0.9068830160162891,0,0.9068830160162891,48\n544,4282e,40,6.559751464136527e-06,0.755735846680241,0,0.755735846680241,40\n545,c8c9d,21,3.4438695186716767e-06,0.3967613195071265,0,0.3967613195071265,21\n"
+ )
+ )
+
+ for x in range(100_01, df.remaining_balance.sum(), 12345):
+ amount = USDCent(x)
+ res = business_payout_event_manager.distribute_amount(df=df, amount=amount)
+ assert isinstance(res, pd.Series)
+ assert res.sum() == amount
+
+ def test_ach_payment_min_amount(
+ self,
+ product,
+ mnt_filepath,
+ thl_lm,
+ client_no_amm,
+ thl_redis_config,
+ payout_event_manager,
+ brokerage_product_payout_event_manager,
+ business_payout_event_manager,
+ delete_ledger_db,
+ create_main_accounts,
+ delete_df_collection,
+ ledger_collection,
+ business,
+ user_factory,
+ product_factory,
+ session_with_tx_factory,
+ pop_ledger_merge,
+ start,
+ bp_payout_factory,
+ adj_to_fail_with_tx_factory,
+ thl_web_rr,
+ lm,
+ product_manager,
+ ):
+ """Test having a Business with three products.. one that lost money
+ and two that gained money. Ensure that the Business balance
+ reflects that to compensate for the Product in the negative and only
+ assigns Brokerage Product payments from the 2 accounts that have
+ a positive balance.
+ """
+ # Now let's load it up and actually test some things
+ delete_ledger_db()
+ create_main_accounts()
+ delete_df_collection(coll=ledger_collection)
+
+ from generalresearch.models.thl.product import Product
+ from generalresearch.models.thl.user import User
+
+ p1: Product = product_factory(business=business)
+ u1: User = user_factory(product=p1)
+ thl_lm.get_account_or_create_bp_wallet(product=p1)
+
+ session_with_tx_factory(
+ user=u1,
+ wall_req_cpi=Decimal("5.00"),
+ started=start + timedelta(days=1),
+ )
+ session_with_tx_factory(
+ user=u1,
+ wall_req_cpi=Decimal("5.00"),
+ started=start + timedelta(days=6),
+ )
+ payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+ bp_payout_factory(
+ product=u1.product,
+ amount=USDCent(475), # 95% of $5.00
+ created=start + timedelta(days=1, minutes=1),
+ )
+
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+ business.prebuild_balance(
+ thl_pg_config=thl_web_rr,
+ lm=lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ pop_ledger=pop_ledger_merge,
+ )
+
+ with pytest.raises(expected_exception=AssertionError) as cm:
+ business_payout_event_manager.create_from_ach_or_wire(
+ business=business,
+ amount=USDCent(500),
+ pm=product_manager,
+ thl_lm=thl_lm,
+ )
+ assert "Must issue Supplier Payouts at least $100 minimum." in str(cm)
+
+ def test_ach_payment(
+ self,
+ product,
+ mnt_filepath,
+ thl_lm,
+ client_no_amm,
+ thl_redis_config,
+ payout_event_manager,
+ brokerage_product_payout_event_manager,
+ business_payout_event_manager,
+ delete_ledger_db,
+ create_main_accounts,
+ delete_df_collection,
+ ledger_collection,
+ business,
+ user_factory,
+ product_factory,
+ session_with_tx_factory,
+ pop_ledger_merge,
+ start,
+ bp_payout_factory,
+ adj_to_fail_with_tx_factory,
+ thl_web_rr,
+ lm,
+ product_manager,
+ rm_ledger_collection,
+ rm_pop_ledger_merge,
+ ):
+ """Test having a Business with three products.. one that lost money
+ and two that gained money. Ensure that the Business balance
+ reflects that to compensate for the Product in the negative and only
+ assigns Brokerage Product payments from the 2 accounts that have
+ a positive balance.
+ """
+ # Now let's load it up and actually test some things
+ delete_ledger_db()
+ create_main_accounts()
+ delete_df_collection(coll=ledger_collection)
+
+ from generalresearch.models.thl.product import Product
+ from generalresearch.models.thl.user import User
+
+ p1: Product = product_factory(business=business)
+ p2: Product = product_factory(business=business)
+ p3: Product = product_factory(business=business)
+ u1: User = user_factory(product=p1)
+ u2: User = user_factory(product=p2)
+ u3: User = user_factory(product=p3)
+ thl_lm.get_account_or_create_bp_wallet(product=p1)
+ thl_lm.get_account_or_create_bp_wallet(product=p2)
+ thl_lm.get_account_or_create_bp_wallet(product=p3)
+
+ ach_id1 = uuid4().hex
+ ach_id2 = uuid4().hex
+
+ # Product 1: Complete, Payout, Recon..
+ s1 = session_with_tx_factory(
+ user=u1,
+ wall_req_cpi=Decimal("5.00"),
+ started=start + timedelta(days=1),
+ )
+ payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+ bp_payout_factory(
+ product=u1.product,
+ amount=USDCent(475), # 95% of $5.00
+ ext_ref_id=ach_id1,
+ created=start + timedelta(days=1, minutes=1),
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ )
+ adj_to_fail_with_tx_factory(
+ session=s1,
+ created=start + timedelta(days=1, minutes=2),
+ )
+
+ # Product 2: Complete x10
+ for idx in range(15):
+ session_with_tx_factory(
+ user=u2,
+ wall_req_cpi=Decimal("7.50"),
+ started=start + timedelta(days=1, hours=2, minutes=1 + idx),
+ )
+
+ # Product 3: Complete x5
+ for idx in range(10):
+ session_with_tx_factory(
+ user=u3,
+ wall_req_cpi=Decimal("7.50"),
+ started=start + timedelta(days=1, hours=3, minutes=1 + idx),
+ )
+
+ # Now that we paid out the business, let's confirm the updated balances
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+ business.prebuild_balance(
+ thl_pg_config=thl_web_rr,
+ lm=lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ pop_ledger=pop_ledger_merge,
+ )
+
+ bb1 = business.balance
+ pb1 = bb1.product_balances[0]
+ pb2 = bb1.product_balances[1]
+ pb3 = bb1.product_balances[2]
+ assert bb1.payout == 25 * 712 + 475 # $7.50 * .95% = $7.125 = $7.12
+ assert bb1.adjustment == -475
+ assert bb1.net == ((25 * 7.12 + 4.75) - 4.75) * 100
+ # The balance is lower because the single $4.75 payout
+ assert bb1.balance == bb1.net - 475
+ assert (
+ bb1.available_balance
+ == (pb2.available_balance + pb3.available_balance) - 475
+ )
+
+ assert pb1.available_balance == 0
+ assert pb2.available_balance == 8010
+ assert pb3.available_balance == 5340
+
+ assert bb1.recoup_usd_str == "$4.75"
+ assert pb1.recoup_usd_str == "$4.75"
+ assert pb2.recoup_usd_str == "$0.00"
+ assert pb3.recoup_usd_str == "$0.00"
+
+ assert business.payouts is None
+ business.prebuild_payouts(
+ thl_pg_config=thl_web_rr,
+ thl_lm=thl_lm,
+ bpem=business_payout_event_manager,
+ )
+ assert len(business.payouts) == 1
+ assert business.payouts[0].ext_ref_id == ach_id1
+
+ bp1 = business_payout_event_manager.create_from_ach_or_wire(
+ business=business,
+ amount=USDCent(bb1.available_balance),
+ pm=product_manager,
+ thl_lm=thl_lm,
+ created=start + timedelta(days=1, hours=5),
+ )
+ assert isinstance(bp1, BusinessPayoutEvent)
+ assert len(bp1.bp_payouts) == 2
+ assert bp1.bp_payouts[0].status == PayoutStatus.COMPLETE
+ assert bp1.bp_payouts[1].status == PayoutStatus.COMPLETE
+ bp1_tx = brokerage_product_payout_event_manager.check_for_ledger_tx(
+ thl_ledger_manager=thl_lm,
+ payout_event=bp1.bp_payouts[0],
+ product_id=bp1.bp_payouts[0].product_id,
+ amount=bp1.bp_payouts[0].amount,
+ )
+ assert bp1_tx
+
+ bp2_tx = brokerage_product_payout_event_manager.check_for_ledger_tx(
+ thl_ledger_manager=thl_lm,
+ payout_event=bp1.bp_payouts[1],
+ product_id=bp1.bp_payouts[1].product_id,
+ amount=bp1.bp_payouts[1].amount,
+ )
+ assert bp2_tx
+
+ # Now that we paid out the business, let's confirm the updated balances
+ rm_ledger_collection()
+ rm_pop_ledger_merge()
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+
+ business.prebuild_balance(
+ thl_pg_config=thl_web_rr,
+ lm=lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ pop_ledger=pop_ledger_merge,
+ )
+ business.prebuild_payouts(
+ thl_pg_config=thl_web_rr,
+ thl_lm=thl_lm,
+ bpem=business_payout_event_manager,
+ )
+ assert len(business.payouts) == 2
+ assert len(business.payouts[0].bp_payouts) == 2
+ assert len(business.payouts[1].bp_payouts) == 1
+
+ bb2 = business.balance
+
+ # Okay os we have the balance before, and after the Business Payout
+ # of bb1.available_balance worth..
+ assert bb1.payout == bb2.payout
+ assert bb1.adjustment == bb2.adjustment
+ assert bb1.net == bb2.net
+ assert bb1.available_balance > bb2.available_balance
+
+ # This is the ultimate test. Confirm that the second time we get the
+ # Business balance, it is equal to the first time we Business balance
+ # minus the amount that was just paid out across any children
+ # Brokerage Products.
+ #
+ # This accounts for all the net positive and net negative children
+ # Brokerage Products under the Business in thi
+ assert bb2.balance == bb1.balance - bb1.available_balance
+
+ def test_ach_payment_partial_amount(
+ self,
+ product,
+ mnt_filepath,
+ thl_lm,
+ client_no_amm,
+ thl_redis_config,
+ payout_event_manager,
+ brokerage_product_payout_event_manager,
+ business_payout_event_manager,
+ delete_ledger_db,
+ create_main_accounts,
+ delete_df_collection,
+ ledger_collection,
+ business,
+ user_factory,
+ product_factory,
+ session_with_tx_factory,
+ pop_ledger_merge,
+ start,
+ bp_payout_factory,
+ adj_to_fail_with_tx_factory,
+ thl_web_rr,
+ lm,
+ product_manager,
+ rm_ledger_collection,
+ rm_pop_ledger_merge,
+ ):
+ """There are valid instances when we want issue a ACH or Wire to a
+ Business, but not for the full Available Balance amount in their
+ account.
+
+ To test this, we'll create a Business with multiple Products, and
+ a cumulative Available Balance of $100 (for example), but then only
+ issue a payout of $75 (for example). We want to confirm the sum
+ of the Product payouts equals the $75 number and isn't greedy and
+ takes the full $100 amount that is available to the Business.
+
+ """
+ # Now let's load it up and actually test some things
+ delete_ledger_db()
+ create_main_accounts()
+ delete_df_collection(coll=ledger_collection)
+
+ from generalresearch.models.thl.product import Product
+ from generalresearch.models.thl.user import User
+
+ p1: Product = product_factory(business=business)
+ p2: Product = product_factory(business=business)
+ p3: Product = product_factory(business=business)
+ u1: User = user_factory(product=p1)
+ u2: User = user_factory(product=p2)
+ u3: User = user_factory(product=p3)
+ thl_lm.get_account_or_create_bp_wallet(product=p1)
+ thl_lm.get_account_or_create_bp_wallet(product=p2)
+ thl_lm.get_account_or_create_bp_wallet(product=p3)
+
+ # Product 1, 2, 3: Complete, and Payout multiple times.
+ for idx in range(5):
+ for u in [u1, u2, u3]:
+ session_with_tx_factory(
+ user=u,
+ wall_req_cpi=Decimal("50.00"),
+ started=start + timedelta(days=1, hours=2, minutes=1 + idx),
+ )
+ payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+
+ # Now that we paid out the business, let's confirm the updated balances
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+ business.prebuild_balance(
+ thl_pg_config=thl_web_rr,
+ lm=lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ pop_ledger=pop_ledger_merge,
+ )
+ business.prebuild_payouts(
+ thl_pg_config=thl_web_rr,
+ thl_lm=thl_lm,
+ bpem=business_payout_event_manager,
+ )
+
+ # Confirm the initial amounts.
+ assert len(business.payouts) == 0
+ bb1 = business.balance
+ assert bb1.payout == 3 * 5 * 4750
+ assert bb1.adjustment == 0
+ assert bb1.payout == bb1.net
+
+ assert bb1.balance_usd_str == "$712.50"
+ assert bb1.available_balance_usd_str == "$534.39"
+
+ for x in range(2):
+ assert bb1.product_balances[x].balance == 5 * 4750
+ assert bb1.product_balances[x].available_balance_usd_str == "$178.13"
+
+ assert business.payouts_total_str == "$0.00"
+ assert business.balance.payment_usd_str == "$0.00"
+ assert business.balance.available_balance_usd_str == "$534.39"
+
+ # This is the important part, even those the Business has $534.39
+ # available to it, we are only trying to issue out a $250.00 ACH or
+ # Wire to the Business
+ bp1 = business_payout_event_manager.create_from_ach_or_wire(
+ business=business,
+ amount=USDCent(250_00),
+ pm=product_manager,
+ thl_lm=thl_lm,
+ created=start + timedelta(days=1, hours=3),
+ )
+ assert isinstance(bp1, BusinessPayoutEvent)
+ assert len(bp1.bp_payouts) == 3
+
+ # Now that we paid out the business, let's confirm the updated
+ # balances. Clear and rebuild the parquet files.
+ rm_ledger_collection()
+ rm_pop_ledger_merge()
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+
+ # Now rebuild and confirm the payouts, balance.payment, and the
+ # balance.available_balance are reflective of having a $250 ACH/Wire
+ # sent to the Business
+ business.prebuild_balance(
+ thl_pg_config=thl_web_rr,
+ lm=lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ pop_ledger=pop_ledger_merge,
+ )
+ business.prebuild_payouts(
+ thl_pg_config=thl_web_rr,
+ thl_lm=thl_lm,
+ bpem=business_payout_event_manager,
+ )
+ assert len(business.payouts) == 1
+ assert len(business.payouts[0].bp_payouts) == 3
+ assert business.payouts_total_str == "$250.00"
+ assert business.balance.payment_usd_str == "$250.00"
+ assert business.balance.available_balance_usd_str == "$346.88"
+
+ def test_ach_tx_id_reference(
+ self,
+ product,
+ mnt_filepath,
+ thl_lm,
+ client_no_amm,
+ thl_redis_config,
+ payout_event_manager,
+ brokerage_product_payout_event_manager,
+ business_payout_event_manager,
+ delete_ledger_db,
+ create_main_accounts,
+ delete_df_collection,
+ ledger_collection,
+ business,
+ user_factory,
+ product_factory,
+ session_with_tx_factory,
+ pop_ledger_merge,
+ start,
+ bp_payout_factory,
+ adj_to_fail_with_tx_factory,
+ thl_web_rr,
+ lm,
+ product_manager,
+ rm_ledger_collection,
+ rm_pop_ledger_merge,
+ ):
+
+ # Now let's load it up and actually test some things
+ delete_ledger_db()
+ create_main_accounts()
+ delete_df_collection(coll=ledger_collection)
+
+ from generalresearch.models.thl.product import Product
+ from generalresearch.models.thl.user import User
+
+ p1: Product = product_factory(business=business)
+ p2: Product = product_factory(business=business)
+ p3: Product = product_factory(business=business)
+ u1: User = user_factory(product=p1)
+ u2: User = user_factory(product=p2)
+ u3: User = user_factory(product=p3)
+ thl_lm.get_account_or_create_bp_wallet(product=p1)
+ thl_lm.get_account_or_create_bp_wallet(product=p2)
+ thl_lm.get_account_or_create_bp_wallet(product=p3)
+
+ ach_id1 = uuid4().hex
+ ach_id2 = uuid4().hex
+
+ for idx in range(15):
+ for iidx, u in enumerate([u1, u2, u3]):
+ session_with_tx_factory(
+ user=u,
+ wall_req_cpi=Decimal("7.50"),
+ started=start + timedelta(days=1, hours=1 + iidx, minutes=1 + idx),
+ )
+ payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+
+ rm_ledger_collection()
+ rm_pop_ledger_merge()
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+ business.prebuild_balance(
+ thl_pg_config=thl_web_rr,
+ lm=lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ pop_ledger=pop_ledger_merge,
+ )
+
+ bp1 = business_payout_event_manager.create_from_ach_or_wire(
+ business=business,
+ amount=USDCent(100_01),
+ transaction_id=ach_id1,
+ pm=product_manager,
+ thl_lm=thl_lm,
+ created=start + timedelta(days=2, hours=1),
+ )
+
+ rm_ledger_collection()
+ rm_pop_ledger_merge()
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+ business.prebuild_balance(
+ thl_pg_config=thl_web_rr,
+ lm=lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ pop_ledger=pop_ledger_merge,
+ )
+
+ bp2 = business_payout_event_manager.create_from_ach_or_wire(
+ business=business,
+ amount=USDCent(100_02),
+ transaction_id=ach_id2,
+ pm=product_manager,
+ thl_lm=thl_lm,
+ created=start + timedelta(days=4, hours=1),
+ )
+
+ assert isinstance(bp1, BusinessPayoutEvent)
+ assert isinstance(bp2, BusinessPayoutEvent)
+
+ rm_ledger_collection()
+ rm_pop_ledger_merge()
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+ business.prebuild_payouts(
+ thl_pg_config=thl_web_rr,
+ thl_lm=thl_lm,
+ bpem=business_payout_event_manager,
+ )
+ business.prebuild_balance(
+ thl_pg_config=thl_web_rr,
+ lm=lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ pop_ledger=pop_ledger_merge,
+ )
+ assert business.payouts[0].ext_ref_id == ach_id2
+ assert business.payouts[1].ext_ref_id == ach_id1
diff --git a/tests/managers/thl/test_product.py b/tests/managers/thl/test_product.py
new file mode 100644
index 0000000..78d5dde
--- /dev/null
+++ b/tests/managers/thl/test_product.py
@@ -0,0 +1,362 @@
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.models import Source
+from generalresearch.models.thl.product import (
+ Product,
+ SourceConfig,
+ UserCreateConfig,
+ SourcesConfig,
+ UserHealthConfig,
+ ProfilingConfig,
+ SupplyPolicy,
+ SupplyConfig,
+)
+from test_utils.models.conftest import product_factory
+
+
+class TestProductManagerGetMethods:
+ def test_get_by_uuid(self, product_manager):
+ product: Product = product_manager.create_dummy(
+ product_id=uuid4().hex,
+ team_id=uuid4().hex,
+ name=f"Test Product ID #{uuid4().hex[:6]}",
+ )
+
+ instance = product_manager.get_by_uuid(product_uuid=product.id)
+ assert isinstance(instance, Product)
+ # self.assertEqual(instance.model_dump(mode="json"), product.model_dump(mode="json"))
+ assert instance.id == product.id
+
+ with pytest.raises(AssertionError) as cm:
+ product_manager.get_by_uuid(product_uuid="abc123")
+ assert "invalid uuid" in str(cm.value)
+
+ with pytest.raises(AssertionError) as cm:
+ product_manager.get_by_uuid(product_uuid=uuid4().hex)
+ assert "product not found" in str(cm.value)
+
+ def test_get_by_uuids(self, product_manager):
+ cnt = 5
+
+ product_uuids = [uuid4().hex for idx in range(cnt)]
+ for product_id in product_uuids:
+ product_manager.create_dummy(
+ product_id=product_id,
+ team_id=uuid4().hex,
+ name=f"Test Product ID #{uuid4().hex[:6]}",
+ )
+
+ res = product_manager.get_by_uuids(product_uuids=product_uuids)
+ assert isinstance(res, list)
+ assert cnt == len(res)
+ for instance in res:
+ assert isinstance(instance, Product)
+
+ with pytest.raises(AssertionError) as cm:
+ product_manager.get_by_uuids(product_uuids=product_uuids + [uuid4().hex])
+ assert "incomplete product response" in str(cm.value)
+
+ with pytest.raises(AssertionError) as cm:
+ product_manager.get_by_uuids(product_uuids=product_uuids + ["abc123"])
+ assert "invalid uuid" in str(cm.value)
+
+ def test_get_by_uuid_if_exists(self, product_manager):
+ product: Product = product_manager.create_dummy(
+ product_id=uuid4().hex,
+ team_id=uuid4().hex,
+ name=f"Test Product ID #{uuid4().hex[:6]}",
+ )
+ instance = product_manager.get_by_uuid_if_exists(product_uuid=product.id)
+ assert isinstance(instance, Product)
+
+ instance = product_manager.get_by_uuid_if_exists(product_uuid="abc123")
+ assert instance == None
+
+ def test_get_by_uuids_if_exists(self, product_manager):
+ product_uuids = [uuid4().hex for _ in range(2)]
+ for product_id in product_uuids:
+ product_manager.create_dummy(
+ product_id=product_id,
+ team_id=uuid4().hex,
+ name=f"Test Product ID #{uuid4().hex[:6]}",
+ )
+
+ res = product_manager.get_by_uuids_if_exists(product_uuids=product_uuids)
+ assert isinstance(res, list)
+ assert 2 == len(res)
+ for instance in res:
+ assert isinstance(instance, Product)
+
+ res = product_manager.get_by_uuids_if_exists(
+ product_uuids=product_uuids + [uuid4().hex]
+ )
+ assert isinstance(res, list)
+ assert 2 == len(res)
+ for instance in res:
+ assert isinstance(instance, Product)
+
+ # # This will raise an error b/c abc123 isn't a uuid.
+ # res = product_manager.get_by_uuids_if_exists(
+ # product_uuids=product_uuids + ["abc123"]
+ # )
+ # assert isinstance(res, list)
+ # assert 2 == len(res)
+ # for instance in res:
+ # assert isinstance(instance, Product)
+
+ def test_get_by_business_ids(self, product_manager):
+ business_ids = [uuid4().hex for i in range(5)]
+
+ product_manager.fetch_uuids(business_uuids=business_ids)
+
+ for business_id in business_ids:
+ product_manager.create(
+ product_id=uuid4().hex,
+ team_id=None,
+ business_id=business_id,
+ redirect_url="https://www.example.com",
+ name=f"Test Product ID #{uuid4().hex[:6]}",
+ user_create_config=None,
+ )
+
+
+class TestProductManagerCreation:
+
+ def test_base(self, product_manager):
+ instance = product_manager.create_dummy(
+ product_id=uuid4().hex,
+ team_id=uuid4().hex,
+ name=f"New Test Product {uuid4().hex[:6]}",
+ )
+
+ assert isinstance(instance, Product)
+
+
+class TestProductManagerCreate:
+
+ def test_create_simple(self, product_manager):
+ # Always required: product_id, team_id, name, redirect_url
+ # Required internally - if not passed use default: harmonizer_domain,
+ # commission_pct, sources
+ product_id = uuid4().hex
+ team_id = uuid4().hex
+ business_id = uuid4().hex
+
+ product_manager.create(
+ product_id=product_id,
+ team_id=team_id,
+ business_id=business_id,
+ name=f"Test Product ID #{uuid4().hex[:6]}",
+ redirect_url="https://www.example.com",
+ )
+
+ instance = product_manager.get_by_uuid(product_uuid=product_id)
+
+ assert team_id == instance.team_id
+ assert business_id == instance.business_id
+
+
+class TestProductManager:
+ sources = [
+ SourceConfig.model_validate(x)
+ for x in [
+ {"name": "d", "active": True},
+ {
+ "name": "f",
+ "active": False,
+ "banned_countries": ["ps", "in", "in"],
+ },
+ {
+ "name": "s",
+ "active": True,
+ "supplier_id": "3488",
+ "allow_pii_only_buyers": True,
+ "allow_unhashed_buyers": True,
+ },
+ {"name": "e", "active": True, "withhold_profiling": True},
+ ]
+ ]
+
+ def test_get_by_uuid1(self, product_manager, team, product, product_factory):
+ p1 = product_factory(team=team)
+ instance = product_manager.get_by_uuid(product_uuid=p1.uuid)
+ assert instance.id == p1.id
+
+ # No Team and no user_create_config
+ assert instance.team_id == team.uuid
+
+ # user_create_config can't be None, so ensure the default was set.
+ assert isinstance(instance.user_create_config, UserCreateConfig)
+ assert 0 == instance.user_create_config.min_hourly_create_limit
+ assert instance.user_create_config.max_hourly_create_limit is None
+
+ def test_get_by_uuid2(self, product_manager, product_factory):
+ p2 = product_factory()
+ instance = product_manager.get_by_uuid(p2.id)
+ assert instance.id, p2.id
+
+ # Team and default user_create_config
+ assert instance.team_id is not None
+ assert instance.team_id == p2.team_id
+
+ assert 0 == instance.user_create_config.min_hourly_create_limit
+ assert instance.user_create_config.max_hourly_create_limit is None
+
+ def test_get_by_uuid3(self, product_manager, product_factory):
+ p3 = product_factory()
+ instance = product_manager.get_by_uuid(p3.id)
+ assert instance.id == p3.id
+
+ # Team and default user_create_config
+ assert instance.team_id is not None
+ assert instance.team_id == p3.team_id
+
+ assert (
+ p3.user_create_config.min_hourly_create_limit
+ == instance.user_create_config.min_hourly_create_limit
+ )
+ assert instance.user_create_config.max_hourly_create_limit is None
+ assert not instance.user_wallet_config.enabled
+
+ def test_sources(self, product_manager):
+ user_defined = [SourceConfig(name=Source.DYNATA, active=False)]
+ sources_config = SourcesConfig(user_defined=user_defined)
+ p = product_manager.create_dummy(sources_config=sources_config)
+
+ p2 = product_manager.get_by_uuid(p.id)
+
+ assert p == p2
+ assert p2.sources_config.user_defined == user_defined
+
+ # Assert d is off and everything else is on
+ dynata = p2.sources_dict[Source.DYNATA]
+ assert not dynata.active
+ assert all(x.active is True for x in p2.sources if x.name != Source.DYNATA)
+
+ def test_global_sources(self, product_manager):
+ sources_config = SupplyConfig(
+ policies=[
+ SupplyPolicy(
+ name=Source.DYNATA,
+ active=True,
+ address=["https://www.example.com"],
+ distribute_harmonizer_active=True,
+ )
+ ]
+ )
+ p1 = product_manager.create_dummy(sources_config=sources_config)
+ p2 = product_manager.get_by_uuid(p1.id)
+ assert p1 == p2
+
+ p1.sources_config.policies.append(
+ SupplyPolicy(
+ name=Source.CINT,
+ active=True,
+ address=["https://www.example.com"],
+ distribute_harmonizer_active=True,
+ )
+ )
+ product_manager.update(p1)
+ p2 = product_manager.get_by_uuid(p1.id)
+ assert p1 == p2
+
+ def test_user_health_config(self, product_manager):
+ p = product_manager.create_dummy(
+ user_health_config=UserHealthConfig(banned_countries=["ng", "in"])
+ )
+
+ p2 = product_manager.get_by_uuid(p.id)
+
+ assert p == p2
+ assert p2.user_health_config.banned_countries == ["in", "ng"]
+ assert p2.user_health_config.allow_ban_iphist
+
+ def test_profiling_config(self, product_manager):
+ p = product_manager.create_dummy(
+ profiling_config=ProfilingConfig(max_questions=1)
+ )
+ p2 = product_manager.get_by_uuid(p.id)
+
+ assert p == p2
+ assert p2.profiling_config.max_questions == 1
+
+ # def test_user_create_config(self):
+ # # -- Product 1 ---
+ # instance1 = PM.get_by_uuid(self.product_id1)
+ # self.assertEqual(instance1.user_create_config.min_hourly_create_limit, 0)
+ #
+ # self.assertEqual(instance1.user_create_config.max_hourly_create_limit, None)
+ # self.assertEqual(60, instance1.user_create_config.clip_hourly_create_limit(60))
+ #
+ # # -- Product 2 ---
+ # instance2 = PM.get_by_uuid(self.product2.id)
+ # self.assertEqual(instance2.user_create_config.min_hourly_create_limit, 200)
+ # self.assertEqual(instance2.user_create_config.max_hourly_create_limit, None)
+ #
+ # self.assertEqual(200, instance2.user_create_config.clip_hourly_create_limit(60))
+ # self.assertEqual(300, instance2.user_create_config.clip_hourly_create_limit(300))
+ #
+ # def test_create_and_cache(self):
+ # PM.uuid_cache.clear()
+ #
+ # # Hit it once, should hit mysql
+ # with self.assertLogs(level='INFO') as cm:
+ # p = PM.get_by_uuid(product_id3)
+ # self.assertEqual(cm.output, [f"INFO:root:Product.get_by_uuid:{product_id3}"])
+ # self.assertIsNotNone(p)
+ # self.assertEqual(p.id, product_id3)
+ #
+ # # Hit it again, should be pulled from cachetools cache
+ # with self.assertLogs(level='INFO') as cm:
+ # logger.info("nothing")
+ # p = PM.get_by_uuid(product_id3)
+ # self.assertEqual(cm.output, [f"INFO:root:nothing"])
+ # self.assertEqual(len(cm.output), 1)
+ # self.assertIsNotNone(p)
+ # self.assertEqual(p.id, product_id3)
+
+
+class TestProductManagerUpdate:
+
+ def test_update(self, product_manager):
+ p = product_manager.create_dummy()
+ p.name = "new name"
+ p.enabled = False
+ p.user_create_config = UserCreateConfig(min_hourly_create_limit=200)
+ p.sources_config = SourcesConfig(
+ user_defined=[SourceConfig(name=Source.DYNATA, active=False)]
+ )
+ product_manager.update(new_product=p)
+ # We cleared the cache in the update
+ # PM.id_cache.clear()
+ p2 = product_manager.get_by_uuid(p.id)
+
+ assert p2.name == "new name"
+ assert not p2.enabled
+ assert p2.user_create_config.min_hourly_create_limit == 200
+ assert not p2.sources_dict[Source.DYNATA].active
+
+
+class TestProductManagerCacheClear:
+
+ def test_cache_clear(self, product_manager):
+ p = product_manager.create_dummy()
+ product_manager.get_by_uuid(product_uuid=p.id)
+ product_manager.get_by_uuid(product_uuid=p.id)
+ product_manager.pg_config.execute_write(
+ query="""
+ UPDATE userprofile_brokerageproduct
+ SET name = 'test-6d9a5ddfd'
+ WHERE id = %s""",
+ params=[p.id],
+ )
+
+ product_manager.cache_clear(p.id)
+
+ # Calling this with or without kwargs hits different internal keys in the cache!
+ p2 = product_manager.get_by_uuid(product_uuid=p.id)
+ assert p2.name == "test-6d9a5ddfd"
+ p2 = product_manager.get_by_uuid(product_uuid=p.id)
+ assert p2.name == "test-6d9a5ddfd"
diff --git a/tests/managers/thl/test_product_prod.py b/tests/managers/thl/test_product_prod.py
new file mode 100644
index 0000000..7b4f677
--- /dev/null
+++ b/tests/managers/thl/test_product_prod.py
@@ -0,0 +1,82 @@
+import logging
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.models.thl.product import Product
+from test_utils.models.conftest import product_factory
+
+logger = logging.getLogger()
+
+
+class TestProductManagerGetMethods:
+
+ def test_get_by_uuid(self, product_manager, product_factory):
+ # Just test that we load properly
+ for p in [product_factory(), product_factory(), product_factory()]:
+ instance = product_manager.get_by_uuid(product_uuid=p.id)
+ assert isinstance(instance, Product)
+ assert instance.id == p.id
+ assert instance.uuid == p.uuid
+
+ with pytest.raises(expected_exception=AssertionError) as cm:
+ product_manager.get_by_uuid(product_uuid=uuid4().hex)
+ assert "product not found" in str(cm.value)
+
+ def test_get_by_uuids(self, product_manager, product_factory):
+ products = [product_factory(), product_factory(), product_factory()]
+ cnt = len(products)
+ res = product_manager.get_by_uuids(product_uuids=[p.id for p in products])
+ assert isinstance(res, list)
+ assert cnt == len(res)
+ for instance in res:
+ assert isinstance(instance, Product)
+
+ with pytest.raises(expected_exception=AssertionError) as cm:
+ product_manager.get_by_uuids(
+ product_uuids=[p.id for p in products] + [uuid4().hex]
+ )
+ assert "incomplete product response" in str(cm.value)
+ with pytest.raises(expected_exception=AssertionError) as cm:
+ product_manager.get_by_uuids(
+ product_uuids=[p.id for p in products] + ["abc123"]
+ )
+ assert "invalid uuid passed" in str(cm.value)
+
+ def test_get_by_uuid_if_exists(self, product_factory, product_manager):
+ products = [product_factory(), product_factory(), product_factory()]
+
+ instance = product_manager.get_by_uuid_if_exists(product_uuid=products[0].id)
+ assert isinstance(instance, Product)
+
+ instance = product_manager.get_by_uuid_if_exists(product_uuid="abc123")
+ assert instance is None
+
+ def test_get_by_uuids_if_exists(self, product_manager, product_factory):
+ products = [product_factory(), product_factory(), product_factory()]
+
+ res = product_manager.get_by_uuids_if_exists(
+ product_uuids=[p.id for p in products[:2]]
+ )
+ assert isinstance(res, list)
+ assert 2 == len(res)
+ for instance in res:
+ assert isinstance(instance, Product)
+
+ res = product_manager.get_by_uuids_if_exists(
+ product_uuids=[p.id for p in products[:2]] + [uuid4().hex]
+ )
+ assert isinstance(res, list)
+ assert 2 == len(res)
+ for instance in res:
+ assert isinstance(instance, Product)
+
+
+class TestProductManagerGetAll:
+
+ @pytest.mark.skip(reason="TODO")
+ def test_get_ALL_by_ids(self, product_manager):
+ products = product_manager.get_all(rand_limit=50)
+ logger.info(f"Fetching {len(products)} product uuids")
+ # todo: once timebucks stops spamming broken accounts, fetch more
+ pass
diff --git a/tests/managers/thl/test_profiling/__init__.py b/tests/managers/thl/test_profiling/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/managers/thl/test_profiling/__init__.py
diff --git a/tests/managers/thl/test_profiling/test_question.py b/tests/managers/thl/test_profiling/test_question.py
new file mode 100644
index 0000000..998466e
--- /dev/null
+++ b/tests/managers/thl/test_profiling/test_question.py
@@ -0,0 +1,49 @@
+from uuid import uuid4
+
+from generalresearch.managers.thl.profiling.question import QuestionManager
+from generalresearch.models import Source
+
+
+class TestQuestionManager:
+
+ def test_get_multi_upk(self, question_manager: QuestionManager, upk_data):
+ qs = question_manager.get_multi_upk(
+ question_ids=[
+ "8a22de34f985476aac85e15547100db8",
+ "0565f87d4bf044298ba169de1339ff7e",
+ "b2b32d68403647e3a87e778a6348d34c",
+ uuid4().hex,
+ ]
+ )
+ assert len(qs) == 3
+
+ def test_get_questions_ranked(self, question_manager: QuestionManager, upk_data):
+ qs = question_manager.get_questions_ranked(country_iso="mx", language_iso="spa")
+ assert len(qs) >= 40
+ assert qs[0].importance.task_score > qs[40].importance.task_score
+ assert all(q.country_iso == "mx" and q.language_iso == "spa" for q in qs)
+
+ def test_lookup_by_property(self, question_manager: QuestionManager, upk_data):
+ q = question_manager.lookup_by_property(
+ property_code="i:industry", country_iso="us", language_iso="eng"
+ )
+ assert q.source == Source.INNOVATE
+
+ q.explanation_template = "You work in the {answer} industry."
+ q.explanation_fragment_template = "you work in the {answer} industry"
+ question_manager.update_question_explanation(q)
+
+ q = question_manager.lookup_by_property(
+ property_code="i:industry", country_iso="us", language_iso="eng"
+ )
+ assert q.explanation_template
+
+ def test_filter_by_property(self, question_manager: QuestionManager, upk_data):
+ lookup = [
+ ("i:industry", "us", "eng"),
+ ("i:industry", "mx", "eng"),
+ ("m:age", "us", "eng"),
+ (f"m:{uuid4().hex}", "us", "eng"),
+ ]
+ qs = question_manager.filter_by_property(lookup)
+ assert len(qs) == 3
diff --git a/tests/managers/thl/test_profiling/test_schema.py b/tests/managers/thl/test_profiling/test_schema.py
new file mode 100644
index 0000000..ae61527
--- /dev/null
+++ b/tests/managers/thl/test_profiling/test_schema.py
@@ -0,0 +1,44 @@
+from generalresearch.models.thl.profiling.upk_property import PropertyType
+
+
+class TestUpkSchemaManager:
+
+ def test_get_props_info(self, upk_schema_manager, upk_data):
+ props = upk_schema_manager.get_props_info()
+ assert (
+ len(props) == 16955
+ ) # ~ 70 properties x each country they are available in
+
+ gender = [
+ x
+ for x in props
+ if x.country_iso == "us"
+ and x.property_id == "73175402104741549f21de2071556cd7"
+ ]
+ assert len(gender) == 1
+ gender = gender[0]
+ assert len(gender.allowed_items) == 3
+ assert gender.allowed_items[0].label == "female"
+ assert gender.allowed_items[1].label == "male"
+ assert gender.prop_type == PropertyType.UPK_ITEM
+ assert gender.categories[0].label == "Demographic"
+
+ age = [
+ x
+ for x in props
+ if x.country_iso == "us"
+ and x.property_id == "94f7379437874076b345d76642d4ce6d"
+ ]
+ assert len(age) == 1
+ age = age[0]
+ assert age.allowed_items is None
+ assert age.prop_type == PropertyType.UPK_NUMERICAL
+ assert age.gold_standard
+
+ cars = [
+ x
+ for x in props
+ if x.country_iso == "us" and x.property_label == "household_auto_type"
+ ][0]
+ assert not cars.gold_standard
+ assert cars.categories[0].label == "Autos & Vehicles"
diff --git a/tests/managers/thl/test_profiling/test_uqa.py b/tests/managers/thl/test_profiling/test_uqa.py
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/tests/managers/thl/test_profiling/test_uqa.py
@@ -0,0 +1 @@
+
diff --git a/tests/managers/thl/test_profiling/test_user_upk.py b/tests/managers/thl/test_profiling/test_user_upk.py
new file mode 100644
index 0000000..53bb8fe
--- /dev/null
+++ b/tests/managers/thl/test_profiling/test_user_upk.py
@@ -0,0 +1,59 @@
+from datetime import datetime, timezone
+
+from generalresearch.managers.thl.profiling.user_upk import UserUpkManager
+
+now = datetime.now(tz=timezone.utc)
+base = {
+ "country_iso": "us",
+ "language_iso": "eng",
+ "timestamp": now,
+}
+upk_ans_dict = [
+ {"pred": "gr:gender", "obj": "gr:male"},
+ {"pred": "gr:age_in_years", "obj": "43"},
+ {"pred": "gr:home_postal_code", "obj": "33143"},
+ {"pred": "gr:ethnic_group", "obj": "gr:caucasians"},
+ {"pred": "gr:ethnic_group", "obj": "gr:asian"},
+]
+for a in upk_ans_dict:
+ a.update(base)
+
+
+class TestUserUpkManager:
+
+ def test_user_upk_empty(self, user_upk_manager: UserUpkManager, upk_data, user):
+ res = user_upk_manager.get_user_upk_mysql(user_id=user.user_id)
+ assert len(res) == 0
+
+ def test_user_upk(self, user_upk_manager: UserUpkManager, upk_data, user):
+ for x in upk_ans_dict:
+ x["user_id"] = user.user_id
+ user_upk = user_upk_manager.populate_user_upk_from_dict(upk_ans_dict)
+ user_upk_manager.set_user_upk(upk_ans=user_upk)
+
+ d = user_upk_manager.get_user_upk_simple(user_id=user.user_id)
+ assert d["gender"] == "male"
+ assert d["age_in_years"] == 43
+ assert d["home_postal_code"] == "33143"
+ assert d["ethnic_group"] == {"caucasians", "asian"}
+
+ # Change my answers. age 43->44, gender male->female,
+ # ethnic->remove asian, add black_or_african_american
+ for x in upk_ans_dict:
+ if x["pred"] == "age_in_years":
+ x["obj"] = "44"
+ if x["pred"] == "gender":
+ x["obj"] = "female"
+ upk_ans_dict[-1]["obj"] = "black_or_african_american"
+ user_upk = user_upk_manager.populate_user_upk_from_dict(upk_ans_dict)
+ user_upk_manager.set_user_upk(upk_ans=user_upk)
+
+ d = user_upk_manager.get_user_upk_simple(user_id=user.user_id)
+ assert d["gender"] == "female"
+ assert d["age_in_years"] == 44
+ assert d["home_postal_code"] == "33143"
+ assert d["ethnic_group"] == {"caucasians", "black_or_african_american"}
+
+ age, gender = user_upk_manager.get_age_gender(user_id=user.user_id)
+ assert age == 44
+ assert gender == "female"
diff --git a/tests/managers/thl/test_session_manager.py b/tests/managers/thl/test_session_manager.py
new file mode 100644
index 0000000..6bedc2b
--- /dev/null
+++ b/tests/managers/thl/test_session_manager.py
@@ -0,0 +1,137 @@
+from datetime import timedelta
+from decimal import Decimal
+from uuid import uuid4
+
+from faker import Faker
+
+from generalresearch.models import DeviceType
+from generalresearch.models.legacy.bucket import Bucket
+from generalresearch.models.thl.definitions import (
+ Status,
+ StatusCode1,
+ SessionStatusCode2,
+)
+from test_utils.models.conftest import user
+
+fake = Faker()
+
+
+class TestSessionManager:
+ def test_create_session(self, session_manager, user, utc_hour_ago):
+ bucket = Bucket(
+ loi_min=timedelta(seconds=60),
+ loi_max=timedelta(seconds=120),
+ user_payout_min=Decimal("1"),
+ user_payout_max=Decimal("2"),
+ )
+
+ s1 = session_manager.create(
+ started=utc_hour_ago,
+ user=user,
+ country_iso="us",
+ device_type=DeviceType.MOBILE,
+ ip=fake.ipv4_public(),
+ bucket=bucket,
+ url_metadata={"foo": "bar"},
+ uuid_id=uuid4().hex,
+ )
+
+ assert s1.id is not None
+ s2 = session_manager.get_from_uuid(session_uuid=s1.uuid)
+ assert s1 == s2
+
+ def test_finish_with_status(self, session_manager, user, utc_hour_ago):
+ uuid_1 = uuid4().hex
+ session = session_manager.create(
+ started=utc_hour_ago, user=user, uuid_id=uuid_1
+ )
+ session_manager.finish_with_status(
+ session=session,
+ status=Status.FAIL,
+ status_code_1=StatusCode1.SESSION_CONTINUE_FAIL,
+ status_code_2=SessionStatusCode2.USER_IS_BLOCKED,
+ )
+
+ s2 = session_manager.get_from_uuid(session_uuid=uuid_1)
+ assert s2.status == Status.FAIL
+ assert s2.status_code_1 == StatusCode1.SESSION_CONTINUE_FAIL
+ assert s2.status_code_2 == SessionStatusCode2.USER_IS_BLOCKED
+
+
+class TestSessionManagerFilter:
+
+ def test_base(self, session_manager, user, utc_now):
+ uuid_id = uuid4().hex
+ session_manager.create(started=utc_now, user=user, uuid_id=uuid_id)
+ res = session_manager.filter(limit=1)
+ assert len(res) != 0
+ assert isinstance(res, list)
+ assert res[0].uuid == uuid_id
+
+ def test_user(self, session_manager, user, utc_hour_ago):
+ session_manager.create(started=utc_hour_ago, user=user, uuid_id=uuid4().hex)
+ session_manager.create(started=utc_hour_ago, user=user, uuid_id=uuid4().hex)
+
+ res = session_manager.filter(user=user)
+ assert len(res) == 2
+
+ def test_product(
+ self, product_factory, user_factory, session_manager, user, utc_hour_ago
+ ):
+ from generalresearch.models.thl.session import Session
+ from generalresearch.models.thl.user import User
+
+ p1 = product_factory()
+
+ for n in range(5):
+ u = user_factory(product=p1)
+ session_manager.create(started=utc_hour_ago, user=u, uuid_id=uuid4().hex)
+
+ res = session_manager.filter(
+ product_uuids=[p1.uuid], started_since=utc_hour_ago
+ )
+ assert isinstance(res[0], Session)
+ assert isinstance(res[0].user, User)
+ assert len(res) == 5
+
+ def test_team(
+ self,
+ product_factory,
+ user_factory,
+ team,
+ session_manager,
+ user,
+ utc_hour_ago,
+ thl_web_rr,
+ ):
+ p1 = product_factory(team=team)
+
+ for n in range(5):
+ u = user_factory(product=p1)
+ session_manager.create(started=utc_hour_ago, user=u, uuid_id=uuid4().hex)
+
+ team.prefetch_products(thl_pg_config=thl_web_rr)
+ assert len(team.product_uuids) == 1
+ res = session_manager.filter(product_uuids=team.product_uuids)
+ assert len(res) == 5
+
+ def test_business(
+ self,
+ product_factory,
+ business,
+ user_factory,
+ session_manager,
+ user,
+ utc_hour_ago,
+ thl_web_rr,
+ ):
+ p1 = product_factory(business=business)
+
+ for n in range(5):
+ u = user_factory(product=p1)
+ session_manager.create(started=utc_hour_ago, user=u, uuid_id=uuid4().hex)
+
+ business.prefetch_products(thl_pg_config=thl_web_rr)
+ assert len(business.product_uuids) == 1
+ res = session_manager.filter(product_uuids=business.product_uuids)
+ assert len(res) == 5
diff --git a/tests/managers/thl/test_survey.py b/tests/managers/thl/test_survey.py
new file mode 100644
index 0000000..58c4577
--- /dev/null
+++ b/tests/managers/thl/test_survey.py
@@ -0,0 +1,376 @@
+import uuid
+from datetime import datetime, timezone
+from decimal import Decimal
+
+import pytest
+
+from generalresearch.models import Source
+from generalresearch.models.legacy.bucket import (
+ SurveyEligibilityCriterion,
+ TopNPlusBucket,
+ DurationSummary,
+ PayoutSummary,
+)
+from generalresearch.models.thl.profiling.user_question_answer import (
+ UserQuestionAnswer,
+)
+from generalresearch.models.thl.survey.model import (
+ Survey,
+ SurveyStat,
+ SurveyCategoryModel,
+ SurveyEligibilityDefinition,
+)
+
+
+@pytest.fixture(scope="session")
+def surveys_fixture():
+ return [
+ Survey(source=Source.TESTING, survey_id="a", buyer_code="buyer1"),
+ Survey(source=Source.TESTING, survey_id="b", buyer_code="buyer2"),
+ Survey(source=Source.TESTING, survey_id="c", buyer_code="buyer2"),
+ ]
+
+
+ssa = SurveyStat(
+ survey_source=Source.TESTING,
+ survey_survey_id="a",
+ survey_is_live=True,
+ quota_id="__all__",
+ cpi=Decimal(1),
+ country_iso="us",
+ version=1,
+ complete_too_fast_cutoff=300,
+ prescreen_conv_alpha=10,
+ prescreen_conv_beta=10,
+ conv_alpha=10,
+ conv_beta=10,
+ dropoff_alpha=10,
+ dropoff_beta=10,
+ completion_time_mu=1,
+ completion_time_sigma=0.4,
+ mobile_eligible_alpha=10,
+ mobile_eligible_beta=10,
+ desktop_eligible_alpha=10,
+ desktop_eligible_beta=10,
+ tablet_eligible_alpha=10,
+ tablet_eligible_beta=10,
+ long_fail_rate=1,
+ user_report_coeff=1,
+ recon_likelihood=0,
+ score_x0=0,
+ score_x1=1,
+ score=100,
+)
+ssb = ssa.model_dump()
+ssb["survey_source"] = Source.TESTING
+ssb["survey_survey_id"] = "b"
+ssb["completion_time_mu"] = 2
+ssb["score"] = 90
+ssb = SurveyStat.model_validate(ssb)
+
+
+class TestSurvey:
+
+ def test(
+ self,
+ delete_buyers_surveys,
+ buyer_manager,
+ survey_manager,
+ surveys_fixture,
+ ):
+ survey_manager.create_or_update(surveys_fixture)
+ survey_ids = {s.survey_id for s in surveys_fixture}
+ res = survey_manager.filter_by_natural_key(
+ source=Source.TESTING, survey_ids=survey_ids
+ )
+ assert len(res) == len(surveys_fixture)
+ assert res[0].id is not None
+ surveys2 = surveys_fixture.copy()
+ surveys2.append(
+ Survey(source=Source.TESTING, survey_id="d", buyer_code="buyer2")
+ )
+ survey_manager.create_or_update(surveys2)
+ survey_ids = {s.survey_id for s in surveys2}
+ res2 = survey_manager.filter_by_natural_key(
+ source=Source.TESTING, survey_ids=survey_ids
+ )
+ assert res2[0].id == res[0].id
+ assert res2[0] == res[0]
+ assert len(res2) == len(surveys2)
+
+ def test_category(self, survey_manager):
+ survey1 = Survey(id=562289, survey_id="a", source=Source.TESTING)
+ survey2 = Survey(id=562290, survey_id="a", source=Source.TESTING)
+ categories = list(survey_manager.category_manager.categories.values())
+ sc = [SurveyCategoryModel(category=c, strength=1 / 2) for c in categories[:2]]
+ sc2 = [SurveyCategoryModel(category=c, strength=1 / 2) for c in categories[2:4]]
+ survey1.categories = sc
+ survey2.categories = sc2
+ surveys = [survey1, survey2]
+ survey_manager.update_surveys_categories(surveys)
+
+ def test_survey_eligibility(
+ self, survey_manager, upk_data, question_manager, uqa_manager
+ ):
+ bucket = TopNPlusBucket(
+ id="c82cf98c578a43218334544ab376b00e",
+ contents=[],
+ duration=DurationSummary(max=1, min=1, q1=1, q2=1, q3=1),
+ quality_score=1,
+ payout=PayoutSummary(max=1, min=1, q1=1, q2=1, q3=1),
+ uri="https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i=2a4a897a76464af2b85703b72a125da0&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=82fe142",
+ )
+
+ survey1 = Survey(
+ survey_id="a",
+ source=Source.TESTING,
+ eligibility_criteria=SurveyEligibilityDefinition(
+ # "5d6d9f3c03bb40bf9d0a24f306387d7c", # gr:gender
+ # "c1309f099ab84a39b01200a56dac65cf", # d:600
+ # "90e86550ddf84b08a9f7f5372dd9651b", # i:gender
+ # "1a8216ddb09440a8bc748cf8ca89ecec", # i:adhoc_13126
+ property_codes=("i:adhoc_13126", "i:gender", "i:something"),
+ ),
+ )
+ # might have a couple surveys in the bucket ... merge them all together
+ qualifying_questions = set(survey1.eligibility_criteria.property_codes)
+
+ uqas = [
+ UserQuestionAnswer(
+ question_id="5d6d9f3c03bb40bf9d0a24f306387d7c",
+ answer=("1",),
+ calc_answers={"i:gender": ("1",), "d:1": ("2",)},
+ country_iso="us",
+ language_iso="eng",
+ property_code="gr:gender",
+ ),
+ UserQuestionAnswer(
+ question_id="c1309f099ab84a39b01200a56dac65cf",
+ answer=("50796", "50784"),
+ country_iso="us",
+ language_iso="eng",
+ property_code="d:600",
+ calc_answers={"d:600": ("50796", "50784")},
+ ),
+ UserQuestionAnswer(
+ question_id="1a8216ddb09440a8bc748cf8ca89ecec",
+ answer=("3", "4"),
+ country_iso="us",
+ language_iso="eng",
+ property_code="i:adhoc_13126",
+ calc_answers={"i:adhoc_13126": ("3", "4")},
+ ),
+ ]
+ uqad = dict()
+ for uqa in uqas:
+ for k, v in uqa.calc_answers.items():
+ if k in qualifying_questions:
+ uqad[k] = uqa
+ uqad[uqa.property_code] = uqa
+
+ question_ids = {uqa.question_id for uqa in uqad.values()}
+ qs = question_manager.get_multi_upk(question_ids)
+ # Sort question by LOWEST task_count (rarest)
+ qs = sorted(qs, key=lambda x: x.importance.task_count if x.importance else 0)
+ # qd = {q.id: q for q in qs}
+
+ q = [x for x in qs if x.ext_question_id == "i:adhoc_13126"][0]
+ q.explanation_template = "You have been diagnosed with: {answer}."
+ q = [x for x in qs if x.ext_question_id == "gr:gender"][0]
+ q.explanation_template = "Your gender is {answer}."
+
+ ecs = []
+ for q in qs:
+ answer_code = uqad[q.ext_question_id].answer
+ answer_label = tuple([q.choices_text_lookup[ans] for ans in answer_code])
+ explanation = None
+ if q.explanation_template:
+ explanation = q.explanation_template.format(
+ answer=", ".join(answer_label)
+ )
+ sec = SurveyEligibilityCriterion(
+ question_id=q.id,
+ question_text=q.text,
+ qualifying_answer=answer_code,
+ qualifying_answer_label=answer_label,
+ explanation=explanation,
+ property_code=q.ext_question_id,
+ )
+ print(sec)
+ ecs.append(sec)
+ bucket.eligibility_criteria = tuple(ecs)
+ print(bucket)
+
+
+class TestSurveyStat:
+ def test(
+ self,
+ delete_buyers_surveys,
+ surveystat_manager,
+ survey_manager,
+ surveys_fixture,
+ ):
+ survey_manager.create_or_update(surveys_fixture)
+ ss = [ssa, ssb]
+ surveystat_manager.update_or_create(ss)
+ keys = [s.unique_key for s in ss]
+ res = surveystat_manager.filter_by_unique_keys(keys)
+ assert len(res) == 2
+ assert res[0].conv_alpha == 10
+
+ ssa.conv_alpha = 11
+ surveystat_manager.update_or_create([ssa])
+ res = surveystat_manager.filter_by_unique_keys([ssa.unique_key])
+ assert len(res) == 1
+ assert res[0].conv_alpha == 11
+
+ @pytest.mark.skip()
+ def test_big(
+ self,
+ # delete_buyers_surveys,
+ surveystat_manager,
+ survey_manager,
+ surveys_fixture,
+ ):
+ survey = surveys_fixture[0].model_copy()
+ surveys = []
+ for idx in range(20_000):
+ s = survey.model_copy()
+ s.survey_id = uuid.uuid4().hex
+ surveys.append(s)
+
+ survey_manager.create_or_update(surveys)
+ # Realistically, we're going to have like, say 20k surveys
+ # and 99% of them will be updated each time
+ survey_stats = []
+ for survey in surveys:
+ ss = ssa.model_copy()
+ ss.survey__survey_id = survey.survey_id
+ ss.quota_id = "__all__"
+ ss.score = 10
+ survey_stats.append(ss)
+ print(len(survey_stats))
+ print(survey_stats[12].natural_key, survey_stats[2000].natural_key)
+ print(f"----a-----: {datetime.now().isoformat()}")
+ res = surveystat_manager.update_or_create(survey_stats)
+ print(f"----b-----: {datetime.now().isoformat()}")
+ assert len(res) == 20_000
+ return
+
+ # 1,000 of the 20,000 are "new"
+ now = datetime.now(tz=timezone.utc)
+ for s in ss[:1000]:
+ s.survey__survey_id = "b"
+ s.updated_at = now
+ # 18,000 need to be updated (scores update or whatever)
+ for s in ss[1000:-1000]:
+ s.score = 20
+ s.conv_alpha = 20
+ s.conv_beta = 20
+ s.updated_at = now
+ # and 1,000 don't change
+ print(f"----c-----: {datetime.now().isoformat()}")
+ res2 = surveystat_manager.update_or_create(ss)
+ print(f"----d-----: {datetime.now().isoformat()}")
+ assert len(res2) == 20_000
+
+ def test_ymsp(
+ self,
+ delete_buyers_surveys,
+ surveys_fixture,
+ survey_manager,
+ surveystat_manager,
+ ):
+ source = Source.TESTING
+ survey = surveys_fixture[0].model_copy()
+ surveys = []
+ for idx in range(100):
+ s = survey.model_copy()
+ s.survey_id = uuid.uuid4().hex
+ surveys.append(s)
+ survey_stats = []
+ for survey in surveys:
+ ss = ssa.model_copy()
+ ss.survey_survey_id = survey.survey_id
+ survey_stats.append(ss)
+
+ surveystat_manager.update_surveystats_for_source(
+ source=source, surveys=surveys, survey_stats=survey_stats
+ )
+ # UPDATE -------
+ since = datetime.now(tz=timezone.utc)
+ print(f"{since=}")
+
+ # 10 survey disappear
+ surveys = surveys[10:]
+
+ # and 2 new ones are created
+ for idx in range(2):
+ s = survey.model_copy()
+ s.survey_id = uuid.uuid4().hex
+ surveys.append(s)
+ survey_stats = []
+ for survey in surveys:
+ ss = ssa.model_copy()
+ ss.survey_survey_id = survey.survey_id
+ survey_stats.append(ss)
+ surveystat_manager.update_surveystats_for_source(
+ source=source, surveys=surveys, survey_stats=survey_stats
+ )
+
+ live_surveys = survey_manager.filter_by_source_live(source=source)
+ assert len(live_surveys) == 92 # 100 - 10 + 2
+
+ ss = surveystat_manager.filter_by_updated_since(since=since)
+ assert len(ss) == 102 # 92 existing + 10 not live
+
+ ss = surveystat_manager.filter_by_live()
+ assert len(ss) == 92
+
+ def test_filter(
+ self,
+ delete_buyers_surveys,
+ surveys_fixture,
+ survey_manager,
+ surveystat_manager,
+ ):
+ surveys = []
+ survey = surveys_fixture[0].model_copy()
+ survey.source = Source.TESTING
+ survey.survey_id = "a"
+ surveys.append(survey.model_copy())
+ survey.source = Source.TESTING
+ survey.survey_id = "b"
+ surveys.append(survey.model_copy())
+ survey.source = Source.TESTING2
+ survey.survey_id = "b"
+ surveys.append(survey.model_copy())
+ survey.source = Source.TESTING2
+ survey.survey_id = "c"
+ surveys.append(survey.model_copy())
+ # 4 surveys t:a, t:b, u:b, u:c
+
+ survey_stats = []
+ for survey in surveys:
+ ss = ssa.model_copy()
+ ss.survey_survey_id = survey.survey_id
+ ss.survey_source = survey.source
+ survey_stats.append(ss)
+
+ surveystat_manager.update_surveystats_for_source(
+ source=Source.TESTING, surveys=surveys[:2], survey_stats=survey_stats[:2]
+ )
+ surveystat_manager.update_surveystats_for_source(
+ source=Source.TESTING2, surveys=surveys[2:], survey_stats=survey_stats[2:]
+ )
+
+ survey_keys = [f"{s.source.value}:{s.survey_id}" for s in surveys]
+ res = surveystat_manager.filter(survey_keys=survey_keys, min_score=0.01)
+ assert len(res) == 4
+ res = survey_manager.filter_by_keys(survey_keys)
+ assert len(res) == 4
+
+ res = surveystat_manager.filter(survey_keys=survey_keys[:2])
+ assert len(res) == 2
+ res = survey_manager.filter_by_keys(survey_keys[:2])
+ assert len(res) == 2
diff --git a/tests/managers/thl/test_survey_penalty.py b/tests/managers/thl/test_survey_penalty.py
new file mode 100644
index 0000000..4c7dc08
--- /dev/null
+++ b/tests/managers/thl/test_survey_penalty.py
@@ -0,0 +1,101 @@
+import uuid
+
+import pytest
+from cachetools.keys import _HashedTuple
+
+from generalresearch.models import Source
+from generalresearch.models.thl.survey.penalty import (
+ BPSurveyPenalty,
+ TeamSurveyPenalty,
+)
+
+
+@pytest.fixture
+def product_uuid() -> str:
+ # Nothing touches the db here, we don't need actual products or teams
+ return uuid.uuid4().hex
+
+
+@pytest.fixture
+def team_uuid() -> str:
+ # Nothing touches the db here, we don't need actual products or teams
+ return uuid.uuid4().hex
+
+
+@pytest.fixture
+def penalties(product_uuid, team_uuid):
+ return [
+ BPSurveyPenalty(
+ source=Source.TESTING, survey_id="a", penalty=0.1, product_id=product_uuid
+ ),
+ BPSurveyPenalty(
+ source=Source.TESTING, survey_id="b", penalty=0.2, product_id=product_uuid
+ ),
+ TeamSurveyPenalty(
+ source=Source.TESTING, survey_id="a", penalty=1, team_id=team_uuid
+ ),
+ TeamSurveyPenalty(
+ source=Source.TESTING, survey_id="c", penalty=1, team_id=team_uuid
+ ),
+ # Source.TESTING:a is different from Source.TESTING2:a !
+ BPSurveyPenalty(
+ source=Source.TESTING2, survey_id="a", penalty=0.5, product_id=product_uuid
+ ),
+ # For a random BP, should not do anything
+ BPSurveyPenalty(
+ source=Source.TESTING, survey_id="b", penalty=1, product_id=uuid.uuid4().hex
+ ),
+ ]
+
+
+class TestSurveyPenalty:
+ def test(self, surveypenalty_manager, penalties, product_uuid, team_uuid):
+ surveypenalty_manager.set_penalties(penalties)
+
+ res = surveypenalty_manager.get_penalties_for(
+ product_id=product_uuid, team_id=team_uuid
+ )
+ assert res == {"t:a": 1.0, "t:b": 0.2, "t:c": 1, "u:a": 0.5}
+
+ # We can update penalties for a marketplace and not erase them for another.
+ # But remember, marketplace is batched, so it'll overwrite the previous
+ # values for that marketplace
+ penalties = [
+ BPSurveyPenalty(
+ source=Source.TESTING2,
+ survey_id="b",
+ penalty=0.1,
+ product_id=product_uuid,
+ )
+ ]
+ surveypenalty_manager.set_penalties(penalties)
+ res = surveypenalty_manager.get_penalties_for(
+ product_id=product_uuid, team_id=team_uuid
+ )
+ assert res == {"t:a": 1.0, "t:b": 0.2, "t:c": 1, "u:b": 0.1}
+
+ # Team id doesn't exist, so it should return the product's penalties
+ team_id_random = uuid.uuid4().hex
+ surveypenalty_manager.cache.clear()
+ res = surveypenalty_manager.get_penalties_for(
+ product_id=product_uuid, team_id=team_id_random
+ )
+ assert res == {"t:a": 0.1, "t:b": 0.2, "u:b": 0.1}
+
+ # Assert it is cached (no redis lookup needed)
+ assert surveypenalty_manager.cache.currsize == 1
+ res = surveypenalty_manager.get_penalties_for(
+ product_id=product_uuid, team_id=team_id_random
+ )
+ assert res == {"t:a": 0.1, "t:b": 0.2, "u:b": 0.1}
+ assert surveypenalty_manager.cache.currsize == 1
+ cached_key = tuple(list(list(surveypenalty_manager.cache.keys())[0])[1:])
+ assert cached_key == tuple(
+ ["product_id", product_uuid, "team_id", team_id_random]
+ )
+
+ # Both don't exist, return nothing
+ res = surveypenalty_manager.get_penalties_for(
+ product_id=uuid.uuid4().hex, team_id=uuid.uuid4().hex
+ )
+ assert res == {}
diff --git a/tests/managers/thl/test_task_adjustment.py b/tests/managers/thl/test_task_adjustment.py
new file mode 100644
index 0000000..839bbe1
--- /dev/null
+++ b/tests/managers/thl/test_task_adjustment.py
@@ -0,0 +1,346 @@
+import logging
+from random import randint
+
+import pytest
+from datetime import datetime, timezone, timedelta
+from decimal import Decimal
+
+from generalresearch.models import Source
+from generalresearch.models.thl.definitions import (
+ Status,
+ StatusCode1,
+ WallAdjustedStatus,
+)
+
+
+@pytest.fixture()
+def session_complete(session_with_tx_factory, user):
+ return session_with_tx_factory(
+ user=user, final_status=Status.COMPLETE, wall_req_cpi=Decimal("1.23")
+ )
+
+
+@pytest.fixture()
+def session_complete_with_wallet(session_with_tx_factory, user_with_wallet):
+ return session_with_tx_factory(
+ user=user_with_wallet,
+ final_status=Status.COMPLETE,
+ wall_req_cpi=Decimal("1.23"),
+ )
+
+
+@pytest.fixture()
+def session_fail(user, session_manager, wall_manager):
+ session = session_manager.create_dummy(
+ started=datetime.now(timezone.utc), user=user
+ )
+ wall1 = wall_manager.create_dummy(
+ session_id=session.id,
+ user_id=user.user_id,
+ source=Source.DYNATA,
+ req_survey_id="72723",
+ req_cpi=Decimal("3.22"),
+ started=datetime.now(timezone.utc),
+ )
+ wall_manager.finish(
+ wall=wall1,
+ status=Status.FAIL,
+ status_code_1=StatusCode1.PS_FAIL,
+ finished=wall1.started + timedelta(seconds=randint(a=60 * 2, b=60 * 10)),
+ )
+ session.wall_events.append(wall1)
+ return session
+
+
+class TestHandleRecons:
+
+ def test_complete_to_recon(
+ self,
+ session_complete,
+ thl_lm,
+ task_adjustment_manager,
+ wall_manager,
+ session_manager,
+ caplog,
+ ):
+ print(wall_manager.pg_config.dsn)
+ mid = session_complete.uuid
+ wall_uuid = session_complete.wall_events[-1].uuid
+ s = session_complete
+ ledger_manager = thl_lm
+
+ revenue_account = ledger_manager.get_account_task_complete_revenue()
+ current_amount = ledger_manager.get_account_filtered_balance(
+ revenue_account, "thl_wall", wall_uuid
+ )
+ assert (
+ current_amount == 123
+ ), "this is the amount of revenue from this task complete"
+
+ bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet(
+ s.user.product
+ )
+ current_bp_payout = ledger_manager.get_account_filtered_balance(
+ bp_wallet_account, "thl_session", mid
+ )
+ assert current_bp_payout == 117, "this is the amount paid to the BP"
+
+ # Do the work here !! ----v
+ task_adjustment_manager.handle_single_recon(
+ ledger_manager=thl_lm,
+ wall_uuid=wall_uuid,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ )
+ assert (
+ len(task_adjustment_manager.filter_by_wall_uuid(wall_uuid=wall_uuid)) == 1
+ )
+
+ current_amount = ledger_manager.get_account_filtered_balance(
+ revenue_account, "thl_wall", wall_uuid
+ )
+ assert current_amount == 0, "after recon, it should be zeroed"
+ current_bp_payout = ledger_manager.get_account_filtered_balance(
+ bp_wallet_account, "thl_session", mid
+ )
+ assert current_bp_payout == 0, "this is the amount paid to the BP"
+ commission_account = ledger_manager.get_account_or_create_bp_commission(
+ s.user.product
+ )
+ assert ledger_manager.get_account_balance(commission_account) == 0
+
+ # Now, say we get the exact same *adjust to incomplete* msg again. It should do nothing!
+ adjusted_timestamp = datetime.now(tz=timezone.utc)
+ wall = wall_manager.get_from_uuid(wall_uuid=wall_uuid)
+ with pytest.raises(match=" is already "):
+ wall_manager.adjust_status(
+ wall,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ adjusted_timestamp=adjusted_timestamp,
+ adjusted_cpi=Decimal(0),
+ )
+
+ session = session_manager.get_from_id(wall.session_id)
+ user = session.user
+ with caplog.at_level(logging.INFO):
+ ledger_manager.create_tx_task_adjustment(
+ wall, user=user, created=adjusted_timestamp
+ )
+ assert "No transactions needed" in caplog.text
+
+ session.wall_events = wall_manager.get_wall_events(session.id)
+ session.user.prefetch_product(wall_manager.pg_config)
+
+ with caplog.at_level(logging.INFO, logger="Wall"):
+ session_manager.adjust_status(session)
+ assert "is already f" in caplog.text or "is already Status.FAIL" in caplog.text
+
+ with caplog.at_level(logging.INFO, logger="LedgerManager"):
+ ledger_manager.create_tx_bp_adjustment(session, created=adjusted_timestamp)
+ assert "No transactions needed" in caplog.text
+
+ current_amount = ledger_manager.get_account_filtered_balance(
+ revenue_account, "thl_wall", wall_uuid
+ )
+ assert current_amount == 0, "after recon, it should be zeroed"
+ current_bp_payout = ledger_manager.get_account_filtered_balance(
+ bp_wallet_account, "thl_session", mid
+ )
+ assert current_bp_payout == 0, "this is the amount paid to the BP"
+
+ # And if we get an adj to fail, and handle it, it should do nothing at all
+ task_adjustment_manager.handle_single_recon(
+ ledger_manager=thl_lm,
+ wall_uuid=wall_uuid,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ )
+ assert (
+ len(task_adjustment_manager.filter_by_wall_uuid(wall_uuid=wall_uuid)) == 1
+ )
+ current_amount = ledger_manager.get_account_filtered_balance(
+ revenue_account, "thl_wall", wall_uuid
+ )
+ assert current_amount == 0, "after recon, it should be zeroed"
+
+ def test_fail_to_complete(self, session_fail, thl_lm, task_adjustment_manager):
+ s = session_fail
+ mid = session_fail.uuid
+ wall_uuid = session_fail.wall_events[-1].uuid
+ ledger_manager = thl_lm
+
+ revenue_account = ledger_manager.get_account_task_complete_revenue()
+ current_amount = ledger_manager.get_account_filtered_balance(
+ revenue_account, "thl_wall", mid
+ )
+ assert (
+ current_amount == 0
+ ), "this is the amount of revenue from this task complete"
+
+ bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet(
+ s.user.product
+ )
+ current_bp_payout = ledger_manager.get_account_filtered_balance(
+ bp_wallet_account, "thl_session", mid
+ )
+ assert current_bp_payout == 0, "this is the amount paid to the BP"
+
+ task_adjustment_manager.handle_single_recon(
+ ledger_manager=thl_lm,
+ wall_uuid=wall_uuid,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE,
+ )
+
+ current_amount = ledger_manager.get_account_filtered_balance(
+ revenue_account, "thl_wall", wall_uuid
+ )
+ assert current_amount == 322, "after recon, we should be paid"
+ current_bp_payout = ledger_manager.get_account_filtered_balance(
+ bp_wallet_account, "thl_session", mid
+ )
+ assert current_bp_payout == 306, "this is the amount paid to the BP"
+
+ # Now reverse it back to fail
+ task_adjustment_manager.handle_single_recon(
+ ledger_manager=thl_lm,
+ wall_uuid=wall_uuid,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ )
+
+ current_amount = ledger_manager.get_account_filtered_balance(
+ revenue_account, "thl_wall", wall_uuid
+ )
+ assert current_amount == 0
+ current_bp_payout = ledger_manager.get_account_filtered_balance(
+ bp_wallet_account, "thl_session", mid
+ )
+ assert current_bp_payout == 0
+
+ commission_account = ledger_manager.get_account_or_create_bp_commission(
+ s.user.product
+ )
+ assert ledger_manager.get_account_balance(commission_account) == 0
+
+ def test_complete_already_complete(
+ self, session_complete, thl_lm, task_adjustment_manager
+ ):
+ s = session_complete
+ mid = session_complete.uuid
+ wall_uuid = session_complete.wall_events[-1].uuid
+ ledger_manager = thl_lm
+
+ for _ in range(4):
+ # just run it 4 times to make sure nothing happens 4 times
+ task_adjustment_manager.handle_single_recon(
+ ledger_manager=thl_lm,
+ wall_uuid=wall_uuid,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE,
+ )
+
+ revenue_account = ledger_manager.get_account_task_complete_revenue()
+ bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet(
+ s.user.product
+ )
+ commission_account = ledger_manager.get_account_or_create_bp_commission(
+ s.user.product
+ )
+ current_amount = ledger_manager.get_account_filtered_balance(
+ revenue_account, "thl_wall", wall_uuid
+ )
+ assert current_amount == 123
+ assert ledger_manager.get_account_balance(commission_account) == 6
+
+ current_bp_payout = ledger_manager.get_account_filtered_balance(
+ bp_wallet_account, "thl_session", mid
+ )
+ assert current_bp_payout == 117
+
+ def test_incomplete_already_incomplete(
+ self, session_fail, thl_lm, task_adjustment_manager
+ ):
+ s = session_fail
+ mid = session_fail.uuid
+ wall_uuid = session_fail.wall_events[-1].uuid
+ ledger_manager = thl_lm
+
+ for _ in range(4):
+ task_adjustment_manager.handle_single_recon(
+ ledger_manager=thl_lm,
+ wall_uuid=wall_uuid,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ )
+
+ revenue_account = ledger_manager.get_account_task_complete_revenue()
+ bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet(
+ s.user.product
+ )
+ commission_account = ledger_manager.get_account_or_create_bp_commission(
+ s.user.product
+ )
+ current_amount = ledger_manager.get_account_filtered_balance(
+ revenue_account, "thl_wall", mid
+ )
+ assert current_amount == 0
+ assert ledger_manager.get_account_balance(commission_account) == 0
+
+ current_bp_payout = ledger_manager.get_account_filtered_balance(
+ bp_wallet_account, "thl_session", mid
+ )
+ assert current_bp_payout == 0
+
+ def test_complete_to_recon_user_wallet(
+ self,
+ session_complete_with_wallet,
+ user_with_wallet,
+ thl_lm,
+ task_adjustment_manager,
+ ):
+ s = session_complete_with_wallet
+ mid = s.uuid
+ wall_uuid = s.wall_events[-1].uuid
+ ledger_manager = thl_lm
+
+ revenue_account = ledger_manager.get_account_task_complete_revenue()
+ amount = ledger_manager.get_account_filtered_balance(
+ revenue_account, "thl_wall", wall_uuid
+ )
+ assert amount == 123, "this is the amount of revenue from this task complete"
+
+ bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet(
+ s.user.product
+ )
+ user_wallet_account = ledger_manager.get_account_or_create_user_wallet(s.user)
+ commission_account = ledger_manager.get_account_or_create_bp_commission(
+ s.user.product
+ )
+ amount = ledger_manager.get_account_filtered_balance(
+ bp_wallet_account, "thl_session", mid
+ )
+ assert amount == 70, "this is the amount paid to the BP"
+ amount = ledger_manager.get_account_filtered_balance(
+ user_wallet_account, "thl_session", mid
+ )
+ assert amount == 47, "this is the amount paid to the user"
+ assert (
+ ledger_manager.get_account_balance(commission_account) == 6
+ ), "earned commission"
+
+ task_adjustment_manager.handle_single_recon(
+ ledger_manager=thl_lm,
+ wall_uuid=wall_uuid,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ )
+
+ amount = ledger_manager.get_account_filtered_balance(
+ revenue_account, "thl_wall", wall_uuid
+ )
+ assert amount == 0
+ amount = ledger_manager.get_account_filtered_balance(
+ bp_wallet_account, "thl_session", mid
+ )
+ assert amount == 0
+ amount = ledger_manager.get_account_filtered_balance(
+ user_wallet_account, "thl_session", mid
+ )
+ assert amount == 0
+ assert (
+ ledger_manager.get_account_balance(commission_account) == 0
+ ), "earned commission"
diff --git a/tests/managers/thl/test_task_status.py b/tests/managers/thl/test_task_status.py
new file mode 100644
index 0000000..55c89c0
--- /dev/null
+++ b/tests/managers/thl/test_task_status.py
@@ -0,0 +1,696 @@
+import pytest
+from datetime import datetime, timezone, timedelta
+from decimal import Decimal
+
+from generalresearch.managers.thl.session import SessionManager
+from generalresearch.models import Source
+from generalresearch.models.thl.definitions import (
+ Status,
+ WallAdjustedStatus,
+ StatusCode1,
+)
+from generalresearch.models.thl.product import (
+ PayoutConfig,
+ UserWalletConfig,
+ PayoutTransformation,
+ PayoutTransformationPercentArgs,
+)
+from generalresearch.models.thl.session import Session, WallOut
+from generalresearch.models.thl.task_status import TaskStatusResponse
+from generalresearch.models.thl.user import User
+
+
+start1 = datetime(2023, 2, 1, tzinfo=timezone.utc)
+finish1 = start1 + timedelta(minutes=5)
+recon1 = start1 + timedelta(days=20)
+start2 = datetime(2023, 2, 2, tzinfo=timezone.utc)
+finish2 = start2 + timedelta(minutes=5)
+start3 = datetime(2023, 2, 3, tzinfo=timezone.utc)
+finish3 = start3 + timedelta(minutes=5)
+
+
+@pytest.fixture(scope="session")
+def bp1(product_manager):
+ # user wallet disabled, payout xform NULL
+ return product_manager.create_dummy(
+ user_wallet_config=UserWalletConfig(enabled=False),
+ payout_config=PayoutConfig(),
+ )
+
+
+@pytest.fixture(scope="session")
+def bp2(product_manager):
+ # user wallet disabled, payout xform 40%
+ return product_manager.create_dummy(
+ user_wallet_config=UserWalletConfig(enabled=False),
+ payout_config=PayoutConfig(
+ payout_transformation=PayoutTransformation(
+ f="payout_transformation_percent",
+ kwargs=PayoutTransformationPercentArgs(pct=0.4),
+ )
+ ),
+ )
+
+
+@pytest.fixture(scope="session")
+def bp3(product_manager):
+ # user wallet enabled, payout xform 50%
+ return product_manager.create_dummy(
+ user_wallet_config=UserWalletConfig(enabled=True),
+ payout_config=PayoutConfig(
+ payout_transformation=PayoutTransformation(
+ f="payout_transformation_percent",
+ kwargs=PayoutTransformationPercentArgs(pct=0.5),
+ )
+ ),
+ )
+
+
+class TestTaskStatus:
+
+ def test_task_status_complete_1(
+ self,
+ bp1,
+ user_factory,
+ finished_session_factory,
+ session_manager: SessionManager,
+ ):
+ # User Payout xform NULL
+ user1: User = user_factory(product=bp1)
+ s1: Session = finished_session_factory(
+ user=user1, started=start1, wall_req_cpi=Decimal(1), wall_count=2
+ )
+
+ expected_tsr = TaskStatusResponse.model_validate(
+ {
+ "tsid": s1.uuid,
+ "product_id": user1.product_id,
+ "product_user_id": user1.product_user_id,
+ "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "finished": s1.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "status": "c",
+ "payout": 95,
+ "user_payout": None,
+ "payout_format": None,
+ "user_payout_string": None,
+ "status_code_1": 14,
+ "payout_transformation": None,
+ }
+ )
+ w1 = s1.wall_events[0]
+ wo1 = WallOut(
+ uuid=w1.uuid,
+ source=Source.TESTING,
+ buyer_id=None,
+ req_survey_id=w1.req_survey_id,
+ req_cpi=w1.req_cpi,
+ started=w1.started,
+ survey_id=w1.survey_id,
+ cpi=w1.cpi,
+ finished=w1.finished,
+ status=w1.status,
+ status_code_1=w1.status_code_1,
+ )
+ w2 = s1.wall_events[1]
+ wo2 = WallOut(
+ uuid=w2.uuid,
+ source=Source.TESTING,
+ buyer_id=None,
+ req_survey_id=w2.req_survey_id,
+ req_cpi=w2.req_cpi,
+ started=w2.started,
+ survey_id=w2.survey_id,
+ cpi=w2.cpi,
+ finished=w2.finished,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ )
+ expected_tsr.wall_events = [wo1, wo2]
+ tsr = session_manager.get_task_status_response(s1.uuid)
+ assert tsr == expected_tsr
+
+ def test_task_status_complete_2(
+ self, bp2, user_factory, finished_session_factory, session_manager
+ ):
+ # User Payout xform 40%
+ user2: User = user_factory(product=bp2)
+ s2: Session = finished_session_factory(
+ user=user2, started=start1, wall_req_cpi=Decimal(1), wall_count=2
+ )
+ payout = 95 # 1.00 - 5% commission
+ user_payout = round(95 * 0.40) # 38
+ expected_tsr = TaskStatusResponse.model_validate(
+ {
+ "tsid": s2.uuid,
+ "product_id": user2.product_id,
+ "product_user_id": user2.product_user_id,
+ "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "finished": s2.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "status": "c",
+ "payout": payout,
+ "user_payout": user_payout,
+ "payout_format": "${payout/100:.2f}",
+ "user_payout_string": "$0.38",
+ "status_code_1": 14,
+ "status_code_2": None,
+ "payout_transformation": {
+ "f": "payout_transformation_percent",
+ "kwargs": {"pct": "0.4"},
+ },
+ }
+ )
+ w1 = s2.wall_events[0]
+ wo1 = WallOut(
+ uuid=w1.uuid,
+ source=Source.TESTING,
+ buyer_id=None,
+ req_survey_id=w1.req_survey_id,
+ req_cpi=w1.req_cpi,
+ started=w1.started,
+ survey_id=w1.survey_id,
+ cpi=w1.cpi,
+ finished=w1.finished,
+ status=w1.status,
+ status_code_1=w1.status_code_1,
+ user_cpi=Decimal("0.38"),
+ user_cpi_string="$0.38",
+ )
+ w2 = s2.wall_events[1]
+ wo2 = WallOut(
+ uuid=w2.uuid,
+ source=Source.TESTING,
+ buyer_id=None,
+ req_survey_id=w2.req_survey_id,
+ req_cpi=w2.req_cpi,
+ started=w2.started,
+ survey_id=w2.survey_id,
+ cpi=w2.cpi,
+ finished=w2.finished,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ user_cpi=Decimal("0.38"),
+ user_cpi_string="$0.38",
+ )
+ expected_tsr.wall_events = [wo1, wo2]
+
+ tsr = session_manager.get_task_status_response(s2.uuid)
+ assert tsr == expected_tsr
+
+ def test_task_status_complete_3(
+ self, bp3, user_factory, finished_session_factory, session_manager
+ ):
+ # Wallet enabled User Payout xform 50% (the response is identical
+ # to the user wallet disabled w same xform)
+ user3: User = user_factory(product=bp3)
+ s3: Session = finished_session_factory(
+ user=user3, started=start1, wall_req_cpi=Decimal(1)
+ )
+
+ expected_tsr = TaskStatusResponse.model_validate(
+ {
+ "tsid": s3.uuid,
+ "product_id": user3.product_id,
+ "product_user_id": user3.product_user_id,
+ "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "finished": s3.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "status": "c",
+ "payout": 95,
+ "user_payout": 48,
+ "payout_format": "${payout/100:.2f}",
+ "user_payout_string": "$0.48",
+ "kwargs": {},
+ "status_code_1": 14,
+ "payout_transformation": {
+ "f": "payout_transformation_percent",
+ "kwargs": {"pct": "0.5"},
+ },
+ }
+ )
+ tsr = session_manager.get_task_status_response(s3.uuid)
+ # Not bothering with wall events ...
+ tsr.wall_events = None
+ assert tsr == expected_tsr
+
+ def test_task_status_fail(
+ self, bp1, user_factory, finished_session_factory, session_manager
+ ):
+ # User Payout xform NULL: user payout is None always
+ user1: User = user_factory(product=bp1)
+ s1: Session = finished_session_factory(
+ user=user1, started=start1, final_status=Status.FAIL
+ )
+ expected_tsr = TaskStatusResponse.model_validate(
+ {
+ "tsid": s1.uuid,
+ "product_id": user1.product_id,
+ "product_user_id": user1.product_user_id,
+ "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "finished": s1.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "status": "f",
+ "payout": 0,
+ "user_payout": None,
+ "payout_format": None,
+ "user_payout_string": None,
+ "kwargs": {},
+ "status_code_1": s1.status_code_1.value,
+ "status_code_2": None,
+ "adjusted_status": None,
+ "adjusted_timestamp": None,
+ "adjusted_payout": None,
+ "adjusted_user_payout": None,
+ "adjusted_user_payout_string": None,
+ "payout_transformation": None,
+ }
+ )
+ tsr = session_manager.get_task_status_response(s1.uuid)
+ # Not bothering with wall events ...
+ tsr.wall_events = None
+ assert tsr == expected_tsr
+
+ def test_task_status_fail_xform(
+ self, bp2, user_factory, finished_session_factory, session_manager
+ ):
+ # User Payout xform 40%: user_payout is 0 (not None)
+
+ user: User = user_factory(product=bp2)
+ s: Session = finished_session_factory(
+ user=user, started=start1, final_status=Status.FAIL
+ )
+ expected_tsr = TaskStatusResponse.model_validate(
+ {
+ "tsid": s.uuid,
+ "product_id": user.product_id,
+ "product_user_id": user.product_user_id,
+ "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "status": "f",
+ "payout": 0,
+ "user_payout": 0,
+ "payout_format": "${payout/100:.2f}",
+ "user_payout_string": "$0.00",
+ "kwargs": {},
+ "status_code_1": s.status_code_1.value,
+ "status_code_2": None,
+ "payout_transformation": {
+ "f": "payout_transformation_percent",
+ "kwargs": {"pct": "0.4"},
+ },
+ }
+ )
+ tsr = session_manager.get_task_status_response(s.uuid)
+ # Not bothering with wall events ...
+ tsr.wall_events = None
+ assert tsr == expected_tsr
+
+ def test_task_status_abandon(
+ self, bp1, user_factory, session_factory, session_manager
+ ):
+ # User Payout xform NULL: all payout fields are None
+ user: User = user_factory(product=bp1)
+ s = session_factory(user=user, started=start1)
+ expected_tsr = TaskStatusResponse.model_validate(
+ {
+ "tsid": s.uuid,
+ "product_id": user.product_id,
+ "product_user_id": user.product_user_id,
+ "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "finished": None,
+ "status": None,
+ "payout": None,
+ "user_payout": None,
+ "payout_format": None,
+ "user_payout_string": None,
+ "kwargs": {},
+ "status_code_1": None,
+ "status_code_2": None,
+ "adjusted_status": None,
+ "adjusted_timestamp": None,
+ "adjusted_payout": None,
+ "adjusted_user_payout": None,
+ "adjusted_user_payout_string": None,
+ "payout_transformation": None,
+ }
+ )
+ tsr = session_manager.get_task_status_response(s.uuid)
+ # Not bothering with wall events ...
+ tsr.wall_events = None
+ assert tsr == expected_tsr
+
+ def test_task_status_abandon_xform(
+ self, bp2, user_factory, session_factory, session_manager
+ ):
+ # User Payout xform 40%: all payout fields are None (same as when payout xform is null)
+ user: User = user_factory(product=bp2)
+ s = session_factory(user=user, started=start1)
+ expected_tsr = TaskStatusResponse.model_validate(
+ {
+ "tsid": s.uuid,
+ "product_id": user.product_id,
+ "product_user_id": user.product_user_id,
+ "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "finished": None,
+ "status": None,
+ "payout": None,
+ "user_payout": None,
+ "payout_format": "${payout/100:.2f}",
+ "user_payout_string": None,
+ "kwargs": {},
+ "status_code_1": None,
+ "status_code_2": None,
+ "adjusted_status": None,
+ "adjusted_timestamp": None,
+ "adjusted_payout": None,
+ "adjusted_user_payout": None,
+ "adjusted_user_payout_string": None,
+ "payout_transformation": {
+ "f": "payout_transformation_percent",
+ "kwargs": {"pct": "0.4"},
+ },
+ }
+ )
+ tsr = session_manager.get_task_status_response(s.uuid)
+ # Not bothering with wall events ...
+ tsr.wall_events = None
+ assert tsr == expected_tsr
+
+ def test_task_status_adj_fail(
+ self,
+ bp1,
+ user_factory,
+ finished_session_factory,
+ wall_manager,
+ session_manager,
+ ):
+ # Complete -> Fail
+ # User Payout xform NULL: adjusted_user_* and user_* is still all None
+ user: User = user_factory(product=bp1)
+ s: Session = finished_session_factory(
+ user=user, started=start1, wall_req_cpi=Decimal(1)
+ )
+ wall_manager.adjust_status(
+ wall=s.wall_events[-1],
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ adjusted_timestamp=recon1,
+ )
+ session_manager.adjust_status(s)
+
+ expected_tsr = TaskStatusResponse.model_validate(
+ {
+ "tsid": s.uuid,
+ "product_id": user.product_id,
+ "product_user_id": user.product_user_id,
+ "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "status": "c",
+ "payout": 95,
+ "user_payout": None,
+ "payout_format": None,
+ "user_payout_string": None,
+ "kwargs": {},
+ "status_code_1": StatusCode1.COMPLETE.value,
+ "status_code_2": None,
+ "adjusted_status": WallAdjustedStatus.ADJUSTED_TO_FAIL.value,
+ "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "adjusted_payout": 0,
+ "adjusted_user_payout": None,
+ "adjusted_user_payout_string": None,
+ "payout_transformation": None,
+ }
+ )
+ tsr = session_manager.get_task_status_response(s.uuid)
+ # Not bothering with wall events ...
+ tsr.wall_events = None
+ assert tsr == expected_tsr
+
+ def test_task_status_adj_fail_xform(
+ self,
+ bp2,
+ user_factory,
+ finished_session_factory,
+ wall_manager,
+ session_manager,
+ ):
+ # Complete -> Fail
+ # User Payout xform 40%: adjusted_user_payout is 0 (not null)
+ user: User = user_factory(product=bp2)
+ s: Session = finished_session_factory(
+ user=user, started=start1, wall_req_cpi=Decimal(1)
+ )
+ wall_manager.adjust_status(
+ wall=s.wall_events[-1],
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL,
+ adjusted_timestamp=recon1,
+ )
+ session_manager.adjust_status(s)
+
+ expected_tsr = TaskStatusResponse.model_validate(
+ {
+ "tsid": s.uuid,
+ "product_id": user.product_id,
+ "product_user_id": user.product_user_id,
+ "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "status": "c",
+ "payout": 95,
+ "user_payout": 38,
+ "payout_format": "${payout/100:.2f}",
+ "user_payout_string": "$0.38",
+ "kwargs": {},
+ "status_code_1": StatusCode1.COMPLETE.value,
+ "status_code_2": None,
+ "adjusted_status": WallAdjustedStatus.ADJUSTED_TO_FAIL.value,
+ "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "adjusted_payout": 0,
+ "adjusted_user_payout": 0,
+ "adjusted_user_payout_string": "$0.00",
+ "payout_transformation": {
+ "f": "payout_transformation_percent",
+ "kwargs": {"pct": "0.4"},
+ },
+ }
+ )
+ tsr = session_manager.get_task_status_response(s.uuid)
+ # Not bothering with wall events ...
+ tsr.wall_events = None
+ assert tsr == expected_tsr
+
+ def test_task_status_adj_complete_from_abandon(
+ self,
+ bp1,
+ user_factory,
+ session_factory,
+ wall_manager,
+ session_manager,
+ ):
+ # User Payout xform NULL
+ user: User = user_factory(product=bp1)
+ s: Session = session_factory(
+ user=user,
+ started=start1,
+ wall_req_cpi=Decimal(1),
+ wall_count=2,
+ final_status=Status.ABANDON,
+ )
+ w = s.wall_events[-1]
+ wall_manager.adjust_status(
+ wall=w,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE,
+ adjusted_cpi=w.cpi,
+ adjusted_timestamp=recon1,
+ )
+ session_manager.adjust_status(s)
+
+ expected_tsr = TaskStatusResponse.model_validate(
+ {
+ "tsid": s.uuid,
+ "product_id": user.product_id,
+ "product_user_id": user.product_user_id,
+ "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "finished": None,
+ "status": None,
+ "payout": None,
+ "user_payout": None,
+ "payout_format": None,
+ "user_payout_string": None,
+ "kwargs": {},
+ "status_code_1": None,
+ "status_code_2": None,
+ "adjusted_status": WallAdjustedStatus.ADJUSTED_TO_COMPLETE.value,
+ "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "adjusted_payout": 95,
+ "adjusted_user_payout": None,
+ "adjusted_user_payout_string": None,
+ "payout_transformation": None,
+ }
+ )
+ tsr = session_manager.get_task_status_response(s.uuid)
+ # Not bothering with wall events ...
+ tsr.wall_events = None
+ assert tsr == expected_tsr
+
+ def test_task_status_adj_complete_from_abandon_xform(
+ self,
+ bp2,
+ user_factory,
+ session_factory,
+ wall_manager,
+ session_manager,
+ ):
+ # User Payout xform 40%
+ user: User = user_factory(product=bp2)
+ s: Session = session_factory(
+ user=user,
+ started=start1,
+ wall_req_cpi=Decimal(1),
+ wall_count=2,
+ final_status=Status.ABANDON,
+ )
+ w = s.wall_events[-1]
+ wall_manager.adjust_status(
+ wall=w,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE,
+ adjusted_cpi=w.cpi,
+ adjusted_timestamp=recon1,
+ )
+ session_manager.adjust_status(s)
+
+ expected_tsr = TaskStatusResponse.model_validate(
+ {
+ "tsid": s.uuid,
+ "product_id": user.product_id,
+ "product_user_id": user.product_user_id,
+ "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "finished": None,
+ "status": None,
+ "payout": None,
+ "user_payout": None,
+ "payout_format": "${payout/100:.2f}",
+ "user_payout_string": None,
+ "kwargs": {},
+ "status_code_1": None,
+ "status_code_2": None,
+ "adjusted_status": WallAdjustedStatus.ADJUSTED_TO_COMPLETE.value,
+ "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "adjusted_payout": 95,
+ "adjusted_user_payout": 38,
+ "adjusted_user_payout_string": "$0.38",
+ "payout_transformation": {
+ "f": "payout_transformation_percent",
+ "kwargs": {"pct": "0.4"},
+ },
+ }
+ )
+ tsr = session_manager.get_task_status_response(s.uuid)
+ # Not bothering with wall events ...
+ tsr.wall_events = None
+ assert tsr == expected_tsr
+
+ def test_task_status_adj_complete_from_fail(
+ self,
+ bp1,
+ user_factory,
+ finished_session_factory,
+ wall_manager,
+ session_manager,
+ ):
+ # User Payout xform NULL
+ user: User = user_factory(product=bp1)
+ s: Session = finished_session_factory(
+ user=user,
+ started=start1,
+ wall_req_cpi=Decimal(1),
+ wall_count=2,
+ final_status=Status.FAIL,
+ )
+ w = s.wall_events[-1]
+ wall_manager.adjust_status(
+ wall=w,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE,
+ adjusted_cpi=w.cpi,
+ adjusted_timestamp=recon1,
+ )
+ session_manager.adjust_status(s)
+
+ expected_tsr = TaskStatusResponse.model_validate(
+ {
+ "tsid": s.uuid,
+ "product_id": user.product_id,
+ "product_user_id": user.product_user_id,
+ "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "status": "f",
+ "payout": 0,
+ "user_payout": None,
+ "payout_format": None,
+ "user_payout_string": None,
+ "kwargs": {},
+ "status_code_1": s.status_code_1.value,
+ "status_code_2": None,
+ "adjusted_status": "ac",
+ "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "adjusted_payout": 95,
+ "adjusted_user_payout": None,
+ "adjusted_user_payout_string": None,
+ "payout_transformation": None,
+ }
+ )
+ tsr = session_manager.get_task_status_response(s.uuid)
+ # Not bothering with wall events ...
+ tsr.wall_events = None
+ assert tsr == expected_tsr
+
+ def test_task_status_adj_complete_from_fail_xform(
+ self,
+ bp2,
+ user_factory,
+ finished_session_factory,
+ wall_manager,
+ session_manager,
+ ):
+ # User Payout xform 40%
+ user: User = user_factory(product=bp2)
+ s: Session = finished_session_factory(
+ user=user,
+ started=start1,
+ wall_req_cpi=Decimal(1),
+ wall_count=2,
+ final_status=Status.FAIL,
+ )
+ w = s.wall_events[-1]
+ wall_manager.adjust_status(
+ wall=w,
+ adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE,
+ adjusted_cpi=w.cpi,
+ adjusted_timestamp=recon1,
+ )
+ session_manager.adjust_status(s)
+ expected_tsr = TaskStatusResponse.model_validate(
+ {
+ "tsid": s.uuid,
+ "product_id": user.product_id,
+ "product_user_id": user.product_user_id,
+ "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "status": "f",
+ "payout": 0,
+ "user_payout": 0,
+ "payout_format": "${payout/100:.2f}",
+ "user_payout_string": "$0.00",
+ "kwargs": {},
+ "status_code_1": s.status_code_1.value,
+ "status_code_2": None,
+ "adjusted_status": "ac",
+ "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
+ "adjusted_payout": 95,
+ "adjusted_user_payout": 38,
+ "adjusted_user_payout_string": "$0.38",
+ "payout_transformation": {
+ "f": "payout_transformation_percent",
+ "kwargs": {"pct": "0.4"},
+ },
+ }
+ )
+ tsr = session_manager.get_task_status_response(s.uuid)
+ # Not bothering with wall events ...
+ tsr.wall_events = None
+ assert tsr == expected_tsr
diff --git a/tests/managers/thl/test_user_manager/__init__.py b/tests/managers/thl/test_user_manager/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/managers/thl/test_user_manager/__init__.py
diff --git a/tests/managers/thl/test_user_manager/test_base.py b/tests/managers/thl/test_user_manager/test_base.py
new file mode 100644
index 0000000..3c7ee38
--- /dev/null
+++ b/tests/managers/thl/test_user_manager/test_base.py
@@ -0,0 +1,274 @@
+import logging
+from datetime import datetime, timezone
+from random import randint
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.managers.thl.user_manager import (
+ UserCreateNotAllowedError,
+ get_bp_user_create_limit_hourly,
+)
+from generalresearch.managers.thl.user_manager.rate_limit import (
+ RateLimitItemPerHourConstantKey,
+)
+from generalresearch.models.thl.product import UserCreateConfig, Product
+from generalresearch.models.thl.user import User
+from test_utils.models.conftest import (
+ user,
+ product,
+ user_manager,
+ product_manager,
+)
+
+logger = logging.getLogger()
+
+
+class TestUserManager:
+
+ def test_copying_lru_cache(self, user_manager, user):
+ # Before adding the deepcopy_return decorator, this would fail b/c the returned user
+ # is mutable, and it would mutate in the cache
+ # user_manager = self.get_user_manager()
+
+ user_manager.clear_user_inmemory_cache(user)
+ u = user_manager.get_user(user_id=user.user_id)
+ assert not u.blocked
+
+ u.blocked = True
+ u = user_manager.get_user(user_id=user.user_id)
+ assert not u.blocked
+
+ def test_get_user_no_inmemory(self, user, user_manager):
+ user_manager.clear_user_inmemory_cache(user)
+ user_manager.get_user.__wrapped__.cache_clear()
+ u = user_manager.get_user(user_id=user.user_id)
+ # this should hit mysql
+ assert u == user
+
+ cache_info = user_manager.get_user.__wrapped__.cache_info()
+ assert cache_info.hits == 0, cache_info
+ assert cache_info.misses == 1, cache_info
+
+ # this should hit the lru cache
+ u = user_manager.get_user(user_id=user.user_id)
+ assert u == user
+
+ cache_info = user_manager.get_user.__wrapped__.cache_info()
+ assert cache_info.hits == 1, cache_info
+ assert cache_info.misses == 1, cache_info
+
+ def test_get_user_with_inmemory(self, user_manager, user):
+ # user_manager = self.get_user_manager()
+
+ user_manager.set_user_inmemory_cache(user)
+ user_manager.get_user.__wrapped__.cache_clear()
+ u = user_manager.get_user(user_id=user.user_id)
+ # this should hit inmemory cache
+ assert u == user
+
+ cache_info = user_manager.get_user.__wrapped__.cache_info()
+ assert cache_info.hits == 0, cache_info
+ assert cache_info.misses == 1, cache_info
+
+ # this should hit the lru cache
+ u = user_manager.get_user(user_id=user.user_id)
+ assert u == user
+
+ cache_info = user_manager.get_user.__wrapped__.cache_info()
+ assert cache_info.hits == 1, cache_info
+ assert cache_info.misses == 1, cache_info
+
+
+class TestBlockUserManager:
+
+ def test_block_user(self, product, user_manager):
+ product_user_id = f"user-{uuid4().hex[:10]}"
+
+ # mysql_user_manager to skip user creation limit check
+ user: User = user_manager.mysql_user_manager.create_user(
+ product_id=product.id, product_user_id=product_user_id
+ )
+ assert not user.blocked
+
+ # get_user to make sure caches are populated
+ user = user_manager.get_user(
+ product_id=product.id, product_user_id=product_user_id
+ )
+ assert not user.blocked
+
+ assert user_manager.block_user(user) is True
+ assert user.blocked
+
+ user = user_manager.get_user(
+ product_id=product.id, product_user_id=product_user_id
+ )
+ assert user.blocked
+
+ user = user_manager.get_user(user_id=user.user_id)
+ assert user.blocked
+
+ def test_block_user_whitelist(self, product, user_manager, thl_web_rw):
+ product_user_id = f"user-{uuid4().hex[:10]}"
+
+ # mysql_user_manager to skip user creation limit check
+ user: User = user_manager.mysql_user_manager.create_user(
+ product_id=product.id, product_user_id=product_user_id
+ )
+ assert not user.blocked
+
+ now = datetime.now(tz=timezone.utc)
+ # Adds user to whitelist
+ thl_web_rw.execute_write(
+ """
+ INSERT INTO userprofile_userstat
+ (user_id, key, value, date)
+ VALUES (%(user_id)s, 'USER_HEALTH.access_control', 1, %(date)s)
+ ON CONFLICT (user_id, key) DO UPDATE SET value=1""",
+ params={"user_id": user.user_id, "date": now},
+ )
+ assert user_manager.is_whitelisted(user)
+ assert user_manager.block_user(user) is False
+ assert not user.blocked
+
+
+class TestCreateUserManager:
+
+ def test_create_user(self, product_manager, thl_web_rw, user_manager):
+ product: Product = product_manager.create_dummy(
+ user_create_config=UserCreateConfig(
+ min_hourly_create_limit=10, max_hourly_create_limit=69
+ ),
+ )
+
+ product_user_id = f"user-{uuid4().hex[:10]}"
+
+ user: User = user_manager.mysql_user_manager.create_user(
+ product_id=product.id, product_user_id=product_user_id
+ )
+ assert isinstance(user, User)
+ assert user.product_id == product.id
+ assert user.product_user_id == product_user_id
+
+ assert user.user_id is not None
+ assert user.uuid is not None
+
+ # make sure thl_user row is created
+ res_thl_user = thl_web_rw.execute_sql_query(
+ query=f"""
+ SELECT *
+ FROM thl_user AS u
+ WHERE u.id = %s
+ """,
+ params=[user.user_id],
+ )
+
+ assert len(res_thl_user) == 1
+
+ u2 = user_manager.get_user(
+ product_id=product.id, product_user_id=product_user_id
+ )
+ assert u2.user_id == user.user_id
+ assert u2.uuid == user.uuid
+
+ def test_create_user_integrity_error(self, product_manager, user_manager, caplog):
+ product: Product = product_manager.create_dummy(
+ product_id=uuid4().hex,
+ team_id=uuid4().hex,
+ name=f"Test Product ID #{uuid4().hex[:6]}",
+ user_create_config=UserCreateConfig(
+ min_hourly_create_limit=10, max_hourly_create_limit=69
+ ),
+ )
+
+ product_user_id = f"user-{uuid4().hex[:10]}"
+ rand_msg = f"log-{uuid4().hex}"
+
+ with caplog.at_level(logging.INFO):
+ logger.info(rand_msg)
+ user1 = user_manager.mysql_user_manager.create_user(
+ product_id=product.id, product_user_id=product_user_id
+ )
+
+ assert len(caplog.records) == 1
+ assert caplog.records[0].getMessage() == rand_msg
+
+ # Should cause a constraint error, triggering a lookup instead
+ with caplog.at_level(logging.INFO):
+ user2 = user_manager.mysql_user_manager.create_user(
+ product_id=product.id, product_user_id=product_user_id
+ )
+
+ assert len(caplog.records) == 3
+ assert caplog.records[0].getMessage() == rand_msg
+ assert (
+ caplog.records[1].getMessage()
+ == f"mysql_user_manager.create_user_new integrity error: {product.id} {product_user_id}"
+ )
+ assert (
+ caplog.records[2].getMessage()
+ == f"get_user_from_mysql: {product.id}, {product_user_id}, None, None"
+ )
+
+ assert user1 == user2
+
+ def test_raise_allow_user_create(self, product_manager, user_manager):
+ rand_num = randint(25, 200)
+ product: Product = product_manager.create_dummy(
+ product_id=uuid4().hex,
+ team_id=uuid4().hex,
+ name=f"Test Product ID #{uuid4().hex[:6]}",
+ user_create_config=UserCreateConfig(
+ min_hourly_create_limit=rand_num,
+ max_hourly_create_limit=rand_num,
+ ),
+ )
+
+ instance: Product = user_manager.product_manager.get_by_uuid(
+ product_uuid=product.id
+ )
+
+ # get_bp_user_create_limit_hourly is dynamically generated, make sure
+ # we use this value in our tests and not the
+ # UserCreateConfig.max_hourly_create_limit value
+ rl_value: int = get_bp_user_create_limit_hourly(product=instance)
+
+ # This is a randomly generated product_id, which means we'll always
+ # use the global defaults
+ assert rand_num == instance.user_create_config.min_hourly_create_limit
+ assert rand_num == instance.user_create_config.max_hourly_create_limit
+ assert rand_num == rl_value
+
+ rl = RateLimitItemPerHourConstantKey(rl_value)
+ assert str(rl) == f"{rand_num} per 1 hour"
+
+ key = rl.key_for("thl-grpc", "allow_user_create", instance.id)
+ assert key == f"LIMITER/thl-grpc/allow_user_create/{instance.id}"
+
+ # make sure we clear the key or subsequent tests will fail
+ user_manager.user_manager_limiter.storage.clear(key=key)
+
+ n = 0
+ with pytest.raises(expected_exception=UserCreateNotAllowedError) as cm:
+ for n, _ in enumerate(range(rl_value + 5)):
+ user_manager.user_manager_limiter.raise_allow_user_create(
+ product=product
+ )
+ assert rl_value == n
+
+
+class TestUserManagerMethods:
+
+ def test_audit_log(self, user_manager, user, audit_log_manager):
+ from generalresearch.models.thl.userhealth import AuditLog
+
+ res = audit_log_manager.filter_by_user_id(user_id=user.user_id)
+ assert len(res) == 0
+
+ msg = uuid4().hex
+ user_manager.audit_log(user=user, level=30, event_type=msg)
+
+ res = audit_log_manager.filter_by_user_id(user_id=user.user_id)
+ assert len(res) == 1
+ assert isinstance(res[0], AuditLog)
+ assert res[0].event_type == msg
diff --git a/tests/managers/thl/test_user_manager/test_mysql.py b/tests/managers/thl/test_user_manager/test_mysql.py
new file mode 100644
index 0000000..0313bbf
--- /dev/null
+++ b/tests/managers/thl/test_user_manager/test_mysql.py
@@ -0,0 +1,25 @@
+from test_utils.models.conftest import user, user_manager
+
+
+class TestUserManagerMysqlNew:
+
+ def test_get_notset(self, user_manager):
+ assert (
+ user_manager.mysql_user_manager.get_user_from_mysql(user_id=-3105) is None
+ )
+
+ def test_get_user_id(self, user, user_manager):
+ assert (
+ user_manager.mysql_user_manager.get_user_from_mysql(user_id=user.user_id)
+ == user
+ )
+
+ def test_get_uuid(self, user, user_manager):
+ u = user_manager.mysql_user_manager.get_user_from_mysql(user_uuid=user.uuid)
+ assert u == user
+
+ def test_get_ubp(self, user, user_manager):
+ u = user_manager.mysql_user_manager.get_user_from_mysql(
+ product_id=user.product_id, product_user_id=user.product_user_id
+ )
+ assert u == user
diff --git a/tests/managers/thl/test_user_manager/test_redis.py b/tests/managers/thl/test_user_manager/test_redis.py
new file mode 100644
index 0000000..a69519e
--- /dev/null
+++ b/tests/managers/thl/test_user_manager/test_redis.py
@@ -0,0 +1,80 @@
+import pytest
+
+from generalresearch.managers.base import Permission
+
+
+class TestUserManagerRedis:
+
+ def test_get_notset(self, user_manager, user):
+ user_manager.clear_user_inmemory_cache(user=user)
+ assert user_manager.redis_user_manager.get_user(user_id=user.user_id) is None
+
+ def test_get_user_id(self, user_manager, user):
+ user_manager.redis_user_manager.set_user(user=user)
+
+ assert user_manager.redis_user_manager.get_user(user_id=user.user_id) == user
+
+ def test_get_uuid(self, user_manager, user):
+ user_manager.redis_user_manager.set_user(user=user)
+
+ assert user_manager.redis_user_manager.get_user(user_uuid=user.uuid) == user
+
+ def test_get_ubp(self, user_manager, user):
+ user_manager.redis_user_manager.set_user(user=user)
+
+ assert (
+ user_manager.redis_user_manager.get_user(
+ product_id=user.product_id, product_user_id=user.product_user_id
+ )
+ == user
+ )
+
+ @pytest.mark.skip(reason="TODO")
+ def test_set(self):
+ # I mean, the sets are implicitly tested by the get tests above. no point
+ pass
+
+ def test_get_with_cache_prefix(self, settings, user, thl_web_rw, thl_web_rr):
+ """
+ Confirm the prefix functionality is working; we do this so it
+ is easier to migrate between any potentially breaking versions
+ if we don't want any broken keys; not as important after
+ pydantic usage...
+ """
+ from generalresearch.managers.thl.user_manager.user_manager import (
+ UserManager,
+ )
+
+ um1 = UserManager(
+ pg_config=thl_web_rw,
+ pg_config_rr=thl_web_rr,
+ sql_permissions=[Permission.UPDATE, Permission.CREATE],
+ redis=settings.redis,
+ redis_timeout=settings.redis_timeout,
+ )
+
+ um2 = UserManager(
+ pg_config=thl_web_rw,
+ pg_config_rr=thl_web_rr,
+ sql_permissions=[Permission.UPDATE, Permission.CREATE],
+ redis=settings.redis,
+ redis_timeout=settings.redis_timeout,
+ cache_prefix="user-lookup-v2",
+ )
+
+ um1.get_or_create_user(
+ product_id=user.product_id, product_user_id=user.product_user_id
+ )
+ um2.get_or_create_user(
+ product_id=user.product_id, product_user_id=user.product_user_id
+ )
+
+ res1 = um1.redis_user_manager.client.get(f"user-lookup:user_id:{user.user_id}")
+ assert res1 is not None
+
+ res2 = um2.redis_user_manager.client.get(
+ f"user-lookup-v2:user_id:{user.user_id}"
+ )
+ assert res2 is not None
+
+ assert res1 == res2
diff --git a/tests/managers/thl/test_user_manager/test_user_fetch.py b/tests/managers/thl/test_user_manager/test_user_fetch.py
new file mode 100644
index 0000000..a4b3d57
--- /dev/null
+++ b/tests/managers/thl/test_user_manager/test_user_fetch.py
@@ -0,0 +1,48 @@
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.models.thl.user import User
+from test_utils.models.conftest import product, user_manager, user_factory
+
+
+class TestUserManagerFetch:
+
+ def test_fetch(self, user_factory, product, user_manager):
+ user1: User = user_factory(product=product)
+ user2: User = user_factory(product=product)
+ res = user_manager.fetch_by_bpuids(
+ product_id=product.uuid,
+ product_user_ids=[user1.product_user_id, user2.product_user_id],
+ )
+ assert len(res) == 2
+
+ res = user_manager.fetch(user_ids=[user1.user_id, user2.user_id])
+ assert len(res) == 2
+
+ res = user_manager.fetch(user_uuids=[user1.uuid, user2.uuid])
+ assert len(res) == 2
+
+ # filter including bogus values
+ res = user_manager.fetch(user_uuids=[user1.uuid, uuid4().hex])
+ assert len(res) == 1
+
+ res = user_manager.fetch(user_uuids=[uuid4().hex])
+ assert len(res) == 0
+
+ def test_fetch_invalid(self, user_manager):
+ with pytest.raises(AssertionError) as e:
+ user_manager.fetch(user_uuids=[], user_ids=None)
+ assert "Must pass ONE of user_ids, user_uuids" in str(e.value)
+
+ with pytest.raises(AssertionError) as e:
+ user_manager.fetch(user_uuids=uuid4().hex)
+ assert "must pass a collection of user_uuids" in str(e.value)
+
+ with pytest.raises(AssertionError) as e:
+ user_manager.fetch(user_uuids=[uuid4().hex], user_ids=[1, 2, 3])
+ assert "Must pass ONE of user_ids, user_uuids" in str(e.value)
+
+ with pytest.raises(AssertionError) as e:
+ user_manager.fetch(user_ids=list(range(501)))
+ assert "limit 500 user_ids" in str(e.value)
diff --git a/tests/managers/thl/test_user_manager/test_user_metadata.py b/tests/managers/thl/test_user_manager/test_user_metadata.py
new file mode 100644
index 0000000..91dc16a
--- /dev/null
+++ b/tests/managers/thl/test_user_manager/test_user_metadata.py
@@ -0,0 +1,88 @@
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.models.thl.user_profile import UserMetadata
+from test_utils.models.conftest import user, user_manager, user_factory
+
+
+class TestUserMetadataManager:
+
+ def test_get_notset(self, user, user_manager, user_metadata_manager):
+ # The row in the db won't exist. It just returns the default obj with everything None (except for the user_id)
+ um1 = user_metadata_manager.get(user_id=user.user_id)
+ assert um1 == UserMetadata(user_id=user.user_id)
+
+ def test_create(self, user_factory, product, user_metadata_manager):
+ from generalresearch.models.thl.user import User
+
+ u1: User = user_factory(product=product)
+
+ email_address = f"{uuid4().hex}@example.com"
+ um = UserMetadata(user_id=u1.user_id, email_address=email_address)
+ # This happens in the model itself, nothing to do with the manager (a model_validator)
+ assert um.email_sha256 is not None
+
+ user_metadata_manager.update(um)
+ um2 = user_metadata_manager.get(email_address=email_address)
+ assert um == um2
+
+ def test_create_no_email(self, product, user_factory, user_metadata_manager):
+ from generalresearch.models.thl.user import User
+
+ u1: User = user_factory(product=product)
+ um = UserMetadata(user_id=u1.user_id)
+ assert um.email_sha256 is None
+
+ user_metadata_manager.update(um)
+ um2 = user_metadata_manager.get(user_id=u1.user_id)
+ assert um == um2
+
+ def test_update(self, product, user_factory, user_metadata_manager):
+ from generalresearch.models.thl.user import User
+
+ u: User = user_factory(product=product)
+
+ email_address = f"{uuid4().hex}@example1.com"
+ um = UserMetadata(user_id=u.user_id, email_address=email_address)
+ user_metadata_manager.update(user_metadata=um)
+
+ um.email_address = email_address.replace("example1", "example2")
+ user_metadata_manager.update(user_metadata=um)
+
+ um2 = user_metadata_manager.get(email_address=um.email_address)
+ assert um2.email_address != email_address
+
+ assert um2 == UserMetadata(
+ user_id=u.user_id,
+ email_address=email_address.replace("example1", "example2"),
+ )
+
+ def test_filter(self, user_factory, product, user_metadata_manager):
+ from generalresearch.models.thl.user import User
+
+ user1: User = user_factory(product=product)
+ user2: User = user_factory(product=product)
+
+ email_address = f"{uuid4().hex}@example.com"
+ res = user_metadata_manager.filter(email_addresses=[email_address])
+ assert len(res) == 0
+
+ # Create 2 user metadata with the same email address
+ user_metadata_manager.update(
+ user_metadata=UserMetadata(
+ user_id=user1.user_id, email_address=email_address
+ )
+ )
+ user_metadata_manager.update(
+ user_metadata=UserMetadata(
+ user_id=user2.user_id, email_address=email_address
+ )
+ )
+
+ res = user_metadata_manager.filter(email_addresses=[email_address])
+ assert len(res) == 2
+
+ with pytest.raises(expected_exception=ValueError) as e:
+ res = user_metadata_manager.get(email_address=email_address)
+ assert "More than 1 result returned!" in str(e.value)
diff --git a/tests/managers/thl/test_user_streak.py b/tests/managers/thl/test_user_streak.py
new file mode 100644
index 0000000..7728f9f
--- /dev/null
+++ b/tests/managers/thl/test_user_streak.py
@@ -0,0 +1,225 @@
+import copy
+from datetime import datetime, timezone, timedelta, date
+from decimal import Decimal
+from zoneinfo import ZoneInfo
+
+import pytest
+
+from generalresearch.managers.thl.user_streak import compute_streaks_from_days
+from generalresearch.models.thl.definitions import StatusCode1, Status
+from generalresearch.models.thl.user_streak import (
+ UserStreak,
+ StreakState,
+ StreakPeriod,
+ StreakFulfillment,
+)
+
+
+def test_compute_streaks_from_days():
+ days = [
+ date(2026, 1, 1),
+ date(2026, 1, 4),
+ date(2026, 1, 5),
+ date(2026, 1, 6),
+ date(2026, 1, 8),
+ date(2026, 1, 9),
+ date(2026, 2, 11),
+ date(2026, 2, 12),
+ ]
+
+ # Active
+ today = date(2026, 2, 12)
+ res = compute_streaks_from_days(days, "us", period=StreakPeriod.DAY, today=today)
+ assert res == (2, 3, StreakState.ACTIVE, date(2026, 2, 12))
+
+ # At Risk
+ today = date(2026, 2, 13)
+ res = compute_streaks_from_days(days, "us", period=StreakPeriod.DAY, today=today)
+ assert res == (2, 3, StreakState.AT_RISK, date(2026, 2, 12))
+
+ # Broken
+ today = date(2026, 2, 14)
+ res = compute_streaks_from_days(days, "us", period=StreakPeriod.DAY, today=today)
+ assert res == (0, 3, StreakState.BROKEN, date(2026, 2, 12))
+
+ # Monthly, active
+ today = date(2026, 2, 14)
+ res = compute_streaks_from_days(days, "us", period=StreakPeriod.MONTH, today=today)
+ assert res == (2, 2, StreakState.ACTIVE, date(2026, 2, 1))
+
+ # monthly, at risk
+ today = date(2026, 3, 1)
+ res = compute_streaks_from_days(days, "us", period=StreakPeriod.MONTH, today=today)
+ assert res == (2, 2, StreakState.AT_RISK, date(2026, 2, 1))
+
+ # monthly, broken
+ today = date(2026, 4, 1)
+ res = compute_streaks_from_days(days, "us", period=StreakPeriod.MONTH, today=today)
+ assert res == (0, 2, StreakState.BROKEN, date(2026, 2, 1))
+
+
+@pytest.fixture
+def broken_active_streak(user):
+ return [
+ UserStreak(
+ period=StreakPeriod.DAY,
+ fulfillment=StreakFulfillment.ACTIVE,
+ country_iso="us",
+ user_id=user.user_id,
+ last_fulfilled_period_start=date(2025, 2, 11),
+ current_streak=0,
+ longest_streak=1,
+ state=StreakState.BROKEN,
+ ),
+ UserStreak(
+ period=StreakPeriod.WEEK,
+ fulfillment=StreakFulfillment.ACTIVE,
+ country_iso="us",
+ user_id=user.user_id,
+ last_fulfilled_period_start=date(2025, 2, 10),
+ current_streak=0,
+ longest_streak=1,
+ state=StreakState.BROKEN,
+ ),
+ UserStreak(
+ period=StreakPeriod.MONTH,
+ fulfillment=StreakFulfillment.ACTIVE,
+ country_iso="us",
+ user_id=user.user_id,
+ last_fulfilled_period_start=date(2025, 2, 1),
+ current_streak=0,
+ longest_streak=1,
+ state=StreakState.BROKEN,
+ ),
+ ]
+
+
+def create_session_fail(session_manager, start, user):
+ session = session_manager.create_dummy(started=start, country_iso="us", user=user)
+ session_manager.finish_with_status(
+ session,
+ finished=start + timedelta(minutes=1),
+ status=Status.FAIL,
+ status_code_1=StatusCode1.BUYER_FAIL,
+ )
+
+
+def create_session_complete(session_manager, start, user):
+ session = session_manager.create_dummy(started=start, country_iso="us", user=user)
+ session_manager.finish_with_status(
+ session,
+ finished=start + timedelta(minutes=1),
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ payout=Decimal(1),
+ )
+
+
+def test_user_streak_empty(user_streak_manager, user):
+ streaks = user_streak_manager.get_user_streaks(
+ user_id=user.user_id, country_iso="us"
+ )
+ assert streaks == []
+
+
+def test_user_streaks_active_broken(
+ user_streak_manager, user, session_manager, broken_active_streak
+):
+ # Testing active streak, but broken (not today or yesterday)
+ start1 = datetime(2025, 2, 12, tzinfo=timezone.utc)
+ end1 = start1 + timedelta(minutes=1)
+
+ # abandon counts as inactive
+ session = session_manager.create_dummy(started=start1, country_iso="us", user=user)
+ streak = user_streak_manager.get_user_streaks(user_id=user.user_id)
+ assert streak == []
+
+ # session start fail counts as inactive
+ session_manager.finish_with_status(
+ session,
+ finished=end1,
+ status=Status.FAIL,
+ status_code_1=StatusCode1.SESSION_START_QUALITY_FAIL,
+ )
+ streak = user_streak_manager.get_user_streaks(
+ user_id=user.user_id, country_iso="us"
+ )
+ assert streak == []
+
+ # Create another as a real failure
+ create_session_fail(session_manager, start1, user)
+
+ # This is going to be the day before b/c midnight utc is 8pm US
+ last_active_day = start1.astimezone(ZoneInfo("America/New_York")).date()
+ assert last_active_day == date(2025, 2, 11)
+ expected_streaks = broken_active_streak.copy()
+ streaks = user_streak_manager.get_user_streaks(user_id=user.user_id)
+ assert streaks == expected_streaks
+
+ # Create another the next day
+ start2 = start1 + timedelta(days=1)
+ create_session_fail(session_manager, start2, user)
+
+ last_active_day = start2.astimezone(ZoneInfo("America/New_York")).date()
+ expected_streaks = copy.deepcopy(broken_active_streak)
+ expected_streaks[0].longest_streak = 2
+ expected_streaks[0].last_fulfilled_period_start = last_active_day
+
+ streaks = user_streak_manager.get_user_streaks(
+ user_id=user.user_id, country_iso="us"
+ )
+ assert streaks == expected_streaks
+
+
+def test_user_streak_complete_active(user_streak_manager, user, session_manager):
+ """Testing active streak that is today"""
+
+ # They completed yesterday NY time. Today isn't over so streak is pending
+ start1 = datetime.now(tz=ZoneInfo("America/New_York")) - timedelta(days=1)
+ create_session_complete(session_manager, start1.astimezone(tz=timezone.utc), user)
+
+ last_complete_day = start1.date()
+ expected_streak = UserStreak(
+ longest_streak=1,
+ current_streak=1,
+ state=StreakState.AT_RISK,
+ last_fulfilled_period_start=last_complete_day,
+ country_iso="us",
+ user_id=user.user_id,
+ fulfillment=StreakFulfillment.COMPLETE,
+ period=StreakPeriod.DAY,
+ )
+ streaks = user_streak_manager.get_user_streaks(
+ user_id=user.user_id, country_iso="us"
+ )
+ streak = [
+ s
+ for s in streaks
+ if s.fulfillment == StreakFulfillment.COMPLETE and s.period == StreakPeriod.DAY
+ ][0]
+ assert streak == expected_streak
+
+ # And now they complete today
+ start2 = datetime.now(tz=ZoneInfo("America/New_York"))
+ create_session_complete(session_manager, start2.astimezone(tz=timezone.utc), user)
+ last_complete_day = start2.date()
+ expected_streak = UserStreak(
+ longest_streak=2,
+ current_streak=2,
+ state=StreakState.ACTIVE,
+ last_fulfilled_period_start=last_complete_day,
+ country_iso="us",
+ user_id=user.user_id,
+ fulfillment=StreakFulfillment.COMPLETE,
+ period=StreakPeriod.DAY,
+ )
+
+ streaks = user_streak_manager.get_user_streaks(
+ user_id=user.user_id, country_iso="us"
+ )
+ streak = [
+ s
+ for s in streaks
+ if s.fulfillment == StreakFulfillment.COMPLETE and s.period == StreakPeriod.DAY
+ ][0]
+ assert streak == expected_streak
diff --git a/tests/managers/thl/test_userhealth.py b/tests/managers/thl/test_userhealth.py
new file mode 100644
index 0000000..1cda8de
--- /dev/null
+++ b/tests/managers/thl/test_userhealth.py
@@ -0,0 +1,367 @@
+from datetime import timezone, datetime
+from uuid import uuid4
+
+import faker
+import pytest
+
+from generalresearch.managers.thl.userhealth import (
+ IPRecordManager,
+ UserIpHistoryManager,
+)
+from generalresearch.models.thl.ipinfo import GeoIPInformation
+from generalresearch.models.thl.user_iphistory import (
+ IPRecord,
+)
+from generalresearch.models.thl.userhealth import AuditLogLevel, AuditLog
+
+fake = faker.Faker()
+
+
+class TestAuditLog:
+
+ def test_init(self, thl_web_rr, audit_log_manager):
+ from generalresearch.managers.thl.userhealth import AuditLogManager
+
+ alm = AuditLogManager(pg_config=thl_web_rr)
+
+ assert isinstance(alm, AuditLogManager)
+ assert isinstance(audit_log_manager, AuditLogManager)
+ assert alm.pg_config.db == thl_web_rr.db
+ assert audit_log_manager.pg_config.db == thl_web_rr.db
+
+ @pytest.mark.parametrize(
+ argnames="level",
+ argvalues=list(AuditLogLevel),
+ )
+ def test_create(self, audit_log_manager, user, level):
+ instance = audit_log_manager.create(
+ user_id=user.user_id, level=level, event_type=uuid4().hex
+ )
+ assert isinstance(instance, AuditLog)
+ assert instance.id != 1
+
+ def test_get_by_id(self, audit_log, audit_log_manager):
+ from generalresearch.models.thl.userhealth import AuditLog
+
+ with pytest.raises(expected_exception=Exception) as cm:
+ audit_log_manager.get_by_id(auditlog_id=999_999_999_999)
+ assert "No AuditLog with id of " in str(cm.value)
+
+ assert isinstance(audit_log, AuditLog)
+ res = audit_log_manager.get_by_id(auditlog_id=audit_log.id)
+ assert isinstance(res, AuditLog)
+ assert res.id == audit_log.id
+ assert res.created.tzinfo == timezone.utc
+
+ def test_filter_by_product(
+ self,
+ user_factory,
+ product_factory,
+ audit_log_factory,
+ audit_log_manager,
+ ):
+ p1 = product_factory()
+ p2 = product_factory()
+
+ audit_log_factory(user_id=user_factory(product=p1).user_id)
+ audit_log_factory(user_id=user_factory(product=p1).user_id)
+ audit_log_factory(user_id=user_factory(product=p1).user_id)
+
+ res = audit_log_manager.filter_by_product(product=p2)
+ assert isinstance(res, list)
+ assert len(res) == 0
+
+ res = audit_log_manager.filter_by_product(product=p1)
+ assert isinstance(res, list)
+ assert len(res) == 3
+
+ audit_log_factory(user_id=user_factory(product=p2).user_id)
+ res = audit_log_manager.filter_by_product(product=p2)
+ assert isinstance(res, list)
+ assert isinstance(res[0], AuditLog)
+ assert len(res) == 1
+
+ def test_filter_by_user_id(
+ self, user_factory, product, audit_log_factory, audit_log_manager
+ ):
+ u1 = user_factory(product=product)
+ u2 = user_factory(product=product)
+
+ audit_log_factory(user_id=u1.user_id)
+ audit_log_factory(user_id=u1.user_id)
+ audit_log_factory(user_id=u1.user_id)
+
+ res = audit_log_manager.filter_by_user_id(user_id=u1.user_id)
+ assert isinstance(res, list)
+ assert len(res) == 3
+
+ res = audit_log_manager.filter_by_user_id(user_id=u2.user_id)
+ assert isinstance(res, list)
+ assert len(res) == 0
+
+ audit_log_factory(user_id=u2.user_id)
+
+ res = audit_log_manager.filter_by_user_id(user_id=u2.user_id)
+ assert isinstance(res, list)
+ assert isinstance(res[0], AuditLog)
+ assert len(res) == 1
+
+ def test_filter(
+ self,
+ user_factory,
+ product_factory,
+ audit_log_factory,
+ audit_log_manager,
+ ):
+ p1 = product_factory()
+ p2 = product_factory()
+ p3 = product_factory()
+
+ u1 = user_factory(product=p1)
+ u2 = user_factory(product=p2)
+ u3 = user_factory(product=p3)
+
+ with pytest.raises(expected_exception=AssertionError) as cm:
+ audit_log_manager.filter(user_ids=[])
+ assert "must pass at least 1 user_id" in str(cm.value)
+
+ with pytest.raises(expected_exception=AssertionError) as cm:
+ audit_log_manager.filter(user_ids=[u1, u2, u3])
+ assert "must pass user_id as int" in str(cm.value)
+
+ res = audit_log_manager.filter(user_ids=[u1.user_id, u2.user_id, u3.user_id])
+ assert isinstance(res, list)
+ assert len(res) == 0
+
+ audit_log_factory(user_id=u1.user_id)
+
+ res = audit_log_manager.filter(user_ids=[u1.user_id, u2.user_id, u3.user_id])
+ assert isinstance(res, list)
+ assert isinstance(res[0], AuditLog)
+ assert len(res) == 1
+
+ def test_filter_count(
+ self,
+ user_factory,
+ product_factory,
+ audit_log_factory,
+ audit_log_manager,
+ ):
+ p1 = product_factory()
+ p2 = product_factory()
+ p3 = product_factory()
+
+ u1 = user_factory(product=p1)
+ u2 = user_factory(product=p2)
+ u3 = user_factory(product=p3)
+
+ with pytest.raises(expected_exception=AssertionError) as cm:
+ audit_log_manager.filter(user_ids=[])
+ assert "must pass at least 1 user_id" in str(cm.value)
+
+ with pytest.raises(expected_exception=AssertionError) as cm:
+ audit_log_manager.filter(user_ids=[u1, u2, u3])
+ assert "must pass user_id as int" in str(cm.value)
+
+ res = audit_log_manager.filter_count(
+ user_ids=[u1.user_id, u2.user_id, u3.user_id]
+ )
+ assert isinstance(res, int)
+ assert res == 0
+
+ audit_log_factory(user_id=u1.user_id, level=20)
+
+ res = audit_log_manager.filter_count(
+ user_ids=[u1.user_id, u2.user_id, u3.user_id]
+ )
+ assert isinstance(res, int)
+ assert res == 1
+
+ res = audit_log_manager.filter_count(
+ user_ids=[u1.user_id, u2.user_id, u3.user_id],
+ created_after=datetime.now(tz=timezone.utc),
+ )
+ assert isinstance(res, int)
+ assert res == 0
+
+ res = audit_log_manager.filter_count(
+ user_ids=[u1.user_id], event_type_like="offerwall-enter.%%"
+ )
+ assert res == 1
+
+ audit_log_factory(user_id=u1.user_id, level=50)
+ res = audit_log_manager.filter_count(
+ user_ids=[u1.user_id],
+ event_type_like="offerwall-enter.%%",
+ level_ge=10,
+ )
+ assert res == 2
+
+ res = audit_log_manager.filter_count(
+ user_ids=[u1.user_id], event_type_like="poop.%", level_ge=10
+ )
+ assert res == 0
+
+
+class TestIPRecordManager:
+
+ def test_init(self, thl_web_rr, thl_redis_config, ip_record_manager):
+ instance = IPRecordManager(pg_config=thl_web_rr, redis_config=thl_redis_config)
+ assert isinstance(instance, IPRecordManager)
+ assert isinstance(ip_record_manager, IPRecordManager)
+
+ def test_create(self, ip_record_manager, user, ip_information):
+ instance = ip_record_manager.create_dummy(
+ user_id=user.user_id, ip=ip_information.ip
+ )
+ assert isinstance(instance, IPRecord)
+
+ assert isinstance(instance.forwarded_ips, list)
+ assert isinstance(instance.forwarded_ip_records[0], IPRecord)
+ assert isinstance(instance.forwarded_ips[0], str)
+
+ assert instance.created == instance.forwarded_ip_records[0].created
+
+ ipr1 = ip_record_manager.filter_ip_records(filter_ips=[instance.ip])
+ assert isinstance(ipr1, list)
+ assert instance.model_dump_json() == ipr1[0].model_dump_json()
+
+ def test_prefetch_info(
+ self,
+ ip_record_factory,
+ ip_information_factory,
+ ip_geoname,
+ user,
+ thl_web_rr,
+ thl_redis_config,
+ ):
+
+ ip = fake.ipv4_public()
+ ip_information_factory(ip=ip, geoname=ip_geoname)
+ ipr: IPRecord = ip_record_factory(user_id=user.user_id, ip=ip)
+
+ assert ipr.information is None
+ assert len(ipr.forwarded_ip_records) >= 1
+ fipr = ipr.forwarded_ip_records[0]
+ assert fipr.information is None
+
+ ipr.prefetch_ipinfo(
+ pg_config=thl_web_rr,
+ redis_config=thl_redis_config,
+ include_forwarded=True,
+ )
+ assert isinstance(ipr.information, GeoIPInformation)
+ assert ipr.information.ip == ipr.ip == ip
+ assert fipr.information is None, "the ipinfo doesn't exist in the db yet"
+
+ ip_information_factory(ip=fipr.ip, geoname=ip_geoname)
+ ipr.prefetch_ipinfo(
+ pg_config=thl_web_rr,
+ redis_config=thl_redis_config,
+ include_forwarded=True,
+ )
+ assert fipr.information is not None
+
+
+@pytest.mark.usefixtures("user_iphistory_manager_clear_cache")
+class TestUserIpHistoryManager:
+ def test_init(self, thl_web_rr, thl_redis_config, user_iphistory_manager):
+ instance = UserIpHistoryManager(
+ pg_config=thl_web_rr, redis_config=thl_redis_config
+ )
+ assert isinstance(instance, UserIpHistoryManager)
+ assert isinstance(user_iphistory_manager, UserIpHistoryManager)
+
+ def test_latest_record(
+ self,
+ user_iphistory_manager,
+ user,
+ ip_record_factory,
+ ip_information_factory,
+ ip_geoname,
+ ):
+ ip = fake.ipv4_public()
+ ip_information_factory(ip=ip, geoname=ip_geoname, is_anonymous=True)
+ ipr1: IPRecord = ip_record_factory(user_id=user.user_id, ip=ip)
+
+ ipr = user_iphistory_manager.get_user_latest_ip_record(user=user)
+ assert ipr.ip == ipr1.ip
+ assert ipr.is_anonymous
+ assert ipr.information.lookup_prefix == "/32"
+
+ ip = fake.ipv6()
+ ip_information_factory(ip=ip, geoname=ip_geoname)
+ ipr2: IPRecord = ip_record_factory(user_id=user.user_id, ip=ip)
+
+ ipr = user_iphistory_manager.get_user_latest_ip_record(user=user)
+ assert ipr.ip == ipr2.ip
+ assert ipr.information.lookup_prefix == "/64"
+ assert ipr.information is not None
+ assert not ipr.is_anonymous
+
+ country_iso = user_iphistory_manager.get_user_latest_country(user=user)
+ assert country_iso == ip_geoname.country_iso
+
+ iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id)
+ assert iph.ips[0].information is not None
+ assert iph.ips[1].information is not None
+ assert iph.ips[0].country_iso == country_iso
+ assert iph.ips[0].is_anonymous
+ assert iph.ips[0].ip == ipr1.ip
+ assert iph.ips[1].ip == ipr2.ip
+
+ def test_virgin(self, user, user_iphistory_manager, ip_record_factory):
+ iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id)
+ assert len(iph.ips) == 0
+
+ ip_record_factory(user_id=user.user_id, ip=fake.ipv4_public())
+ iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id)
+ assert len(iph.ips) == 1
+
+ def test_out_of_order(
+ self,
+ ip_record_factory,
+ user,
+ user_iphistory_manager,
+ ip_information_factory,
+ ip_geoname,
+ ):
+ # Create the user-ip association BEFORE the ip even exists in the ipinfo table
+ ip = fake.ipv4_public()
+ ip_record_factory(user_id=user.user_id, ip=ip)
+ iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id)
+ assert len(iph.ips) == 1
+ ipr = iph.ips[0]
+ assert ipr.information is None
+ assert not ipr.is_anonymous
+
+ ip_information_factory(ip=ip, geoname=ip_geoname, is_anonymous=True)
+ iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id)
+ assert len(iph.ips) == 1
+ ipr = iph.ips[0]
+ assert ipr.information is not None
+ assert ipr.is_anonymous
+
+ def test_out_of_order_ipv6(
+ self,
+ ip_record_factory,
+ user,
+ user_iphistory_manager,
+ ip_information_factory,
+ ip_geoname,
+ ):
+ # Create the user-ip association BEFORE the ip even exists in the ipinfo table
+ ip = fake.ipv6()
+ ip_record_factory(user_id=user.user_id, ip=ip)
+ iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id)
+ assert len(iph.ips) == 1
+ ipr = iph.ips[0]
+ assert ipr.information is None
+ assert not ipr.is_anonymous
+
+ ip_information_factory(ip=ip, geoname=ip_geoname, is_anonymous=True)
+ iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id)
+ assert len(iph.ips) == 1
+ ipr = iph.ips[0]
+ assert ipr.information is not None
+ assert ipr.is_anonymous
diff --git a/tests/managers/thl/test_wall_manager.py b/tests/managers/thl/test_wall_manager.py
new file mode 100644
index 0000000..ee44e23
--- /dev/null
+++ b/tests/managers/thl/test_wall_manager.py
@@ -0,0 +1,283 @@
+from datetime import datetime, timezone, timedelta
+from decimal import Decimal
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.models import Source
+from generalresearch.models.thl.session import (
+ ReportValue,
+ Status,
+ StatusCode1,
+)
+from test_utils.models.conftest import user, session
+
+
+class TestWallManager:
+
+ @pytest.mark.parametrize("wall_count", [1, 2, 5, 10, 50, 99])
+ def test_get_wall_events(
+ self, wall_manager, session_factory, user, wall_count, utc_hour_ago
+ ):
+ from generalresearch.models.thl.session import Session
+
+ s1: Session = session_factory(
+ user=user, wall_count=wall_count, started=utc_hour_ago
+ )
+
+ assert len(s1.wall_events) == wall_count
+ assert len(wall_manager.get_wall_events(session_id=s1.id)) == wall_count
+
+ db_wall_events = wall_manager.get_wall_events(session_id=s1.id)
+ assert [w.uuid for w in s1.wall_events] == [w.uuid for w in db_wall_events]
+ assert [w.source for w in s1.wall_events] == [w.source for w in db_wall_events]
+ assert [w.buyer_id for w in s1.wall_events] == [
+ w.buyer_id for w in db_wall_events
+ ]
+ assert [w.req_survey_id for w in s1.wall_events] == [
+ w.req_survey_id for w in db_wall_events
+ ]
+ assert [w.started for w in s1.wall_events] == [
+ w.started for w in db_wall_events
+ ]
+
+ assert sum([w.req_cpi for w in s1.wall_events]) == sum(
+ [w.req_cpi for w in db_wall_events]
+ )
+ assert sum([w.cpi for w in s1.wall_events]) == sum(
+ [w.cpi for w in db_wall_events]
+ )
+
+ assert [w.session_id for w in s1.wall_events] == [
+ w.session_id for w in db_wall_events
+ ]
+ assert [w.user_id for w in s1.wall_events] == [
+ w.user_id for w in db_wall_events
+ ]
+ assert [w.survey_id for w in s1.wall_events] == [
+ w.survey_id for w in db_wall_events
+ ]
+
+ assert [w.finished for w in s1.wall_events] == [
+ w.finished for w in db_wall_events
+ ]
+
+ def test_get_wall_events_list_input(
+ self, wall_manager, session_factory, user, utc_hour_ago
+ ):
+ from generalresearch.models.thl.session import Session
+
+ session_ids = []
+ for idx in range(10):
+ s: Session = session_factory(user=user, wall_count=5, started=utc_hour_ago)
+ session_ids.append(s.id)
+
+ session_ids.sort()
+ res = wall_manager.get_wall_events(session_ids=session_ids)
+
+ assert isinstance(res, list)
+ assert len(res) == 50
+
+ res1 = list(set([w.session_id for w in res]))
+ res1.sort()
+
+ assert session_ids == res1
+
+ def test_create_wall(self, wall_manager, session_manager, user, session):
+ w = wall_manager.create(
+ session_id=session.id,
+ user_id=user.user_id,
+ uuid_id=uuid4().hex,
+ started=datetime.now(tz=timezone.utc),
+ source=Source.DYNATA,
+ buyer_id="123",
+ req_survey_id="456",
+ req_cpi=Decimal("1"),
+ )
+
+ assert w is not None
+ w2 = wall_manager.get_from_uuid(wall_uuid=w.uuid)
+ assert w == w2
+
+ def test_report_wall_abandon(
+ self, wall_manager, session_manager, user, session, utc_hour_ago
+ ):
+ w1 = wall_manager.create(
+ session_id=session.id,
+ user_id=user.user_id,
+ uuid_id=uuid4().hex,
+ started=utc_hour_ago,
+ source=Source.DYNATA,
+ buyer_id="123",
+ req_survey_id="456",
+ req_cpi=Decimal("1"),
+ )
+ wall_manager.report(
+ wall=w1,
+ report_value=ReportValue.REASON_UNKNOWN,
+ report_timestamp=utc_hour_ago + timedelta(minutes=1),
+ )
+ w2 = wall_manager.get_from_uuid(wall_uuid=w1.uuid)
+
+ # I Reported a session with no status. It should be marked as an abandon with a finished ts
+ assert ReportValue.REASON_UNKNOWN == w2.report_value
+ assert Status.ABANDON == w2.status
+ assert utc_hour_ago + timedelta(minutes=1) == w2.finished
+ assert w2.report_notes is None
+
+ # There is nothing stopping it from being un-abandoned...
+ finished = w1.started + timedelta(minutes=10)
+ wall_manager.finish(
+ wall=w1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ finished=finished,
+ )
+ w2 = wall_manager.get_from_uuid(wall_uuid=w1.uuid)
+ assert ReportValue.REASON_UNKNOWN == w2.report_value
+ assert w2.report_notes is None
+ assert finished == w2.finished
+ assert Status.COMPLETE == w2.status
+ # the status and finished get updated
+
+ def test_report_wall(
+ self, wall_manager, session_manager, user, session, utc_hour_ago
+ ):
+ w1 = wall_manager.create(
+ session_id=session.id,
+ user_id=user.user_id,
+ uuid_id=uuid4().hex,
+ started=utc_hour_ago,
+ source=Source.DYNATA,
+ buyer_id="123",
+ req_survey_id="456",
+ req_cpi=Decimal("1"),
+ )
+
+ finish_ts = utc_hour_ago + timedelta(minutes=10)
+ report_ts = utc_hour_ago + timedelta(minutes=11)
+ wall_manager.finish(
+ wall=w1,
+ status=Status.COMPLETE,
+ status_code_1=StatusCode1.COMPLETE,
+ finished=finish_ts,
+ )
+
+ # I Reported a session that was already completed. Only update the report values
+ wall_manager.report(
+ wall=w1,
+ report_value=ReportValue.TECHNICAL_ERROR,
+ report_timestamp=report_ts,
+ report_notes="This survey blows!",
+ )
+ w2 = wall_manager.get_from_uuid(wall_uuid=w1.uuid)
+
+ assert ReportValue.TECHNICAL_ERROR == w2.report_value
+ assert finish_ts == w2.finished
+ assert Status.COMPLETE == w2.status
+ assert "This survey blows!" == w2.report_notes
+
+ def test_filter_wall_attempts(
+ self, wall_manager, session_manager, user, session, utc_hour_ago
+ ):
+ res = wall_manager.filter_wall_attempts(user_id=user.user_id)
+ assert len(res) == 0
+ w1 = wall_manager.create(
+ session_id=session.id,
+ user_id=user.user_id,
+ uuid_id=uuid4().hex,
+ started=utc_hour_ago,
+ source=Source.DYNATA,
+ buyer_id="123",
+ req_survey_id="456",
+ req_cpi=Decimal("1"),
+ )
+ res = wall_manager.filter_wall_attempts(user_id=user.user_id)
+ assert len(res) == 1
+ w2 = wall_manager.create(
+ session_id=session.id,
+ user_id=user.user_id,
+ uuid_id=uuid4().hex,
+ started=utc_hour_ago + timedelta(minutes=1),
+ source=Source.DYNATA,
+ buyer_id="123",
+ req_survey_id="555",
+ req_cpi=Decimal("1"),
+ )
+ res = wall_manager.filter_wall_attempts(user_id=user.user_id)
+ assert len(res) == 2
+
+
+class TestWallCacheManager:
+
+ def test_get_attempts_none(self, wall_cache_manager, user):
+ attempts = wall_cache_manager.get_attempts(user.user_id)
+ assert len(attempts) == 0
+
+ def test_get_wall_events(
+ self, wall_cache_manager, wall_manager, session_manager, user
+ ):
+ start1 = datetime.now(timezone.utc) - timedelta(hours=3)
+ start2 = datetime.now(timezone.utc) - timedelta(hours=2)
+ start3 = datetime.now(timezone.utc) - timedelta(hours=1)
+
+ session = session_manager.create_dummy(started=start1, user=user)
+ wall1 = wall_manager.create_dummy(
+ session_id=session.id,
+ user_id=session.user_id,
+ started=start1,
+ req_cpi=Decimal("1.23"),
+ req_survey_id="11111",
+ source=Source.DYNATA,
+ )
+ # The flag never got set, so no results!
+ attempts = wall_cache_manager.get_attempts(user_id=user.user_id)
+ assert len(attempts) == 0
+
+ wall_cache_manager.set_flag(user_id=user.user_id)
+ attempts = wall_cache_manager.get_attempts(user_id=user.user_id)
+ assert len(attempts) == 1
+
+ wall2 = wall_manager.create_dummy(
+ session_id=session.id,
+ user_id=session.user_id,
+ started=start2,
+ req_cpi=Decimal("1.23"),
+ req_survey_id="22222",
+ source=Source.DYNATA,
+ )
+
+ # We haven't set the flag, so the cache won't update!
+ attempts = wall_cache_manager.get_attempts(user_id=user.user_id)
+ assert len(attempts) == 1
+
+ # Now set the flag
+ wall_cache_manager.set_flag(user_id=user.user_id)
+ attempts = wall_cache_manager.get_attempts(user_id=user.user_id)
+ assert len(attempts) == 2
+ # It is in desc order
+ assert attempts[0].req_survey_id == "22222"
+ assert attempts[1].req_survey_id == "11111"
+
+ # Test the trim. Fill up the cache with 6000 events, then add another,
+ # and it should be first in the list, with only 5k others
+ attempts10000 = [attempts[0]] * 6000
+ wall_cache_manager.update_attempts_redis_(attempts10000, user_id=user.user_id)
+
+ session = session_manager.create_dummy(started=start3, user=user)
+ wall3 = wall_manager.create_dummy(
+ session_id=session.id,
+ user_id=session.user_id,
+ started=start3,
+ req_cpi=Decimal("1.23"),
+ req_survey_id="33333",
+ source=Source.DYNATA,
+ )
+ wall_cache_manager.set_flag(user_id=user.user_id)
+ attempts = wall_cache_manager.get_attempts(user_id=user.user_id)
+
+ redis_key = wall_cache_manager.get_cache_key_(user_id=user.user_id)
+ assert wall_cache_manager.redis_client.llen(redis_key) == 5000
+
+ assert len(attempts) == 5000
+ assert attempts[0].req_survey_id == "33333"