aboutsummaryrefslogtreecommitdiff
path: root/tests/models/thl/test_product.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/models/thl/test_product.py')
-rw-r--r--tests/models/thl/test_product.py1130
1 files changed, 1130 insertions, 0 deletions
diff --git a/tests/models/thl/test_product.py b/tests/models/thl/test_product.py
new file mode 100644
index 0000000..52f60c2
--- /dev/null
+++ b/tests/models/thl/test_product.py
@@ -0,0 +1,1130 @@
+import os
+import shutil
+from datetime import datetime, timezone, timedelta
+from decimal import Decimal
+from typing import Optional
+from uuid import uuid4
+
+import pytest
+from pydantic import ValidationError
+
+from generalresearch.currency import USDCent
+from generalresearch.models import Source
+from generalresearch.models.thl.product import (
+ Product,
+ PayoutConfig,
+ PayoutTransformation,
+ ProfilingConfig,
+ SourcesConfig,
+ IntegrationMode,
+ SupplyConfig,
+ SourceConfig,
+ SupplyPolicy,
+)
+
+
+class TestProduct:
+
+ def test_init(self):
+ # By default, just a Pydantic instance doesn't have an id_int
+ instance = Product.model_validate(
+ dict(
+ id="968a9acc79b74b6fb49542d82516d284",
+ name="test-968a9acc",
+ redirect_url="https://www.google.com/hey",
+ )
+ )
+ assert instance.id_int is None
+
+ res = instance.model_dump_json()
+ # We're not excluding anything here, only in the "*Out" variants
+ assert "id_int" in res
+
+ def test_init_db(self, product_manager):
+ # By default, just a Pydantic instance doesn't have an id_int
+ instance = product_manager.create_dummy()
+ assert isinstance(instance.id_int, int)
+
+ res = instance.model_dump_json()
+
+ # we json skip & exclude
+ res = instance.model_dump()
+
+ def test_redirect_url(self):
+ p = Product.model_validate(
+ dict(
+ id="968a9acc79b74b6fb49542d82516d284",
+ created="2023-09-21T22:13:09.274672Z",
+ commission_pct=Decimal("0.05"),
+ enabled=True,
+ sources=[{"name": "d", "active": True}],
+ name="test-968a9acc",
+ max_session_len=600,
+ team_id="8b5e94afd8a246bf8556ad9986486baa",
+ redirect_url="https://www.google.com/hey",
+ )
+ )
+
+ with pytest.raises(expected_exception=ValidationError):
+ p.redirect_url = ""
+
+ with pytest.raises(expected_exception=ValidationError):
+ p.redirect_url = None
+
+ with pytest.raises(expected_exception=ValidationError):
+ p.redirect_url = "http://www.example.com/test/?a=1&b=2"
+
+ with pytest.raises(expected_exception=ValidationError):
+ p.redirect_url = "http://www.example.com/test/?a=1&b=2&tsid="
+
+ p.redirect_url = "https://www.example.com/test/?a=1&b=2"
+ c = p.generate_bp_redirect(tsid="c6ab6ba1e75b44e2bf5aab00fc68e3b7")
+ assert (
+ c
+ == "https://www.example.com/test/?a=1&b=2&tsid=c6ab6ba1e75b44e2bf5aab00fc68e3b7"
+ )
+
+ def test_harmonizer_domain(self):
+ p = Product(
+ id="968a9acc79b74b6fb49542d82516d284",
+ created="2023-09-21T22:13:09.274672Z",
+ commission_pct=Decimal("0.05"),
+ enabled=True,
+ name="test-968a9acc",
+ team_id="8b5e94afd8a246bf8556ad9986486baa",
+ harmonizer_domain="profile.generalresearch.com",
+ redirect_url="https://www.google.com/hey",
+ )
+ assert p.harmonizer_domain == "https://profile.generalresearch.com/"
+ p.harmonizer_domain = "https://profile.generalresearch.com/"
+ p.harmonizer_domain = "https://profile.generalresearch.com"
+ assert p.harmonizer_domain == "https://profile.generalresearch.com/"
+ with pytest.raises(expected_exception=Exception):
+ p.harmonizer_domain = ""
+ with pytest.raises(expected_exception=Exception):
+ p.harmonizer_domain = None
+ with pytest.raises(expected_exception=Exception):
+ # no https
+ p.harmonizer_domain = "http://profile.generalresearch.com"
+ with pytest.raises(expected_exception=Exception):
+ # "/a" at the end
+ p.harmonizer_domain = "https://profile.generalresearch.com/a"
+
+ def test_payout_xform(self):
+ p = Product(
+ id="968a9acc79b74b6fb49542d82516d284",
+ created="2023-09-21T22:13:09.274672Z",
+ commission_pct=Decimal("0.05"),
+ enabled=True,
+ name="test-968a9acc",
+ team_id="8b5e94afd8a246bf8556ad9986486baa",
+ harmonizer_domain="profile.generalresearch.com",
+ redirect_url="https://www.google.com/hey",
+ )
+
+ p.payout_config.payout_transformation = PayoutTransformation.model_validate(
+ {
+ "f": "payout_transformation_percent",
+ "kwargs": {"pct": "0.5", "min_payout": "0.10"},
+ }
+ )
+
+ assert (
+ "payout_transformation_percent" == p.payout_config.payout_transformation.f
+ )
+ assert 0.5 == p.payout_config.payout_transformation.kwargs.pct
+ assert (
+ Decimal("0.10") == p.payout_config.payout_transformation.kwargs.min_payout
+ )
+ assert p.payout_config.payout_transformation.kwargs.max_payout is None
+
+ # This calls get_payout_transformation_func
+ # 50% of $1.00
+ assert Decimal("0.50") == p.calculate_user_payment(Decimal(1))
+ # with a min
+ assert Decimal("0.10") == p.calculate_user_payment(Decimal("0.15"))
+
+ with pytest.raises(expected_exception=ValidationError) as cm:
+ p.payout_config.payout_transformation = PayoutTransformation.model_validate(
+ {"f": "payout_transformation_percent", "kwargs": {}}
+ )
+ assert "1 validation error for PayoutTransformation\nkwargs.pct" in str(
+ cm.value
+ )
+
+ with pytest.raises(expected_exception=ValidationError) as cm:
+ p.payout_config.payout_transformation = PayoutTransformation.model_validate(
+ {"f": "payout_transformation_percent"}
+ )
+
+ assert "1 validation error for PayoutTransformation\nkwargs" in str(cm.value)
+
+ with pytest.warns(expected_warning=Warning) as w:
+ p.payout_config.payout_transformation = PayoutTransformation.model_validate(
+ {
+ "f": "payout_transformation_percent",
+ "kwargs": {"pct": 1, "min_payout": "0.5"},
+ }
+ )
+ assert "Are you sure you want to pay respondents >95% of CPI?" in "".join(
+ [str(i.message) for i in w]
+ )
+
+ p.payout_config = PayoutConfig()
+ assert p.calculate_user_payment(Decimal("0.15")) is None
+
+ def test_payout_xform_amt(self):
+ p = Product(
+ id="968a9acc79b74b6fb49542d82516d284",
+ created="2023-09-21T22:13:09.274672Z",
+ commission_pct=Decimal("0.05"),
+ enabled=True,
+ name="test-968a9acc",
+ team_id="8b5e94afd8a246bf8556ad9986486baa",
+ harmonizer_domain="profile.generalresearch.com",
+ redirect_url="https://www.google.com/hey",
+ )
+
+ p.payout_config.payout_transformation = PayoutTransformation.model_validate(
+ {
+ "f": "payout_transformation_amt",
+ }
+ )
+
+ assert "payout_transformation_amt" == p.payout_config.payout_transformation.f
+
+ # This calls get_payout_transformation_func
+ # 95% of $1.00
+ assert p.calculate_user_payment(Decimal(1)) == Decimal("0.95")
+ assert p.calculate_user_payment(Decimal("1.05")) == Decimal("1.00")
+
+ assert p.calculate_user_payment(
+ Decimal("0.10"), user_wallet_balance=Decimal(0)
+ ) == Decimal("0.07")
+ assert p.calculate_user_payment(
+ Decimal("1.05"), user_wallet_balance=Decimal(0)
+ ) == Decimal("0.97")
+ assert p.calculate_user_payment(
+ Decimal(".05"), user_wallet_balance=Decimal(1)
+ ) == Decimal("0.02")
+ # final balance will be <0, so pay the full amount
+ assert p.calculate_user_payment(
+ Decimal(".50"), user_wallet_balance=Decimal(-1)
+ ) == p.calculate_user_payment(Decimal("0.50"))
+ # final balance will be >0, so do the 7c rounding
+ assert p.calculate_user_payment(
+ Decimal(".50"), user_wallet_balance=Decimal("-0.10")
+ ) == (
+ p.calculate_user_payment(Decimal(".40"), user_wallet_balance=Decimal(0))
+ - Decimal("-0.10")
+ )
+
+ def test_payout_xform_none(self):
+ p = Product(
+ id="968a9acc79b74b6fb49542d82516d284",
+ created="2023-09-21T22:13:09.274672Z",
+ commission_pct=Decimal("0.05"),
+ enabled=True,
+ name="test-968a9acc",
+ team_id="8b5e94afd8a246bf8556ad9986486baa",
+ harmonizer_domain="profile.generalresearch.com",
+ redirect_url="https://www.google.com/hey",
+ payout_config=PayoutConfig(payout_format=None, payout_transformation=None),
+ )
+ assert p.format_payout_format(Decimal("1.00")) is None
+
+ pt = PayoutTransformation.model_validate(
+ {"kwargs": {"pct": 0.5}, "f": "payout_transformation_percent"}
+ )
+ p.payout_config = PayoutConfig(
+ payout_format="{payout*10:,.0f} Points", payout_transformation=pt
+ )
+ assert p.format_payout_format(Decimal("1.00")) == "1,000 Points"
+
+ def test_profiling(self):
+ p = Product(
+ id="968a9acc79b74b6fb49542d82516d284",
+ created="2023-09-21T22:13:09.274672Z",
+ commission_pct=Decimal("0.05"),
+ enabled=True,
+ name="test-968a9acc",
+ team_id="8b5e94afd8a246bf8556ad9986486baa",
+ harmonizer_domain="profile.generalresearch.com",
+ redirect_url="https://www.google.com/hey",
+ )
+ assert p.profiling_config.enabled is True
+
+ p.profiling_config = ProfilingConfig(max_questions=1)
+ assert p.profiling_config.max_questions == 1
+
+ def test_bp_account(self, product, thl_lm):
+ assert product.bp_account is None
+
+ product.prefetch_bp_account(thl_lm=thl_lm)
+
+ from generalresearch.models.thl.ledger import LedgerAccount
+
+ assert isinstance(product.bp_account, LedgerAccount)
+
+
+class TestGlobalProduct:
+ # We have one product ID that is special; we call it the Global
+ # Product ID and in prod the. This product stores a bunch of extra
+ # things in the SourcesConfig
+
+ def test_init_and_props(self):
+ instance = Product(
+ name="Global Config",
+ redirect_url="https://www.example.com",
+ sources_config=SupplyConfig(
+ policies=[
+ # This is the config for Dynata that any BP is allowed to use
+ SupplyPolicy(
+ address=["https://dynata.internal:50051"],
+ active=True,
+ name=Source.DYNATA,
+ integration_mode=IntegrationMode.PLATFORM,
+ ),
+ # Spectrum that is using OUR credentials, that anyone is allowed to use.
+ # Same as the dynata config above, just that the dynata supplier_id is
+ # inferred by the dynata-grpc; it's not required to be set.
+ SupplyPolicy(
+ address=["https://spectrum.internal:50051"],
+ active=True,
+ name=Source.SPECTRUM,
+ supplier_id="example-supplier-id",
+ # implicit Scope = GLOBAL
+ # default integration_mode=IntegrationMode.PLATFORM,
+ ),
+ # A spectrum config with a different supplier_id, but
+ # it is OUR supplier, and we are paid for the completes. Only a certain BP
+ # can use this config.
+ SupplyPolicy(
+ address=["https://spectrum.internal:50051"],
+ active=True,
+ name=Source.SPECTRUM,
+ supplier_id="example-supplier-id",
+ team_ids=["d42194c2dfe44d7c9bec98123bc4a6c0"],
+ # implicit Scope = TEAM
+ # default integration_mode=IntegrationMode.PLATFORM,
+ ),
+ # The supplier ID is associated with THEIR
+ # credentials, and we do not get paid for this activity.
+ SupplyPolicy(
+ address=["https://cint.internal:50051"],
+ active=True,
+ name=Source.CINT,
+ supplier_id="example-supplier-id",
+ product_ids=["db8918b3e87d4444b60241d0d3a54caa"],
+ integration_mode=IntegrationMode.PASS_THROUGH,
+ ),
+ # We could have another global cint integration available
+ # to anyone also, or we could have another like above
+ SupplyPolicy(
+ address=["https://cint.internal:50051"],
+ active=True,
+ name=Source.CINT,
+ supplier_id="example-supplier-id",
+ team_ids=["b163972a59584de881e5eab01ad10309"],
+ integration_mode=IntegrationMode.PASS_THROUGH,
+ ),
+ ]
+ ),
+ )
+
+ assert Product.model_validate_json(instance.model_dump_json()) == instance
+
+ s = instance.sources_config
+ # Cint should NOT have a global config
+ assert set(s.global_scoped_policies_dict.keys()) == {
+ Source.DYNATA,
+ Source.SPECTRUM,
+ }
+
+ # The spectrum global config is the one that isn't scoped to a
+ # specific supplier
+ assert (
+ s.global_scoped_policies_dict[Source.SPECTRUM].supplier_id
+ == "grl-supplier-id"
+ )
+
+ assert set(s.team_scoped_policies_dict.keys()) == {
+ "b163972a59584de881e5eab01ad10309",
+ "d42194c2dfe44d7c9bec98123bc4a6c0",
+ }
+ # This team has one team-scoped config, and it's for spectrum
+ assert s.team_scoped_policies_dict[
+ "d42194c2dfe44d7c9bec98123bc4a6c0"
+ ].keys() == {Source.SPECTRUM}
+
+ # For a random product/team, it'll just have the globally-scoped config
+ random_product = uuid4().hex
+ random_team = uuid4().hex
+ res = instance.sources_config.get_policies_for(
+ product_id=random_product, team_id=random_team
+ )
+ assert res == s.global_scoped_policies_dict
+
+ # It'll have the global config plus cint, and it should use the PRODUCT
+ # scoped config, not the TEAM scoped!
+ res = instance.sources_config.get_policies_for(
+ product_id="db8918b3e87d4444b60241d0d3a54caa",
+ team_id="b163972a59584de881e5eab01ad10309",
+ )
+ assert set(res.keys()) == {
+ Source.DYNATA,
+ Source.SPECTRUM,
+ Source.CINT,
+ }
+ assert res[Source.CINT].supplier_id == "example-supplier-id"
+
+ def test_source_vs_supply_validate(self):
+ # sources_config can be a SupplyConfig or SourcesConfig.
+ # make sure they get model_validated correctly
+ gp = Product(
+ name="Global Config",
+ redirect_url="https://www.example.com",
+ sources_config=SupplyConfig(
+ policies=[
+ SupplyPolicy(
+ address=["https://dynata.internal:50051"],
+ active=True,
+ name=Source.DYNATA,
+ integration_mode=IntegrationMode.PLATFORM,
+ )
+ ]
+ ),
+ )
+ bp = Product(
+ name="test product config",
+ redirect_url="https://www.example.com",
+ sources_config=SourcesConfig(
+ user_defined=[
+ SourceConfig(
+ active=False,
+ name=Source.DYNATA,
+ )
+ ]
+ ),
+ )
+ assert Product.model_validate_json(gp.model_dump_json()) == gp
+ assert Product.model_validate_json(bp.model_dump_json()) == bp
+
+ def test_validations(self):
+ with pytest.raises(
+ ValidationError, match="Can only have one GLOBAL policy per Source"
+ ):
+ SupplyConfig(
+ policies=[
+ SupplyPolicy(
+ address=["https://dynata.internal:50051"],
+ active=True,
+ name=Source.DYNATA,
+ integration_mode=IntegrationMode.PLATFORM,
+ ),
+ SupplyPolicy(
+ address=["https://dynata.internal:50051"],
+ active=True,
+ name=Source.DYNATA,
+ integration_mode=IntegrationMode.PASS_THROUGH,
+ ),
+ ]
+ )
+ with pytest.raises(
+ ValidationError,
+ match="Can only have one PRODUCT policy per Source per BP",
+ ):
+ SupplyConfig(
+ policies=[
+ SupplyPolicy(
+ address=["https://dynata.internal:50051"],
+ active=True,
+ name=Source.DYNATA,
+ product_ids=["7e417dec1c8a406e8554099b46e518ca"],
+ integration_mode=IntegrationMode.PLATFORM,
+ ),
+ SupplyPolicy(
+ address=["https://dynata.internal:50051"],
+ active=True,
+ name=Source.DYNATA,
+ product_ids=["7e417dec1c8a406e8554099b46e518ca"],
+ integration_mode=IntegrationMode.PASS_THROUGH,
+ ),
+ ]
+ )
+ with pytest.raises(
+ ValidationError,
+ match="Can only have one TEAM policy per Source per Team",
+ ):
+ SupplyConfig(
+ policies=[
+ SupplyPolicy(
+ address=["https://dynata.internal:50051"],
+ active=True,
+ name=Source.DYNATA,
+ team_ids=["7e417dec1c8a406e8554099b46e518ca"],
+ integration_mode=IntegrationMode.PLATFORM,
+ ),
+ SupplyPolicy(
+ address=["https://dynata.internal:50051"],
+ active=True,
+ name=Source.DYNATA,
+ team_ids=["7e417dec1c8a406e8554099b46e518ca"],
+ integration_mode=IntegrationMode.PASS_THROUGH,
+ ),
+ ]
+ )
+
+
+class TestGlobalProductConfigFor:
+ def test_no_user_defined(self):
+ sc = SupplyConfig(
+ policies=[
+ SupplyPolicy(
+ address=["https://dynata.internal:50051"],
+ active=True,
+ name=Source.DYNATA,
+ )
+ ]
+ )
+ product = Product(
+ name="Test Product Config",
+ redirect_url="https://www.example.com",
+ sources_config=SourcesConfig(),
+ )
+ res = sc.get_config_for_product(product=product)
+ assert len(res.policies) == 1
+
+ def test_user_defined_merge(self):
+ sc = SupplyConfig(
+ policies=[
+ SupplyPolicy(
+ address=["https://dynata.internal:50051"],
+ banned_countries=["mx"],
+ active=True,
+ name=Source.DYNATA,
+ ),
+ SupplyPolicy(
+ address=["https://dynata.internal:50051"],
+ banned_countries=["ca"],
+ active=True,
+ name=Source.DYNATA,
+ team_ids=[uuid4().hex],
+ ),
+ ]
+ )
+ product = Product(
+ name="Test Product Config",
+ redirect_url="https://www.example.com",
+ sources_config=SourcesConfig(
+ user_defined=[
+ SourceConfig(
+ name=Source.DYNATA,
+ active=False,
+ banned_countries=["us"],
+ )
+ ]
+ ),
+ )
+ res = sc.get_config_for_product(product=product)
+ assert len(res.policies) == 1
+ assert not res.policies[0].active
+ assert res.policies[0].banned_countries == ["mx", "us"]
+
+ def test_no_eligible(self):
+ sc = SupplyConfig(
+ policies=[
+ SupplyPolicy(
+ address=["https://dynata.internal:50051"],
+ active=True,
+ name=Source.DYNATA,
+ team_ids=["7e417dec1c8a406e8554099b46e518ca"],
+ integration_mode=IntegrationMode.PLATFORM,
+ )
+ ]
+ )
+ product = Product(
+ name="Test Product Config",
+ redirect_url="https://www.example.com",
+ sources_config=SourcesConfig(),
+ )
+ res = sc.get_config_for_product(product=product)
+ assert len(res.policies) == 0
+
+
+class TestProductFinancials:
+
+ @pytest.fixture
+ def start(self) -> "datetime":
+ return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc)
+
+ @pytest.fixture
+ def offset(self) -> str:
+ return "30d"
+
+ @pytest.fixture
+ def duration(self) -> Optional["timedelta"]:
+ return None
+
+ def test_balance(
+ self,
+ business,
+ product_factory,
+ user_factory,
+ mnt_filepath,
+ bp_payout_factory,
+ thl_lm,
+ lm,
+ duration,
+ offset,
+ thl_redis_config,
+ start,
+ thl_web_rr,
+ brokerage_product_payout_event_manager,
+ session_with_tx_factory,
+ delete_ledger_db,
+ create_main_accounts,
+ client_no_amm,
+ ledger_collection,
+ pop_ledger_merge,
+ delete_df_collection,
+ ):
+ delete_ledger_db()
+ create_main_accounts()
+ delete_df_collection(coll=ledger_collection)
+
+ from generalresearch.models.thl.product import Product
+ from generalresearch.models.thl.user import User
+ from generalresearch.models.thl.finance import ProductBalances
+ from generalresearch.currency import USDCent
+
+ p1: Product = product_factory(business=business)
+ u1: User = user_factory(product=p1)
+ bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=p1)
+ thl_lm.get_account_or_create_user_wallet(user=u1)
+ brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+
+ assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 0
+
+ session_with_tx_factory(
+ user=u1,
+ wall_req_cpi=Decimal(".50"),
+ started=start + timedelta(days=1),
+ )
+ assert thl_lm.get_account_balance(account=bp_wallet) == 48
+ assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 1
+
+ session_with_tx_factory(
+ user=u1,
+ wall_req_cpi=Decimal("1.00"),
+ started=start + timedelta(days=2),
+ )
+ assert thl_lm.get_account_balance(account=bp_wallet) == 143
+ assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 2
+
+ with pytest.raises(expected_exception=AssertionError) as cm:
+ p1.prebuild_balance(
+ thl_lm=thl_lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ )
+ assert "Cannot build Product Balance" in str(cm.value)
+
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+
+ p1.prebuild_balance(
+ thl_lm=thl_lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ )
+ assert isinstance(p1.balance, ProductBalances)
+ assert p1.balance.payout == 143
+ assert p1.balance.adjustment == 0
+ assert p1.balance.expense == 0
+ assert p1.balance.net == 143
+ assert p1.balance.balance == 143
+ assert p1.balance.retainer == 35
+ assert p1.balance.available_balance == 108
+
+ p1.prebuild_payouts(
+ thl_lm=thl_lm,
+ bp_pem=brokerage_product_payout_event_manager,
+ )
+ assert p1.payouts is not None
+ assert len(p1.payouts) == 0
+ assert p1.payouts_total == 0
+ assert p1.payouts_total_str == "$0.00"
+
+ # -- Now pay them out...
+
+ bp_payout_factory(
+ product=p1,
+ amount=USDCent(50),
+ created=start + timedelta(days=3),
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ )
+ assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 3
+
+ # RM the entire directories
+ shutil.rmtree(ledger_collection.archive_path)
+ os.makedirs(ledger_collection.archive_path, exist_ok=True)
+ shutil.rmtree(pop_ledger_merge.archive_path)
+ os.makedirs(pop_ledger_merge.archive_path, exist_ok=True)
+
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+
+ p1.prebuild_balance(
+ thl_lm=thl_lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ )
+ assert isinstance(p1.balance, ProductBalances)
+ assert p1.balance.payout == 143
+ assert p1.balance.adjustment == 0
+ assert p1.balance.expense == 0
+ assert p1.balance.net == 143
+ assert p1.balance.balance == 93
+ assert p1.balance.retainer == 23
+ assert p1.balance.available_balance == 70
+
+ p1.prebuild_payouts(
+ thl_lm=thl_lm,
+ bp_pem=brokerage_product_payout_event_manager,
+ )
+ assert p1.payouts is not None
+ assert len(p1.payouts) == 1
+ assert p1.payouts_total == 50
+ assert p1.payouts_total_str == "$0.50"
+
+ # -- Now pay ou another!.
+
+ bp_payout_factory(
+ product=p1,
+ amount=USDCent(5),
+ created=start + timedelta(days=4),
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ )
+ assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 4
+
+ # RM the entire directories
+ shutil.rmtree(ledger_collection.archive_path)
+ os.makedirs(ledger_collection.archive_path, exist_ok=True)
+ shutil.rmtree(pop_ledger_merge.archive_path)
+ os.makedirs(pop_ledger_merge.archive_path, exist_ok=True)
+
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+
+ p1.prebuild_balance(
+ thl_lm=thl_lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ )
+ assert isinstance(p1.balance, ProductBalances)
+ assert p1.balance.payout == 143
+ assert p1.balance.adjustment == 0
+ assert p1.balance.expense == 0
+ assert p1.balance.net == 143
+ assert p1.balance.balance == 88
+ assert p1.balance.retainer == 22
+ assert p1.balance.available_balance == 66
+
+ p1.prebuild_payouts(
+ thl_lm=thl_lm,
+ bp_pem=brokerage_product_payout_event_manager,
+ )
+ assert p1.payouts is not None
+ assert len(p1.payouts) == 2
+ assert p1.payouts_total == 55
+ assert p1.payouts_total_str == "$0.55"
+
+
+class TestProductBalance:
+
+ @pytest.fixture
+ def start(self) -> "datetime":
+ return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc)
+
+ @pytest.fixture
+ def offset(self) -> str:
+ return "30d"
+
+ @pytest.fixture
+ def duration(self) -> Optional["timedelta"]:
+ return None
+
+ def test_inconsistent(
+ self,
+ product,
+ mnt_filepath,
+ thl_lm,
+ client_no_amm,
+ thl_redis_config,
+ brokerage_product_payout_event_manager,
+ delete_ledger_db,
+ create_main_accounts,
+ delete_df_collection,
+ ledger_collection,
+ business,
+ user_factory,
+ product_factory,
+ session_with_tx_factory,
+ pop_ledger_merge,
+ start,
+ bp_payout_factory,
+ payout_event_manager,
+ ):
+ # Now let's load it up and actually test some things
+ delete_ledger_db()
+ create_main_accounts()
+ delete_df_collection(coll=ledger_collection)
+
+ from generalresearch.models.thl.user import User
+
+ u1: User = user_factory(product=product)
+
+ # 1. Complete and Build Parquets 1st time
+ session_with_tx_factory(
+ user=u1,
+ wall_req_cpi=Decimal(".75"),
+ started=start + timedelta(days=1),
+ )
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+
+ # 2. Payout and build Parquets 2nd time
+ payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+ bp_payout_factory(
+ product=product,
+ amount=USDCent(71),
+ ext_ref_id=uuid4().hex,
+ created=start + timedelta(days=1, minutes=1),
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ )
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+
+ with pytest.raises(expected_exception=AssertionError) as cm:
+ product.prebuild_balance(
+ thl_lm=thl_lm, ds=mnt_filepath, client=client_no_amm
+ )
+ assert "Sql and Parquet Balance inconsistent" in str(cm)
+
+ def test_not_inconsistent(
+ self,
+ product,
+ mnt_filepath,
+ thl_lm,
+ client_no_amm,
+ thl_redis_config,
+ brokerage_product_payout_event_manager,
+ delete_ledger_db,
+ create_main_accounts,
+ delete_df_collection,
+ ledger_collection,
+ business,
+ user_factory,
+ product_factory,
+ session_with_tx_factory,
+ pop_ledger_merge,
+ start,
+ bp_payout_factory,
+ payout_event_manager,
+ ):
+ # This is very similar to the test_complete_payout_pq_inconsistent
+ # test, however this time we're only going to assign the payout
+ # in real time, and not in the past. This means that even if we
+ # build the parquet files multiple times, they will include the
+ # payout.
+
+ # Now let's load it up and actually test some things
+ delete_ledger_db()
+ create_main_accounts()
+ delete_df_collection(coll=ledger_collection)
+
+ from generalresearch.models.thl.user import User
+
+ u1: User = user_factory(product=product)
+
+ # 1. Complete and Build Parquets 1st time
+ session_with_tx_factory(
+ user=u1,
+ wall_req_cpi=Decimal(".75"),
+ started=start + timedelta(days=1),
+ )
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+
+ # 2. Payout and build Parquets 2nd time but this payout is "now"
+ # so it hasn't already been archived
+ payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+ bp_payout_factory(
+ product=product,
+ amount=USDCent(71),
+ ext_ref_id=uuid4().hex,
+ created=datetime.now(tz=timezone.utc),
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ )
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+
+ # We just want to call this to confirm it doesn't raise.
+ product.prebuild_balance(thl_lm=thl_lm, ds=mnt_filepath, client=client_no_amm)
+
+
+class TestProductPOPFinancial:
+
+ @pytest.fixture
+ def start(self) -> "datetime":
+ return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc)
+
+ @pytest.fixture
+ def offset(self) -> str:
+ return "30d"
+
+ @pytest.fixture
+ def duration(self) -> Optional["timedelta"]:
+ return None
+
+ def test_base(
+ self,
+ product,
+ mnt_filepath,
+ thl_lm,
+ client_no_amm,
+ thl_redis_config,
+ brokerage_product_payout_event_manager,
+ delete_ledger_db,
+ create_main_accounts,
+ delete_df_collection,
+ ledger_collection,
+ business,
+ user_factory,
+ product_factory,
+ session_with_tx_factory,
+ pop_ledger_merge,
+ start,
+ bp_payout_factory,
+ payout_event_manager,
+ ):
+ # This is very similar to the test_complete_payout_pq_inconsistent
+ # test, however this time we're only going to assign the payout
+ # in real time, and not in the past. This means that even if we
+ # build the parquet files multiple times, they will include the
+ # payout.
+
+ # Now let's load it up and actually test some things
+ delete_ledger_db()
+ create_main_accounts()
+ delete_df_collection(coll=ledger_collection)
+
+ from generalresearch.models.thl.user import User
+
+ u1: User = user_factory(product=product)
+
+ # 1. Complete and Build Parquets 1st time
+ session_with_tx_factory(
+ user=u1,
+ wall_req_cpi=Decimal(".75"),
+ started=start + timedelta(days=1),
+ )
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+
+ # --- test ---
+ assert product.pop_financial is None
+ product.prebuild_pop_financial(
+ thl_lm=thl_lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ pop_ledger=pop_ledger_merge,
+ )
+
+ from generalresearch.models.thl.finance import POPFinancial
+
+ assert isinstance(product.pop_financial, list)
+ assert isinstance(product.pop_financial[0], POPFinancial)
+ pf1: POPFinancial = product.pop_financial[0]
+ assert isinstance(pf1.time, datetime)
+ assert pf1.payout == 71
+ assert pf1.net == 71
+ assert pf1.adjustment == 0
+ for adj in pf1.adjustment_types:
+ assert adj.amount == 0
+
+
+class TestProductCache:
+
+ @pytest.fixture
+ def start(self) -> "datetime":
+ return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc)
+
+ @pytest.fixture
+ def offset(self) -> str:
+ return "30d"
+
+ @pytest.fixture
+ def duration(self) -> Optional["timedelta"]:
+ return None
+
+ def test_basic(
+ self,
+ product,
+ mnt_filepath,
+ thl_lm,
+ client_no_amm,
+ thl_redis_config,
+ brokerage_product_payout_event_manager,
+ delete_ledger_db,
+ create_main_accounts,
+ delete_df_collection,
+ ledger_collection,
+ business,
+ user_factory,
+ product_factory,
+ session_with_tx_factory,
+ pop_ledger_merge,
+ start,
+ ):
+ # Now let's load it up and actually test some things
+ delete_ledger_db()
+ create_main_accounts()
+ delete_df_collection(coll=ledger_collection)
+
+ # Confirm the default / null behavior
+ rc = thl_redis_config.create_redis_client()
+ res: Optional[str] = rc.get(product.cache_key)
+ assert res is None
+ with pytest.raises(expected_exception=AssertionError):
+ product.set_cache(
+ thl_lm=thl_lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ bp_pem=brokerage_product_payout_event_manager,
+ redis_config=thl_redis_config,
+ )
+
+ from generalresearch.models.thl.product import Product
+ from generalresearch.models.thl.user import User
+
+ u1: User = user_factory(product=product)
+
+ session_with_tx_factory(
+ user=u1,
+ wall_req_cpi=Decimal(".75"),
+ started=start + timedelta(days=1),
+ )
+
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+
+ # Now try again with everything in place
+ product.set_cache(
+ thl_lm=thl_lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ bp_pem=brokerage_product_payout_event_manager,
+ redis_config=thl_redis_config,
+ )
+
+ # Fetch from cache and assert the instance loaded from redis
+ res: Optional[str] = rc.get(product.cache_key)
+ assert isinstance(res, str)
+ from generalresearch.models.thl.ledger import LedgerAccount
+
+ assert isinstance(product.bp_account, LedgerAccount)
+
+ p1: Product = Product.model_validate_json(res)
+ assert p1.balance.product_id == product.uuid
+ assert p1.balance.payout_usd_str == "$0.71"
+ assert p1.balance.retainer_usd_str == "$0.17"
+ assert p1.balance.available_balance_usd_str == "$0.54"
+
+ def test_neg_balance_cache(
+ self,
+ product,
+ mnt_filepath,
+ thl_lm,
+ client_no_amm,
+ thl_redis_config,
+ brokerage_product_payout_event_manager,
+ delete_ledger_db,
+ create_main_accounts,
+ delete_df_collection,
+ ledger_collection,
+ business,
+ user_factory,
+ product_factory,
+ session_with_tx_factory,
+ pop_ledger_merge,
+ start,
+ bp_payout_factory,
+ payout_event_manager,
+ adj_to_fail_with_tx_factory,
+ ):
+ # Now let's load it up and actually test some things
+ delete_ledger_db()
+ create_main_accounts()
+ delete_df_collection(coll=ledger_collection)
+
+ from generalresearch.models.thl.product import Product
+ from generalresearch.models.thl.user import User
+
+ u1: User = user_factory(product=product)
+
+ # 1. Complete
+ s1 = session_with_tx_factory(
+ user=u1,
+ wall_req_cpi=Decimal(".75"),
+ started=start + timedelta(days=1),
+ )
+
+ # 2. Payout
+ payout_event_manager.set_account_lookup_table(thl_lm=thl_lm)
+ bp_payout_factory(
+ product=product,
+ amount=USDCent(71),
+ ext_ref_id=uuid4().hex,
+ created=start + timedelta(days=1, minutes=1),
+ skip_wallet_balance_check=True,
+ skip_one_per_day_check=True,
+ )
+
+ # 3. Recon
+ adj_to_fail_with_tx_factory(
+ session=s1,
+ created=start + timedelta(days=1, minutes=1),
+ )
+
+ # Finally, process everything:
+ ledger_collection.initial_load(client=None, sync=True)
+ pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection)
+
+ product.set_cache(
+ thl_lm=thl_lm,
+ ds=mnt_filepath,
+ client=client_no_amm,
+ bp_pem=brokerage_product_payout_event_manager,
+ redis_config=thl_redis_config,
+ )
+
+ # Fetch from cache and assert the instance loaded from redis
+ rc = thl_redis_config.create_redis_client()
+ res: Optional[str] = rc.get(product.cache_key)
+ assert isinstance(res, str)
+
+ p1: Product = Product.model_validate_json(res)
+ assert p1.balance.product_id == product.uuid
+ assert p1.balance.payout_usd_str == "$0.71"
+ assert p1.balance.adjustment == -71
+ assert p1.balance.expense == 0
+ assert p1.balance.net == 0
+ assert p1.balance.balance == -71
+ assert p1.balance.retainer_usd_str == "$0.00"
+ assert p1.balance.available_balance_usd_str == "$0.00"