aboutsummaryrefslogtreecommitdiff
path: root/tests/managers/leaderboard.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/managers/leaderboard.py')
-rw-r--r--tests/managers/leaderboard.py274
1 files changed, 274 insertions, 0 deletions
diff --git a/tests/managers/leaderboard.py b/tests/managers/leaderboard.py
new file mode 100644
index 0000000..4d32dd0
--- /dev/null
+++ b/tests/managers/leaderboard.py
@@ -0,0 +1,274 @@
+import os
+import time
+import zoneinfo
+from datetime import datetime, timezone
+from decimal import Decimal
+from uuid import uuid4
+
+import pytest
+
+from generalresearch.managers.leaderboard.manager import LeaderboardManager
+from generalresearch.managers.leaderboard.tasks import hit_leaderboards
+from generalresearch.models.thl.definitions import Status
+from generalresearch.models.thl.user import User
+from generalresearch.models.thl.product import Product
+from generalresearch.models.thl.session import Session
+from generalresearch.models.thl.leaderboard import (
+ LeaderboardCode,
+ LeaderboardFrequency,
+ LeaderboardRow,
+)
+from generalresearch.models.thl.product import (
+ PayoutConfig,
+ PayoutTransformation,
+ PayoutTransformationPercentArgs,
+)
+
+# random uuid for leaderboard tests
+product_id = uuid4().hex
+
+
+@pytest.fixture(autouse=True)
+def set_timezone():
+ os.environ["TZ"] = "UTC"
+ time.tzset()
+ yield
+ # Optionally reset to default
+ os.environ.pop("TZ", None)
+ time.tzset()
+
+
+@pytest.fixture
+def session_factory():
+ return _create_session
+
+
+def _create_session(
+ product_user_id="aaa", country_iso="us", user_payout=Decimal("1.00")
+):
+ user = User(
+ product_id=product_id,
+ product_user_id=product_user_id,
+ )
+ user.product = Product(
+ id=product_id,
+ name="test",
+ redirect_url="https://www.example.com",
+ payout_config=PayoutConfig(
+ payout_transformation=PayoutTransformation(
+ f="payout_transformation_percent",
+ kwargs=PayoutTransformationPercentArgs(pct=0.5),
+ )
+ ),
+ )
+ session = Session(
+ user=user,
+ started=datetime(2025, 2, 5, 6, tzinfo=timezone.utc),
+ id=1,
+ country_iso=country_iso,
+ status=Status.COMPLETE,
+ payout=Decimal("2.00"),
+ user_payout=user_payout,
+ )
+ return session
+
+
+@pytest.fixture(scope="function")
+def setup_leaderboards(thl_redis):
+ complete_count = {
+ "aaa": 10,
+ "bbb": 6,
+ "ccc": 6,
+ "ddd": 6,
+ "eee": 2,
+ "fff": 1,
+ "ggg": 1,
+ }
+ sum_payout = {"aaa": 345, "bbb": 100, "ccc": 100}
+ max_payout = sum_payout
+ country_iso = "us"
+ for freq in [
+ LeaderboardFrequency.DAILY,
+ LeaderboardFrequency.WEEKLY,
+ LeaderboardFrequency.MONTHLY,
+ ]:
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.COMPLETE_COUNT,
+ freq=freq,
+ product_id=product_id,
+ country_iso=country_iso,
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ thl_redis.delete(m.key)
+ thl_redis.zadd(m.key, complete_count)
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.SUM_PAYOUTS,
+ freq=freq,
+ product_id=product_id,
+ country_iso=country_iso,
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ thl_redis.delete(m.key)
+ thl_redis.zadd(m.key, sum_payout)
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.LARGEST_PAYOUT,
+ freq=freq,
+ product_id=product_id,
+ country_iso=country_iso,
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ thl_redis.delete(m.key)
+ thl_redis.zadd(m.key, max_payout)
+
+
+class TestLeaderboards:
+
+ def test_leaderboard_manager(self, setup_leaderboards, thl_redis):
+ country_iso = "us"
+ board_code = LeaderboardCode.COMPLETE_COUNT
+ freq = LeaderboardFrequency.DAILY
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=board_code,
+ freq=freq,
+ product_id=product_id,
+ country_iso=country_iso,
+ within_time=datetime(2025, 2, 5, 0, 0, 0),
+ )
+ lb = m.get_leaderboard()
+ assert lb.period_start_local == datetime(
+ 2025, 2, 5, 0, 0, 0, tzinfo=zoneinfo.ZoneInfo(key="America/New_York")
+ )
+ assert lb.period_end_local == datetime(
+ 2025,
+ 2,
+ 5,
+ 23,
+ 59,
+ 59,
+ 999999,
+ tzinfo=zoneinfo.ZoneInfo(key="America/New_York"),
+ )
+ assert lb.period_start_utc == datetime(2025, 2, 5, 5, tzinfo=timezone.utc)
+ assert lb.row_count == 7
+ assert lb.rows == [
+ LeaderboardRow(bpuid="aaa", rank=1, value=10),
+ LeaderboardRow(bpuid="bbb", rank=2, value=6),
+ LeaderboardRow(bpuid="ccc", rank=2, value=6),
+ LeaderboardRow(bpuid="ddd", rank=2, value=6),
+ LeaderboardRow(bpuid="eee", rank=5, value=2),
+ LeaderboardRow(bpuid="fff", rank=6, value=1),
+ LeaderboardRow(bpuid="ggg", rank=6, value=1),
+ ]
+
+ def test_leaderboard_manager_bpuid(self, setup_leaderboards, thl_redis):
+ country_iso = "us"
+ board_code = LeaderboardCode.COMPLETE_COUNT
+ freq = LeaderboardFrequency.DAILY
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=board_code,
+ freq=freq,
+ product_id=product_id,
+ country_iso=country_iso,
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ lb = m.get_leaderboard(bp_user_id="fff", limit=1)
+
+ # TODO: this won't work correctly if I request bpuid 'ggg', because it
+ # is ordered at the end even though it is tied, so it won't get a
+ # row after ('fff')
+
+ assert lb.rows == [
+ LeaderboardRow(bpuid="eee", rank=5, value=2),
+ LeaderboardRow(bpuid="fff", rank=6, value=1),
+ LeaderboardRow(bpuid="ggg", rank=6, value=1),
+ ]
+
+ lb.censor()
+ assert lb.rows[0].bpuid == "ee*"
+
+ def test_leaderboard_hit(self, setup_leaderboards, session_factory, thl_redis):
+ hit_leaderboards(redis_client=thl_redis, session=session_factory())
+
+ for freq in [
+ LeaderboardFrequency.DAILY,
+ LeaderboardFrequency.WEEKLY,
+ LeaderboardFrequency.MONTHLY,
+ ]:
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.COMPLETE_COUNT,
+ freq=freq,
+ product_id=product_id,
+ country_iso="us",
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ lb = m.get_leaderboard(limit=1)
+ assert lb.row_count == 7
+ assert lb.rows == [LeaderboardRow(bpuid="aaa", rank=1, value=11)]
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.LARGEST_PAYOUT,
+ freq=freq,
+ product_id=product_id,
+ country_iso="us",
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ lb = m.get_leaderboard(limit=1)
+ assert lb.row_count == 3
+ assert lb.rows == [LeaderboardRow(bpuid="aaa", rank=1, value=345)]
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.SUM_PAYOUTS,
+ freq=freq,
+ product_id=product_id,
+ country_iso="us",
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ lb = m.get_leaderboard(limit=1)
+ assert lb.row_count == 3
+ assert lb.rows == [LeaderboardRow(bpuid="aaa", rank=1, value=345 + 100)]
+
+ def test_leaderboard_hit_new_row(
+ self, setup_leaderboards, session_factory, thl_redis
+ ):
+ session = session_factory(product_user_id="zzz")
+ hit_leaderboards(redis_client=thl_redis, session=session)
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.COMPLETE_COUNT,
+ freq=LeaderboardFrequency.DAILY,
+ product_id=product_id,
+ country_iso="us",
+ within_time=datetime(2025, 2, 5, 12, 12, 12),
+ )
+ lb = m.get_leaderboard()
+ assert lb.row_count == 8
+ assert LeaderboardRow(bpuid="zzz", value=1, rank=6) in lb.rows
+
+ def test_leaderboard_country(self, thl_redis):
+ m = LeaderboardManager(
+ redis_client=thl_redis,
+ board_code=LeaderboardCode.COMPLETE_COUNT,
+ freq=LeaderboardFrequency.DAILY,
+ product_id=product_id,
+ country_iso="jp",
+ within_time=datetime(
+ 2025,
+ 2,
+ 1,
+ ),
+ )
+ lb = m.get_leaderboard()
+ assert lb.row_count == 0
+ assert lb.period_start_local == datetime(
+ 2025, 2, 1, 0, 0, 0, tzinfo=zoneinfo.ZoneInfo(key="Asia/Tokyo")
+ )
+ assert lb.local_start_time == "2025-02-01T00:00:00+09:00"
+ assert lb.local_end_time == "2025-02-01T23:59:59.999999+09:00"
+ assert lb.period_start_utc == datetime(2025, 1, 31, 15, tzinfo=timezone.utc)
+ print(lb.model_dump(mode="json"))