diff options
| author | Max Nanis | 2026-03-06 16:49:46 -0500 |
|---|---|---|
| committer | Max Nanis | 2026-03-06 16:49:46 -0500 |
| commit | 91d040211a4ed6e4157896256a762d3854777b5e (patch) | |
| tree | cd95922ea4257dc8d3f4e4cbe8534474709a20dc /tests | |
| download | generalresearch-91d040211a4ed6e4157896256a762d3854777b5e.tar.gz generalresearch-91d040211a4ed6e4157896256a762d3854777b5e.zip | |
Initial commitv3.3.4
Diffstat (limited to 'tests')
158 files changed, 28836 insertions, 0 deletions
diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..30ed1c7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,19 @@ +pytest_plugins = [ + "distributed.utils_test", + "test_utils.conftest", + # -- GRL IQ + "test_utils.grliq.conftest", + "test_utils.grliq.managers.conftest", + "test_utils.grliq.models.conftest", + # -- Incite + "test_utils.incite.conftest", + "test_utils.incite.collections.conftest", + "test_utils.incite.mergers.conftest", + # -- Managers + "test_utils.managers.conftest", + "test_utils.managers.contest.conftest", + "test_utils.managers.ledger.conftest", + "test_utils.managers.upk.conftest", + # -- Models + "test_utils.models.conftest", +] diff --git a/tests/grliq/__init__.py b/tests/grliq/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/grliq/__init__.py diff --git a/tests/grliq/managers/__init__.py b/tests/grliq/managers/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/grliq/managers/__init__.py diff --git a/tests/grliq/managers/test_forensic_data.py b/tests/grliq/managers/test_forensic_data.py new file mode 100644 index 0000000..ed4da80 --- /dev/null +++ b/tests/grliq/managers/test_forensic_data.py @@ -0,0 +1,212 @@ +from datetime import timedelta +from uuid import uuid4 + +import pytest + +from generalresearch.grliq.models.events import TimingData, MouseEvent +from generalresearch.grliq.models.forensic_data import GrlIqData +from generalresearch.grliq.models.forensic_result import ( + GrlIqCheckerResults, + GrlIqForensicCategoryResult, +) + +try: + from psycopg.errors import UniqueViolation +except ImportError: + pass + + +class TestGrlIqDataManager: + + def test_create_dummy(self, grliq_dm): + from generalresearch.grliq.managers.forensic_data import GrlIqDataManager + from generalresearch.grliq.models.forensic_data import GrlIqData + + grliq_dm: GrlIqDataManager + gd1: GrlIqData = grliq_dm.create_dummy(is_attempt_allowed=True) + + assert isinstance(gd1, GrlIqData) + assert isinstance(gd1.results, GrlIqCheckerResults) + assert isinstance(gd1.category_result, GrlIqForensicCategoryResult) + + def test_create(self, grliq_data, grliq_dm): + grliq_dm.create(grliq_data) + assert grliq_data.id is not None + + with pytest.raises(UniqueViolation): + grliq_dm.create(grliq_data) + + @pytest.mark.skip(reason="todo") + def test_set_result(self): + pass + + @pytest.mark.skip(reason="todo") + def test_update_fingerprint(self): + pass + + @pytest.mark.skip(reason="todo") + def test_update_fingerprint(self): + pass + + @pytest.mark.skip(reason="todo") + def test_update_data(self): + pass + + def test_get_id(self, grliq_data, grliq_dm): + grliq_dm.create(grliq_data) + + res = grliq_dm.get_data(forensic_id=grliq_data.id) + assert res == grliq_data + + def test_get_uuid(self, grliq_data, grliq_dm): + grliq_dm.create(grliq_data) + + res = grliq_dm.get_data(forensic_uuid=grliq_data.uuid) + assert res == grliq_data + + @pytest.mark.skip(reason="todo") + def test_filter_timing_data(self): + pass + + @pytest.mark.skip(reason="todo") + def test_get_unique_user_count_by_fingerprint(self): + pass + + def test_filter_data(self, grliq_data, grliq_dm): + grliq_dm.create(grliq_data) + res = grliq_dm.filter_data(uuids=[grliq_data.uuid])[0] + assert res == grliq_data + now = res.created_at + res = grliq_dm.filter_data( + created_after=now, created_before=now + timedelta(minutes=1) + ) + assert len(res) == 1 + res = grliq_dm.filter_data( + created_after=now + timedelta(seconds=1), + created_before=now + timedelta(minutes=1), + ) + assert len(res) == 0 + + @pytest.mark.skip(reason="todo") + def test_filter_results(self): + pass + + @pytest.mark.skip(reason="todo") + def test_filter_category_results(self): + pass + + @pytest.mark.skip(reason="todo") + def test_make_filter_str(self): + pass + + def test_filter_count(self, grliq_dm, product): + res = grliq_dm.filter_count(product_id=product.uuid) + + assert isinstance(res, int) + + @pytest.mark.skip(reason="todo") + def test_filter(self): + pass + + @pytest.mark.skip(reason="todo") + def test_temporary_add_missing_fields(self): + pass + + +class TestForensicDataGetAndFilter: + + def test_events(self, grliq_dm, grliq_em): + """If load_events=True, the events and mouse_events attributes should + be an array no matter what. An empty array means that the events were + loaded, but there were no events available. + + If loaded_eventsFalse, the events and mouse_events attributes should + be None + """ + # Load Events == False + forensic_uuid = uuid4().hex + grliq_dm.create_dummy(is_attempt_allowed=True, uuid=forensic_uuid) + + instance = grliq_dm.filter_data(uuids=[forensic_uuid])[0] + assert isinstance(instance, GrlIqData) + + assert instance.events is None + assert instance.mouse_events is None + + # Load Events == True + instance = grliq_dm.get_data(forensic_uuid=forensic_uuid, load_events=True) + assert isinstance(instance, GrlIqData) + # This one doesn't have any events though + assert len(instance.events) == 0 + assert len(instance.mouse_events) == 0 + + def test_timing(self, grliq_dm, grliq_em): + forensic_uuid = uuid4().hex + grliq_dm.create_dummy(is_attempt_allowed=True, uuid=forensic_uuid) + + instance = grliq_dm.filter_data(uuids=[forensic_uuid])[0] + + grliq_em.update_or_create_timing( + session_uuid=instance.mid, + timing_data=TimingData( + client_rtts=[100, 200, 150], server_rtts=[150, 120, 120] + ), + ) + instance = grliq_dm.get_data(forensic_uuid=forensic_uuid, load_events=True) + assert isinstance(instance, GrlIqData) + assert isinstance(instance.events, list) + assert isinstance(instance.mouse_events, list) + assert isinstance(instance.timing_data, TimingData) + + def test_events_events(self, grliq_dm, grliq_em): + forensic_uuid = uuid4().hex + grliq_dm.create_dummy(is_attempt_allowed=True, uuid=forensic_uuid) + + instance = grliq_dm.filter_data(uuids=[forensic_uuid])[0] + + grliq_em.update_or_create_events( + session_uuid=instance.mid, + events=[{"a": "b"}], + mouse_events=[], + event_start=instance.created_at, + event_end=instance.created_at + timedelta(minutes=1), + ) + instance = grliq_dm.get_data(forensic_uuid=forensic_uuid, load_events=True) + assert isinstance(instance, GrlIqData) + assert isinstance(instance.events, list) + assert isinstance(instance.mouse_events, list) + assert instance.timing_data is None + assert instance.events == [{"a": "b"}] + assert len(instance.mouse_events) == 0 + assert len(instance.pointer_move_events) == 0 + assert len(instance.keyboard_events) == 0 + + def test_events_click(self, grliq_dm, grliq_em): + forensic_uuid = uuid4().hex + grliq_dm.create_dummy(is_attempt_allowed=True, uuid=forensic_uuid) + instance = grliq_dm.get_data(forensic_uuid=forensic_uuid, load_events=True) + + click_event = { + "type": "click", + "pageX": 0, + "pageY": 0, + "timeStamp": 123, + "pointerType": "mouse", + } + me = MouseEvent.from_dict(click_event) + grliq_em.update_or_create_events( + session_uuid=instance.mid, + events=[click_event], + mouse_events=[], + event_start=instance.created_at, + event_end=instance.created_at + timedelta(minutes=1), + ) + instance = grliq_dm.get_data(forensic_uuid=forensic_uuid, load_events=True) + assert isinstance(instance, GrlIqData) + assert isinstance(instance.events, list) + assert isinstance(instance.mouse_events, list) + assert instance.timing_data is None + assert instance.events == [click_event] + assert instance.mouse_events == [me] + assert len(instance.pointer_move_events) == 0 + assert len(instance.keyboard_events) == 0 diff --git a/tests/grliq/managers/test_forensic_results.py b/tests/grliq/managers/test_forensic_results.py new file mode 100644 index 0000000..a837a64 --- /dev/null +++ b/tests/grliq/managers/test_forensic_results.py @@ -0,0 +1,16 @@ +class TestGrlIqCategoryResultsReader: + + def test_filter_category_results(self, grliq_dm, grliq_crr): + from generalresearch.grliq.models.forensic_result import ( + Phase, + GrlIqForensicCategoryResult, + ) + + # this is just testing that it doesn't fail + grliq_dm.create_dummy(is_attempt_allowed=True) + grliq_dm.create_dummy(is_attempt_allowed=True) + + res = grliq_crr.filter_category_results(limit=2, phase=Phase.OFFERWALL_ENTER)[0] + assert res.get("category_result") + assert isinstance(res["category_result"], GrlIqForensicCategoryResult) + assert res["user_agent"].os.family diff --git a/tests/grliq/models/__init__.py b/tests/grliq/models/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/grliq/models/__init__.py diff --git a/tests/grliq/models/test_forensic_data.py b/tests/grliq/models/test_forensic_data.py new file mode 100644 index 0000000..653f9a9 --- /dev/null +++ b/tests/grliq/models/test_forensic_data.py @@ -0,0 +1,49 @@ +import pytest +from pydantic import ValidationError + +from generalresearch.grliq.models.forensic_data import GrlIqData, Platform + + +class TestGrlIqData: + + def test_supported_fonts(self, grliq_data): + s = grliq_data.supported_fonts_binary + assert len(s) == 1043 + assert "Ubuntu" in grliq_data.supported_fonts + + def test_battery(self, grliq_data): + assert not grliq_data.battery_charging + assert grliq_data.battery_level == 0.41 + + def test_base(self, grliq_data): + g: GrlIqData = grliq_data + assert g.timezone == "America/Los_Angeles" + assert g.platform == Platform.LINUX_X86_64 + assert g.webgl_extensions + # ... more + + assert g.results is None + assert g.category_result is None + + s = g.model_dump_json() + g2: GrlIqData = GrlIqData.model_validate_json(s) + + assert g2.results is None + assert g2.category_result is None + + assert g == g2 + + # Testing things that will cause a validation error, should only be + # because something is "corrupt", not b/c the user is a baddie + def test_corrupt(self, grliq_data): + """Test for timestamp and timezone offset mismatch validation.""" + d = grliq_data.model_dump(mode="json") + d.update( + { + "timezone": "America/XXX", + } + ) + with pytest.raises(ValidationError) as e: + GrlIqData.model_validate(d) + + assert "Invalid timezone name" in str(e.value) diff --git a/tests/grliq/test_utils.py b/tests/grliq/test_utils.py new file mode 100644 index 0000000..d9034d5 --- /dev/null +++ b/tests/grliq/test_utils.py @@ -0,0 +1,17 @@ +from pathlib import Path +from uuid import uuid4 + + +class TestUtils: + + def test_get_screenshot_fp(self, mnt_grliq_archive_dir, utc_hour_ago): + from generalresearch.grliq.utils import get_screenshot_fp + + fp1 = get_screenshot_fp( + created_at=utc_hour_ago, + forensic_uuid=uuid4(), + grliq_archive_dir=mnt_grliq_archive_dir, + create_dir_if_not_exists=True, + ) + + assert isinstance(fp1, Path) diff --git a/tests/incite/__init__.py b/tests/incite/__init__.py new file mode 100644 index 0000000..2f736e8 --- /dev/null +++ b/tests/incite/__init__.py @@ -0,0 +1,137 @@ +# class TestParquetBehaviors(CleanTempDirectoryTestCls): +# wall_coll = WallDFCollection( +# start=GLOBAL_VARS["wall"].start, +# offset="49h", +# archive_path=f"{settings.incite_mount_dir}/raw/df-collections/{DFCollectionType.WALL.value}", +# ) +# +# def test_filters(self): +# # Using REAL data here +# start = datetime(year=2024, month=1, day=15, hour=12, tzinfo=timezone.utc) +# end = datetime(year=2024, month=1, day=15, hour=20, tzinfo=timezone.utc) +# end_max = datetime( +# year=2024, month=1, day=15, hour=20, tzinfo=timezone.utc +# ) + timedelta(hours=2) +# +# ir = pd.Interval(left=pd.Timestamp(start), right=pd.Timestamp(end)) +# wall_items = [w for w in self.wall_coll.items if w.interval.overlaps(ir)] +# ddf = self.wall_coll.ddf( +# items=wall_items, +# include_partial=True, +# force_rr_latest=False, +# columns=["started", "finished"], +# filters=[ +# ("started", ">=", start), +# ("started", "<", end), +# ], +# ) +# +# df = ddf.compute() +# self.assertIsInstance(df, pd.DataFrame) +# +# # No started=None, and they're all between the started and the end +# self.assertFalse(df.started.isna().any()) +# self.assertFalse((df.started < start).any()) +# self.assertFalse((df.started > end).any()) +# +# # Has finished=None and finished=time, so +# # the finished is all between the started and +# # the end_max +# self.assertTrue(df.finished.isna().any()) +# self.assertTrue((df.finished.dt.year == 2024).any()) +# +# self.assertFalse((df.finished > end_max).any()) +# self.assertFalse((df.finished < start).any()) +# +# # def test_user_id_list(self): +# # # Calling compute turns it into a np.ndarray +# # user_ids = self.instance.ddf( +# # columns=["user_id"] +# # ).user_id.unique().values.compute() +# # self.assertIsInstance(user_ids, np.ndarray) +# # +# # # If ddf filters work with ndarray +# # user_product_merge = <todo: assign> +# # +# # with self.assertRaises(TypeError) as cm: +# # user_product_merge.ddf( +# # filters=[("id", "in", user_ids)]) +# # self.assertIn("Value of 'in' filter must be a list, set or tuple.", str(cm.exception)) +# # +# # # No compute == dask array +# # user_ids = self.instance.ddf( +# # columns=["user_id"] +# # ).user_id.unique().values +# # self.assertIsInstance(user_ids, da.Array) +# # +# # with self.assertRaises(TypeError) as cm: +# # user_product_merge.ddf( +# # filters=[("id", "in", user_ids)]) +# # self.assertIn("Value of 'in' filter must be a list, set or tuple.", str(cm.exception)) +# # +# # # pick a product_id (most active one) +# # self.product_id = instance.df.product_id.value_counts().index[0] +# # self.expected_columns: int = len(instance._schema.columns) +# # self.instance = instance +# +# # def test_basic(self): +# # # now try to load up the data! +# # self.instance.grouped_key = self.product_id +# # +# # # Confirm any of the items are archived +# # self.assertTrue(self.instance.progress.has_archive.eq(True).any()) +# # +# # # Confirm it returns a df +# # df = self.instance.dd().compute() +# # +# # self.assertFalse(df.empty) +# # self.assertEqual(df.shape[1], self.expected_columns) +# # self.assertGreater(df.shape[0], 1) +# # +# # # Confirm that DF only contains this product_id +# # self.assertEqual(df[df.product_id == self.product_id].shape, df.shape) +# +# # def test_god_vs_product_id(self): +# # self.instance.grouped_key = self.product_id +# # df_product_origin = self.instance.dd(columns=None, filters=None).compute() +# # +# # self.instance.grouped_key = None +# # df_god_origin = self.instance.dd(columns=None, +# # filters=[("product_id", "==", self.product_id)]).compute() +# # +# # self.assertTrue(df_god_origin.equals(df_product_origin)) +# +# # +# # instance = POPSessionMerge( +# # start=START, +# # archive_path=self.PATH, +# # group_by="product_id" +# # ) +# # instance.build(U=GLOBAL_VARS["user"], S=GLOBAL_VARS["session"], W=GLOBAL_VARS["wall"]) +# # instance.save(god_only=False) +# # +# # # pick a product_id (most active one) +# # self.product_id = instance.df.product_id.value_counts().index[0] +# # self.expected_columns: int = len(instance._schema.columns) +# # self.instance = instance +# +# +# class TestValidItem(CleanTempDirectoryTestCls): +# +# def test_interval(self): +# for k in GLOBAL_VARS.keys(): +# coll = GLOBAL_VARS[k] +# item = coll.items[0] +# ir = item.interval +# +# self.assertIsInstance(ir, pd.Interval) +# self.assertLess(a=ir.left, b=ir.right) +# +# def test_str(self): +# for k in GLOBAL_VARS.keys(): +# coll = GLOBAL_VARS[k] +# item = coll.items[0] +# +# offset = coll.offset or "–" +# +# self.assertIn(offset, str(item)) diff --git a/tests/incite/collections/__init__.py b/tests/incite/collections/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/incite/collections/__init__.py diff --git a/tests/incite/collections/test_df_collection_base.py b/tests/incite/collections/test_df_collection_base.py new file mode 100644 index 0000000..5aaa729 --- /dev/null +++ b/tests/incite/collections/test_df_collection_base.py @@ -0,0 +1,113 @@ +from datetime import datetime, timezone + +import pandas as pd +import pytest +from pandera import DataFrameSchema + +from generalresearch.incite.collections import ( + DFCollectionType, + DFCollection, +) +from test_utils.incite.conftest import mnt_filepath + +df_collection_types = [e for e in DFCollectionType if e is not DFCollectionType.TEST] + + +@pytest.mark.parametrize("df_coll_type", df_collection_types) +class TestDFCollectionBase: + """None of these tests are about the DFCollection with any specific + data_type... that will be handled in other parameterized tests + + """ + + def test_init(self, mnt_filepath, df_coll_type): + """Try to initialize the DFCollection with various invalid parameters""" + with pytest.raises(expected_exception=ValueError) as cm: + DFCollection(archive_path=mnt_filepath.data_src) + assert "Must explicitly provide a data_type" in str(cm.value) + + # with pytest.raises(expected_exception=ValueError) as cm: + # DFCollection( + # data_type=DFCollectionType.TEST, archive_path=mnt_filepath.data_src + # ) + # assert "Must provide a supported data_type" in str(cm.value) + + instance = DFCollection( + data_type=DFCollectionType.WALL, archive_path=mnt_filepath.data_src + ) + assert instance.data_type == DFCollectionType.WALL + + +@pytest.mark.parametrize("df_coll_type", df_collection_types) +class TestDFCollectionBaseProperties: + + @pytest.mark.skip + def test_df_collection_items(self, mnt_filepath, df_coll_type): + instance = DFCollection( + data_type=df_coll_type, + start=datetime(year=1800, month=1, day=1, tzinfo=timezone.utc), + finished=datetime(year=1900, month=1, day=1, tzinfo=timezone.utc), + offset="100d", + archive_path=mnt_filepath.archive_path(enum_type=df_coll_type), + ) + + assert len(instance.interval_range) == len(instance.items) + assert len(instance.items) == 366 + + def test_df_collection_progress(self, mnt_filepath, df_coll_type): + instance = DFCollection( + data_type=df_coll_type, + start=datetime(year=1800, month=1, day=1, tzinfo=timezone.utc), + finished=datetime(year=1900, month=1, day=1, tzinfo=timezone.utc), + offset="100d", + archive_path=mnt_filepath.archive_path(enum_type=df_coll_type), + ) + + # Progress returns a dataframe with a row each Item + assert isinstance(instance.progress, pd.DataFrame) + assert instance.progress.shape == (366, 6) + + def test_df_collection_schema(self, mnt_filepath, df_coll_type): + instance1 = DFCollection( + data_type=DFCollectionType.WALL, archive_path=mnt_filepath.data_src + ) + + instance2 = DFCollection( + data_type=DFCollectionType.SESSION, archive_path=mnt_filepath.data_src + ) + + assert instance1._schema != instance2._schema + assert isinstance(instance1._schema, DataFrameSchema) + assert isinstance(instance2._schema, DataFrameSchema) + + +class TestDFCollectionBaseMethods: + + @pytest.mark.skip + def test_initial_load(self, mnt_filepath, thl_web_rr): + instance = DFCollection( + pg_config=thl_web_rr, + data_type=DFCollectionType.USER, + start=datetime(year=2022, month=1, day=1, minute=0, tzinfo=timezone.utc), + finished=datetime(year=2022, month=1, day=1, minute=5, tzinfo=timezone.utc), + offset="2min", + archive_path=mnt_filepath.data_src, + ) + + # Confirm that there are no archives available yet + assert instance.progress.has_archive.eq(False).all() + + instance.initial_load() + assert 47 == len(instance.ddf().index) + assert instance.progress.should_archive.eq(True).all() + + # A few archives should have been made + assert not instance.progress.has_archive.eq(False).all() + + @pytest.mark.skip + def test_fetch_force_rr_latest(self): + pass + + @pytest.mark.skip + def test_force_rr_latest(self): + pass diff --git a/tests/incite/collections/test_df_collection_item_base.py b/tests/incite/collections/test_df_collection_item_base.py new file mode 100644 index 0000000..a0c0b0b --- /dev/null +++ b/tests/incite/collections/test_df_collection_item_base.py @@ -0,0 +1,72 @@ +from datetime import datetime, timezone + +import pytest + +from generalresearch.incite.collections import ( + DFCollectionType, + DFCollectionItem, + DFCollection, +) +from test_utils.incite.conftest import mnt_filepath + +df_collection_types = [e for e in DFCollectionType if e is not DFCollectionType.TEST] + + +@pytest.mark.parametrize("df_coll_type", df_collection_types) +class TestDFCollectionItemBase: + + def test_init(self, mnt_filepath, df_coll_type): + collection = DFCollection( + data_type=df_coll_type, + offset="100d", + start=datetime(year=1800, month=1, day=1, tzinfo=timezone.utc), + finished=datetime(year=1900, month=1, day=1, tzinfo=timezone.utc), + archive_path=mnt_filepath.archive_path(enum_type=df_coll_type), + ) + + item = DFCollectionItem() + item._collection = collection + + assert isinstance(item, DFCollectionItem) + + +@pytest.mark.parametrize("df_coll_type", df_collection_types) +class TestDFCollectionItemProperties: + + @pytest.mark.skip + def test_filename(self, df_coll_type): + pass + + +@pytest.mark.parametrize("df_coll_type", df_collection_types) +class TestDFCollectionItemMethods: + + def test_has_mysql_false(self, mnt_filepath, df_coll_type): + collection = DFCollection( + data_type=df_coll_type, + offset="100d", + start=datetime(year=1800, month=1, day=1, tzinfo=timezone.utc), + finished=datetime(year=1900, month=1, day=1, tzinfo=timezone.utc), + archive_path=mnt_filepath.archive_path(enum_type=df_coll_type), + ) + + instance1: DFCollectionItem = collection.items[0] + assert not instance1.has_mysql() + + def test_has_mysql_true(self, thl_web_rr, mnt_filepath, df_coll_type): + collection = DFCollection( + data_type=df_coll_type, + offset="100d", + start=datetime(year=1800, month=1, day=1, tzinfo=timezone.utc), + finished=datetime(year=1900, month=1, day=1, tzinfo=timezone.utc), + archive_path=mnt_filepath.archive_path(enum_type=df_coll_type), + pg_config=thl_web_rr, + ) + + # Has RR, assume unittest server is online + instance2: DFCollectionItem = collection.items[0] + assert instance2.has_mysql() + + @pytest.mark.skip + def test_update_partial_archive(self, df_coll_type): + pass diff --git a/tests/incite/collections/test_df_collection_item_thl_web.py b/tests/incite/collections/test_df_collection_item_thl_web.py new file mode 100644 index 0000000..9c3d67a --- /dev/null +++ b/tests/incite/collections/test_df_collection_item_thl_web.py @@ -0,0 +1,994 @@ +from datetime import datetime, timezone, timedelta +from itertools import product as iter_product +from os.path import join as pjoin +from pathlib import PurePath, Path +from uuid import uuid4 + +import dask.dataframe as dd +import pandas as pd +import pytest +from distributed import Client, Scheduler, Worker + +# noinspection PyUnresolvedReferences +from distributed.utils_test import ( + gen_cluster, + client_no_amm, + loop, + loop_in_thread, + cleanup, + cluster_fixture, + client, +) +from faker import Faker +from pandera import DataFrameSchema +from pydantic import FilePath + +from generalresearch.incite.base import CollectionItemBase +from generalresearch.incite.collections import ( + DFCollectionItem, + DFCollectionType, +) +from generalresearch.incite.schemas import ARCHIVE_AFTER +from generalresearch.models.thl.user import User +from generalresearch.pg_helper import PostgresConfig +from generalresearch.sql_helper import PostgresDsn +from test_utils.incite.conftest import mnt_filepath, incite_item_factory + +fake = Faker() + +df_collections = [ + DFCollectionType.WALL, + DFCollectionType.SESSION, + DFCollectionType.LEDGER, + DFCollectionType.TASK_ADJUSTMENT, +] + +unsupported_mock_types = { + DFCollectionType.IP_INFO, + DFCollectionType.IP_HISTORY, + DFCollectionType.IP_HISTORY_WS, + DFCollectionType.TASK_ADJUSTMENT, +} + + +def combo_object(): + for x in iter_product( + df_collections, + ["15min", "45min", "1H"], + ): + yield x + + +class TestDFCollectionItemBase: + def test_init(self): + instance = CollectionItemBase() + assert isinstance(instance, CollectionItemBase) + assert isinstance(instance.start, datetime) + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollectionItemProperties: + + def test_filename(self, df_collection_data_type, df_collection, offset): + for i in df_collection.items: + assert isinstance(i.filename, str) + + assert isinstance(i.path, PurePath) + assert i.path.name == i.filename + + assert i._collection.data_type.name.lower() in i.filename + assert i._collection.offset in i.filename + assert i.start.strftime("%Y-%m-%d-%H-%M-%S") in i.filename + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollectionItemPropertiesBase: + + def test_name(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.name, str) + + def test_finish(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.finish, datetime) + + def test_interval(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.interval, pd.Interval) + + def test_partial_filename(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.partial_filename, str) + + def test_empty_filename(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.empty_filename, str) + + def test_path(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.path, FilePath) + + def test_partial_path(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.partial_path, FilePath) + + def test_empty_path(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.empty_path, FilePath) + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset, duration", + argvalues=list( + iter_product( + df_collections, + ["12h", "10D"], + [timedelta(days=10), timedelta(days=45)], + ) + ), +) +class TestDFCollectionItemMethod: + + def test_has_mysql( + self, + df_collection, + thl_web_rr, + offset, + duration, + df_collection_data_type, + delete_df_collection, + ): + delete_df_collection(coll=df_collection) + + df_collection.pg_config = None + for i in df_collection.items: + assert not i.has_mysql() + + # Confirm that the regular connection should work as expected + df_collection.pg_config = thl_web_rr + for i in df_collection.items: + assert i.has_mysql() + + # Make a fake connection and confirm it does NOT work + df_collection.pg_config = PostgresConfig( + dsn=PostgresDsn(f"postgres://root:@127.0.0.1/{uuid4().hex}"), + connect_timeout=5, + statement_timeout=1, + ) + for i in df_collection.items: + assert not i.has_mysql() + + @pytest.mark.skip + def test_update_partial_archive( + self, + df_collection, + offset, + duration, + thl_web_rw, + df_collection_data_type, + delete_df_collection, + ): + # for i in collection.items: + # assert i.update_partial_archive() + # assert df.created.max() < _last_time_block[1] + pass + + @pytest.mark.skip + def test_create_partial_archive( + self, + df_collection, + offset, + duration, + create_main_accounts, + thl_web_rw, + thl_lm, + df_collection_data_type, + user_factory, + product, + client_no_amm, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + assert 1 + 1 == 2 + + def test_dict( + self, + df_collection_data_type, + offset, + duration, + df_collection, + delete_df_collection, + ): + delete_df_collection(coll=df_collection) + + for item in df_collection.items: + res = item.to_dict() + assert isinstance(res, dict) + assert len(res.keys()) == 6 + + assert isinstance(res["should_archive"], bool) + assert isinstance(res["has_archive"], bool) + assert isinstance(res["path"], Path) + assert isinstance(res["filename"], str) + + assert isinstance(res["start"], datetime) + assert isinstance(res["finish"], datetime) + assert res["start"] < res["finish"] + + def test_from_mysql( + self, + df_collection_data_type, + df_collection, + offset, + duration, + create_main_accounts, + thl_web_rw, + user_factory, + product, + incite_item_factory, + delete_df_collection, + ): + from generalresearch.models.thl.user import User + + if df_collection.data_type in unsupported_mock_types: + return + + delete_df_collection(coll=df_collection) + u1: User = user_factory(product=product) + + # No data has been loaded, but we can confirm the from_mysql returns + # back an empty data with the correct columns + for item in df_collection.items: + # Unlike .from_mysql_ledger(), .from_mysql_standard() will return + # back and empty df with the correct columns in place + delete_df_collection(coll=df_collection) + df = item.from_mysql() + if df_collection.data_type == DFCollectionType.LEDGER: + assert df is None + else: + assert df.empty + assert set(df.columns) == set(df_collection._schema.columns.keys()) + + incite_item_factory(user=u1, item=item) + + df = item.from_mysql() + assert not df.empty + assert set(df.columns) == set(df_collection._schema.columns.keys()) + if df_collection.data_type == DFCollectionType.LEDGER: + # The number of rows in this dataframe will change depending + # on the mocking of data. It's because if the account has + # user wallet on, then there will be more transactions for + # example. + assert df.shape[0] > 0 + + def test_from_mysql_standard( + self, + df_collection_data_type, + df_collection, + offset, + duration, + user_factory, + product, + incite_item_factory, + delete_df_collection, + ): + from generalresearch.models.thl.user import User + + if df_collection.data_type in unsupported_mock_types: + return + u1: User = user_factory(product=product) + + delete_df_collection(coll=df_collection) + + for item in df_collection.items: + item: DFCollectionItem + + if df_collection.data_type == DFCollectionType.LEDGER: + # We're using parametrize, so this If statement is just to + # confirm other Item Types will always raise an assertion + with pytest.raises(expected_exception=AssertionError) as cm: + res = item.from_mysql_standard() + assert ( + "Can't call from_mysql_standard for Ledger DFCollectionItem" + in str(cm.value) + ) + + continue + + # Unlike .from_mysql_ledger(), .from_mysql_standard() will return + # back and empty df with the correct columns in place + df = item.from_mysql_standard() + assert df.empty + assert set(df.columns) == set(df_collection._schema.columns.keys()) + + incite_item_factory(user=u1, item=item) + + df = item.from_mysql_standard() + assert not df.empty + assert set(df.columns) == set(df_collection._schema.columns.keys()) + assert df.shape[0] > 0 + + def test_from_mysql_ledger( + self, + df_collection, + user, + create_main_accounts, + offset, + duration, + thl_web_rw, + thl_lm, + df_collection_data_type, + user_factory, + product, + client_no_amm, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + from generalresearch.models.thl.user import User + + if df_collection.data_type != DFCollectionType.LEDGER: + return + u1: User = user_factory(product=product) + + delete_df_collection(coll=df_collection) + + for item in df_collection.items: + item: DFCollectionItem + delete_df_collection(coll=df_collection) + + # Okay, now continue with the actual Ledger Item tests... we need + # to ensure that this item.start - item.finish range hasn't had + # any prior transactions created within that range. + assert item.from_mysql_ledger() is None + + # Create main accounts doesn't matter because it doesn't + # add any transactions to the db + assert item.from_mysql_ledger() is None + + incite_item_factory(user=u1, item=item) + df = item.from_mysql_ledger() + assert isinstance(df, pd.DataFrame) + + # Not only is this a np.int64 to int comparison, but I also know it + # isn't actually measuring anything meaningful. However, it's useful + # as it tells us if the DF contains all the correct TX Entries. I + # figured it's better to count the amount rather than just the + # number of rows. DF == transactions * 2 because there are two + # entries per transactions + # assert df.amount.sum() == total_amt + # assert total_entries == df.shape[0] + + assert not df.tx_id.is_unique + df["net"] = df.direction * df.amount + assert df.groupby("tx_id").net.sum().sum() == 0 + + def test_to_archive( + self, + df_collection, + user, + offset, + duration, + df_collection_data_type, + user_factory, + product, + client_no_amm, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + from generalresearch.models.thl.user import User + + if df_collection.data_type in unsupported_mock_types: + return + u1: User = user_factory(product=product) + + delete_df_collection(coll=df_collection) + + for item in df_collection.items: + item: DFCollectionItem + + incite_item_factory(user=u1, item=item) + + # Load up the data that we'll be using for various to_archive + # methods. + df = item.from_mysql() + ddf = dd.from_pandas(df, npartitions=1) + + # (1) Write the basic archive, the issue is that because it's + # an empty pd.DataFrame, it never makes an actual parquet file + assert item.to_archive(ddf=ddf, is_partial=False, overwrite=False) + assert item.has_archive() + assert item.has_archive(include_empty=False) + + def test__to_archive( + self, + df_collection_data_type, + df_collection, + user_factory, + product, + offset, + duration, + client_no_amm, + user, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + """We already have a test for the "non-private" version of this, + which primarily just uses the respective Client to determine if + the ddf is_empty or not. + + Therefore, use the private test to check the manual behavior of + passing in the is_empty or overwrite. + """ + if df_collection.data_type in unsupported_mock_types: + return + + delete_df_collection(coll=df_collection) + u1: User = user_factory(product=product) + + for item in df_collection.items: + item: DFCollectionItem + + incite_item_factory(user=u1, item=item) + + # Load up the data that we'll be using for various to_archive + # methods. Will always be empty pd.DataFrames for now... + df = item.from_mysql() + ddf = dd.from_pandas(df, npartitions=1) + + # (1) Confirm a missing ddf (shouldn't bc of type hint) should + # immediately return back False + assert not item._to_archive(ddf=None, is_empty=True) + assert not item._to_archive(ddf=None, is_empty=False) + + # (2) Setting empty overrides any possible state of the ddf + for rand_val in [df, ddf, True, 1_000]: + assert not item.empty_path.exists() + item._to_archive(ddf=rand_val, is_empty=True) + assert item.empty_path.exists() + item.empty_path.unlink() + + # (3) Trigger a warning with overwrite. First write an empty, + # then write it again with override default to confirm it worked, + # then write it again with override=False to confirm it does + # not work. + assert item._to_archive(ddf=ddf, is_empty=True) + res1 = item.empty_path.stat() + + # Returns none because it knows the file (regular, empty, or + # partial) already exists + assert not item._to_archive(ddf=ddf, is_empty=True, overwrite=False) + + # Currently override=True doesn't actually work on empty files + # because it's checked again in .set_empty() and isn't + # aware of the override flag that may be passed in to + # item._to_archive() + with pytest.raises(expected_exception=AssertionError) as cm: + item._to_archive(ddf=rand_val, is_empty=True, overwrite=True) + assert "set_empty is already set; why are you doing this?" in str(cm.value) + + # We can assert the file stats are the same because we were never + # able to go ahead and rewrite or update it in anyway + res2 = item.empty_path.stat() + assert res1 == res2 + + @pytest.mark.skip + def test_to_archive_numbered_partial( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + @pytest.mark.skip + def test_initial_load( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + @pytest.mark.skip + def test_clear_corrupt_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset, duration", + argvalues=list(iter_product(df_collections, ["12h", "10D"], [timedelta(days=15)])), +) +class TestDFCollectionItemMethodBase: + + @pytest.mark.skip + def test_path_exists(self, df_collection_data_type, offset, duration): + pass + + @pytest.mark.skip + def test_next_numbered_path(self, df_collection_data_type, offset, duration): + pass + + @pytest.mark.skip + def test_search_highest_numbered_path( + self, df_collection_data_type, offset, duration + ): + pass + + @pytest.mark.skip + def test_tmp_filename(self, df_collection_data_type, offset, duration): + pass + + @pytest.mark.skip + def test_tmp_path(self, df_collection_data_type, offset, duration): + pass + + def test_is_empty(self, df_collection_data_type, df_collection, offset, duration): + """ + test_has_empty was merged into this because item.has_empty is + an alias for is_empty.. or vis-versa + """ + + for item in df_collection.items: + assert not item.is_empty() + assert not item.has_empty() + + item.empty_path.touch() + + assert item.is_empty() + assert item.has_empty() + + def test_has_partial_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + for item in df_collection.items: + assert not item.has_partial_archive() + item.partial_path.touch() + assert item.has_partial_archive() + + def test_has_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + for item in df_collection.items: + # (1) Originally, nothing exists... so let's just make a file and + # confirm that it works if just touch that path (no validation + # occurs at all). + assert not item.has_archive(include_empty=False) + assert not item.has_archive(include_empty=True) + item.path.touch() + assert item.has_archive(include_empty=False) + assert item.has_archive(include_empty=True) + + item.path.unlink() + assert not item.has_archive(include_empty=False) + assert not item.has_archive(include_empty=True) + + # (2) Same as the above, except make an empty directory + # instead of a file + assert not item.has_archive(include_empty=False) + assert not item.has_archive(include_empty=True) + item.path.mkdir() + assert item.has_archive(include_empty=False) + assert item.has_archive(include_empty=True) + + item.path.rmdir() + assert not item.has_archive(include_empty=False) + assert not item.has_archive(include_empty=True) + + # (3) Rather than make a empty file or dir at the path, let's + # touch the empty_path and confirm the include_empty option + # works + + item.empty_path.touch() + assert not item.has_archive(include_empty=False) + assert item.has_archive(include_empty=True) + + def test_delete_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + for item in df_collection.items: + item: DFCollectionItem + # (1) Confirm that it doesn't raise an error or anything if we + # try to delete files or folders that do not exist + CollectionItemBase.delete_archive(generic_path=item.path) + CollectionItemBase.delete_archive(generic_path=item.empty_path) + CollectionItemBase.delete_archive(generic_path=item.partial_path) + + item.path.touch() + item.empty_path.touch() + item.partial_path.touch() + + CollectionItemBase.delete_archive(generic_path=item.path) + CollectionItemBase.delete_archive(generic_path=item.empty_path) + CollectionItemBase.delete_archive(generic_path=item.partial_path) + + assert not item.path.exists() + assert not item.empty_path.exists() + assert not item.partial_path.exists() + + def test_should_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + schema: DataFrameSchema = df_collection._schema + aa = schema.metadata[ARCHIVE_AFTER] + + # It shouldn't be None, it can be timedelta(seconds=0) + assert isinstance(aa, timedelta) + + for item in df_collection.items: + item: DFCollectionItem + + if datetime.now(tz=timezone.utc) > item.finish + aa: + assert item.should_archive() + else: + assert not item.should_archive() + + @pytest.mark.skip + def test_set_empty(self, df_collection_data_type, df_collection, offset, duration): + pass + + def test_valid_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + # Originally, nothing has been saved or anything.. so confirm it + # always comes back as None + for item in df_collection.items: + assert not item.valid_archive(generic_path=None, sample=None) + + _path = Path(pjoin(df_collection.archive_path, uuid4().hex)) + + # (1) Fail if isfile, but doesn't exist and if we can't read + # it as valid ParquetFile + assert not item.valid_archive(generic_path=_path, sample=None) + _path.touch() + assert not item.valid_archive(generic_path=_path, sample=None) + _path.unlink() + + # (2) Fail if isdir and we can't read it as a valid ParquetFile + _path.mkdir() + assert _path.is_dir() + assert not item.valid_archive(generic_path=_path, sample=None) + _path.rmdir() + + @pytest.mark.skip + def test_validate_df( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + @pytest.mark.skip + def test_from_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + def test__to_dict(self, df_collection_data_type, df_collection, offset, duration): + + for item in df_collection.items: + res = item._to_dict() + assert isinstance(res, dict) + assert len(res.keys()) == 6 + + assert isinstance(res["should_archive"], bool) + assert isinstance(res["has_archive"], bool) + assert isinstance(res["path"], Path) + assert isinstance(res["filename"], str) + + assert isinstance(res["start"], datetime) + assert isinstance(res["finish"], datetime) + assert res["start"] < res["finish"] + + @pytest.mark.skip + def test_delete_partial( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + @pytest.mark.skip + def test_cleanup_partials( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + @pytest.mark.skip + def test_delete_dangling_partials( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) +async def test_client(client, s, worker): + """c,s,a are all required - the secondary Worker (b) is not required""" + + assert isinstance(client, Client) + assert isinstance(s, Scheduler) + assert isinstance(worker, Worker) + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", + argvalues=combo_object(), +) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) +@pytest.mark.anyio +async def test_client_parametrize(c, s, w, df_collection_data_type, offset): + """c,s,a are all required - the secondary Worker (b) is not required""" + + assert isinstance(c, Client), f"c is not Client, it's {type(c)}" + assert isinstance(s, Scheduler), f"s is not Scheduler, it's {type(s)}" + assert isinstance(w, Worker), f"w is not Worker, it's {type(w)}" + + assert df_collection_data_type is not None + assert isinstance(offset, str) + + +# I cannot figure out how to define the parametrize on the Test, but then have +# sync or async methods within it, with some having or not having the +# gen_cluster decorator set. + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset, duration", + argvalues=list(iter_product(df_collections, ["12h", "10D"], [timedelta(days=15)])), +) +class TestDFCollectionItemFunctionalTest: + + def test_to_archive_and_ddf( + self, + df_collection_data_type, + offset, + duration, + client_no_amm, + df_collection, + user, + user_factory, + product, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + from generalresearch.models.thl.user import User + + if df_collection.data_type in unsupported_mock_types: + return + u1: User = user_factory(product=product) + + delete_df_collection(coll=df_collection) + df_collection._client = client_no_amm + + # Assert that there are no pre-existing archives + assert df_collection.progress.has_archive.eq(False).all() + res = df_collection.ddf() + assert res is None + + delete_df_collection(coll=df_collection) + for item in df_collection.items: + item: DFCollectionItem + + incite_item_factory(user=u1, item=item) + item.initial_load() + + # I know it seems weird to delete items from the database before we + # proceed with the test. However, the content should have already + # been saved out into an parquet at this point, and I am too lazy + # to write a separate teardown for a collection (and not a + # single Item) + + # Now that we went ahead with the initial_load, Assert that all + # items have archives files saved + assert isinstance(df_collection.progress, pd.DataFrame) + assert df_collection.progress.has_archive.eq(True).all() + + ddf = df_collection.ddf() + shape = df_collection._client.compute(collections=ddf.shape, sync=True) + assert shape[0] > 5 + + def test_filesize_estimate( + self, + df_collection, + user, + offset, + duration, + client_no_amm, + user_factory, + product, + df_collection_data_type, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + """A functional test to write some Parquet files for the + DFCollection and then confirm that the files get written + correctly. + + Confirm the files are written correctly by: + (1) Validating their passing the pandera schema + (2) The file or dir has an expected size on disk + """ + import pyarrow.parquet as pq + from generalresearch.models.thl.user import User + import os + + if df_collection.data_type in unsupported_mock_types: + return + delete_df_collection(coll=df_collection) + u1: User = user_factory(product=product) + + # Pick 3 random items to sample for correct filesize + for item in df_collection.items: + item: DFCollectionItem + + incite_item_factory(user=u1, item=item) + item.initial_load(overwrite=True) + + total_bytes = 0 + for fp in pq.ParquetDataset(item.path).files: + total_bytes += os.stat(fp).st_size + + total_mb = total_bytes / 1_048_576 + + assert total_bytes > 1_000 + assert total_mb < 1 + + def test_to_archive_client( + self, + client_no_amm, + df_collection, + user_factory, + product, + offset, + duration, + df_collection_data_type, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + from generalresearch.models.thl.user import User + + delete_df_collection(coll=df_collection) + df_collection._client = client_no_amm + u1: User = user_factory(product=product) + + for item in df_collection.items: + item: DFCollectionItem + + if df_collection.data_type in unsupported_mock_types: + continue + + incite_item_factory(user=u1, item=item) + + # Load up the data that we'll be using for various to_archive + # methods. Will always be empty pd.DataFrames for now... + df = item.from_mysql() + ddf = dd.from_pandas(df, npartitions=1) + assert isinstance(ddf, dd.DataFrame) + + # (1) Write the basic archive, the issue is that because it's + # an empty pd.DataFrame, it never makes an actual parquet file + assert not item.has_archive() + saved = item.to_archive(ddf=ddf, is_partial=False, overwrite=False) + assert saved + assert item.has_archive(include_empty=True) + + @pytest.mark.skip + def test_get_items(self, df_collection, product, offset, duration): + with pytest.warns(expected_warning=ResourceWarning) as cm: + df_collection.get_items_last365() + assert "DFCollectionItem has missing archives" in str( + [w.message for w in cm.list] + ) + + res = df_collection.get_items_last365() + assert len(res) == len(df_collection.items) + + def test_saving_protections( + self, + client_no_amm, + df_collection_data_type, + df_collection, + incite_item_factory, + delete_df_collection, + user_factory, + product, + offset, + duration, + mnt_filepath, + ): + """Don't allow creating an archive for data that will likely be + overwritten or updated + """ + from generalresearch.models.thl.user import User + + if df_collection.data_type in unsupported_mock_types: + return + u1: User = user_factory(product=product) + + schema: DataFrameSchema = df_collection._schema + aa = schema.metadata[ARCHIVE_AFTER] + assert isinstance(aa, timedelta) + + delete_df_collection(df_collection) + for item in df_collection.items: + item: DFCollectionItem + + incite_item_factory(user=u1, item=item) + + should_archive = item.should_archive() + res = item.initial_load() + + # self.assertIn("Cannot create archive for such new data", str(cm.records)) + + # .to_archive() will return back True or False depending on if it + # was successful. We want to compare that result to the + # .should_archive() method result + assert should_archive == res + + def test_empty_item( + self, + client_no_amm, + df_collection_data_type, + df_collection, + incite_item_factory, + delete_df_collection, + user, + offset, + duration, + mnt_filepath, + ): + delete_df_collection(coll=df_collection) + + for item in df_collection.items: + assert not item.has_empty() + df: pd.DataFrame = item.from_mysql() + + # We do this check b/c the Ledger returns back None and + # I don't want it to fail when we go to make a ddf + if df is None: + item.set_empty() + else: + ddf = dd.from_pandas(df, npartitions=1) + item.to_archive(ddf=ddf) + + assert item.has_empty() + + def test_file_touching( + self, + client_no_amm, + df_collection_data_type, + df_collection, + incite_item_factory, + delete_df_collection, + user_factory, + product, + offset, + duration, + mnt_filepath, + ): + from generalresearch.models.thl.user import User + + delete_df_collection(coll=df_collection) + df_collection._client = client_no_amm + u1: User = user_factory(product=product) + + for item in df_collection.items: + # Confirm none of the paths exist yet + assert not item.has_archive() + assert not item.path_exists(generic_path=item.path) + assert not item.has_empty() + assert not item.path_exists(generic_path=item.empty_path) + + if df_collection.data_type in unsupported_mock_types: + assert not item.has_archive(include_empty=False) + assert not item.has_empty() + assert not item.path_exists(generic_path=item.empty_path) + else: + incite_item_factory(user=u1, item=item) + item.initial_load() + + assert item.has_archive(include_empty=False) + assert item.path_exists(generic_path=item.path) + assert not item.has_empty() diff --git a/tests/incite/collections/test_df_collection_thl_marketplaces.py b/tests/incite/collections/test_df_collection_thl_marketplaces.py new file mode 100644 index 0000000..0a77938 --- /dev/null +++ b/tests/incite/collections/test_df_collection_thl_marketplaces.py @@ -0,0 +1,75 @@ +from datetime import datetime, timezone +from itertools import product + +import pytest +from pandera import Column, Index, DataFrameSchema + +from generalresearch.incite.collections import DFCollection +from generalresearch.incite.collections import DFCollectionType +from generalresearch.incite.collections.thl_marketplaces import ( + InnovateSurveyHistoryCollection, + MorningSurveyTimeseriesCollection, + SagoSurveyHistoryCollection, + SpectrumSurveyTimeseriesCollection, +) +from test_utils.incite.conftest import mnt_filepath + + +def combo_object(): + for x in product( + [ + InnovateSurveyHistoryCollection, + MorningSurveyTimeseriesCollection, + SagoSurveyHistoryCollection, + SpectrumSurveyTimeseriesCollection, + ], + ["5min", "6H", "30D"], + ): + yield x + + +@pytest.mark.parametrize("df_coll, offset", combo_object()) +class TestDFCollection_thl_marketplaces: + + def test_init(self, mnt_filepath, df_coll, offset, spectrum_rw): + assert issubclass(df_coll, DFCollection) + + # This is stupid, but we need to pull the default from the + # Pydantic field + data_type = df_coll.model_fields["data_type"].default + assert isinstance(data_type, DFCollectionType) + + # (1) Can't be totally empty, needs a path... + with pytest.raises(expected_exception=Exception) as cm: + instance = df_coll() + + # (2) Confirm it only needs the archive_path + instance = df_coll( + archive_path=mnt_filepath.archive_path(enum_type=data_type), + ) + assert isinstance(instance, DFCollection) + + # (3) Confirm it loads with all + instance = df_coll( + archive_path=mnt_filepath.archive_path(enum_type=data_type), + sql_helper=spectrum_rw, + offset=offset, + start=datetime(year=2023, month=6, day=1, minute=0, tzinfo=timezone.utc), + finished=datetime(year=2023, month=6, day=1, minute=5, tzinfo=timezone.utc), + ) + assert isinstance(instance, DFCollection) + + # (4) Now that we initialize the Class, we can access the _schema + assert isinstance(instance._schema, DataFrameSchema) + assert isinstance(instance._schema.index, Index) + + for c in instance._schema.columns.keys(): + assert isinstance(c, str) + col = instance._schema.columns[c] + assert isinstance(col, Column) + + assert instance._schema.coerce, "coerce on all Schemas" + assert isinstance(instance._schema.checks, list) + assert len(instance._schema.checks) == 0 + assert isinstance(instance._schema.metadata, dict) + assert len(instance._schema.metadata.keys()) == 2 diff --git a/tests/incite/collections/test_df_collection_thl_web.py b/tests/incite/collections/test_df_collection_thl_web.py new file mode 100644 index 0000000..e6f464b --- /dev/null +++ b/tests/incite/collections/test_df_collection_thl_web.py @@ -0,0 +1,160 @@ +from datetime import datetime +from itertools import product + +import dask.dataframe as dd +import pandas as pd +import pytest +from pandera import DataFrameSchema + +from generalresearch.incite.collections import DFCollection, DFCollectionType + + +def combo_object(): + for x in product( + [ + DFCollectionType.USER, + DFCollectionType.WALL, + DFCollectionType.SESSION, + DFCollectionType.TASK_ADJUSTMENT, + DFCollectionType.AUDIT_LOG, + DFCollectionType.LEDGER, + ], + ["30min", "1H"], + ): + yield x + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollection_thl_web: + + def test_init(self, df_collection_data_type, offset, df_collection): + assert isinstance(df_collection_data_type, DFCollectionType) + assert isinstance(df_collection, DFCollection) + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollection_thl_web_Properties: + + def test_items(self, df_collection_data_type, offset, df_collection): + assert isinstance(df_collection.items, list) + for i in df_collection.items: + assert i._collection == df_collection + + def test__schema(self, df_collection_data_type, offset, df_collection): + assert isinstance(df_collection._schema, DataFrameSchema) + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollection_thl_web_BaseProperties: + + @pytest.mark.skip + def test__interval_range(self, df_collection_data_type, offset, df_collection): + pass + + def test_interval_start(self, df_collection_data_type, offset, df_collection): + assert isinstance(df_collection.interval_start, datetime) + + def test_interval_range(self, df_collection_data_type, offset, df_collection): + assert isinstance(df_collection.interval_range, list) + + def test_progress(self, df_collection_data_type, offset, df_collection): + assert isinstance(df_collection.progress, pd.DataFrame) + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollection_thl_web_Methods: + + @pytest.mark.skip + def test_initial_loads(self, df_collection_data_type, df_collection, offset): + pass + + @pytest.mark.skip + def test_fetch_force_rr_latest( + self, df_collection_data_type, df_collection, offset + ): + pass + + @pytest.mark.skip + def test_force_rr_latest(self, df_collection_data_type, df_collection, offset): + pass + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollection_thl_web_BaseMethods: + + def test_fetch_all_paths(self, df_collection_data_type, offset, df_collection): + res = df_collection.fetch_all_paths( + items=None, force_rr_latest=False, include_partial=False + ) + assert isinstance(res, list) + + @pytest.mark.skip + def test_ddf(self, df_collection_data_type, offset, df_collection): + res = df_collection.ddf() + assert isinstance(res, dd.DataFrame) + + # -- cleanup -- + @pytest.mark.skip + def test_schedule_cleanup(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_cleanup(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_cleanup_partials(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_clear_tmp_archives(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_clear_corrupt_archives( + self, df_collection_data_type, offset, df_collection + ): + pass + + @pytest.mark.skip + def test_rebuild_symlinks(self, df_collection_data_type, offset, df_collection): + pass + + # -- Source timing -- + @pytest.mark.skip + def test_get_item(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_get_item_start(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_get_items(self, df_collection_data_type, offset, df_collection): + # If we get all the items from the start of the collection, it + # should include all the items! + res1 = df_collection.items + res2 = df_collection.get_items(since=df_collection.start) + assert len(res1) == len(res2) + + @pytest.mark.skip + def test_get_items_from_year(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_get_items_last90(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_get_items_last365(self, df_collection_data_type, offset, df_collection): + pass diff --git a/tests/incite/collections/test_df_collection_thl_web_ledger.py b/tests/incite/collections/test_df_collection_thl_web_ledger.py new file mode 100644 index 0000000..599d979 --- /dev/null +++ b/tests/incite/collections/test_df_collection_thl_web_ledger.py @@ -0,0 +1,32 @@ +# def test_loaded(self, client_no_amm, collection, new_user_fixture, pop_ledger_merge): +# collection._client = client_no_amm +# +# teardown_events(collection) +# THL_LM.create_main_accounts() +# +# for item in collection.items: +# populate_events(item, user=new_user_fixture) +# item.initial_load() +# +# ddf = collection.ddf( +# force_rr_latest=False, +# include_partial=True, +# filters=[ +# ("created", ">=", collection.start), +# ("created", "<", collection.finished), +# ], +# ) +# +# assert isinstance(ddf, dd.DataFrame) +# df = client_no_amm.compute(collections=ddf, sync=True) +# assert isinstance(df, pd.DataFrame) +# +# # Simple validation check(s) +# assert not df.tx_id.is_unique +# df["net"] = df.direction * df.amount +# assert df.groupby("tx_id").net.sum().sum() == 0 +# +# teardown_events(collection) +# +# +# diff --git a/tests/incite/mergers/__init__.py b/tests/incite/mergers/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/incite/mergers/__init__.py diff --git a/tests/incite/mergers/foundations/__init__.py b/tests/incite/mergers/foundations/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/incite/mergers/foundations/__init__.py diff --git a/tests/incite/mergers/foundations/test_enriched_session.py b/tests/incite/mergers/foundations/test_enriched_session.py new file mode 100644 index 0000000..ec15d38 --- /dev/null +++ b/tests/incite/mergers/foundations/test_enriched_session.py @@ -0,0 +1,138 @@ +from datetime import timedelta, timezone, datetime +from decimal import Decimal +from itertools import product +from typing import Optional + +from generalresearch.incite.schemas.admin_responses import ( + AdminPOPSessionSchema, +) + +import dask.dataframe as dd +import pandas as pd +import pytest + +from test_utils.incite.collections.conftest import ( + session_collection, + wall_collection, +) + + +@pytest.mark.parametrize( + argnames="offset, duration", + argvalues=list( + product( + ["12h", "3D"], + [timedelta(days=5)], + ) + ), +) +class TestEnrichedSession: + + def test_base( + self, + client_no_amm, + product, + user_factory, + wall_collection, + session_collection, + enriched_session_merge, + thl_web_rr, + delete_df_collection, + incite_item_factory, + ): + from generalresearch.models.thl.user import User + + delete_df_collection(coll=session_collection) + + u1: User = user_factory(product=product, created=session_collection.start) + + for item in session_collection.items: + incite_item_factory(item=item, user=u1) + item.initial_load() + + for item in wall_collection.items: + item.initial_load() + + enriched_session_merge.build( + client=client_no_amm, + wall_coll=wall_collection, + session_coll=session_collection, + pg_config=thl_web_rr, + ) + + # -- + + ddf = enriched_session_merge.ddf() + assert isinstance(ddf, dd.DataFrame) + + df = client_no_amm.compute(collections=ddf, sync=True) + assert isinstance(df, pd.DataFrame) + + assert not df.empty + + # -- Teardown + delete_df_collection(session_collection) + + +class TestEnrichedSessionAdmin: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2020, month=3, day=14, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "1d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return timedelta(days=5) + + def test_to_admin_response( + self, + event_report_request, + enriched_session_merge, + client_no_amm, + wall_collection, + session_collection, + thl_web_rr, + session_report_request, + user_factory, + start, + session_factory, + product_factory, + delete_df_collection, + ): + delete_df_collection(coll=wall_collection) + delete_df_collection(coll=session_collection) + + p1 = product_factory() + p2 = product_factory() + + for p in [p1, p2]: + u = user_factory(product=p) + for i in range(50): + s = session_factory( + user=u, + wall_count=1, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(minutes=i, seconds=1), + ) + wall_collection.initial_load(client=None, sync=True) + session_collection.initial_load(client=None, sync=True) + + enriched_session_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + + df = enriched_session_merge.to_admin_response( + rr=session_report_request, client=client_no_amm + ) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert isinstance(AdminPOPSessionSchema.validate(df), pd.DataFrame) + assert df.index.get_level_values(1).nunique() == 2 diff --git a/tests/incite/mergers/foundations/test_enriched_task_adjust.py b/tests/incite/mergers/foundations/test_enriched_task_adjust.py new file mode 100644 index 0000000..96c214f --- /dev/null +++ b/tests/incite/mergers/foundations/test_enriched_task_adjust.py @@ -0,0 +1,76 @@ +from datetime import timedelta +from itertools import product as iter_product + +import dask.dataframe as dd +import pandas as pd +import pytest + +from test_utils.incite.collections.conftest import ( + wall_collection, + task_adj_collection, + session_collection, +) +from test_utils.incite.mergers.conftest import enriched_wall_merge + + +@pytest.mark.parametrize( + argnames="offset, duration,", + argvalues=list( + iter_product( + ["12h", "3D"], + [timedelta(days=5)], + ) + ), +) +class TestEnrichedTaskAdjust: + + @pytest.mark.skip + def test_base( + self, + client_no_amm, + user_factory, + product, + task_adj_collection, + wall_collection, + session_collection, + enriched_wall_merge, + enriched_task_adjust_merge, + incite_item_factory, + delete_df_collection, + thl_web_rr, + ): + from generalresearch.models.thl.user import User + + # -- Build & Setup + delete_df_collection(coll=session_collection) + u1: User = user_factory(product=product) + + for item in session_collection.items: + incite_item_factory(user=u1, item=item) + item.initial_load() + for item in wall_collection.items: + item.initial_load() + + enriched_wall_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + + enriched_task_adjust_merge.build( + client=client_no_amm, + task_adjust_coll=task_adj_collection, + enriched_wall=enriched_wall_merge, + pg_config=thl_web_rr, + ) + + # -- + + ddf = enriched_task_adjust_merge.ddf() + assert isinstance(ddf, dd.DataFrame) + + df = client_no_amm.compute(collections=ddf, sync=True) + assert isinstance(df, pd.DataFrame) + + assert not df.empty diff --git a/tests/incite/mergers/foundations/test_enriched_wall.py b/tests/incite/mergers/foundations/test_enriched_wall.py new file mode 100644 index 0000000..8f4995b --- /dev/null +++ b/tests/incite/mergers/foundations/test_enriched_wall.py @@ -0,0 +1,236 @@ +from datetime import timedelta, timezone, datetime +from decimal import Decimal +from itertools import product as iter_product +from typing import Optional + +import dask.dataframe as dd +import pandas as pd +import pytest + +# noinspection PyUnresolvedReferences +from distributed.utils_test import ( + gen_cluster, + client_no_amm, + loop, + loop_in_thread, + cleanup, + cluster_fixture, + client, +) + +from generalresearch.incite.mergers.foundations.enriched_wall import ( + EnrichedWallMergeItem, +) +from test_utils.incite.collections.conftest import ( + session_collection, + wall_collection, +) +from test_utils.incite.conftest import incite_item_factory +from test_utils.incite.mergers.conftest import ( + enriched_wall_merge, +) + + +@pytest.mark.parametrize( + argnames="offset, duration", + argvalues=list(iter_product(["48h", "3D"], [timedelta(days=5)])), +) +class TestEnrichedWall: + + def test_base( + self, + client_no_amm, + product, + user_factory, + wall_collection, + thl_web_rr, + session_collection, + enriched_wall_merge, + delete_df_collection, + incite_item_factory, + ): + from generalresearch.models.thl.user import User + + # -- Build & Setup + delete_df_collection(coll=session_collection) + delete_df_collection(coll=wall_collection) + u1: User = user_factory(product=product, created=session_collection.start) + + for item in session_collection.items: + incite_item_factory(item=item, user=u1) + item.initial_load() + + for item in wall_collection.items: + item.initial_load() + + enriched_wall_merge.build( + client=client_no_amm, + wall_coll=wall_collection, + session_coll=session_collection, + pg_config=thl_web_rr, + ) + + # -- + + ddf = enriched_wall_merge.ddf() + assert isinstance(ddf, dd.DataFrame) + + df = client_no_amm.compute(collections=ddf, sync=True) + assert isinstance(df, pd.DataFrame) + + assert not df.empty + + def test_base_item( + self, + client_no_amm, + product, + user_factory, + wall_collection, + session_collection, + enriched_wall_merge, + delete_df_collection, + thl_web_rr, + incite_item_factory, + ): + # -- Build & Setup + delete_df_collection(coll=session_collection) + u = user_factory(product=product, created=session_collection.start) + + for item in session_collection.items: + incite_item_factory(item=item, user=u) + item.initial_load() + for item in wall_collection.items: + item.initial_load() + + enriched_wall_merge.build( + client=client_no_amm, + wall_coll=wall_collection, + session_coll=session_collection, + pg_config=thl_web_rr, + ) + + # -- + + for item in enriched_wall_merge.items: + assert isinstance(item, EnrichedWallMergeItem) + + path = item.path + + try: + modified_time1 = path.stat().st_mtime + except (Exception,): + modified_time1 = 0 + + item.build( + client=client_no_amm, + wall_coll=wall_collection, + session_coll=session_collection, + pg_config=thl_web_rr, + ) + modified_time2 = path.stat().st_mtime + + # Merger Items can't be updated unless it's a partial, confirm + # that even after attempting to rebuild, it doesn't re-touch + # the file + assert modified_time2 == modified_time1 + + # def test_admin_pop_session_device_type(ew_merge_setup): + # self.build() + # + # rr = ReportRequest( + # report_type=ReportType.POP_EVENT, + # index0="started", + # index1="device_type", + # freq="min", + # start=start, + # ) + # + # df, categories, updated = self.instance.to_admin_response( + # rr=rr, product_ids=[self.product.id], client=client + # ) + # + # assert isinstance(df, pd.DataFrame) + # device_types_str = [str(e.value) for e in DeviceType] + # device_types = df.index.get_level_values(1).values + # assert all([dt in device_types_str for dt in device_types]) + + +class TestEnrichedWallToAdmin: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2020, month=3, day=14, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "1d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return timedelta(days=5) + + def test_empty(self, enriched_wall_merge, client_no_amm, start): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest.model_validate({"interval": "5min", "start": start}) + + res = enriched_wall_merge.to_admin_response( + rr=rr, + client=client_no_amm, + ) + + assert isinstance(res, pd.DataFrame) + + assert res.empty + assert len(res.columns) > 5 + + def test_to_admin_response( + self, + event_report_request, + enriched_wall_merge, + client_no_amm, + wall_collection, + session_collection, + thl_web_rr, + user, + session_factory, + delete_df_collection, + product_factory, + user_factory, + start, + ): + delete_df_collection(coll=wall_collection) + delete_df_collection(coll=session_collection) + + p1 = product_factory() + p2 = product_factory() + + for p in [p1, p2]: + u = user_factory(product=p) + for i in range(50): + s = session_factory( + user=u, + wall_count=2, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(minutes=i, seconds=1), + ) + + wall_collection.initial_load(client=None, sync=True) + session_collection.initial_load(client=None, sync=True) + + enriched_wall_merge.build( + client=client_no_amm, + wall_coll=wall_collection, + session_coll=session_collection, + pg_config=thl_web_rr, + ) + + df = enriched_wall_merge.to_admin_response( + rr=event_report_request, client=client_no_amm + ) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + # assert len(df) == 1 + # assert user.product_id == df.reset_index().loc[0, "index1"] + assert df.index.get_level_values(1).nunique() == 2 diff --git a/tests/incite/mergers/foundations/test_user_id_product.py b/tests/incite/mergers/foundations/test_user_id_product.py new file mode 100644 index 0000000..f96bfb4 --- /dev/null +++ b/tests/incite/mergers/foundations/test_user_id_product.py @@ -0,0 +1,73 @@ +from datetime import timedelta, datetime, timezone +from itertools import product + +import pandas as pd +import pytest + +# noinspection PyUnresolvedReferences +from distributed.utils_test import ( + gen_cluster, + client_no_amm, + loop, + loop_in_thread, + cleanup, + cluster_fixture, + client, +) + +from generalresearch.incite.mergers.foundations.user_id_product import ( + UserIdProductMergeItem, +) +from test_utils.incite.mergers.conftest import user_id_product_merge + + +@pytest.mark.parametrize( + argnames="offset, duration, start", + argvalues=list( + product( + ["12h", "3D"], + [timedelta(days=5)], + [ + (datetime.now(tz=timezone.utc) - timedelta(days=35)).replace( + microsecond=0 + ) + ], + ) + ), +) +class TestUserIDProduct: + + @pytest.mark.skip + def test_base(self, client_no_amm, user_id_product_merge): + ddf = user_id_product_merge.ddf() + df = client_no_amm.compute(collections=ddf, sync=True) + assert isinstance(df, pd.DataFrame) + assert not df.empty + + @pytest.mark.skip + def test_base_item(self, client_no_amm, user_id_product_merge, user_collection): + assert len(user_id_product_merge.items) == 1 + + for item in user_id_product_merge.items: + assert isinstance(item, UserIdProductMergeItem) + + path = item.path + + try: + modified_time1 = path.stat().st_mtime + except (Exception,): + modified_time1 = 0 + + user_id_product_merge.build(client=client_no_amm, user_coll=user_collection) + modified_time2 = path.stat().st_mtime + + assert modified_time2 > modified_time1 + + @pytest.mark.skip + def test_read(self, client_no_amm, user_id_product_merge): + users_ddf = user_id_product_merge.ddf() + df = client_no_amm.compute(collections=users_ddf, sync=True) + + assert isinstance(df, pd.DataFrame) + assert len(df.columns) == 1 + assert str(df.product_id.dtype) == "category" diff --git a/tests/incite/mergers/test_merge_collection.py b/tests/incite/mergers/test_merge_collection.py new file mode 100644 index 0000000..692cac3 --- /dev/null +++ b/tests/incite/mergers/test_merge_collection.py @@ -0,0 +1,102 @@ +from datetime import datetime, timezone, timedelta +from itertools import product + +import pandas as pd +import pytest +from pandera import DataFrameSchema + +from generalresearch.incite.mergers import ( + MergeCollection, + MergeType, +) +from test_utils.incite.conftest import mnt_filepath + +merge_types = list(e for e in MergeType if e != MergeType.TEST) + + +@pytest.mark.parametrize( + argnames="merge_type, offset, duration, start", + argvalues=list( + product( + merge_types, + ["5min", "6h", "14D"], + [timedelta(days=30)], + [ + (datetime.now(tz=timezone.utc) - timedelta(days=35)).replace( + microsecond=0 + ) + ], + ) + ), +) +class TestMergeCollection: + + def test_init(self, mnt_filepath, merge_type, offset, duration, start): + with pytest.raises(expected_exception=ValueError) as cm: + MergeCollection(archive_path=mnt_filepath.data_src) + assert "Must explicitly provide a merge_type" in str(cm.value) + + instance = MergeCollection( + merge_type=merge_type, + archive_path=mnt_filepath.archive_path(enum_type=merge_type), + ) + assert instance.merge_type == merge_type + + def test_items(self, mnt_filepath, merge_type, offset, duration, start): + instance = MergeCollection( + merge_type=merge_type, + offset=offset, + start=start, + finished=start + duration, + archive_path=mnt_filepath.archive_path(enum_type=merge_type), + ) + + assert len(instance.interval_range) == len(instance.items) + + def test_progress(self, mnt_filepath, merge_type, offset, duration, start): + instance = MergeCollection( + merge_type=merge_type, + offset=offset, + start=start, + finished=start + duration, + archive_path=mnt_filepath.archive_path(enum_type=merge_type), + ) + + assert isinstance(instance.progress, pd.DataFrame) + assert instance.progress.shape[0] > 0 + assert instance.progress.shape[1] == 7 + assert instance.progress["group_by"].isnull().all() + + def test_schema(self, mnt_filepath, merge_type, offset, duration, start): + instance = MergeCollection( + merge_type=merge_type, + archive_path=mnt_filepath.archive_path(enum_type=merge_type), + ) + + assert isinstance(instance._schema, DataFrameSchema) + + def test_load(self, mnt_filepath, merge_type, offset, duration, start): + instance = MergeCollection( + merge_type=merge_type, + start=start, + finished=start + duration, + offset=offset, + archive_path=mnt_filepath.archive_path(enum_type=merge_type), + ) + + # Confirm that there are no archives available yet + assert instance.progress.has_archive.eq(False).all() + + def test_get_items(self, mnt_filepath, merge_type, offset, duration, start): + instance = MergeCollection( + start=start, + finished=start + duration, + offset=offset, + merge_type=merge_type, + archive_path=mnt_filepath.archive_path(enum_type=merge_type), + ) + + # with pytest.raises(expected_exception=ResourceWarning) as cm: + res = instance.get_items_last365() + # assert "has missing archives", str(cm.value) + assert len(res) == len(instance.items) diff --git a/tests/incite/mergers/test_merge_collection_item.py b/tests/incite/mergers/test_merge_collection_item.py new file mode 100644 index 0000000..96f8789 --- /dev/null +++ b/tests/incite/mergers/test_merge_collection_item.py @@ -0,0 +1,66 @@ +from datetime import datetime, timezone, timedelta +from itertools import product +from pathlib import PurePath + +import pytest + +from generalresearch.incite.mergers import MergeCollectionItem, MergeType +from generalresearch.incite.mergers.foundations.enriched_session import ( + EnrichedSessionMerge, +) +from generalresearch.incite.mergers.foundations.enriched_wall import ( + EnrichedWallMerge, +) +from test_utils.incite.mergers.conftest import merge_collection + + +@pytest.mark.parametrize( + argnames="merge_type, offset, duration", + argvalues=list( + product( + [MergeType.ENRICHED_SESSION, MergeType.ENRICHED_WALL], + ["1h"], + [timedelta(days=1)], + ) + ), +) +class TestMergeCollectionItem: + + def test_file_naming(self, merge_collection, offset, duration, start): + assert len(merge_collection.items) == 25 + + items: list[MergeCollectionItem] = merge_collection.items + + for i in items: + i: MergeCollectionItem + + assert isinstance(i.path, PurePath) + assert i.path.name == i.filename + + assert i._collection.merge_type.name.lower() in i.filename + assert i._collection.offset in i.filename + assert i.start.strftime("%Y-%m-%d-%H-%M-%S") in i.filename + + def test_archives(self, merge_collection, offset, duration, start): + assert len(merge_collection.items) == 25 + + for i in merge_collection.items: + assert not i.has_archive() + assert not i.has_empty() + assert not i.is_empty() + assert not i.has_partial_archive() + assert i.has_archive() == i.path_exists(generic_path=i.path) + + res = set([i.should_archive() for i in merge_collection.items]) + assert len(res) == 1 + + def test_item_to_archive(self, merge_collection, offset, duration, start): + for item in merge_collection.items: + item: MergeCollectionItem + assert not item.has_archive() + + # TODO: setup build methods + # ddf = self.build + # saved = instance.to_archive(ddf=ddf) + # self.assertTrue(saved) + # self.assertTrue(instance.has_archive()) diff --git a/tests/incite/mergers/test_pop_ledger.py b/tests/incite/mergers/test_pop_ledger.py new file mode 100644 index 0000000..6f96108 --- /dev/null +++ b/tests/incite/mergers/test_pop_ledger.py @@ -0,0 +1,307 @@ +from datetime import timedelta, datetime, timezone +from itertools import product as iter_product +from typing import Optional + +import pandas as pd +import pytest +from distributed.utils_test import client_no_amm + +from generalresearch.incite.schemas.mergers.pop_ledger import ( + numerical_col_names, +) +from test_utils.incite.collections.conftest import ledger_collection +from test_utils.incite.conftest import mnt_filepath, incite_item_factory +from test_utils.incite.mergers.conftest import pop_ledger_merge +from test_utils.managers.ledger.conftest import create_main_accounts + + +@pytest.mark.parametrize( + argnames="offset, duration", + argvalues=list( + iter_product( + ["12h", "3D"], + [timedelta(days=4)], + ) + ), +) +class TestMergePOPLedger: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2020, month=3, day=14, tzinfo=timezone.utc) + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return timedelta(days=5) + + def test_base( + self, + client_no_amm, + ledger_collection, + pop_ledger_merge, + product, + user_factory, + create_main_accounts, + thl_lm, + delete_df_collection, + incite_item_factory, + delete_ledger_db, + ): + from generalresearch.models.thl.ledger import LedgerAccount + + u = user_factory(product=product, created=ledger_collection.start) + + # -- Build & Setup + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + # assert ledger_collection.start is None + # assert ledger_collection.offset is None + + for item in ledger_collection.items: + incite_item_factory(item=item, user=u) + item.initial_load() + + # Confirm any of the items are archived + assert ledger_collection.progress.has_archive.eq(True).all() + + pop_ledger_merge.build( + client=client_no_amm, + ledger_coll=ledger_collection, + ) + # assert pop_ledger_merge.progress.has_archive.eq(True).all() + + ddf = pop_ledger_merge.ddf() + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + + # -- + + user_wallet_account: LedgerAccount = thl_lm.get_account_or_create_user_wallet( + user=u + ) + cash_account: LedgerAccount = thl_lm.get_account_cash() + rev_account: LedgerAccount = thl_lm.get_account_task_complete_revenue() + + item_finishes = [i.finish for i in ledger_collection.items] + item_finishes.sort(reverse=True) + last_item_finish = item_finishes[0] + + # Pure SQL based lookups + cash_balance: int = thl_lm.get_account_balance(account=cash_account) + rev_balance: int = thl_lm.get_account_balance(account=rev_account) + assert cash_balance > rev_balance + + # (1) Test Cash Account + ddf = pop_ledger_merge.ddf( + columns=numerical_col_names, + filters=[ + ("account_id", "==", cash_account.uuid), + ("time_idx", ">=", ledger_collection.start), + ("time_idx", "<", last_item_finish), + ], + ) + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + assert df["mp_payment.CREDIT"].sum() == 0 + assert cash_balance > 0 + assert df["mp_payment.DEBIT"].sum() == cash_balance + + # (2) Test Revenue Account + ddf = pop_ledger_merge.ddf( + columns=numerical_col_names, + filters=[ + ("account_id", "==", rev_account.uuid), + ("time_idx", ">=", ledger_collection.start), + ("time_idx", "<", last_item_finish), + ], + ) + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + assert rev_balance == 0 + assert df["bp_payment.CREDIT"].sum() == 0 + assert df["mp_payment.DEBIT"].sum() == 0 + assert df["mp_payment.CREDIT"].sum() > 0 + + # -- Cleanup + delete_ledger_db() + + def test_pydantic_init( + self, + client_no_amm, + ledger_collection, + pop_ledger_merge, + mnt_filepath, + product, + user_factory, + create_main_accounts, + offset, + duration, + start, + thl_lm, + incite_item_factory, + delete_df_collection, + delete_ledger_db, + session_collection, + ): + from generalresearch.models.thl.ledger import LedgerAccount + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.finance import ProductBalances + + u = user_factory(product=product, created=session_collection.start) + + assert ledger_collection.finished is not None + assert isinstance(u.product, Product) + delete_ledger_db() + create_main_accounts(), + delete_df_collection(coll=ledger_collection) + + bp_account: LedgerAccount = thl_lm.get_account_or_create_bp_wallet( + product=u.product + ) + cash_account: LedgerAccount = thl_lm.get_account_cash() + rev_account: LedgerAccount = thl_lm.get_account_task_complete_revenue() + + for item in ledger_collection.items: + incite_item_factory(item=item, user=u) + item.initial_load(overwrite=True) + + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + item_finishes = [i.finish for i in ledger_collection.items] + item_finishes.sort(reverse=True) + last_item_finish = item_finishes[0] + + # (1) Filter by the Product Account, this means no cash_account, or + # rev_account transactions will be present in here... + ddf = pop_ledger_merge.ddf( + columns=numerical_col_names + ["time_idx"], + filters=[ + ("account_id", "==", bp_account.uuid), + ("time_idx", ">=", ledger_collection.start), + ("time_idx", "<", last_item_finish), + ], + ) + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + df = df.set_index("time_idx") + assert not df.empty + + instance = ProductBalances.from_pandas(input_data=df.sum()) + assert instance.payout == instance.net == instance.bp_payment_credit + assert instance.available_balance < instance.net + assert instance.available_balance + instance.retainer == instance.net + assert instance.balance == thl_lm.get_account_balance(bp_account) + assert df["bp_payment.CREDIT"].sum() == thl_lm.get_account_balance(bp_account) + + # (2) Filter by the Cash Account + ddf = pop_ledger_merge.ddf( + columns=numerical_col_names + ["time_idx"], + filters=[ + ("account_id", "==", cash_account.uuid), + ("time_idx", ">=", ledger_collection.start), + ("time_idx", "<", last_item_finish), + ], + ) + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + cash_balance: int = thl_lm.get_account_balance(account=cash_account) + assert df["bp_payment.CREDIT"].sum() == 0 + assert cash_balance > 0 + assert df["mp_payment.CREDIT"].sum() == 0 + assert df["mp_payment.DEBIT"].sum() == cash_balance + + # (2) Filter by the Revenue Account + ddf = pop_ledger_merge.ddf( + columns=numerical_col_names + ["time_idx"], + filters=[ + ("account_id", "==", rev_account.uuid), + ("time_idx", ">=", ledger_collection.start), + ("time_idx", "<", last_item_finish), + ], + ) + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + rev_balance: int = thl_lm.get_account_balance(account=rev_account) + assert rev_balance == 0 + assert df["bp_payment.CREDIT"].sum() == 0 + assert df["mp_payment.DEBIT"].sum() == 0 + assert df["mp_payment.CREDIT"].sum() > 0 + + def test_resample( + self, + client_no_amm, + ledger_collection, + pop_ledger_merge, + mnt_filepath, + user_factory, + product, + create_main_accounts, + offset, + duration, + start, + thl_lm, + delete_df_collection, + incite_item_factory, + ): + from generalresearch.models.thl.user import User + + assert ledger_collection.finished is not None + delete_df_collection(coll=ledger_collection) + u1: User = user_factory(product=product) + + bp_account = thl_lm.get_account_or_create_bp_wallet(product=u1.product) + + for item in ledger_collection.items: + incite_item_factory(user=u1, item=item) + item.initial_load(overwrite=True) + + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + item_finishes = [i.finish for i in ledger_collection.items] + item_finishes.sort(reverse=True) + last_item_finish = item_finishes[0] + + ddf = pop_ledger_merge.ddf( + columns=numerical_col_names + ["time_idx"], + filters=[ + ("account_id", "==", bp_account.uuid), + ("time_idx", ">=", ledger_collection.start), + ("time_idx", "<", last_item_finish), + ], + ) + df = client_no_amm.compute(collections=ddf, sync=True) + assert isinstance(df, pd.DataFrame) + assert isinstance(df.index, pd.Index) + assert not isinstance(df.index, pd.RangeIndex) + + # Now change the index so we can easily resample it + df = df.set_index("time_idx") + assert isinstance(df.index, pd.Index) + assert isinstance(df.index, pd.DatetimeIndex) + + bp_account_balance = thl_lm.get_account_balance(account=bp_account) + + # Initial sum + initial_sum = df.sum().sum() + # assert len(df) == 48 # msg="Original df should be 48 rows" + + # Original (1min) to 5min + df_5min = df.resample("5min").sum() + # assert len(df_5min) == 12 + assert initial_sum == df_5min.sum().sum() + + # 30min + df_30min = df.resample("30min").sum() + # assert len(df_30min) == 2 + assert initial_sum == df_30min.sum().sum() + + # 1hr + df_1hr = df.resample("1h").sum() + # assert len(df_1hr) == 1 + assert initial_sum == df_1hr.sum().sum() + + # 1 day + df_1day = df.resample("1d").sum() + # assert len(df_1day) == 1 + assert initial_sum == df_1day.sum().sum() diff --git a/tests/incite/mergers/test_ym_survey_merge.py b/tests/incite/mergers/test_ym_survey_merge.py new file mode 100644 index 0000000..4c2df6b --- /dev/null +++ b/tests/incite/mergers/test_ym_survey_merge.py @@ -0,0 +1,125 @@ +from datetime import timedelta, timezone, datetime +from itertools import product + +import pandas as pd +import pytest + +# noinspection PyUnresolvedReferences +from distributed.utils_test import ( + gen_cluster, + client_no_amm, + loop, + loop_in_thread, + cleanup, + cluster_fixture, + client, +) + +from test_utils.incite.collections.conftest import wall_collection, session_collection +from test_utils.incite.mergers.conftest import ( + enriched_session_merge, + ym_survey_wall_merge, +) + + +@pytest.mark.parametrize( + argnames="offset, duration, start", + argvalues=list( + product( + ["12h", "3D"], + [timedelta(days=30)], + [ + (datetime.now(tz=timezone.utc) - timedelta(days=35)).replace( + microsecond=0 + ) + ], + ) + ), +) +class TestYMSurveyMerge: + """We override start, not because it's needed on the YMSurveyWall merge, + which operates on a rolling 10-day window, but because we don't want + to mock data in the wall collection and enriched_session_merge from + the 1800s and then wonder why there is no data available in the past + 10 days in the database. + """ + + def test_base( + self, + client_no_amm, + user_factory, + product, + ym_survey_wall_merge, + wall_collection, + session_collection, + enriched_session_merge, + delete_df_collection, + incite_item_factory, + thl_web_rr, + ): + from generalresearch.models.thl.user import User + + delete_df_collection(coll=session_collection) + user: User = user_factory(product=product, created=session_collection.start) + + # -- Build & Setup + assert ym_survey_wall_merge.start is None + assert ym_survey_wall_merge.offset == "10D" + + for item in session_collection.items: + incite_item_factory(item=item, user=user) + item.initial_load() + for item in wall_collection.items: + item.initial_load() + + # Confirm any of the items are archived + assert session_collection.progress.has_archive.eq(True).all() + assert wall_collection.progress.has_archive.eq(True).all() + + enriched_session_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + assert enriched_session_merge.progress.has_archive.eq(True).all() + + ddf = enriched_session_merge.ddf() + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + + # -- + + ym_survey_wall_merge.build( + client=client_no_amm, + wall_coll=wall_collection, + enriched_session=enriched_session_merge, + ) + assert ym_survey_wall_merge.progress.has_archive.eq(True).all() + + # -- + + ddf = ym_survey_wall_merge.ddf() + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + + # -- + assert df.product_id.nunique() == 1 + assert df.team_id.nunique() == 1 + assert df.source.nunique() > 1 + + started_min_ts = df.started.min() + started_max_ts = df.started.max() + + assert type(started_min_ts) is pd.Timestamp + assert type(started_max_ts) is pd.Timestamp + + started_min: datetime = datetime.fromisoformat(str(started_min_ts)) + started_max: datetime = datetime.fromisoformat(str(started_max_ts)) + + started_delta = started_max - started_min + assert started_delta >= timedelta(days=3) diff --git a/tests/incite/schemas/__init__.py b/tests/incite/schemas/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/incite/schemas/__init__.py diff --git a/tests/incite/schemas/test_admin_responses.py b/tests/incite/schemas/test_admin_responses.py new file mode 100644 index 0000000..43aa399 --- /dev/null +++ b/tests/incite/schemas/test_admin_responses.py @@ -0,0 +1,239 @@ +from datetime import datetime, timezone, timedelta +from random import sample +from typing import List + +import numpy as np +import pandas as pd +import pytest + +from generalresearch.incite.schemas import empty_dataframe_from_schema +from generalresearch.incite.schemas.admin_responses import ( + AdminPOPSchema, + SIX_HOUR_SECONDS, +) +from generalresearch.locales import Localelator + + +class TestAdminPOPSchema: + schema_df = empty_dataframe_from_schema(AdminPOPSchema) + countries = list(Localelator().get_all_countries())[:5] + dates = [datetime(year=2024, month=1, day=i, tzinfo=None) for i in range(1, 10)] + + @classmethod + def assign_valid_vals(cls, df: pd.DataFrame) -> pd.DataFrame: + for c in df.columns: + check_attrs: dict = AdminPOPSchema.columns[c].checks[0].statistics + df[c] = np.random.randint( + check_attrs["min_value"], check_attrs["max_value"], df.shape[0] + ) + + return df + + def test_empty(self): + with pytest.raises(Exception): + AdminPOPSchema.validate(pd.DataFrame()) + + def test_new_empty_df(self): + df = empty_dataframe_from_schema(AdminPOPSchema) + + assert isinstance(df, pd.DataFrame) + assert isinstance(df.index, pd.MultiIndex) + assert df.columns.size == len(AdminPOPSchema.columns) + + def test_valid(self): + # (1) Works with raw naive datetime + dates = [ + datetime(year=2024, month=1, day=i, tzinfo=None).isoformat() + for i in range(1, 10) + ] + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[dates, self.countries], names=["index0", "index1"] + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + df = AdminPOPSchema.validate(df) + assert isinstance(df, pd.DataFrame) + + # (2) Works with isoformat naive datetime + dates = [datetime(year=2024, month=1, day=i, tzinfo=None) for i in range(1, 10)] + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[dates, self.countries], names=["index0", "index1"] + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + df = AdminPOPSchema.validate(df) + assert isinstance(df, pd.DataFrame) + + def test_index_tz_parser(self): + tz_dates = [ + datetime(year=2024, month=1, day=i, tzinfo=timezone.utc) + for i in range(1, 10) + ] + + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[tz_dates, self.countries], names=["index0", "index1"] + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + # Initially, they're all set with a timezone + timestmaps: List[pd.Timestamp] = [i for i in df.index.get_level_values(0)] + assert all([ts.tz == timezone.utc for ts in timestmaps]) + + # After validation, the timezone is removed + df = AdminPOPSchema.validate(df) + timestmaps: List[pd.Timestamp] = [i for i in df.index.get_level_values(0)] + assert all([ts.tz is None for ts in timestmaps]) + + def test_index_tz_no_future_beyond_one_year(self): + now = datetime.now(tz=timezone.utc) + tz_dates = [now + timedelta(days=i * 365) for i in range(1, 10)] + + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[tz_dates, self.countries], names=["index0", "index1"] + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + with pytest.raises(Exception) as cm: + AdminPOPSchema.validate(df) + + assert ( + "Index 'index0' failed element-wise validator " + "number 0: less_than(" in str(cm.value) + ) + + def test_index_only_str(self): + # --- float64 to str! --- + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[self.dates, np.random.rand(1, 10)[0]], + names=["index0", "index1"], + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + vals = [i for i in df.index.get_level_values(1)] + assert all([isinstance(v, float) for v in vals]) + + df = AdminPOPSchema.validate(df, lazy=True) + + vals = [i for i in df.index.get_level_values(1)] + assert all([isinstance(v, str) for v in vals]) + + # --- int to str --- + + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[self.dates, sample(range(100), 20)], + names=["index0", "index1"], + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + vals = [i for i in df.index.get_level_values(1)] + assert all([isinstance(v, int) for v in vals]) + + df = AdminPOPSchema.validate(df, lazy=True) + + vals = [i for i in df.index.get_level_values(1)] + assert all([isinstance(v, str) for v in vals]) + + # a = 1 + assert isinstance(df, pd.DataFrame) + + def test_invalid_parsing(self): + # (1) Timezones AND as strings will still parse correctly + tz_str_dates = [ + datetime( + year=2024, month=1, day=1, minute=i, tzinfo=timezone.utc + ).isoformat() + for i in range(1, 10) + ] + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[tz_str_dates, self.countries], + names=["index0", "index1"], + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + df = AdminPOPSchema.validate(df, lazy=True) + + assert isinstance(df, pd.DataFrame) + timestmaps: List[pd.Timestamp] = [i for i in df.index.get_level_values(0)] + assert all([ts.tz is None for ts in timestmaps]) + + # (2) Timezones are removed + dates = [ + datetime(year=2024, month=1, day=1, minute=i, tzinfo=timezone.utc) + for i in range(1, 10) + ] + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[dates, self.countries], names=["index0", "index1"] + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + # Has tz before validation, and none after + timestmaps: List[pd.Timestamp] = [i for i in df.index.get_level_values(0)] + assert all([ts.tz is timezone.utc for ts in timestmaps]) + + df = AdminPOPSchema.validate(df, lazy=True) + + timestmaps: List[pd.Timestamp] = [i for i in df.index.get_level_values(0)] + assert all([ts.tz is None for ts in timestmaps]) + + def test_clipping(self): + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[self.dates, self.countries], + names=["index0", "index1"], + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + df = AdminPOPSchema.validate(df) + assert df.elapsed_avg.max() < SIX_HOUR_SECONDS + + # Now that we know it's valid, break the elapsed avg + df["elapsed_avg"] = np.random.randint( + SIX_HOUR_SECONDS, SIX_HOUR_SECONDS + 10_000, df.shape[0] + ) + assert df.elapsed_avg.max() > SIX_HOUR_SECONDS + + # Confirm it doesn't fail if the values are greater, and that + # all the values are clipped to the max + df = AdminPOPSchema.validate(df) + assert df.elapsed_avg.eq(SIX_HOUR_SECONDS).all() + + def test_rounding(self): + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[self.dates, self.countries], + names=["index0", "index1"], + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + df["payout_avg"] = 2.123456789900002 + + assert df.payout_avg.sum() == 95.5555555455001 + + df = AdminPOPSchema.validate(df) + assert df.payout_avg.sum() == 95.40000000000003 diff --git a/tests/incite/schemas/test_thl_web.py b/tests/incite/schemas/test_thl_web.py new file mode 100644 index 0000000..7f4434b --- /dev/null +++ b/tests/incite/schemas/test_thl_web.py @@ -0,0 +1,70 @@ +import pandas as pd +import pytest +from pandera.errors import SchemaError + + +class TestWallSchema: + + def test_empty(self): + from generalresearch.incite.schemas.thl_web import THLWallSchema + + with pytest.raises(SchemaError): + THLWallSchema.validate(pd.DataFrame()) + + def test_index_missing(self): + from generalresearch.incite.schemas.thl_web import THLWallSchema + + df = pd.DataFrame(columns=THLWallSchema.columns.keys()) + + with pytest.raises(SchemaError) as cm: + THLWallSchema.validate(df) + + def test_no_rows(self): + from generalresearch.incite.schemas.thl_web import THLWallSchema + + df = pd.DataFrame(index=["uuid"], columns=THLWallSchema.columns.keys()) + + with pytest.raises(SchemaError) as cm: + THLWallSchema.validate(df) + + def test_new_empty_df(self): + from generalresearch.incite.schemas import empty_dataframe_from_schema + from generalresearch.incite.schemas.thl_web import THLWallSchema + + df = empty_dataframe_from_schema(THLWallSchema) + assert isinstance(df, pd.DataFrame) + assert df.columns.size == 20 + + +class TestSessionSchema: + + def test_empty(self): + from generalresearch.incite.schemas.thl_web import THLSessionSchema + + with pytest.raises(SchemaError): + THLSessionSchema.validate(pd.DataFrame()) + + def test_index_missing(self): + from generalresearch.incite.schemas.thl_web import THLSessionSchema + + df = pd.DataFrame(columns=THLSessionSchema.columns.keys()) + df.set_index("uuid", inplace=True) + + with pytest.raises(SchemaError) as cm: + THLSessionSchema.validate(df) + + def test_no_rows(self): + from generalresearch.incite.schemas.thl_web import THLSessionSchema + + df = pd.DataFrame(index=["id"], columns=THLSessionSchema.columns.keys()) + + with pytest.raises(SchemaError) as cm: + THLSessionSchema.validate(df) + + def test_new_empty_df(self): + from generalresearch.incite.schemas import empty_dataframe_from_schema + from generalresearch.incite.schemas.thl_web import THLSessionSchema + + df = empty_dataframe_from_schema(THLSessionSchema) + assert isinstance(df, pd.DataFrame) + assert df.columns.size == 21 diff --git a/tests/incite/test_collection_base.py b/tests/incite/test_collection_base.py new file mode 100644 index 0000000..497e5ab --- /dev/null +++ b/tests/incite/test_collection_base.py @@ -0,0 +1,318 @@ +from datetime import datetime, timezone, timedelta +from os.path import exists as pexists, join as pjoin +from pathlib import Path +from uuid import uuid4 + +import numpy as np +import pandas as pd +import pytest +from _pytest._code.code import ExceptionInfo + +from generalresearch.incite.base import CollectionBase +from test_utils.incite.conftest import mnt_filepath + +AGO_15min = (datetime.now(tz=timezone.utc) - timedelta(minutes=15)).replace( + microsecond=0 +) +AGO_1HR = (datetime.now(tz=timezone.utc) - timedelta(hours=1)).replace(microsecond=0) +AGO_2HR = (datetime.now(tz=timezone.utc) - timedelta(hours=2)).replace(microsecond=0) + + +class TestCollectionBase: + def test_init(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.df.empty is True + + def test_init_df(self, mnt_filepath): + # Only an empty pd.DataFrame can ever be provided + instance = CollectionBase( + df=pd.DataFrame({}), archive_path=mnt_filepath.data_src + ) + assert isinstance(instance.df, pd.DataFrame) + + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase( + df=pd.DataFrame(columns=[0, 1, 2]), archive_path=mnt_filepath.data_src + ) + assert "Do not provide a pd.DataFrame" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase( + df=pd.DataFrame(np.random.randint(100, size=(1000, 1)), columns=["A"]), + archive_path=mnt_filepath.data_src, + ) + assert "Do not provide a pd.DataFrame" in str(cm.value) + + def test_init_start(self, mnt_filepath): + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase( + start=datetime.now(tz=timezone.utc) - timedelta(days=10), + archive_path=mnt_filepath.data_src, + ) + assert "Collection.start must not have microseconds" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + tz = timezone(timedelta(hours=-5), "EST") + + CollectionBase( + start=datetime(year=2000, month=1, day=1, tzinfo=tz), + archive_path=mnt_filepath.data_src, + ) + assert "Timezone is not UTC" in str(cm.value) + + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.start == datetime( + year=2018, month=1, day=1, tzinfo=timezone.utc + ) + + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase( + start=AGO_2HR, offset="3h", archive_path=mnt_filepath.data_src + ) + assert "Offset must be equal to, or smaller the start timestamp" in str( + cm.value + ) + + def test_init_archive_path(self, mnt_filepath): + """DirectoryPath is apparently smart enough to confirm that the + directory path exists. + """ + + # (1) Basic, confirm an existing path works + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.archive_path == mnt_filepath.data_src + + # (2) It can't point to a file + file_path = Path(pjoin(mnt_filepath.data_src, f"{uuid4().hex}.zip")) + assert not pexists(file_path) + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase(archive_path=file_path) + assert "Path does not point to a directory" in str(cm.value) + + # (3) It doesn't create the directory if it doesn't exist + new_path = Path(pjoin(mnt_filepath.data_src, f"{uuid4().hex}/")) + assert not pexists(new_path) + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase(archive_path=new_path) + assert "Path does not point to a directory" in str(cm.value) + + def test_init_offset(self, mnt_filepath): + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase(offset="1:X", archive_path=mnt_filepath.data_src) + assert "Invalid offset alias provided. Please review:" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase(offset=f"59sec", archive_path=mnt_filepath.data_src) + assert "Must be equal to, or longer than 1 min" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase(offset=f"{365 * 101}d", archive_path=mnt_filepath.data_src) + assert "String should have at most 5 characters" in str(cm.value) + + +class TestCollectionBaseProperties: + + def test_items(self, mnt_filepath): + with pytest.raises(expected_exception=NotImplementedError) as cm: + cm: ExceptionInfo + instance = CollectionBase(archive_path=mnt_filepath.data_src) + x = instance.items + assert "Must override" in str(cm.value) + + def test_interval_range(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + # Private method requires the end parameter + with pytest.raises(expected_exception=AssertionError) as cm: + cm: ExceptionInfo + instance._interval_range(end=None) + assert "an end value must be provided" in str(cm.value) + + # End param must be same as started (which forces utc) + tz = timezone(timedelta(hours=-5), "EST") + with pytest.raises(expected_exception=AssertionError) as cm: + cm: ExceptionInfo + instance._interval_range(end=datetime.now(tz=tz)) + assert "Timezones must match" in str(cm.value) + + res = instance._interval_range(end=datetime.now(tz=timezone.utc)) + assert isinstance(res, pd.IntervalIndex) + assert res.closed_left + assert res.is_non_overlapping_monotonic + assert res.is_monotonic_increasing + assert res.is_unique + + def test_interval_range2(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert isinstance(instance.interval_range, list) + + # 1 hrs ago has 2 x 30min + the future 30min + OFFSET = "30min" + instance = CollectionBase( + start=AGO_1HR, offset=OFFSET, archive_path=mnt_filepath.data_src + ) + assert len(instance.interval_range) == 3 + assert instance.interval_range[0][0] == AGO_1HR + + # 1 hrs ago has 1 x 60min + the future 60min + OFFSET = "60min" + instance = CollectionBase( + start=AGO_1HR, offset=OFFSET, archive_path=mnt_filepath.data_src + ) + assert len(instance.interval_range) == 2 + + def test_progress(self, mnt_filepath): + with pytest.raises(expected_exception=NotImplementedError) as cm: + cm: ExceptionInfo + instance = CollectionBase( + start=AGO_15min, offset="3min", archive_path=mnt_filepath.data_src + ) + x = instance.progress + assert "Must override" in str(cm.value) + + def test_progress2(self, mnt_filepath): + instance = CollectionBase( + start=AGO_2HR, + offset="15min", + archive_path=mnt_filepath.data_src, + ) + assert instance.df.empty + + with pytest.raises(expected_exception=NotImplementedError) as cm: + df = instance.progress + assert "Must override" in str(cm.value) + + def test_items2(self, mnt_filepath): + """There can't be a test for this because the Items need a path whic + isn't possible in the generic form + """ + instance = CollectionBase( + start=AGO_1HR, offset="5min", archive_path=mnt_filepath.data_src + ) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + cm: ExceptionInfo + items = instance.items + assert "Must override" in str(cm.value) + + # item = items[-3] + # ddf = instance.ddf(items=[item], include_partial=True, force_rr_latest=False) + # df = item.validate_ddf(ddf=ddf) + # assert isinstance(df, pd.DataFrame) + # assert len(df.columns) == 16 + # assert str(df.product_id.dtype) == "object" + # assert str(ddf.product_id.dtype) == "string" + + def test_items3(self, mnt_filepath): + instance = CollectionBase( + start=AGO_2HR, + offset="15min", + archive_path=mnt_filepath.data_src, + ) + with pytest.raises(expected_exception=NotImplementedError) as cm: + item = instance.items[0] + assert "Must override" in str(cm.value) + + +class TestCollectionBaseMethodsCleanup: + def test_fetch_force_rr_latest(self, mnt_filepath): + coll = CollectionBase(archive_path=mnt_filepath.data_src) + + with pytest.raises(expected_exception=Exception) as cm: + cm: ExceptionInfo + coll.fetch_force_rr_latest(sources=[]) + assert "Must override" in str(cm.value) + + def test_fetch_all_paths(self, mnt_filepath): + coll = CollectionBase(archive_path=mnt_filepath.data_src) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + cm: ExceptionInfo + coll.fetch_all_paths( + items=None, force_rr_latest=False, include_partial=False + ) + assert "Must override" in str(cm.value) + + +class TestCollectionBaseMethodsCleanup: + @pytest.mark.skip + def test_cleanup_partials(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.cleanup_partials() is None # it doesn't return anything + + def test_clear_tmp_archives(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.clear_tmp_archives() is None # it doesn't return anything + + @pytest.mark.skip + def test_clear_corrupt_archives(self, mnt_filepath): + """TODO: expand this so it actually has corrupt archives that we + check to see if they're removed + """ + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.clear_corrupt_archives() is None # it doesn't return anything + + @pytest.mark.skip + def test_rebuild_symlinks(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.rebuild_symlinks() is None + + +class TestCollectionBaseMethodsSourceTiming: + + def test_get_item(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + i = pd.Interval(left=1, right=2, closed="left") + + with pytest.raises(expected_exception=NotImplementedError) as cm: + instance.get_item(interval=i) + assert "Must override" in str(cm.value) + + def test_get_item_start(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + + dt = datetime.now(tz=timezone.utc) + start = pd.Timestamp(dt) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + instance.get_item_start(start=start) + assert "Must override" in str(cm.value) + + def test_get_items(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + + dt = datetime.now(tz=timezone.utc) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + instance.get_items(since=dt) + assert "Must override" in str(cm.value) + + def test_get_items_from_year(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + instance.get_items_from_year(year=2020) + assert "Must override" in str(cm.value) + + def test_get_items_last90(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + instance.get_items_last90() + assert "Must override" in str(cm.value) + + def test_get_items_last365(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + instance.get_items_last365() + assert "Must override" in str(cm.value) diff --git a/tests/incite/test_collection_base_item.py b/tests/incite/test_collection_base_item.py new file mode 100644 index 0000000..e5d1d02 --- /dev/null +++ b/tests/incite/test_collection_base_item.py @@ -0,0 +1,223 @@ +from datetime import datetime, timezone +from os.path import join as pjoin +from pathlib import Path +from uuid import uuid4 + +import dask.dataframe as dd +import pandas as pd +import pytest +from pydantic import ValidationError + +from generalresearch.incite.base import CollectionItemBase + + +class TestCollectionItemBase: + def test_init(self): + dt = datetime.now(tz=timezone.utc).replace(microsecond=0) + + instance = CollectionItemBase() + instance2 = CollectionItemBase(start=dt) + + assert isinstance(instance, CollectionItemBase) + assert isinstance(instance2, CollectionItemBase) + + assert instance.start.second == instance2.start.second + assert 0 == instance.start.microsecond == instance2.start.microsecond + + def test_init_start(self): + dt = datetime.now(tz=timezone.utc) + + with pytest.raises(expected_exception=ValidationError) as cm: + CollectionItemBase(start=dt) + + assert "CollectionItem.start must not have microsecond precision" in str( + cm.value + ) + + +class TestCollectionItemBaseProperties: + + def test_finish(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.finish + + def test_interval(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.interval + + def test_filename(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=NotImplementedError) as cm: + res = instance.filename + + assert "Do not use CollectionItemBase directly" in str(cm.value) + + def test_partial_filename(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=NotImplementedError) as cm: + res = instance.filename + + assert "Do not use CollectionItemBase directly" in str(cm.value) + + def test_empty_filename(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=NotImplementedError) as cm: + res = instance.filename + + assert "Do not use CollectionItemBase directly" in str(cm.value) + + def test_path(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.path + + def test_partial_path(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.partial_path + + def test_empty_path(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.empty_path + + +class TestCollectionItemBaseMethods: + + @pytest.mark.skip + def test_next_numbered_path(self): + pass + + @pytest.mark.skip + def test_search_highest_numbered_path(self): + pass + + def test_tmp_filename(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=NotImplementedError) as cm: + res = instance.tmp_filename() + assert "Do not use CollectionItemBase directly" in str(cm.value) + + def test_tmp_path(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.tmp_path() + + def test_is_empty(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.is_empty() + + def test_has_empty(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.has_empty() + + def test_has_partial_archive(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.has_partial_archive() + + @pytest.mark.parametrize("include_empty", [True, False]) + def test_has_archive(self, include_empty): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.has_archive(include_empty=include_empty) + + def test_delete_archive_file(self, mnt_filepath): + path1 = Path(pjoin(mnt_filepath.data_src, f"{uuid4().hex}.zip")) + + # Confirm it doesn't exist, and that delete_archive() doesn't throw + # an error when trying to delete a non-existent file or folder + assert not path1.exists() + CollectionItemBase.delete_archive(generic_path=path1) + # TODO: LOG.warning(f"tried removing non-existent file: {generic_path}") + + # Create it, confirm it exists, delete it, and confirm it doesn't exist + path1.touch() + assert path1.exists() + CollectionItemBase.delete_archive(generic_path=path1) + assert not path1.exists() + + def test_delete_archive_dir(self, mnt_filepath): + path1 = Path(pjoin(mnt_filepath.data_src, f"{uuid4().hex}")) + + # Confirm it doesn't exist, and that delete_archive() doesn't throw + # an error when trying to delete a non-existent file or folder + assert not path1.exists() + CollectionItemBase.delete_archive(generic_path=path1) + # TODO: LOG.warning(f"tried removing non-existent file: {generic_path}") + + # Create it, confirm it exists, delete it, and confirm it doesn't exist + path1.mkdir() + assert path1.exists() + assert path1.is_dir() + CollectionItemBase.delete_archive(generic_path=path1) + assert not path1.exists() + + def test_should_archive(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.should_archive() + + def test_set_empty(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.set_empty() + + def test_valid_archive(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.valid_archive(generic_path=None, sample=None) + + +class TestCollectionItemBaseMethodsORM: + + @pytest.mark.skip + def test_from_archive(self): + pass + + @pytest.mark.parametrize("is_partial", [True, False]) + def test_to_archive(self, is_partial): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=NotImplementedError) as cm: + res = instance.to_archive( + ddf=dd.from_pandas(data=pd.DataFrame()), is_partial=is_partial + ) + assert "Must override" in str(cm.value) + + @pytest.mark.skip + def test__to_dict(self): + pass + + @pytest.mark.skip + def test_delete_partial(self): + pass + + @pytest.mark.skip + def test_cleanup_partials(self): + pass + + @pytest.mark.skip + def test_delete_dangling_partials(self): + pass diff --git a/tests/incite/test_grl_flow.py b/tests/incite/test_grl_flow.py new file mode 100644 index 0000000..c632f9a --- /dev/null +++ b/tests/incite/test_grl_flow.py @@ -0,0 +1,23 @@ +class TestGRLFlow: + + def test_init(self, mnt_filepath, thl_web_rr): + from generalresearch.incite.defaults import ( + ledger_df_collection, + task_df_collection, + pop_ledger as plm, + ) + + from generalresearch.incite.collections.thl_web import ( + LedgerDFCollection, + TaskAdjustmentDFCollection, + ) + from generalresearch.incite.mergers.pop_ledger import PopLedgerMerge + + ledger_df = ledger_df_collection(ds=mnt_filepath, pg_config=thl_web_rr) + assert isinstance(ledger_df, LedgerDFCollection) + + task_df = task_df_collection(ds=mnt_filepath, pg_config=thl_web_rr) + assert isinstance(task_df, TaskAdjustmentDFCollection) + + pop_ledger = plm(ds=mnt_filepath) + assert isinstance(pop_ledger, PopLedgerMerge) diff --git a/tests/incite/test_interval_idx.py b/tests/incite/test_interval_idx.py new file mode 100644 index 0000000..ea2bced --- /dev/null +++ b/tests/incite/test_interval_idx.py @@ -0,0 +1,23 @@ +import pandas as pd +from datetime import datetime, timezone, timedelta + + +class TestIntervalIndex: + + def test_init(self): + start = datetime(year=2000, month=1, day=1) + end = datetime(year=2000, month=1, day=10) + + iv_r: pd.IntervalIndex = pd.interval_range( + start=start, end=end, freq="1d", closed="left" + ) + assert isinstance(iv_r, pd.IntervalIndex) + assert len(iv_r.to_list()) == 9 + + # If the offset is longer than the end - start it will not + # error. It will simply have 0 rows. + iv_r: pd.IntervalIndex = pd.interval_range( + start=start, end=end, freq="30d", closed="left" + ) + assert isinstance(iv_r, pd.IntervalIndex) + assert len(iv_r.to_list()) == 0 diff --git a/tests/managers/__init__.py b/tests/managers/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/managers/__init__.py diff --git a/tests/managers/gr/__init__.py b/tests/managers/gr/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/managers/gr/__init__.py diff --git a/tests/managers/gr/test_authentication.py b/tests/managers/gr/test_authentication.py new file mode 100644 index 0000000..53b6931 --- /dev/null +++ b/tests/managers/gr/test_authentication.py @@ -0,0 +1,125 @@ +import logging +from random import randint +from uuid import uuid4 + +import pytest + +from generalresearch.models.gr.authentication import GRUser +from test_utils.models.conftest import gr_user + +SSO_ISSUER = "" + + +class TestGRUserManager: + + def test_create(self, gr_um): + from generalresearch.models.gr.authentication import GRUser + + user: GRUser = gr_um.create_dummy() + instance = gr_um.get_by_id(user.id) + assert user.id == instance.id + + instance2 = gr_um.get_by_id(user.id) + assert user.model_dump_json() == instance2.model_dump_json() + + def test_get_by_id(self, gr_user, gr_um): + with pytest.raises(expected_exception=ValueError) as cm: + gr_um.get_by_id(gr_user_id=999_999_999) + assert "GRUser not found" in str(cm.value) + + instance = gr_um.get_by_id(gr_user_id=gr_user.id) + assert instance.sub == gr_user.sub + + def test_get_by_sub(self, gr_user, gr_um): + with pytest.raises(expected_exception=ValueError) as cm: + gr_um.get_by_sub(sub=uuid4().hex) + assert "GRUser not found" in str(cm.value) + + instance = gr_um.get_by_sub(sub=gr_user.sub) + assert instance.id == gr_user.id + + def test_get_by_sub_or_create(self, gr_user, gr_um): + sub = f"{uuid4().hex}-{uuid4().hex}" + + with pytest.raises(expected_exception=ValueError) as cm: + gr_um.get_by_sub(sub=sub) + assert "GRUser not found" in str(cm.value) + + instance = gr_um.get_by_sub_or_create(sub=sub) + assert isinstance(instance, GRUser) + assert instance.sub == sub + + def test_get_all(self, gr_um): + res1 = gr_um.get_all() + assert isinstance(res1, list) + + gr_um.create_dummy() + res2 = gr_um.get_all() + assert len(res1) == len(res2) - 1 + + def test_get_by_team(self, gr_um): + res = gr_um.get_by_team(team_id=999_999_999) + assert isinstance(res, list) + assert res == [] + + def test_list_product_uuids(self, caplog, gr_user, gr_um, thl_web_rr): + with caplog.at_level(logging.WARNING): + gr_um.list_product_uuids(user=gr_user, thl_pg_config=thl_web_rr) + assert "prefetch not run" in caplog.text + + +class TestGRTokenManager: + + def test_create(self, gr_user, gr_tm): + assert gr_tm.create(user_id=gr_user.id) is None + + token = gr_tm.get_by_user_id(user_id=gr_user.id) + assert gr_user.id == token.user_id + + def test_get_by_user_id(self, gr_user, gr_tm): + assert gr_tm.create(user_id=gr_user.id) is None + + token = gr_tm.get_by_user_id(user_id=gr_user.id) + assert gr_user.id == token.user_id + + def test_prefetch_user(self, gr_user, gr_tm, gr_db, gr_redis_config): + from generalresearch.models.gr.authentication import GRToken + + gr_tm.create(user_id=gr_user.id) + + token: GRToken = gr_tm.get_by_user_id(user_id=gr_user.id) + assert token.user is None + + token.prefetch_user(pg_config=gr_db, redis_config=gr_redis_config) + assert token.user.id == gr_user.id + + def test_get_by_key(self, gr_user, gr_um, gr_tm): + gr_tm.create(user_id=gr_user.id) + token = gr_tm.get_by_user_id(user_id=gr_user.id) + + instance = gr_tm.get_by_key(api_key=token.key) + assert token.created == instance.created + + # Search for non-existent key + with pytest.raises(expected_exception=Exception) as cm: + gr_tm.get_by_key(api_key=uuid4().hex) + assert "No GRUser with token of " in str(cm.value) + + @pytest.mark.skip(reason="no idea how to actually test this...") + def test_get_by_sso_key(self, gr_user, gr_um, gr_tm, gr_redis_config): + from generalresearch.models.gr.authentication import GRToken + + api_key = "..." + jwks = { + # ... + } + + instance = gr_tm.get_by_key( + api_key=api_key, + jwks=jwks, + audience="...", + issuer=SSO_ISSUER, + gr_redis_config=gr_redis_config, + ) + + assert isinstance(instance, GRToken) diff --git a/tests/managers/gr/test_business.py b/tests/managers/gr/test_business.py new file mode 100644 index 0000000..7eb77f8 --- /dev/null +++ b/tests/managers/gr/test_business.py @@ -0,0 +1,150 @@ +from uuid import uuid4 + +import pytest + +from test_utils.models.conftest import business + + +class TestBusinessBankAccountManager: + + def test_init(self, business_bank_account_manager, gr_db): + assert business_bank_account_manager.pg_config == gr_db + + def test_create(self, business, business_bank_account_manager): + from generalresearch.models.gr.business import ( + TransferMethod, + BusinessBankAccount, + ) + + instance = business_bank_account_manager.create( + business_id=business.id, + uuid=uuid4().hex, + transfer_method=TransferMethod.ACH, + ) + assert isinstance(instance, BusinessBankAccount) + assert isinstance(instance.id, int) + + res = business_bank_account_manager.get_by_business_id( + business_id=instance.business_id + ) + assert isinstance(res, list) + assert len(res) == 1 + assert isinstance(res[0], BusinessBankAccount) + assert res[0].business_id == instance.business_id + + +class TestBusinessAddressManager: + + def test_create(self, business, business_address_manager): + from generalresearch.models.gr.business import BusinessAddress + + res = business_address_manager.create(uuid=uuid4().hex, business_id=business.id) + assert isinstance(res, BusinessAddress) + assert isinstance(res.id, int) + + +class TestBusinessManager: + + def test_create(self, business_manager): + from generalresearch.models.gr.business import Business + + instance = business_manager.create_dummy() + assert isinstance(instance, Business) + assert isinstance(instance.id, int) + + def test_get_or_create(self, business_manager): + uuid_key = uuid4().hex + + assert business_manager.get_by_uuid(business_uuid=uuid_key) is None + + instance = business_manager.get_or_create( + uuid=uuid_key, + name=f"name-{uuid4().hex[:6]}", + ) + + res = business_manager.get_by_uuid(business_uuid=uuid_key) + assert res.id == instance.id + + def test_get_all(self, business_manager): + res1 = business_manager.get_all() + assert isinstance(res1, list) + + business_manager.create_dummy() + res2 = business_manager.get_all() + assert len(res1) == len(res2) - 1 + + @pytest.mark.skip(reason="TODO") + def test_get_by_team(self): + pass + + def test_get_by_user_id( + self, business_manager, gr_user, team_manager, membership_manager + ): + res = business_manager.get_by_user_id(user_id=gr_user.id) + assert len(res) == 0 + + # Create a Business, but don't add it to anything + b1 = business_manager.create_dummy() + res = business_manager.get_by_user_id(user_id=gr_user.id) + assert len(res) == 0 + + # Create a Team, but don't create any Memberships + t1 = team_manager.create_dummy() + res = business_manager.get_by_user_id(user_id=gr_user.id) + assert len(res) == 0 + + # Create a Membership for the gr_user to the Team... but it doesn't + # matter because the Team doesn't have any Business yet + m1 = membership_manager.create(team=t1, gr_user=gr_user) + res = business_manager.get_by_user_id(user_id=gr_user.id) + assert len(res) == 0 + + # Add the Business to the Team... now the Business should be available + # to the gr_user + team_manager.add_business(team=t1, business=b1) + res = business_manager.get_by_user_id(user_id=gr_user.id) + assert len(res) == 1 + + # Add another Business to the Team! + b2 = business_manager.create_dummy() + team_manager.add_business(team=t1, business=b2) + res = business_manager.get_by_user_id(user_id=gr_user.id) + assert len(res) == 2 + + @pytest.mark.skip(reason="TODO") + def test_get_uuids_by_user_id(self): + pass + + def test_get_by_uuid(self, business, business_manager): + instance = business_manager.get_by_uuid(business_uuid=business.uuid) + assert business.id == instance.id + + def test_get_by_id(self, business, business_manager): + instance = business_manager.get_by_id(business_id=business.id) + assert business.uuid == instance.uuid + + def test_cache_key(self, business): + assert "business:" in business.cache_key + + # def test_create_raise_on_duplicate(self): + # b_uuid = uuid4().hex + # + # # Make the first one + # business = BusinessManager.create( + # uuid=b_uuid, + # name=f"test-{b_uuid[:6]}") + # assert isinstance(business, Business) + # + # # Try to make it again + # with pytest.raises(expected_exception=psycopg.errors.UniqueViolation): + # business = BusinessManager.create( + # uuid=b_uuid, + # name=f"test-{b_uuid[:6]}") + # + # def test_get_by_team(self, team): + # for idx in range(5): + # BusinessManager.create(name=f"Business Name #{uuid4().hex[:6]}", team=team) + # + # res = BusinessManager.get_by_team(team_id=team.id) + # assert isinstance(res, list) + # assert 5 == len(res) diff --git a/tests/managers/gr/test_team.py b/tests/managers/gr/test_team.py new file mode 100644 index 0000000..9215da4 --- /dev/null +++ b/tests/managers/gr/test_team.py @@ -0,0 +1,125 @@ +from uuid import uuid4 + +from test_utils.models.conftest import team + + +class TestMembershipManager: + + def test_init(self, membership_manager, gr_db): + assert membership_manager.pg_config == gr_db + + +class TestTeamManager: + + def test_init(self, team_manager, gr_db): + assert team_manager.pg_config == gr_db + + def test_get_or_create(self, team_manager): + from generalresearch.models.gr.team import Team + + new_uuid = uuid4().hex + + team: Team = team_manager.get_or_create(uuid=new_uuid) + + assert isinstance(team, Team) + assert isinstance(team.id, int) + assert team.uuid == new_uuid + assert team.name == "< Unknown >" + + def test_get_all(self, team_manager): + res1 = team_manager.get_all() + assert isinstance(res1, list) + + team_manager.create_dummy() + res2 = team_manager.get_all() + assert len(res1) == len(res2) - 1 + + def test_create(self, team_manager): + from generalresearch.models.gr.team import Team + + team: Team = team_manager.create_dummy() + assert isinstance(team, Team) + assert isinstance(team.id, int) + + def test_add_user(self, team, team_manager, gr_um, gr_db, gr_redis_config): + from generalresearch.models.gr.authentication import GRUser + from generalresearch.models.gr.team import Membership + + user: GRUser = gr_um.create_dummy() + + instance = team_manager.add_user(team=team, gr_user=user) + assert isinstance(instance, Membership) + + # assert team.gr_users is None + team.prefetch_gr_users(pg_config=gr_db, redis_config=gr_redis_config) + assert isinstance(team.gr_users, list) + assert len(team.gr_users) + assert team.gr_users == [user] + + def test_get_by_uuid(self, team_manager): + from generalresearch.models.gr.team import Team + + team: Team = team_manager.create_dummy() + + instance = team_manager.get_by_uuid(team_uuid=team.uuid) + assert team.id == instance.id + + def test_get_by_id(self, team_manager): + from generalresearch.models.gr.team import Team + + team: Team = team_manager.create_dummy() + + instance = team_manager.get_by_id(team_id=team.id) + assert team.uuid == instance.uuid + + def test_get_by_user(self, team, team_manager, gr_um): + from generalresearch.models.gr.authentication import GRUser + from generalresearch.models.gr.team import Team + + user: GRUser = gr_um.create_dummy() + team_manager.add_user(team=team, gr_user=user) + + res = team_manager.get_by_user(gr_user=user) + assert isinstance(res, list) + assert len(res) == 1 + instance = res[0] + assert isinstance(instance, Team) + assert instance.uuid == team.uuid + + def test_get_by_user_duplicates( + self, + gr_user_token, + gr_user, + membership, + product_factory, + membership_factory, + team, + thl_web_rr, + gr_redis_config, + gr_db, + ): + product_factory(team=team) + membership_factory(team=team, gr_user=gr_user) + + gr_user.prefetch_teams( + pg_config=gr_db, + redis_config=gr_redis_config, + ) + + assert len(gr_user.teams) == 1 + + # def test_create_raise_on_duplicate(self): + # t_uuid = uuid4().hex + # + # # Make the first one + # team = TeamManager.create( + # uuid=t_uuid, + # name=f"test-{t_uuid[:6]}") + # assert isinstance(team, Team) + # + # # Try to make it again + # with pytest.raises(expected_exception=psycopg.errors.UniqueViolation): + # TeamManager.create( + # uuid=t_uuid, + # name=f"test-{t_uuid[:6]}") + # diff --git a/tests/managers/leaderboard.py b/tests/managers/leaderboard.py new file mode 100644 index 0000000..4d32dd0 --- /dev/null +++ b/tests/managers/leaderboard.py @@ -0,0 +1,274 @@ +import os +import time +import zoneinfo +from datetime import datetime, timezone +from decimal import Decimal +from uuid import uuid4 + +import pytest + +from generalresearch.managers.leaderboard.manager import LeaderboardManager +from generalresearch.managers.leaderboard.tasks import hit_leaderboards +from generalresearch.models.thl.definitions import Status +from generalresearch.models.thl.user import User +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.session import Session +from generalresearch.models.thl.leaderboard import ( + LeaderboardCode, + LeaderboardFrequency, + LeaderboardRow, +) +from generalresearch.models.thl.product import ( + PayoutConfig, + PayoutTransformation, + PayoutTransformationPercentArgs, +) + +# random uuid for leaderboard tests +product_id = uuid4().hex + + +@pytest.fixture(autouse=True) +def set_timezone(): + os.environ["TZ"] = "UTC" + time.tzset() + yield + # Optionally reset to default + os.environ.pop("TZ", None) + time.tzset() + + +@pytest.fixture +def session_factory(): + return _create_session + + +def _create_session( + product_user_id="aaa", country_iso="us", user_payout=Decimal("1.00") +): + user = User( + product_id=product_id, + product_user_id=product_user_id, + ) + user.product = Product( + id=product_id, + name="test", + redirect_url="https://www.example.com", + payout_config=PayoutConfig( + payout_transformation=PayoutTransformation( + f="payout_transformation_percent", + kwargs=PayoutTransformationPercentArgs(pct=0.5), + ) + ), + ) + session = Session( + user=user, + started=datetime(2025, 2, 5, 6, tzinfo=timezone.utc), + id=1, + country_iso=country_iso, + status=Status.COMPLETE, + payout=Decimal("2.00"), + user_payout=user_payout, + ) + return session + + +@pytest.fixture(scope="function") +def setup_leaderboards(thl_redis): + complete_count = { + "aaa": 10, + "bbb": 6, + "ccc": 6, + "ddd": 6, + "eee": 2, + "fff": 1, + "ggg": 1, + } + sum_payout = {"aaa": 345, "bbb": 100, "ccc": 100} + max_payout = sum_payout + country_iso = "us" + for freq in [ + LeaderboardFrequency.DAILY, + LeaderboardFrequency.WEEKLY, + LeaderboardFrequency.MONTHLY, + ]: + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.COMPLETE_COUNT, + freq=freq, + product_id=product_id, + country_iso=country_iso, + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + thl_redis.delete(m.key) + thl_redis.zadd(m.key, complete_count) + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.SUM_PAYOUTS, + freq=freq, + product_id=product_id, + country_iso=country_iso, + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + thl_redis.delete(m.key) + thl_redis.zadd(m.key, sum_payout) + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.LARGEST_PAYOUT, + freq=freq, + product_id=product_id, + country_iso=country_iso, + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + thl_redis.delete(m.key) + thl_redis.zadd(m.key, max_payout) + + +class TestLeaderboards: + + def test_leaderboard_manager(self, setup_leaderboards, thl_redis): + country_iso = "us" + board_code = LeaderboardCode.COMPLETE_COUNT + freq = LeaderboardFrequency.DAILY + m = LeaderboardManager( + redis_client=thl_redis, + board_code=board_code, + freq=freq, + product_id=product_id, + country_iso=country_iso, + within_time=datetime(2025, 2, 5, 0, 0, 0), + ) + lb = m.get_leaderboard() + assert lb.period_start_local == datetime( + 2025, 2, 5, 0, 0, 0, tzinfo=zoneinfo.ZoneInfo(key="America/New_York") + ) + assert lb.period_end_local == datetime( + 2025, + 2, + 5, + 23, + 59, + 59, + 999999, + tzinfo=zoneinfo.ZoneInfo(key="America/New_York"), + ) + assert lb.period_start_utc == datetime(2025, 2, 5, 5, tzinfo=timezone.utc) + assert lb.row_count == 7 + assert lb.rows == [ + LeaderboardRow(bpuid="aaa", rank=1, value=10), + LeaderboardRow(bpuid="bbb", rank=2, value=6), + LeaderboardRow(bpuid="ccc", rank=2, value=6), + LeaderboardRow(bpuid="ddd", rank=2, value=6), + LeaderboardRow(bpuid="eee", rank=5, value=2), + LeaderboardRow(bpuid="fff", rank=6, value=1), + LeaderboardRow(bpuid="ggg", rank=6, value=1), + ] + + def test_leaderboard_manager_bpuid(self, setup_leaderboards, thl_redis): + country_iso = "us" + board_code = LeaderboardCode.COMPLETE_COUNT + freq = LeaderboardFrequency.DAILY + m = LeaderboardManager( + redis_client=thl_redis, + board_code=board_code, + freq=freq, + product_id=product_id, + country_iso=country_iso, + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + lb = m.get_leaderboard(bp_user_id="fff", limit=1) + + # TODO: this won't work correctly if I request bpuid 'ggg', because it + # is ordered at the end even though it is tied, so it won't get a + # row after ('fff') + + assert lb.rows == [ + LeaderboardRow(bpuid="eee", rank=5, value=2), + LeaderboardRow(bpuid="fff", rank=6, value=1), + LeaderboardRow(bpuid="ggg", rank=6, value=1), + ] + + lb.censor() + assert lb.rows[0].bpuid == "ee*" + + def test_leaderboard_hit(self, setup_leaderboards, session_factory, thl_redis): + hit_leaderboards(redis_client=thl_redis, session=session_factory()) + + for freq in [ + LeaderboardFrequency.DAILY, + LeaderboardFrequency.WEEKLY, + LeaderboardFrequency.MONTHLY, + ]: + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.COMPLETE_COUNT, + freq=freq, + product_id=product_id, + country_iso="us", + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + lb = m.get_leaderboard(limit=1) + assert lb.row_count == 7 + assert lb.rows == [LeaderboardRow(bpuid="aaa", rank=1, value=11)] + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.LARGEST_PAYOUT, + freq=freq, + product_id=product_id, + country_iso="us", + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + lb = m.get_leaderboard(limit=1) + assert lb.row_count == 3 + assert lb.rows == [LeaderboardRow(bpuid="aaa", rank=1, value=345)] + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.SUM_PAYOUTS, + freq=freq, + product_id=product_id, + country_iso="us", + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + lb = m.get_leaderboard(limit=1) + assert lb.row_count == 3 + assert lb.rows == [LeaderboardRow(bpuid="aaa", rank=1, value=345 + 100)] + + def test_leaderboard_hit_new_row( + self, setup_leaderboards, session_factory, thl_redis + ): + session = session_factory(product_user_id="zzz") + hit_leaderboards(redis_client=thl_redis, session=session) + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.COMPLETE_COUNT, + freq=LeaderboardFrequency.DAILY, + product_id=product_id, + country_iso="us", + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + lb = m.get_leaderboard() + assert lb.row_count == 8 + assert LeaderboardRow(bpuid="zzz", value=1, rank=6) in lb.rows + + def test_leaderboard_country(self, thl_redis): + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.COMPLETE_COUNT, + freq=LeaderboardFrequency.DAILY, + product_id=product_id, + country_iso="jp", + within_time=datetime( + 2025, + 2, + 1, + ), + ) + lb = m.get_leaderboard() + assert lb.row_count == 0 + assert lb.period_start_local == datetime( + 2025, 2, 1, 0, 0, 0, tzinfo=zoneinfo.ZoneInfo(key="Asia/Tokyo") + ) + assert lb.local_start_time == "2025-02-01T00:00:00+09:00" + assert lb.local_end_time == "2025-02-01T23:59:59.999999+09:00" + assert lb.period_start_utc == datetime(2025, 1, 31, 15, tzinfo=timezone.utc) + print(lb.model_dump(mode="json")) diff --git a/tests/managers/test_events.py b/tests/managers/test_events.py new file mode 100644 index 0000000..a0fab38 --- /dev/null +++ b/tests/managers/test_events.py @@ -0,0 +1,530 @@ +import random +import time +from datetime import timedelta, datetime, timezone +from decimal import Decimal +from functools import partial +from typing import Optional +from uuid import uuid4 + +import math +import pytest +from math import floor + +from generalresearch.managers.events import EventSubscriber +from generalresearch.models import Source +from generalresearch.models.events import ( + MessageKind, + EventType, + AggregateBySource, + MaxGaugeBySource, +) +from generalresearch.models.legacy.bucket import Bucket +from generalresearch.models.thl.definitions import Status, StatusCode1 +from generalresearch.models.thl.session import Session, Wall +from generalresearch.models.thl.user import User + + +# We don't need anything in the db, so not using the db fixtures +@pytest.fixture(scope="function") +def product_id(product_manager): + return uuid4().hex + + +@pytest.fixture(scope="function") +def user_factory(product_id): + return partial(create_dummy, product_id=product_id) + + +@pytest.fixture(scope="function") +def event_subscriber(thl_redis_config, product_id): + return EventSubscriber(redis_config=thl_redis_config, product_id=product_id) + + +def create_dummy( + product_id: Optional[str] = None, product_user_id: Optional[str] = None +) -> User: + return User( + product_id=product_id, + product_user_id=product_user_id or uuid4().hex, + uuid=uuid4().hex, + created=datetime.now(tz=timezone.utc), + user_id=random.randint(0, floor(2**32 / 2)), + ) + + +class TestActiveUsers: + + def test_run_empty(self, event_manager, product_id): + res = event_manager.get_user_stats(product_id) + assert res == { + "active_users_last_1h": 0, + "active_users_last_24h": 0, + "signups_last_24h": 0, + "in_progress_users": 0, + } + + def test_run(self, event_manager, product_id, user_factory): + event_manager.clear_global_user_stats() + user1: User = user_factory() + + # No matter how many times we do this, they're only active once + event_manager.handle_user(user1) + event_manager.handle_user(user1) + event_manager.handle_user(user1) + + res = event_manager.get_user_stats(product_id) + assert res == { + "active_users_last_1h": 1, + "active_users_last_24h": 1, + "signups_last_24h": 1, + "in_progress_users": 0, + } + + assert event_manager.get_global_user_stats() == { + "active_users_last_1h": 1, + "active_users_last_24h": 1, + "signups_last_24h": 1, + "in_progress_users": 0, + } + + # Create a 2nd user in another product + product_id2 = uuid4().hex + user2: User = user_factory(product_id=product_id2) + # Change to say user was created >24 hrs ago + user2.created = user2.created - timedelta(hours=25) + event_manager.handle_user(user2) + + # And now each have 1 active user + assert event_manager.get_user_stats(product_id) == { + "active_users_last_1h": 1, + "active_users_last_24h": 1, + "signups_last_24h": 1, + "in_progress_users": 0, + } + # user2 was created older than 24hrs ago + assert event_manager.get_user_stats(product_id2) == { + "active_users_last_1h": 1, + "active_users_last_24h": 1, + "signups_last_24h": 0, + "in_progress_users": 0, + } + # 2 globally active + assert event_manager.get_global_user_stats() == { + "active_users_last_1h": 2, + "active_users_last_24h": 2, + "signups_last_24h": 1, + "in_progress_users": 0, + } + + def test_inprogress(self, event_manager, product_id, user_factory): + event_manager.clear_global_user_stats() + user1: User = user_factory() + user2: User = user_factory() + + # No matter how many times we do this, they're only active once + event_manager.mark_user_inprogress(user1) + event_manager.mark_user_inprogress(user1) + event_manager.mark_user_inprogress(user1) + event_manager.mark_user_inprogress(user2) + + res = event_manager.get_user_stats(product_id) + assert res["in_progress_users"] == 2 + + event_manager.unmark_user_inprogress(user1) + res = event_manager.get_user_stats(product_id) + assert res["in_progress_users"] == 1 + + # Shouldn't do anything + event_manager.unmark_user_inprogress(user1) + res = event_manager.get_user_stats(product_id) + assert res["in_progress_users"] == 1 + + def test_expiry(self, event_manager, product_id, user_factory): + event_manager.clear_global_user_stats() + user1: User = user_factory() + event_manager.handle_user(user1) + event_manager.mark_user_inprogress(user1) + sec_24hr = timedelta(hours=24).total_seconds() + + # We don't want to wait an hour to test this, so we're going to + # just confirm that the keys will expire + time.sleep(1.1) + ttl = event_manager.redis_client.httl( + f"active_users_last_1h:{product_id}", user1.product_user_id + )[0] + assert 3600 - 60 <= ttl <= 3600 + ttl = event_manager.redis_client.httl( + f"active_users_last_24h:{product_id}", user1.product_user_id + )[0] + assert sec_24hr - 60 <= ttl <= sec_24hr + + ttl = event_manager.redis_client.httl("signups_last_24h", user1.uuid)[0] + assert sec_24hr - 60 <= ttl <= sec_24hr + + ttl = event_manager.redis_client.httl("in_progress_users", user1.uuid)[0] + assert 3600 - 60 <= ttl <= 3600 + + +class TestSessionStats: + + def test_run_empty(self, event_manager, product_id): + res = event_manager.get_session_stats(product_id) + assert res == { + "session_enters_last_1h": 0, + "session_enters_last_24h": 0, + "session_fails_last_1h": 0, + "session_fails_last_24h": 0, + "session_completes_last_1h": 0, + "session_completes_last_24h": 0, + "sum_payouts_last_1h": 0, + "sum_payouts_last_24h": 0, + "sum_user_payouts_last_1h": 0, + "sum_user_payouts_last_24h": 0, + "session_avg_payout_last_24h": None, + "session_avg_user_payout_last_24h": None, + "session_complete_avg_loi_last_24h": None, + "session_fail_avg_loi_last_24h": None, + } + + def test_run(self, event_manager, product_id, user_factory, utc_now, utc_hour_ago): + event_manager.clear_global_session_stats() + + user: User = user_factory() + session = Session( + country_iso="us", + started=utc_hour_ago + timedelta(minutes=10), + user=user, + ) + event_manager.session_on_enter(session=session, user=user) + session.update( + finished=utc_now, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("1.00"), + user_payout=Decimal("0.95"), + ) + event_manager.session_on_finish(session=session, user=user) + assert event_manager.get_session_stats(product_id) == { + "session_enters_last_1h": 1, + "session_enters_last_24h": 1, + "session_fails_last_1h": 0, + "session_fails_last_24h": 0, + "session_completes_last_1h": 1, + "session_completes_last_24h": 1, + "sum_payouts_last_1h": 100, + "sum_payouts_last_24h": 100, + "sum_user_payouts_last_1h": 95, + "sum_user_payouts_last_24h": 95, + "session_avg_payout_last_24h": 100, + "session_avg_user_payout_last_24h": 95, + "session_complete_avg_loi_last_24h": round(session.elapsed.total_seconds()), + "session_fail_avg_loi_last_24h": None, + } + + # The session gets inserted into redis using the session.finished + # timestamp, no matter what time it is right now. So we can + # kind of test the expiry by setting it to 61 min ago + session2 = Session( + country_iso="us", + started=utc_hour_ago - timedelta(minutes=10), + user=user, + ) + event_manager.session_on_enter(session=session2, user=user) + session2.update( + finished=utc_hour_ago - timedelta(minutes=1), + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("2.00"), + user_payout=Decimal("1.50"), + ) + event_manager.session_on_finish(session=session2, user=user) + avg_loi = ( + round(session.elapsed.total_seconds()) + + round(session2.elapsed.total_seconds()) + ) / 2 + assert event_manager.get_session_stats(product_id) == { + "session_enters_last_1h": 1, + "session_enters_last_24h": 2, + "session_fails_last_1h": 0, + "session_fails_last_24h": 0, + "session_completes_last_1h": 1, + "session_completes_last_24h": 2, + "sum_payouts_last_1h": 100, + "sum_payouts_last_24h": 300, + "sum_user_payouts_last_1h": 95, + "sum_user_payouts_last_24h": 95 + 150, + "session_avg_payout_last_24h": math.ceil((100 + 200) / 2), + "session_avg_user_payout_last_24h": math.ceil((95 + 150) / 2), + "session_complete_avg_loi_last_24h": avg_loi, + "session_fail_avg_loi_last_24h": None, + } + + # Don't want to wait an hour, so confirm the keys will expire + name = "session_completes_last_1h:" + product_id + res = event_manager.redis_client.hgetall(name) + field = (int(utc_now.timestamp()) // 60) * 60 + field_name = str(field) + assert res == {field_name: "1"} + assert ( + 3600 - 60 < event_manager.redis_client.httl(name, field_name)[0] < 3600 + 60 + ) + + # Second BP, fail + product_id2 = uuid4().hex + user2: User = user_factory(product_id=product_id2) + session3 = Session( + country_iso="us", + started=utc_now - timedelta(minutes=1), + user=user2, + ) + event_manager.session_on_enter(session=session3, user=user) + session3.update( + finished=utc_now, + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + ) + event_manager.session_on_finish(session=session3, user=user) + avg_loi_complete = ( + round(session.elapsed.total_seconds()) + + round(session2.elapsed.total_seconds()) + ) / 2 + assert event_manager.get_session_stats(product_id) == { + "session_enters_last_1h": 2, + "session_enters_last_24h": 3, + "session_fails_last_1h": 1, + "session_fails_last_24h": 1, + "session_completes_last_1h": 1, + "session_completes_last_24h": 2, + "sum_payouts_last_1h": 100, + "sum_payouts_last_24h": 300, + "sum_user_payouts_last_1h": 95, + "sum_user_payouts_last_24h": 95 + 150, + "session_avg_payout_last_24h": math.ceil((100 + 200) / 2), + "session_avg_user_payout_last_24h": math.ceil((95 + 150) / 2), + "session_complete_avg_loi_last_24h": avg_loi_complete, + "session_fail_avg_loi_last_24h": round(session3.elapsed.total_seconds()), + } + + +class TestTaskStatsManager: + def test_empty(self, event_manager): + event_manager.clear_task_stats() + assert event_manager.get_task_stats_raw() == { + "live_task_count": AggregateBySource(total=0), + "live_tasks_max_payout": MaxGaugeBySource(value=None), + "task_created_count_last_1h": AggregateBySource(total=0), + "task_created_count_last_24h": AggregateBySource(total=0), + } + assert event_manager.get_latest_task_stats() is None + + sm = event_manager.get_stats_message(product_id=uuid4().hex) + assert sm.data.task_created_count_last_24h.total == 0 + assert sm.data.live_tasks_max_payout.value is None + + def test(self, event_manager): + event_manager.clear_task_stats() + event_manager.set_source_task_stats( + source=Source.TESTING, + live_task_count=100, + live_tasks_max_payout=Decimal("1.00"), + ) + assert event_manager.get_task_stats_raw() == { + "live_task_count": AggregateBySource( + total=100, by_source={Source.TESTING: 100} + ), + "live_tasks_max_payout": MaxGaugeBySource( + value=100, by_source={Source.TESTING: 100} + ), + "task_created_count_last_1h": AggregateBySource(total=0), + "task_created_count_last_24h": AggregateBySource(total=0), + } + event_manager.set_source_task_stats( + source=Source.TESTING2, + live_task_count=50, + live_tasks_max_payout=Decimal("2.00"), + ) + assert event_manager.get_task_stats_raw() == { + "live_task_count": AggregateBySource( + total=150, by_source={Source.TESTING: 100, Source.TESTING2: 50} + ), + "live_tasks_max_payout": MaxGaugeBySource( + value=200, by_source={Source.TESTING: 100, Source.TESTING2: 200} + ), + "task_created_count_last_1h": AggregateBySource(total=0), + "task_created_count_last_24h": AggregateBySource(total=0), + } + event_manager.set_source_task_stats( + source=Source.TESTING, + live_task_count=101, + live_tasks_max_payout=Decimal("1.50"), + ) + assert event_manager.get_task_stats_raw() == { + "live_task_count": AggregateBySource( + total=151, by_source={Source.TESTING: 101, Source.TESTING2: 50} + ), + "live_tasks_max_payout": MaxGaugeBySource( + value=200, by_source={Source.TESTING: 150, Source.TESTING2: 200} + ), + "task_created_count_last_1h": AggregateBySource(total=0), + "task_created_count_last_24h": AggregateBySource(total=0), + } + event_manager.set_source_task_stats( + source=Source.TESTING, + live_task_count=99, + live_tasks_max_payout=Decimal("2.50"), + ) + assert event_manager.get_task_stats_raw() == { + "live_task_count": AggregateBySource( + total=149, by_source={Source.TESTING: 99, Source.TESTING2: 50} + ), + "live_tasks_max_payout": MaxGaugeBySource( + value=250, by_source={Source.TESTING: 250, Source.TESTING2: 200} + ), + "task_created_count_last_1h": AggregateBySource(total=0), + "task_created_count_last_24h": AggregateBySource(total=0), + } + event_manager.set_source_task_stats( + source=Source.TESTING, live_task_count=0, live_tasks_max_payout=Decimal("0") + ) + assert event_manager.get_task_stats_raw() == { + "live_task_count": AggregateBySource( + total=50, by_source={Source.TESTING2: 50} + ), + "live_tasks_max_payout": MaxGaugeBySource( + value=200, by_source={Source.TESTING2: 200} + ), + "task_created_count_last_1h": AggregateBySource(total=0), + "task_created_count_last_24h": AggregateBySource(total=0), + } + + event_manager.set_source_task_stats( + source=Source.TESTING, + live_task_count=0, + live_tasks_max_payout=Decimal("0"), + created_count=10, + ) + res = event_manager.get_task_stats_raw() + assert res["task_created_count_last_1h"] == AggregateBySource( + total=10, by_source={Source.TESTING: 10} + ) + assert res["task_created_count_last_24h"] == AggregateBySource( + total=10, by_source={Source.TESTING: 10} + ) + + event_manager.set_source_task_stats( + source=Source.TESTING, + live_task_count=0, + live_tasks_max_payout=Decimal("0"), + created_count=10, + ) + res = event_manager.get_task_stats_raw() + assert res["task_created_count_last_1h"] == AggregateBySource( + total=20, by_source={Source.TESTING: 20} + ) + assert res["task_created_count_last_24h"] == AggregateBySource( + total=20, by_source={Source.TESTING: 20} + ) + + event_manager.set_source_task_stats( + source=Source.TESTING2, + live_task_count=0, + live_tasks_max_payout=Decimal("0"), + created_count=1, + ) + res = event_manager.get_task_stats_raw() + assert res["task_created_count_last_1h"] == AggregateBySource( + total=21, by_source={Source.TESTING: 20, Source.TESTING2: 1} + ) + assert res["task_created_count_last_24h"] == AggregateBySource( + total=21, by_source={Source.TESTING: 20, Source.TESTING2: 1} + ) + + sm = event_manager.get_stats_message(product_id=uuid4().hex) + assert sm.data.task_created_count_last_24h.total == 21 + + +class TestChannelsSubscriptions: + def test_stats_worker( + self, + event_manager, + event_subscriber, + product_id, + user_factory, + utc_hour_ago, + utc_now, + ): + event_manager.clear_stats() + assert event_subscriber.pubsub + # We subscribed, so manually trigger the stats worker + # to get all product_ids subscribed and publish a + # stats message into that channel + event_manager.stats_worker_task() + assert product_id in event_manager.get_active_subscribers() + msg = event_subscriber.get_next_message() + print(msg) + assert msg.kind == MessageKind.STATS + + user = user_factory() + session = Session( + country_iso="us", + started=utc_hour_ago, + user=user, + clicked_bucket=Bucket(user_payout_min=Decimal("0.50")), + id=1, + ) + + event_manager.handle_session_enter(session, user) + msg = event_subscriber.get_next_message() + print(msg) + assert msg.kind == MessageKind.EVENT + assert msg.data.event_type == EventType.SESSION_ENTER + + wall = Wall( + req_survey_id="a", + req_cpi=Decimal("1"), + source=Source.TESTING, + session_id=session.id, + user_id=user.user_id, + ) + + event_manager.handle_task_enter(wall, session, user) + msg = event_subscriber.get_next_message() + print(msg) + assert msg.kind == MessageKind.EVENT + assert msg.data.event_type == EventType.TASK_ENTER + + wall.update( + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + finished=datetime.now(tz=timezone.utc), + cpi=Decimal("1"), + ) + event_manager.handle_task_finish(wall, session, user) + msg = event_subscriber.get_next_message() + print(msg) + assert msg.kind == MessageKind.EVENT + assert msg.data.event_type == EventType.TASK_FINISH + + session.update( + finished=utc_now, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("1.00"), + user_payout=Decimal("1.00"), + ) + + event_manager.handle_session_finish(session, user) + msg = event_subscriber.get_next_message() + print(msg) + assert msg.kind == MessageKind.EVENT + assert msg.data.event_type == EventType.SESSION_FINISH + + event_manager.stats_worker_task() + assert product_id in event_manager.get_active_subscribers() + msg = event_subscriber.get_next_message() + print(msg) + assert msg.kind == MessageKind.STATS + assert msg.data.active_users_last_1h == 1 + assert msg.data.session_enters_last_24h == 1 + assert msg.data.session_completes_last_1h == 1 + assert msg.data.signups_last_24h == 1 diff --git a/tests/managers/test_lucid.py b/tests/managers/test_lucid.py new file mode 100644 index 0000000..1a1bae7 --- /dev/null +++ b/tests/managers/test_lucid.py @@ -0,0 +1,23 @@ +import pytest + +from generalresearch.managers.lucid.profiling import get_profiling_library + +qids = ["42", "43", "45", "97", "120", "639", "15297"] + + +class TestLucidProfiling: + + @pytest.mark.skip + def test_get_library(self, thl_web_rr): + pks = [(qid, "us", "eng") for qid in qids] + qs = get_profiling_library(thl_web_rr, pks=pks) + assert len(qids) == len(qs) + + # just making sure this doesn't raise errors + for q in qs: + q.to_upk_question() + + # a lot will fail parsing because they have no options or the options are blank + # just asserting that we get some back + qs = get_profiling_library(thl_web_rr, country_iso="mx", language_iso="spa") + assert len(qs) > 100 diff --git a/tests/managers/test_userpid.py b/tests/managers/test_userpid.py new file mode 100644 index 0000000..4a3f699 --- /dev/null +++ b/tests/managers/test_userpid.py @@ -0,0 +1,68 @@ +import pytest +from pydantic import MySQLDsn + +from generalresearch.managers.marketplace.user_pid import UserPidMultiManager +from generalresearch.sql_helper import SqlHelper +from generalresearch.managers.cint.user_pid import CintUserPidManager +from generalresearch.managers.dynata.user_pid import DynataUserPidManager +from generalresearch.managers.innovate.user_pid import InnovateUserPidManager +from generalresearch.managers.morning.user_pid import MorningUserPidManager + +# from generalresearch.managers.precision import PrecisionUserPidManager +from generalresearch.managers.prodege.user_pid import ProdegeUserPidManager +from generalresearch.managers.repdata.user_pid import RepdataUserPidManager +from generalresearch.managers.sago.user_pid import SagoUserPidManager +from generalresearch.managers.spectrum.user_pid import SpectrumUserPidManager + +dsn = "" + + +class TestCintUserPidManager: + + def test_filter(self): + m = CintUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-cint"))) + + with pytest.raises(expected_exception=AssertionError) as excinfo: + m.filter() + assert str(excinfo.value) == "Must pass ONE of user_ids, pids" + + with pytest.raises(expected_exception=AssertionError) as excinfo: + m.filter(user_ids=[1, 2, 3], pids=["ed5b47c8551d453d985501391f190d3f"]) + assert str(excinfo.value) == "Must pass ONE of user_ids, pids" + + # pids get .hex before and after + assert m.filter(pids=["ed5b47c8551d453d985501391f190d3f"]) == m.filter( + pids=["ed5b47c8-551d-453d-9855-01391f190d3f"] + ) + + # user_ids and pids are for the same 3 users + res1 = m.filter(user_ids=[61586871, 61458915, 61390116]) + res2 = m.filter( + pids=[ + "ed5b47c8551d453d985501391f190d3f", + "7e640732c59f43d1b7c00137ab66600c", + "5160aeec9c3b4dbb85420128e6da6b5a", + ] + ) + assert res1 == res2 + + +class TestUserPidMultiManager: + + def test_filter(self): + managers = [ + CintUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-cint"))), + DynataUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-dynata"))), + InnovateUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-innovate"))), + MorningUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-morning"))), + ProdegeUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-prodege"))), + RepdataUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-repdata"))), + SagoUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-sago"))), + SpectrumUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-spectrum"))), + ] + m = UserPidMultiManager(sql_helper=SqlHelper(MySQLDsn(dsn)), managers=managers) + res = m.filter(user_ids=[1]) + assert len(res) == len(managers) + + res = m.filter(user_ids=[1, 2, 3]) + assert len(res) > len(managers) diff --git a/tests/managers/thl/__init__.py b/tests/managers/thl/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/managers/thl/__init__.py diff --git a/tests/managers/thl/test_buyer.py b/tests/managers/thl/test_buyer.py new file mode 100644 index 0000000..69ea105 --- /dev/null +++ b/tests/managers/thl/test_buyer.py @@ -0,0 +1,25 @@ +from generalresearch.models import Source + + +class TestBuyer: + + def test( + self, + delete_buyers_surveys, + buyer_manager, + ): + + bs = buyer_manager.bulk_get_or_create(source=Source.TESTING, codes=["a", "b"]) + assert len(bs) == 2 + buyer_a = bs[0] + assert buyer_a.id is not None + bs2 = buyer_manager.bulk_get_or_create(source=Source.TESTING, codes=["a", "c"]) + assert len(bs2) == 2 + buyer_a2 = bs2[0] + buyer_c = bs2[1] + # a isn't created again + assert buyer_a == buyer_a2 + assert bs2[0].id is not None + + # and its cached + assert buyer_c.id == buyer_manager.source_code_pk[f"{Source.TESTING.value}:c"] diff --git a/tests/managers/thl/test_cashout_method.py b/tests/managers/thl/test_cashout_method.py new file mode 100644 index 0000000..fe561d3 --- /dev/null +++ b/tests/managers/thl/test_cashout_method.py @@ -0,0 +1,139 @@ +import pytest + +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.thl.wallet.cashout_method import ( + CashMailCashoutMethodData, + USDeliveryAddress, + PaypalCashoutMethodData, +) +from test_utils.managers.cashout_methods import ( + EXAMPLE_TANGO_CASHOUT_METHODS, + AMT_ASSIGNMENT_CASHOUT_METHOD, + AMT_BONUS_CASHOUT_METHOD, +) + + +class TestTangoCashoutMethods: + + def test_create_and_get(self, cashout_method_manager, setup_cashoutmethod_db): + res = cashout_method_manager.filter(payout_types=[PayoutType.TANGO]) + assert len(res) == 2 + cm = [x for x in res if x.ext_id == "U025035"][0] + assert EXAMPLE_TANGO_CASHOUT_METHODS[0] == cm + + def test_user( + self, cashout_method_manager, user_with_wallet, setup_cashoutmethod_db + ): + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + # This user ONLY has the two tango cashout methods, no AMT + assert len(res) == 2 + + +class TestAMTCashoutMethods: + + def test_create_and_get(self, cashout_method_manager, setup_cashoutmethod_db): + res = cashout_method_manager.filter(payout_types=[PayoutType.AMT]) + assert len(res) == 2 + cm = [x for x in res if x.name == "AMT Assignment"][0] + assert AMT_ASSIGNMENT_CASHOUT_METHOD == cm + cm = [x for x in res if x.name == "AMT Bonus"][0] + assert AMT_BONUS_CASHOUT_METHOD == cm + + def test_user( + self, cashout_method_manager, user_with_wallet_amt, setup_cashoutmethod_db + ): + res = cashout_method_manager.get_cashout_methods(user_with_wallet_amt) + # This user has the 2 tango, plus amt bonus & assignment + assert len(res) == 4 + + +class TestUserCashoutMethods: + + def test(self, cashout_method_manager, user_with_wallet, delete_cashoutmethod_db): + delete_cashoutmethod_db() + + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 0 + + def test_cash_in_mail( + self, cashout_method_manager, user_with_wallet, delete_cashoutmethod_db + ): + delete_cashoutmethod_db() + + data = CashMailCashoutMethodData( + delivery_address=USDeliveryAddress.model_validate( + { + "name_or_attn": "Josh Ackerman", + "address": "123 Fake St", + "city": "San Francisco", + "state": "CA", + "postal_code": "12345", + } + ) + ) + cashout_method_manager.create_cash_in_mail_cashout_method( + data=data, user=user_with_wallet + ) + + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 1 + assert res[0].data.delivery_address.postal_code == "12345" + + # try to create the same one again. should just do nothing + cashout_method_manager.create_cash_in_mail_cashout_method( + data=data, user=user_with_wallet + ) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 1 + + # Create with a new address, will create a new one + data.delivery_address.postal_code = "99999" + cashout_method_manager.create_cash_in_mail_cashout_method( + data=data, user=user_with_wallet + ) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 2 + + def test_paypal( + self, cashout_method_manager, user_with_wallet, delete_cashoutmethod_db + ): + delete_cashoutmethod_db() + + data = PaypalCashoutMethodData(email="test@example.com") + cashout_method_manager.create_paypal_cashout_method( + data=data, user=user_with_wallet + ) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 1 + assert res[0].data.email == "test@example.com" + + # try to create the same one again. should just do nothing + cashout_method_manager.create_paypal_cashout_method( + data=data, user=user_with_wallet + ) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 1 + + # Create with a new email, will error! must delete the old one first. + # We can only have one paypal active + data.email = "test2@example.com" + with pytest.raises( + ValueError, + match="User already has a cashout method of this type. Delete the existing one and try again.", + ): + cashout_method_manager.create_paypal_cashout_method( + data=data, user=user_with_wallet + ) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 1 + + cashout_method_manager.delete_cashout_method(res[0].id) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 0 + + cashout_method_manager.create_paypal_cashout_method( + data=data, user=user_with_wallet + ) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 1 + assert res[0].data.email == "test2@example.com" diff --git a/tests/managers/thl/test_category.py b/tests/managers/thl/test_category.py new file mode 100644 index 0000000..ad0f07b --- /dev/null +++ b/tests/managers/thl/test_category.py @@ -0,0 +1,100 @@ +import pytest + +from generalresearch.models.thl.category import Category + + +class TestCategory: + + @pytest.fixture + def beauty_fitness(self, thl_web_rw): + + return Category( + uuid="12c1e96be82c4642a07a12a90ce6f59e", + adwords_vertical_id="44", + label="Beauty & Fitness", + path="/Beauty & Fitness", + ) + + @pytest.fixture + def hair_care(self, beauty_fitness): + + return Category( + uuid="dd76c4b565d34f198dad3687326503d6", + adwords_vertical_id="146", + label="Hair Care", + path="/Beauty & Fitness/Hair Care", + ) + + @pytest.fixture + def hair_loss(self, hair_care): + + return Category( + uuid="aacff523c8e246888215611ec3b823c0", + adwords_vertical_id="235", + label="Hair Loss", + path="/Beauty & Fitness/Hair Care/Hair Loss", + ) + + @pytest.fixture + def category_data( + self, category_manager, thl_web_rw, beauty_fitness, hair_care, hair_loss + ): + cats = [beauty_fitness, hair_care, hair_loss] + data = [x.model_dump(mode="json") for x in cats] + # We need the parent pk's to set the parent_id. So insert all without a parent, + # then pull back all pks and map to the parents as parsed by the parent_path + query = """ + INSERT INTO marketplace_category + (uuid, adwords_vertical_id, label, path) + VALUES + (%(uuid)s, %(adwords_vertical_id)s, %(label)s, %(path)s) + ON CONFLICT (uuid) DO NOTHING; + """ + with thl_web_rw.make_connection() as conn: + with conn.cursor() as c: + c.executemany(query=query, params_seq=data) + conn.commit() + + res = thl_web_rw.execute_sql_query("SELECT id, path FROM marketplace_category") + path_id = {x["path"]: x["id"] for x in res} + data = [ + {"id": path_id[c.path], "parent_id": path_id[c.parent_path]} + for c in cats + if c.parent_path + ] + query = """ + UPDATE marketplace_category + SET parent_id = %(parent_id)s + WHERE id = %(id)s; + """ + with thl_web_rw.make_connection() as conn: + with conn.cursor() as c: + c.executemany(query=query, params_seq=data) + conn.commit() + + category_manager.populate_caches() + + def test( + self, + category_data, + category_manager, + beauty_fitness, + hair_care, + hair_loss, + ): + # category_manager on init caches the category info. This rarely/never changes so this is fine, + # but now that tests get run on a new db each time, the category_manager is inited before + # the fixtures run. so category_manager's cache needs to be rerun + + # path='/Beauty & Fitness/Hair Care/Hair Loss' + c: Category = category_manager.get_by_label("Hair Loss") + # Beauty & Fitness + assert beauty_fitness.uuid == category_manager.get_top_level(c).uuid + + c: Category = category_manager.categories[beauty_fitness.uuid] + # The root is itself + assert c == category_manager.get_category_root(c) + + # The root is Beauty & Fitness + c: Category = category_manager.get_by_label("Hair Loss") + assert beauty_fitness.uuid == category_manager.get_category_root(c).uuid diff --git a/tests/managers/thl/test_contest/__init__.py b/tests/managers/thl/test_contest/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/managers/thl/test_contest/__init__.py diff --git a/tests/managers/thl/test_contest/test_leaderboard.py b/tests/managers/thl/test_contest/test_leaderboard.py new file mode 100644 index 0000000..80a88a5 --- /dev/null +++ b/tests/managers/thl/test_contest/test_leaderboard.py @@ -0,0 +1,138 @@ +from datetime import datetime, timezone, timedelta +from zoneinfo import ZoneInfo + +from generalresearch.currency import USDCent +from generalresearch.models.thl.contest.definitions import ( + ContestStatus, + ContestEndReason, +) +from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContest, + LeaderboardContestCreate, +) +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.user import User +from test_utils.managers.contest.conftest import ( + leaderboard_contest_in_db as contest_in_db, + leaderboard_contest_create as contest_create, +) + + +class TestLeaderboardContestCRUD: + + def test_create( + self, + contest_create: LeaderboardContestCreate, + product_user_wallet_yes: Product, + thl_lm, + contest_manager, + ): + c = contest_manager.create( + product_id=product_user_wallet_yes.uuid, contest_create=contest_create + ) + c_out = contest_manager.get(c.uuid) + assert c == c_out + + assert isinstance(c, LeaderboardContest) + assert c.prize_count == 2 + assert c.status == ContestStatus.ACTIVE + # We have it set in the fixture as the daily contest for 2025-01-01 + assert c.end_condition.ends_at == datetime( + 2025, 1, 1, 23, 59, 59, 999999, tzinfo=ZoneInfo("America/New_York") + ).astimezone(tz=timezone.utc) + timedelta(minutes=90) + + def test_enter( + self, + user_with_wallet: User, + contest_in_db: LeaderboardContest, + thl_lm, + contest_manager, + user_manager, + thl_redis, + ): + contest = contest_in_db + user = user_with_wallet + + c: LeaderboardContest = contest_manager.get(contest_uuid=contest.uuid) + + c = contest_manager.get_leaderboard_user_view( + contest_uuid=contest.uuid, + user=user, + redis_client=thl_redis, + user_manager=user_manager, + ) + assert c.user_rank is None + + lbm = c.get_leaderboard_manager() + lbm.hit_complete_count(user.product_user_id) + + c = contest_manager.get_leaderboard_user_view( + contest_uuid=contest.uuid, + user=user, + redis_client=thl_redis, + user_manager=user_manager, + ) + assert c.user_rank == 1 + + def test_contest_ends( + self, + user_with_wallet: User, + contest_in_db: LeaderboardContest, + thl_lm, + contest_manager, + user_manager, + thl_redis, + ): + # The contest should be over. We need to trigger it. + contest = contest_in_db + contest._redis_client = thl_redis + contest._user_manager = user_manager + user = user_with_wallet + + lbm = contest.get_leaderboard_manager() + lbm.hit_complete_count(user.product_user_id) + + c = contest_manager.get_leaderboard_user_view( + contest_uuid=contest.uuid, + user=user, + redis_client=thl_redis, + user_manager=user_manager, + ) + assert c.user_rank == 1 + + bp_wallet = thl_lm.get_account_or_create_bp_wallet_by_uuid(user.product_id) + bp_wallet_balance = thl_lm.get_account_balance(account=bp_wallet) + assert bp_wallet_balance == 0 + user_wallet = thl_lm.get_account_or_create_user_wallet(user=user) + user_balance = thl_lm.get_account_balance(user_wallet) + assert user_balance == 0 + + decision, reason = contest.should_end() + assert decision + assert reason == ContestEndReason.ENDS_AT + + contest_manager.end_contest_if_over(contest=contest, ledger_manager=thl_lm) + + c: LeaderboardContest = contest_manager.get(contest_uuid=contest.uuid) + assert c.status == ContestStatus.COMPLETED + print(c) + + user_contest = contest_manager.get_leaderboard_user_view( + contest_uuid=contest.uuid, + user=user, + redis_client=thl_redis, + user_manager=user_manager, + ) + assert len(user_contest.user_winnings) == 1 + w = user_contest.user_winnings[0] + assert w.product_user_id == user.product_user_id + assert w.prize.cash_amount == USDCent(15_00) + + # The prize is $15.00, so the user should get $15, paid by the bp + assert thl_lm.get_account_balance(account=user_wallet) == 15_00 + # contest wallet is 0, and the BP gets 20c + contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid( + contest_uuid=c.uuid + ) + assert thl_lm.get_account_balance(account=contest_wallet) == 0 + assert thl_lm.get_account_balance(account=bp_wallet) == -15_00 diff --git a/tests/managers/thl/test_contest/test_milestone.py b/tests/managers/thl/test_contest/test_milestone.py new file mode 100644 index 0000000..7312a64 --- /dev/null +++ b/tests/managers/thl/test_contest/test_milestone.py @@ -0,0 +1,296 @@ +from datetime import datetime, timezone + +from generalresearch.models.thl.contest.definitions import ( + ContestStatus, + ContestEndReason, +) +from generalresearch.models.thl.contest.milestone import ( + MilestoneContest, + MilestoneContestCreate, + MilestoneUserView, + ContestEntryTrigger, +) +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.user import User +from test_utils.managers.contest.conftest import ( + milestone_contest as contest, + milestone_contest_in_db as contest_in_db, + milestone_contest_create as contest_create, + milestone_contest_factory as contest_factory, +) + + +class TestMilestoneContest: + + def test_should_end(self, contest: MilestoneContest, thl_lm, contest_manager): + # contest is active and has no entries + should, msg = contest.should_end() + assert not should, msg + + # Change so that the contest ends now + contest.end_condition.ends_at = datetime.now(tz=timezone.utc) + should, msg = contest.should_end() + assert should + assert msg == ContestEndReason.ENDS_AT + + # Change the win amount it thinks it past over the target + contest.end_condition.ends_at = None + contest.end_condition.max_winners = 10 + contest.win_count = 10 + should, msg = contest.should_end() + assert should + assert msg == ContestEndReason.MAX_WINNERS + + +class TestMilestoneContestCRUD: + + def test_create( + self, + contest_create: MilestoneContestCreate, + product_user_wallet_yes: Product, + thl_lm, + contest_manager, + ): + c = contest_manager.create( + product_id=product_user_wallet_yes.uuid, contest_create=contest_create + ) + c_out = contest_manager.get(c.uuid) + assert c == c_out + + assert isinstance(c, MilestoneContest) + assert c.prize_count == 2 + assert c.status == ContestStatus.ACTIVE + assert c.end_condition.max_winners == 5 + assert c.entry_trigger == ContestEntryTrigger.TASK_COMPLETE + assert c.target_amount == 3 + assert c.win_count == 0 + + def test_enter( + self, + user_with_wallet: User, + contest_in_db: MilestoneContest, + thl_lm, + contest_manager, + ): + # Users CANNOT directly enter a milestone contest through the api, + # but we'll call this manager method when a trigger is hit. + contest = contest_in_db + user = user_with_wallet + + contest_manager.enter_milestone_contest( + contest_uuid=contest.uuid, + user=user, + country_iso="us", + ledger_manager=thl_lm, + incr=1, + ) + + c: MilestoneContest = contest_manager.get(contest_uuid=contest.uuid) + assert c.status == ContestStatus.ACTIVE + assert not hasattr(c, "current_amount") + assert not hasattr(c, "current_participants") + + c: MilestoneUserView = contest_manager.get_milestone_user_view( + contest_uuid=contest.uuid, user=user_with_wallet + ) + assert c.user_amount == 1 + + # Contest wallet should have 0 bc there is no ledger + contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid( + contest_uuid=contest.uuid + ) + assert thl_lm.get_account_balance(contest_wallet) == 0 + + # Enter again! + contest_manager.enter_milestone_contest( + contest_uuid=contest.uuid, + user=user, + country_iso="us", + ledger_manager=thl_lm, + incr=1, + ) + c: MilestoneUserView = contest_manager.get_milestone_user_view( + contest_uuid=contest.uuid, user=user_with_wallet + ) + assert c.user_amount == 2 + + # We should have ONE entry with a value of 2 + e = contest_manager.get_entries_by_contest_id(c.id) + assert len(e) == 1 + assert e[0].amount == 2 + + def test_enter_win( + self, + user_with_wallet: User, + contest_in_db: MilestoneContest, + thl_lm, + contest_manager, + ): + # User enters contest, which brings the USER'S total amount above the limit, + # and the user reaches the milestone + contest = contest_in_db + user = user_with_wallet + + user_wallet = thl_lm.get_account_or_create_user_wallet(user=user) + user_balance = thl_lm.get_account_balance(account=user_wallet) + bp_wallet = thl_lm.get_account_or_create_bp_wallet_by_uuid( + product_uuid=user.product_id + ) + bp_wallet_balance = thl_lm.get_account_balance(account=bp_wallet) + + c: MilestoneUserView = contest_manager.get_milestone_user_view( + contest_uuid=contest.uuid, user=user_with_wallet + ) + assert c.user_amount == 0 + res, msg = c.is_user_eligible(country_iso="us") + assert res, msg + + # User reaches the milestone after 3 completes/whatevers. + for _ in range(3): + contest_manager.enter_milestone_contest( + contest_uuid=contest.uuid, + user=user, + country_iso="us", + ledger_manager=thl_lm, + incr=1, + ) + + # to be clear, the contest itself doesn't end! + c: MilestoneContest = contest_manager.get(contest_uuid=contest.uuid) + assert c.status == ContestStatus.ACTIVE + + c: MilestoneUserView = contest_manager.get_milestone_user_view( + contest_uuid=contest.uuid, user=user_with_wallet + ) + assert c.user_amount == 3 + res, msg = c.is_user_eligible(country_iso="us") + assert not res + assert msg == "User should have won already" + + assert len(c.user_winnings) == 2 + assert c.win_count == 1 + + # The prize was awarded! User should have won $1.00 + assert thl_lm.get_account_balance(user_wallet) - user_balance == 100 + # Which was paid from the BP's balance + assert thl_lm.get_account_balance(bp_wallet) - bp_wallet_balance == -100 + + # winnings = cm.get_winnings_by_user(user=user) + # assert len(winnings) == 1 + # win = winnings[0] + # assert win.product_user_id == user.product_user_id + + def test_enter_ends( + self, + user_factory, + product_user_wallet_yes: Product, + contest_in_db: MilestoneContest, + thl_lm, + contest_manager, + ): + # Multiple users reach the milestone. Contest ends after 5 wins. + users = [user_factory(product=product_user_wallet_yes) for _ in range(5)] + contest = contest_in_db + + for u in users: + contest_manager.enter_milestone_contest( + contest_uuid=contest.uuid, + user=u, + country_iso="us", + ledger_manager=thl_lm, + incr=3, + ) + + c: MilestoneContest = contest_manager.get(contest_uuid=contest.uuid) + assert c.status == ContestStatus.COMPLETED + assert c.end_reason == ContestEndReason.MAX_WINNERS + + def test_trigger( + self, + user_with_wallet: User, + contest_in_db: MilestoneContest, + thl_lm, + contest_manager, + ): + # Pretend user just got a complete + cnt = contest_manager.hit_milestone_triggers( + country_iso="us", + user=user_with_wallet, + event=ContestEntryTrigger.TASK_COMPLETE, + ledger_manager=thl_lm, + ) + assert cnt == 1 + + # Assert this contest got entered + c: MilestoneUserView = contest_manager.get_milestone_user_view( + contest_uuid=contest_in_db.uuid, user=user_with_wallet + ) + assert c.user_amount == 1 + + +class TestMilestoneContestUserViews: + def test_list_user_eligible_country( + self, user_with_wallet: User, contest_factory, thl_lm, contest_manager + ): + # No contests exists + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="us" + ) + assert len(cs) == 0 + + # Create a contest. It'll be in the US/CA + contest_factory(country_isos={"us", "ca"}) + + # Not eligible in mexico + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="mx" + ) + assert len(cs) == 0 + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="us" + ) + assert len(cs) == 1 + + # Create another, any country + contest_factory(country_isos=None) + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="mx" + ) + assert len(cs) == 1 + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="us" + ) + assert len(cs) == 2 + + def test_list_user_eligible( + self, user_with_money: User, contest_factory, thl_lm, contest_manager + ): + # User reaches milestone after 1 complete + c = contest_factory(target_amount=1) + user = user_with_money + + cs = contest_manager.get_many_by_user_eligible( + user=user_with_money, country_iso="us" + ) + assert len(cs) == 1 + + contest_manager.enter_milestone_contest( + contest_uuid=c.uuid, user=user, country_iso="us", ledger_manager=thl_lm + ) + + # User isn't eligible anymore + cs = contest_manager.get_many_by_user_eligible( + user=user_with_money, country_iso="us" + ) + assert len(cs) == 0 + + # But it comes back in the list entered + cs = contest_manager.get_many_by_user_entered(user=user_with_money) + assert len(cs) == 1 + c = cs[0] + assert c.user_amount == 1 + assert isinstance(c, MilestoneUserView) + assert not hasattr(c, "current_win_probability") + + # They won one contest with 2 prizes + assert len(contest_manager.get_winnings_by_user(user_with_money)) == 2 diff --git a/tests/managers/thl/test_contest/test_raffle.py b/tests/managers/thl/test_contest/test_raffle.py new file mode 100644 index 0000000..060055a --- /dev/null +++ b/tests/managers/thl/test_contest/test_raffle.py @@ -0,0 +1,474 @@ +from datetime import datetime, timezone + +import pytest +from pydantic import ValidationError +from pytest import approx + +from generalresearch.currency import USDCent +from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerTransactionConditionFailedError, +) +from generalresearch.models.thl.contest import ( + ContestPrize, + ContestEntryRule, + ContestEndCondition, +) +from generalresearch.models.thl.contest.definitions import ( + ContestStatus, + ContestPrizeKind, + ContestEndReason, +) +from generalresearch.models.thl.contest.exceptions import ContestError +from generalresearch.models.thl.contest.raffle import ( + ContestEntry, + ContestEntryType, +) +from generalresearch.models.thl.contest.raffle import ( + RaffleContest, + RaffleContestCreate, + RaffleUserView, +) +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.user import User +from test_utils.managers.contest.conftest import ( + raffle_contest as contest, + raffle_contest_in_db as contest_in_db, + raffle_contest_create as contest_create, + raffle_contest_factory as contest_factory, +) + + +class TestRaffleContest: + + def test_should_end(self, contest: RaffleContest, thl_lm, contest_manager): + # contest is active and has no entries + should, msg = contest.should_end() + assert not should, msg + + # Change so that the contest ends now + contest.end_condition.ends_at = datetime.now(tz=timezone.utc) + should, msg = contest.should_end() + assert should + assert msg == ContestEndReason.ENDS_AT + + # Change the entry amount it thinks it has to over the target + contest.end_condition.ends_at = None + contest.current_amount = USDCent(100) + should, msg = contest.should_end() + assert should + assert msg == ContestEndReason.TARGET_ENTRY_AMOUNT + + +class TestRaffleContestCRUD: + + def test_create( + self, + contest_create: RaffleContestCreate, + product_user_wallet_yes: Product, + thl_lm, + contest_manager, + ): + c = contest_manager.create( + product_id=product_user_wallet_yes.uuid, contest_create=contest_create + ) + c_out = contest_manager.get(c.uuid) + assert c == c_out + + assert isinstance(c, RaffleContest) + assert c.prize_count == 1 + assert c.status == ContestStatus.ACTIVE + assert c.end_condition.target_entry_amount == USDCent(100) + assert c.current_amount == 0 + assert c.current_participants == 0 + + @pytest.mark.parametrize("user_with_money", [{"min_balance": 60}], indirect=True) + def test_enter( + self, + user_with_money: User, + contest_in_db: RaffleContest, + thl_lm, + contest_manager, + ): + # Raffle ends at $1.00. User enters for $0.60 + print(user_with_money.product_id) + print(contest_in_db.product_id) + print(contest_in_db.uuid) + contest = contest_in_db + + user_wallet = thl_lm.get_account_or_create_user_wallet(user=user_with_money) + user_balance = thl_lm.get_account_balance(account=user_wallet) + + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(60) + ) + entry = contest_manager.enter_contest( + contest_uuid=contest.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + c: RaffleContest = contest_manager.get(contest_uuid=contest.uuid) + assert c.current_amount == USDCent(60) + assert c.current_participants == 1 + assert c.status == ContestStatus.ACTIVE + + c: RaffleUserView = contest_manager.get_raffle_user_view( + contest_uuid=contest.uuid, user=user_with_money + ) + assert c.user_amount == USDCent(60) + assert c.user_amount_today == USDCent(60) + assert c.projected_win_probability == approx(60 / 100, rel=0.01) + + # Contest wallet should have $0.60 + contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid( + contest_uuid=contest.uuid + ) + assert thl_lm.get_account_balance(account=contest_wallet) == 60 + # User spent 60c + assert user_balance - thl_lm.get_account_balance(account=user_wallet) == 60 + + @pytest.mark.parametrize("user_with_money", [{"min_balance": 120}], indirect=True) + def test_enter_ends( + self, + user_with_money: User, + contest_in_db: RaffleContest, + thl_lm, + contest_manager, + ): + # User enters contest, which brings the total amount above the limit, + # and the contest should end, with a winner selected + contest = contest_in_db + + bp_wallet = thl_lm.get_account_or_create_bp_wallet_by_uuid( + user_with_money.product_id + ) + # I bribed the user, so the balance is not 0 + bp_wallet_balance = thl_lm.get_account_balance(account=bp_wallet) + + for _ in range(2): + entry = ContestEntry( + entry_type=ContestEntryType.CASH, + user=user_with_money, + amount=USDCent(60), + ) + contest_manager.enter_contest( + contest_uuid=contest.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + c: RaffleContest = contest_manager.get(contest_uuid=contest.uuid) + assert c.status == ContestStatus.COMPLETED + print(c) + + user_contest = contest_manager.get_raffle_user_view( + contest_uuid=contest.uuid, user=user_with_money + ) + assert user_contest.current_win_probability == 1 + assert user_contest.projected_win_probability == 1 + assert len(user_contest.user_winnings) == 1 + + # todo: make a all winning method + winnings = contest_manager.get_winnings_by_user(user=user_with_money) + assert len(winnings) == 1 + win = winnings[0] + assert win.product_user_id == user_with_money.product_user_id + + # Contest wallet should have gotten zeroed out + contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid( + contest_uuid=contest.uuid + ) + assert thl_lm.get_account_balance(contest_wallet) == 0 + # Expense wallet gets the $1.00 expense + expense_wallet = thl_lm.get_account_or_create_bp_expense_by_uuid( + product_uuid=user_with_money.product_id, expense_name="Prize" + ) + assert thl_lm.get_account_balance(expense_wallet) == -100 + # And the BP gets 20c + assert thl_lm.get_account_balance(bp_wallet) - bp_wallet_balance == 20 + + @pytest.mark.parametrize("user_with_money", [{"min_balance": 120}], indirect=True) + def test_enter_ends_cash_prize( + self, user_with_money: User, contest_factory, thl_lm, contest_manager + ): + # Same as test_enter_ends, but the prize is cash. Just + # testing the ledger methods + c = contest_factory( + prizes=[ + ContestPrize( + name="$1.00 bonus", + kind=ContestPrizeKind.CASH, + estimated_cash_value=USDCent(100), + cash_amount=USDCent(100), + ) + ] + ) + assert c.prizes[0].kind == ContestPrizeKind.CASH + + user_wallet = thl_lm.get_account_or_create_user_wallet(user=user_with_money) + user_balance = thl_lm.get_account_balance(user_wallet) + bp_wallet = thl_lm.get_account_or_create_bp_wallet_by_uuid( + user_with_money.product_id + ) + bp_wallet_balance = thl_lm.get_account_balance(bp_wallet) + + ## Enter Contest + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(120) + ) + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + + # The prize is $1.00, so the user spent $1.20 entering, won, then got $1.00 back + assert ( + thl_lm.get_account_balance(account=user_wallet) == user_balance + 100 - 120 + ) + # contest wallet is 0, and the BP gets 20c + contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid( + contest_uuid=c.uuid + ) + assert thl_lm.get_account_balance(account=contest_wallet) == 0 + assert thl_lm.get_account_balance(account=bp_wallet) - bp_wallet_balance == 20 + + def test_enter_failure( + self, + user_with_wallet: User, + contest_in_db: RaffleContest, + thl_lm, + contest_manager, + ): + c = contest_in_db + user = user_with_wallet + + # Tries to enter $0 + with pytest.raises(ValidationError) as e: + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user, amount=USDCent(0) + ) + assert "Input should be greater than 0" in str(e.value) + + # User has no money + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user, amount=USDCent(20) + ) + with pytest.raises(LedgerTransactionConditionFailedError) as e: + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + assert e.value.args[0] == "insufficient balance" + + # Tries to enter with the wrong entry type (count, on a cash contest) + entry = ContestEntry(entry_type=ContestEntryType.COUNT, user=user, amount=1) + with pytest.raises(AssertionError) as e: + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + assert "incompatible entry type" in str(e.value) + + @pytest.mark.parametrize("user_with_money", [{"min_balance": 100}], indirect=True) + def test_enter_not_eligible( + self, user_with_money: User, contest_factory, thl_lm, contest_manager + ): + # Max entry amount per user $0.10. Contest still ends at $1.00 + c = contest_factory( + entry_rule=ContestEntryRule( + max_entry_amount_per_user=USDCent(10), + max_daily_entries_per_user=USDCent(8), + ) + ) + c: RaffleContest = contest_manager.get(c.uuid) + assert c.entry_rule.max_entry_amount_per_user == USDCent(10) + assert c.entry_rule.max_daily_entries_per_user == USDCent(8) + + # User tries to enter $0.20 + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(20) + ) + with pytest.raises(ContestError) as e: + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + assert "Entry would exceed max amount per user." in str(e.value) + + # User tries to enter $0.10 + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(10) + ) + with pytest.raises(ContestError) as e: + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + assert "Entry would exceed max amount per user per day." in str(e.value) + + # User enters $0.08 successfully + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(8) + ) + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + + # Then can't anymore + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(1) + ) + with pytest.raises(ContestError) as e: + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + assert "Entry would exceed max amount per user per day." in str(e.value) + + +class TestRaffleContestUserViews: + def test_list_user_eligible_country( + self, user_with_wallet: User, contest_factory, thl_lm, contest_manager + ): + # No contests exists + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="us" + ) + assert len(cs) == 0 + + # Create a contest. It'll be in the US/CA + contest_factory(country_isos={"us", "ca"}) + + # Not eligible in mexico + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="mx" + ) + assert len(cs) == 0 + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="us" + ) + assert len(cs) == 1 + + # Create another, any country + contest_factory(country_isos=None) + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="mx" + ) + assert len(cs) == 1 + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="us" + ) + assert len(cs) == 2 + + def test_list_user_eligible( + self, user_with_money: User, contest_factory, thl_lm, contest_manager + ): + c = contest_factory( + end_condition=ContestEndCondition(target_entry_amount=USDCent(10)), + entry_rule=ContestEntryRule( + max_entry_amount_per_user=USDCent(1), + ), + ) + cs = contest_manager.get_many_by_user_eligible( + user=user_with_money, country_iso="us" + ) + assert len(cs) == 1 + + entry = ContestEntry( + entry_type=ContestEntryType.CASH, + user=user_with_money, + amount=USDCent(1), + ) + contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + + # User isn't eligible anymore + cs = contest_manager.get_many_by_user_eligible( + user=user_with_money, country_iso="us" + ) + assert len(cs) == 0 + + # But it comes back in the list entered + cs = contest_manager.get_many_by_user_entered(user=user_with_money) + assert len(cs) == 1 + c = cs[0] + assert c.user_amount == USDCent(1) + assert c.user_amount_today == USDCent(1) + assert c.current_win_probability == 1 + assert c.projected_win_probability == approx(1 / 10, rel=0.01) + + # And nothing won yet #todo + # cs = cm.get_many_by_user_won(user=user_with_money) + + assert len(contest_manager.get_winnings_by_user(user_with_money)) == 0 + + def test_list_user_winnings( + self, user_with_money: User, contest_factory, thl_lm, contest_manager + ): + c = contest_factory( + end_condition=ContestEndCondition(target_entry_amount=USDCent(100)), + ) + entry = ContestEntry( + entry_type=ContestEntryType.CASH, + user=user_with_money, + amount=USDCent(100), + ) + contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + # Contest ends after 100 entry, user enters 100 entry, user wins! + ws = contest_manager.get_winnings_by_user(user_with_money) + assert len(ws) == 1 + w = ws[0] + assert w.user.user_id == user_with_money.user_id + assert w.prize == c.prizes[0] + assert w.awarded_cash_amount is None + + cs = contest_manager.get_many_by_user_won(user_with_money) + assert len(cs) == 1 + c = cs[0] + w = c.user_winnings[0] + assert w.prize == c.prizes[0] + assert w.user.user_id == user_with_money.user_id + + +class TestRaffleContestCRUDCount: + # This is a COUNT contest. No cash moves. Not really fleshed out what we'd do with this. + @pytest.mark.skip + def test_enter( + self, user_with_wallet: User, contest_factory, thl_lm, contest_manager + ): + c = contest_factory(entry_type=ContestEntryType.COUNT) + entry = ContestEntry( + entry_type=ContestEntryType.COUNT, + user=user_with_wallet, + amount=1, + ) + contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) diff --git a/tests/managers/thl/test_harmonized_uqa.py b/tests/managers/thl/test_harmonized_uqa.py new file mode 100644 index 0000000..6bbbbe1 --- /dev/null +++ b/tests/managers/thl/test_harmonized_uqa.py @@ -0,0 +1,116 @@ +from datetime import datetime, timezone + +import pytest + +from generalresearch.managers.thl.profiling.uqa import UQAManager +from generalresearch.models.thl.profiling.user_question_answer import ( + UserQuestionAnswer, + DUMMY_UQA, +) +from generalresearch.models.thl.user import User + + +@pytest.mark.usefixtures("uqa_db_index", "upk_data", "uqa_manager_clear_cache") +class TestUQAManager: + + def test_init(self, uqa_manager: UQAManager, user: User): + uqas = uqa_manager.get(user) + assert len(uqas) == 0 + + def test_create(self, uqa_manager: UQAManager, user: User): + now = datetime.now(tz=timezone.utc) + uqas = [ + UserQuestionAnswer( + user_id=user.user_id, + question_id="fd5bd491b75a491aa7251159680bf1f1", + country_iso="us", + language_iso="eng", + answer=("2",), + timestamp=now, + property_code="m:job_role", + calc_answers={"m:job_role": ("2",)}, + ) + ] + uqa_manager.create(user, uqas) + + res = uqa_manager.get(user=user) + assert len(res) == 1 + assert res[0] == uqas[0] + + # Same question, so this gets updated + now = datetime.now(tz=timezone.utc) + uqas_update = [ + UserQuestionAnswer( + user_id=user.user_id, + question_id="fd5bd491b75a491aa7251159680bf1f1", + country_iso="us", + language_iso="eng", + answer=("3",), + timestamp=now, + property_code="m:job_role", + calc_answers={"m:job_role": ("3",)}, + ) + ] + uqa_manager.create(user, uqas_update) + res = uqa_manager.get(user=user) + assert len(res) == 1 + assert res[0] == uqas_update[0] + + # Add a new answer + now = datetime.now(tz=timezone.utc) + uqas_new = [ + UserQuestionAnswer( + user_id=user.user_id, + question_id="3b65220db85f442ca16bb0f1c0e3a456", + country_iso="us", + language_iso="eng", + answer=("3",), + timestamp=now, + property_code="gr:children_age_gender", + calc_answers={"gr:children_age_gender": ("3",)}, + ) + ] + uqa_manager.create(user, uqas_new) + res = uqa_manager.get(user=user) + assert len(res) == 2 + assert res[1] == uqas_update[0] + assert res[0] == uqas_new[0] + + +@pytest.mark.usefixtures("uqa_db_index", "upk_data", "uqa_manager_clear_cache") +class TestUQAManagerCache: + + def test_get_uqa_empty(self, uqa_manager: UQAManager, user: User, caplog): + res = uqa_manager.get(user=user) + assert len(res) == 0 + + res = uqa_manager.get_from_db(user=user) + assert len(res) == 0 + + # Checking that the cache has only the dummy_uqa in it + res = uqa_manager.get_from_cache(user=user) + assert res == [DUMMY_UQA] + + with caplog.at_level("INFO"): + res = uqa_manager.get(user=user) + assert f"thl-grpc:uqa-cache-v2:{user.user_id} exists" in caplog.text + assert len(res) == 0 + + def test_get_uqa(self, uqa_manager: UQAManager, user: User, caplog): + + # Now the user sends an answer + uqas = [ + UserQuestionAnswer( + question_id="5d6d9f3c03bb40bf9d0a24f306387d7c", + answer=("1",), + timestamp=datetime.now(tz=timezone.utc), + country_iso="us", + language_iso="eng", + property_code="gr:gender", + user_id=user.user_id, + calc_answers={}, + ) + ] + uqa_manager.update_cache(user=user, uqas=uqas) + res = uqa_manager.get_from_cache(user=user) + assert res == uqas diff --git a/tests/managers/thl/test_ipinfo.py b/tests/managers/thl/test_ipinfo.py new file mode 100644 index 0000000..847b00c --- /dev/null +++ b/tests/managers/thl/test_ipinfo.py @@ -0,0 +1,117 @@ +import faker + +from generalresearch.managers.thl.ipinfo import ( + IPGeonameManager, + IPInformationManager, + GeoIpInfoManager, +) +from generalresearch.models.thl.ipinfo import IPGeoname, IPInformation + +fake = faker.Faker() + + +class TestIPGeonameManager: + + def test_init(self, thl_web_rr, ip_geoname_manager: IPGeonameManager): + + instance = IPGeonameManager(pg_config=thl_web_rr) + assert isinstance(instance, IPGeonameManager) + assert isinstance(ip_geoname_manager, IPGeonameManager) + + def test_create(self, ip_geoname_manager: IPGeonameManager): + + instance = ip_geoname_manager.create_dummy() + + assert isinstance(instance, IPGeoname) + + res = ip_geoname_manager.fetch_geoname_ids(filter_ids=[instance.geoname_id]) + + assert res[0].model_dump_json() == instance.model_dump_json() + + +class TestIPInformationManager: + + def test_init(self, thl_web_rr, ip_information_manager: IPInformationManager): + instance = IPInformationManager(pg_config=thl_web_rr) + assert isinstance(instance, IPInformationManager) + assert isinstance(ip_information_manager, IPInformationManager) + + def test_create(self, ip_information_manager: IPInformationManager): + instance = ip_information_manager.create_dummy() + + assert isinstance(instance, IPInformation) + + res = ip_information_manager.fetch_ip_information(filter_ips=[instance.ip]) + + assert res[0].model_dump_json() == instance.model_dump_json() + + def test_prefetch_geoname(self, ip_information, ip_geoname, thl_web_rr): + assert isinstance(ip_information, IPInformation) + + assert ip_information.geoname_id == ip_geoname.geoname_id + assert ip_information.geoname is None + + ip_information.prefetch_geoname(pg_config=thl_web_rr) + assert isinstance(ip_information.geoname, IPGeoname) + + +class TestGeoIpInfoManager: + def test_init( + self, thl_web_rr, thl_redis_config, geoipinfo_manager: GeoIpInfoManager + ): + instance = GeoIpInfoManager(pg_config=thl_web_rr, redis_config=thl_redis_config) + assert isinstance(instance, GeoIpInfoManager) + assert isinstance(geoipinfo_manager, GeoIpInfoManager) + + def test_multi(self, ip_information_factory, ip_geoname, geoipinfo_manager): + ip = fake.ipv4_public() + ip_information_factory(ip=ip, geoname=ip_geoname) + ips = [ip] + + # This only looks up in redis. They don't exist yet + res = geoipinfo_manager.get_cache_multi(ip_addresses=ips) + assert res == {ip: None} + + # Looks up in redis, if not exists, looks in mysql, then sets + # the caches that didn't exist. + res = geoipinfo_manager.get_multi(ip_addresses=ips) + assert res[ip] is not None + + ip2 = fake.ipv4_public() + ip_information_factory(ip=ip2, geoname=ip_geoname) + ips = [ip, ip2] + res = geoipinfo_manager.get_cache_multi(ip_addresses=ips) + assert res[ip] is not None + assert res[ip2] is None + res = geoipinfo_manager.get_multi(ip_addresses=ips) + assert res[ip] is not None + assert res[ip2] is not None + res = geoipinfo_manager.get_cache_multi(ip_addresses=ips) + assert res[ip] is not None + assert res[ip2] is not None + + def test_multi_ipv6(self, ip_information_factory, ip_geoname, geoipinfo_manager): + ip = fake.ipv6() + # Make another IP that will be in the same /64 block. + ip2 = ip[:-1] + "a" if ip[-1] != "a" else ip[:-1] + "b" + ip_information_factory(ip=ip, geoname=ip_geoname) + ips = [ip, ip2] + print(f"{ips=}") + + # This only looks up in redis. They don't exist yet + res = geoipinfo_manager.get_cache_multi(ip_addresses=ips) + assert res == {ip: None, ip2: None} + + # Looks up in redis, if not exists, looks in mysql, then sets + # the caches that didn't exist. + res = geoipinfo_manager.get_multi(ip_addresses=ips) + assert res[ip].ip == ip + assert res[ip].lookup_prefix == "/64" + assert res[ip2].ip == ip2 + assert res[ip2].lookup_prefix == "/64" + # they should be the same basically, except for the ip + + def test_doesnt_exist(self, geoipinfo_manager): + ip = fake.ipv4_public() + res = geoipinfo_manager.get_multi(ip_addresses=[ip]) + assert res == {ip: None} diff --git a/tests/managers/thl/test_ledger/__init__.py b/tests/managers/thl/test_ledger/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/managers/thl/test_ledger/__init__.py diff --git a/tests/managers/thl/test_ledger/test_lm_accounts.py b/tests/managers/thl/test_ledger/test_lm_accounts.py new file mode 100644 index 0000000..e0d9b0b --- /dev/null +++ b/tests/managers/thl/test_ledger/test_lm_accounts.py @@ -0,0 +1,268 @@ +from itertools import product as iproduct +from random import randint +from uuid import uuid4 + +import pytest + +from generalresearch.currency import LedgerCurrency +from generalresearch.managers.base import Permission +from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, +) +from generalresearch.managers.thl.ledger_manager.ledger import LedgerManager +from generalresearch.models.thl.ledger import LedgerAccount, AccountType, Direction +from generalresearch.models.thl.ledger import ( + LedgerEntry, +) +from test_utils.managers.ledger.conftest import ledger_account + + +@pytest.mark.parametrize( + argnames="currency, kind, acct_id", + argvalues=list( + iproduct( + ["USD", "test", "EUR"], + ["expense", "wallet", "revenue", "cash"], + [uuid4().hex for i in range(3)], + ) + ), +) +class TestLedgerAccountManagerNoResults: + + def test_get_account_no_results(self, currency, kind, acct_id, lm): + """Try to query for accounts that we know don't exist and confirm that + we either get the expected None result or it raises the correct + exception + """ + qn = ":".join([currency, kind, acct_id]) + + # (1) .get_account is just a wrapper for .get_account_many_ but + # call it either way + assert lm.get_account(qualified_name=qn, raise_on_error=False) is None + + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + lm.get_account(qualified_name=qn, raise_on_error=True) + + # (2) .get_account_if_exists is another wrapper + assert lm.get_account(qualified_name=qn, raise_on_error=False) is None + + def test_get_account_no_results_many(self, currency, kind, acct_id, lm): + qn = ":".join([currency, kind, acct_id]) + + # (1) .get_many_ + assert lm.get_account_many_(qualified_names=[qn], raise_on_error=False) == [] + + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + lm.get_account_many_(qualified_names=[qn], raise_on_error=True) + + # (2) .get_many + assert lm.get_account_many(qualified_names=[qn], raise_on_error=False) == [] + + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + lm.get_account_many(qualified_names=[qn], raise_on_error=True) + + # (3) .get_accounts(..) + assert lm.get_accounts_if_exists(qualified_names=[qn]) == [] + + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + lm.get_accounts(qualified_names=[qn]) + + +@pytest.mark.parametrize( + argnames="currency, account_type, direction", + argvalues=list( + iproduct( + list(LedgerCurrency), + list(AccountType), + list(Direction), + ) + ), +) +class TestLedgerAccountManagerCreate: + + def test_create_account_error_permission( + self, currency, account_type, direction, lm + ): + """Confirm that the Permission values that are set on the Ledger Manger + allow the Creation action to occur. + """ + acct_uuid = uuid4().hex + + account = LedgerAccount( + display_name=f"test-{uuid4().hex}", + currency=currency, + qualified_name=f"{currency.value}:{account_type.value}:{acct_uuid}", + account_type=account_type, + normal_balance=direction, + ) + + # (1) With no Permissions defined + test_lm = LedgerManager( + pg_config=lm.pg_config, + permissions=[], + redis_config=lm.redis_config, + cache_prefix=lm.cache_prefix, + testing=lm.testing, + ) + + with pytest.raises(expected_exception=AssertionError) as excinfo: + test_lm.create_account(account=account) + assert ( + str(excinfo.value) == "LedgerManager does not have sufficient permissions" + ) + + # (2) With Permissions defined, but not CREATE + test_lm = LedgerManager( + pg_config=lm.pg_config, + permissions=[Permission.READ, Permission.UPDATE, Permission.DELETE], + redis_config=lm.redis_config, + cache_prefix=lm.cache_prefix, + testing=lm.testing, + ) + + with pytest.raises(expected_exception=AssertionError) as excinfo: + test_lm.create_account(account=account) + assert ( + str(excinfo.value) == "LedgerManager does not have sufficient permissions" + ) + + def test_create(self, currency, account_type, direction, lm): + """Confirm that the Permission values that are set on the Ledger Manger + allow the Creation action to occur. + """ + + acct_uuid = uuid4().hex + qn = f"{currency.value}:{account_type.value}:{acct_uuid}" + + acct_model = LedgerAccount( + uuid=acct_uuid, + display_name=f"test-{uuid4().hex}", + currency=currency, + qualified_name=qn, + account_type=account_type, + normal_balance=direction, + ) + account = lm.create_account(account=acct_model) + assert isinstance(account, LedgerAccount) + + # Query for, and make sure the Account was saved in the DB + res = lm.get_account(qualified_name=qn, raise_on_error=True) + assert account.uuid == res.uuid + + def test_get_or_create(self, currency, account_type, direction, lm): + """Confirm that the Permission values that are set on the Ledger Manger + allow the Creation action to occur. + """ + + acct_uuid = uuid4().hex + qn = f"{currency.value}:{account_type.value}:{acct_uuid}" + + acct_model = LedgerAccount( + uuid=acct_uuid, + display_name=f"test-{uuid4().hex}", + currency=currency, + qualified_name=qn, + account_type=account_type, + normal_balance=direction, + ) + account = lm.get_account_or_create(account=acct_model) + assert isinstance(account, LedgerAccount) + + # Query for, and make sure the Account was saved in the DB + res = lm.get_account(qualified_name=qn, raise_on_error=True) + assert account.uuid == res.uuid + + +class TestLedgerAccountManagerGet: + + def test_get(self, ledger_account, lm): + res = lm.get_account(qualified_name=ledger_account.qualified_name) + assert res.uuid == ledger_account.uuid + + res = lm.get_account_many(qualified_names=[ledger_account.qualified_name]) + assert len(res) == 1 + assert res[0].uuid == ledger_account.uuid + + res = lm.get_accounts(qualified_names=[ledger_account.qualified_name]) + assert len(res) == 1 + assert res[0].uuid == ledger_account.uuid + + # TODO: I can't test the get_balance without first having Transaction + # creation working + + def test_get_balance_empty( + self, ledger_account, ledger_account_credit, ledger_account_debit, ledger_tx, lm + ): + res = lm.get_account_balance(account=ledger_account) + assert res == 0 + + res = lm.get_account_balance(account=ledger_account_credit) + assert res == 100 + + res = lm.get_account_balance(account=ledger_account_debit) + assert res == 100 + + @pytest.mark.parametrize("n_times", range(5)) + def test_get_account_filtered_balance( + self, + ledger_account, + ledger_account_credit, + ledger_account_debit, + ledger_tx, + n_times, + lm, + ): + """Try searching for random metadata and confirm it's always 0 because + Tx can be found. + """ + rand_key = f"key-{uuid4().hex[:10]}" + rand_value = uuid4().hex + + assert ( + lm.get_account_filtered_balance( + account=ledger_account, metadata_key=rand_key, metadata_value=rand_value + ) + == 0 + ) + + # Let's create a transaction with this metadata to confirm it saves + # and that we can filter it back + rand_amount = randint(10, 1_000) + + lm.create_tx( + entries=[ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=ledger_account_credit.uuid, + amount=rand_amount, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=ledger_account_debit.uuid, + amount=rand_amount, + ), + ], + metadata={rand_key: rand_value}, + ) + + assert ( + lm.get_account_filtered_balance( + account=ledger_account_credit, + metadata_key=rand_key, + metadata_value=rand_value, + ) + == rand_amount + ) + + assert ( + lm.get_account_filtered_balance( + account=ledger_account_debit, + metadata_key=rand_key, + metadata_value=rand_value, + ) + == rand_amount + ) + + def test_get_balance_timerange_empty(self, ledger_account, lm): + res = lm.get_account_balance_timerange(account=ledger_account) + assert res == 0 diff --git a/tests/managers/thl/test_ledger/test_lm_tx.py b/tests/managers/thl/test_ledger/test_lm_tx.py new file mode 100644 index 0000000..37b7ba3 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_lm_tx.py @@ -0,0 +1,235 @@ +from decimal import Decimal +from random import randint +from uuid import uuid4 + +import pytest + +from generalresearch.currency import LedgerCurrency +from generalresearch.managers.thl.ledger_manager.ledger import LedgerManager +from generalresearch.models.thl.ledger import ( + Direction, + LedgerEntry, + LedgerTransaction, +) + + +class TestLedgerManagerCreateTx: + + def test_create_account_error_permission(self, lm): + """Confirm that the Permission values that are set on the Ledger Manger + allow the Creation action to occur. + """ + acct_uuid = uuid4().hex + + # (1) With no Permissions defined + test_lm = LedgerManager( + pg_config=lm.pg_config, + permissions=[], + redis_config=lm.redis_config, + cache_prefix=lm.cache_prefix, + testing=lm.testing, + ) + + with pytest.raises(expected_exception=AssertionError) as excinfo: + test_lm.create_tx(entries=[]) + assert ( + str(excinfo.value) + == "LedgerTransactionManager has insufficient Permissions" + ) + + def test_create_assertions(self, ledger_account_debit, ledger_account_credit, lm): + with pytest.raises(expected_exception=ValueError) as excinfo: + lm.create_tx( + entries=[ + { + "direction": Direction.CREDIT, + "account_uuid": uuid4().hex, + "amount": randint(a=1, b=100), + } + ] + ) + assert ( + "Assertion failed, ledger transaction must have 2 or more entries" + in str(excinfo.value) + ) + + def test_create(self, ledger_account_credit, ledger_account_debit, lm): + amount = int(Decimal("1.00") * 100) + + entries = [ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=ledger_account_credit.uuid, + amount=amount, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=ledger_account_debit.uuid, + amount=amount, + ), + ] + + # Create a Transaction and validate the operation was successful + tx = lm.create_tx(entries=entries) + assert isinstance(tx, LedgerTransaction) + + res = lm.get_tx_by_id(transaction_id=tx.id) + assert isinstance(res, LedgerTransaction) + assert len(res.entries) == 2 + assert tx.id == res.id + + def test_create_and_reverse(self, ledger_account_credit, ledger_account_debit, lm): + amount = int(Decimal("1.00") * 100) + + entries = [ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=ledger_account_credit.uuid, + amount=amount, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=ledger_account_debit.uuid, + amount=amount, + ), + ] + + tx = lm.create_tx(entries=entries) + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.id == tx.id + + assert lm.get_account_balance(account=ledger_account_credit) == 100 + assert lm.get_account_balance(account=ledger_account_debit) == 100 + assert lm.check_ledger_balanced() is True + + # Reverse it + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=ledger_account_credit.uuid, + amount=amount, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=ledger_account_debit.uuid, + amount=amount, + ), + ] + + tx = lm.create_tx(entries=entries) + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.id == tx.id + + assert lm.get_account_balance(ledger_account_credit) == 0 + assert lm.get_account_balance(ledger_account_debit) == 0 + assert lm.check_ledger_balanced() + + # subtract again + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=ledger_account_credit.uuid, + amount=amount, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=ledger_account_debit.uuid, + amount=amount, + ), + ] + tx = lm.create_tx(entries=entries) + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.id == tx.id + + assert lm.get_account_balance(ledger_account_credit) == -100 + assert lm.get_account_balance(ledger_account_debit) == -100 + assert lm.check_ledger_balanced() + + +class TestLedgerManagerGetTx: + + # @pytest.mark.parametrize("currency", [LedgerCurrency.TEST], indirect=True) + def test_get_tx_by_id(self, ledger_tx, lm): + with pytest.raises(expected_exception=AssertionError): + lm.get_tx_by_id(transaction_id=ledger_tx) + + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + assert res.id == ledger_tx.id + + # @pytest.mark.parametrize("currency", [LedgerCurrency.TEST], indirect=True) + def test_get_tx_by_ids(self, ledger_tx, lm): + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + assert res.id == ledger_tx.id + + @pytest.mark.parametrize( + "tag", [f"{LedgerCurrency.TEST}:{uuid4().hex}"], indirect=True + ) + def test_get_tx_ids_by_tag(self, ledger_tx, tag, lm): + # (1) search for a random tag + res = lm.get_tx_ids_by_tag(tag="aaa:bbb") + assert isinstance(res, set) + assert len(res) == 0 + + # (2) search for the tag that was used during ledger_transaction creation + res = lm.get_tx_ids_by_tag(tag=tag) + assert isinstance(res, set) + assert len(res) == 1 + + def test_get_tx_by_tag(self, ledger_tx, tag, lm): + # (1) search for a random tag + res = lm.get_tx_by_tag(tag="aaa:bbb") + assert isinstance(res, list) + assert len(res) == 0 + + # (2) search for the tag that was used during ledger_transaction creation + res = lm.get_tx_by_tag(tag=tag) + assert isinstance(res, list) + assert len(res) == 1 + + assert isinstance(res[0], LedgerTransaction) + assert ledger_tx.id == res[0].id + + def test_get_tx_filtered_by_account( + self, ledger_tx, ledger_account, ledger_account_debit, ledger_account_credit, lm + ): + # (1) Do basic assertion checks first + with pytest.raises(expected_exception=AssertionError) as excinfo: + lm.get_tx_filtered_by_account(account_uuid=ledger_account) + assert str(excinfo.value) == "account_uuid must be a str" + + # (2) This search doesn't return anything because this ledger account + # wasn't actually used in the entries for the ledger_transaction + res = lm.get_tx_filtered_by_account(account_uuid=ledger_account.uuid) + assert len(res) == 0 + + # (3) Either the credit or the debit example ledger_accounts wll work + # to find this transaction because they're both used in the entries + res = lm.get_tx_filtered_by_account(account_uuid=ledger_account_debit.uuid) + assert len(res) == 1 + assert res[0].id == ledger_tx.id + + res = lm.get_tx_filtered_by_account(account_uuid=ledger_account_credit.uuid) + assert len(res) == 1 + assert ledger_tx.id == res[0].id + + res2 = lm.get_tx_by_id(transaction_id=ledger_tx.id) + assert res2.model_dump_json() == res[0].model_dump_json() + + def test_filter_metadata(self, ledger_tx, tx_metadata, lm): + key, value = next(iter(tx_metadata.items())) + + # (1) Confirm a random key,value pair returns nothing + res = lm.get_tx_filtered_by_metadata( + metadata_key=f"key-{uuid4().hex[:10]}", metadata_value=uuid4().hex[:12] + ) + assert len(res) == 0 + + # (2) confirm a key,value pair return the correct results + res = lm.get_tx_filtered_by_metadata(metadata_key=key, metadata_value=value) + assert len(res) == 1 + + # assert 0 == THL_lm.get_filtered_account_balance(account2, "thl_wall", "ccc") + # assert 300 == THL_lm.get_filtered_account_balance(account1, "thl_wall", "aaa") + # assert 300 == THL_lm.get_filtered_account_balance(account2, "thl_wall", "aaa") + # assert 0 == THL_lm.get_filtered_account_balance(account3, "thl_wall", "ccc") + # assert THL_lm.check_ledger_balanced() diff --git a/tests/managers/thl/test_ledger/test_lm_tx_entries.py b/tests/managers/thl/test_ledger/test_lm_tx_entries.py new file mode 100644 index 0000000..5bf1c48 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_lm_tx_entries.py @@ -0,0 +1,26 @@ +from generalresearch.models.thl.ledger import LedgerEntry + + +class TestLedgerEntryManager: + + def test_get_tx_entries_by_tx(self, ledger_tx, lm): + # First confirm the Ledger TX exists with 2 Entries + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + assert len(res.entries) == 2 + + tx_entries = lm.get_tx_entries_by_tx(transaction=ledger_tx) + assert len(tx_entries) == 2 + + assert res.entries == tx_entries + assert isinstance(tx_entries[0], LedgerEntry) + + def test_get_tx_entries_by_txs(self, ledger_tx, lm): + # First confirm the Ledger TX exists with 2 Entries + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + assert len(res.entries) == 2 + + tx_entries = lm.get_tx_entries_by_txs(transactions=[ledger_tx]) + assert len(tx_entries) == 2 + + assert res.entries == tx_entries + assert isinstance(tx_entries[0], LedgerEntry) diff --git a/tests/managers/thl/test_ledger/test_lm_tx_locks.py b/tests/managers/thl/test_ledger/test_lm_tx_locks.py new file mode 100644 index 0000000..df2611b --- /dev/null +++ b/tests/managers/thl/test_ledger/test_lm_tx_locks.py @@ -0,0 +1,371 @@ +import logging +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from typing import Callable + +import pytest + +from generalresearch.managers.thl.ledger_manager.conditions import ( + generate_condition_mp_payment, +) +from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerTransactionCreateLockError, + LedgerTransactionFlagAlreadyExistsError, + LedgerTransactionCreateError, +) +from generalresearch.models import Source +from generalresearch.models.thl.ledger import LedgerTransaction +from generalresearch.models.thl.session import ( + Wall, + Status, + StatusCode1, + Session, + WallAdjustedStatus, +) +from generalresearch.models.thl.user import User +from test_utils.models.conftest import user_factory, session, product_user_wallet_no + +logger = logging.getLogger("LedgerManager") + + +class TestLedgerLocks: + + def test_a( + self, + user_factory, + session_factory, + product_user_wallet_no, + create_main_accounts, + caplog, + thl_lm, + lm, + utc_hour_ago, + currency, + wall_factory, + delete_ledger_db, + ): + """ + TODO: This whole test is confusing a I don't really understand. + It needs to be better documented and explained what we want + it to do and evaluate... + """ + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_user_wallet_no) + s1 = session_factory( + user=user, + wall_count=3, + wall_req_cpis=[Decimal("1.23"), Decimal("3.21"), Decimal("4")], + wall_statuses=[Status.COMPLETE, Status.COMPLETE, Status.COMPLETE], + ) + + # A User does a Wall Completion in Session=1 + w1 = s1.wall_events[0] + tx = thl_lm.create_tx_task_complete(wall=w1, user=user, created=w1.started) + assert isinstance(tx, LedgerTransaction) + + # A User does another Wall Completion in Session=1 + w2 = s1.wall_events[1] + tx = thl_lm.create_tx_task_complete(wall=w2, user=user, created=w2.started) + assert isinstance(tx, LedgerTransaction) + + # That first Wall Complete was "adjusted" to instead be marked + # as a Failure + w1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + tx = thl_lm.create_tx_task_adjustment(wall=w1, user=user) + assert isinstance(tx, LedgerTransaction) + + # A User does another! Wall Completion in Session=1; however, we + # don't create a transaction for it + w3 = s1.wall_events[2] + + # Make sure we clear any flags/locks first + lock_key = f"{currency.value}:thl_wall:{w3.uuid}" + lock_name = f"{lm.cache_prefix}:transaction_lock:{lock_key}" + flag_name = f"{lm.cache_prefix}:transaction_flag:{lock_key}" + lm.redis_client.delete(lock_name) + lm.redis_client.delete(flag_name) + + # Despite the + f1 = generate_condition_mp_payment(wall=w1) + f2 = generate_condition_mp_payment(wall=w2) + f3 = generate_condition_mp_payment(wall=w3) + assert f1(lm=lm) is False + assert f2(lm=lm) is False + assert f3(lm=lm) is True + + condition = f3 + create_tx_func = lambda: thl_lm.create_tx_task_complete_(wall=w3, user=user) + assert isinstance(create_tx_func, Callable) + assert f3(lm) is True + + lm.redis_client.delete(flag_name) + lm.redis_client.delete(lock_name) + + tx = thl_lm.create_tx_protected( + lock_key=lock_key, condition=condition, create_tx_func=create_tx_func + ) + assert f3(lm) is False + + # purposely hold the lock open + tx = None + lm.redis_client.set(lock_name, "1") + with caplog.at_level(logging.ERROR): + with pytest.raises(expected_exception=LedgerTransactionCreateLockError): + tx = thl_lm.create_tx_protected( + lock_key=lock_key, + condition=condition, + create_tx_func=create_tx_func, + ) + assert tx is None + assert "Unable to acquire lock within the time specified" in caplog.text + lm.redis_client.delete(lock_name) + + def test_locking( + self, + user_factory, + product_user_wallet_no, + create_main_accounts, + delete_ledger_db, + caplog, + thl_lm, + lm, + ): + delete_ledger_db() + create_main_accounts() + + now = datetime.now(timezone.utc) - timedelta(hours=1) + user: User = user_factory(product=product_user_wallet_no) + + # A User does a Wall complete on Session.id=1 and the transaction is + # logged to the ledger + wall1 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.23"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=now, + finished=now + timedelta(seconds=1), + ) + thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started) + + # A User does a Wall complete on Session.id=1 and the transaction is + # logged to the ledger + wall2 = Wall( + user_id=user.user_id, + source=Source.FULL_CIRCLE, + req_survey_id="yyy", + req_cpi=Decimal("3.21"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=now, + finished=now + timedelta(seconds=1), + ) + thl_lm.create_tx_task_complete(wall=wall2, user=user, created=wall2.started) + + # An hour later, the first wall complete is adjusted to a Failure and + # it's tracked in the ledger + wall1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=now + timedelta(hours=1), + ) + thl_lm.create_tx_task_adjustment(wall=wall1, user=user) + + # A User does a Wall complete on Session.id=1 and the transaction + # IS NOT logged to the ledger + wall3 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("4"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=now, + finished=now + timedelta(seconds=1), + uuid="867a282d8b4d40d2a2093d75b802b629", + ) + + revenue_account = thl_lm.get_account_task_complete_revenue() + assert 0 == thl_lm.get_account_filtered_balance( + account=revenue_account, + metadata_key="thl_wall", + metadata_value=wall3.uuid, + ) + # Make sure we clear any flags/locks first + lock_key = f"test:thl_wall:{wall3.uuid}" + lock_name = f"{lm.cache_prefix}:transaction_lock:{lock_key}" + flag_name = f"{lm.cache_prefix}:transaction_flag:{lock_key}" + lm.redis_client.delete(lock_name) + lm.redis_client.delete(flag_name) + + # Purposely hold the lock open + lm.redis_client.set(name=lock_name, value="1") + with caplog.at_level(logging.DEBUG): + with pytest.raises(expected_exception=LedgerTransactionCreateLockError): + tx = thl_lm.create_tx_task_complete( + wall=wall3, user=user, created=wall3.started + ) + assert isinstance(tx, LedgerTransaction) + assert "Unable to acquire lock within the time specified" in caplog.text + + # Release the lock + lm.redis_client.delete(lock_name) + + # Set the redis flag to indicate it has been run + lm.redis_client.set(flag_name, "1") + # with self.assertLogs(logger=logger, level=logging.DEBUG) as cm2: + with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError): + tx = thl_lm.create_tx_task_complete( + wall=wall3, user=user, created=wall3.started + ) + # self.assertIn("entered_lock: True, flag_set: True", cm2.output[0]) + + # Unset the flag + lm.redis_client.delete(flag_name) + + assert 0 == lm.get_account_filtered_balance( + account=revenue_account, + metadata_key="thl_wall", + metadata_value=wall3.uuid, + ) + + # Now actually run it + tx = thl_lm.create_tx_task_complete( + wall=wall3, user=user, created=wall3.started + ) + assert tx is not None + + # Run it again, should return None + # Confirm the Exception inheritance works + tx = None + with pytest.raises(expected_exception=LedgerTransactionCreateError): + tx = thl_lm.create_tx_task_complete( + wall=wall3, user=user, created=wall3.started + ) + assert tx is None + + # clear the redis flag, it should query the db + assert lm.redis_client.get(flag_name) is not None + lm.redis_client.delete(flag_name) + assert lm.redis_client.get(flag_name) is None + + with pytest.raises(expected_exception=LedgerTransactionCreateError): + tx = thl_lm.create_tx_task_complete( + wall=wall3, user=user, created=wall3.started + ) + + assert 400 == thl_lm.get_account_filtered_balance( + account=revenue_account, + metadata_key="thl_wall", + metadata_value=wall3.uuid, + ) + + def test_bp_payment_without_locks( + self, user_factory, product_user_wallet_no, create_main_accounts, thl_lm, lm + ): + user: User = user_factory(product=product_user_wallet_no) + wall1 = Wall( + user_id=user.user_id, + source=Source.SAGO, + req_survey_id="xxx", + req_cpi=Decimal("0.50"), + session_id=3, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + + thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started) + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": session.started + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + print(thl_net, commission_amount, bp_pay, user_pay) + + # Run it 3 times without any checks, and it gets made three times! + thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + thl_lm.create_tx_bp_payment_(session=session, created=wall1.started) + thl_lm.create_tx_bp_payment_(session=session, created=wall1.started) + + bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=user.product) + assert 48 * 3 == lm.get_account_balance(account=bp_wallet) + assert 48 * 3 == thl_lm.get_account_filtered_balance( + account=bp_wallet, metadata_key="thl_session", metadata_value=session.uuid + ) + assert lm.check_ledger_balanced() + + def test_bp_payment_with_locks( + self, user_factory, product_user_wallet_no, create_main_accounts, thl_lm, lm + ): + user: User = user_factory(product=product_user_wallet_no) + + wall1 = Wall( + user_id=user.user_id, + source=Source.SAGO, + req_survey_id="xxx", + req_cpi=Decimal("0.50"), + session_id=3, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + + thl_lm.create_tx_task_complete(wall1, user, created=wall1.started) + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": session.started + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + print(thl_net, commission_amount, bp_pay, user_pay) + + # Make sure we clear any flags/locks first + lock_key = f"test:thl_wall:{wall1.uuid}" + lock_name = f"{lm.cache_prefix}:transaction_lock:{lock_key}" + flag_name = f"{lm.cache_prefix}:transaction_flag:{lock_key}" + lm.redis_client.delete(lock_name) + lm.redis_client.delete(flag_name) + + # Run it 3 times with check, and it gets made once! + thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + with pytest.raises(expected_exception=LedgerTransactionCreateError): + thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + + with pytest.raises(expected_exception=LedgerTransactionCreateError): + thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + + bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=user.product) + assert 48 == thl_lm.get_account_balance(bp_wallet) + assert 48 == thl_lm.get_account_filtered_balance( + account=bp_wallet, + metadata_key="thl_session", + metadata_value=session.uuid, + ) + assert lm.check_ledger_balanced() diff --git a/tests/managers/thl/test_ledger/test_lm_tx_metadata.py b/tests/managers/thl/test_ledger/test_lm_tx_metadata.py new file mode 100644 index 0000000..5d12633 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_lm_tx_metadata.py @@ -0,0 +1,34 @@ +class TestLedgerMetadataManager: + + def test_get_tx_metadata_by_txs(self, ledger_tx, lm): + # First confirm the Ledger TX exists with 2 Entries + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + assert isinstance(res.metadata, dict) + + tx_metadatas = lm.get_tx_metadata_by_txs(transactions=[ledger_tx]) + assert isinstance(tx_metadatas, dict) + assert isinstance(tx_metadatas[ledger_tx.id], dict) + + assert res.metadata == tx_metadatas[ledger_tx.id] + + def test_get_tx_metadata_ids_by_tx(self, ledger_tx, lm): + # First confirm the Ledger TX exists with 2 Entries + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + tx_metadata_cnt = len(res.metadata.keys()) + + tx_metadata_ids = lm.get_tx_metadata_ids_by_tx(transaction=ledger_tx) + assert isinstance(tx_metadata_ids, set) + assert isinstance(list(tx_metadata_ids)[0], int) + + assert tx_metadata_cnt == len(tx_metadata_ids) + + def test_get_tx_metadata_ids_by_txs(self, ledger_tx, lm): + # First confirm the Ledger TX exists with 2 Entries + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + tx_metadata_cnt = len(res.metadata.keys()) + + tx_metadata_ids = lm.get_tx_metadata_ids_by_txs(transactions=[ledger_tx]) + assert isinstance(tx_metadata_ids, set) + assert isinstance(list(tx_metadata_ids)[0], int) + + assert tx_metadata_cnt == len(tx_metadata_ids) diff --git a/tests/managers/thl/test_ledger/test_thl_lm_accounts.py b/tests/managers/thl/test_ledger/test_thl_lm_accounts.py new file mode 100644 index 0000000..01d5fe1 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_thl_lm_accounts.py @@ -0,0 +1,411 @@ +from uuid import uuid4 + +import pytest + + +class TestThlLedgerManagerAccounts: + + def test_get_account_or_create_user_wallet(self, user, thl_lm, lm): + from generalresearch.currency import LedgerCurrency + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + account = thl_lm.get_account_or_create_user_wallet(user=user) + assert isinstance(account, LedgerAccount) + + assert user.uuid in account.qualified_name + assert account.display_name == f"User Wallet {user.uuid}" + assert account.account_type == AccountType.USER_WALLET + assert account.normal_balance == Direction.CREDIT + assert account.reference_type == "user" + assert account.reference_uuid == user.uuid + assert account.currency == LedgerCurrency.TEST + + # Actually query for it to confirm + res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True) + assert res.model_dump_json() == account.model_dump_json() + + def test_get_account_or_create_bp_wallet(self, product, thl_lm, lm): + from generalresearch.currency import LedgerCurrency + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + account = thl_lm.get_account_or_create_bp_wallet(product=product) + assert isinstance(account, LedgerAccount) + + assert product.uuid in account.qualified_name + assert account.display_name == f"BP Wallet {product.uuid}" + assert account.account_type == AccountType.BP_WALLET + assert account.normal_balance == Direction.CREDIT + assert account.reference_type == "bp" + assert account.reference_uuid == product.uuid + assert account.currency == LedgerCurrency.TEST + + # Actually query for it to confirm + res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True) + assert res.model_dump_json() == account.model_dump_json() + + def test_get_account_or_create_bp_commission(self, product, thl_lm, lm): + from generalresearch.currency import LedgerCurrency + from generalresearch.models.thl.ledger import ( + Direction, + AccountType, + ) + + account = thl_lm.get_account_or_create_bp_commission(product=product) + + assert product.uuid in account.qualified_name + assert account.display_name == f"Revenue from commission {product.uuid}" + assert account.account_type == AccountType.REVENUE + assert account.normal_balance == Direction.CREDIT + assert account.reference_type == "bp" + assert account.reference_uuid == product.uuid + assert account.currency == LedgerCurrency.TEST + + # Actually query for it to confirm + res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True) + assert res.model_dump_json() == account.model_dump_json() + + @pytest.mark.parametrize("expense", ["tango", "paypal", "gift", "tremendous"]) + def test_get_account_or_create_bp_expense(self, product, expense, thl_lm, lm): + from generalresearch.currency import LedgerCurrency + from generalresearch.models.thl.ledger import ( + Direction, + AccountType, + ) + + account = thl_lm.get_account_or_create_bp_expense( + product=product, expense_name=expense + ) + assert product.uuid in account.qualified_name + assert account.display_name == f"Expense {expense} {product.uuid}" + assert account.account_type == AccountType.EXPENSE + assert account.normal_balance == Direction.DEBIT + assert account.reference_type == "bp" + assert account.reference_uuid == product.uuid + assert account.currency == LedgerCurrency.TEST + + # Actually query for it to confirm + res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True) + assert res.model_dump_json() == account.model_dump_json() + + def test_get_or_create_bp_pending_payout_account(self, product, thl_lm, lm): + from generalresearch.currency import LedgerCurrency + from generalresearch.models.thl.ledger import ( + Direction, + AccountType, + ) + + account = thl_lm.get_or_create_bp_pending_payout_account(product=product) + + assert product.uuid in account.qualified_name + assert account.display_name == f"BP Wallet Pending {product.uuid}" + assert account.account_type == AccountType.BP_WALLET + assert account.normal_balance == Direction.CREDIT + assert account.reference_type == "bp" + assert account.reference_uuid == product.uuid + assert account.currency == LedgerCurrency.TEST + + # Actually query for it to confirm + res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True) + assert res.model_dump_json() == account.model_dump_json() + + def test_get_account_task_complete_revenue_raises( + self, delete_ledger_db, thl_lm, lm + ): + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + + delete_ledger_db() + + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + thl_lm.get_account_task_complete_revenue() + + def test_get_account_task_complete_revenue( + self, account_cash, account_revenue_task_complete, thl_lm, lm + ): + from generalresearch.models.thl.ledger import ( + LedgerAccount, + AccountType, + ) + + res = thl_lm.get_account_task_complete_revenue() + assert isinstance(res, LedgerAccount) + assert res.reference_type is None + assert res.reference_uuid is None + assert res.account_type == AccountType.REVENUE + assert res.display_name == "Cash flow task complete" + + def test_get_account_cash_raises(self, delete_ledger_db, thl_lm, lm): + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + + delete_ledger_db() + + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + thl_lm.get_account_cash() + + def test_get_account_cash(self, account_cash, thl_lm, lm): + from generalresearch.models.thl.ledger import ( + LedgerAccount, + AccountType, + ) + + res = thl_lm.get_account_cash() + assert isinstance(res, LedgerAccount) + assert res.reference_type is None + assert res.reference_uuid is None + assert res.account_type == AccountType.CASH + assert res.display_name == "Operating Cash Account" + + def test_get_accounts(self, setup_accounts, product, user_factory, thl_lm, lm, lam): + from generalresearch.models.thl.user import User + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + + user1: User = user_factory(product=product) + user2: User = user_factory(product=product) + + account1 = thl_lm.get_account_or_create_bp_wallet(product=product) + + # (1) known account and confirm it comes back + res = lm.get_account(qualified_name=account1.qualified_name) + assert account1.model_dump_json() == res.model_dump_json() + + # (2) known accounts and confirm they both come back + res = lam.get_accounts(qualified_names=[account1.qualified_name]) + assert isinstance(res, list) + assert len(res) == 1 + assert account1 in res + + # Get 2 known and 1 made up qualified names, and confirm it raises + # an error + with pytest.raises(LedgerAccountDoesntExistError): + lam.get_accounts( + qualified_names=[ + account1.qualified_name, + f"test:bp_wall:{uuid4().hex}", + ] + ) + + def test_get_accounts_if_exists(self, product_factory, currency, thl_lm, lm): + from generalresearch.models.thl.product import Product + + p1: Product = product_factory() + p2: Product = product_factory() + + account1 = thl_lm.get_account_or_create_bp_wallet(product=p1) + account2 = thl_lm.get_account_or_create_bp_wallet(product=p2) + + # (1) known account and confirm it comes back + res = lm.get_account(qualified_name=account1.qualified_name) + assert account1.model_dump_json() == res.model_dump_json() + + # (2) known accounts and confirm they both come back + res = lm.get_accounts( + qualified_names=[account1.qualified_name, account2.qualified_name] + ) + assert isinstance(res, list) + assert len(res) == 2 + assert account1 in res + assert account2 in res + + # Get 2 known and 1 made up qualified names, and confirm only 2 + # come back + lm.get_accounts_if_exists( + qualified_names=[ + account1.qualified_name, + account2.qualified_name, + f"{currency.value}:bp_wall:{uuid4().hex}", + ] + ) + + assert isinstance(res, list) + assert len(res) == 2 + + # Confirm an empty array comes back for all unknown qualified names + res = lm.get_accounts_if_exists( + qualified_names=[ + f"{lm.currency.value}:bp_wall:{uuid4().hex}" for i in range(5) + ] + ) + assert isinstance(res, list) + assert len(res) == 0 + + def test_get_accounts_for_products(self, product_factory, thl_lm, lm): + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + from generalresearch.models.thl.ledger import ( + LedgerAccount, + ) + + # Create 5 Products + product_uuids = [] + for i in range(5): + _p = product_factory() + product_uuids.append(_p.uuid) + + # Confirm that this fails.. because none of those accounts have been + # created yet + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + thl_lm.get_accounts_bp_wallet_for_products(product_uuids=product_uuids) + + # Create the bp_wallet accounts and then try again + for p_uuid in product_uuids: + thl_lm.get_account_or_create_bp_wallet_by_uuid(product_uuid=p_uuid) + + res = thl_lm.get_accounts_bp_wallet_for_products(product_uuids=product_uuids) + assert len(res) == len(product_uuids) + assert all([isinstance(i, LedgerAccount) for i in res]) + + +class TestLedgerAccountManager: + + def test_get_or_create(self, thl_lm, lm, lam): + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + u = uuid4().hex + name = f"test-{u[:8]}" + + account = LedgerAccount( + display_name=name, + qualified_name=f"test:bp_wallet:{u}", + normal_balance=Direction.DEBIT, + account_type=AccountType.BP_WALLET, + currency="test", + reference_type="bp", + reference_uuid=u, + ) + + # First we want to validate that using the get_account method raises + # an error for a random LedgerAccount which we know does not exist. + with pytest.raises(LedgerAccountDoesntExistError): + lam.get_account(qualified_name=account.qualified_name) + + # Now that we know it doesn't exist, get_or_create for it + instance = lam.get_account_or_create(account=account) + + # It should always return + assert isinstance(instance, LedgerAccount) + assert instance.reference_uuid == u + + def test_get(self, user, thl_lm, lm, lam): + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + from generalresearch.models.thl.ledger import ( + LedgerAccount, + AccountType, + ) + + with pytest.raises(LedgerAccountDoesntExistError): + lam.get_account(qualified_name=f"test:bp_wallet:{user.product.id}") + + thl_lm.get_account_or_create_bp_wallet(product=user.product) + account = lam.get_account(qualified_name=f"test:bp_wallet:{user.product.id}") + + assert isinstance(account, LedgerAccount) + assert AccountType.BP_WALLET == account.account_type + assert user.product.uuid == account.reference_uuid + + def test_get_many(self, product_factory, thl_lm, lm, lam, currency): + from generalresearch.models.thl.product import Product + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + + p1: Product = product_factory() + p2: Product = product_factory() + + account1 = thl_lm.get_account_or_create_bp_wallet(product=p1) + account2 = thl_lm.get_account_or_create_bp_wallet(product=p2) + + # Get 1 known account and confirm it comes back + res = lam.get_account_many( + qualified_names=[account1.qualified_name, account2.qualified_name] + ) + assert isinstance(res, list) + assert len(res) == 2 + assert account1 in res + + # Get 2 known accounts and confirm they both come back + res = lam.get_account_many( + qualified_names=[account1.qualified_name, account2.qualified_name] + ) + assert isinstance(res, list) + assert len(res) == 2 + assert account1 in res + assert account2 in res + + # Get 2 known and 1 made up qualified names, and confirm only 2 come + # back. Don't raise on error, so we can confirm the array is "short" + res = lam.get_account_many( + qualified_names=[ + account1.qualified_name, + account2.qualified_name, + f"test:bp_wall:{uuid4().hex}", + ], + raise_on_error=False, + ) + assert isinstance(res, list) + assert len(res) == 2 + + # Same as above, but confirm the raise works on checking res length + with pytest.raises(LedgerAccountDoesntExistError): + lam.get_account_many( + qualified_names=[ + account1.qualified_name, + account2.qualified_name, + f"test:bp_wall:{uuid4().hex}", + ], + raise_on_error=True, + ) + + # Confirm an empty array comes back for all unknown qualified names + res = lam.get_account_many( + qualified_names=[f"test:bp_wall:{uuid4().hex}" for i in range(5)], + raise_on_error=False, + ) + assert isinstance(res, list) + assert len(res) == 0 + + def test_create_account(self, thl_lm, lm, lam): + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + u = uuid4().hex + name = f"test-{u[:8]}" + + account = LedgerAccount( + display_name=name, + qualified_name=f"test:bp_wallet:{u}", + normal_balance=Direction.DEBIT, + account_type=AccountType.BP_WALLET, + currency="test", + reference_type="bp", + reference_uuid=u, + ) + + lam.create_account(account=account) + assert lam.get_account(f"test:bp_wallet:{u}") == account + assert lam.get_account_or_create(account) == account diff --git a/tests/managers/thl/test_ledger/test_thl_lm_bp_payout.py b/tests/managers/thl/test_ledger/test_thl_lm_bp_payout.py new file mode 100644 index 0000000..294d092 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_thl_lm_bp_payout.py @@ -0,0 +1,516 @@ +import logging +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from random import randint +from uuid import uuid4 + +import pytest +import redis +from pydantic import RedisDsn +from redis.lock import Lock + +from generalresearch.currency import USDCent +from generalresearch.managers.base import Permission +from generalresearch.managers.thl.ledger_manager.thl_ledger import ThlLedgerManager +from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerTransactionFlagAlreadyExistsError, + LedgerTransactionConditionFailedError, + LedgerTransactionReleaseLockError, + LedgerTransactionCreateError, +) +from generalresearch.managers.thl.ledger_manager.ledger import LedgerTransaction +from generalresearch.models import Source +from generalresearch.models.thl.definitions import PayoutStatus +from generalresearch.models.thl.ledger import Direction, TransactionType +from generalresearch.models.thl.session import ( + Wall, + Status, + StatusCode1, + Session, +) +from generalresearch.models.thl.user import User +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.redis_helper import RedisConfig + + +def broken_acquire(self, *args, **kwargs): + raise redis.exceptions.TimeoutError("Simulated timeout during acquire") + + +def broken_release(self, *args, **kwargs): + raise redis.exceptions.TimeoutError("Simulated timeout during release") + + +class TestThlLedgerManagerBPPayout: + + def test_create_tx_with_bp_payment( + self, + user_factory, + product_user_wallet_no, + create_main_accounts, + caplog, + thl_lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + now = datetime.now(timezone.utc) - timedelta(hours=1) + user: User = user_factory(product=product_user_wallet_no) + + wall1 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("6.00"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=now, + finished=now + timedelta(seconds=1), + ) + tx = thl_lm.create_tx_task_complete( + wall=wall1, user=user, created=wall1.started + ) + assert isinstance(tx, LedgerTransaction) + + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": now + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + + lock_key = f"test:bp_payout:{user.product.id}" + flag_name = f"{thl_lm.cache_prefix}:transaction_flag:{lock_key}" + thl_lm.redis_client.delete(flag_name) + + payoutevent_uuid = uuid4().hex + thl_lm.create_tx_bp_payout( + product=user.product, + amount=USDCent(200), + created=now, + payoutevent_uuid=payoutevent_uuid, + ) + + payoutevent_uuid = uuid4().hex + thl_lm.create_tx_bp_payout( + product=user.product, + amount=USDCent(200), + created=now + timedelta(minutes=2), + skip_one_per_day_check=True, + payoutevent_uuid=payoutevent_uuid, + ) + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + assert 170 == thl_lm.get_account_balance(bp_wallet_account) + assert 200 == thl_lm.get_account_balance(cash) + + with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError): + thl_lm.create_tx_bp_payout( + user.product, + amount=USDCent(200), + created=now + timedelta(minutes=2), + skip_one_per_day_check=False, + skip_wallet_balance_check=False, + payoutevent_uuid=payoutevent_uuid, + ) + + payoutevent_uuid = uuid4().hex + with caplog.at_level(logging.INFO): + with pytest.raises(LedgerTransactionConditionFailedError): + thl_lm.create_tx_bp_payout( + user.product, + amount=USDCent(10_000), + created=now + timedelta(minutes=2), + skip_one_per_day_check=True, + skip_wallet_balance_check=False, + payoutevent_uuid=payoutevent_uuid, + ) + assert "failed condition check balance:" in caplog.text + + thl_lm.create_tx_bp_payout( + product=user.product, + amount=USDCent(10_00), + created=now + timedelta(minutes=2), + skip_one_per_day_check=True, + skip_wallet_balance_check=True, + payoutevent_uuid=payoutevent_uuid, + ) + assert 170 - 1000 == thl_lm.get_account_balance(bp_wallet_account) + + def test_create_tx(self, product, caplog, thl_lm, currency): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + payoutevent_uuid = uuid4().hex + + # Create a BP Payout for a Product without any activity. By issuing, + # the skip_* checks, we should be able to force it to work, and will + # then ultimately result in a negative balance + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + skip_flag_check=True, + ) + + # Check the basic attributes + assert isinstance(tx, LedgerTransaction) + assert tx.ext_description == "BP Payout" + assert ( + tx.tag + == f"{currency.value}:{TransactionType.BP_PAYOUT.value}:{payoutevent_uuid}" + ) + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + + # Check the Product's balance, it should be negative the amount that was + # paid out. That's because the Product earned nothing.. and then was + # sent something. + balance = thl_lm.get_account_balance( + account=thl_lm.get_account_or_create_bp_wallet(product=product) + ) + assert balance == int(rand_amount) * -1 + + # Test some basic assertions + with caplog.at_level(logging.INFO): + with pytest.raises(expected_exception=Exception): + thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=uuid4().hex, + created=datetime.now(tz=timezone.utc), + skip_wallet_balance_check=False, + skip_one_per_day_check=False, + skip_flag_check=False, + ) + assert "failed condition check >1 tx per day" in caplog.text + + def test_create_tx_redis_failure(self, product, thl_web_rw, thl_lm): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + payoutevent_uuid = uuid4().hex + now = datetime.now(tz=timezone.utc) + + thl_lm.create_tx_plug_bp_wallet( + product, rand_amount, now, direction=Direction.CREDIT + ) + + # Non routable IP address. Redis will fail + thl_lm_redis_0 = ThlLedgerManager( + pg_config=thl_web_rw, + permissions=[ + Permission.CREATE, + Permission.READ, + Permission.UPDATE, + Permission.DELETE, + ], + testing=True, + redis_config=RedisConfig( + dsn=RedisDsn("redis://10.255.255.1:6379"), + socket_connect_timeout=0.1, + ), + ) + + with pytest.raises(expected_exception=Exception) as e: + tx = thl_lm_redis_0.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + ) + assert e.type is redis.exceptions.TimeoutError + # No txs were created + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product) + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + assert len(txs) == 0 + + def test_create_tx_multiple_per_day(self, product, thl_lm): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + payoutevent_uuid = uuid4().hex + now = datetime.now(tz=timezone.utc) + + thl_lm.create_tx_plug_bp_wallet( + product, rand_amount * USDCent(2), now, direction=Direction.CREDIT + ) + + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + ) + + # Try to create another + # Will fail b/c it has the same payout event uuid + with pytest.raises(expected_exception=Exception) as e: + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + ) + assert e.type is LedgerTransactionFlagAlreadyExistsError + + # Try to create another + # Will fail due to multiple per day + payoutevent_uuid2 = uuid4().hex + with pytest.raises(expected_exception=Exception) as e: + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid2, + created=datetime.now(tz=timezone.utc), + ) + assert e.type is LedgerTransactionConditionFailedError + assert str(e.value) == ">1 tx per day" + + # Make it run by skipping one per day check + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid2, + created=datetime.now(tz=timezone.utc), + skip_one_per_day_check=True, + ) + + def test_create_tx_redis_lock_release_error(self, product, thl_lm): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + payoutevent_uuid = uuid4().hex + now = datetime.now(tz=timezone.utc) + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product) + + thl_lm.create_tx_plug_bp_wallet( + product, rand_amount * USDCent(2), now, direction=Direction.CREDIT + ) + + original_acquire = Lock.acquire + original_release = Lock.release + Lock.acquire = broken_acquire + + # Create TX will fail on lock enter, no tx will actually get created + with pytest.raises(expected_exception=Exception) as e: + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + ) + assert e.type is LedgerTransactionCreateError + assert str(e.value) == "Redis error: Simulated timeout during acquire" + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + assert len(txs) == 0 + + Lock.acquire = original_acquire + Lock.release = broken_release + + # Create TX will fail on lock exit, after the tx was created! + with pytest.raises(expected_exception=Exception) as e: + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + ) + assert e.type is LedgerTransactionReleaseLockError + assert str(e.value) == "Redis error: Simulated timeout during release" + + # Transaction was still created! + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + assert len(txs) == 1 + Lock.release = original_release + + +class TestPayoutEventManagerBPPayout: + + def test_create(self, product, thl_lm, brokerage_product_payout_event_manager): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + now = datetime.now(tz=timezone.utc) + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product) + assert thl_lm.get_account_balance(bp_wallet_account) == 0 + thl_lm.create_tx_plug_bp_wallet( + product, rand_amount, now, direction=Direction.CREDIT + ) + assert thl_lm.get_account_balance(bp_wallet_account) == rand_amount + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + created=now, + amount=rand_amount, + payout_type=PayoutType.ACH, + ) + assert brokerage_product_payout_event_manager.check_for_ledger_tx( + thl_ledger_manager=thl_lm, + product_id=product.id, + amount=rand_amount, + payout_event=pe, + ) + assert thl_lm.get_account_balance(bp_wallet_account) == 0 + + def test_create_with_redis_error( + self, product, caplog, thl_lm, brokerage_product_payout_event_manager + ): + caplog.set_level("WARNING") + original_acquire = Lock.acquire + original_release = Lock.release + + rand_amount: USDCent = USDCent(randint(100, 1_000)) + now = datetime.now(tz=timezone.utc) + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product) + assert thl_lm.get_account_balance(bp_wallet_account) == 0 + thl_lm.create_tx_plug_bp_wallet( + product, rand_amount, now, direction=Direction.CREDIT + ) + assert thl_lm.get_account_balance(bp_wallet_account) == rand_amount + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + # Will fail on lock enter, no tx will actually get created + Lock.acquire = broken_acquire + with pytest.raises(expected_exception=Exception) as e: + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + created=now, + amount=rand_amount, + payout_type=PayoutType.ACH, + ) + assert e.type is LedgerTransactionCreateError + assert str(e.value) == "Redis error: Simulated timeout during acquire" + assert any( + "Simulated timeout during acquire. No ledger tx was created" in m + for m in caplog.messages + ) + + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + # One payout event is created, status is failed, and no ledger txs exist + assert len(txs) == 0 + pes = ( + brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products( + thl_ledger_manager=thl_lm, product_uuids=[product.id] + ) + ) + assert len(pes) == 1 + assert pes[0].status == PayoutStatus.FAILED + pe = pes[0] + + # Fix the redis method + Lock.acquire = original_acquire + + # Try to fix the failed payout, by trying ledger tx again + brokerage_product_payout_event_manager.retry_create_bp_payout_event_tx( + product=product, thl_ledger_manager=thl_lm, payout_event_uuid=pe.uuid + ) + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + assert len(txs) == 1 + assert thl_lm.get_account_balance(bp_wallet_account) == 0 + + # And then try to run it again, it'll fail because a payout event with the same info exists + with pytest.raises(expected_exception=Exception) as e: + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + created=now, + amount=rand_amount, + payout_type=PayoutType.ACH, + ) + assert e.type is ValueError + assert "Payout event already exists!" in str(e.value) + + # We wouldn't do this in practice, because this is paying out the BP again, but + # we can if want to. + # Change the timestamp so it'll create a new payout event + now = datetime.now(tz=timezone.utc) + with pytest.raises(LedgerTransactionConditionFailedError) as e: + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + created=now, + amount=rand_amount, + payout_type=PayoutType.ACH, + ) + # But it will fail due to 1 per day check + assert str(e.value) == ">1 tx per day" + pe = brokerage_product_payout_event_manager.get_by_uuid(e.value.pe_uuid) + assert pe.status == PayoutStatus.FAILED + + # And if we really want to, we can make it again + now = datetime.now(tz=timezone.utc) + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + created=now, + amount=rand_amount, + payout_type=PayoutType.ACH, + skip_one_per_day_check=True, + skip_wallet_balance_check=True, + ) + + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + assert len(txs) == 2 + # since they were paid twice + assert thl_lm.get_account_balance(bp_wallet_account) == 0 - rand_amount + + Lock.release = original_release + Lock.acquire = original_acquire + + def test_create_with_redis_error_release( + self, product, caplog, thl_lm, brokerage_product_payout_event_manager + ): + caplog.set_level("WARNING") + + original_release = Lock.release + + rand_amount: USDCent = USDCent(randint(100, 1_000)) + now = datetime.now(tz=timezone.utc) + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product) + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + assert thl_lm.get_account_balance(bp_wallet_account) == 0 + thl_lm.create_tx_plug_bp_wallet( + product, rand_amount, now, direction=Direction.CREDIT + ) + assert thl_lm.get_account_balance(bp_wallet_account) == rand_amount + + # Will fail on lock exit, after the tx was created! + # But it'll see that the tx was created and so everything will be fine + Lock.release = broken_release + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + created=now, + amount=rand_amount, + payout_type=PayoutType.ACH, + ) + assert any( + "Simulated timeout during release but ledger tx exists" in m + for m in caplog.messages + ) + + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + assert len(txs) == 1 + pes = ( + brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products( + thl_ledger_manager=thl_lm, product_uuids=[product.uuid] + ) + ) + assert len(pes) == 1 + assert pes[0].status == PayoutStatus.COMPLETE + Lock.release = original_release diff --git a/tests/managers/thl/test_ledger/test_thl_lm_tx.py b/tests/managers/thl/test_ledger/test_thl_lm_tx.py new file mode 100644 index 0000000..31c7107 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_thl_lm_tx.py @@ -0,0 +1,1762 @@ +import logging +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from random import randint +from uuid import uuid4 + +import pytest + +from generalresearch.currency import USDCent +from generalresearch.managers.thl.ledger_manager.ledger import ( + LedgerTransaction, +) +from generalresearch.models import Source +from generalresearch.models.thl.definitions import ( + WALL_ALLOWED_STATUS_STATUS_CODE, +) +from generalresearch.models.thl.ledger import Direction +from generalresearch.models.thl.ledger import TransactionType +from generalresearch.models.thl.product import ( + PayoutConfig, + PayoutTransformation, + UserWalletConfig, +) +from generalresearch.models.thl.session import ( + Wall, + Status, + StatusCode1, + Session, + WallAdjustedStatus, +) +from generalresearch.models.thl.user import User +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.thl.payout import UserPayoutEvent + +logger = logging.getLogger("LedgerManager") + + +class TestThlLedgerTxManager: + + def test_create_tx_task_complete( + self, + wall, + user, + account_revenue_task_complete, + create_main_accounts, + thl_lm, + lm, + ): + create_main_accounts() + tx = thl_lm.create_tx_task_complete(wall=wall, user=user) + assert isinstance(tx, LedgerTransaction) + + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.created == tx.created + + def test_create_tx_task_complete_( + self, wall, user, account_revenue_task_complete, thl_lm, lm + ): + tx = thl_lm.create_tx_task_complete_(wall=wall, user=user) + assert isinstance(tx, LedgerTransaction) + + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.created == tx.created + + def test_create_tx_bp_payment( + self, + session_factory, + user, + create_main_accounts, + delete_ledger_db, + thl_lm, + lm, + session_manager, + ): + delete_ledger_db() + create_main_accounts() + s1 = session_factory(user=user) + + status, status_code_1 = s1.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments() + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=status_code_1, + finished=datetime.now(tz=timezone.utc) + timedelta(minutes=10), + payout=bp_pay, + user_payout=user_pay, + ) + + tx = thl_lm.create_tx_bp_payment(session=s1) + assert isinstance(tx, LedgerTransaction) + + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.created == tx.created + + def test_create_tx_bp_payment_amt( + self, + session_factory, + user_factory, + product_manager, + create_main_accounts, + delete_ledger_db, + thl_lm, + lm, + session_manager, + ): + delete_ledger_db() + create_main_accounts() + product = product_manager.create_dummy( + payout_config=PayoutConfig( + payout_transformation=PayoutTransformation( + f="payout_transformation_amt" + ) + ), + user_wallet_config=UserWalletConfig(amt=True, enabled=True), + ) + user = user_factory(product=product) + s1 = session_factory(user=user, wall_req_cpi=Decimal("1")) + + status, status_code_1 = s1.determine_session_status() + assert status == Status.COMPLETE + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments( + thl_ledger_manager=thl_lm + ) + print(thl_net, commission_amount, bp_pay, user_pay) + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=status_code_1, + finished=datetime.now(tz=timezone.utc) + timedelta(minutes=10), + payout=bp_pay, + user_payout=user_pay, + ) + + tx = thl_lm.create_tx_bp_payment(session=s1) + assert isinstance(tx, LedgerTransaction) + + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.created == tx.created + + def test_create_tx_bp_payment_( + self, + session_factory, + user, + create_main_accounts, + thl_lm, + lm, + session_manager, + utc_hour_ago, + ): + s1 = session_factory(user=user) + status, status_code_1 = s1.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments() + session_manager.finish_with_status( + session=s1, + status=status, + status_code_1=status_code_1, + finished=utc_hour_ago + timedelta(minutes=10), + payout=bp_pay, + user_payout=user_pay, + ) + + s1.determine_payments() + tx = thl_lm.create_tx_bp_payment_(session=s1) + assert isinstance(tx, LedgerTransaction) + + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.created == tx.created + + def test_create_tx_task_adjustment( + self, wall_factory, session, user, create_main_accounts, thl_lm, lm + ): + """Create Wall event Complete, and Create a Tx Task Adjustment + + - I don't know what this does exactly... but we can confirm + the transaction comes back with balanced amounts, and that + the name of the Source is in the Tx description + """ + + wall_status = Status.COMPLETE + wall: Wall = wall_factory(session=session, wall_status=wall_status) + + tx = thl_lm.create_tx_task_adjustment(wall=wall, user=user) + assert isinstance(tx, LedgerTransaction) + res = lm.get_tx_by_id(transaction_id=tx.id) + + assert res.entries[0].amount == int(wall.cpi * 100) + assert res.entries[1].amount == int(wall.cpi * 100) + assert wall.source.name in res.ext_description + assert res.created == tx.created + + def test_create_tx_bp_adjustment(self, session, user, caplog, thl_lm, lm): + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + + # The default session fixture is just an unfinished wall event + assert len(session.wall_events) == 1 + assert session.finished is None + assert status == Status.TIMEOUT + assert status_code_1 in list( + WALL_ALLOWED_STATUS_STATUS_CODE.get(Status.TIMEOUT, {}) + ) + assert thl_net == Decimal(0) + assert commission_amount == Decimal(0) + assert bp_pay == Decimal(0) + assert user_pay is None + + # Update the finished timestamp, but nothing else. This means that + # there is no financial changes needed + session.update( + **{ + "finished": datetime.now(tz=timezone.utc) + timedelta(minutes=10), + } + ) + assert session.finished + with caplog.at_level(logging.INFO): + tx = thl_lm.create_tx_bp_adjustment(session=session) + assert tx is None + assert "No transactions needed." in caplog.text + + def test_create_tx_bp_payout(self, product, caplog, thl_lm, currency): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + payoutevent_uuid = uuid4().hex + + # Create a BP Payout for a Product without any activity. By issuing, + # the skip_* checks, we should be able to force it to work, and will + # then ultimately result in a negative balance + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + skip_flag_check=True, + ) + + # Check the basic attributes + assert isinstance(tx, LedgerTransaction) + assert tx.ext_description == "BP Payout" + assert ( + tx.tag + == f"{thl_lm.currency.value}:{TransactionType.BP_PAYOUT.value}:{payoutevent_uuid}" + ) + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + + # Check the Product's balance, it should be negative the amount that was + # paid out. That's because the Product earned nothing.. and then was + # sent something. + balance = thl_lm.get_account_balance( + account=thl_lm.get_account_or_create_bp_wallet(product=product) + ) + assert balance == int(rand_amount) * -1 + + # Test some basic assertions + with caplog.at_level(logging.INFO): + with pytest.raises(expected_exception=Exception): + thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=uuid4().hex, + created=datetime.now(tz=timezone.utc), + skip_wallet_balance_check=False, + skip_one_per_day_check=False, + skip_flag_check=False, + ) + assert "failed condition check >1 tx per day" in caplog.text + + def test_create_tx_bp_payout_(self, product, thl_lm, lm, currency): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + payoutevent_uuid = uuid4().hex + + # Create a BP Payout for a Product without any activity. + tx = thl_lm.create_tx_bp_payout_( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + ) + + # Check the basic attributes + assert isinstance(tx, LedgerTransaction) + assert tx.ext_description == "BP Payout" + assert ( + tx.tag + == f"{currency.value}:{TransactionType.BP_PAYOUT.value}:{payoutevent_uuid}" + ) + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + + def test_create_tx_plug_bp_wallet( + self, product, create_main_accounts, thl_lm, lm, currency + ): + """A BP Wallet "plug" is a way to makeup discrepancies and simply + add or remove money + """ + rand_amount: USDCent = USDCent(randint(100, 1_000)) + + tx = thl_lm.create_tx_plug_bp_wallet( + product=product, + amount=rand_amount, + created=datetime.now(tz=timezone.utc), + direction=Direction.DEBIT, + skip_flag_check=False, + ) + + assert isinstance(tx, LedgerTransaction) + + # We issued the BP money they didn't earn, so now they have a + # negative balance + balance = thl_lm.get_account_balance( + account=thl_lm.get_account_or_create_bp_wallet(product=product) + ) + assert balance == int(rand_amount) * -1 + + def test_create_tx_plug_bp_wallet_( + self, product, create_main_accounts, thl_lm, lm, currency + ): + """A BP Wallet "plug" is a way to fix discrepancies and simply + add or remove money. + + Similar to above, but because it's unprotected, we can immediately + issue another to see if the balance changes + """ + rand_amount: USDCent = USDCent(randint(100, 1_000)) + + tx = thl_lm.create_tx_plug_bp_wallet_( + product=product, + amount=rand_amount, + created=datetime.now(tz=timezone.utc), + direction=Direction.DEBIT, + ) + + assert isinstance(tx, LedgerTransaction) + + # We issued the BP money they didn't earn, so now they have a + # negative balance + balance = thl_lm.get_account_balance( + account=thl_lm.get_account_or_create_bp_wallet(product=product) + ) + assert balance == int(rand_amount) * -1 + + # Issue a positive one now, and confirm the balance goes positive + thl_lm.create_tx_plug_bp_wallet_( + product=product, + amount=rand_amount + rand_amount, + created=datetime.now(tz=timezone.utc), + direction=Direction.CREDIT, + ) + balance = thl_lm.get_account_balance( + account=thl_lm.get_account_or_create_bp_wallet(product=product) + ) + assert balance == int(rand_amount) + + def test_create_tx_user_payout_request( + self, + user, + product_user_wallet_yes, + user_factory, + delete_df_collection, + thl_lm, + lm, + currency, + ): + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=500, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + # The default user fixture uses a product that doesn't have wallet + # mode enabled + with pytest.raises(expected_exception=AssertionError): + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + skip_flag_check=True, + skip_wallet_balance_check=True, + ) + + # Now try it for a user on a product with wallet mode + u2 = user_factory(product=product_user_wallet_yes) + + # User's pre-balance is 0 because no activity has occurred yet + pre_balance = lm.get_account_balance( + account=thl_lm.get_account_or_create_user_wallet(user=u2) + ) + assert pre_balance == 0 + + tx = thl_lm.create_tx_user_payout_request( + user=u2, + payout_event=pe, + skip_flag_check=True, + skip_wallet_balance_check=True, + ) + assert isinstance(tx, LedgerTransaction) + assert tx.entries[0].amount == pe.amount + assert tx.entries[1].amount == pe.amount + assert tx.ext_description == "User Payout Paypal Request $5.00" + + # + # (TODO): This key ":user_payout:" is NOT part of the TransactionType + # enum and was manually set. It should be based off the + # TransactionType names. + # + + assert tx.tag == f"{currency.value}:user_payout:{pe.uuid}:request" + + # Post balance is -$5.00 because it comes out of the wallet before + # it's Approved or Completed + post_balance = lm.get_account_balance( + account=thl_lm.get_account_or_create_user_wallet(user=u2) + ) + assert post_balance == -500 + + def test_create_tx_user_payout_request_( + self, + user, + product_user_wallet_yes, + user_factory, + delete_ledger_db, + thl_lm, + lm, + ): + delete_ledger_db() + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=500, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + rand_description = uuid4().hex + tx = thl_lm.create_tx_user_payout_request_( + user=user, payout_event=pe, description=rand_description + ) + + assert tx.ext_description == rand_description + + post_balance = lm.get_account_balance( + account=thl_lm.get_account_or_create_user_wallet(user=user) + ) + assert post_balance == -500 + + def test_create_tx_user_payout_complete( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + delete_ledger_db, + thl_lm, + lm, + currency, + ): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_user_wallet_yes) + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + rand_amount = randint(100, 1_000) + + # Ensure the user starts out with nothing... + assert lm.get_account_balance(account=user_account) == 0 + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=rand_amount, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + # Confirm it's not possible unless a request occurred happen + with pytest.raises(expected_exception=ValueError): + thl_lm.create_tx_user_payout_complete( + user=user, + payout_event=pe, + fee_amount=None, + skip_flag_check=False, + ) + + # (1) Make a request first + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + skip_flag_check=True, + skip_wallet_balance_check=True, + ) + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == rand_amount * -1 + + # (2) Complete the request + tx = thl_lm.create_tx_user_payout_complete( + user=user, + payout_event=pe, + fee_amount=Decimal(0), + skip_flag_check=False, + ) + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + assert tx.tag == f"{currency.value}:user_payout:{pe.uuid}:complete" + assert isinstance(tx, LedgerTransaction) + + # The amount that comes out of the user wallet doesn't change after + # it's approved becuase it's already been withdrawn + assert lm.get_account_balance(account=user_account) == rand_amount * -1 + + def test_create_tx_user_payout_complete_( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + ): + user: User = user_factory(product=product_user_wallet_yes) + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + rand_amount = randint(100, 1_000) + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=rand_amount, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + # (1) Make a request first + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + skip_flag_check=True, + skip_wallet_balance_check=True, + ) + + # (2) Complete the request + rand_desc = uuid4().hex + + bp_expense_account = thl_lm.get_account_or_create_bp_expense( + product=user.product, expense_name="paypal" + ) + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=user.product) + + tx = thl_lm.create_tx_user_payout_complete_( + user=user, + payout_event=pe, + fee_amount=Decimal("0.00"), + fee_expense_account=bp_expense_account, + fee_payer_account=bp_wallet_account, + description=rand_desc, + ) + assert tx.ext_description == rand_desc + assert lm.get_account_balance(account=user_account) == rand_amount * -1 + + def test_create_tx_user_payout_cancelled( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + currency, + ): + user: User = user_factory(product=product_user_wallet_yes) + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + rand_amount = randint(100, 1_000) + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=rand_amount, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + # (1) Make a request first + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + skip_flag_check=True, + skip_wallet_balance_check=True, + ) + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == rand_amount * -1 + + # (2) Cancel the request + tx = thl_lm.create_tx_user_payout_cancelled( + user=user, + payout_event=pe, + skip_flag_check=False, + ) + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + assert tx.tag == f"{currency.value}:user_payout:{pe.uuid}:cancel" + assert isinstance(tx, LedgerTransaction) + + # Assert the balance comes back to 0 after it was cancelled + assert lm.get_account_balance(account=user_account) == 0 + + def test_create_tx_user_payout_cancelled_( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + currency, + ): + user: User = user_factory(product=product_user_wallet_yes) + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + rand_amount = randint(100, 1_000) + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=rand_amount, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + # (1) Make a request first + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + skip_flag_check=True, + skip_wallet_balance_check=True, + ) + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == rand_amount * -1 + + # (2) Cancel the request + rand_desc = uuid4().hex + tx = thl_lm.create_tx_user_payout_cancelled_( + user=user, payout_event=pe, description=rand_desc + ) + assert isinstance(tx, LedgerTransaction) + assert tx.ext_description == rand_desc + assert lm.get_account_balance(account=user_account) == 0 + + def test_create_tx_user_bonus( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + currency, + ): + user: User = user_factory(product=product_user_wallet_yes) + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + rand_amount = randint(100, 1_000) + rand_ref_uuid = uuid4().hex + rand_desc = uuid4().hex + + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == 0 + + tx = thl_lm.create_tx_user_bonus( + user=user, + amount=Decimal(rand_amount / 100), + ref_uuid=rand_ref_uuid, + description=rand_desc, + skip_flag_check=True, + ) + assert tx.ext_description == rand_desc + assert tx.tag == f"{thl_lm.currency.value}:user_bonus:{rand_ref_uuid}" + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == rand_amount + + def test_create_tx_user_bonus_( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + currency, + ): + user: User = user_factory(product=product_user_wallet_yes) + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + rand_amount = randint(100, 1_000) + rand_ref_uuid = uuid4().hex + rand_desc = uuid4().hex + + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == 0 + + tx = thl_lm.create_tx_user_bonus_( + user=user, + amount=Decimal(rand_amount / 100), + ref_uuid=rand_ref_uuid, + description=rand_desc, + ) + assert tx.ext_description == rand_desc + assert tx.tag == f"{thl_lm.currency.value}:user_bonus:{rand_ref_uuid}" + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == rand_amount + + +class TestThlLedgerTxManagerFlows: + """Combine the various THL_LM methods to create actual "real world" + examples + """ + + def test_create_tx_task_complete( + self, user, create_main_accounts, thl_lm, lm, currency, delete_ledger_db + ): + delete_ledger_db() + create_main_accounts() + + wall1 = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.23"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started) + + wall2 = Wall( + user_id=1, + source=Source.FULL_CIRCLE, + req_survey_id="yyy", + req_cpi=Decimal("3.21"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + thl_lm.create_tx_task_complete(wall=wall2, user=user, created=wall2.started) + + cash = thl_lm.get_account_cash() + revenue = thl_lm.get_account_task_complete_revenue() + + assert lm.get_account_balance(cash) == 123 + 321 + assert lm.get_account_balance(revenue) == 123 + 321 + assert lm.check_ledger_balanced() + + assert ( + lm.get_account_filtered_balance( + account=revenue, metadata_key="source", metadata_value="d" + ) + == 123 + ) + + assert ( + lm.get_account_filtered_balance( + account=revenue, metadata_key="source", metadata_value="f" + ) + == 321 + ) + + assert ( + lm.get_account_filtered_balance( + account=revenue, metadata_key="source", metadata_value="x" + ) + == 0 + ) + + assert ( + thl_lm.get_account_filtered_balance( + account=revenue, + metadata_key="thl_wall", + metadata_value=wall1.uuid, + ) + == 123 + ) + + def test_create_transaction_task_complete_1_cent( + self, user, create_main_accounts, thl_lm, lm, currency + ): + wall1 = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("0.007"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + tx = thl_lm.create_tx_task_complete( + wall=wall1, user=user, created=wall1.started + ) + + assert isinstance(tx, LedgerTransaction) + + def test_create_transaction_bp_payment( + self, + user, + create_main_accounts, + thl_lm, + lm, + currency, + delete_ledger_db, + session_factory, + utc_hour_ago, + ): + delete_ledger_db() + create_main_accounts() + + s1: Session = session_factory( + user=user, + wall_count=1, + started=utc_hour_ago, + wall_source=Source.TESTING, + ) + w1: Wall = s1.wall_events[0] + + tx = thl_lm.create_tx_task_complete(wall=w1, user=user, created=w1.started) + assert isinstance(tx, LedgerTransaction) + + status, status_code_1 = s1.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments() + s1.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": s1.started + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + print(thl_net, commission_amount, bp_pay, user_pay) + thl_lm.create_tx_bp_payment(session=s1, created=w1.started) + + revenue = thl_lm.get_account_task_complete_revenue() + bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=user.product) + bp_commission = thl_lm.get_account_or_create_bp_commission(product=user.product) + + assert 0 == lm.get_account_balance(account=revenue) + assert 50 == lm.get_account_filtered_balance( + account=revenue, + metadata_key="source", + metadata_value=Source.TESTING, + ) + assert 48 == lm.get_account_balance(account=bp_wallet) + assert 48 == lm.get_account_filtered_balance( + account=bp_wallet, + metadata_key="thl_session", + metadata_value=s1.uuid, + ) + assert 2 == thl_lm.get_account_balance(account=bp_commission) + assert thl_lm.check_ledger_balanced() + + def test_create_transaction_bp_payment_round( + self, + user_factory, + product_user_wallet_no, + create_main_accounts, + thl_lm, + lm, + currency, + ): + product_user_wallet_no.commission_pct = Decimal("0.085") + user: User = user_factory(product=product_user_wallet_no) + + wall1 = Wall( + user_id=user.user_id, + source=Source.SAGO, + req_survey_id="xxx", + req_cpi=Decimal("0.287"), + session_id=3, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + + tx = thl_lm.create_tx_task_complete( + wall=wall1, user=user, created=wall1.started + ) + assert isinstance(tx, LedgerTransaction) + + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": session.started + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + + print(thl_net, commission_amount, bp_pay, user_pay) + tx = thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + assert isinstance(tx, LedgerTransaction) + + def test_create_transaction_bp_payment_round2( + self, delete_ledger_db, user, create_main_accounts, thl_lm, lm, currency + ): + delete_ledger_db() + create_main_accounts() + # user must be no user wallet + # e.g. session 869b5bfa47f44b4f81cd095ed01df2ff this fails if you dont round properly + + wall1 = Wall( + user_id=user.user_id, + source=Source.SAGO, + req_survey_id="xxx", + req_cpi=Decimal("1.64500"), + session_id=3, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + + thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started) + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + # thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": session.started + timedelta(minutes=10), + "payout": Decimal("1.53"), + "user_payout": Decimal("1.53"), + } + ) + + thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + + def test_create_transaction_bp_payment_round3( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + currency, + ): + # e.g. session ___ fails b/c we rounded incorrectly + # before, and now we are off by a penny... + user: User = user_factory(product=product_user_wallet_yes) + + wall1 = Wall( + user_id=user.user_id, + source=Source.SAGO, + req_survey_id="xxx", + req_cpi=Decimal("0.385"), + session_id=3, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started) + + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + # thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": session.started + timedelta(minutes=10), + "payout": Decimal("0.39"), + "user_payout": Decimal("0.26"), + } + ) + # with pytest.logs(logger, level=logging.WARNING) as cm: + # tx = thl_lm.create_transaction_bp_payment(session, created=wall1.started) + # assert "Capping bp_pay to thl_net" in cm.output[0] + + def test_create_transaction_bp_payment_user_wallet( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + delete_ledger_db, + thl_lm, + session_manager, + wall_manager, + lm, + session_factory, + currency, + utc_hour_ago, + ): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_user_wallet_yes) + assert user.product.user_wallet_enabled + + s1: Session = session_factory( + user=user, + wall_count=1, + started=utc_hour_ago, + wall_req_cpi=Decimal(".50"), + wall_source=Source.TESTING, + ) + w1: Wall = s1.wall_events[0] + + thl_lm.create_tx_task_complete(wall=w1, user=user, created=w1.started) + + status, status_code_1 = s1.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments() + session_manager.finish_with_status( + session=s1, + status=status, + status_code_1=status_code_1, + finished=s1.started + timedelta(minutes=10), + payout=bp_pay, + user_payout=user_pay, + ) + thl_lm.create_tx_bp_payment(session=s1, created=w1.started) + + revenue = thl_lm.get_account_task_complete_revenue() + bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=user.product) + bp_commission = thl_lm.get_account_or_create_bp_commission(product=user.product) + user_wallet = thl_lm.get_account_or_create_user_wallet(user=user) + + assert 0 == thl_lm.get_account_balance(account=revenue) + assert 50 == thl_lm.get_account_filtered_balance( + account=revenue, + metadata_key="source", + metadata_value=Source.TESTING, + ) + + assert 48 - 19 == thl_lm.get_account_balance(account=bp_wallet) + assert 48 - 19 == thl_lm.get_account_filtered_balance( + account=bp_wallet, + metadata_key="thl_session", + metadata_value=s1.uuid, + ) + assert 2 == thl_lm.get_account_balance(bp_commission) + assert 19 == thl_lm.get_account_balance(user_wallet) + assert 19 == thl_lm.get_account_filtered_balance( + account=user_wallet, + metadata_key="thl_session", + metadata_value=s1.uuid, + ) + + assert 0 == thl_lm.get_account_filtered_balance( + account=user_wallet, metadata_key="thl_session", metadata_value="x" + ) + assert thl_lm.check_ledger_balanced() + + +class TestThlLedgerManagerAdj: + + def test_create_tx_task_adjustment( + self, + user_factory, + product_user_wallet_no, + create_main_accounts, + delete_ledger_db, + thl_lm, + lm, + utc_hour_ago, + currency, + ): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_user_wallet_no) + + wall1 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.23"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=utc_hour_ago, + finished=utc_hour_ago + timedelta(seconds=1), + ) + + thl_lm.create_tx_task_complete(wall1, user, created=wall1.started) + + wall2 = Wall( + user_id=1, + source=Source.FULL_CIRCLE, + req_survey_id="yyy", + req_cpi=Decimal("3.21"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=utc_hour_ago, + finished=utc_hour_ago + timedelta(seconds=1), + ) + thl_lm.create_tx_task_complete(wall2, user, created=wall2.started) + + wall1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + print(wall1.get_cpi_after_adjustment()) + thl_lm.create_tx_task_adjustment(wall1, user) + + cash = thl_lm.get_account_cash() + revenue = thl_lm.get_account_task_complete_revenue() + + assert 123 + 321 - 123 == thl_lm.get_account_balance(account=cash) + assert 123 + 321 - 123 == thl_lm.get_account_balance(account=revenue) + assert thl_lm.check_ledger_balanced() + assert 0 == thl_lm.get_account_filtered_balance( + revenue, metadata_key="source", metadata_value="d" + ) + assert 321 == thl_lm.get_account_filtered_balance( + revenue, metadata_key="source", metadata_value="f" + ) + assert 0 == thl_lm.get_account_filtered_balance( + revenue, metadata_key="source", metadata_value="x" + ) + assert 123 - 123 == thl_lm.get_account_filtered_balance( + account=revenue, metadata_key="thl_wall", metadata_value=wall1.uuid + ) + + # un-reconcile it + wall1.update( + adjusted_status=None, + adjusted_cpi=None, + adjusted_timestamp=utc_hour_ago + timedelta(minutes=45), + ) + print(wall1.get_cpi_after_adjustment()) + thl_lm.create_tx_task_adjustment(wall1, user) + # and then run it again to make sure it does nothing + thl_lm.create_tx_task_adjustment(wall1, user) + + cash = thl_lm.get_account_cash() + revenue = thl_lm.get_account_task_complete_revenue() + + assert 123 + 321 - 123 + 123 == thl_lm.get_account_balance(cash) + assert 123 + 321 - 123 + 123 == thl_lm.get_account_balance(revenue) + assert thl_lm.check_ledger_balanced() + assert 123 == thl_lm.get_account_filtered_balance( + account=revenue, metadata_key="source", metadata_value="d" + ) + assert 321 == thl_lm.get_account_filtered_balance( + account=revenue, metadata_key="source", metadata_value="f" + ) + assert 0 == thl_lm.get_account_filtered_balance( + account=revenue, metadata_key="source", metadata_value="x" + ) + assert 123 - 123 + 123 == thl_lm.get_account_filtered_balance( + account=revenue, metadata_key="thl_wall", metadata_value=wall1.uuid + ) + + def test_create_tx_bp_adjustment( + self, + user, + product_user_wallet_no, + create_main_accounts, + caplog, + thl_lm, + lm, + currency, + session_manager, + wall_manager, + session_factory, + utc_hour_ago, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpis=[Decimal(1), Decimal(3)], + wall_statuses=[Status.COMPLETE, Status.COMPLETE], + started=utc_hour_ago, + ) + + w1: Wall = s1.wall_events[0] + w2: Wall = s1.wall_events[1] + + thl_lm.create_tx_task_complete(wall=w1, user=user, created=w1.started) + thl_lm.create_tx_task_complete(wall=w2, user=user, created=w2.started) + + status, status_code_1 = s1.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments() + session_manager.finish_with_status( + session=s1, + status=status, + status_code_1=status_code_1, + finished=utc_hour_ago + timedelta(minutes=10), + payout=bp_pay, + user_payout=user_pay, + ) + thl_lm.create_tx_bp_payment(session=s1, created=w1.started) + revenue = thl_lm.get_account_task_complete_revenue() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=user.product) + bp_commission_account = thl_lm.get_account_or_create_bp_commission( + product=user.product + ) + assert 380 == thl_lm.get_account_balance(account=bp_wallet_account) + assert 0 == thl_lm.get_account_balance(account=revenue) + assert 20 == thl_lm.get_account_balance(account=bp_commission_account) + thl_lm.check_ledger_balanced() + + # This should do nothing (since we haven't adjusted any wall events) + s1.adjust_status() + with caplog.at_level(logging.INFO): + thl_lm.create_tx_bp_adjustment(session=s1) + + assert ( + "create_transaction_bp_adjustment. No transactions needed." in caplog.text + ) + + # self.assertEqual(380, ledger_manager.get_account_balance(bp_wallet_account)) + # self.assertEqual(0, ledger_manager.get_account_balance(revenue)) + # self.assertEqual(20, ledger_manager.get_account_balance(bp_commission_account)) + # self.assertTrue(ledger_manager.check_ledger_balanced()) + + # recon $1 survey. + wall_manager.adjust_status( + wall=w1, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=Decimal(0), + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + thl_lm.create_tx_task_adjustment(wall=w1, user=user) + # -$1.00 b/c the MP took the $1 back, but we haven't yet taken the BP payment back + assert -100 == thl_lm.get_account_balance(revenue) + s1.adjust_status() + thl_lm.create_tx_bp_adjustment(session=s1) + + with caplog.at_level(logging.INFO): + thl_lm.create_tx_bp_adjustment(session=s1) + assert ( + "create_transaction_bp_adjustment. No transactions needed." in caplog.text + ) + + assert 380 - 95 == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20 - 5 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # unrecon the $1 survey + wall_manager.adjust_status( + wall=w1, + adjusted_status=None, + adjusted_cpi=None, + adjusted_timestamp=utc_hour_ago + timedelta(minutes=45), + ) + thl_lm.create_tx_task_adjustment( + wall=w1, + user=user, + created=utc_hour_ago + timedelta(minutes=45), + ) + new_status, new_payout, new_user_payout = s1.determine_new_status_and_payouts() + s1.adjust_status() + thl_lm.create_tx_bp_adjustment(session=s1) + assert 380 == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20, thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + def test_create_tx_bp_adjustment_small( + self, + user_factory, + product_user_wallet_no, + create_main_accounts, + delete_ledger_db, + thl_lm, + lm, + utc_hour_ago, + currency, + ): + delete_ledger_db() + create_main_accounts() + + # This failed when I didn't check that `change_commission` > 0 in + # create_transaction_bp_adjustment + user: User = user_factory(product=product_user_wallet_no) + + wall1 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("0.10"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=utc_hour_ago, + finished=utc_hour_ago + timedelta(seconds=1), + ) + + tx = thl_lm.create_tx_task_complete( + wall=wall1, user=user, created=wall1.started + ) + assert isinstance(tx, LedgerTransaction) + + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": utc_hour_ago + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + thl_lm.create_tx_bp_payment(session, created=wall1.started) + + wall1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + thl_lm.create_tx_task_adjustment(wall1, user) + session.adjust_status() + thl_lm.create_tx_bp_adjustment(session) + + def test_create_tx_bp_adjustment_abandon( + self, + user_factory, + product_user_wallet_no, + delete_ledger_db, + session_factory, + create_main_accounts, + caplog, + thl_lm, + lm, + currency, + utc_hour_ago, + session_manager, + wall_manager, + ): + delete_ledger_db() + create_main_accounts() + user: User = user_factory(product=product_user_wallet_no) + s1: Session = session_factory( + user=user, final_status=Status.ABANDON, wall_req_cpi=Decimal(1) + ) + w1 = s1.wall_events[-1] + + # Adjust to complete. + wall_manager.adjust_status( + wall=w1, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w1.cpi, + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + thl_lm.create_tx_task_adjustment(wall=w1, user=user) + s1.adjust_status() + thl_lm.create_tx_bp_adjustment(session=s1) + # And then adjust it back (it was abandon before, but now it should be + # fail (?) or back to abandon?) + wall_manager.adjust_status( + wall=w1, + adjusted_status=None, + adjusted_cpi=None, + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + thl_lm.create_tx_task_adjustment(wall=w1, user=user) + s1.adjust_status() + thl_lm.create_tx_bp_adjustment(session=s1) + + revenue = thl_lm.get_account_task_complete_revenue() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=user.product) + bp_commission_account = thl_lm.get_account_or_create_bp_commission( + product=user.product + ) + assert 0 == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 0 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # This should do nothing + s1.adjust_status() + with caplog.at_level(logging.INFO): + thl_lm.create_tx_bp_adjustment(session=s1) + assert "No transactions needed" in caplog.text + + # Now back to complete again + wall_manager.adjust_status( + wall=w1, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w1.cpi, + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + s1.adjust_status() + thl_lm.create_tx_bp_adjustment(session=s1) + assert 95 == thl_lm.get_account_balance(bp_wallet_account) + + def test_create_tx_bp_adjustment_user_wallet( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + delete_ledger_db, + caplog, + thl_lm, + lm, + currency, + ): + delete_ledger_db() + create_main_accounts() + + now = datetime.now(timezone.utc) - timedelta(days=1) + user: User = user_factory(product=product_user_wallet_yes) + + # Create 2 Wall completes and create the respective transaction for + # them. We then create a 3rd wall event which is a failure but we + # do NOT create a transaction for it + + wall3 = Wall( + user_id=8, + source=Source.CINT, + req_survey_id="zzz", + req_cpi=Decimal("2.00"), + session_id=1, + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + started=now, + finished=now + timedelta(minutes=1), + ) + + now_w1 = now + timedelta(minutes=1, milliseconds=1) + wall1 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.00"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=now_w1, + finished=now_w1 + timedelta(minutes=1), + ) + tx = thl_lm.create_tx_task_complete( + wall=wall1, user=user, created=wall1.started + ) + assert isinstance(tx, LedgerTransaction) + + now_w2 = now + timedelta(minutes=2, milliseconds=1) + wall2 = Wall( + user_id=user.user_id, + source=Source.FULL_CIRCLE, + req_survey_id="yyy", + req_cpi=Decimal("3.00"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=now_w2, + finished=now_w2 + timedelta(minutes=1), + ) + tx = thl_lm.create_tx_task_complete( + wall=wall2, user=user, created=wall2.started + ) + assert isinstance(tx, LedgerTransaction) + + # It doesn't matter what order these wall events go in as because + # we have a pydantic valiator that sorts them + wall_events = [wall3, wall1, wall2] + # shuffle(wall_events) + session = Session(started=wall1.started, user=user, wall_events=wall_events) + status, status_code_1 = session.determine_session_status() + assert status == Status.COMPLETE + assert status_code_1 == StatusCode1.COMPLETE + + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + assert thl_net == Decimal("4.00") + assert commission_amount == Decimal("0.20") + assert bp_pay == Decimal("3.80") + assert user_pay == Decimal("1.52") + + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": now + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + + tx = thl_lm.create_tx_bp_adjustment(session=session, created=wall1.started) + assert isinstance(tx, LedgerTransaction) + + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=user.product) + assert 228 == thl_lm.get_account_balance(account=bp_wallet_account) + + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + assert 152 == thl_lm.get_account_balance(account=user_account) + + revenue = thl_lm.get_account_task_complete_revenue() + assert 0 == thl_lm.get_account_balance(account=revenue) + + bp_commission_account = thl_lm.get_account_or_create_bp_commission( + product=user.product + ) + assert 20 == thl_lm.get_account_balance(account=bp_commission_account) + + # the total (4.00) = 2.28 + 1.52 + .20 + assert thl_lm.check_ledger_balanced() + + # This should do nothing (since we haven't adjusted any wall events) + session.adjust_status() + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + with caplog.at_level(logging.INFO): + thl_lm.create_tx_bp_adjustment(session) + assert ( + "create_transaction_bp_adjustment. No transactions needed." in caplog.text + ) + + # recon $1 survey. + wall1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=now + timedelta(hours=1), + ) + thl_lm.create_tx_task_adjustment(wall1, user) + # -$1.00 b/c the MP took the $1 back, but we haven't yet taken the BP payment back + assert -100 == thl_lm.get_account_balance(revenue) + session.adjust_status() + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + thl_lm.create_tx_bp_adjustment(session) + + # running this twice b/c it should do nothing the 2nd time + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + with caplog.at_level(logging.INFO): + thl_lm.create_tx_bp_adjustment(session) + assert ( + "create_transaction_bp_adjustment. No transactions needed." in caplog.text + ) + + assert 228 - 57 == thl_lm.get_account_balance(bp_wallet_account) + assert 152 - 38 == thl_lm.get_account_balance(user_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20 - 5 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # unrecon the $1 survey + wall1.update( + adjusted_status=None, + adjusted_cpi=None, + adjusted_timestamp=now + timedelta(hours=2), + ) + tx = thl_lm.create_tx_task_adjustment(wall=wall1, user=user) + assert isinstance(tx, LedgerTransaction) + + new_status, new_payout, new_user_payout = ( + session.determine_new_status_and_payouts() + ) + print(new_status, new_payout, new_user_payout, session.adjusted_payout) + session.adjust_status() + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + thl_lm.create_tx_bp_adjustment(session) + + assert 228 - 57 + 57 == thl_lm.get_account_balance(bp_wallet_account) + assert 152 - 38 + 38 == thl_lm.get_account_balance(user_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20 - 5 + 5 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # make the $2 failure into a complete also + wall3.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=wall3.cpi, + adjusted_timestamp=now + timedelta(hours=2), + ) + thl_lm.create_tx_task_adjustment(wall3, user) + new_status, new_payout, new_user_payout = ( + session.determine_new_status_and_payouts() + ) + print(new_status, new_payout, new_user_payout, session.adjusted_payout) + session.adjust_status() + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + thl_lm.create_tx_bp_adjustment(session) + assert 228 - 57 + 57 + 114 == thl_lm.get_account_balance(bp_wallet_account) + assert 152 - 38 + 38 + 76 == thl_lm.get_account_balance(user_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20 - 5 + 5 + 10 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + def test_create_transaction_bp_adjustment_cpi_adjustment( + self, + user_factory, + product_user_wallet_no, + create_main_accounts, + delete_ledger_db, + caplog, + thl_lm, + lm, + utc_hour_ago, + currency, + ): + delete_ledger_db() + create_main_accounts() + user: User = user_factory(product=product_user_wallet_no) + + wall1 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.00"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=utc_hour_ago, + finished=utc_hour_ago + timedelta(seconds=1), + ) + tx = thl_lm.create_tx_task_complete( + wall=wall1, user=user, created=wall1.started + ) + assert isinstance(tx, LedgerTransaction) + + wall2 = Wall( + user_id=user.user_id, + source=Source.FULL_CIRCLE, + req_survey_id="yyy", + req_cpi=Decimal("3.00"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=utc_hour_ago, + finished=utc_hour_ago + timedelta(seconds=1), + ) + tx = thl_lm.create_tx_task_complete( + wall=wall2, user=user, created=wall2.started + ) + assert isinstance(tx, LedgerTransaction) + + session = Session(started=wall1.started, user=user, wall_events=[wall1, wall2]) + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": utc_hour_ago + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + thl_lm.create_tx_bp_payment(session, created=wall1.started) + + revenue = thl_lm.get_account_task_complete_revenue() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_commission_account = thl_lm.get_account_or_create_bp_commission(user.product) + assert 380 == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # cpi adjustment $1 -> $.60. + wall1.update( + adjusted_status=WallAdjustedStatus.CPI_ADJUSTMENT, + adjusted_cpi=Decimal("0.60"), + adjusted_timestamp=utc_hour_ago + timedelta(minutes=30), + ) + thl_lm.create_tx_task_adjustment(wall1, user) + + # -$0.40 b/c the MP took $0.40 back, but we haven't yet taken the BP payment back + assert -40 == thl_lm.get_account_balance(revenue) + session.adjust_status() + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + thl_lm.create_tx_bp_adjustment(session) + + # running this twice b/c it should do nothing the 2nd time + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + with caplog.at_level(logging.INFO): + thl_lm.create_tx_bp_adjustment(session) + assert "create_transaction_bp_adjustment." in caplog.text + assert "No transactions needed." in caplog.text + + assert 380 - 38 == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20 - 2 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # adjust it to failure + wall1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=utc_hour_ago + timedelta(minutes=45), + ) + thl_lm.create_tx_task_adjustment(wall1, user) + session.adjust_status() + thl_lm.create_tx_bp_adjustment(session) + assert 300 - (300 * 0.05) == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 300 * 0.05 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # and then back to cpi adj again, but this time for more than the orig amount + wall1.update( + adjusted_status=WallAdjustedStatus.CPI_ADJUSTMENT, + adjusted_cpi=Decimal("2.00"), + adjusted_timestamp=utc_hour_ago + timedelta(minutes=45), + ) + thl_lm.create_tx_task_adjustment(wall1, user) + session.adjust_status() + thl_lm.create_tx_bp_adjustment(session) + assert 500 - (500 * 0.05) == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 500 * 0.05 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # And adjust again + wall1.update( + adjusted_status=WallAdjustedStatus.CPI_ADJUSTMENT, + adjusted_cpi=Decimal("3.00"), + adjusted_timestamp=utc_hour_ago + timedelta(minutes=45), + ) + thl_lm.create_tx_task_adjustment(wall=wall1, user=user) + session.adjust_status() + thl_lm.create_tx_bp_adjustment(session=session) + assert 600 - (600 * 0.05) == thl_lm.get_account_balance( + account=bp_wallet_account + ) + assert 0 == thl_lm.get_account_balance(account=revenue) + assert 600 * 0.05 == thl_lm.get_account_balance(account=bp_commission_account) + assert thl_lm.check_ledger_balanced() diff --git a/tests/managers/thl/test_ledger/test_thl_lm_tx__user_payouts.py b/tests/managers/thl/test_ledger/test_thl_lm_tx__user_payouts.py new file mode 100644 index 0000000..1e7146a --- /dev/null +++ b/tests/managers/thl/test_ledger/test_thl_lm_tx__user_payouts.py @@ -0,0 +1,505 @@ +import logging +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from uuid import uuid4 + +import pytest + +from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerTransactionFlagAlreadyExistsError, + LedgerTransactionConditionFailedError, +) +from generalresearch.models.thl.user import User +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.thl.payout import UserPayoutEvent +from test_utils.managers.ledger.conftest import create_main_accounts + + +class TestLedgerManagerAMT: + + def test_create_transaction_amt_ass_request( + self, + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + user: User = user_factory(product=product_amt_true) + + # debit_account_uuid nothing checks they match the ledger ... todo? + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.AMT_HIT, + amount=5, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + flag_key = f"test:user_payout:{pe.uuid}:request" + flag_name = f"ledger-manager:transaction_flag:{flag_key}" + lm.redis_client.delete(flag_name) + + # User has $0 in their wallet. They are allowed amt_assignment payouts until -$1.00 + thl_lm.create_tx_user_payout_request(user=user, payout_event=pe) + with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError): + thl_lm.create_tx_user_payout_request( + user=user, payout_event=pe, skip_flag_check=False + ) + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + thl_lm.create_tx_user_payout_request( + user=user, payout_event=pe, skip_flag_check=True + ) + pe2 = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.AMT_HIT, + amount=96, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + flag_key = f"test:user_payout:{pe2.uuid}:request" + flag_name = f"ledger-manager:transaction_flag:{flag_key}" + lm.redis_client.delete(flag_name) + # 96 cents would put them over the -$1.00 limit + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + thl_lm.create_tx_user_payout_request(user, payout_event=pe2) + + # But they could do 0.95 cents + pe2.amount = 95 + thl_lm.create_tx_user_payout_request( + user, payout_event=pe2, skip_flag_check=True + ) + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + product=user.product + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user=user) + + assert 0 == lm.get_account_balance(account=bp_wallet_account) + assert 0 == lm.get_account_balance(account=cash) + assert 100 == lm.get_account_balance(account=bp_pending_account) + assert -100 == lm.get_account_balance(account=user_wallet_account) + assert thl_lm.check_ledger_balanced() + assert -5 == thl_lm.get_account_filtered_balance( + account=user_wallet_account, + metadata_key="payoutevent", + metadata_value=pe.uuid, + ) + + assert -95 == thl_lm.get_account_filtered_balance( + account=user_wallet_account, + metadata_key="payoutevent", + metadata_value=pe2.uuid, + ) + + def test_create_transaction_amt_ass_complete( + self, + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + user: User = user_factory(product=product_amt_true) + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.AMT_HIT, + amount=5, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + flag = f"ledger-manager:transaction_flag:test:user_payout:{pe.uuid}:request" + lm.redis_client.delete(flag) + flag = f"ledger-manager:transaction_flag:test:user_payout:{pe.uuid}:complete" + lm.redis_client.delete(flag) + + # User has $0 in their wallet. They are allowed amt_assignment payouts until -$1.00 + thl_lm.create_tx_user_payout_request(user, payout_event=pe) + thl_lm.create_tx_user_payout_complete(user, payout_event=pe) + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + user.product + ) + bp_amt_expense_account = thl_lm.get_account_or_create_bp_expense( + user.product, expense_name="amt" + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user) + + # BP wallet pays the 1cent fee + assert -1 == thl_lm.get_account_balance(bp_wallet_account) + assert -5 == thl_lm.get_account_balance(cash) + assert -1 == thl_lm.get_account_balance(bp_amt_expense_account) + assert 0 == thl_lm.get_account_balance(bp_pending_account) + assert -5 == lm.get_account_balance(user_wallet_account) + assert thl_lm.check_ledger_balanced() + + def test_create_transaction_amt_bonus( + self, + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_amt_true) + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.AMT_BONUS, + amount=34, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + flag = f"ledger-manager:transaction_flag:test:user_payout:{pe.uuid}:request" + lm.redis_client.delete(flag) + flag = f"ledger-manager:transaction_flag:test:user_payout:{pe.uuid}:complete" + lm.redis_client.delete(flag) + + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + # User has $0 in their wallet. No amt bonus allowed + thl_lm.create_tx_user_payout_request(user, payout_event=pe) + + thl_lm.create_tx_user_bonus( + user, + amount=Decimal(5), + ref_uuid="e703830dec124f17abed2d697d8d7701", + description="Bribe", + skip_flag_check=True, + ) + pe.amount = 101 + thl_lm.create_tx_user_payout_request( + user, payout_event=pe, skip_flag_check=False + ) + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe, skip_flag_check=False + ) + with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError): + # duplicate, even if amount changed + pe.amount = 200 + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe, skip_flag_check=False + ) + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + # duplicate + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe, skip_flag_check=True + ) + pe.uuid = "533364150de4451198e5774e221a2acb" + pe.amount = 9900 + with pytest.raises(expected_exception=ValueError): + # Trying to complete payout with no pending tx + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe, skip_flag_check=True + ) + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + # trying to payout $99 with only a $5 balance + thl_lm.create_tx_user_payout_request( + user, payout_event=pe, skip_flag_check=True + ) + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + user.product + ) + bp_amt_expense_account = thl_lm.get_account_or_create_bp_expense( + user.product, expense_name="amt" + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user) + assert -500 + round(-101 * 0.20) == thl_lm.get_account_balance( + bp_wallet_account + ) + assert -101 == lm.get_account_balance(cash) + assert -20 == lm.get_account_balance(bp_amt_expense_account) + assert 0 == lm.get_account_balance(bp_pending_account) + assert 500 - 101 == lm.get_account_balance(user_wallet_account) + assert lm.check_ledger_balanced() is True + + def test_create_transaction_amt_bonus_cancel( + self, + user_factory, + product_amt_true, + create_main_accounts, + caplog, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + now = datetime.now(timezone.utc) - timedelta(hours=1) + user: User = user_factory(product=product_amt_true) + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.AMT_BONUS, + amount=101, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + thl_lm.create_tx_user_bonus( + user, + amount=Decimal(5), + ref_uuid="c44f4da2db1d421ebc6a5e5241ca4ce6", + description="Bribe", + skip_flag_check=True, + ) + thl_lm.create_tx_user_payout_request( + user, payout_event=pe, skip_flag_check=True + ) + thl_lm.create_tx_user_payout_cancelled( + user, payout_event=pe, skip_flag_check=True + ) + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + with caplog.at_level(logging.WARNING): + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe, skip_flag_check=True + ) + assert "trying to complete payout that was already cancelled" in caplog.text + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + user.product + ) + bp_amt_expense_account = thl_lm.get_account_or_create_bp_expense( + user.product, expense_name="amt" + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user) + assert -500 == thl_lm.get_account_balance(account=bp_wallet_account) + assert 0 == thl_lm.get_account_balance(account=cash) + assert 0 == thl_lm.get_account_balance(account=bp_amt_expense_account) + assert 0 == thl_lm.get_account_balance(account=bp_pending_account) + assert 500 == thl_lm.get_account_balance(account=user_wallet_account) + assert thl_lm.check_ledger_balanced() + + pe2 = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.AMT_BONUS, + amount=200, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + thl_lm.create_tx_user_payout_request( + user, payout_event=pe2, skip_flag_check=True + ) + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe2, skip_flag_check=True + ) + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + with caplog.at_level(logging.WARNING): + thl_lm.create_tx_user_payout_cancelled( + user, payout_event=pe2, skip_flag_check=True + ) + assert "trying to cancel payout that was already completed" in caplog.text + + +class TestLedgerManagerTango: + + def test_create_transaction_tango_request( + self, + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_amt_true) + + # debit_account_uuid nothing checks they match the ledger ... todo? + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.TANGO, + amount=500, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + flag_key = f"test:user_payout:{pe.uuid}:request" + flag_name = f"ledger-manager:transaction_flag:{flag_key}" + lm.redis_client.delete(flag_name) + thl_lm.create_tx_user_bonus( + user, + amount=Decimal(6), + ref_uuid="e703830dec124f17abed2d697d8d7701", + description="Bribe", + skip_flag_check=True, + ) + thl_lm.create_tx_user_payout_request( + user, payout_event=pe, skip_flag_check=True + ) + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + user.product + ) + bp_tango_expense_account = thl_lm.get_account_or_create_bp_expense( + user.product, expense_name="tango" + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user) + assert -600 == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(cash) + assert 0 == thl_lm.get_account_balance(bp_tango_expense_account) + assert 500 == thl_lm.get_account_balance(bp_pending_account) + assert 600 - 500 == thl_lm.get_account_balance(user_wallet_account) + assert thl_lm.check_ledger_balanced() + + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe, skip_flag_check=True + ) + assert -600 - round(500 * 0.035) == thl_lm.get_account_balance( + bp_wallet_account + ) + assert -500, thl_lm.get_account_balance(cash) + assert round(-500 * 0.035) == thl_lm.get_account_balance( + bp_tango_expense_account + ) + assert 0 == lm.get_account_balance(bp_pending_account) + assert 100 == lm.get_account_balance(user_wallet_account) + assert lm.check_ledger_balanced() + + +class TestLedgerManagerPaypal: + + def test_create_transaction_paypal_request( + self, + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + now = datetime.now(tz=timezone.utc) - timedelta(hours=1) + user: User = user_factory(product=product_amt_true) + + # debit_account_uuid nothing checks they match the ledger ... todo? + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=500, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + flag_key = f"test:user_payout:{pe.uuid}:request" + flag_name = f"ledger-manager:transaction_flag:{flag_key}" + lm.redis_client.delete(flag_name) + thl_lm.create_tx_user_bonus( + user=user, + amount=Decimal(6), + ref_uuid="e703830dec124f17abed2d697d8d7701", + description="Bribe", + skip_flag_check=True, + ) + + thl_lm.create_tx_user_payout_request( + user, payout_event=pe, skip_flag_check=True + ) + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + product=user.product + ) + bp_paypal_expense_account = thl_lm.get_account_or_create_bp_expense( + product=user.product, expense_name="paypal" + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user=user) + assert -600 == lm.get_account_balance(account=bp_wallet_account) + assert 0 == lm.get_account_balance(account=cash) + assert 0 == lm.get_account_balance(account=bp_paypal_expense_account) + assert 500 == lm.get_account_balance(account=bp_pending_account) + assert 600 - 500 == lm.get_account_balance(account=user_wallet_account) + assert thl_lm.check_ledger_balanced() + + thl_lm.create_tx_user_payout_complete( + user=user, payout_event=pe, skip_flag_check=True, fee_amount=Decimal("0.50") + ) + assert -600 - 50 == thl_lm.get_account_balance(bp_wallet_account) + assert -500 == thl_lm.get_account_balance(cash) + assert -50 == thl_lm.get_account_balance(bp_paypal_expense_account) + assert 0 == thl_lm.get_account_balance(bp_pending_account) + assert 100 == thl_lm.get_account_balance(user_wallet_account) + assert thl_lm.check_ledger_balanced() + + +class TestLedgerManagerBonus: + + def test_create_transaction_bonus( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_user_wallet_yes) + + thl_lm.create_tx_user_bonus( + user=user, + amount=Decimal(5), + ref_uuid="8d0aaf612462448a9ebdd57fab0fc660", + description="Bribe", + skip_flag_check=True, + ) + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + product=user.product + ) + bp_amt_expense_account = thl_lm.get_account_or_create_bp_expense( + user.product, expense_name="amt" + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user=user) + + assert -500 == lm.get_account_balance(account=bp_wallet_account) + assert 0 == lm.get_account_balance(account=cash) + assert 0 == lm.get_account_balance(account=bp_amt_expense_account) + assert 0 == lm.get_account_balance(account=bp_pending_account) + assert 500 == lm.get_account_balance(account=user_wallet_account) + assert thl_lm.check_ledger_balanced() + + with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError): + thl_lm.create_tx_user_bonus( + user=user, + amount=Decimal(5), + ref_uuid="8d0aaf612462448a9ebdd57fab0fc660", + description="Bribe", + skip_flag_check=False, + ) + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + thl_lm.create_tx_user_bonus( + user=user, + amount=Decimal(5), + ref_uuid="8d0aaf612462448a9ebdd57fab0fc660", + description="Bribe", + skip_flag_check=True, + ) diff --git a/tests/managers/thl/test_ledger/test_thl_pem.py b/tests/managers/thl/test_ledger/test_thl_pem.py new file mode 100644 index 0000000..5fb9e7d --- /dev/null +++ b/tests/managers/thl/test_ledger/test_thl_pem.py @@ -0,0 +1,251 @@ +import uuid +from random import randint +from uuid import uuid4, UUID + +import pytest + +from generalresearch.currency import USDCent +from generalresearch.models.thl.definitions import PayoutStatus +from generalresearch.models.thl.payout import BrokerageProductPayoutEvent +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.wallet.cashout_method import ( + CashoutRequestInfo, +) + + +class TestThlPayoutEventManager: + + def test_get_by_uuid(self, brokerage_product_payout_event_manager, thl_lm): + """This validates that the method raises an exception if it + fails. There are plenty of other tests that use this method so + it seems silly to duplicate it here again + """ + + with pytest.raises(expected_exception=AssertionError) as excinfo: + brokerage_product_payout_event_manager.get_by_uuid(pe_uuid=uuid4().hex) + assert "expected 1 result, got 0" in str(excinfo.value) + + def test_filter_by( + self, + product_factory, + usd_cent, + bp_payout_event_factory, + thl_lm, + brokerage_product_payout_event_manager, + ): + from generalresearch.models.thl.payout import UserPayoutEvent + + N_PRODUCTS = randint(3, 10) + N_PAYOUT_EVENTS = randint(3, 10) + amounts = [] + products = [] + + for x_idx in range(N_PRODUCTS): + product: Product = product_factory() + thl_lm.get_account_or_create_bp_wallet(product=product) + products.append(product) + brokerage_product_payout_event_manager.set_account_lookup_table( + thl_lm=thl_lm + ) + + for y_idx in range(N_PAYOUT_EVENTS): + pe = bp_payout_event_factory(product=product, usd_cent=usd_cent) + amounts.append(int(usd_cent)) + assert isinstance(pe, BrokerageProductPayoutEvent) + + # We just added Payout Events for Products, now go ahead and + # query for them + accounts = thl_lm.get_accounts_bp_wallet_for_products( + product_uuids=[i.uuid for i in products] + ) + res = brokerage_product_payout_event_manager.filter_by( + debit_account_uuids=[i.uuid for i in accounts] + ) + + assert len(res) == (N_PRODUCTS * N_PAYOUT_EVENTS) + assert sum([i.amount for i in res]) == sum(amounts) + + def test_get_bp_payout_events_for_product( + self, + product_factory, + usd_cent, + bp_payout_event_factory, + brokerage_product_payout_event_manager, + thl_lm, + ): + from generalresearch.models.thl.payout import UserPayoutEvent + + N_PRODUCTS = randint(3, 10) + N_PAYOUT_EVENTS = randint(3, 10) + amounts = [] + products = [] + + for x_idx in range(N_PRODUCTS): + product: Product = product_factory() + products.append(product) + thl_lm.get_account_or_create_bp_wallet(product=product) + brokerage_product_payout_event_manager.set_account_lookup_table( + thl_lm=thl_lm + ) + + for y_idx in range(N_PAYOUT_EVENTS): + pe = bp_payout_event_factory(product=product, usd_cent=usd_cent) + amounts.append(usd_cent) + assert isinstance(pe, BrokerageProductPayoutEvent) + + # We just added 5 Payouts for a specific Product, now go + # ahead and query for them + res = brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products( + thl_ledger_manager=thl_lm, product_uuids=[product.id] + ) + + assert len(res) == N_PAYOUT_EVENTS + + # Now that all the Payouts for all the Products have been added, go + # ahead and query for them + res = ( + brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products( + thl_ledger_manager=thl_lm, product_uuids=[i.uuid for i in products] + ) + ) + + assert len(res) == (N_PRODUCTS * N_PAYOUT_EVENTS) + assert sum([i.amount for i in res]) == sum(amounts) + + @pytest.mark.skip + def test_get_payout_detail(self, user_payout_event_manager): + """This fails because the description coming back is None, but then + it tries to return a PayoutEvent which validates that the + description can't be None + """ + from generalresearch.models.thl.payout import ( + UserPayoutEvent, + PayoutType, + ) + + rand_amount = randint(a=99, b=999) + + pe = user_payout_event_manager.create( + debit_account_uuid=uuid4().hex, + account_reference_type="str-type-random", + account_reference_uuid=uuid4().hex, + cashout_method_uuid=uuid4().hex, + description="Best payout !", + amount=rand_amount, + status=PayoutStatus.PENDING, + ext_ref_id="123", + payout_type=PayoutType.CASH_IN_MAIL, + request_data={"foo": 123}, + order_data={}, + ) + + res = user_payout_event_manager.get_payout_detail(pe_uuid=pe.uuid) + assert isinstance(res, CashoutRequestInfo) + + # def test_filter_by(self): + # raise NotImplementedError + + def test_create(self, user_payout_event_manager): + from generalresearch.models.thl.payout import UserPayoutEvent + + # Confirm the creation method returns back an instance. + pe = user_payout_event_manager.create_dummy() + assert isinstance(pe, UserPayoutEvent) + + # Now query the DB for that PayoutEvent to confirm it was actually + # saved. + res = user_payout_event_manager.get_by_uuid(pe_uuid=pe.uuid) + assert isinstance(res, UserPayoutEvent) + assert UUID(res.uuid) + + # Confirm they're the same + # assert pe.model_dump_json() == res2.model_dump_json() + assert res.description is None + + # def test_update(self): + # raise NotImplementedError + + def test_create_bp_payout( + self, + product, + delete_ledger_db, + create_main_accounts, + thl_lm, + brokerage_product_payout_event_manager, + lm, + ): + from generalresearch.models.thl.payout import UserPayoutEvent + + delete_ledger_db() + create_main_accounts() + + account_bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=product) + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + rand_amount = randint(a=99, b=999) + + # Save a Brokerage Product Payout, so we have something in the + # Payout Event table and the respective ledger TX and Entry rows for it + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(rand_amount), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + assert isinstance(pe, BrokerageProductPayoutEvent) + + # Now try to query for it! + res = thl_lm.get_tx_bp_payouts(account_uuids=[account_bp_wallet.uuid]) + assert len(res) == 1 + res = thl_lm.get_tx_bp_payouts(account_uuids=[uuid4().hex]) + assert len(res) == 0 + + # Confirm it added to the users balance. The amount is negative because + # money was sent to the Brokerage Product, but they didn't have + # any activity that earned them money + bal = lm.get_account_balance(account=account_bp_wallet) + assert rand_amount == bal * -1 + + +class TestBPPayoutEvent: + + def test_get_bp_bp_payout_events_for_products( + self, + product_factory, + bp_payout_event_factory, + usd_cent, + delete_ledger_db, + create_main_accounts, + brokerage_product_payout_event_manager, + thl_lm, + ): + delete_ledger_db() + create_main_accounts() + + N_PAYOUT_EVENTS = randint(3, 10) + amounts = [] + + product: Product = product_factory() + thl_lm.get_account_or_create_bp_wallet(product=product) + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + for y_idx in range(N_PAYOUT_EVENTS): + bp_payout_event_factory(product=product, usd_cent=usd_cent) + amounts.append(usd_cent) + + # Fetch using the _bp_bp_ approach, so we have an + # array of BPPayoutEvents + bp_bp_res = ( + brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products( + thl_ledger_manager=thl_lm, product_uuids=[product.uuid] + ) + ) + assert isinstance(bp_bp_res, list) + assert sum(amounts) == sum([i.amount for i in bp_bp_res]) + for i in bp_bp_res: + assert isinstance(i, BrokerageProductPayoutEvent) + assert isinstance(i.amount, int) + assert isinstance(i.amount_usd, USDCent) + assert isinstance(i.amount_usd_str, str) + assert i.amount_usd_str[0] == "$" diff --git a/tests/managers/thl/test_ledger/test_user_txs.py b/tests/managers/thl/test_ledger/test_user_txs.py new file mode 100644 index 0000000..d81b244 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_user_txs.py @@ -0,0 +1,288 @@ +from datetime import timedelta, datetime, timezone +from decimal import Decimal +from uuid import uuid4 + +from generalresearch.managers.thl.payout import ( + AMT_ASSIGNMENT_CASHOUT_METHOD, + AMT_BONUS_CASHOUT_METHOD, +) +from generalresearch.managers.thl.user_compensate import user_compensate +from generalresearch.models.thl.definitions import ( + Status, + WallAdjustedStatus, +) +from generalresearch.models.thl.ledger import ( + UserLedgerTransactionTypesSummary, + UserLedgerTransactionTypeSummary, + TransactionType, +) +from generalresearch.models.thl.session import Session +from generalresearch.models.thl.user import User +from generalresearch.models.thl.wallet import PayoutType + + +def test_user_txs( + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + session_with_tx_factory, + adj_to_fail_with_tx_factory, + adj_to_complete_with_tx_factory, + session_factory, + user_payout_event_manager, + utc_now, +): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_amt_true) + account = thl_lm.get_account_or_create_user_wallet(user) + print(f"{account.uuid=}") + + s: Session = session_with_tx_factory(user=user, wall_req_cpi=Decimal("1.00")) + + bribe_uuid = user_compensate( + ledger_manager=thl_lm, + user=user, + amount_int=100, + ) + + pe = user_payout_event_manager.create( + uuid=uuid4().hex, + debit_account_uuid=account.uuid, + cashout_method_uuid=AMT_ASSIGNMENT_CASHOUT_METHOD, + amount=5, + created=utc_now, + payout_type=PayoutType.AMT_HIT, + request_data=dict(), + ) + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + ) + pe = user_payout_event_manager.create( + uuid=uuid4().hex, + debit_account_uuid=account.uuid, + cashout_method_uuid=AMT_BONUS_CASHOUT_METHOD, + amount=127, + created=utc_now, + payout_type=PayoutType.AMT_BONUS, + request_data=dict(), + ) + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + ) + + wall = s.wall_events[-1] + adj_to_fail_with_tx_factory(session=s, created=wall.finished) + + # And a fail -> complete adjustment + s_fail: Session = session_factory( + user=user, + wall_count=1, + final_status=Status.FAIL, + wall_req_cpi=Decimal("2.00"), + ) + adj_to_complete_with_tx_factory(session=s_fail, created=utc_now) + + # txs = thl_lm.get_tx_filtered_by_account(account.uuid) + # print(len(txs), txs) + txs = thl_lm.get_user_txs(user) + assert len(txs.transactions) == 6 + assert txs.total == 6 + assert txs.page == 1 + assert txs.size == 50 + + # print(len(txs.transactions), txs) + d = txs.model_dump_json() + # print(d) + + descriptions = {x.description for x in txs.transactions} + assert descriptions == { + "Compensation Bonus", + "HIT Bonus", + "HIT Reward", + "Task Adjustment", + "Task Complete", + } + amounts = {x.amount for x in txs.transactions} + assert amounts == {-127, 100, 38, -38, -5, 76} + + assert txs.summary == UserLedgerTransactionTypesSummary( + bp_adjustment=UserLedgerTransactionTypeSummary( + entry_count=2, min_amount=-38, max_amount=76, total_amount=76 - 38 + ), + bp_payment=UserLedgerTransactionTypeSummary( + entry_count=1, min_amount=38, max_amount=38, total_amount=38 + ), + user_bonus=UserLedgerTransactionTypeSummary( + entry_count=1, min_amount=100, max_amount=100, total_amount=100 + ), + user_payout_request=UserLedgerTransactionTypeSummary( + entry_count=2, min_amount=-127, max_amount=-5, total_amount=-132 + ), + ) + tx_adj_c = [ + tx for tx in txs.transactions if tx.tx_type == TransactionType.BP_ADJUSTMENT + ] + assert sorted([tx.amount for tx in tx_adj_c]) == [-38, 76] + + +def test_user_txs_pagination( + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + session_with_tx_factory, + adj_to_fail_with_tx_factory, + user_payout_event_manager, + utc_now, +): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_amt_true) + account = thl_lm.get_account_or_create_user_wallet(user) + print(f"{account.uuid=}") + + for _ in range(12): + user_compensate( + ledger_manager=thl_lm, + user=user, + amount_int=100, + skip_flag_check=True, + ) + + txs = thl_lm.get_user_txs(user, page=1, size=5) + assert len(txs.transactions) == 5 + assert txs.total == 12 + assert txs.page == 1 + assert txs.size == 5 + assert txs.summary.user_bonus.total_amount == 1200 + assert txs.summary.user_bonus.entry_count == 12 + + # Skip to the 3rd page. We made 12, so there are 2 left + txs = thl_lm.get_user_txs(user, page=3, size=5) + assert len(txs.transactions) == 2 + assert txs.total == 12 + assert txs.page == 3 + assert txs.summary.user_bonus.total_amount == 1200 + assert txs.summary.user_bonus.entry_count == 12 + + # Should be empty, not fail + txs = thl_lm.get_user_txs(user, page=4, size=5) + assert len(txs.transactions) == 0 + assert txs.total == 12 + assert txs.page == 4 + assert txs.summary.user_bonus.total_amount == 1200 + assert txs.summary.user_bonus.entry_count == 12 + + # Test filtering. We should pull back only this one + now = datetime.now(tz=timezone.utc) + user_compensate( + ledger_manager=thl_lm, + user=user, + amount_int=100, + skip_flag_check=True, + ) + txs = thl_lm.get_user_txs(user, page=1, size=5, time_start=now) + assert len(txs.transactions) == 1 + assert txs.total == 1 + assert txs.page == 1 + # And the summary is restricted to this time range also! + assert txs.summary.user_bonus.total_amount == 100 + assert txs.summary.user_bonus.entry_count == 1 + + # And filtering with 0 results + now = datetime.now(tz=timezone.utc) + txs = thl_lm.get_user_txs(user, page=1, size=5, time_start=now) + assert len(txs.transactions) == 0 + assert txs.total == 0 + assert txs.page == 1 + assert txs.pages == 0 + # And the summary is restricted to this time range also! + assert txs.summary.user_bonus.total_amount == None + assert txs.summary.user_bonus.entry_count == 0 + + +def test_user_txs_rolling_balance( + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + session_with_tx_factory, + adj_to_fail_with_tx_factory, + user_payout_event_manager, +): + """ + Creates 3 $1.00 bonuses (postive), + then 1 cashout (negative), $1.50 + then 3 more $1.00 bonuses. + Note: pagination + rolling balance will BREAK if txs have + identical timestamps. In practice, they do not. + """ + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_amt_true) + account = thl_lm.get_account_or_create_user_wallet(user) + + for _ in range(3): + user_compensate( + ledger_manager=thl_lm, + user=user, + amount_int=100, + skip_flag_check=True, + ) + pe = user_payout_event_manager.create( + uuid=uuid4().hex, + debit_account_uuid=account.uuid, + cashout_method_uuid=AMT_BONUS_CASHOUT_METHOD, + amount=150, + payout_type=PayoutType.AMT_BONUS, + request_data=dict(), + ) + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + ) + for _ in range(3): + user_compensate( + ledger_manager=thl_lm, + user=user, + amount_int=100, + skip_flag_check=True, + ) + + txs = thl_lm.get_user_txs(user, page=1, size=10) + assert txs.transactions[0].balance_after == 100 + assert txs.transactions[1].balance_after == 200 + assert txs.transactions[2].balance_after == 300 + assert txs.transactions[3].balance_after == 150 + assert txs.transactions[4].balance_after == 250 + assert txs.transactions[5].balance_after == 350 + assert txs.transactions[6].balance_after == 450 + + # Ascending order, get 2nd page, make sure the balances include + # the previous txs. (will return last 3 txs) + txs = thl_lm.get_user_txs(user, page=2, size=4) + assert len(txs.transactions) == 3 + assert txs.transactions[0].balance_after == 250 + assert txs.transactions[1].balance_after == 350 + assert txs.transactions[2].balance_after == 450 + + # Descending order, get 1st page. Will + # return most recent 3 txs in desc order + txs = thl_lm.get_user_txs(user, page=1, size=3, order_by="-created") + assert len(txs.transactions) == 3 + assert txs.transactions[0].balance_after == 450 + assert txs.transactions[1].balance_after == 350 + assert txs.transactions[2].balance_after == 250 diff --git a/tests/managers/thl/test_ledger/test_wallet.py b/tests/managers/thl/test_ledger/test_wallet.py new file mode 100644 index 0000000..a0abd7c --- /dev/null +++ b/tests/managers/thl/test_ledger/test_wallet.py @@ -0,0 +1,78 @@ +from decimal import Decimal +from uuid import uuid4 + +import pytest + +from generalresearch.models.thl.product import ( + UserWalletConfig, + PayoutConfig, + PayoutTransformation, + PayoutTransformationPercentArgs, +) +from generalresearch.models.thl.user import User + + +@pytest.fixture() +def schrute_product(product_manager): + return product_manager.create_dummy( + user_wallet_config=UserWalletConfig(enabled=True, amt=False), + payout_config=PayoutConfig( + payout_transformation=PayoutTransformation( + f="payout_transformation_percent", + kwargs=PayoutTransformationPercentArgs(pct=0.4), + ), + payout_format="{payout:,.0f} Schrute Bucks", + ), + ) + + +class TestGetUserWalletBalance: + def test_get_user_wallet_balance_non_managed(self, user, thl_lm): + with pytest.raises( + AssertionError, + match="Can't get wallet balance on non-managed account.", + ): + thl_lm.get_user_wallet_balance(user=user) + + def test_get_user_wallet_balance_managed_0( + self, schrute_product, user_factory, thl_lm + ): + assert ( + schrute_product.payout_config.payout_format == "{payout:,.0f} Schrute Bucks" + ) + user: User = user_factory(schrute_product) + balance = thl_lm.get_user_wallet_balance(user=user) + assert balance == 0 + balance_string = user.product.format_payout_format(Decimal(balance) / 100) + assert balance_string == "0 Schrute Bucks" + redeemable_balance = thl_lm.get_user_redeemable_wallet_balance( + user=user, user_wallet_balance=balance + ) + assert redeemable_balance == 0 + redeemable_balance_string = user.product.format_payout_format( + Decimal(redeemable_balance) / 100 + ) + assert redeemable_balance_string == "0 Schrute Bucks" + + def test_get_user_wallet_balance_managed( + self, schrute_product, user_factory, thl_lm, session_with_tx_factory + ): + user: User = user_factory(schrute_product) + thl_lm.create_tx_user_bonus( + user=user, + amount=Decimal(1), + ref_uuid=uuid4().hex, + description="cheese", + ) + session_with_tx_factory(user=user, wall_req_cpi=Decimal("1.23")) + + # This product has a payout xform of 40% and commission of 5% + # 1.23 * 0.05 = 0.06 of commission + # 1.17 of payout * 0.40 = 0.47 of user pay and (1.17-0.47) 0.70 bp pay + balance = thl_lm.get_user_wallet_balance(user=user) + assert balance == 47 + 100 # plus the $1 bribe + + redeemable_balance = thl_lm.get_user_redeemable_wallet_balance( + user=user, user_wallet_balance=balance + ) + assert redeemable_balance == 20 + 100 diff --git a/tests/managers/thl/test_maxmind.py b/tests/managers/thl/test_maxmind.py new file mode 100644 index 0000000..c588c58 --- /dev/null +++ b/tests/managers/thl/test_maxmind.py @@ -0,0 +1,273 @@ +import json +import logging +from typing import Callable + +import geoip2.models +import pytest +from faker import Faker +from faker.providers.address.en_US import Provider as USAddressProvider + +from generalresearch.managers.thl.ipinfo import GeoIpInfoManager +from generalresearch.managers.thl.maxmind import MaxmindManager +from generalresearch.managers.thl.maxmind.basic import ( + MaxmindBasicManager, +) +from generalresearch.models.thl.ipinfo import ( + GeoIPInformation, + normalize_ip, +) +from generalresearch.models.thl.maxmind.definitions import UserType + +fake = Faker() + +US_STATES = {x.lower() for x in USAddressProvider.states} + +IP_v4_INDIA = "106.203.146.157" +IP_v6_INDIA = "2402:3a80:4649:de3f:0:24:74aa:b601" +IP_v4_US = "174.218.60.101" +IP_v6_US = "2600:1700:ece0:9410:55d:faf3:c15d:6e4" +IP_v6_US_SAME_64 = "2600:1700:ece0:9410:55d:faf3:c15d:aaaa" + + +@pytest.fixture(scope="session") +def delete_ipinfo(thl_web_rw) -> Callable: + def _delete_ipinfo(ip): + thl_web_rw.execute_write( + query="DELETE FROM thl_geoname WHERE geoname_id IN (SELECT geoname_id FROM thl_ipinformation WHERE ip = %s);", + params=[ip], + ) + thl_web_rw.execute_write( + query="DELETE FROM thl_ipinformation WHERE ip = %s;", + params=[ip], + ) + + return _delete_ipinfo + + +class TestMaxmindBasicManager: + + def test_init(self, maxmind_basic_manager): + + assert isinstance(maxmind_basic_manager, MaxmindBasicManager) + + def test_get_basic_ip_information(self, maxmind_basic_manager): + ip = IP_v4_INDIA + maxmind_basic_manager.run_update_geoip_db() + + res1 = maxmind_basic_manager.get_basic_ip_information(ip_address=ip) + assert isinstance(res1, geoip2.models.Country) + assert res1.country.iso_code == "IN" + assert res1.country.name == "India" + + res2 = maxmind_basic_manager.get_basic_ip_information( + ip_address=fake.ipv4_private() + ) + assert res2 is None + + def test_get_country_iso_from_ip_geoip2db(self, maxmind_basic_manager): + ip = IP_v4_INDIA + maxmind_basic_manager.run_update_geoip_db() + + res1 = maxmind_basic_manager.get_country_iso_from_ip_geoip2db(ip=ip) + assert res1 == "in" + + res2 = maxmind_basic_manager.get_country_iso_from_ip_geoip2db( + ip=fake.ipv4_private() + ) + assert res2 is None + + def test_get_basic_ip_information_ipv6(self, maxmind_basic_manager): + ip = IP_v6_INDIA + maxmind_basic_manager.run_update_geoip_db() + + res1 = maxmind_basic_manager.get_basic_ip_information(ip_address=ip) + assert isinstance(res1, geoip2.models.Country) + assert res1.country.iso_code == "IN" + assert res1.country.name == "India" + + +class TestMaxmindManager: + + def test_init(self, thl_web_rr, thl_redis_config, maxmind_manager: MaxmindManager): + instance = MaxmindManager(pg_config=thl_web_rr, redis_config=thl_redis_config) + assert isinstance(instance, MaxmindManager) + assert isinstance(maxmind_manager, MaxmindManager) + + def test_create_basic( + self, + maxmind_manager: MaxmindManager, + geoipinfo_manager: GeoIpInfoManager, + delete_ipinfo, + ): + # This is (currently) an IP in India, and so it should only do the basic lookup + ip = IP_v4_INDIA + delete_ipinfo(ip) + geoipinfo_manager.clear_cache(ip) + assert geoipinfo_manager.get_cache(ip) is None + assert geoipinfo_manager.get_mysql_if_exists(ip) is None + + maxmind_manager.run_ip_information(ip, force_insights=False) + # Check that it is in the cache and in mysql + res = geoipinfo_manager.get_cache(ip) + assert res.ip == ip + assert res.basic + res = geoipinfo_manager.get_mysql(ip) + assert res.ip == ip + assert res.basic + + def test_create_basic_ipv6( + self, + maxmind_manager: MaxmindManager, + geoipinfo_manager: GeoIpInfoManager, + delete_ipinfo, + ): + # This is (currently) an IP in India, and so it should only do the basic lookup + ip = IP_v6_INDIA + normalized_ip, lookup_prefix = normalize_ip(ip) + delete_ipinfo(ip) + geoipinfo_manager.clear_cache(ip) + delete_ipinfo(normalized_ip) + geoipinfo_manager.clear_cache(normalized_ip) + assert geoipinfo_manager.get_cache(ip) is None + assert geoipinfo_manager.get_cache(normalized_ip) is None + assert geoipinfo_manager.get_mysql_if_exists(ip) is None + assert geoipinfo_manager.get_mysql_if_exists(normalized_ip) is None + + maxmind_manager.run_ip_information(ip, force_insights=False) + + # Check that it is in the cache + res = geoipinfo_manager.get_cache(ip) + # The looked up IP (/128) is returned, + assert res.ip == ip + assert res.lookup_prefix == "/64" + assert res.basic + + # ... but the normalized version was stored (/64) + assert geoipinfo_manager.get_cache_raw(ip) is None + res = json.loads(geoipinfo_manager.get_cache_raw(normalized_ip)) + assert res["ip"] == normalized_ip + + # Check mysql + res = geoipinfo_manager.get_mysql(ip) + assert res.ip == ip + assert res.lookup_prefix == "/64" + assert res.basic + with pytest.raises(AssertionError): + geoipinfo_manager.get_mysql_raw(ip) + res = geoipinfo_manager.get_mysql_raw(normalized_ip) + assert res["ip"] == normalized_ip + + def test_create_insights( + self, + maxmind_manager: MaxmindManager, + geoipinfo_manager: GeoIpInfoManager, + delete_ipinfo, + ): + # This is (currently) an IP in the US, so it should do insights + ip = IP_v4_US + delete_ipinfo(ip) + geoipinfo_manager.clear_cache(ip) + assert geoipinfo_manager.get_cache(ip) is None + assert geoipinfo_manager.get_mysql_if_exists(ip) is None + + res1 = maxmind_manager.run_ip_information(ip, force_insights=False) + assert isinstance(res1, GeoIPInformation) + + # Check that it is in the cache and in mysql + res2 = geoipinfo_manager.get_cache(ip) + assert isinstance(res2, GeoIPInformation) + assert res2.ip == ip + assert not res2.basic + + res3 = geoipinfo_manager.get_mysql(ip) + assert isinstance(res3, GeoIPInformation) + assert res3.ip == ip + assert not res3.basic + assert res3.is_anonymous is False + assert res3.subdivision_1_name.lower() in US_STATES + # this might change ... + assert res3.user_type == UserType.CELLULAR + + assert res1 == res2 == res3, "runner, cache, mysql all return same instance" + + def test_create_insights_ipv6( + self, + maxmind_manager: MaxmindManager, + geoipinfo_manager: GeoIpInfoManager, + delete_ipinfo, + ): + # This is (currently) an IP in the US, so it should do insights + ip = IP_v6_US + normalized_ip, lookup_prefix = normalize_ip(ip) + delete_ipinfo(ip) + geoipinfo_manager.clear_cache(ip) + delete_ipinfo(normalized_ip) + geoipinfo_manager.clear_cache(normalized_ip) + assert geoipinfo_manager.get_cache(ip) is None + assert geoipinfo_manager.get_cache(normalized_ip) is None + assert geoipinfo_manager.get_mysql_if_exists(ip) is None + assert geoipinfo_manager.get_mysql_if_exists(normalized_ip) is None + + res1 = maxmind_manager.run_ip_information(ip, force_insights=False) + assert isinstance(res1, GeoIPInformation) + assert res1.lookup_prefix == "/64" + + # Check that it is in the cache and in mysql + res2 = geoipinfo_manager.get_cache(ip) + assert isinstance(res2, GeoIPInformation) + assert res2.ip == ip + assert not res2.basic + + res3 = geoipinfo_manager.get_mysql(ip) + assert isinstance(res3, GeoIPInformation) + assert res3.ip == ip + assert not res3.basic + assert res3.is_anonymous is False + assert res3.subdivision_1_name.lower() in US_STATES + # this might change ... + assert res3.user_type == UserType.RESIDENTIAL + + assert res1 == res2 == res3, "runner, cache, mysql all return same instance" + + def test_get_or_create_ip_information(self, maxmind_manager): + ip = IP_v4_US + + res1 = maxmind_manager.get_or_create_ip_information(ip_address=ip) + assert isinstance(res1, GeoIPInformation) + + res2 = maxmind_manager.get_or_create_ip_information( + ip_address=fake.ipv4_private() + ) + assert res2 is None + + def test_get_or_create_ip_information_ipv6( + self, maxmind_manager, delete_ipinfo, geoipinfo_manager, caplog + ): + ip = IP_v6_US + normalized_ip, lookup_prefix = normalize_ip(ip) + delete_ipinfo(normalized_ip) + geoipinfo_manager.clear_cache(normalized_ip) + + with caplog.at_level(logging.INFO): + res1 = maxmind_manager.get_or_create_ip_information(ip_address=ip) + assert isinstance(res1, GeoIPInformation) + assert res1.ip == ip + # It looks up in insight using the normalize IP! + assert f"get_insights_ip_information: {normalized_ip}" in caplog.text + + # And it should NOT do the lookup again with an ipv6 in the same /64 block! + ip = IP_v6_US_SAME_64 + caplog.clear() + with caplog.at_level(logging.INFO): + res2 = maxmind_manager.get_or_create_ip_information(ip_address=ip) + assert isinstance(res2, GeoIPInformation) + assert res2.ip == ip + assert "get_insights_ip_information" not in caplog.text + + def test_run_ip_information(self, maxmind_manager): + ip = IP_v4_US + + res = maxmind_manager.run_ip_information(ip_address=ip) + assert isinstance(res, GeoIPInformation) + assert res.country_name == "United States" + assert res.country_iso == "us" diff --git a/tests/managers/thl/test_payout.py b/tests/managers/thl/test_payout.py new file mode 100644 index 0000000..31087b8 --- /dev/null +++ b/tests/managers/thl/test_payout.py @@ -0,0 +1,1269 @@ +import logging +import os +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from random import choice as rand_choice, randint +from typing import Optional +from uuid import uuid4 + +import pandas as pd +import pytest + +from generalresearch.currency import USDCent +from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerTransactionConditionFailedError, +) +from generalresearch.managers.thl.payout import UserPayoutEventManager +from generalresearch.models.thl.definitions import PayoutStatus +from generalresearch.models.thl.ledger import LedgerEntry, Direction +from generalresearch.models.thl.payout import BusinessPayoutEvent +from generalresearch.models.thl.payout import UserPayoutEvent +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.thl.ledger import LedgerAccount + +logger = logging.getLogger() + +cashout_method_uuid = uuid4().hex + + +class TestPayout: + + def test_get_by_uuid_and_create( + self, + user, + user_payout_event_manager: UserPayoutEventManager, + thl_lm, + utc_now, + ): + + user_account: LedgerAccount = thl_lm.get_account_or_create_user_wallet( + user=user + ) + + pe1: UserPayoutEvent = user_payout_event_manager.create( + debit_account_uuid=user_account.uuid, + payout_type=PayoutType.PAYPAL, + cashout_method_uuid=cashout_method_uuid, + amount=100, + created=utc_now, + ) + + # these get added by the query + pe1.account_reference_type = "user" + pe1.account_reference_uuid = user.uuid + # pe1.description = "PayPal" + + pe2 = user_payout_event_manager.get_by_uuid(pe_uuid=pe1.uuid) + + assert pe1 == pe2 + + def test_update(self, user, user_payout_event_manager, lm, thl_lm, utc_now): + from generalresearch.models.thl.definitions import PayoutStatus + from generalresearch.models.thl.wallet import PayoutType + + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + + pe1 = user_payout_event_manager.create( + status=PayoutStatus.PENDING, + debit_account_uuid=user_account.uuid, + payout_type=PayoutType.PAYPAL, + cashout_method_uuid=cashout_method_uuid, + amount=100, + created=utc_now, + ) + user_payout_event_manager.update( + payout_event=pe1, + status=PayoutStatus.APPROVED, + order_data={"foo": "bar"}, + ext_ref_id="abc", + ) + pe = user_payout_event_manager.get_by_uuid(pe_uuid=pe1.uuid) + assert pe.status == PayoutStatus.APPROVED + assert pe.order_data == {"foo": "bar"} + + with pytest.raises(expected_exception=AssertionError) as cm: + user_payout_event_manager.update( + payout_event=pe, status=PayoutStatus.PENDING + ) + assert "status APPROVED can only be" in str(cm.value) + + def test_create_bp_payout( + self, + user, + thl_web_rr, + user_payout_event_manager, + lm, + thl_lm, + product, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + ): + delete_ledger_db() + create_main_accounts() + from generalresearch.models.thl.ledger import LedgerAccount + + thl_lm.get_account_or_create_bp_wallet(product=product) + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + # wallet balance failure + brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(100), + skip_wallet_balance_check=False, + skip_one_per_day_check=False, + ) + + # (we don't have a special method for this) Put money in the BP's account + amount_cents = 100 + cash_account: LedgerAccount = thl_lm.get_account_cash() + bp_wallet: LedgerAccount = thl_lm.get_account_or_create_bp_wallet( + product=product + ) + + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=cash_account.uuid, + amount=amount_cents, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=bp_wallet.uuid, + amount=amount_cents, + ), + ] + + lm.create_tx(entries=entries) + assert 100 == lm.get_account_balance(account=bp_wallet) + + # Then run it again for $1.00 + brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(100), + skip_wallet_balance_check=False, + skip_one_per_day_check=False, + ) + assert 0 == lm.get_account_balance(account=bp_wallet) + + # Run again should without balance check, should still fail due to day check + with pytest.raises(LedgerTransactionConditionFailedError): + brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(100), + skip_wallet_balance_check=True, + skip_one_per_day_check=False, + ) + + # And then we can run again skip both checks + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(100), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + assert -100 == lm.get_account_balance(account=bp_wallet) + + pe = brokerage_product_payout_event_manager.get_by_uuid(pe.uuid) + txs = lm.get_tx_filtered_by_metadata( + metadata_key="event_payout", metadata_value=pe.uuid + ) + + assert 1 == len(txs) + + def test_create_bp_payout_quick_dupe( + self, + user, + product, + thl_web_rw, + brokerage_product_payout_event_manager, + thl_lm, + lm, + utc_now, + create_main_accounts, + ): + thl_lm.get_account_or_create_bp_wallet(product=product) + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(100), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + created=utc_now, + ) + + with pytest.raises(ValueError) as cm: + brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(100), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + created=utc_now, + ) + assert "Payout event already exists!" in str(cm.value) + + def test_filter( + self, + thl_web_rw, + thl_lm, + lm, + product, + user, + user_payout_event_manager, + utc_now, + ): + from generalresearch.models.thl.definitions import PayoutStatus + from generalresearch.models.thl.wallet import PayoutType + + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + bp_account = thl_lm.get_account_or_create_bp_wallet(product=product) + + user_payout_event_manager.create( + status=PayoutStatus.PENDING, + debit_account_uuid=user_account.uuid, + payout_type=PayoutType.PAYPAL, + cashout_method_uuid=cashout_method_uuid, + amount=100, + created=utc_now, + ) + + user_payout_event_manager.create( + status=PayoutStatus.PENDING, + debit_account_uuid=bp_account.uuid, + payout_type=PayoutType.PAYPAL, + cashout_method_uuid=cashout_method_uuid, + amount=200, + created=utc_now, + ) + + pes = user_payout_event_manager.filter_by( + reference_uuid=user.uuid, created=utc_now + ) + assert 1 == len(pes) + + pes = user_payout_event_manager.filter_by( + debit_account_uuids=[bp_account.uuid], created=utc_now + ) + assert 1 == len(pes) + + pes = user_payout_event_manager.filter_by( + debit_account_uuids=[bp_account.uuid], amount=123 + ) + assert 0 == len(pes) + + pes = user_payout_event_manager.filter_by( + product_ids=[user.product_id], bp_user_ids=["x"] + ) + assert 0 == len(pes) + + pes = user_payout_event_manager.filter_by(product_ids=[user.product_id]) + assert 1 == len(pes) + + pes = user_payout_event_manager.filter_by( + cashout_types=[PayoutType.PAYPAL], + bp_user_ids=[user.product_user_id], + ) + assert 1 == len(pes) + + pes = user_payout_event_manager.filter_by( + statuses=[PayoutStatus.FAILED], + cashout_types=[PayoutType.PAYPAL], + bp_user_ids=[user.product_user_id], + ) + assert 0 == len(pes) + + pes = user_payout_event_manager.filter_by( + statuses=[PayoutStatus.PENDING], + cashout_types=[PayoutType.PAYPAL], + bp_user_ids=[user.product_user_id], + ) + assert 1 == len(pes) + + +class TestPayoutEventManager: + + def test_set_account_lookup_table( + self, payout_event_manager, thl_redis_config, thl_lm, delete_ledger_db + ): + delete_ledger_db() + rc = thl_redis_config.create_redis_client() + rc.delete("pem:account_to_product") + rc.delete("pem:product_to_account") + N = 5 + + for idx in range(N): + thl_lm.get_account_or_create_bp_wallet_by_uuid(product_uuid=uuid4().hex) + + res = rc.hgetall(name="pem:account_to_product") + assert len(res.items()) == 0 + + res = rc.hgetall(name="pem:product_to_account") + assert len(res.items()) == 0 + + payout_event_manager.set_account_lookup_table( + thl_lm=thl_lm, + ) + + res = rc.hgetall(name="pem:account_to_product") + assert len(res.items()) == N + + res = rc.hgetall(name="pem:product_to_account") + assert len(res.items()) == N + + thl_lm.get_account_or_create_bp_wallet_by_uuid(product_uuid=uuid4().hex) + payout_event_manager.set_account_lookup_table( + thl_lm=thl_lm, + ) + + res = rc.hgetall(name="pem:account_to_product") + assert len(res.items()) == N + 1 + + res = rc.hgetall(name="pem:product_to_account") + assert len(res.items()) == N + 1 + + +class TestBusinessPayoutEventManager: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "5d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return timedelta(days=10) + + def test_base( + self, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + thl_lm, + thl_web_rr, + product_factory, + bp_payout_factory, + business, + ): + delete_ledger_db() + create_main_accounts() + + from generalresearch.models.thl.product import Product + + p1: Product = product_factory(business=business) + thl_lm.get_account_or_create_bp_wallet(product=p1) + business_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + ach_id1 = uuid4().hex + ach_id2 = uuid4().hex + + bp_payout_factory( + product=p1, + amount=USDCent(1), + ext_ref_id=None, + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=p1, + amount=USDCent(1), + ext_ref_id=ach_id1, + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=p1, + amount=USDCent(25), + ext_ref_id=ach_id1, + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=p1, + amount=USDCent(50), + ext_ref_id=ach_id2, + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + + assert len(business.payouts) == 3 + assert business.payouts_total == sum([pe.amount for pe in business.payouts]) + assert business.payouts[0].created > business.payouts[1].created + assert len(business.payouts[0].bp_payouts) == 1 + assert len(business.payouts[1].bp_payouts) == 2 + + assert business.payouts[0].ext_ref_id == ach_id2 + assert business.payouts[1].ext_ref_id == ach_id1 + assert business.payouts[2].ext_ref_id is None + + def test_update_ext_reference_ids( + self, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + thl_lm, + thl_web_rr, + product_factory, + bp_payout_factory, + delete_df_collection, + user_factory, + ledger_collection, + session_with_tx_factory, + pop_ledger_merge, + client_no_amm, + mnt_filepath, + lm, + product_manager, + start, + business, + ): + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p1) + business_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + # $250.00 to work with + for idx in range(1, 10): + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("25.00"), + started=start + timedelta(days=1, minutes=idx), + ) + + ach_id1 = uuid4().hex + ach_id2 = uuid4().hex + + with pytest.raises(expected_exception=Warning) as cm: + business_payout_event_manager.update_ext_reference_ids( + new_value=ach_id2, + current_value=ach_id1, + ) + assert "No event_payouts found to UPDATE" in str(cm) + + # We must build the balance to issue ACH/Wire + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + res = business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(100_01), + pm=product_manager, + thl_lm=thl_lm, + transaction_id=ach_id1, + ) + assert isinstance(res, BusinessPayoutEvent) + + # Okay, now that there is a payout_event, let's try to update the + # ext_reference_id + business_payout_event_manager.update_ext_reference_ids( + new_value=ach_id2, + current_value=ach_id1, + ) + + res = business_payout_event_manager.filter_by(ext_ref_id=ach_id1) + assert len(res) == 0 + + res = business_payout_event_manager.filter_by(ext_ref_id=ach_id2) + assert len(res) == 1 + + def test_delete_failed_business_payout( + self, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + thl_lm, + thl_web_rr, + product_factory, + bp_payout_factory, + currency, + delete_df_collection, + user_factory, + ledger_collection, + session_with_tx_factory, + pop_ledger_merge, + client_no_amm, + mnt_filepath, + lm, + product_manager, + start, + business, + ): + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p1) + business_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + # $250.00 to work with + for idx in range(1, 10): + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("25.00"), + started=start + timedelta(days=1, minutes=idx), + ) + + # We must build the balance to issue ACH/Wire + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + ach_id1 = uuid4().hex + + res = business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(100_01), + pm=product_manager, + thl_lm=thl_lm, + transaction_id=ach_id1, + ) + assert isinstance(res, BusinessPayoutEvent) + + # (1) Confirm the initial Event Payout, Tx, TxMeta, TxEntry all exist + event_payouts = business_payout_event_manager.filter_by(ext_ref_id=ach_id1) + event_payout_uuids = [i.uuid for i in event_payouts] + assert len(event_payout_uuids) == 1 + tags = [f"{currency.value}:bp_payout:{x}" for x in event_payout_uuids] + transactions = thl_lm.get_txs_by_tags(tags=tags) + assert len(transactions) == 1 + tx_metadata_ids = thl_lm.get_tx_metadata_ids_by_txs(transactions=transactions) + assert len(tx_metadata_ids) == 2 + tx_entries = thl_lm.get_tx_entries_by_txs(transactions=transactions) + assert len(tx_entries) == 2 + + # (2) Delete! + business_payout_event_manager.delete_failed_business_payout( + ext_ref_id=ach_id1, thl_lm=thl_lm + ) + + # (3) Confirm the initial Event Payout, Tx, TxMeta, TxEntry have + # all been deleted + res = business_payout_event_manager.filter_by(ext_ref_id=ach_id1) + assert len(res) == 0 + + # Note: b/c the event_payout shouldn't exist anymore, we are taking + # the tag strings and transactions from when they did.. + res = thl_lm.get_txs_by_tags(tags=tags) + assert len(res) == 0 + + tx_metadata_ids = thl_lm.get_tx_metadata_ids_by_txs(transactions=transactions) + assert len(tx_metadata_ids) == 0 + tx_entries = thl_lm.get_tx_entries_by_txs(transactions=transactions) + assert len(tx_entries) == 0 + + def test_recoup_empty(self, business_payout_event_manager): + res = {uuid4().hex: USDCent(0) for i in range(100)} + df = pd.DataFrame.from_dict(res, orient="index").reset_index() + df.columns = ["product_id", "available_balance"] + + with pytest.raises(expected_exception=ValueError) as cm: + business_payout_event_manager.recoup_proportional( + df=df, target_amount=USDCent(1) + ) + assert "Total available amount is empty, cannot recoup" in str(cm) + + def test_recoup_exceeds(self, business_payout_event_manager): + from random import randint + + res = {uuid4().hex: USDCent(randint(a=0, b=1_000_00)) for i in range(100)} + df = pd.DataFrame.from_dict(res, orient="index").reset_index() + df.columns = ["product_id", "available_balance"] + + avail_balance = USDCent(int(df.available_balance.sum())) + + with pytest.raises(expected_exception=ValueError) as cm: + business_payout_event_manager.recoup_proportional( + df=df, target_amount=avail_balance + USDCent(1) + ) + assert " exceeds total available " in str(cm) + + def test_recoup(self, business_payout_event_manager): + from random import randint, random + + res = {uuid4().hex: USDCent(randint(a=0, b=1_000_00)) for i in range(100)} + df = pd.DataFrame.from_dict(res, orient="index").reset_index() + df.columns = ["product_id", "available_balance"] + + avail_balance = USDCent(int(df.available_balance.sum())) + random_recoup_amount = USDCent(1 + int(int(avail_balance) * random() * 0.5)) + + res = business_payout_event_manager.recoup_proportional( + df=df, target_amount=random_recoup_amount + ) + + assert isinstance(res, pd.DataFrame) + assert res.weight.sum() == pytest.approx(1) + assert res.deduction.sum() == random_recoup_amount + assert res.remaining_balance.sum() == avail_balance - random_recoup_amount + + def test_recoup_loop(self, business_payout_event_manager, request): + # TODO: Generate this file at random + fp = os.path.join( + request.config.rootpath, "data/pytest_recoup_proportional.csv" + ) + df = pd.read_csv(fp, index_col=0) + res = business_payout_event_manager.recoup_proportional( + df=df, target_amount=USDCent(1416089) + ) + + res = res[res["remaining_balance"] > 0] + + assert int(res.deduction.sum()) == 1416089 + + def test_recoup_loop_single_profitable_account(self, business_payout_event_manager): + res = [{"product_id": uuid4().hex, "available_balance": 0} for i in range(1000)] + for x in range(100): + item = rand_choice(res) + item["available_balance"] = randint(8, 12) + + df = pd.DataFrame(res) + res = business_payout_event_manager.recoup_proportional( + df=df, target_amount=USDCent(500) + ) + # res = res[res["remaining_balance"] > 0] + assert int(res.deduction.sum()) == 500 + + def test_recoup_loop_assertions(self, business_payout_event_manager): + df = pd.DataFrame( + [ + { + "product_id": uuid4().hex, + "available_balance": randint(0, 999_999), + } + for i in range(10_000) + ] + ) + available_balance = int(df.available_balance.sum()) + + # Exact amount + res = business_payout_event_manager.recoup_proportional( + df=df, target_amount=available_balance + ) + assert res.remaining_balance.sum() == 0 + assert int(res.deduction.sum()) == available_balance + + # Slightly less + res = business_payout_event_manager.recoup_proportional( + df=df, target_amount=available_balance - 1 + ) + assert res.remaining_balance.sum() == 1 + assert int(res.deduction.sum()) == available_balance - 1 + + # Slightly less + with pytest.raises(expected_exception=Exception) as cm: + res = business_payout_event_manager.recoup_proportional( + df=df, target_amount=available_balance + 1 + ) + + # Don't pull anything + res = business_payout_event_manager.recoup_proportional(df=df, target_amount=0) + assert res.remaining_balance.sum() == available_balance + assert int(res.deduction.sum()) == 0 + + def test_distribute_amount(self, business_payout_event_manager): + import io + + df = pd.read_csv( + io.StringIO( + "product_id,available_balance,weight,raw_deduction,deduction,remainder,remaining_balance\n0,15faf,11768,0.0019298788807489663,222.3374860933269,223,0.33748609332690194,11545\n5,793e3,202,3.312674489388946e-05,3.8164660257352168,3,0.8164660257352168,199\n6,c703b,22257,0.0036500097084321667,420.5103184890531,421,0.510318489053077,21836\n13,72a70,1424,0.00023352715212326036,26.90419614181658,27,0.9041961418165805,1397\n14,86173,45634,0.007483692457860156,862.1812406851528,863,0.18124068515282943,44771\n17,4f230,143676,0.02356197128403199,2714.5275876907576,2715,0.5275876907576276,140961\n18,e1ee6,129,2.1155198471840298e-05,2.437248105543777,2,0.43724810554377713,127\n22,4524a,85613,0.014040000052478012,1617.5203260458868,1618,0.5203260458868044,83995\n25,4f5e2,30282,0.004966059845924558,572.1298227292765,573,0.12982272927649774,29709\n28,0f3b3,135,2.213916119146078e-05,2.5506084825458135,2,0.5506084825458135,133\n29,15c6f,1226,0.00020105638237578454,23.163303700749385,23,0.16330370074938472,1203\n31,c0c04,376,6.166166376288335e-05,7.103916958794265,7,0.10391695879426521,369\n33,2934c,37649,0.006174202071831903,711.3174722916099,712,0.31747229160987445,36937\n38,5585d,16471,0.0027011416591448184,311.1931282667562,312,0.19312826675621864,16159\n42,0a749,663,0.00010872788051806293,12.526321658724994,12,0.5263216587249939,651\n43,9e336,322,5.280599928629904e-05,6.08367356577594,6,0.08367356577593998,316\n46,043a5,11656,0.001911511576649384,220.22142572262223,221,0.2214257226222287,11435\n47,d6f4e,39,6.3957576775331136e-06,0.7368424505132349,0,0.7368424505132349,39\n48,3e123,3012,0.0004939492852494804,56.90690925502214,57,0.9069092550221427,2955\n49,76ccc,76,1.24635277818594e-05,1.4358981086924578,1,0.4358981086924578,75\n51,8acb9,7710,0.0012643920947123155,145.66808444761645,146,0.6680844476164509,7564\n56,fef3e,212,3.476668275992359e-05,4.005399987405277,4,0.005399987405277251,208\n57,6d7c2,455709,0.07473344449925481,8609.890673870148,8610,0.8906738701480208,447099\n58,f51f6,257,4.2146403157077186e-05,4.855602814920548,4,0.8556028149205481,253\n61,06acf,84310,0.013826316148533765,1592.902230840278,1593,0.90223084027798,82717\n62,6eca7,40,6.559751464136527e-06,0.755735846680241,0,0.755735846680241,40\n68,4a415,1955,0.00032060785280967275,36.93658950649678,37,0.9365895064967802,1918\n69,57a16,409,6.707345872079598e-05,7.727399032305463,7,0.7273990323054633,402\n70,c3ef6,593,9.724831545582401e-05,11.203783927034571,11,0.20378392703457138,582\n71,385da,825,0.00013529487394781586,15.587051837779969,15,0.5870518377799687,810\n72,a8435,748,0.00012266735237935304,14.132260332920506,14,0.13226033292050587,734\n75,c9374,263383,0.043193175496966774,4976.199362654548,4977,0.19936265454816748,258406\n76,7fcc7,136,2.230315497806419e-05,2.569501878712819,2,0.5695018787128192,134\n77,26aec,356,5.838178803081509e-05,6.726049035454145,6,0.7260490354541451,350\n78,76cc5,413,6.772943386720964e-05,7.802972616973489,7,0.8029726169734888,406\n82,9476a,13973,0.002291485180209492,263.99742464157515,264,0.9974246415751509,13709\n85,ee099,2397,0.0003930931064883814,45.28747061231344,46,0.2874706123134416,2351\n87,24bd5,122295,0.020055620132664414,2310.567884244002,2311,0.5678842440020162,119984\n91,1fa8f,4,6.559751464136527e-07,0.0755735846680241,0,0.0755735846680241,4\n92,53b5a,1,1.6399378660341317e-07,0.018893396167006023,0,0.018893396167006023,1\n93,f4f9e,201,3.296275110728605e-05,3.797572629568211,3,0.7975726295682111,198\n95,ff5d7,21317,0.0034958555490249587,402.75052609206745,403,0.7505260920674459,20914\n96,6290d,80,1.3119502928273054e-05,1.511471693360482,1,0.5114716933604819,79\n100,9c34b,1870,0.0003066683809483826,35.33065083230127,36,0.33065083230126646,1834\n101,d32a9,11577,0.0018985560675077143,218.72884742542874,219,0.7288474254287394,11358\n102,a6001,8981,0.0014728281974852537,169.6815909758811,170,0.6815909758811074,8811\n106,0a1ee,9,1.4759440794307186e-06,0.17004056550305424,0,0.17004056550305424,9\n108,8ad36,51,8.363683116774072e-06,0.9635632045173074,0,0.9635632045173074,51\n111,389ab,75,1.2299533995255988e-05,1.417004712525452,1,0.4170047125254519,74\n114,be86b,4831,0.000792253983081089,91.2739968828061,92,0.27399688280610235,4739\n118,99e96,271,4.444231616952497e-05,5.120110361258632,5,0.12011036125863228,266\n121,3b729,12417,0.0020363108482545815,234.59930020571383,235,0.5993002057138312,12182\n122,f16e2,2697,0.0004422912424694053,50.95548946241525,51,0.955489462415251,2646\n125,241c5,2,3.2798757320682633e-07,0.03778679233401205,0,0.03778679233401205,2\n127,523af,569,9.33124645773421e-05,10.750342419026428,10,0.7503424190264276,559\n130,dce2e,31217,0.005119394036398749,589.7951481454271,590,0.7951481454271061,30627\n133,214b8,37360,0.006126807867503516,705.8572807993451,706,0.8572807993450624,36654\n136,03c88,35,5.739782531119461e-06,0.6612688658452109,0,0.6612688658452109,35\n137,08021,44828,0.007351513465857806,846.9531633745461,847,0.9531633745460795,43981\n144,bc3d9,1174,0.00019252870547240705,22.180847100065073,22,0.18084710006507265,1152\n148,bac2f,3745,0.0006141567308297824,70.75576864543757,71,0.7557686454375698,3674\n150,d9e69,8755,0.0014357656017128823,165.41168344213776,166,0.41168344213775754,8589\n151,36b18,1,1.6399378660341317e-07,0.018893396167006023,0,0.018893396167006023,1\n152,e15d7,5259,0.0008624433237473498,99.36037044228468,100,0.3603704422846761,5159\n155,bad23,16504,0.002706553454102731,311.81661034026746,312,0.8166103402674594,16192\n159,9f77f,2220,0.00036406620625957726,41.943339490753374,42,0.9433394907533739,2178\n160,945ab,131,2.1483186045047126e-05,2.4750348978777894,2,0.4750348978777894,129\n161,3e2dc,354,5.8053800457608265e-05,6.688262243120133,6,0.6882622431201328,348\n162,62811,1608,0.0002637020088582884,30.38058103654569,31,0.38058103654568853,1577\n164,59a3d,2541,0.0004167082117592729,48.00811966036231,49,0.008119660362311265,2492\n165,871ca,8,1.3119502928273053e-06,0.1511471693360482,0,0.1511471693360482,8\n167,55d38,683,0.00011200775625013119,12.904189582065115,12,0.9041895820651149,671\n169,1e5af,6,9.83962719620479e-07,0.11336037700203613,0,0.11336037700203613,6\n170,b2901,15994,0.0026229166229349904,302.18097829509435,303,0.1809782950943486,15691\n174,dc880,2435,0.00039932487037931105,46.00541966665967,47,0.005419666659669531,2388\n176,d189e,34579,0.005670741146959424,653.3147460589013,654,0.31474605890127805,33925\n177,d35f5,41070,0.006735224815802179,775.9517805789375,776,0.9517805789374734,40294\n178,bd7a6,12563,0.0020602539410986796,237.35773604609668,238,0.35773604609667586,12325\n180,662cd,9675,0.0015866398853880224,182.79360791578327,183,0.7936079157832694,9492\n181,e995f,1011,0.00016579771825605072,19.101223524843093,19,0.1012235248430926,992\n189,0faf8,626,0.00010266011041373665,11.827266000545771,11,0.8272660005457713,615\n190,1fd84,213,3.4930676546527005e-05,4.024293383572283,4,0.02429338357228339,209\n193,c9b44,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n194,bbd8d,16686,0.0027364003232645522,315.25520844266254,316,0.2552084426625356,16370\n196,d4945,2654,0.00043523950964545856,50.14307342723399,51,0.1430734272339933,2603\n197,bbde8,3043,0.0004990330926341863,57.49260453619934,58,0.4926045361993374,2985\n200,579ad,20833,0.0034164825563089067,393.6061223472365,394,0.6061223472365214,20439\n202,2f932,15237,0.0024987733264762065,287.8786773966708,288,0.8786773966708097,14949\n208,b649d,103551,0.01698172059657004,1956.430066489641,1957,0.43006648964092165,101594\n209,0f939,26025,0.0042679382963538275,491.7006352463318,492,0.7006352463317853,25533\n211,638d8,6218,0.0010197133651000231,117.47913736644347,118,0.47913736644346727,6100\n215,d2a81,8301,0.0013613124225949327,156.834081582317,157,0.8340815823169976,8144\n216,62293,4,6.559751464136527e-07,0.0755735846680241,0,0.0755735846680241,4\n218,c8ae9,2829,0.00046393842230105583,53.44941775646004,54,0.44941775646004345,2775\n219,83a9f,6556,0.0010751432649719768,123.8651052708915,124,0.8651052708915046,6432\n221,d256a,72,1.1807552635445749e-05,1.360324524024434,1,0.36032452402443393,71\n222,6fdc2,7,1.1479565062238922e-06,0.1322537731690422,0,0.1322537731690422,7\n224,56b88,146928,0.02409527907806629,2775.9689120258613,2776,0.9689120258613002,144152\n230,f50f8,5798,0.0009508359747265895,109.54391097630092,110,0.5439109763009213,5688\n231,fa3be,2,3.2798757320682633e-07,0.03778679233401205,0,0.03778679233401205,2\n232,94934,537381,0.08812714503872877,10152.952125621865,10153,0.9521256218649796,527228\n234,4e20c,5,8.199689330170659e-07,0.09446698083503012,0,0.09446698083503012,5\n235,a5d68,31101,0.005100370757152753,587.6035141900543,588,0.603514190054284,30513\n236,e5a29,3208,0.0005260920674237494,60.61001490375533,61,0.610014903755328,3147\n237,0ce0f,294,4.821417326140347e-05,5.554658473099771,5,0.5546584730997708,289\n240,66d2b,18633,0.0030556962257813976,352.04065077982324,353,0.04065077982323828,18280\n244,1bd17,1815,0.0002976487226851949,34.29151404311594,35,0.29151404311593865,1780\n245,32aca,224,3.673460819916455e-05,4.23212074140935,4,0.23212074140935002,220\n247,cbf8e,4747,0.0007784785050064023,89.6869516047776,90,0.6869516047776045,4657\n249,8f24b,807633,0.13244679385587438,15258.930226547576,15259,0.9302265475762397,792374\n251,c97c3,20526,0.0033661364638216586,387.80584972396565,388,0.8058497239656504,20138\n252,88fee,13821,0.0022665581246457734,261.12562842419027,262,0.12562842419026765,13559\n253,a9ad3,178,2.9190894015407545e-05,3.3630245177270726,3,0.36302451772707256,175\n254,83738,104,1.705535380675497e-05,1.9649132013686266,1,0.9649132013686266,103\n255,21f6c,6288,0.001031192930162262,118.8016750981339,119,0.8016750981338987,6169\n256,97e28,6,9.83962719620479e-07,0.11336037700203613,0,0.11336037700203613,6\n257,7f689,39,6.3957576775331136e-06,0.7368424505132349,0,0.7368424505132349,39\n258,e7a50,28031,0.004596909832280275,529.6007879573459,530,0.600787957345915,27501\n259,2eb98,1,1.6399378660341317e-07,0.018893396167006023,0,0.018893396167006023,1\n260,349ba,19518,0.0032008307269254183,368.76130638762356,369,0.7613063876235628,19149\n261,a3d04,235,3.853853985180209e-05,4.439948099246416,4,0.43994809924641576,231\n262,40971,2249,0.0003688220260710762,42.49124797959655,43,0.491247979596551,2206\n264,4f588,6105,0.0010011820672138373,115.34418359957178,116,0.34418359957177813,5989\n269,f182e,1020,0.00016727366233548145,19.271264090346147,19,0.27126409034614696,1001\n270,3798e,5168,0.0008475198891664393,97.64107139108714,98,0.641071391087138,5070\n274,81dc3,274,4.493429752933521e-05,5.176790549759651,5,0.1767905497596507,269\n285,8520b,2,3.2798757320682633e-07,0.03778679233401205,0,0.03778679233401205,2\n287,37b89,3742,0.0006136647494699721,70.69908845693654,71,0.6990884569365363,3671\n291,63706,4740,0.0007773305485001784,89.55469783160856,90,0.5546978316085642,4650\n293,21c9d,241,3.9522502571422574e-05,4.553308476248452,4,0.5533084762484517,237\n296,3a42a,610,0.00010003620982808204,11.524971661873675,11,0.5249716618736748,599\n302,31a8f,148533,0.02435848910556477,2806.292812873906,2807,0.2928128739058593,145726\n303,38467,2399,0.0003934210940615882,45.32525740464745,46,0.3252574046474521,2353\n304,a49a6,26573,0.004357806891412498,502.0542163458511,503,0.05421634585110269,26070\n305,3c4e7,2286,0.0003748897961754025,43.19030363777577,44,0.19030363777577008,2242\n306,63986,34132,0.005597435924347699,644.8693979722497,645,0.8693979722496579,33487\n307,640f4,50,8.199689330170658e-06,0.9446698083503012,0,0.9446698083503012,50\n309,ba81a,3015,0.0004944412666092907,56.96358944352316,57,0.963589443523162,2958\n310,b7e8e,11409,0.0018710051113583408,215.55475686937172,216,0.5547568693717153,11193\n311,98301,5694,0.0009337806209198346,107.5789977749323,108,0.5789977749323043,5586\n312,e70bd,19,3.11588194546485e-06,0.35897452717311445,0,0.35897452717311445,19\n314,adf9c,5023,0.0008237407901089443,94.90152894687125,95,0.9015289468712524,4928\n316,ab5f4,55,9.019658263187724e-06,1.0391367891853314,1,0.039136789185331367,54\n318,d4a4c,2242,0.0003676740695648523,42.358994206427504,43,0.35899420642750357,2199\n319,b4062,222,3.640662062595773e-05,4.194333949075338,4,0.19433394907533774,218\n321,cb691,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n322,b83ed,3716,0.0006094009110182833,70.20786015659439,71,0.20786015659439272,3645\n323,8759e,5964,0.0009780589433027562,112.68021474002393,113,0.6802147400239278,5851\n325,2f217,3625,0.0005944774764373728,68.48856110539684,69,0.4885611053968404,3556\n326,d683f,28156,0.004617409055605701,531.9624624782216,532,0.962462478221596,27624\n327,97cf7,928,0.00015218623396796743,17.533071642981593,17,0.5330716429815929,911\n328,75135,2841,0.0004659063477402968,53.67613851046411,54,0.6761385104641136,2787\n329,d7c10,29913,0.0049055461386678986,565.1581595436512,566,0.15815954365120888,29347\n330,8598e,66,1.082358991582527e-05,1.2469641470223976,1,0.24696414702239755,65\n331,5e72b,10825,0.0017752327399819475,204.52101350784022,205,0.5210135078402232,10620\n332,30bff,975,0.00015989394193832783,18.42106126283087,18,0.42106126283087164,957\n333,f79d7,696,0.00011413967547597557,13.149803732236194,13,0.1498037322361938,683\n334,d0a0d,2121,0.0003478308213858393,40.07289327021978,41,0.07289327021977954,2080\n335,22ec5,3942,0.0006464635067906547,74.47776769033774,75,0.4777676903377426,3867\n336,efeac,515,8.445680010075779e-05,9.730099026008103,9,0.7300990260081033,506\n337,c3854,105882,0.017363990113142592,2000.4705729549316,2001,0.4705729549316402,103881\n338,cd2a3,42924,0.007039269296164907,810.9801370725665,811,0.9801370725665493,42113\n339,d7333,5089,0.0008345643800247697,96.14849309389366,97,0.14849309389366283,4992\n340,7d48c,33,5.411794957912635e-06,0.6234820735111988,0,0.6234820735111988,33\n341,6a148,100,1.6399378660341316e-05,1.8893396167006025,1,0.8893396167006025,99\n342,ba8ff,400,6.559751464136527e-05,7.55735846680241,7,0.5573584668024099,393\n343,38810,3194,0.0005237961544113017,60.34550735741724,61,0.34550735741724026,3133\n344,04d6e,43,7.051732823946766e-06,0.812416035181259,0,0.812416035181259,43\n345,d4f05,1394,0.00022860733852515797,26.337394256806398,27,0.3373942568063981,1367\n346,53f71,1325,0.00021729176724952245,25.033749921282983,26,0.0337499212829826,1299\n347,e4d06,60,9.83962719620479e-06,1.1336037700203614,1,0.1336037700203614,59\n348,8ab28,5,8.199689330170659e-07,0.09446698083503012,0,0.09446698083503012,5\n349,f21e0,65,1.0659596129221857e-05,1.2280707508553916,1,0.22807075085539164,64\n350,02fc2,2202,0.0003611143181007158,41.603258359747265,42,0.6032583597472652,2160\n351,310c0,18370,0.0030125658599047,347.07168758790067,348,0.07168758790066931,18022\n352,fceae,169,2.7714949935976826e-05,3.192983952224018,3,0.1929839522240182,166\n353,25c52,256,4.198240937047377e-05,4.836709418753542,4,0.836709418753542,252\n354,2e0cc,7029,0.0011527123260353911,132.80168165788535,133,0.8016816578853536,6896\n355,4baa9,405,6.641748357438233e-05,7.65182544763744,7,0.6518254476374397,398\n356,f2cb1,2236,0.00036669010684523185,42.24563382942547,43,0.24563382942547207,2193\n357,c4f9f,69,1.1315571275635508e-05,1.3036443355234157,1,0.30364433552341574,68\n358,0fe2b,33,5.411794957912635e-06,0.6234820735111988,0,0.6234820735111988,33\n359,ba2cc,911,0.0001493983395957094,17.21188390814249,17,0.21188390814248947,894\n360,22c5c,84,1.3775478074686707e-05,1.587045278028506,1,0.587045278028506,83\n361,b0369,42,6.8877390373433535e-06,0.793522639014253,0,0.793522639014253,42\n362,a7d03,12344,0.002024339301832532,233.22008228552235,234,0.2200822855223521,12110\n363,f6d7e,11703,0.0019192192846197442,221.1094153424715,222,0.1094153424714932,11481\n364,c281d,50376,0.008261350993933542,951.7737253090955,952,0.7737253090955392,49424\n365,dfbd6,3932,0.0006448235689246206,74.2888337286677,75,0.2888337286676972,3857\n366,a34ba,3909,0.0006410517118327421,73.85428561682654,74,0.8542856168265445,3835\n367,157ee,253,4.149042801066353e-05,4.780029230252524,4,0.7800292302525236,249\n369,33a29,6954,0.0011404127920401352,131.3846769453599,132,0.3846769453598995,6822\n370,2bd95,11279,0.001849685919099897,213.09861536766095,214,0.09861536766095469,11065\n371,138b7,2094,0.00034340298914754716,39.56277157371061,40,0.5627715737106129,2054\n372,eea6f,3884,0.0006369518671676568,73.3819507126514,74,0.38195071265140257,3810\n373,8f39b,1046,0.00017153750078717018,19.7624923906883,19,0.7624923906883012,1027\n374,bacfb,773,0.00012676719704443837,14.604595237095657,14,0.6045952370956567,759\n375,23403,855,0.00014021468754591825,16.15385372279015,16,0.1538537227901493,839\n376,0ad32,1630,0.00026730987216356345,30.79623575221982,31,0.7962357522198182,1599\n377,6be35,21,3.4438695186716767e-06,0.3967613195071265,0,0.3967613195071265,21\n378,12402,2994,0.000490997397090619,56.566828124016034,57,0.566828124016034,2937\n379,bae61,40980,0.006720465375007872,774.251374923907,775,0.25137492390695115,40205\n380,384f7,23545,0.003861233705577363,444.84501275215683,445,0.8450127521568334,23100\n381,03ed6,115,1.8859285459392513e-05,2.1727405592056925,2,0.17274055920569253,113\n382,cc4c8,3529,0.0005787340729234451,66.67479507336427,67,0.6747950733642654,3462\n383,263fa,4019,0.0006590910283591176,75.93255919519721,76,0.9325591951972143,3943\n384,704a1,11504,0.001886584521085665,217.3496295052373,218,0.34962950523728864,11286\n385,a0efb,7960,0.0013053905413631687,150.39143348936796,151,0.3914334893679552,7809\n386,b3197,981,0.00016087790465794832,18.53442163983291,18,0.5344216398329102,963\n387,1c91d,174,2.8534918868993892e-05,3.2874509330590485,3,0.28745093305904845,171\n388,1afff,797,0.0001307030479229203,15.058036745103802,15,0.05803674510380219,782\n389,20304,10605,0.0017391541069291968,200.3644663510989,201,0.3644663510989119,10404\n390,638fd,79,1.295550914166964e-05,1.4925782971934758,1,0.4925782971934758,78\n391,8c258,11,1.8039316526375448e-06,0.20782735783706627,0,0.20782735783706627,11\n392,c1847,20,3.2798757320682634e-06,0.3778679233401205,0,0.3778679233401205,20\n393,d72ec,2445,0.0004009648082453452,46.19435362832973,47,0.19435362832972913,2398\n394,d83e3,60,9.83962719620479e-06,1.1336037700203614,1,0.1336037700203614,59\n395,04a4f,60,9.83962719620479e-06,1.1336037700203614,1,0.1336037700203614,59\n396,94c2e,52,8.527676903377485e-06,0.9824566006843133,0,0.9824566006843133,52\n397,2fc45,8,1.3119502928273053e-06,0.1511471693360482,0,0.1511471693360482,8\n398,e7986,30,4.919813598102395e-06,0.5668018850101807,0,0.5668018850101807,30\n399,612a5,20340,0.003335633619513424,384.2916780369025,385,0.2916780369025105,19955\n400,f9e3a,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n401,75d14,667,0.00010938385566447659,12.601895243393018,12,0.6018952433930185,655\n402,e02b6,86,1.4103465647893532e-05,1.624832070362518,1,0.6248320703625181,85\n403,b4c46,479,7.855302378303491e-05,9.049936763995886,9,0.04993676399588587,470\n404,f9fb4,299,4.9034142194420536e-05,5.6491254539348015,5,0.6491254539348015,294\n405,08461,34,5.575788744516048e-06,0.6423754696782048,0,0.6423754696782048,34\n406,42032,124,2.0335229538823232e-05,2.342781124708747,2,0.3427811247087469,122\n407,b7ea6,2861,0.0004691862234723651,54.05400643380424,55,0.05400643380423986,2806\n408,cf91f,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n409,64f3d,11,1.8039316526375448e-06,0.20782735783706627,0,0.20782735783706627,11\n410,49d78,408,6.690946493419258e-05,7.708505636138459,7,0.708505636138459,401\n411,aa802,3907,0.0006407237242595352,73.81649882449253,74,0.8164988244925269,3833\n412,7feff,450,7.379720397153592e-05,8.50202827515271,8,0.5020282751527105,442\n413,08bc7,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n414,c7cad,16,2.6239005856546107e-06,0.3022943386720964,0,0.3022943386720964,16\n415,ec21a,84,1.3775478074686707e-05,1.587045278028506,1,0.587045278028506,83\n416,a1ba3,5,8.199689330170659e-07,0.09446698083503012,0,0.09446698083503012,5\n417,7f276,240,3.935850878481916e-05,4.534415080081446,4,0.5344150800814456,236\n419,e86cf,188,3.0830831881441676e-05,3.5519584793971326,3,0.5519584793971326,185\n420,4f7b0,14,2.2959130124477845e-06,0.2645075463380844,0,0.2645075463380844,14\n421,c61f3,363,5.952974453703898e-05,6.858302808623187,6,0.8583028086231872,357\n422,0c672,32,5.247801171309221e-06,0.6045886773441927,0,0.6045886773441927,32\n423,9f766,14028,0.00230050483847268,265.0365614307605,266,0.03656143076051421,13762\n424,2f990,4092,0.0006710625747811667,77.31177711538865,78,0.3117771153886508,4014\n425,660cb,1043,0.00017104551942735994,19.705812202187285,19,0.7058122021872855,1024\n426,0045c,81,1.3283496714876466e-05,1.5303650895274878,1,0.5303650895274878,80\n427,934e9,818,0.00013414691744159196,15.454798064610927,15,0.45479806461092664,803\n428,820eb,21,3.4438695186716767e-06,0.3967613195071265,0,0.3967613195071265,21\n429,8736b,14,2.2959130124477845e-06,0.2645075463380844,0,0.2645075463380844,14\n430,3de8b,18,2.9518881588614373e-06,0.3400811310061085,0,0.3400811310061085,18\n431,16140,21927,0.0035958917588530407,414.2754977539411,415,0.2754977539411243,21512\n432,44e80,17490,0.0028682513276936964,330.44549896093537,331,0.44549896093536745,17159\n434,c10c0,58575,0.009605936050294927,1106.680680482378,1107,0.6806804823779657,57468\n435,05214,129705,0.021270814091395706,2450.5679498415166,2451,0.5679498415165654,127254\n436,d70bb,130766,0.021444811498981926,2470.61384317471,2471,0.6138431747099276,128295\n437,a7eba,155,2.5419036923529042e-05,2.928476405885934,2,0.928476405885934,153\n438,2afa7,603,9.888825332185814e-05,11.392717888704633,11,0.39271788870463276,592\n439,ff2e5,1855,0.00030420847414933144,35.04724988979618,36,0.04724988979617706,1819\n440,b784f,148,2.4271080417305148e-05,2.7962226327168915,2,0.7962226327168915,146\n441,c2ce4,16469,0.0027008136715716115,311.15534147442224,312,0.15534147442224366,16157\n442,02ab5,491,8.052094922227586e-05,9.276657517999958,9,0.27665751799995775,482\n443,56013,65781,0.010787675276559121,1242.8264932618233,1243,0.8264932618233161,64538\n444,b4b2c,2429,0.0003983409076596906,45.89205928965763,46,0.8920592896576309,2383\n445,298a3,231,3.788256470538844e-05,4.364374514578392,4,0.3643745145783921,227\n446,3df28,3918,0.0006425276559121728,74.0243261823296,75,0.024326182329602375,3843\n447,99b56,594,9.741230924242742e-05,11.222677323201578,11,0.2226773232015784,583\n448,7c7e3,285,4.6738229181972755e-05,5.384617907596717,5,0.3846179075967173,280\n449,230a9,29334,0.004810593736224522,554.2188831629547,555,0.21888316295473942,28779\n450,11677,332788,0.05457516425617666,6287.4955236256,6288,0.4955236256000717,326500\n451,2e832,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n452,b961f,105,1.7219347593358382e-05,1.9838065975356325,1,0.9838065975356325,104\n453,caff0,2624,0.00043031969604735617,49.57627154222381,50,0.5762715422238074,2574\n454,c0a9c,76675,0.012574223587816704,1448.6511511051867,1449,0.6511511051867274,75226\n455,91d53,2145,0.00035176667226432124,40.52633477822792,41,0.5263347782279197,2104\n456,77494,1404,0.0002302472763911921,26.526328218476458,27,0.5263282184764577,1377\n457,9346a,160758,0.026363313146791495,3037.2645810155545,3038,0.2645810155545405,157720\n458,48d36,690,0.00011315571275635509,13.036443355234157,13,0.03644335523415698,677\n459,de96c,20110,0.003297915048594639,379.9461969184912,380,0.9461969184911823,19730\n460,1eca8,1577,0.00025861820147358256,29.7948857553685,30,0.7948857553685009,1547\n461,29179,1443,0.0002366430340687252,27.26317066898969,28,0.2631706689896909,1415\n462,f6416,12243,0.0020077759293855874,231.31184927265477,232,0.3118492726547686,12011\n463,eb5ab,24654,0.004043102814920548,465.79778910136656,466,0.7977891013665612,24188\n464,bb24e,275,4.5098291315938624e-05,5.195683945926657,5,0.19568394592665683,270\n465,089a1,1595,0.000261570089632444,30.13496688637461,31,0.13496688637460963,1564\n466,8f9ad,8142,0.00133523741052499,153.83003159176306,154,0.8300315917630599,7988\n467,01aa2,67287,0.011034649919183862,1271.2799478893344,1272,0.2799478893343803,66015\n468,b8427,76545,0.012552904395558262,1446.1950096034761,1447,0.19500960347613727,75098\n469,4ca48,20684,0.003392047482104998,390.7910063183526,391,0.7910063183525722,20293\n470,a6fe0,673,0.00011036781838409706,12.715255620395054,12,0.7152556203950535,661\n471,8131b,524,8.59327441801885e-05,9.900139591511158,9,0.9001395915111576,515\n472,9ed2f,2084,0.000341763051281513,39.37383761204055,40,0.3738376120405533,2044\n473,62851,47540,0.007796264615126262,898.1920537794664,899,0.19205377946639146,46641\n474,c8364,85303,0.013989161978630954,1611.663373234115,1612,0.6633732341149425,83691\n475,63dcc,5712,0.000936732509078696,107.9190789059384,108,0.9190789059384059,5604\n476,53490,53415,0.008759728111421314,1009.1907562606268,1010,0.19075626062681295,52405\n477,ea1ca,9513,0.0015600728919582694,179.7328777367283,180,0.7328777367283124,9333\n478,b98cd,86047,0.014111173355863893,1625.7200599823675,1626,0.720059982367502,84421\n479,5cc45,458,7.510915426436323e-05,8.65317544448876,8,0.6531754444887596,450\n480,935e4,1124,0.0001843290161422364,21.23617729171477,21,0.2361772917147711,1103\n481,25830,51,8.363683116774072e-06,0.9635632045173074,0,0.9635632045173074,51\n482,e1bce,401,6.576150842796868e-05,7.576251862969416,7,0.576251862969416,394\n483,522f3,25806,0.00423202365708768,487.5629814857574,488,0.5629814857574047,25318\n484,a091d,297,4.870615462121371e-05,5.611338661600789,5,0.6113386616007892,292\n485,5d15a,3101,0.0005085447322571842,58.588421513885685,59,0.5884215138856845,3042\n486,3807a,606,9.938023468166839e-05,11.449398077205652,11,0.44939807720565206,595\n487,6fc74,14406,0.00236249448980877,272.1782651818888,273,0.17826518188877571,14133\n488,c4b61,186,3.0502844308234848e-05,3.5141716870631203,3,0.5141716870631203,183\n489,1d557,935,0.0001533341904741913,17.665325416150633,17,0.6653254161506332,918\n490,4810f,267,4.378634102311132e-05,5.044536776590609,5,0.04453677659060862,262\n492,20171,1303,0.00021368390394424736,24.61809520560885,24,0.6180952056088493,1279\n493,645a8,507,8.314484980793048e-05,9.578951856672054,9,0.5789518566720542,498\n494,32855,5841,0.0009578877075505363,110.3563270114822,111,0.35632701148219326,5730\n495,9f0b8,1978,0.00032437970990155125,37.37113761833792,38,0.3711376183379187,1940\n496,bfa97,171,2.804293750918365e-05,3.23077074455803,3,0.23077074455803004,168\n497,c7ae8,39,6.3957576775331136e-06,0.7368424505132349,0,0.7368424505132349,39\n498,0443b,107,1.754733516656521e-05,2.0215933898696448,2,0.021593389869644763,105\n499,8b6d1,36,5.9037763177228745e-06,0.680162262012217,0,0.680162262012217,36\n500,0df28,803,0.00013168701064254076,15.171397122105837,15,0.17139712210583724,788\n501,b8da7,1772,0.00029059698986124815,33.479098007934674,34,0.47909800793467383,1738\n502,be111,30,4.919813598102395e-06,0.5668018850101807,0,0.5668018850101807,30\n503,fff89,160,2.6239005856546107e-05,3.022943386720964,3,0.022943386720963854,157\n504,69f64,6756,0.0011079420222926595,127.64378450429271,128,0.6437845042927108,6628\n505,cd02f,1582,0.00025943817040659966,29.889352736203534,30,0.8893527362035343,1552\n506,49604,1205,0.00019761251285711287,22.76654238124226,22,0.7665423812422603,1183\n507,f5741,2674,0.0004385193853775268,50.52094135057411,51,0.5209413505741125,2623\n508,35c93,984,0.00016136988601775856,18.59110182833393,18,0.5911018283339295,966\n509,7af3c,15579,0.0025548592014945737,294.34021888578684,295,0.34021888578683956,15284\n510,b061f,733,0.00012020744558030186,13.848859390415416,13,0.8488593904154165,720\n511,10963,255,4.181841558387036e-05,4.817816022586537,4,0.8178160225865367,251\n512,7429f,596,9.774029681563425e-05,11.26046411553559,11,0.2604641155355907,585\n513,8e06f,114,1.86952916727891e-05,2.153847163038687,2,0.15384716303868684,112\n514,adfd4,10604,0.0017389901131425933,200.3455729549319,201,0.345572954931896,10403\n515,f5fd5,4215,0.0006912338105333865,79.63566484393039,80,0.6356648439303854,4135\n516,565d2,285,4.6738229181972755e-05,5.384617907596717,5,0.3846179075967173,280\n517,8f6cf,119,1.9515260605806166e-05,2.2483141438737166,2,0.24831414387371664,117\n518,4f72e,3024,0.0004959172106887215,57.13363000902623,58,0.13363000902622701,2966\n519,d9867,1147,0.00018810087323411492,21.670725403555913,21,0.6707254035559131,1126\n520,2c9f1,82,1.344749050147988e-05,1.549258485694494,1,0.549258485694494,81\n521,dccb0,2004,0.00032864354835324,37.86236591868007,38,0.8623659186800694,1966\n522,6e283,2322,0.00038079357249312537,43.87046589978799,44,0.8704658997879875,2278\n523,43593,923,0.00015136626503495037,17.438604662146563,17,0.43860466214656313,906\n524,17b6e,903,0.0001480863893028821,17.06073673880644,17,0.06073673880644037,886\n525,72ae9,357,5.85457818174185e-05,6.74494243162115,6,0.7449424316211504,351\n526,d902d,873,0.0001431665757047797,16.49393485379626,16,0.49393485379626156,857\n527,64690,318,5.215002413988539e-05,6.008099981107916,6,0.00809998110791632,312\n528,f476f,25,4.099844665085329e-06,0.4723349041751506,0,0.4723349041751506,25\n529,82153,127,2.0827210898633473e-05,2.3994613132097653,2,0.3994613132097653,125\n530,de876,128,2.0991204685236885e-05,2.418354709376771,2,0.418354709376771,126\n531,0f108,641,0.00010512001721278785,12.110666943050862,12,0.11066694305086244,629\n532,c4fd2,1020,0.00016727366233548145,19.271264090346147,19,0.27126409034614696,1001\n533,5a3a4,147,2.4107086630701735e-05,2.7773292365498854,2,0.7773292365498854,145\n534,59b0c,157,2.5747024496735867e-05,2.966263198219946,2,0.9662631982199459,155\n535,1b5c6,24,3.935850878481916e-06,0.45344150800814453,0,0.45344150800814453,24\n536,433d0,129,2.1155198471840298e-05,2.437248105543777,2,0.43724810554377713,127\n537,69d01,1266,0.00020761613383992107,23.919039547429627,23,0.9190395474296267,1243\n538,d05da,277,4.542627888914545e-05,5.233470738260669,5,0.2334707382606691,272\n539,c552a,924,0.00015153025882155377,17.45749805831357,17,0.4574980583135684,907\n540,ede77,196,3.214278217426898e-05,3.703105648733181,3,0.7031056487331808,193\n541,5fb34,336,5.510191229874683e-05,6.348181112114024,6,0.34818111211402414,330\n542,3c8b0,303,4.969011734083419e-05,5.724699038602826,5,0.724699038602826,298\n543,69faf,48,7.871701756963832e-06,0.9068830160162891,0,0.9068830160162891,48\n544,4282e,40,6.559751464136527e-06,0.755735846680241,0,0.755735846680241,40\n545,c8c9d,21,3.4438695186716767e-06,0.3967613195071265,0,0.3967613195071265,21\n" + ) + ) + + for x in range(100_01, df.remaining_balance.sum(), 12345): + amount = USDCent(x) + res = business_payout_event_manager.distribute_amount(df=df, amount=amount) + assert isinstance(res, pd.Series) + assert res.sum() == amount + + def test_ach_payment_min_amount( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + payout_event_manager, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + adj_to_fail_with_tx_factory, + thl_web_rr, + lm, + product_manager, + ): + """Test having a Business with three products.. one that lost money + and two that gained money. Ensure that the Business balance + reflects that to compensate for the Product in the negative and only + assigns Brokerage Product payments from the 2 accounts that have + a positive balance. + """ + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p1) + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("5.00"), + started=start + timedelta(days=1), + ) + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("5.00"), + started=start + timedelta(days=6), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + bp_payout_factory( + product=u1.product, + amount=USDCent(475), # 95% of $5.00 + created=start + timedelta(days=1, minutes=1), + ) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + with pytest.raises(expected_exception=AssertionError) as cm: + business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(500), + pm=product_manager, + thl_lm=thl_lm, + ) + assert "Must issue Supplier Payouts at least $100 minimum." in str(cm) + + def test_ach_payment( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + payout_event_manager, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + adj_to_fail_with_tx_factory, + thl_web_rr, + lm, + product_manager, + rm_ledger_collection, + rm_pop_ledger_merge, + ): + """Test having a Business with three products.. one that lost money + and two that gained money. Ensure that the Business balance + reflects that to compensate for the Product in the negative and only + assigns Brokerage Product payments from the 2 accounts that have + a positive balance. + """ + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + p2: Product = product_factory(business=business) + p3: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + u2: User = user_factory(product=p2) + u3: User = user_factory(product=p3) + thl_lm.get_account_or_create_bp_wallet(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p2) + thl_lm.get_account_or_create_bp_wallet(product=p3) + + ach_id1 = uuid4().hex + ach_id2 = uuid4().hex + + # Product 1: Complete, Payout, Recon.. + s1 = session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("5.00"), + started=start + timedelta(days=1), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + bp_payout_factory( + product=u1.product, + amount=USDCent(475), # 95% of $5.00 + ext_ref_id=ach_id1, + created=start + timedelta(days=1, minutes=1), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + adj_to_fail_with_tx_factory( + session=s1, + created=start + timedelta(days=1, minutes=2), + ) + + # Product 2: Complete x10 + for idx in range(15): + session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal("7.50"), + started=start + timedelta(days=1, hours=2, minutes=1 + idx), + ) + + # Product 3: Complete x5 + for idx in range(10): + session_with_tx_factory( + user=u3, + wall_req_cpi=Decimal("7.50"), + started=start + timedelta(days=1, hours=3, minutes=1 + idx), + ) + + # Now that we paid out the business, let's confirm the updated balances + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + bb1 = business.balance + pb1 = bb1.product_balances[0] + pb2 = bb1.product_balances[1] + pb3 = bb1.product_balances[2] + assert bb1.payout == 25 * 712 + 475 # $7.50 * .95% = $7.125 = $7.12 + assert bb1.adjustment == -475 + assert bb1.net == ((25 * 7.12 + 4.75) - 4.75) * 100 + # The balance is lower because the single $4.75 payout + assert bb1.balance == bb1.net - 475 + assert ( + bb1.available_balance + == (pb2.available_balance + pb3.available_balance) - 475 + ) + + assert pb1.available_balance == 0 + assert pb2.available_balance == 8010 + assert pb3.available_balance == 5340 + + assert bb1.recoup_usd_str == "$4.75" + assert pb1.recoup_usd_str == "$4.75" + assert pb2.recoup_usd_str == "$0.00" + assert pb3.recoup_usd_str == "$0.00" + + assert business.payouts is None + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + assert len(business.payouts) == 1 + assert business.payouts[0].ext_ref_id == ach_id1 + + bp1 = business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(bb1.available_balance), + pm=product_manager, + thl_lm=thl_lm, + created=start + timedelta(days=1, hours=5), + ) + assert isinstance(bp1, BusinessPayoutEvent) + assert len(bp1.bp_payouts) == 2 + assert bp1.bp_payouts[0].status == PayoutStatus.COMPLETE + assert bp1.bp_payouts[1].status == PayoutStatus.COMPLETE + bp1_tx = brokerage_product_payout_event_manager.check_for_ledger_tx( + thl_ledger_manager=thl_lm, + payout_event=bp1.bp_payouts[0], + product_id=bp1.bp_payouts[0].product_id, + amount=bp1.bp_payouts[0].amount, + ) + assert bp1_tx + + bp2_tx = brokerage_product_payout_event_manager.check_for_ledger_tx( + thl_ledger_manager=thl_lm, + payout_event=bp1.bp_payouts[1], + product_id=bp1.bp_payouts[1].product_id, + amount=bp1.bp_payouts[1].amount, + ) + assert bp2_tx + + # Now that we paid out the business, let's confirm the updated balances + rm_ledger_collection() + rm_pop_ledger_merge() + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + assert len(business.payouts) == 2 + assert len(business.payouts[0].bp_payouts) == 2 + assert len(business.payouts[1].bp_payouts) == 1 + + bb2 = business.balance + + # Okay os we have the balance before, and after the Business Payout + # of bb1.available_balance worth.. + assert bb1.payout == bb2.payout + assert bb1.adjustment == bb2.adjustment + assert bb1.net == bb2.net + assert bb1.available_balance > bb2.available_balance + + # This is the ultimate test. Confirm that the second time we get the + # Business balance, it is equal to the first time we Business balance + # minus the amount that was just paid out across any children + # Brokerage Products. + # + # This accounts for all the net positive and net negative children + # Brokerage Products under the Business in thi + assert bb2.balance == bb1.balance - bb1.available_balance + + def test_ach_payment_partial_amount( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + payout_event_manager, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + adj_to_fail_with_tx_factory, + thl_web_rr, + lm, + product_manager, + rm_ledger_collection, + rm_pop_ledger_merge, + ): + """There are valid instances when we want issue a ACH or Wire to a + Business, but not for the full Available Balance amount in their + account. + + To test this, we'll create a Business with multiple Products, and + a cumulative Available Balance of $100 (for example), but then only + issue a payout of $75 (for example). We want to confirm the sum + of the Product payouts equals the $75 number and isn't greedy and + takes the full $100 amount that is available to the Business. + + """ + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + p2: Product = product_factory(business=business) + p3: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + u2: User = user_factory(product=p2) + u3: User = user_factory(product=p3) + thl_lm.get_account_or_create_bp_wallet(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p2) + thl_lm.get_account_or_create_bp_wallet(product=p3) + + # Product 1, 2, 3: Complete, and Payout multiple times. + for idx in range(5): + for u in [u1, u2, u3]: + session_with_tx_factory( + user=u, + wall_req_cpi=Decimal("50.00"), + started=start + timedelta(days=1, hours=2, minutes=1 + idx), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + # Now that we paid out the business, let's confirm the updated balances + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + + # Confirm the initial amounts. + assert len(business.payouts) == 0 + bb1 = business.balance + assert bb1.payout == 3 * 5 * 4750 + assert bb1.adjustment == 0 + assert bb1.payout == bb1.net + + assert bb1.balance_usd_str == "$712.50" + assert bb1.available_balance_usd_str == "$534.39" + + for x in range(2): + assert bb1.product_balances[x].balance == 5 * 4750 + assert bb1.product_balances[x].available_balance_usd_str == "$178.13" + + assert business.payouts_total_str == "$0.00" + assert business.balance.payment_usd_str == "$0.00" + assert business.balance.available_balance_usd_str == "$534.39" + + # This is the important part, even those the Business has $534.39 + # available to it, we are only trying to issue out a $250.00 ACH or + # Wire to the Business + bp1 = business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(250_00), + pm=product_manager, + thl_lm=thl_lm, + created=start + timedelta(days=1, hours=3), + ) + assert isinstance(bp1, BusinessPayoutEvent) + assert len(bp1.bp_payouts) == 3 + + # Now that we paid out the business, let's confirm the updated + # balances. Clear and rebuild the parquet files. + rm_ledger_collection() + rm_pop_ledger_merge() + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + # Now rebuild and confirm the payouts, balance.payment, and the + # balance.available_balance are reflective of having a $250 ACH/Wire + # sent to the Business + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + assert len(business.payouts) == 1 + assert len(business.payouts[0].bp_payouts) == 3 + assert business.payouts_total_str == "$250.00" + assert business.balance.payment_usd_str == "$250.00" + assert business.balance.available_balance_usd_str == "$346.88" + + def test_ach_tx_id_reference( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + payout_event_manager, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + adj_to_fail_with_tx_factory, + thl_web_rr, + lm, + product_manager, + rm_ledger_collection, + rm_pop_ledger_merge, + ): + + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + p2: Product = product_factory(business=business) + p3: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + u2: User = user_factory(product=p2) + u3: User = user_factory(product=p3) + thl_lm.get_account_or_create_bp_wallet(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p2) + thl_lm.get_account_or_create_bp_wallet(product=p3) + + ach_id1 = uuid4().hex + ach_id2 = uuid4().hex + + for idx in range(15): + for iidx, u in enumerate([u1, u2, u3]): + session_with_tx_factory( + user=u, + wall_req_cpi=Decimal("7.50"), + started=start + timedelta(days=1, hours=1 + iidx, minutes=1 + idx), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + rm_ledger_collection() + rm_pop_ledger_merge() + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + bp1 = business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(100_01), + transaction_id=ach_id1, + pm=product_manager, + thl_lm=thl_lm, + created=start + timedelta(days=2, hours=1), + ) + + rm_ledger_collection() + rm_pop_ledger_merge() + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + bp2 = business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(100_02), + transaction_id=ach_id2, + pm=product_manager, + thl_lm=thl_lm, + created=start + timedelta(days=4, hours=1), + ) + + assert isinstance(bp1, BusinessPayoutEvent) + assert isinstance(bp2, BusinessPayoutEvent) + + rm_ledger_collection() + rm_pop_ledger_merge() + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + assert business.payouts[0].ext_ref_id == ach_id2 + assert business.payouts[1].ext_ref_id == ach_id1 diff --git a/tests/managers/thl/test_product.py b/tests/managers/thl/test_product.py new file mode 100644 index 0000000..78d5dde --- /dev/null +++ b/tests/managers/thl/test_product.py @@ -0,0 +1,362 @@ +from uuid import uuid4 + +import pytest + +from generalresearch.models import Source +from generalresearch.models.thl.product import ( + Product, + SourceConfig, + UserCreateConfig, + SourcesConfig, + UserHealthConfig, + ProfilingConfig, + SupplyPolicy, + SupplyConfig, +) +from test_utils.models.conftest import product_factory + + +class TestProductManagerGetMethods: + def test_get_by_uuid(self, product_manager): + product: Product = product_manager.create_dummy( + product_id=uuid4().hex, + team_id=uuid4().hex, + name=f"Test Product ID #{uuid4().hex[:6]}", + ) + + instance = product_manager.get_by_uuid(product_uuid=product.id) + assert isinstance(instance, Product) + # self.assertEqual(instance.model_dump(mode="json"), product.model_dump(mode="json")) + assert instance.id == product.id + + with pytest.raises(AssertionError) as cm: + product_manager.get_by_uuid(product_uuid="abc123") + assert "invalid uuid" in str(cm.value) + + with pytest.raises(AssertionError) as cm: + product_manager.get_by_uuid(product_uuid=uuid4().hex) + assert "product not found" in str(cm.value) + + def test_get_by_uuids(self, product_manager): + cnt = 5 + + product_uuids = [uuid4().hex for idx in range(cnt)] + for product_id in product_uuids: + product_manager.create_dummy( + product_id=product_id, + team_id=uuid4().hex, + name=f"Test Product ID #{uuid4().hex[:6]}", + ) + + res = product_manager.get_by_uuids(product_uuids=product_uuids) + assert isinstance(res, list) + assert cnt == len(res) + for instance in res: + assert isinstance(instance, Product) + + with pytest.raises(AssertionError) as cm: + product_manager.get_by_uuids(product_uuids=product_uuids + [uuid4().hex]) + assert "incomplete product response" in str(cm.value) + + with pytest.raises(AssertionError) as cm: + product_manager.get_by_uuids(product_uuids=product_uuids + ["abc123"]) + assert "invalid uuid" in str(cm.value) + + def test_get_by_uuid_if_exists(self, product_manager): + product: Product = product_manager.create_dummy( + product_id=uuid4().hex, + team_id=uuid4().hex, + name=f"Test Product ID #{uuid4().hex[:6]}", + ) + instance = product_manager.get_by_uuid_if_exists(product_uuid=product.id) + assert isinstance(instance, Product) + + instance = product_manager.get_by_uuid_if_exists(product_uuid="abc123") + assert instance == None + + def test_get_by_uuids_if_exists(self, product_manager): + product_uuids = [uuid4().hex for _ in range(2)] + for product_id in product_uuids: + product_manager.create_dummy( + product_id=product_id, + team_id=uuid4().hex, + name=f"Test Product ID #{uuid4().hex[:6]}", + ) + + res = product_manager.get_by_uuids_if_exists(product_uuids=product_uuids) + assert isinstance(res, list) + assert 2 == len(res) + for instance in res: + assert isinstance(instance, Product) + + res = product_manager.get_by_uuids_if_exists( + product_uuids=product_uuids + [uuid4().hex] + ) + assert isinstance(res, list) + assert 2 == len(res) + for instance in res: + assert isinstance(instance, Product) + + # # This will raise an error b/c abc123 isn't a uuid. + # res = product_manager.get_by_uuids_if_exists( + # product_uuids=product_uuids + ["abc123"] + # ) + # assert isinstance(res, list) + # assert 2 == len(res) + # for instance in res: + # assert isinstance(instance, Product) + + def test_get_by_business_ids(self, product_manager): + business_ids = [uuid4().hex for i in range(5)] + + product_manager.fetch_uuids(business_uuids=business_ids) + + for business_id in business_ids: + product_manager.create( + product_id=uuid4().hex, + team_id=None, + business_id=business_id, + redirect_url="https://www.example.com", + name=f"Test Product ID #{uuid4().hex[:6]}", + user_create_config=None, + ) + + +class TestProductManagerCreation: + + def test_base(self, product_manager): + instance = product_manager.create_dummy( + product_id=uuid4().hex, + team_id=uuid4().hex, + name=f"New Test Product {uuid4().hex[:6]}", + ) + + assert isinstance(instance, Product) + + +class TestProductManagerCreate: + + def test_create_simple(self, product_manager): + # Always required: product_id, team_id, name, redirect_url + # Required internally - if not passed use default: harmonizer_domain, + # commission_pct, sources + product_id = uuid4().hex + team_id = uuid4().hex + business_id = uuid4().hex + + product_manager.create( + product_id=product_id, + team_id=team_id, + business_id=business_id, + name=f"Test Product ID #{uuid4().hex[:6]}", + redirect_url="https://www.example.com", + ) + + instance = product_manager.get_by_uuid(product_uuid=product_id) + + assert team_id == instance.team_id + assert business_id == instance.business_id + + +class TestProductManager: + sources = [ + SourceConfig.model_validate(x) + for x in [ + {"name": "d", "active": True}, + { + "name": "f", + "active": False, + "banned_countries": ["ps", "in", "in"], + }, + { + "name": "s", + "active": True, + "supplier_id": "3488", + "allow_pii_only_buyers": True, + "allow_unhashed_buyers": True, + }, + {"name": "e", "active": True, "withhold_profiling": True}, + ] + ] + + def test_get_by_uuid1(self, product_manager, team, product, product_factory): + p1 = product_factory(team=team) + instance = product_manager.get_by_uuid(product_uuid=p1.uuid) + assert instance.id == p1.id + + # No Team and no user_create_config + assert instance.team_id == team.uuid + + # user_create_config can't be None, so ensure the default was set. + assert isinstance(instance.user_create_config, UserCreateConfig) + assert 0 == instance.user_create_config.min_hourly_create_limit + assert instance.user_create_config.max_hourly_create_limit is None + + def test_get_by_uuid2(self, product_manager, product_factory): + p2 = product_factory() + instance = product_manager.get_by_uuid(p2.id) + assert instance.id, p2.id + + # Team and default user_create_config + assert instance.team_id is not None + assert instance.team_id == p2.team_id + + assert 0 == instance.user_create_config.min_hourly_create_limit + assert instance.user_create_config.max_hourly_create_limit is None + + def test_get_by_uuid3(self, product_manager, product_factory): + p3 = product_factory() + instance = product_manager.get_by_uuid(p3.id) + assert instance.id == p3.id + + # Team and default user_create_config + assert instance.team_id is not None + assert instance.team_id == p3.team_id + + assert ( + p3.user_create_config.min_hourly_create_limit + == instance.user_create_config.min_hourly_create_limit + ) + assert instance.user_create_config.max_hourly_create_limit is None + assert not instance.user_wallet_config.enabled + + def test_sources(self, product_manager): + user_defined = [SourceConfig(name=Source.DYNATA, active=False)] + sources_config = SourcesConfig(user_defined=user_defined) + p = product_manager.create_dummy(sources_config=sources_config) + + p2 = product_manager.get_by_uuid(p.id) + + assert p == p2 + assert p2.sources_config.user_defined == user_defined + + # Assert d is off and everything else is on + dynata = p2.sources_dict[Source.DYNATA] + assert not dynata.active + assert all(x.active is True for x in p2.sources if x.name != Source.DYNATA) + + def test_global_sources(self, product_manager): + sources_config = SupplyConfig( + policies=[ + SupplyPolicy( + name=Source.DYNATA, + active=True, + address=["https://www.example.com"], + distribute_harmonizer_active=True, + ) + ] + ) + p1 = product_manager.create_dummy(sources_config=sources_config) + p2 = product_manager.get_by_uuid(p1.id) + assert p1 == p2 + + p1.sources_config.policies.append( + SupplyPolicy( + name=Source.CINT, + active=True, + address=["https://www.example.com"], + distribute_harmonizer_active=True, + ) + ) + product_manager.update(p1) + p2 = product_manager.get_by_uuid(p1.id) + assert p1 == p2 + + def test_user_health_config(self, product_manager): + p = product_manager.create_dummy( + user_health_config=UserHealthConfig(banned_countries=["ng", "in"]) + ) + + p2 = product_manager.get_by_uuid(p.id) + + assert p == p2 + assert p2.user_health_config.banned_countries == ["in", "ng"] + assert p2.user_health_config.allow_ban_iphist + + def test_profiling_config(self, product_manager): + p = product_manager.create_dummy( + profiling_config=ProfilingConfig(max_questions=1) + ) + p2 = product_manager.get_by_uuid(p.id) + + assert p == p2 + assert p2.profiling_config.max_questions == 1 + + # def test_user_create_config(self): + # # -- Product 1 --- + # instance1 = PM.get_by_uuid(self.product_id1) + # self.assertEqual(instance1.user_create_config.min_hourly_create_limit, 0) + # + # self.assertEqual(instance1.user_create_config.max_hourly_create_limit, None) + # self.assertEqual(60, instance1.user_create_config.clip_hourly_create_limit(60)) + # + # # -- Product 2 --- + # instance2 = PM.get_by_uuid(self.product2.id) + # self.assertEqual(instance2.user_create_config.min_hourly_create_limit, 200) + # self.assertEqual(instance2.user_create_config.max_hourly_create_limit, None) + # + # self.assertEqual(200, instance2.user_create_config.clip_hourly_create_limit(60)) + # self.assertEqual(300, instance2.user_create_config.clip_hourly_create_limit(300)) + # + # def test_create_and_cache(self): + # PM.uuid_cache.clear() + # + # # Hit it once, should hit mysql + # with self.assertLogs(level='INFO') as cm: + # p = PM.get_by_uuid(product_id3) + # self.assertEqual(cm.output, [f"INFO:root:Product.get_by_uuid:{product_id3}"]) + # self.assertIsNotNone(p) + # self.assertEqual(p.id, product_id3) + # + # # Hit it again, should be pulled from cachetools cache + # with self.assertLogs(level='INFO') as cm: + # logger.info("nothing") + # p = PM.get_by_uuid(product_id3) + # self.assertEqual(cm.output, [f"INFO:root:nothing"]) + # self.assertEqual(len(cm.output), 1) + # self.assertIsNotNone(p) + # self.assertEqual(p.id, product_id3) + + +class TestProductManagerUpdate: + + def test_update(self, product_manager): + p = product_manager.create_dummy() + p.name = "new name" + p.enabled = False + p.user_create_config = UserCreateConfig(min_hourly_create_limit=200) + p.sources_config = SourcesConfig( + user_defined=[SourceConfig(name=Source.DYNATA, active=False)] + ) + product_manager.update(new_product=p) + # We cleared the cache in the update + # PM.id_cache.clear() + p2 = product_manager.get_by_uuid(p.id) + + assert p2.name == "new name" + assert not p2.enabled + assert p2.user_create_config.min_hourly_create_limit == 200 + assert not p2.sources_dict[Source.DYNATA].active + + +class TestProductManagerCacheClear: + + def test_cache_clear(self, product_manager): + p = product_manager.create_dummy() + product_manager.get_by_uuid(product_uuid=p.id) + product_manager.get_by_uuid(product_uuid=p.id) + product_manager.pg_config.execute_write( + query=""" + UPDATE userprofile_brokerageproduct + SET name = 'test-6d9a5ddfd' + WHERE id = %s""", + params=[p.id], + ) + + product_manager.cache_clear(p.id) + + # Calling this with or without kwargs hits different internal keys in the cache! + p2 = product_manager.get_by_uuid(product_uuid=p.id) + assert p2.name == "test-6d9a5ddfd" + p2 = product_manager.get_by_uuid(product_uuid=p.id) + assert p2.name == "test-6d9a5ddfd" diff --git a/tests/managers/thl/test_product_prod.py b/tests/managers/thl/test_product_prod.py new file mode 100644 index 0000000..7b4f677 --- /dev/null +++ b/tests/managers/thl/test_product_prod.py @@ -0,0 +1,82 @@ +import logging +from uuid import uuid4 + +import pytest + +from generalresearch.models.thl.product import Product +from test_utils.models.conftest import product_factory + +logger = logging.getLogger() + + +class TestProductManagerGetMethods: + + def test_get_by_uuid(self, product_manager, product_factory): + # Just test that we load properly + for p in [product_factory(), product_factory(), product_factory()]: + instance = product_manager.get_by_uuid(product_uuid=p.id) + assert isinstance(instance, Product) + assert instance.id == p.id + assert instance.uuid == p.uuid + + with pytest.raises(expected_exception=AssertionError) as cm: + product_manager.get_by_uuid(product_uuid=uuid4().hex) + assert "product not found" in str(cm.value) + + def test_get_by_uuids(self, product_manager, product_factory): + products = [product_factory(), product_factory(), product_factory()] + cnt = len(products) + res = product_manager.get_by_uuids(product_uuids=[p.id for p in products]) + assert isinstance(res, list) + assert cnt == len(res) + for instance in res: + assert isinstance(instance, Product) + + with pytest.raises(expected_exception=AssertionError) as cm: + product_manager.get_by_uuids( + product_uuids=[p.id for p in products] + [uuid4().hex] + ) + assert "incomplete product response" in str(cm.value) + with pytest.raises(expected_exception=AssertionError) as cm: + product_manager.get_by_uuids( + product_uuids=[p.id for p in products] + ["abc123"] + ) + assert "invalid uuid passed" in str(cm.value) + + def test_get_by_uuid_if_exists(self, product_factory, product_manager): + products = [product_factory(), product_factory(), product_factory()] + + instance = product_manager.get_by_uuid_if_exists(product_uuid=products[0].id) + assert isinstance(instance, Product) + + instance = product_manager.get_by_uuid_if_exists(product_uuid="abc123") + assert instance is None + + def test_get_by_uuids_if_exists(self, product_manager, product_factory): + products = [product_factory(), product_factory(), product_factory()] + + res = product_manager.get_by_uuids_if_exists( + product_uuids=[p.id for p in products[:2]] + ) + assert isinstance(res, list) + assert 2 == len(res) + for instance in res: + assert isinstance(instance, Product) + + res = product_manager.get_by_uuids_if_exists( + product_uuids=[p.id for p in products[:2]] + [uuid4().hex] + ) + assert isinstance(res, list) + assert 2 == len(res) + for instance in res: + assert isinstance(instance, Product) + + +class TestProductManagerGetAll: + + @pytest.mark.skip(reason="TODO") + def test_get_ALL_by_ids(self, product_manager): + products = product_manager.get_all(rand_limit=50) + logger.info(f"Fetching {len(products)} product uuids") + # todo: once timebucks stops spamming broken accounts, fetch more + pass diff --git a/tests/managers/thl/test_profiling/__init__.py b/tests/managers/thl/test_profiling/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/managers/thl/test_profiling/__init__.py diff --git a/tests/managers/thl/test_profiling/test_question.py b/tests/managers/thl/test_profiling/test_question.py new file mode 100644 index 0000000..998466e --- /dev/null +++ b/tests/managers/thl/test_profiling/test_question.py @@ -0,0 +1,49 @@ +from uuid import uuid4 + +from generalresearch.managers.thl.profiling.question import QuestionManager +from generalresearch.models import Source + + +class TestQuestionManager: + + def test_get_multi_upk(self, question_manager: QuestionManager, upk_data): + qs = question_manager.get_multi_upk( + question_ids=[ + "8a22de34f985476aac85e15547100db8", + "0565f87d4bf044298ba169de1339ff7e", + "b2b32d68403647e3a87e778a6348d34c", + uuid4().hex, + ] + ) + assert len(qs) == 3 + + def test_get_questions_ranked(self, question_manager: QuestionManager, upk_data): + qs = question_manager.get_questions_ranked(country_iso="mx", language_iso="spa") + assert len(qs) >= 40 + assert qs[0].importance.task_score > qs[40].importance.task_score + assert all(q.country_iso == "mx" and q.language_iso == "spa" for q in qs) + + def test_lookup_by_property(self, question_manager: QuestionManager, upk_data): + q = question_manager.lookup_by_property( + property_code="i:industry", country_iso="us", language_iso="eng" + ) + assert q.source == Source.INNOVATE + + q.explanation_template = "You work in the {answer} industry." + q.explanation_fragment_template = "you work in the {answer} industry" + question_manager.update_question_explanation(q) + + q = question_manager.lookup_by_property( + property_code="i:industry", country_iso="us", language_iso="eng" + ) + assert q.explanation_template + + def test_filter_by_property(self, question_manager: QuestionManager, upk_data): + lookup = [ + ("i:industry", "us", "eng"), + ("i:industry", "mx", "eng"), + ("m:age", "us", "eng"), + (f"m:{uuid4().hex}", "us", "eng"), + ] + qs = question_manager.filter_by_property(lookup) + assert len(qs) == 3 diff --git a/tests/managers/thl/test_profiling/test_schema.py b/tests/managers/thl/test_profiling/test_schema.py new file mode 100644 index 0000000..ae61527 --- /dev/null +++ b/tests/managers/thl/test_profiling/test_schema.py @@ -0,0 +1,44 @@ +from generalresearch.models.thl.profiling.upk_property import PropertyType + + +class TestUpkSchemaManager: + + def test_get_props_info(self, upk_schema_manager, upk_data): + props = upk_schema_manager.get_props_info() + assert ( + len(props) == 16955 + ) # ~ 70 properties x each country they are available in + + gender = [ + x + for x in props + if x.country_iso == "us" + and x.property_id == "73175402104741549f21de2071556cd7" + ] + assert len(gender) == 1 + gender = gender[0] + assert len(gender.allowed_items) == 3 + assert gender.allowed_items[0].label == "female" + assert gender.allowed_items[1].label == "male" + assert gender.prop_type == PropertyType.UPK_ITEM + assert gender.categories[0].label == "Demographic" + + age = [ + x + for x in props + if x.country_iso == "us" + and x.property_id == "94f7379437874076b345d76642d4ce6d" + ] + assert len(age) == 1 + age = age[0] + assert age.allowed_items is None + assert age.prop_type == PropertyType.UPK_NUMERICAL + assert age.gold_standard + + cars = [ + x + for x in props + if x.country_iso == "us" and x.property_label == "household_auto_type" + ][0] + assert not cars.gold_standard + assert cars.categories[0].label == "Autos & Vehicles" diff --git a/tests/managers/thl/test_profiling/test_uqa.py b/tests/managers/thl/test_profiling/test_uqa.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/managers/thl/test_profiling/test_uqa.py @@ -0,0 +1 @@ + diff --git a/tests/managers/thl/test_profiling/test_user_upk.py b/tests/managers/thl/test_profiling/test_user_upk.py new file mode 100644 index 0000000..53bb8fe --- /dev/null +++ b/tests/managers/thl/test_profiling/test_user_upk.py @@ -0,0 +1,59 @@ +from datetime import datetime, timezone + +from generalresearch.managers.thl.profiling.user_upk import UserUpkManager + +now = datetime.now(tz=timezone.utc) +base = { + "country_iso": "us", + "language_iso": "eng", + "timestamp": now, +} +upk_ans_dict = [ + {"pred": "gr:gender", "obj": "gr:male"}, + {"pred": "gr:age_in_years", "obj": "43"}, + {"pred": "gr:home_postal_code", "obj": "33143"}, + {"pred": "gr:ethnic_group", "obj": "gr:caucasians"}, + {"pred": "gr:ethnic_group", "obj": "gr:asian"}, +] +for a in upk_ans_dict: + a.update(base) + + +class TestUserUpkManager: + + def test_user_upk_empty(self, user_upk_manager: UserUpkManager, upk_data, user): + res = user_upk_manager.get_user_upk_mysql(user_id=user.user_id) + assert len(res) == 0 + + def test_user_upk(self, user_upk_manager: UserUpkManager, upk_data, user): + for x in upk_ans_dict: + x["user_id"] = user.user_id + user_upk = user_upk_manager.populate_user_upk_from_dict(upk_ans_dict) + user_upk_manager.set_user_upk(upk_ans=user_upk) + + d = user_upk_manager.get_user_upk_simple(user_id=user.user_id) + assert d["gender"] == "male" + assert d["age_in_years"] == 43 + assert d["home_postal_code"] == "33143" + assert d["ethnic_group"] == {"caucasians", "asian"} + + # Change my answers. age 43->44, gender male->female, + # ethnic->remove asian, add black_or_african_american + for x in upk_ans_dict: + if x["pred"] == "age_in_years": + x["obj"] = "44" + if x["pred"] == "gender": + x["obj"] = "female" + upk_ans_dict[-1]["obj"] = "black_or_african_american" + user_upk = user_upk_manager.populate_user_upk_from_dict(upk_ans_dict) + user_upk_manager.set_user_upk(upk_ans=user_upk) + + d = user_upk_manager.get_user_upk_simple(user_id=user.user_id) + assert d["gender"] == "female" + assert d["age_in_years"] == 44 + assert d["home_postal_code"] == "33143" + assert d["ethnic_group"] == {"caucasians", "black_or_african_american"} + + age, gender = user_upk_manager.get_age_gender(user_id=user.user_id) + assert age == 44 + assert gender == "female" diff --git a/tests/managers/thl/test_session_manager.py b/tests/managers/thl/test_session_manager.py new file mode 100644 index 0000000..6bedc2b --- /dev/null +++ b/tests/managers/thl/test_session_manager.py @@ -0,0 +1,137 @@ +from datetime import timedelta +from decimal import Decimal +from uuid import uuid4 + +from faker import Faker + +from generalresearch.models import DeviceType +from generalresearch.models.legacy.bucket import Bucket +from generalresearch.models.thl.definitions import ( + Status, + StatusCode1, + SessionStatusCode2, +) +from test_utils.models.conftest import user + +fake = Faker() + + +class TestSessionManager: + def test_create_session(self, session_manager, user, utc_hour_ago): + bucket = Bucket( + loi_min=timedelta(seconds=60), + loi_max=timedelta(seconds=120), + user_payout_min=Decimal("1"), + user_payout_max=Decimal("2"), + ) + + s1 = session_manager.create( + started=utc_hour_ago, + user=user, + country_iso="us", + device_type=DeviceType.MOBILE, + ip=fake.ipv4_public(), + bucket=bucket, + url_metadata={"foo": "bar"}, + uuid_id=uuid4().hex, + ) + + assert s1.id is not None + s2 = session_manager.get_from_uuid(session_uuid=s1.uuid) + assert s1 == s2 + + def test_finish_with_status(self, session_manager, user, utc_hour_ago): + uuid_1 = uuid4().hex + session = session_manager.create( + started=utc_hour_ago, user=user, uuid_id=uuid_1 + ) + session_manager.finish_with_status( + session=session, + status=Status.FAIL, + status_code_1=StatusCode1.SESSION_CONTINUE_FAIL, + status_code_2=SessionStatusCode2.USER_IS_BLOCKED, + ) + + s2 = session_manager.get_from_uuid(session_uuid=uuid_1) + assert s2.status == Status.FAIL + assert s2.status_code_1 == StatusCode1.SESSION_CONTINUE_FAIL + assert s2.status_code_2 == SessionStatusCode2.USER_IS_BLOCKED + + +class TestSessionManagerFilter: + + def test_base(self, session_manager, user, utc_now): + uuid_id = uuid4().hex + session_manager.create(started=utc_now, user=user, uuid_id=uuid_id) + res = session_manager.filter(limit=1) + assert len(res) != 0 + assert isinstance(res, list) + assert res[0].uuid == uuid_id + + def test_user(self, session_manager, user, utc_hour_ago): + session_manager.create(started=utc_hour_ago, user=user, uuid_id=uuid4().hex) + session_manager.create(started=utc_hour_ago, user=user, uuid_id=uuid4().hex) + + res = session_manager.filter(user=user) + assert len(res) == 2 + + def test_product( + self, product_factory, user_factory, session_manager, user, utc_hour_ago + ): + from generalresearch.models.thl.session import Session + from generalresearch.models.thl.user import User + + p1 = product_factory() + + for n in range(5): + u = user_factory(product=p1) + session_manager.create(started=utc_hour_ago, user=u, uuid_id=uuid4().hex) + + res = session_manager.filter( + product_uuids=[p1.uuid], started_since=utc_hour_ago + ) + assert isinstance(res[0], Session) + assert isinstance(res[0].user, User) + assert len(res) == 5 + + def test_team( + self, + product_factory, + user_factory, + team, + session_manager, + user, + utc_hour_ago, + thl_web_rr, + ): + p1 = product_factory(team=team) + + for n in range(5): + u = user_factory(product=p1) + session_manager.create(started=utc_hour_ago, user=u, uuid_id=uuid4().hex) + + team.prefetch_products(thl_pg_config=thl_web_rr) + assert len(team.product_uuids) == 1 + res = session_manager.filter(product_uuids=team.product_uuids) + assert len(res) == 5 + + def test_business( + self, + product_factory, + business, + user_factory, + session_manager, + user, + utc_hour_ago, + thl_web_rr, + ): + p1 = product_factory(business=business) + + for n in range(5): + u = user_factory(product=p1) + session_manager.create(started=utc_hour_ago, user=u, uuid_id=uuid4().hex) + + business.prefetch_products(thl_pg_config=thl_web_rr) + assert len(business.product_uuids) == 1 + res = session_manager.filter(product_uuids=business.product_uuids) + assert len(res) == 5 diff --git a/tests/managers/thl/test_survey.py b/tests/managers/thl/test_survey.py new file mode 100644 index 0000000..58c4577 --- /dev/null +++ b/tests/managers/thl/test_survey.py @@ -0,0 +1,376 @@ +import uuid +from datetime import datetime, timezone +from decimal import Decimal + +import pytest + +from generalresearch.models import Source +from generalresearch.models.legacy.bucket import ( + SurveyEligibilityCriterion, + TopNPlusBucket, + DurationSummary, + PayoutSummary, +) +from generalresearch.models.thl.profiling.user_question_answer import ( + UserQuestionAnswer, +) +from generalresearch.models.thl.survey.model import ( + Survey, + SurveyStat, + SurveyCategoryModel, + SurveyEligibilityDefinition, +) + + +@pytest.fixture(scope="session") +def surveys_fixture(): + return [ + Survey(source=Source.TESTING, survey_id="a", buyer_code="buyer1"), + Survey(source=Source.TESTING, survey_id="b", buyer_code="buyer2"), + Survey(source=Source.TESTING, survey_id="c", buyer_code="buyer2"), + ] + + +ssa = SurveyStat( + survey_source=Source.TESTING, + survey_survey_id="a", + survey_is_live=True, + quota_id="__all__", + cpi=Decimal(1), + country_iso="us", + version=1, + complete_too_fast_cutoff=300, + prescreen_conv_alpha=10, + prescreen_conv_beta=10, + conv_alpha=10, + conv_beta=10, + dropoff_alpha=10, + dropoff_beta=10, + completion_time_mu=1, + completion_time_sigma=0.4, + mobile_eligible_alpha=10, + mobile_eligible_beta=10, + desktop_eligible_alpha=10, + desktop_eligible_beta=10, + tablet_eligible_alpha=10, + tablet_eligible_beta=10, + long_fail_rate=1, + user_report_coeff=1, + recon_likelihood=0, + score_x0=0, + score_x1=1, + score=100, +) +ssb = ssa.model_dump() +ssb["survey_source"] = Source.TESTING +ssb["survey_survey_id"] = "b" +ssb["completion_time_mu"] = 2 +ssb["score"] = 90 +ssb = SurveyStat.model_validate(ssb) + + +class TestSurvey: + + def test( + self, + delete_buyers_surveys, + buyer_manager, + survey_manager, + surveys_fixture, + ): + survey_manager.create_or_update(surveys_fixture) + survey_ids = {s.survey_id for s in surveys_fixture} + res = survey_manager.filter_by_natural_key( + source=Source.TESTING, survey_ids=survey_ids + ) + assert len(res) == len(surveys_fixture) + assert res[0].id is not None + surveys2 = surveys_fixture.copy() + surveys2.append( + Survey(source=Source.TESTING, survey_id="d", buyer_code="buyer2") + ) + survey_manager.create_or_update(surveys2) + survey_ids = {s.survey_id for s in surveys2} + res2 = survey_manager.filter_by_natural_key( + source=Source.TESTING, survey_ids=survey_ids + ) + assert res2[0].id == res[0].id + assert res2[0] == res[0] + assert len(res2) == len(surveys2) + + def test_category(self, survey_manager): + survey1 = Survey(id=562289, survey_id="a", source=Source.TESTING) + survey2 = Survey(id=562290, survey_id="a", source=Source.TESTING) + categories = list(survey_manager.category_manager.categories.values()) + sc = [SurveyCategoryModel(category=c, strength=1 / 2) for c in categories[:2]] + sc2 = [SurveyCategoryModel(category=c, strength=1 / 2) for c in categories[2:4]] + survey1.categories = sc + survey2.categories = sc2 + surveys = [survey1, survey2] + survey_manager.update_surveys_categories(surveys) + + def test_survey_eligibility( + self, survey_manager, upk_data, question_manager, uqa_manager + ): + bucket = TopNPlusBucket( + id="c82cf98c578a43218334544ab376b00e", + contents=[], + duration=DurationSummary(max=1, min=1, q1=1, q2=1, q3=1), + quality_score=1, + payout=PayoutSummary(max=1, min=1, q1=1, q2=1, q3=1), + uri="https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i=2a4a897a76464af2b85703b72a125da0&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=82fe142", + ) + + survey1 = Survey( + survey_id="a", + source=Source.TESTING, + eligibility_criteria=SurveyEligibilityDefinition( + # "5d6d9f3c03bb40bf9d0a24f306387d7c", # gr:gender + # "c1309f099ab84a39b01200a56dac65cf", # d:600 + # "90e86550ddf84b08a9f7f5372dd9651b", # i:gender + # "1a8216ddb09440a8bc748cf8ca89ecec", # i:adhoc_13126 + property_codes=("i:adhoc_13126", "i:gender", "i:something"), + ), + ) + # might have a couple surveys in the bucket ... merge them all together + qualifying_questions = set(survey1.eligibility_criteria.property_codes) + + uqas = [ + UserQuestionAnswer( + question_id="5d6d9f3c03bb40bf9d0a24f306387d7c", + answer=("1",), + calc_answers={"i:gender": ("1",), "d:1": ("2",)}, + country_iso="us", + language_iso="eng", + property_code="gr:gender", + ), + UserQuestionAnswer( + question_id="c1309f099ab84a39b01200a56dac65cf", + answer=("50796", "50784"), + country_iso="us", + language_iso="eng", + property_code="d:600", + calc_answers={"d:600": ("50796", "50784")}, + ), + UserQuestionAnswer( + question_id="1a8216ddb09440a8bc748cf8ca89ecec", + answer=("3", "4"), + country_iso="us", + language_iso="eng", + property_code="i:adhoc_13126", + calc_answers={"i:adhoc_13126": ("3", "4")}, + ), + ] + uqad = dict() + for uqa in uqas: + for k, v in uqa.calc_answers.items(): + if k in qualifying_questions: + uqad[k] = uqa + uqad[uqa.property_code] = uqa + + question_ids = {uqa.question_id for uqa in uqad.values()} + qs = question_manager.get_multi_upk(question_ids) + # Sort question by LOWEST task_count (rarest) + qs = sorted(qs, key=lambda x: x.importance.task_count if x.importance else 0) + # qd = {q.id: q for q in qs} + + q = [x for x in qs if x.ext_question_id == "i:adhoc_13126"][0] + q.explanation_template = "You have been diagnosed with: {answer}." + q = [x for x in qs if x.ext_question_id == "gr:gender"][0] + q.explanation_template = "Your gender is {answer}." + + ecs = [] + for q in qs: + answer_code = uqad[q.ext_question_id].answer + answer_label = tuple([q.choices_text_lookup[ans] for ans in answer_code]) + explanation = None + if q.explanation_template: + explanation = q.explanation_template.format( + answer=", ".join(answer_label) + ) + sec = SurveyEligibilityCriterion( + question_id=q.id, + question_text=q.text, + qualifying_answer=answer_code, + qualifying_answer_label=answer_label, + explanation=explanation, + property_code=q.ext_question_id, + ) + print(sec) + ecs.append(sec) + bucket.eligibility_criteria = tuple(ecs) + print(bucket) + + +class TestSurveyStat: + def test( + self, + delete_buyers_surveys, + surveystat_manager, + survey_manager, + surveys_fixture, + ): + survey_manager.create_or_update(surveys_fixture) + ss = [ssa, ssb] + surveystat_manager.update_or_create(ss) + keys = [s.unique_key for s in ss] + res = surveystat_manager.filter_by_unique_keys(keys) + assert len(res) == 2 + assert res[0].conv_alpha == 10 + + ssa.conv_alpha = 11 + surveystat_manager.update_or_create([ssa]) + res = surveystat_manager.filter_by_unique_keys([ssa.unique_key]) + assert len(res) == 1 + assert res[0].conv_alpha == 11 + + @pytest.mark.skip() + def test_big( + self, + # delete_buyers_surveys, + surveystat_manager, + survey_manager, + surveys_fixture, + ): + survey = surveys_fixture[0].model_copy() + surveys = [] + for idx in range(20_000): + s = survey.model_copy() + s.survey_id = uuid.uuid4().hex + surveys.append(s) + + survey_manager.create_or_update(surveys) + # Realistically, we're going to have like, say 20k surveys + # and 99% of them will be updated each time + survey_stats = [] + for survey in surveys: + ss = ssa.model_copy() + ss.survey__survey_id = survey.survey_id + ss.quota_id = "__all__" + ss.score = 10 + survey_stats.append(ss) + print(len(survey_stats)) + print(survey_stats[12].natural_key, survey_stats[2000].natural_key) + print(f"----a-----: {datetime.now().isoformat()}") + res = surveystat_manager.update_or_create(survey_stats) + print(f"----b-----: {datetime.now().isoformat()}") + assert len(res) == 20_000 + return + + # 1,000 of the 20,000 are "new" + now = datetime.now(tz=timezone.utc) + for s in ss[:1000]: + s.survey__survey_id = "b" + s.updated_at = now + # 18,000 need to be updated (scores update or whatever) + for s in ss[1000:-1000]: + s.score = 20 + s.conv_alpha = 20 + s.conv_beta = 20 + s.updated_at = now + # and 1,000 don't change + print(f"----c-----: {datetime.now().isoformat()}") + res2 = surveystat_manager.update_or_create(ss) + print(f"----d-----: {datetime.now().isoformat()}") + assert len(res2) == 20_000 + + def test_ymsp( + self, + delete_buyers_surveys, + surveys_fixture, + survey_manager, + surveystat_manager, + ): + source = Source.TESTING + survey = surveys_fixture[0].model_copy() + surveys = [] + for idx in range(100): + s = survey.model_copy() + s.survey_id = uuid.uuid4().hex + surveys.append(s) + survey_stats = [] + for survey in surveys: + ss = ssa.model_copy() + ss.survey_survey_id = survey.survey_id + survey_stats.append(ss) + + surveystat_manager.update_surveystats_for_source( + source=source, surveys=surveys, survey_stats=survey_stats + ) + # UPDATE ------- + since = datetime.now(tz=timezone.utc) + print(f"{since=}") + + # 10 survey disappear + surveys = surveys[10:] + + # and 2 new ones are created + for idx in range(2): + s = survey.model_copy() + s.survey_id = uuid.uuid4().hex + surveys.append(s) + survey_stats = [] + for survey in surveys: + ss = ssa.model_copy() + ss.survey_survey_id = survey.survey_id + survey_stats.append(ss) + surveystat_manager.update_surveystats_for_source( + source=source, surveys=surveys, survey_stats=survey_stats + ) + + live_surveys = survey_manager.filter_by_source_live(source=source) + assert len(live_surveys) == 92 # 100 - 10 + 2 + + ss = surveystat_manager.filter_by_updated_since(since=since) + assert len(ss) == 102 # 92 existing + 10 not live + + ss = surveystat_manager.filter_by_live() + assert len(ss) == 92 + + def test_filter( + self, + delete_buyers_surveys, + surveys_fixture, + survey_manager, + surveystat_manager, + ): + surveys = [] + survey = surveys_fixture[0].model_copy() + survey.source = Source.TESTING + survey.survey_id = "a" + surveys.append(survey.model_copy()) + survey.source = Source.TESTING + survey.survey_id = "b" + surveys.append(survey.model_copy()) + survey.source = Source.TESTING2 + survey.survey_id = "b" + surveys.append(survey.model_copy()) + survey.source = Source.TESTING2 + survey.survey_id = "c" + surveys.append(survey.model_copy()) + # 4 surveys t:a, t:b, u:b, u:c + + survey_stats = [] + for survey in surveys: + ss = ssa.model_copy() + ss.survey_survey_id = survey.survey_id + ss.survey_source = survey.source + survey_stats.append(ss) + + surveystat_manager.update_surveystats_for_source( + source=Source.TESTING, surveys=surveys[:2], survey_stats=survey_stats[:2] + ) + surveystat_manager.update_surveystats_for_source( + source=Source.TESTING2, surveys=surveys[2:], survey_stats=survey_stats[2:] + ) + + survey_keys = [f"{s.source.value}:{s.survey_id}" for s in surveys] + res = surveystat_manager.filter(survey_keys=survey_keys, min_score=0.01) + assert len(res) == 4 + res = survey_manager.filter_by_keys(survey_keys) + assert len(res) == 4 + + res = surveystat_manager.filter(survey_keys=survey_keys[:2]) + assert len(res) == 2 + res = survey_manager.filter_by_keys(survey_keys[:2]) + assert len(res) == 2 diff --git a/tests/managers/thl/test_survey_penalty.py b/tests/managers/thl/test_survey_penalty.py new file mode 100644 index 0000000..4c7dc08 --- /dev/null +++ b/tests/managers/thl/test_survey_penalty.py @@ -0,0 +1,101 @@ +import uuid + +import pytest +from cachetools.keys import _HashedTuple + +from generalresearch.models import Source +from generalresearch.models.thl.survey.penalty import ( + BPSurveyPenalty, + TeamSurveyPenalty, +) + + +@pytest.fixture +def product_uuid() -> str: + # Nothing touches the db here, we don't need actual products or teams + return uuid.uuid4().hex + + +@pytest.fixture +def team_uuid() -> str: + # Nothing touches the db here, we don't need actual products or teams + return uuid.uuid4().hex + + +@pytest.fixture +def penalties(product_uuid, team_uuid): + return [ + BPSurveyPenalty( + source=Source.TESTING, survey_id="a", penalty=0.1, product_id=product_uuid + ), + BPSurveyPenalty( + source=Source.TESTING, survey_id="b", penalty=0.2, product_id=product_uuid + ), + TeamSurveyPenalty( + source=Source.TESTING, survey_id="a", penalty=1, team_id=team_uuid + ), + TeamSurveyPenalty( + source=Source.TESTING, survey_id="c", penalty=1, team_id=team_uuid + ), + # Source.TESTING:a is different from Source.TESTING2:a ! + BPSurveyPenalty( + source=Source.TESTING2, survey_id="a", penalty=0.5, product_id=product_uuid + ), + # For a random BP, should not do anything + BPSurveyPenalty( + source=Source.TESTING, survey_id="b", penalty=1, product_id=uuid.uuid4().hex + ), + ] + + +class TestSurveyPenalty: + def test(self, surveypenalty_manager, penalties, product_uuid, team_uuid): + surveypenalty_manager.set_penalties(penalties) + + res = surveypenalty_manager.get_penalties_for( + product_id=product_uuid, team_id=team_uuid + ) + assert res == {"t:a": 1.0, "t:b": 0.2, "t:c": 1, "u:a": 0.5} + + # We can update penalties for a marketplace and not erase them for another. + # But remember, marketplace is batched, so it'll overwrite the previous + # values for that marketplace + penalties = [ + BPSurveyPenalty( + source=Source.TESTING2, + survey_id="b", + penalty=0.1, + product_id=product_uuid, + ) + ] + surveypenalty_manager.set_penalties(penalties) + res = surveypenalty_manager.get_penalties_for( + product_id=product_uuid, team_id=team_uuid + ) + assert res == {"t:a": 1.0, "t:b": 0.2, "t:c": 1, "u:b": 0.1} + + # Team id doesn't exist, so it should return the product's penalties + team_id_random = uuid.uuid4().hex + surveypenalty_manager.cache.clear() + res = surveypenalty_manager.get_penalties_for( + product_id=product_uuid, team_id=team_id_random + ) + assert res == {"t:a": 0.1, "t:b": 0.2, "u:b": 0.1} + + # Assert it is cached (no redis lookup needed) + assert surveypenalty_manager.cache.currsize == 1 + res = surveypenalty_manager.get_penalties_for( + product_id=product_uuid, team_id=team_id_random + ) + assert res == {"t:a": 0.1, "t:b": 0.2, "u:b": 0.1} + assert surveypenalty_manager.cache.currsize == 1 + cached_key = tuple(list(list(surveypenalty_manager.cache.keys())[0])[1:]) + assert cached_key == tuple( + ["product_id", product_uuid, "team_id", team_id_random] + ) + + # Both don't exist, return nothing + res = surveypenalty_manager.get_penalties_for( + product_id=uuid.uuid4().hex, team_id=uuid.uuid4().hex + ) + assert res == {} diff --git a/tests/managers/thl/test_task_adjustment.py b/tests/managers/thl/test_task_adjustment.py new file mode 100644 index 0000000..839bbe1 --- /dev/null +++ b/tests/managers/thl/test_task_adjustment.py @@ -0,0 +1,346 @@ +import logging +from random import randint + +import pytest +from datetime import datetime, timezone, timedelta +from decimal import Decimal + +from generalresearch.models import Source +from generalresearch.models.thl.definitions import ( + Status, + StatusCode1, + WallAdjustedStatus, +) + + +@pytest.fixture() +def session_complete(session_with_tx_factory, user): + return session_with_tx_factory( + user=user, final_status=Status.COMPLETE, wall_req_cpi=Decimal("1.23") + ) + + +@pytest.fixture() +def session_complete_with_wallet(session_with_tx_factory, user_with_wallet): + return session_with_tx_factory( + user=user_with_wallet, + final_status=Status.COMPLETE, + wall_req_cpi=Decimal("1.23"), + ) + + +@pytest.fixture() +def session_fail(user, session_manager, wall_manager): + session = session_manager.create_dummy( + started=datetime.now(timezone.utc), user=user + ) + wall1 = wall_manager.create_dummy( + session_id=session.id, + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="72723", + req_cpi=Decimal("3.22"), + started=datetime.now(timezone.utc), + ) + wall_manager.finish( + wall=wall1, + status=Status.FAIL, + status_code_1=StatusCode1.PS_FAIL, + finished=wall1.started + timedelta(seconds=randint(a=60 * 2, b=60 * 10)), + ) + session.wall_events.append(wall1) + return session + + +class TestHandleRecons: + + def test_complete_to_recon( + self, + session_complete, + thl_lm, + task_adjustment_manager, + wall_manager, + session_manager, + caplog, + ): + print(wall_manager.pg_config.dsn) + mid = session_complete.uuid + wall_uuid = session_complete.wall_events[-1].uuid + s = session_complete + ledger_manager = thl_lm + + revenue_account = ledger_manager.get_account_task_complete_revenue() + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert ( + current_amount == 123 + ), "this is the amount of revenue from this task complete" + + bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet( + s.user.product + ) + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 117, "this is the amount paid to the BP" + + # Do the work here !! ----v + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + ) + assert ( + len(task_adjustment_manager.filter_by_wall_uuid(wall_uuid=wall_uuid)) == 1 + ) + + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert current_amount == 0, "after recon, it should be zeroed" + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 0, "this is the amount paid to the BP" + commission_account = ledger_manager.get_account_or_create_bp_commission( + s.user.product + ) + assert ledger_manager.get_account_balance(commission_account) == 0 + + # Now, say we get the exact same *adjust to incomplete* msg again. It should do nothing! + adjusted_timestamp = datetime.now(tz=timezone.utc) + wall = wall_manager.get_from_uuid(wall_uuid=wall_uuid) + with pytest.raises(match=" is already "): + wall_manager.adjust_status( + wall, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=adjusted_timestamp, + adjusted_cpi=Decimal(0), + ) + + session = session_manager.get_from_id(wall.session_id) + user = session.user + with caplog.at_level(logging.INFO): + ledger_manager.create_tx_task_adjustment( + wall, user=user, created=adjusted_timestamp + ) + assert "No transactions needed" in caplog.text + + session.wall_events = wall_manager.get_wall_events(session.id) + session.user.prefetch_product(wall_manager.pg_config) + + with caplog.at_level(logging.INFO, logger="Wall"): + session_manager.adjust_status(session) + assert "is already f" in caplog.text or "is already Status.FAIL" in caplog.text + + with caplog.at_level(logging.INFO, logger="LedgerManager"): + ledger_manager.create_tx_bp_adjustment(session, created=adjusted_timestamp) + assert "No transactions needed" in caplog.text + + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert current_amount == 0, "after recon, it should be zeroed" + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 0, "this is the amount paid to the BP" + + # And if we get an adj to fail, and handle it, it should do nothing at all + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + ) + assert ( + len(task_adjustment_manager.filter_by_wall_uuid(wall_uuid=wall_uuid)) == 1 + ) + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert current_amount == 0, "after recon, it should be zeroed" + + def test_fail_to_complete(self, session_fail, thl_lm, task_adjustment_manager): + s = session_fail + mid = session_fail.uuid + wall_uuid = session_fail.wall_events[-1].uuid + ledger_manager = thl_lm + + revenue_account = ledger_manager.get_account_task_complete_revenue() + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", mid + ) + assert ( + current_amount == 0 + ), "this is the amount of revenue from this task complete" + + bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet( + s.user.product + ) + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 0, "this is the amount paid to the BP" + + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + ) + + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert current_amount == 322, "after recon, we should be paid" + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 306, "this is the amount paid to the BP" + + # Now reverse it back to fail + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + ) + + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert current_amount == 0 + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 0 + + commission_account = ledger_manager.get_account_or_create_bp_commission( + s.user.product + ) + assert ledger_manager.get_account_balance(commission_account) == 0 + + def test_complete_already_complete( + self, session_complete, thl_lm, task_adjustment_manager + ): + s = session_complete + mid = session_complete.uuid + wall_uuid = session_complete.wall_events[-1].uuid + ledger_manager = thl_lm + + for _ in range(4): + # just run it 4 times to make sure nothing happens 4 times + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + ) + + revenue_account = ledger_manager.get_account_task_complete_revenue() + bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet( + s.user.product + ) + commission_account = ledger_manager.get_account_or_create_bp_commission( + s.user.product + ) + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert current_amount == 123 + assert ledger_manager.get_account_balance(commission_account) == 6 + + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 117 + + def test_incomplete_already_incomplete( + self, session_fail, thl_lm, task_adjustment_manager + ): + s = session_fail + mid = session_fail.uuid + wall_uuid = session_fail.wall_events[-1].uuid + ledger_manager = thl_lm + + for _ in range(4): + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + ) + + revenue_account = ledger_manager.get_account_task_complete_revenue() + bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet( + s.user.product + ) + commission_account = ledger_manager.get_account_or_create_bp_commission( + s.user.product + ) + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", mid + ) + assert current_amount == 0 + assert ledger_manager.get_account_balance(commission_account) == 0 + + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 0 + + def test_complete_to_recon_user_wallet( + self, + session_complete_with_wallet, + user_with_wallet, + thl_lm, + task_adjustment_manager, + ): + s = session_complete_with_wallet + mid = s.uuid + wall_uuid = s.wall_events[-1].uuid + ledger_manager = thl_lm + + revenue_account = ledger_manager.get_account_task_complete_revenue() + amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert amount == 123, "this is the amount of revenue from this task complete" + + bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet( + s.user.product + ) + user_wallet_account = ledger_manager.get_account_or_create_user_wallet(s.user) + commission_account = ledger_manager.get_account_or_create_bp_commission( + s.user.product + ) + amount = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert amount == 70, "this is the amount paid to the BP" + amount = ledger_manager.get_account_filtered_balance( + user_wallet_account, "thl_session", mid + ) + assert amount == 47, "this is the amount paid to the user" + assert ( + ledger_manager.get_account_balance(commission_account) == 6 + ), "earned commission" + + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + ) + + amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert amount == 0 + amount = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert amount == 0 + amount = ledger_manager.get_account_filtered_balance( + user_wallet_account, "thl_session", mid + ) + assert amount == 0 + assert ( + ledger_manager.get_account_balance(commission_account) == 0 + ), "earned commission" diff --git a/tests/managers/thl/test_task_status.py b/tests/managers/thl/test_task_status.py new file mode 100644 index 0000000..55c89c0 --- /dev/null +++ b/tests/managers/thl/test_task_status.py @@ -0,0 +1,696 @@ +import pytest +from datetime import datetime, timezone, timedelta +from decimal import Decimal + +from generalresearch.managers.thl.session import SessionManager +from generalresearch.models import Source +from generalresearch.models.thl.definitions import ( + Status, + WallAdjustedStatus, + StatusCode1, +) +from generalresearch.models.thl.product import ( + PayoutConfig, + UserWalletConfig, + PayoutTransformation, + PayoutTransformationPercentArgs, +) +from generalresearch.models.thl.session import Session, WallOut +from generalresearch.models.thl.task_status import TaskStatusResponse +from generalresearch.models.thl.user import User + + +start1 = datetime(2023, 2, 1, tzinfo=timezone.utc) +finish1 = start1 + timedelta(minutes=5) +recon1 = start1 + timedelta(days=20) +start2 = datetime(2023, 2, 2, tzinfo=timezone.utc) +finish2 = start2 + timedelta(minutes=5) +start3 = datetime(2023, 2, 3, tzinfo=timezone.utc) +finish3 = start3 + timedelta(minutes=5) + + +@pytest.fixture(scope="session") +def bp1(product_manager): + # user wallet disabled, payout xform NULL + return product_manager.create_dummy( + user_wallet_config=UserWalletConfig(enabled=False), + payout_config=PayoutConfig(), + ) + + +@pytest.fixture(scope="session") +def bp2(product_manager): + # user wallet disabled, payout xform 40% + return product_manager.create_dummy( + user_wallet_config=UserWalletConfig(enabled=False), + payout_config=PayoutConfig( + payout_transformation=PayoutTransformation( + f="payout_transformation_percent", + kwargs=PayoutTransformationPercentArgs(pct=0.4), + ) + ), + ) + + +@pytest.fixture(scope="session") +def bp3(product_manager): + # user wallet enabled, payout xform 50% + return product_manager.create_dummy( + user_wallet_config=UserWalletConfig(enabled=True), + payout_config=PayoutConfig( + payout_transformation=PayoutTransformation( + f="payout_transformation_percent", + kwargs=PayoutTransformationPercentArgs(pct=0.5), + ) + ), + ) + + +class TestTaskStatus: + + def test_task_status_complete_1( + self, + bp1, + user_factory, + finished_session_factory, + session_manager: SessionManager, + ): + # User Payout xform NULL + user1: User = user_factory(product=bp1) + s1: Session = finished_session_factory( + user=user1, started=start1, wall_req_cpi=Decimal(1), wall_count=2 + ) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s1.uuid, + "product_id": user1.product_id, + "product_user_id": user1.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s1.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "c", + "payout": 95, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "status_code_1": 14, + "payout_transformation": None, + } + ) + w1 = s1.wall_events[0] + wo1 = WallOut( + uuid=w1.uuid, + source=Source.TESTING, + buyer_id=None, + req_survey_id=w1.req_survey_id, + req_cpi=w1.req_cpi, + started=w1.started, + survey_id=w1.survey_id, + cpi=w1.cpi, + finished=w1.finished, + status=w1.status, + status_code_1=w1.status_code_1, + ) + w2 = s1.wall_events[1] + wo2 = WallOut( + uuid=w2.uuid, + source=Source.TESTING, + buyer_id=None, + req_survey_id=w2.req_survey_id, + req_cpi=w2.req_cpi, + started=w2.started, + survey_id=w2.survey_id, + cpi=w2.cpi, + finished=w2.finished, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + ) + expected_tsr.wall_events = [wo1, wo2] + tsr = session_manager.get_task_status_response(s1.uuid) + assert tsr == expected_tsr + + def test_task_status_complete_2( + self, bp2, user_factory, finished_session_factory, session_manager + ): + # User Payout xform 40% + user2: User = user_factory(product=bp2) + s2: Session = finished_session_factory( + user=user2, started=start1, wall_req_cpi=Decimal(1), wall_count=2 + ) + payout = 95 # 1.00 - 5% commission + user_payout = round(95 * 0.40) # 38 + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s2.uuid, + "product_id": user2.product_id, + "product_user_id": user2.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s2.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "c", + "payout": payout, + "user_payout": user_payout, + "payout_format": "${payout/100:.2f}", + "user_payout_string": "$0.38", + "status_code_1": 14, + "status_code_2": None, + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.4"}, + }, + } + ) + w1 = s2.wall_events[0] + wo1 = WallOut( + uuid=w1.uuid, + source=Source.TESTING, + buyer_id=None, + req_survey_id=w1.req_survey_id, + req_cpi=w1.req_cpi, + started=w1.started, + survey_id=w1.survey_id, + cpi=w1.cpi, + finished=w1.finished, + status=w1.status, + status_code_1=w1.status_code_1, + user_cpi=Decimal("0.38"), + user_cpi_string="$0.38", + ) + w2 = s2.wall_events[1] + wo2 = WallOut( + uuid=w2.uuid, + source=Source.TESTING, + buyer_id=None, + req_survey_id=w2.req_survey_id, + req_cpi=w2.req_cpi, + started=w2.started, + survey_id=w2.survey_id, + cpi=w2.cpi, + finished=w2.finished, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + user_cpi=Decimal("0.38"), + user_cpi_string="$0.38", + ) + expected_tsr.wall_events = [wo1, wo2] + + tsr = session_manager.get_task_status_response(s2.uuid) + assert tsr == expected_tsr + + def test_task_status_complete_3( + self, bp3, user_factory, finished_session_factory, session_manager + ): + # Wallet enabled User Payout xform 50% (the response is identical + # to the user wallet disabled w same xform) + user3: User = user_factory(product=bp3) + s3: Session = finished_session_factory( + user=user3, started=start1, wall_req_cpi=Decimal(1) + ) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s3.uuid, + "product_id": user3.product_id, + "product_user_id": user3.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s3.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "c", + "payout": 95, + "user_payout": 48, + "payout_format": "${payout/100:.2f}", + "user_payout_string": "$0.48", + "kwargs": {}, + "status_code_1": 14, + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.5"}, + }, + } + ) + tsr = session_manager.get_task_status_response(s3.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_fail( + self, bp1, user_factory, finished_session_factory, session_manager + ): + # User Payout xform NULL: user payout is None always + user1: User = user_factory(product=bp1) + s1: Session = finished_session_factory( + user=user1, started=start1, final_status=Status.FAIL + ) + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s1.uuid, + "product_id": user1.product_id, + "product_user_id": user1.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s1.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "f", + "payout": 0, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "kwargs": {}, + "status_code_1": s1.status_code_1.value, + "status_code_2": None, + "adjusted_status": None, + "adjusted_timestamp": None, + "adjusted_payout": None, + "adjusted_user_payout": None, + "adjusted_user_payout_string": None, + "payout_transformation": None, + } + ) + tsr = session_manager.get_task_status_response(s1.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_fail_xform( + self, bp2, user_factory, finished_session_factory, session_manager + ): + # User Payout xform 40%: user_payout is 0 (not None) + + user: User = user_factory(product=bp2) + s: Session = finished_session_factory( + user=user, started=start1, final_status=Status.FAIL + ) + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "f", + "payout": 0, + "user_payout": 0, + "payout_format": "${payout/100:.2f}", + "user_payout_string": "$0.00", + "kwargs": {}, + "status_code_1": s.status_code_1.value, + "status_code_2": None, + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.4"}, + }, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_abandon( + self, bp1, user_factory, session_factory, session_manager + ): + # User Payout xform NULL: all payout fields are None + user: User = user_factory(product=bp1) + s = session_factory(user=user, started=start1) + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": None, + "status": None, + "payout": None, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "kwargs": {}, + "status_code_1": None, + "status_code_2": None, + "adjusted_status": None, + "adjusted_timestamp": None, + "adjusted_payout": None, + "adjusted_user_payout": None, + "adjusted_user_payout_string": None, + "payout_transformation": None, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_abandon_xform( + self, bp2, user_factory, session_factory, session_manager + ): + # User Payout xform 40%: all payout fields are None (same as when payout xform is null) + user: User = user_factory(product=bp2) + s = session_factory(user=user, started=start1) + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": None, + "status": None, + "payout": None, + "user_payout": None, + "payout_format": "${payout/100:.2f}", + "user_payout_string": None, + "kwargs": {}, + "status_code_1": None, + "status_code_2": None, + "adjusted_status": None, + "adjusted_timestamp": None, + "adjusted_payout": None, + "adjusted_user_payout": None, + "adjusted_user_payout_string": None, + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.4"}, + }, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_adj_fail( + self, + bp1, + user_factory, + finished_session_factory, + wall_manager, + session_manager, + ): + # Complete -> Fail + # User Payout xform NULL: adjusted_user_* and user_* is still all None + user: User = user_factory(product=bp1) + s: Session = finished_session_factory( + user=user, started=start1, wall_req_cpi=Decimal(1) + ) + wall_manager.adjust_status( + wall=s.wall_events[-1], + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=recon1, + ) + session_manager.adjust_status(s) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "c", + "payout": 95, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "kwargs": {}, + "status_code_1": StatusCode1.COMPLETE.value, + "status_code_2": None, + "adjusted_status": WallAdjustedStatus.ADJUSTED_TO_FAIL.value, + "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "adjusted_payout": 0, + "adjusted_user_payout": None, + "adjusted_user_payout_string": None, + "payout_transformation": None, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_adj_fail_xform( + self, + bp2, + user_factory, + finished_session_factory, + wall_manager, + session_manager, + ): + # Complete -> Fail + # User Payout xform 40%: adjusted_user_payout is 0 (not null) + user: User = user_factory(product=bp2) + s: Session = finished_session_factory( + user=user, started=start1, wall_req_cpi=Decimal(1) + ) + wall_manager.adjust_status( + wall=s.wall_events[-1], + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=recon1, + ) + session_manager.adjust_status(s) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "c", + "payout": 95, + "user_payout": 38, + "payout_format": "${payout/100:.2f}", + "user_payout_string": "$0.38", + "kwargs": {}, + "status_code_1": StatusCode1.COMPLETE.value, + "status_code_2": None, + "adjusted_status": WallAdjustedStatus.ADJUSTED_TO_FAIL.value, + "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "adjusted_payout": 0, + "adjusted_user_payout": 0, + "adjusted_user_payout_string": "$0.00", + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.4"}, + }, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_adj_complete_from_abandon( + self, + bp1, + user_factory, + session_factory, + wall_manager, + session_manager, + ): + # User Payout xform NULL + user: User = user_factory(product=bp1) + s: Session = session_factory( + user=user, + started=start1, + wall_req_cpi=Decimal(1), + wall_count=2, + final_status=Status.ABANDON, + ) + w = s.wall_events[-1] + wall_manager.adjust_status( + wall=w, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w.cpi, + adjusted_timestamp=recon1, + ) + session_manager.adjust_status(s) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": None, + "status": None, + "payout": None, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "kwargs": {}, + "status_code_1": None, + "status_code_2": None, + "adjusted_status": WallAdjustedStatus.ADJUSTED_TO_COMPLETE.value, + "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "adjusted_payout": 95, + "adjusted_user_payout": None, + "adjusted_user_payout_string": None, + "payout_transformation": None, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_adj_complete_from_abandon_xform( + self, + bp2, + user_factory, + session_factory, + wall_manager, + session_manager, + ): + # User Payout xform 40% + user: User = user_factory(product=bp2) + s: Session = session_factory( + user=user, + started=start1, + wall_req_cpi=Decimal(1), + wall_count=2, + final_status=Status.ABANDON, + ) + w = s.wall_events[-1] + wall_manager.adjust_status( + wall=w, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w.cpi, + adjusted_timestamp=recon1, + ) + session_manager.adjust_status(s) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": None, + "status": None, + "payout": None, + "user_payout": None, + "payout_format": "${payout/100:.2f}", + "user_payout_string": None, + "kwargs": {}, + "status_code_1": None, + "status_code_2": None, + "adjusted_status": WallAdjustedStatus.ADJUSTED_TO_COMPLETE.value, + "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "adjusted_payout": 95, + "adjusted_user_payout": 38, + "adjusted_user_payout_string": "$0.38", + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.4"}, + }, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_adj_complete_from_fail( + self, + bp1, + user_factory, + finished_session_factory, + wall_manager, + session_manager, + ): + # User Payout xform NULL + user: User = user_factory(product=bp1) + s: Session = finished_session_factory( + user=user, + started=start1, + wall_req_cpi=Decimal(1), + wall_count=2, + final_status=Status.FAIL, + ) + w = s.wall_events[-1] + wall_manager.adjust_status( + wall=w, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w.cpi, + adjusted_timestamp=recon1, + ) + session_manager.adjust_status(s) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "f", + "payout": 0, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "kwargs": {}, + "status_code_1": s.status_code_1.value, + "status_code_2": None, + "adjusted_status": "ac", + "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "adjusted_payout": 95, + "adjusted_user_payout": None, + "adjusted_user_payout_string": None, + "payout_transformation": None, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_adj_complete_from_fail_xform( + self, + bp2, + user_factory, + finished_session_factory, + wall_manager, + session_manager, + ): + # User Payout xform 40% + user: User = user_factory(product=bp2) + s: Session = finished_session_factory( + user=user, + started=start1, + wall_req_cpi=Decimal(1), + wall_count=2, + final_status=Status.FAIL, + ) + w = s.wall_events[-1] + wall_manager.adjust_status( + wall=w, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w.cpi, + adjusted_timestamp=recon1, + ) + session_manager.adjust_status(s) + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "f", + "payout": 0, + "user_payout": 0, + "payout_format": "${payout/100:.2f}", + "user_payout_string": "$0.00", + "kwargs": {}, + "status_code_1": s.status_code_1.value, + "status_code_2": None, + "adjusted_status": "ac", + "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "adjusted_payout": 95, + "adjusted_user_payout": 38, + "adjusted_user_payout_string": "$0.38", + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.4"}, + }, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr diff --git a/tests/managers/thl/test_user_manager/__init__.py b/tests/managers/thl/test_user_manager/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/managers/thl/test_user_manager/__init__.py diff --git a/tests/managers/thl/test_user_manager/test_base.py b/tests/managers/thl/test_user_manager/test_base.py new file mode 100644 index 0000000..3c7ee38 --- /dev/null +++ b/tests/managers/thl/test_user_manager/test_base.py @@ -0,0 +1,274 @@ +import logging +from datetime import datetime, timezone +from random import randint +from uuid import uuid4 + +import pytest + +from generalresearch.managers.thl.user_manager import ( + UserCreateNotAllowedError, + get_bp_user_create_limit_hourly, +) +from generalresearch.managers.thl.user_manager.rate_limit import ( + RateLimitItemPerHourConstantKey, +) +from generalresearch.models.thl.product import UserCreateConfig, Product +from generalresearch.models.thl.user import User +from test_utils.models.conftest import ( + user, + product, + user_manager, + product_manager, +) + +logger = logging.getLogger() + + +class TestUserManager: + + def test_copying_lru_cache(self, user_manager, user): + # Before adding the deepcopy_return decorator, this would fail b/c the returned user + # is mutable, and it would mutate in the cache + # user_manager = self.get_user_manager() + + user_manager.clear_user_inmemory_cache(user) + u = user_manager.get_user(user_id=user.user_id) + assert not u.blocked + + u.blocked = True + u = user_manager.get_user(user_id=user.user_id) + assert not u.blocked + + def test_get_user_no_inmemory(self, user, user_manager): + user_manager.clear_user_inmemory_cache(user) + user_manager.get_user.__wrapped__.cache_clear() + u = user_manager.get_user(user_id=user.user_id) + # this should hit mysql + assert u == user + + cache_info = user_manager.get_user.__wrapped__.cache_info() + assert cache_info.hits == 0, cache_info + assert cache_info.misses == 1, cache_info + + # this should hit the lru cache + u = user_manager.get_user(user_id=user.user_id) + assert u == user + + cache_info = user_manager.get_user.__wrapped__.cache_info() + assert cache_info.hits == 1, cache_info + assert cache_info.misses == 1, cache_info + + def test_get_user_with_inmemory(self, user_manager, user): + # user_manager = self.get_user_manager() + + user_manager.set_user_inmemory_cache(user) + user_manager.get_user.__wrapped__.cache_clear() + u = user_manager.get_user(user_id=user.user_id) + # this should hit inmemory cache + assert u == user + + cache_info = user_manager.get_user.__wrapped__.cache_info() + assert cache_info.hits == 0, cache_info + assert cache_info.misses == 1, cache_info + + # this should hit the lru cache + u = user_manager.get_user(user_id=user.user_id) + assert u == user + + cache_info = user_manager.get_user.__wrapped__.cache_info() + assert cache_info.hits == 1, cache_info + assert cache_info.misses == 1, cache_info + + +class TestBlockUserManager: + + def test_block_user(self, product, user_manager): + product_user_id = f"user-{uuid4().hex[:10]}" + + # mysql_user_manager to skip user creation limit check + user: User = user_manager.mysql_user_manager.create_user( + product_id=product.id, product_user_id=product_user_id + ) + assert not user.blocked + + # get_user to make sure caches are populated + user = user_manager.get_user( + product_id=product.id, product_user_id=product_user_id + ) + assert not user.blocked + + assert user_manager.block_user(user) is True + assert user.blocked + + user = user_manager.get_user( + product_id=product.id, product_user_id=product_user_id + ) + assert user.blocked + + user = user_manager.get_user(user_id=user.user_id) + assert user.blocked + + def test_block_user_whitelist(self, product, user_manager, thl_web_rw): + product_user_id = f"user-{uuid4().hex[:10]}" + + # mysql_user_manager to skip user creation limit check + user: User = user_manager.mysql_user_manager.create_user( + product_id=product.id, product_user_id=product_user_id + ) + assert not user.blocked + + now = datetime.now(tz=timezone.utc) + # Adds user to whitelist + thl_web_rw.execute_write( + """ + INSERT INTO userprofile_userstat + (user_id, key, value, date) + VALUES (%(user_id)s, 'USER_HEALTH.access_control', 1, %(date)s) + ON CONFLICT (user_id, key) DO UPDATE SET value=1""", + params={"user_id": user.user_id, "date": now}, + ) + assert user_manager.is_whitelisted(user) + assert user_manager.block_user(user) is False + assert not user.blocked + + +class TestCreateUserManager: + + def test_create_user(self, product_manager, thl_web_rw, user_manager): + product: Product = product_manager.create_dummy( + user_create_config=UserCreateConfig( + min_hourly_create_limit=10, max_hourly_create_limit=69 + ), + ) + + product_user_id = f"user-{uuid4().hex[:10]}" + + user: User = user_manager.mysql_user_manager.create_user( + product_id=product.id, product_user_id=product_user_id + ) + assert isinstance(user, User) + assert user.product_id == product.id + assert user.product_user_id == product_user_id + + assert user.user_id is not None + assert user.uuid is not None + + # make sure thl_user row is created + res_thl_user = thl_web_rw.execute_sql_query( + query=f""" + SELECT * + FROM thl_user AS u + WHERE u.id = %s + """, + params=[user.user_id], + ) + + assert len(res_thl_user) == 1 + + u2 = user_manager.get_user( + product_id=product.id, product_user_id=product_user_id + ) + assert u2.user_id == user.user_id + assert u2.uuid == user.uuid + + def test_create_user_integrity_error(self, product_manager, user_manager, caplog): + product: Product = product_manager.create_dummy( + product_id=uuid4().hex, + team_id=uuid4().hex, + name=f"Test Product ID #{uuid4().hex[:6]}", + user_create_config=UserCreateConfig( + min_hourly_create_limit=10, max_hourly_create_limit=69 + ), + ) + + product_user_id = f"user-{uuid4().hex[:10]}" + rand_msg = f"log-{uuid4().hex}" + + with caplog.at_level(logging.INFO): + logger.info(rand_msg) + user1 = user_manager.mysql_user_manager.create_user( + product_id=product.id, product_user_id=product_user_id + ) + + assert len(caplog.records) == 1 + assert caplog.records[0].getMessage() == rand_msg + + # Should cause a constraint error, triggering a lookup instead + with caplog.at_level(logging.INFO): + user2 = user_manager.mysql_user_manager.create_user( + product_id=product.id, product_user_id=product_user_id + ) + + assert len(caplog.records) == 3 + assert caplog.records[0].getMessage() == rand_msg + assert ( + caplog.records[1].getMessage() + == f"mysql_user_manager.create_user_new integrity error: {product.id} {product_user_id}" + ) + assert ( + caplog.records[2].getMessage() + == f"get_user_from_mysql: {product.id}, {product_user_id}, None, None" + ) + + assert user1 == user2 + + def test_raise_allow_user_create(self, product_manager, user_manager): + rand_num = randint(25, 200) + product: Product = product_manager.create_dummy( + product_id=uuid4().hex, + team_id=uuid4().hex, + name=f"Test Product ID #{uuid4().hex[:6]}", + user_create_config=UserCreateConfig( + min_hourly_create_limit=rand_num, + max_hourly_create_limit=rand_num, + ), + ) + + instance: Product = user_manager.product_manager.get_by_uuid( + product_uuid=product.id + ) + + # get_bp_user_create_limit_hourly is dynamically generated, make sure + # we use this value in our tests and not the + # UserCreateConfig.max_hourly_create_limit value + rl_value: int = get_bp_user_create_limit_hourly(product=instance) + + # This is a randomly generated product_id, which means we'll always + # use the global defaults + assert rand_num == instance.user_create_config.min_hourly_create_limit + assert rand_num == instance.user_create_config.max_hourly_create_limit + assert rand_num == rl_value + + rl = RateLimitItemPerHourConstantKey(rl_value) + assert str(rl) == f"{rand_num} per 1 hour" + + key = rl.key_for("thl-grpc", "allow_user_create", instance.id) + assert key == f"LIMITER/thl-grpc/allow_user_create/{instance.id}" + + # make sure we clear the key or subsequent tests will fail + user_manager.user_manager_limiter.storage.clear(key=key) + + n = 0 + with pytest.raises(expected_exception=UserCreateNotAllowedError) as cm: + for n, _ in enumerate(range(rl_value + 5)): + user_manager.user_manager_limiter.raise_allow_user_create( + product=product + ) + assert rl_value == n + + +class TestUserManagerMethods: + + def test_audit_log(self, user_manager, user, audit_log_manager): + from generalresearch.models.thl.userhealth import AuditLog + + res = audit_log_manager.filter_by_user_id(user_id=user.user_id) + assert len(res) == 0 + + msg = uuid4().hex + user_manager.audit_log(user=user, level=30, event_type=msg) + + res = audit_log_manager.filter_by_user_id(user_id=user.user_id) + assert len(res) == 1 + assert isinstance(res[0], AuditLog) + assert res[0].event_type == msg diff --git a/tests/managers/thl/test_user_manager/test_mysql.py b/tests/managers/thl/test_user_manager/test_mysql.py new file mode 100644 index 0000000..0313bbf --- /dev/null +++ b/tests/managers/thl/test_user_manager/test_mysql.py @@ -0,0 +1,25 @@ +from test_utils.models.conftest import user, user_manager + + +class TestUserManagerMysqlNew: + + def test_get_notset(self, user_manager): + assert ( + user_manager.mysql_user_manager.get_user_from_mysql(user_id=-3105) is None + ) + + def test_get_user_id(self, user, user_manager): + assert ( + user_manager.mysql_user_manager.get_user_from_mysql(user_id=user.user_id) + == user + ) + + def test_get_uuid(self, user, user_manager): + u = user_manager.mysql_user_manager.get_user_from_mysql(user_uuid=user.uuid) + assert u == user + + def test_get_ubp(self, user, user_manager): + u = user_manager.mysql_user_manager.get_user_from_mysql( + product_id=user.product_id, product_user_id=user.product_user_id + ) + assert u == user diff --git a/tests/managers/thl/test_user_manager/test_redis.py b/tests/managers/thl/test_user_manager/test_redis.py new file mode 100644 index 0000000..a69519e --- /dev/null +++ b/tests/managers/thl/test_user_manager/test_redis.py @@ -0,0 +1,80 @@ +import pytest + +from generalresearch.managers.base import Permission + + +class TestUserManagerRedis: + + def test_get_notset(self, user_manager, user): + user_manager.clear_user_inmemory_cache(user=user) + assert user_manager.redis_user_manager.get_user(user_id=user.user_id) is None + + def test_get_user_id(self, user_manager, user): + user_manager.redis_user_manager.set_user(user=user) + + assert user_manager.redis_user_manager.get_user(user_id=user.user_id) == user + + def test_get_uuid(self, user_manager, user): + user_manager.redis_user_manager.set_user(user=user) + + assert user_manager.redis_user_manager.get_user(user_uuid=user.uuid) == user + + def test_get_ubp(self, user_manager, user): + user_manager.redis_user_manager.set_user(user=user) + + assert ( + user_manager.redis_user_manager.get_user( + product_id=user.product_id, product_user_id=user.product_user_id + ) + == user + ) + + @pytest.mark.skip(reason="TODO") + def test_set(self): + # I mean, the sets are implicitly tested by the get tests above. no point + pass + + def test_get_with_cache_prefix(self, settings, user, thl_web_rw, thl_web_rr): + """ + Confirm the prefix functionality is working; we do this so it + is easier to migrate between any potentially breaking versions + if we don't want any broken keys; not as important after + pydantic usage... + """ + from generalresearch.managers.thl.user_manager.user_manager import ( + UserManager, + ) + + um1 = UserManager( + pg_config=thl_web_rw, + pg_config_rr=thl_web_rr, + sql_permissions=[Permission.UPDATE, Permission.CREATE], + redis=settings.redis, + redis_timeout=settings.redis_timeout, + ) + + um2 = UserManager( + pg_config=thl_web_rw, + pg_config_rr=thl_web_rr, + sql_permissions=[Permission.UPDATE, Permission.CREATE], + redis=settings.redis, + redis_timeout=settings.redis_timeout, + cache_prefix="user-lookup-v2", + ) + + um1.get_or_create_user( + product_id=user.product_id, product_user_id=user.product_user_id + ) + um2.get_or_create_user( + product_id=user.product_id, product_user_id=user.product_user_id + ) + + res1 = um1.redis_user_manager.client.get(f"user-lookup:user_id:{user.user_id}") + assert res1 is not None + + res2 = um2.redis_user_manager.client.get( + f"user-lookup-v2:user_id:{user.user_id}" + ) + assert res2 is not None + + assert res1 == res2 diff --git a/tests/managers/thl/test_user_manager/test_user_fetch.py b/tests/managers/thl/test_user_manager/test_user_fetch.py new file mode 100644 index 0000000..a4b3d57 --- /dev/null +++ b/tests/managers/thl/test_user_manager/test_user_fetch.py @@ -0,0 +1,48 @@ +from uuid import uuid4 + +import pytest + +from generalresearch.models.thl.user import User +from test_utils.models.conftest import product, user_manager, user_factory + + +class TestUserManagerFetch: + + def test_fetch(self, user_factory, product, user_manager): + user1: User = user_factory(product=product) + user2: User = user_factory(product=product) + res = user_manager.fetch_by_bpuids( + product_id=product.uuid, + product_user_ids=[user1.product_user_id, user2.product_user_id], + ) + assert len(res) == 2 + + res = user_manager.fetch(user_ids=[user1.user_id, user2.user_id]) + assert len(res) == 2 + + res = user_manager.fetch(user_uuids=[user1.uuid, user2.uuid]) + assert len(res) == 2 + + # filter including bogus values + res = user_manager.fetch(user_uuids=[user1.uuid, uuid4().hex]) + assert len(res) == 1 + + res = user_manager.fetch(user_uuids=[uuid4().hex]) + assert len(res) == 0 + + def test_fetch_invalid(self, user_manager): + with pytest.raises(AssertionError) as e: + user_manager.fetch(user_uuids=[], user_ids=None) + assert "Must pass ONE of user_ids, user_uuids" in str(e.value) + + with pytest.raises(AssertionError) as e: + user_manager.fetch(user_uuids=uuid4().hex) + assert "must pass a collection of user_uuids" in str(e.value) + + with pytest.raises(AssertionError) as e: + user_manager.fetch(user_uuids=[uuid4().hex], user_ids=[1, 2, 3]) + assert "Must pass ONE of user_ids, user_uuids" in str(e.value) + + with pytest.raises(AssertionError) as e: + user_manager.fetch(user_ids=list(range(501))) + assert "limit 500 user_ids" in str(e.value) diff --git a/tests/managers/thl/test_user_manager/test_user_metadata.py b/tests/managers/thl/test_user_manager/test_user_metadata.py new file mode 100644 index 0000000..91dc16a --- /dev/null +++ b/tests/managers/thl/test_user_manager/test_user_metadata.py @@ -0,0 +1,88 @@ +from uuid import uuid4 + +import pytest + +from generalresearch.models.thl.user_profile import UserMetadata +from test_utils.models.conftest import user, user_manager, user_factory + + +class TestUserMetadataManager: + + def test_get_notset(self, user, user_manager, user_metadata_manager): + # The row in the db won't exist. It just returns the default obj with everything None (except for the user_id) + um1 = user_metadata_manager.get(user_id=user.user_id) + assert um1 == UserMetadata(user_id=user.user_id) + + def test_create(self, user_factory, product, user_metadata_manager): + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + + email_address = f"{uuid4().hex}@example.com" + um = UserMetadata(user_id=u1.user_id, email_address=email_address) + # This happens in the model itself, nothing to do with the manager (a model_validator) + assert um.email_sha256 is not None + + user_metadata_manager.update(um) + um2 = user_metadata_manager.get(email_address=email_address) + assert um == um2 + + def test_create_no_email(self, product, user_factory, user_metadata_manager): + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + um = UserMetadata(user_id=u1.user_id) + assert um.email_sha256 is None + + user_metadata_manager.update(um) + um2 = user_metadata_manager.get(user_id=u1.user_id) + assert um == um2 + + def test_update(self, product, user_factory, user_metadata_manager): + from generalresearch.models.thl.user import User + + u: User = user_factory(product=product) + + email_address = f"{uuid4().hex}@example1.com" + um = UserMetadata(user_id=u.user_id, email_address=email_address) + user_metadata_manager.update(user_metadata=um) + + um.email_address = email_address.replace("example1", "example2") + user_metadata_manager.update(user_metadata=um) + + um2 = user_metadata_manager.get(email_address=um.email_address) + assert um2.email_address != email_address + + assert um2 == UserMetadata( + user_id=u.user_id, + email_address=email_address.replace("example1", "example2"), + ) + + def test_filter(self, user_factory, product, user_metadata_manager): + from generalresearch.models.thl.user import User + + user1: User = user_factory(product=product) + user2: User = user_factory(product=product) + + email_address = f"{uuid4().hex}@example.com" + res = user_metadata_manager.filter(email_addresses=[email_address]) + assert len(res) == 0 + + # Create 2 user metadata with the same email address + user_metadata_manager.update( + user_metadata=UserMetadata( + user_id=user1.user_id, email_address=email_address + ) + ) + user_metadata_manager.update( + user_metadata=UserMetadata( + user_id=user2.user_id, email_address=email_address + ) + ) + + res = user_metadata_manager.filter(email_addresses=[email_address]) + assert len(res) == 2 + + with pytest.raises(expected_exception=ValueError) as e: + res = user_metadata_manager.get(email_address=email_address) + assert "More than 1 result returned!" in str(e.value) diff --git a/tests/managers/thl/test_user_streak.py b/tests/managers/thl/test_user_streak.py new file mode 100644 index 0000000..7728f9f --- /dev/null +++ b/tests/managers/thl/test_user_streak.py @@ -0,0 +1,225 @@ +import copy +from datetime import datetime, timezone, timedelta, date +from decimal import Decimal +from zoneinfo import ZoneInfo + +import pytest + +from generalresearch.managers.thl.user_streak import compute_streaks_from_days +from generalresearch.models.thl.definitions import StatusCode1, Status +from generalresearch.models.thl.user_streak import ( + UserStreak, + StreakState, + StreakPeriod, + StreakFulfillment, +) + + +def test_compute_streaks_from_days(): + days = [ + date(2026, 1, 1), + date(2026, 1, 4), + date(2026, 1, 5), + date(2026, 1, 6), + date(2026, 1, 8), + date(2026, 1, 9), + date(2026, 2, 11), + date(2026, 2, 12), + ] + + # Active + today = date(2026, 2, 12) + res = compute_streaks_from_days(days, "us", period=StreakPeriod.DAY, today=today) + assert res == (2, 3, StreakState.ACTIVE, date(2026, 2, 12)) + + # At Risk + today = date(2026, 2, 13) + res = compute_streaks_from_days(days, "us", period=StreakPeriod.DAY, today=today) + assert res == (2, 3, StreakState.AT_RISK, date(2026, 2, 12)) + + # Broken + today = date(2026, 2, 14) + res = compute_streaks_from_days(days, "us", period=StreakPeriod.DAY, today=today) + assert res == (0, 3, StreakState.BROKEN, date(2026, 2, 12)) + + # Monthly, active + today = date(2026, 2, 14) + res = compute_streaks_from_days(days, "us", period=StreakPeriod.MONTH, today=today) + assert res == (2, 2, StreakState.ACTIVE, date(2026, 2, 1)) + + # monthly, at risk + today = date(2026, 3, 1) + res = compute_streaks_from_days(days, "us", period=StreakPeriod.MONTH, today=today) + assert res == (2, 2, StreakState.AT_RISK, date(2026, 2, 1)) + + # monthly, broken + today = date(2026, 4, 1) + res = compute_streaks_from_days(days, "us", period=StreakPeriod.MONTH, today=today) + assert res == (0, 2, StreakState.BROKEN, date(2026, 2, 1)) + + +@pytest.fixture +def broken_active_streak(user): + return [ + UserStreak( + period=StreakPeriod.DAY, + fulfillment=StreakFulfillment.ACTIVE, + country_iso="us", + user_id=user.user_id, + last_fulfilled_period_start=date(2025, 2, 11), + current_streak=0, + longest_streak=1, + state=StreakState.BROKEN, + ), + UserStreak( + period=StreakPeriod.WEEK, + fulfillment=StreakFulfillment.ACTIVE, + country_iso="us", + user_id=user.user_id, + last_fulfilled_period_start=date(2025, 2, 10), + current_streak=0, + longest_streak=1, + state=StreakState.BROKEN, + ), + UserStreak( + period=StreakPeriod.MONTH, + fulfillment=StreakFulfillment.ACTIVE, + country_iso="us", + user_id=user.user_id, + last_fulfilled_period_start=date(2025, 2, 1), + current_streak=0, + longest_streak=1, + state=StreakState.BROKEN, + ), + ] + + +def create_session_fail(session_manager, start, user): + session = session_manager.create_dummy(started=start, country_iso="us", user=user) + session_manager.finish_with_status( + session, + finished=start + timedelta(minutes=1), + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + ) + + +def create_session_complete(session_manager, start, user): + session = session_manager.create_dummy(started=start, country_iso="us", user=user) + session_manager.finish_with_status( + session, + finished=start + timedelta(minutes=1), + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal(1), + ) + + +def test_user_streak_empty(user_streak_manager, user): + streaks = user_streak_manager.get_user_streaks( + user_id=user.user_id, country_iso="us" + ) + assert streaks == [] + + +def test_user_streaks_active_broken( + user_streak_manager, user, session_manager, broken_active_streak +): + # Testing active streak, but broken (not today or yesterday) + start1 = datetime(2025, 2, 12, tzinfo=timezone.utc) + end1 = start1 + timedelta(minutes=1) + + # abandon counts as inactive + session = session_manager.create_dummy(started=start1, country_iso="us", user=user) + streak = user_streak_manager.get_user_streaks(user_id=user.user_id) + assert streak == [] + + # session start fail counts as inactive + session_manager.finish_with_status( + session, + finished=end1, + status=Status.FAIL, + status_code_1=StatusCode1.SESSION_START_QUALITY_FAIL, + ) + streak = user_streak_manager.get_user_streaks( + user_id=user.user_id, country_iso="us" + ) + assert streak == [] + + # Create another as a real failure + create_session_fail(session_manager, start1, user) + + # This is going to be the day before b/c midnight utc is 8pm US + last_active_day = start1.astimezone(ZoneInfo("America/New_York")).date() + assert last_active_day == date(2025, 2, 11) + expected_streaks = broken_active_streak.copy() + streaks = user_streak_manager.get_user_streaks(user_id=user.user_id) + assert streaks == expected_streaks + + # Create another the next day + start2 = start1 + timedelta(days=1) + create_session_fail(session_manager, start2, user) + + last_active_day = start2.astimezone(ZoneInfo("America/New_York")).date() + expected_streaks = copy.deepcopy(broken_active_streak) + expected_streaks[0].longest_streak = 2 + expected_streaks[0].last_fulfilled_period_start = last_active_day + + streaks = user_streak_manager.get_user_streaks( + user_id=user.user_id, country_iso="us" + ) + assert streaks == expected_streaks + + +def test_user_streak_complete_active(user_streak_manager, user, session_manager): + """Testing active streak that is today""" + + # They completed yesterday NY time. Today isn't over so streak is pending + start1 = datetime.now(tz=ZoneInfo("America/New_York")) - timedelta(days=1) + create_session_complete(session_manager, start1.astimezone(tz=timezone.utc), user) + + last_complete_day = start1.date() + expected_streak = UserStreak( + longest_streak=1, + current_streak=1, + state=StreakState.AT_RISK, + last_fulfilled_period_start=last_complete_day, + country_iso="us", + user_id=user.user_id, + fulfillment=StreakFulfillment.COMPLETE, + period=StreakPeriod.DAY, + ) + streaks = user_streak_manager.get_user_streaks( + user_id=user.user_id, country_iso="us" + ) + streak = [ + s + for s in streaks + if s.fulfillment == StreakFulfillment.COMPLETE and s.period == StreakPeriod.DAY + ][0] + assert streak == expected_streak + + # And now they complete today + start2 = datetime.now(tz=ZoneInfo("America/New_York")) + create_session_complete(session_manager, start2.astimezone(tz=timezone.utc), user) + last_complete_day = start2.date() + expected_streak = UserStreak( + longest_streak=2, + current_streak=2, + state=StreakState.ACTIVE, + last_fulfilled_period_start=last_complete_day, + country_iso="us", + user_id=user.user_id, + fulfillment=StreakFulfillment.COMPLETE, + period=StreakPeriod.DAY, + ) + + streaks = user_streak_manager.get_user_streaks( + user_id=user.user_id, country_iso="us" + ) + streak = [ + s + for s in streaks + if s.fulfillment == StreakFulfillment.COMPLETE and s.period == StreakPeriod.DAY + ][0] + assert streak == expected_streak diff --git a/tests/managers/thl/test_userhealth.py b/tests/managers/thl/test_userhealth.py new file mode 100644 index 0000000..1cda8de --- /dev/null +++ b/tests/managers/thl/test_userhealth.py @@ -0,0 +1,367 @@ +from datetime import timezone, datetime +from uuid import uuid4 + +import faker +import pytest + +from generalresearch.managers.thl.userhealth import ( + IPRecordManager, + UserIpHistoryManager, +) +from generalresearch.models.thl.ipinfo import GeoIPInformation +from generalresearch.models.thl.user_iphistory import ( + IPRecord, +) +from generalresearch.models.thl.userhealth import AuditLogLevel, AuditLog + +fake = faker.Faker() + + +class TestAuditLog: + + def test_init(self, thl_web_rr, audit_log_manager): + from generalresearch.managers.thl.userhealth import AuditLogManager + + alm = AuditLogManager(pg_config=thl_web_rr) + + assert isinstance(alm, AuditLogManager) + assert isinstance(audit_log_manager, AuditLogManager) + assert alm.pg_config.db == thl_web_rr.db + assert audit_log_manager.pg_config.db == thl_web_rr.db + + @pytest.mark.parametrize( + argnames="level", + argvalues=list(AuditLogLevel), + ) + def test_create(self, audit_log_manager, user, level): + instance = audit_log_manager.create( + user_id=user.user_id, level=level, event_type=uuid4().hex + ) + assert isinstance(instance, AuditLog) + assert instance.id != 1 + + def test_get_by_id(self, audit_log, audit_log_manager): + from generalresearch.models.thl.userhealth import AuditLog + + with pytest.raises(expected_exception=Exception) as cm: + audit_log_manager.get_by_id(auditlog_id=999_999_999_999) + assert "No AuditLog with id of " in str(cm.value) + + assert isinstance(audit_log, AuditLog) + res = audit_log_manager.get_by_id(auditlog_id=audit_log.id) + assert isinstance(res, AuditLog) + assert res.id == audit_log.id + assert res.created.tzinfo == timezone.utc + + def test_filter_by_product( + self, + user_factory, + product_factory, + audit_log_factory, + audit_log_manager, + ): + p1 = product_factory() + p2 = product_factory() + + audit_log_factory(user_id=user_factory(product=p1).user_id) + audit_log_factory(user_id=user_factory(product=p1).user_id) + audit_log_factory(user_id=user_factory(product=p1).user_id) + + res = audit_log_manager.filter_by_product(product=p2) + assert isinstance(res, list) + assert len(res) == 0 + + res = audit_log_manager.filter_by_product(product=p1) + assert isinstance(res, list) + assert len(res) == 3 + + audit_log_factory(user_id=user_factory(product=p2).user_id) + res = audit_log_manager.filter_by_product(product=p2) + assert isinstance(res, list) + assert isinstance(res[0], AuditLog) + assert len(res) == 1 + + def test_filter_by_user_id( + self, user_factory, product, audit_log_factory, audit_log_manager + ): + u1 = user_factory(product=product) + u2 = user_factory(product=product) + + audit_log_factory(user_id=u1.user_id) + audit_log_factory(user_id=u1.user_id) + audit_log_factory(user_id=u1.user_id) + + res = audit_log_manager.filter_by_user_id(user_id=u1.user_id) + assert isinstance(res, list) + assert len(res) == 3 + + res = audit_log_manager.filter_by_user_id(user_id=u2.user_id) + assert isinstance(res, list) + assert len(res) == 0 + + audit_log_factory(user_id=u2.user_id) + + res = audit_log_manager.filter_by_user_id(user_id=u2.user_id) + assert isinstance(res, list) + assert isinstance(res[0], AuditLog) + assert len(res) == 1 + + def test_filter( + self, + user_factory, + product_factory, + audit_log_factory, + audit_log_manager, + ): + p1 = product_factory() + p2 = product_factory() + p3 = product_factory() + + u1 = user_factory(product=p1) + u2 = user_factory(product=p2) + u3 = user_factory(product=p3) + + with pytest.raises(expected_exception=AssertionError) as cm: + audit_log_manager.filter(user_ids=[]) + assert "must pass at least 1 user_id" in str(cm.value) + + with pytest.raises(expected_exception=AssertionError) as cm: + audit_log_manager.filter(user_ids=[u1, u2, u3]) + assert "must pass user_id as int" in str(cm.value) + + res = audit_log_manager.filter(user_ids=[u1.user_id, u2.user_id, u3.user_id]) + assert isinstance(res, list) + assert len(res) == 0 + + audit_log_factory(user_id=u1.user_id) + + res = audit_log_manager.filter(user_ids=[u1.user_id, u2.user_id, u3.user_id]) + assert isinstance(res, list) + assert isinstance(res[0], AuditLog) + assert len(res) == 1 + + def test_filter_count( + self, + user_factory, + product_factory, + audit_log_factory, + audit_log_manager, + ): + p1 = product_factory() + p2 = product_factory() + p3 = product_factory() + + u1 = user_factory(product=p1) + u2 = user_factory(product=p2) + u3 = user_factory(product=p3) + + with pytest.raises(expected_exception=AssertionError) as cm: + audit_log_manager.filter(user_ids=[]) + assert "must pass at least 1 user_id" in str(cm.value) + + with pytest.raises(expected_exception=AssertionError) as cm: + audit_log_manager.filter(user_ids=[u1, u2, u3]) + assert "must pass user_id as int" in str(cm.value) + + res = audit_log_manager.filter_count( + user_ids=[u1.user_id, u2.user_id, u3.user_id] + ) + assert isinstance(res, int) + assert res == 0 + + audit_log_factory(user_id=u1.user_id, level=20) + + res = audit_log_manager.filter_count( + user_ids=[u1.user_id, u2.user_id, u3.user_id] + ) + assert isinstance(res, int) + assert res == 1 + + res = audit_log_manager.filter_count( + user_ids=[u1.user_id, u2.user_id, u3.user_id], + created_after=datetime.now(tz=timezone.utc), + ) + assert isinstance(res, int) + assert res == 0 + + res = audit_log_manager.filter_count( + user_ids=[u1.user_id], event_type_like="offerwall-enter.%%" + ) + assert res == 1 + + audit_log_factory(user_id=u1.user_id, level=50) + res = audit_log_manager.filter_count( + user_ids=[u1.user_id], + event_type_like="offerwall-enter.%%", + level_ge=10, + ) + assert res == 2 + + res = audit_log_manager.filter_count( + user_ids=[u1.user_id], event_type_like="poop.%", level_ge=10 + ) + assert res == 0 + + +class TestIPRecordManager: + + def test_init(self, thl_web_rr, thl_redis_config, ip_record_manager): + instance = IPRecordManager(pg_config=thl_web_rr, redis_config=thl_redis_config) + assert isinstance(instance, IPRecordManager) + assert isinstance(ip_record_manager, IPRecordManager) + + def test_create(self, ip_record_manager, user, ip_information): + instance = ip_record_manager.create_dummy( + user_id=user.user_id, ip=ip_information.ip + ) + assert isinstance(instance, IPRecord) + + assert isinstance(instance.forwarded_ips, list) + assert isinstance(instance.forwarded_ip_records[0], IPRecord) + assert isinstance(instance.forwarded_ips[0], str) + + assert instance.created == instance.forwarded_ip_records[0].created + + ipr1 = ip_record_manager.filter_ip_records(filter_ips=[instance.ip]) + assert isinstance(ipr1, list) + assert instance.model_dump_json() == ipr1[0].model_dump_json() + + def test_prefetch_info( + self, + ip_record_factory, + ip_information_factory, + ip_geoname, + user, + thl_web_rr, + thl_redis_config, + ): + + ip = fake.ipv4_public() + ip_information_factory(ip=ip, geoname=ip_geoname) + ipr: IPRecord = ip_record_factory(user_id=user.user_id, ip=ip) + + assert ipr.information is None + assert len(ipr.forwarded_ip_records) >= 1 + fipr = ipr.forwarded_ip_records[0] + assert fipr.information is None + + ipr.prefetch_ipinfo( + pg_config=thl_web_rr, + redis_config=thl_redis_config, + include_forwarded=True, + ) + assert isinstance(ipr.information, GeoIPInformation) + assert ipr.information.ip == ipr.ip == ip + assert fipr.information is None, "the ipinfo doesn't exist in the db yet" + + ip_information_factory(ip=fipr.ip, geoname=ip_geoname) + ipr.prefetch_ipinfo( + pg_config=thl_web_rr, + redis_config=thl_redis_config, + include_forwarded=True, + ) + assert fipr.information is not None + + +@pytest.mark.usefixtures("user_iphistory_manager_clear_cache") +class TestUserIpHistoryManager: + def test_init(self, thl_web_rr, thl_redis_config, user_iphistory_manager): + instance = UserIpHistoryManager( + pg_config=thl_web_rr, redis_config=thl_redis_config + ) + assert isinstance(instance, UserIpHistoryManager) + assert isinstance(user_iphistory_manager, UserIpHistoryManager) + + def test_latest_record( + self, + user_iphistory_manager, + user, + ip_record_factory, + ip_information_factory, + ip_geoname, + ): + ip = fake.ipv4_public() + ip_information_factory(ip=ip, geoname=ip_geoname, is_anonymous=True) + ipr1: IPRecord = ip_record_factory(user_id=user.user_id, ip=ip) + + ipr = user_iphistory_manager.get_user_latest_ip_record(user=user) + assert ipr.ip == ipr1.ip + assert ipr.is_anonymous + assert ipr.information.lookup_prefix == "/32" + + ip = fake.ipv6() + ip_information_factory(ip=ip, geoname=ip_geoname) + ipr2: IPRecord = ip_record_factory(user_id=user.user_id, ip=ip) + + ipr = user_iphistory_manager.get_user_latest_ip_record(user=user) + assert ipr.ip == ipr2.ip + assert ipr.information.lookup_prefix == "/64" + assert ipr.information is not None + assert not ipr.is_anonymous + + country_iso = user_iphistory_manager.get_user_latest_country(user=user) + assert country_iso == ip_geoname.country_iso + + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert iph.ips[0].information is not None + assert iph.ips[1].information is not None + assert iph.ips[0].country_iso == country_iso + assert iph.ips[0].is_anonymous + assert iph.ips[0].ip == ipr1.ip + assert iph.ips[1].ip == ipr2.ip + + def test_virgin(self, user, user_iphistory_manager, ip_record_factory): + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert len(iph.ips) == 0 + + ip_record_factory(user_id=user.user_id, ip=fake.ipv4_public()) + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert len(iph.ips) == 1 + + def test_out_of_order( + self, + ip_record_factory, + user, + user_iphistory_manager, + ip_information_factory, + ip_geoname, + ): + # Create the user-ip association BEFORE the ip even exists in the ipinfo table + ip = fake.ipv4_public() + ip_record_factory(user_id=user.user_id, ip=ip) + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert len(iph.ips) == 1 + ipr = iph.ips[0] + assert ipr.information is None + assert not ipr.is_anonymous + + ip_information_factory(ip=ip, geoname=ip_geoname, is_anonymous=True) + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert len(iph.ips) == 1 + ipr = iph.ips[0] + assert ipr.information is not None + assert ipr.is_anonymous + + def test_out_of_order_ipv6( + self, + ip_record_factory, + user, + user_iphistory_manager, + ip_information_factory, + ip_geoname, + ): + # Create the user-ip association BEFORE the ip even exists in the ipinfo table + ip = fake.ipv6() + ip_record_factory(user_id=user.user_id, ip=ip) + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert len(iph.ips) == 1 + ipr = iph.ips[0] + assert ipr.information is None + assert not ipr.is_anonymous + + ip_information_factory(ip=ip, geoname=ip_geoname, is_anonymous=True) + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert len(iph.ips) == 1 + ipr = iph.ips[0] + assert ipr.information is not None + assert ipr.is_anonymous diff --git a/tests/managers/thl/test_wall_manager.py b/tests/managers/thl/test_wall_manager.py new file mode 100644 index 0000000..ee44e23 --- /dev/null +++ b/tests/managers/thl/test_wall_manager.py @@ -0,0 +1,283 @@ +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from uuid import uuid4 + +import pytest + +from generalresearch.models import Source +from generalresearch.models.thl.session import ( + ReportValue, + Status, + StatusCode1, +) +from test_utils.models.conftest import user, session + + +class TestWallManager: + + @pytest.mark.parametrize("wall_count", [1, 2, 5, 10, 50, 99]) + def test_get_wall_events( + self, wall_manager, session_factory, user, wall_count, utc_hour_ago + ): + from generalresearch.models.thl.session import Session + + s1: Session = session_factory( + user=user, wall_count=wall_count, started=utc_hour_ago + ) + + assert len(s1.wall_events) == wall_count + assert len(wall_manager.get_wall_events(session_id=s1.id)) == wall_count + + db_wall_events = wall_manager.get_wall_events(session_id=s1.id) + assert [w.uuid for w in s1.wall_events] == [w.uuid for w in db_wall_events] + assert [w.source for w in s1.wall_events] == [w.source for w in db_wall_events] + assert [w.buyer_id for w in s1.wall_events] == [ + w.buyer_id for w in db_wall_events + ] + assert [w.req_survey_id for w in s1.wall_events] == [ + w.req_survey_id for w in db_wall_events + ] + assert [w.started for w in s1.wall_events] == [ + w.started for w in db_wall_events + ] + + assert sum([w.req_cpi for w in s1.wall_events]) == sum( + [w.req_cpi for w in db_wall_events] + ) + assert sum([w.cpi for w in s1.wall_events]) == sum( + [w.cpi for w in db_wall_events] + ) + + assert [w.session_id for w in s1.wall_events] == [ + w.session_id for w in db_wall_events + ] + assert [w.user_id for w in s1.wall_events] == [ + w.user_id for w in db_wall_events + ] + assert [w.survey_id for w in s1.wall_events] == [ + w.survey_id for w in db_wall_events + ] + + assert [w.finished for w in s1.wall_events] == [ + w.finished for w in db_wall_events + ] + + def test_get_wall_events_list_input( + self, wall_manager, session_factory, user, utc_hour_ago + ): + from generalresearch.models.thl.session import Session + + session_ids = [] + for idx in range(10): + s: Session = session_factory(user=user, wall_count=5, started=utc_hour_ago) + session_ids.append(s.id) + + session_ids.sort() + res = wall_manager.get_wall_events(session_ids=session_ids) + + assert isinstance(res, list) + assert len(res) == 50 + + res1 = list(set([w.session_id for w in res])) + res1.sort() + + assert session_ids == res1 + + def test_create_wall(self, wall_manager, session_manager, user, session): + w = wall_manager.create( + session_id=session.id, + user_id=user.user_id, + uuid_id=uuid4().hex, + started=datetime.now(tz=timezone.utc), + source=Source.DYNATA, + buyer_id="123", + req_survey_id="456", + req_cpi=Decimal("1"), + ) + + assert w is not None + w2 = wall_manager.get_from_uuid(wall_uuid=w.uuid) + assert w == w2 + + def test_report_wall_abandon( + self, wall_manager, session_manager, user, session, utc_hour_ago + ): + w1 = wall_manager.create( + session_id=session.id, + user_id=user.user_id, + uuid_id=uuid4().hex, + started=utc_hour_ago, + source=Source.DYNATA, + buyer_id="123", + req_survey_id="456", + req_cpi=Decimal("1"), + ) + wall_manager.report( + wall=w1, + report_value=ReportValue.REASON_UNKNOWN, + report_timestamp=utc_hour_ago + timedelta(minutes=1), + ) + w2 = wall_manager.get_from_uuid(wall_uuid=w1.uuid) + + # I Reported a session with no status. It should be marked as an abandon with a finished ts + assert ReportValue.REASON_UNKNOWN == w2.report_value + assert Status.ABANDON == w2.status + assert utc_hour_ago + timedelta(minutes=1) == w2.finished + assert w2.report_notes is None + + # There is nothing stopping it from being un-abandoned... + finished = w1.started + timedelta(minutes=10) + wall_manager.finish( + wall=w1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + finished=finished, + ) + w2 = wall_manager.get_from_uuid(wall_uuid=w1.uuid) + assert ReportValue.REASON_UNKNOWN == w2.report_value + assert w2.report_notes is None + assert finished == w2.finished + assert Status.COMPLETE == w2.status + # the status and finished get updated + + def test_report_wall( + self, wall_manager, session_manager, user, session, utc_hour_ago + ): + w1 = wall_manager.create( + session_id=session.id, + user_id=user.user_id, + uuid_id=uuid4().hex, + started=utc_hour_ago, + source=Source.DYNATA, + buyer_id="123", + req_survey_id="456", + req_cpi=Decimal("1"), + ) + + finish_ts = utc_hour_ago + timedelta(minutes=10) + report_ts = utc_hour_ago + timedelta(minutes=11) + wall_manager.finish( + wall=w1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + finished=finish_ts, + ) + + # I Reported a session that was already completed. Only update the report values + wall_manager.report( + wall=w1, + report_value=ReportValue.TECHNICAL_ERROR, + report_timestamp=report_ts, + report_notes="This survey blows!", + ) + w2 = wall_manager.get_from_uuid(wall_uuid=w1.uuid) + + assert ReportValue.TECHNICAL_ERROR == w2.report_value + assert finish_ts == w2.finished + assert Status.COMPLETE == w2.status + assert "This survey blows!" == w2.report_notes + + def test_filter_wall_attempts( + self, wall_manager, session_manager, user, session, utc_hour_ago + ): + res = wall_manager.filter_wall_attempts(user_id=user.user_id) + assert len(res) == 0 + w1 = wall_manager.create( + session_id=session.id, + user_id=user.user_id, + uuid_id=uuid4().hex, + started=utc_hour_ago, + source=Source.DYNATA, + buyer_id="123", + req_survey_id="456", + req_cpi=Decimal("1"), + ) + res = wall_manager.filter_wall_attempts(user_id=user.user_id) + assert len(res) == 1 + w2 = wall_manager.create( + session_id=session.id, + user_id=user.user_id, + uuid_id=uuid4().hex, + started=utc_hour_ago + timedelta(minutes=1), + source=Source.DYNATA, + buyer_id="123", + req_survey_id="555", + req_cpi=Decimal("1"), + ) + res = wall_manager.filter_wall_attempts(user_id=user.user_id) + assert len(res) == 2 + + +class TestWallCacheManager: + + def test_get_attempts_none(self, wall_cache_manager, user): + attempts = wall_cache_manager.get_attempts(user.user_id) + assert len(attempts) == 0 + + def test_get_wall_events( + self, wall_cache_manager, wall_manager, session_manager, user + ): + start1 = datetime.now(timezone.utc) - timedelta(hours=3) + start2 = datetime.now(timezone.utc) - timedelta(hours=2) + start3 = datetime.now(timezone.utc) - timedelta(hours=1) + + session = session_manager.create_dummy(started=start1, user=user) + wall1 = wall_manager.create_dummy( + session_id=session.id, + user_id=session.user_id, + started=start1, + req_cpi=Decimal("1.23"), + req_survey_id="11111", + source=Source.DYNATA, + ) + # The flag never got set, so no results! + attempts = wall_cache_manager.get_attempts(user_id=user.user_id) + assert len(attempts) == 0 + + wall_cache_manager.set_flag(user_id=user.user_id) + attempts = wall_cache_manager.get_attempts(user_id=user.user_id) + assert len(attempts) == 1 + + wall2 = wall_manager.create_dummy( + session_id=session.id, + user_id=session.user_id, + started=start2, + req_cpi=Decimal("1.23"), + req_survey_id="22222", + source=Source.DYNATA, + ) + + # We haven't set the flag, so the cache won't update! + attempts = wall_cache_manager.get_attempts(user_id=user.user_id) + assert len(attempts) == 1 + + # Now set the flag + wall_cache_manager.set_flag(user_id=user.user_id) + attempts = wall_cache_manager.get_attempts(user_id=user.user_id) + assert len(attempts) == 2 + # It is in desc order + assert attempts[0].req_survey_id == "22222" + assert attempts[1].req_survey_id == "11111" + + # Test the trim. Fill up the cache with 6000 events, then add another, + # and it should be first in the list, with only 5k others + attempts10000 = [attempts[0]] * 6000 + wall_cache_manager.update_attempts_redis_(attempts10000, user_id=user.user_id) + + session = session_manager.create_dummy(started=start3, user=user) + wall3 = wall_manager.create_dummy( + session_id=session.id, + user_id=session.user_id, + started=start3, + req_cpi=Decimal("1.23"), + req_survey_id="33333", + source=Source.DYNATA, + ) + wall_cache_manager.set_flag(user_id=user.user_id) + attempts = wall_cache_manager.get_attempts(user_id=user.user_id) + + redis_key = wall_cache_manager.get_cache_key_(user_id=user.user_id) + assert wall_cache_manager.redis_client.llen(redis_key) == 5000 + + assert len(attempts) == 5000 + assert attempts[0].req_survey_id == "33333" diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/models/__init__.py diff --git a/tests/models/admin/__init__.py b/tests/models/admin/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/models/admin/__init__.py diff --git a/tests/models/admin/test_report_request.py b/tests/models/admin/test_report_request.py new file mode 100644 index 0000000..cf4c405 --- /dev/null +++ b/tests/models/admin/test_report_request.py @@ -0,0 +1,163 @@ +from datetime import timezone, datetime + +import pandas as pd +import pytest +from pydantic import ValidationError + + +class TestReportRequest: + def test_base(self, utc_60days_ago): + from generalresearch.models.admin.request import ( + ReportRequest, + ReportType, + ) + + rr = ReportRequest() + + assert isinstance(rr.start, datetime), "rr.start incorrect type" + assert isinstance(rr.start_floor, datetime), "rr.start_floor incorrect type" + + assert rr.report_type == ReportType.POP_SESSION + assert rr.start != rr.start_floor, "rr.start != rr.start_floor" + assert rr.start_floor.tzinfo == timezone.utc, "rr.start_floor.tzinfo not utc" + + rr1 = ReportRequest.model_validate( + { + "start": datetime( + year=datetime.now().year, + month=1, + day=1, + hour=0, + minute=30, + second=25, + microsecond=35, + tzinfo=timezone.utc, + ), + "interval": "1h", + } + ) + + assert isinstance(rr1.start, datetime), "rr1.start incorrect type" + assert isinstance(rr1.start_floor, datetime), "rr1.start_floor incorrect type" + + rr2 = ReportRequest.model_validate( + { + "start": datetime( + year=datetime.now().year, + month=1, + day=1, + hour=6, + minute=30, + second=25, + microsecond=35, + tzinfo=timezone.utc, + ), + "interval": "1d", + } + ) + + assert isinstance(rr2.start, datetime), "rr2.start incorrect type" + assert isinstance(rr2.start_floor, datetime), "rr2.start_floor incorrect type" + + assert rr1.start != rr2.start, "rr1.start != rr2.start" + assert rr1.start_floor == rr2.start_floor, "rr1.start_floor == rr2.start_floor" + + # datetime.datetime(2025, 7, 9, 0, 0, tzinfo=datetime.timezone.utc) + # datetime.datetime(2025, 7, 9, 0, 0, tzinfo=datetime.timezone.utc) + + # datetime.datetime(2025, 7, 9, 0, 0, tzinfo=datetime.timezone.utc) = + # ReportRequest(report_type=<ReportType.POP_SESSION: 'pop_session'>, + # index0='started', index1='product_id', + # start=datetime.datetime(2025, 7, 9, 0, 46, 23, 145756, tzinfo=datetime.timezone.utc), + # end=datetime.datetime(2025, 9, 7, 0, 46, 23, 149195, tzinfo=datetime.timezone.utc), + # interval='1h', include_open_bucket=True, + # start_floor=datetime.datetime(2025, 7, 9, 0, 0, tzinfo=datetime.timezone.utc)).start_floor + + # datetime.datetime(2025, 7, 9, 0, 0, tzinfo=datetime.timezone.utc) = + # ReportRequest(report_type=<ReportType.POP_SESSION: 'pop_session'>, + # index0='started', index1='product_id', + # start=datetime.datetime(2025, 7, 9, 0, 46, 23, 145756, tzinfo=datetime.timezone.utc), + # end=datetime.datetime(2025, 9, 7, 0, 46, 23, 149267, tzinfo=datetime.timezone.utc), + # interval='1d', include_open_bucket=True, + # start_floor=datetime.datetime(2025, 7, 9, 0, 0, tzinfo=datetime.timezone.utc)).start_floor + + def test_start_end_range(self, utc_90days_ago, utc_30days_ago): + from generalresearch.models.admin.request import ReportRequest + + with pytest.raises(expected_exception=ValidationError) as cm: + ReportRequest.model_validate( + {"start": utc_30days_ago, "end": utc_90days_ago} + ) + + with pytest.raises(expected_exception=ValidationError) as cm: + ReportRequest.model_validate( + { + "start": datetime(year=1990, month=1, day=1), + "end": datetime(year=1950, month=1, day=1), + } + ) + + def test_start_end_range_tz(self): + from generalresearch.models.admin.request import ReportRequest + from zoneinfo import ZoneInfo + + pacific_tz = ZoneInfo("America/Los_Angeles") + + with pytest.raises(expected_exception=ValidationError) as cm: + ReportRequest.model_validate( + { + "start": datetime(year=2000, month=1, day=1, tzinfo=pacific_tz), + "end": datetime(year=2000, month=6, day=1, tzinfo=pacific_tz), + } + ) + + def test_start_floor_naive(self): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest() + + assert rr.start_floor_naive.tzinfo is None + + def test_end_naive(self): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest() + + assert rr.end_naive.tzinfo is None + + def test_pd_interval(self): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest() + + assert isinstance(rr.pd_interval, pd.Interval) + + def test_interval_timedelta(self): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest() + + assert isinstance(rr.interval_timedelta, pd.Timedelta) + + def test_buckets(self): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest() + + assert isinstance(rr.buckets(), pd.DatetimeIndex) + + def test_bucket_ranges(self): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest() + assert isinstance(rr.bucket_ranges(), list) + + rr = ReportRequest.model_validate( + { + "interval": "1d", + "start": datetime(year=2000, month=1, day=1, tzinfo=timezone.utc), + "end": datetime(year=2000, month=1, day=10, tzinfo=timezone.utc), + } + ) + + assert len(rr.bucket_ranges()) == 10 diff --git a/tests/models/custom_types/__init__.py b/tests/models/custom_types/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/models/custom_types/__init__.py diff --git a/tests/models/custom_types/test_aware_datetime.py b/tests/models/custom_types/test_aware_datetime.py new file mode 100644 index 0000000..a23413c --- /dev/null +++ b/tests/models/custom_types/test_aware_datetime.py @@ -0,0 +1,82 @@ +import logging +from datetime import datetime, timezone +from typing import Optional + +import pytest +import pytz +from pydantic import BaseModel, ValidationError, Field + +from generalresearch.models.custom_types import AwareDatetimeISO + +logger = logging.getLogger() + + +class AwareDatetimeISOModel(BaseModel): + dt_optional: Optional[AwareDatetimeISO] = Field(default=None) + dt: AwareDatetimeISO + + +class TestAwareDatetimeISO: + def test_str(self): + dt = "2023-10-10T01:01:01.0Z" + t = AwareDatetimeISOModel(dt=dt, dt_optional=dt) + AwareDatetimeISOModel.model_validate_json(t.model_dump_json()) + + t = AwareDatetimeISOModel(dt=dt, dt_optional=None) + AwareDatetimeISOModel.model_validate_json(t.model_dump_json()) + + def test_dt(self): + dt = datetime(2023, 10, 10, 1, 1, 1, tzinfo=timezone.utc) + t = AwareDatetimeISOModel(dt=dt, dt_optional=dt) + AwareDatetimeISOModel.model_validate_json(t.model_dump_json()) + + t = AwareDatetimeISOModel(dt=dt, dt_optional=None) + AwareDatetimeISOModel.model_validate_json(t.model_dump_json()) + + dt = datetime(2023, 10, 10, 1, 1, 1, microsecond=123, tzinfo=timezone.utc) + t = AwareDatetimeISOModel(dt=dt, dt_optional=dt) + AwareDatetimeISOModel.model_validate_json(t.model_dump_json()) + + t = AwareDatetimeISOModel(dt=dt, dt_optional=None) + AwareDatetimeISOModel.model_validate_json(t.model_dump_json()) + + def test_no_tz(self): + dt = datetime(2023, 10, 10, 1, 1, 1) + + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=dt, dt_optional=None) + + dt = "2023-10-10T01:01:01.0" + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=dt, dt_optional=None) + + def test_non_utc_tz(self): + dt = datetime( + year=2023, + month=10, + day=10, + hour=1, + second=1, + minute=1, + tzinfo=pytz.timezone("US/Central"), + ) + + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=dt, dt_optional=dt) + + def test_invalid_format(self): + dt = "2023-10-10T01:01:01Z" + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=dt, dt_optional=dt) + + dt = "2023-10-10T01:01:01" + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=dt, dt_optional=dt) + dt = "2023-10-10" + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=dt, dt_optional=dt) + + def test_required(self): + dt = "2023-10-10T01:01:01.0Z" + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=None, dt_optional=dt) diff --git a/tests/models/custom_types/test_dsn.py b/tests/models/custom_types/test_dsn.py new file mode 100644 index 0000000..b37f2c4 --- /dev/null +++ b/tests/models/custom_types/test_dsn.py @@ -0,0 +1,112 @@ +from typing import Optional +from uuid import uuid4 + +import pytest +from pydantic import BaseModel, ValidationError, Field +from pydantic import MySQLDsn +from pydantic_core import Url + +from generalresearch.models.custom_types import DaskDsn, SentryDsn + + +# --- Test Pydantic Models --- + + +class SettingsModel(BaseModel): + dask: Optional["DaskDsn"] = Field(default=None) + sentry: Optional["SentryDsn"] = Field(default=None) + db: Optional["MySQLDsn"] = Field(default=None) + + +# --- Pytest themselves --- + + +class TestDaskDsn: + + def test_base(self): + from dask.distributed import Client + + m = SettingsModel(dask="tcp://dask-scheduler.internal") + + assert m.dask.scheme == "tcp" + assert m.dask.host == "dask-scheduler.internal" + assert m.dask.port == 8786 + + with pytest.raises(expected_exception=TypeError) as cm: + Client(m.dask) + assert "Scheduler address must be a string or a Cluster instance" in str( + cm.value + ) + + # todo: this requires vpn connection. maybe do this part with a localhost dsn + # client = Client(str(m.dask)) + # self.assertIsInstance(client, Client) + + def test_str(self): + m = SettingsModel(dask="tcp://dask-scheduler.internal") + assert isinstance(m.dask, Url) + assert "tcp://dask-scheduler.internal:8786" == str(m.dask) + + def test_auth(self): + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(dask="tcp://test:password@dask-scheduler.internal") + assert "User & Password are not supported" in str(cm.value) + + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(dask="tcp://test:@dask-scheduler.internal") + assert "User & Password are not supported" in str(cm.value) + + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(dask="tcp://:password@dask-scheduler.internal") + assert "User & Password are not supported" in str(cm.value) + + def test_invalid_schema(self): + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(dask="dask-scheduler.internal") + assert "relative URL without a base" in str(cm.value) + + # I look forward to the day we use infiniband interfaces + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(dask="ucx://dask-scheduler.internal") + assert "URL scheme should be 'tcp'" in str(cm.value) + + def test_port(self): + m = SettingsModel(dask="tcp://dask-scheduler.internal") + assert m.dask.port == 8786 + + +class TestSentryDsn: + def test_base(self): + m = SettingsModel( + sentry=f"https://{uuid4().hex}@12345.ingest.us.sentry.io/9876543" + ) + + assert m.sentry.scheme == "https" + assert m.sentry.host == "12345.ingest.us.sentry.io" + assert m.sentry.port == 443 + + def test_str(self): + test_url: str = f"https://{uuid4().hex}@12345.ingest.us.sentry.io/9876543" + m = SettingsModel(sentry=test_url) + assert isinstance(m.sentry, Url) + assert test_url == str(m.sentry) + + def test_auth(self): + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel( + sentry="https://0123456789abc:password@12345.ingest.us.sentry.io/9876543" + ) + assert "Sentry password is not supported" in str(cm.value) + + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(sentry="https://test:@12345.ingest.us.sentry.io/9876543") + assert "Sentry user key seems bad" in str(cm.value) + + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(sentry="https://:password@12345.ingest.us.sentry.io/9876543") + assert "Sentry URL requires a user key" in str(cm.value) + + def test_port(self): + test_url: str = f"https://{uuid4().hex}@12345.ingest.us.sentry.io/9876543" + m = SettingsModel(sentry=test_url) + assert m.sentry.port == 443 diff --git a/tests/models/custom_types/test_therest.py b/tests/models/custom_types/test_therest.py new file mode 100644 index 0000000..13e9bae --- /dev/null +++ b/tests/models/custom_types/test_therest.py @@ -0,0 +1,42 @@ +import json +from uuid import UUID + +import pytest +from pydantic import TypeAdapter, ValidationError + + +class TestAll: + + def test_comma_sep_str(self): + from generalresearch.models.custom_types import AlphaNumStrSet + + t = TypeAdapter(AlphaNumStrSet) + assert {"a", "b", "c"} == t.validate_python(["a", "b", "c"]) + assert '"a,b,c"' == t.dump_json({"c", "b", "a"}).decode() + assert '""' == t.dump_json(set()).decode() + assert {"a", "b", "c"} == t.validate_json('"c,b,a"') + assert set() == t.validate_json('""') + + with pytest.raises(ValidationError): + t.validate_python({"", "b", "a"}) + + with pytest.raises(ValidationError): + t.validate_python({""}) + + with pytest.raises(ValidationError): + t.validate_json('",b,a"') + + def test_UUIDStrCoerce(self): + from generalresearch.models.custom_types import UUIDStrCoerce + + t = TypeAdapter(UUIDStrCoerce) + uuid_str = "18e70590176e49c693b07682f3c112be" + assert uuid_str == t.validate_python("18e70590-176e-49c6-93b0-7682f3c112be") + assert uuid_str == t.validate_python( + UUID("18e70590-176e-49c6-93b0-7682f3c112be") + ) + assert ( + json.dumps(uuid_str) + == t.dump_json("18e70590176e49c693b07682f3c112be").decode() + ) + assert uuid_str == t.validate_json('"18e70590-176e-49c6-93b0-7682f3c112be"') diff --git a/tests/models/custom_types/test_uuid_str.py b/tests/models/custom_types/test_uuid_str.py new file mode 100644 index 0000000..91af9ae --- /dev/null +++ b/tests/models/custom_types/test_uuid_str.py @@ -0,0 +1,51 @@ +from typing import Optional +from uuid import uuid4 + +import pytest +from pydantic import BaseModel, ValidationError, Field + +from generalresearch.models.custom_types import UUIDStr + + +class UUIDStrModel(BaseModel): + uuid_optional: Optional[UUIDStr] = Field(default_factory=lambda: uuid4().hex) + uuid: UUIDStr + + +class TestUUIDStr: + def test_str(self): + v = "58889cd67f9f4c699b25437112dce638" + + t = UUIDStrModel(uuid=v, uuid_optional=v) + UUIDStrModel.model_validate_json(t.model_dump_json()) + + t = UUIDStrModel(uuid=v, uuid_optional=None) + t2 = UUIDStrModel.model_validate_json(t.model_dump_json()) + + assert t2.uuid_optional is None + assert t2.uuid == v + + def test_uuid(self): + v = uuid4() + + with pytest.raises(ValidationError) as cm: + UUIDStrModel(uuid=v, uuid_optional=None) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(ValidationError) as cm: + UUIDStrModel(uuid="58889cd67f9f4c699b25437112dce638", uuid_optional=v) + assert "Input should be a valid string" in str(cm.value) + + def test_invalid_format(self): + v = "x" + with pytest.raises(ValidationError): + UUIDStrModel(uuid=v, uuid_optional=None) + + with pytest.raises(ValidationError): + UUIDStrModel(uuid="58889cd67f9f4c699b25437112dce638", uuid_optional=v) + + def test_required(self): + v = "58889cd67f9f4c699b25437112dce638" + + with pytest.raises(ValidationError): + UUIDStrModel(uuid=None, uuid_optional=v) diff --git a/tests/models/dynata/__init__.py b/tests/models/dynata/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/models/dynata/__init__.py diff --git a/tests/models/dynata/test_eligbility.py b/tests/models/dynata/test_eligbility.py new file mode 100644 index 0000000..23437f5 --- /dev/null +++ b/tests/models/dynata/test_eligbility.py @@ -0,0 +1,324 @@ +from datetime import datetime, timezone + + +class TestEligibility: + + def test_evaluate_task_criteria(self): + from generalresearch.models.dynata.survey import ( + DynataQuotaGroup, + DynataFilterGroup, + DynataSurvey, + DynataRequirements, + ) + + filters = [[["a", "b"], ["c", "d"]], [["e"], ["f"]]] + filters = [DynataFilterGroup.model_validate(f) for f in filters] + criteria_evaluation = { + "a": True, + "b": True, + "c": True, + "d": False, + "e": True, + "f": True, + } + quotas = [ + DynataQuotaGroup.model_validate( + [{"count": 100, "condition_hashes": [], "status": "OPEN"}] + ) + ] + task = DynataSurvey.model_validate( + { + "survey_id": "1", + "filters": filters, + "quotas": quotas, + "allowed_devices": set("1"), + "calculation_type": "COMPLETES", + "client_id": "", + "country_iso": "us", + "language_iso": "eng", + "group_id": "g1", + "project_id": "p1", + "status": "OPEN", + "project_exclusions": set(), + "created": datetime.now(tz=timezone.utc), + "category_exclusions": set(), + "category_ids": set(), + "cpi": 1, + "days_in_field": 0, + "expected_count": 0, + "order_number": "", + "live_link": "", + "bid_ir": 0.5, + "bid_loi": 500, + "requirements": DynataRequirements(), + } + ) + assert task.determine_eligibility(criteria_evaluation) + + # task status + task.status = "CLOSED" + assert not task.determine_eligibility(criteria_evaluation) + task.status = "OPEN" + + # one quota with no space left (count = 0) + quotas = [ + DynataQuotaGroup.model_validate( + [{"count": 0, "condition_hashes": [], "status": "OPEN"}] + ) + ] + task.quotas = quotas + assert not task.determine_eligibility(criteria_evaluation) + + # we pass 'a' and 'b' + quotas = [ + DynataQuotaGroup.model_validate( + [{"count": 100, "condition_hashes": ["a", "b"], "status": "OPEN"}] + ) + ] + task.quotas = quotas + assert task.determine_eligibility(criteria_evaluation) + + # make 'f' false, we still pass the 2nd filtergroup b/c 'e' is True + criteria_evaluation = { + "a": True, + "b": True, + "c": True, + "d": False, + "e": True, + "f": False, + } + assert task.determine_eligibility(criteria_evaluation) + + # make 'e' false, we don't pass the 2nd filtergroup + criteria_evaluation = { + "a": True, + "b": True, + "c": True, + "d": False, + "e": False, + "f": False, + } + assert not task.determine_eligibility(criteria_evaluation) + + # We fail quota 'c','d', but we pass 'a','b', so we pass the first quota group + criteria_evaluation = { + "a": True, + "b": True, + "c": True, + "d": False, + "e": True, + "f": True, + } + quotas = [ + DynataQuotaGroup.model_validate( + [ + {"count": 100, "condition_hashes": ["a", "b"], "status": "OPEN"}, + {"count": 100, "condition_hashes": ["c", "d"], "status": "CLOSED"}, + ] + ) + ] + task.quotas = quotas + assert task.determine_eligibility(criteria_evaluation) + + # we pass the first qg, but then fall into a full 2nd qg + quotas = [ + DynataQuotaGroup.model_validate( + [ + {"count": 100, "condition_hashes": ["a", "b"], "status": "OPEN"}, + {"count": 100, "condition_hashes": ["c", "d"], "status": "CLOSED"}, + ] + ), + DynataQuotaGroup.model_validate( + [{"count": 100, "condition_hashes": ["f"], "status": "CLOSED"}] + ), + ] + task.quotas = quotas + assert not task.determine_eligibility(criteria_evaluation) + + def test_soft_pair(self): + from generalresearch.models.dynata.survey import ( + DynataQuotaGroup, + DynataFilterGroup, + DynataSurvey, + DynataRequirements, + ) + + filters = [[["a", "b"], ["c", "d"]], [["e"], ["f"]]] + filters = [DynataFilterGroup.model_validate(f) for f in filters] + criteria_evaluation = { + "a": True, + "b": True, + "c": True, + "d": False, + "e": True, + "f": True, + } + quotas = [ + DynataQuotaGroup.model_validate( + [{"count": 100, "condition_hashes": [], "status": "OPEN"}] + ) + ] + task = DynataSurvey.model_validate( + { + "survey_id": "1", + "filters": filters, + "quotas": quotas, + "allowed_devices": set("1"), + "calculation_type": "COMPLETES", + "client_id": "", + "country_iso": "us", + "language_iso": "eng", + "group_id": "g1", + "project_id": "p1", + "status": "OPEN", + "project_exclusions": set(), + "created": datetime.now(tz=timezone.utc), + "category_exclusions": set(), + "category_ids": set(), + "cpi": 1, + "days_in_field": 0, + "expected_count": 0, + "order_number": "", + "live_link": "", + "bid_ir": 0.5, + "bid_loi": 500, + "requirements": DynataRequirements(), + } + ) + assert task.passes_filters(criteria_evaluation) + passes, condition_hashes = task.passes_filters_soft(criteria_evaluation) + assert passes + + # make 'e' & 'f' None, we don't pass the 2nd filtergroup + criteria_evaluation = { + "a": True, + "b": True, + "c": True, + "d": False, + "e": None, + "f": None, + } + assert not task.passes_filters(criteria_evaluation) + passes, conditional_hashes = task.passes_filters_soft(criteria_evaluation) + assert passes is None + assert {"e", "f"} == conditional_hashes + + # 1st filtergroup unknown + criteria_evaluation = { + "a": True, + "b": None, + "c": None, + "d": None, + "e": None, + "f": None, + } + assert not task.passes_filters(criteria_evaluation) + passes, conditional_hashes = task.passes_filters_soft(criteria_evaluation) + assert passes is None + assert {"b", "c", "d", "e", "f"} == conditional_hashes + + # 1st filtergroup unknown, 2nd cell False + criteria_evaluation = { + "a": True, + "b": None, + "c": None, + "d": False, + "e": None, + "f": None, + } + assert not task.passes_filters(criteria_evaluation) + passes, conditional_hashes = task.passes_filters_soft(criteria_evaluation) + assert passes is None + assert {"b", "e", "f"} == conditional_hashes + + # we pass the first qg, unknown 2nd + criteria_evaluation = { + "a": True, + "b": True, + "c": None, + "d": False, + "e": None, + "f": None, + } + quotas = [ + DynataQuotaGroup.model_validate( + [ + {"count": 100, "condition_hashes": ["a", "b"], "status": "OPEN"}, + {"count": 100, "condition_hashes": ["c", "d"], "status": "CLOSED"}, + ] + ), + DynataQuotaGroup.model_validate( + [{"count": 100, "condition_hashes": ["f"], "status": "OPEN"}] + ), + ] + task.quotas = quotas + passes, conditional_hashes = task.passes_quotas_soft(criteria_evaluation) + assert passes is None + assert {"f"} == conditional_hashes + + # both quota groups unknown + criteria_evaluation = { + "a": True, + "b": None, + "c": None, + "d": False, + "e": None, + "g": None, + } + quotas = [ + DynataQuotaGroup.model_validate( + [ + {"count": 100, "condition_hashes": ["a", "b"], "status": "OPEN"}, + {"count": 100, "condition_hashes": ["c", "d"], "status": "CLOSED"}, + ] + ), + DynataQuotaGroup.model_validate( + [{"count": 100, "condition_hashes": ["g"], "status": "OPEN"}] + ), + ] + task.quotas = quotas + passes, conditional_hashes = task.passes_quotas_soft(criteria_evaluation) + assert passes is None + assert {"b", "g"} == conditional_hashes + + passes, conditional_hashes = task.determine_eligibility_soft( + criteria_evaluation + ) + assert passes is None + assert {"b", "e", "f", "g"} == conditional_hashes + + # def x(self): + # # ---- + # c1 = DynataCondition(question_id='gender', values=['male'], value_type=ConditionValueType.LIST) # 718f759 + # c2 = DynataCondition(question_id='age', values=['18-24'], value_type=ConditionValueType.RANGE) # 7a7b290 + # obj1 = DynataFilterObject(cells=[c1.criterion_hash, c2.criterion_hash]) + # + # c3 = DynataCondition(question_id='gender', values=['female'], value_type=ConditionValueType.LIST) # 38fa4e1 + # c4 = DynataCondition(question_id='age', values=['35-45'], value_type=ConditionValueType.RANGE) # e4f06fa + # obj2 = DynataFilterObject(cells=[c3.criterion_hash, c4.criterion_hash]) + # + # grp1 = DynataFilterGroup(objects=[obj1, obj2]) + # + # # ----- + # c5 = DynataCondition(question_id='ethnicity', values=['white'], value_type=ConditionValueType.LIST) # eb9b9a4 + # obj3 = DynataFilterObject(cells=[c5.criterion_hash]) + # + # c6 = DynataCondition(question_id='ethnicity', values=['black'], value_type=ConditionValueType.LIST) # 039fe2d + # obj4 = DynataFilterObject(cells=[c6.criterion_hash]) + # + # grp2 = DynataFilterGroup(objects=[obj3, obj4]) + # # ----- + # q1 = DynataQuota(count=5, status=DynataStatus.OPEN, + # condition_hashes=[c1.criterion_hash, c2.criterion_hash]) + # q2 = DynataQuota(count=10, status=DynataStatus.CLOSED, + # condition_hashes=[c3.criterion_hash, c4.criterion_hash]) + # qg1 = DynataQuotaGroup(cells=[q1, q2]) + # # ---- + # + # s = DynataSurvey(survey_id='123', status=DynataStatus.OPEN, country_iso='us', + # language_iso='eng', group_id='123', client_id='123', project_id='12', + # filters=[grp1, grp2], + # quotas=[qg1]) + # ce = {'718f759': True, '7a7b290': True, 'eb9b9a4': True} + # s.passes_filters(ce) + # s.passes_quotas(ce) diff --git a/tests/models/dynata/test_survey.py b/tests/models/dynata/test_survey.py new file mode 100644 index 0000000..ad953a3 --- /dev/null +++ b/tests/models/dynata/test_survey.py @@ -0,0 +1,164 @@ +class TestDynataCondition: + + def test_condition_create(self): + from generalresearch.models.dynata.survey import DynataCondition + + cell = { + "tag": "90606986-5508-461b-a821-216e9a72f1a0", + "attribute_id": 120, + "negate": False, + "kind": "VALUE", + "value": "45398", + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"120": {"45398"}}) + assert not c.evaluate_criterion({"120": {"11111"}}) + + cell = { + "tag": "aa7169c0-cb34-499a-aadd-31e0013df8fd", + "attribute_id": 231302, + "negate": False, + "operator": "OR", + "kind": "LIST", + "list": ["514802", "514804", "514808", "514810"], + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"231302": {"514804", "123445"}}) + assert not c.evaluate_criterion({"231302": {"123445"}}) + + cell = { + "tag": "aa7169c0-cb34-499a-aadd-31e0013df8fd", + "attribute_id": 231302, + "negate": False, + "operator": "AND", + "kind": "LIST", + "list": ["514802", "514804"], + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"231302": {"514802", "514804"}}) + assert not c.evaluate_criterion({"231302": {"514802"}}) + + cell = { + "tag": "75a36c67-0328-4c1b-a4dd-67d34688ff68", + "attribute_id": 80, + "negate": False, + "kind": "RANGE", + "range": {"from": 18, "to": 99}, + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"80": {"20"}}) + assert not c.evaluate_criterion({"80": {"120"}}) + + cell = { + "tag": "dd64b622-ed10-4a3b-e1h8-a4e63b59vha2", + "attribute_id": 83, + "negate": False, + "kind": "INEFFABLE", + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"83": {"20"}}) + + cell = { + "tag": "kei35kkjj-d00k-52kj-b3j4-a4jinx9832", + "attribute_id": 8, + "negate": False, + "kind": "ANSWERED", + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"8": {"20"}}) + assert not c.evaluate_criterion({"81": {"20"}}) + + def test_condition_range(self): + from generalresearch.models.dynata.survey import DynataCondition + + cell = { + "tag": "75a36c67-0328-4c1b-a4dd-67d34688ff68", + "attribute_id": 80, + "negate": False, + "kind": "RANGE", + "range": {"from": 18, "to": None}, + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"80": {"20"}}) + + def test_recontact(self): + from generalresearch.models.dynata.survey import DynataCondition + + cell = { + "tag": "d559212d-7984-4239-89c2-06c29588d79e", + "attribute_id": 238384, + "negate": False, + "operator": "OR", + "kind": "INVITE_COLLECTIONS", + "invite_collections": ["621041", "621042"], + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"80": {"20"}}, user_groups={"621041", "a"}) + + +class TestDynataSurvey: + pass + + # def test_survey_eligibility(self): + # d = {'survey_id': 29333264, 'survey_name': '#29333264', 'survey_status': 22, + # 'field_end_date': datetime(2024, 5, 23, 18, 18, 31, tzinfo=timezone.utc), + # 'category': 'Exciting New', 'category_code': 232, + # 'crtd_on': datetime(2024, 5, 20, 17, 48, 13, tzinfo=timezone.utc), + # 'mod_on': datetime(2024, 5, 20, 18, 18, 31, tzinfo=timezone.utc), + # 'soft_launch': False, 'click_balancing': 0, 'price_type': 1, 'pii': False, + # 'buyer_message': '', 'buyer_id': 4726, 'incl_excl': 0, + # 'cpi': Decimal('1.20000'), 'last_complete_date': None, 'project_last_complete_date': None, + # 'quotas': [], 'qualifications': [], + # 'country_iso': 'fr', 'language_iso': 'fre', 'overall_ir': 0.4, 'overall_loi': 600, + # 'last_block_ir': None, 'last_block_loi': None, 'survey_exclusions': set(), 'exclusion_period': 0} + # s = DynataSurvey.from_api(d) + # s.qualifications = ['a', 'b', 'c'] + # s.quotas = [ + # SpectrumQuota(remaining_count=10, condition_hashes=['a', 'b']), + # SpectrumQuota(remaining_count=0, condition_hashes=['d']), + # SpectrumQuota(remaining_count=10, condition_hashes=['e']) + # ] + # + # self.assertTrue(s.passes_qualifications({'a': True, 'b': True, 'c': True})) + # self.assertFalse(s.passes_qualifications({'a': True, 'b': True, 'c': False})) + # + # # we do NOT match a full quota, so we pass + # self.assertTrue(s.passes_quotas({'a': True, 'b': True, 'd': False})) + # # We dont pass any + # self.assertFalse(s.passes_quotas({})) + # # we only pass a full quota + # self.assertFalse(s.passes_quotas({'d': True})) + # # we only dont pass a full quota, but we haven't passed any open + # self.assertFalse(s.passes_quotas({'d': False})) + # # we pass a quota, but also pass a full quota, so fail + # self.assertFalse(s.passes_quotas({'e': True, 'd': True})) + # # we pass a quota, but are unknown in a full quota, so fail + # self.assertFalse(s.passes_quotas({'e': True})) + # + # # # Soft Pair + # self.assertEqual((True, set()), s.passes_qualifications_soft({'a': True, 'b': True, 'c': True})) + # self.assertEqual((False, set()), s.passes_qualifications_soft({'a': True, 'b': True, 'c': False})) + # self.assertEqual((None, set('c')), s.passes_qualifications_soft({'a': True, 'b': True, 'c': None})) + # + # # we do NOT match a full quota, so we pass + # self.assertEqual((True, set()), s.passes_quotas_soft({'a': True, 'b': True, 'd': False})) + # # We dont pass any + # self.assertEqual((None, {'a', 'b', 'd', 'e'}), s.passes_quotas_soft({})) + # # we only pass a full quota + # self.assertEqual((False, set()), s.passes_quotas_soft({'d': True})) + # # we only dont pass a full quota, but we haven't passed any open + # self.assertEqual((None, {'a', 'b', 'e'}), s.passes_quotas_soft({'d': False})) + # # we pass a quota, but also pass a full quota, so fail + # self.assertEqual((False, set()), s.passes_quotas_soft({'e': True, 'd': True})) + # # we pass a quota, but are unknown in a full quota, so fail + # self.assertEqual((None, {'d'}), s.passes_quotas_soft({'e': True})) + # + # self.assertEqual(True, s.determine_eligibility({'a': True, 'b': True, 'c': True, 'd': False})) + # self.assertEqual(False, s.determine_eligibility({'a': True, 'b': True, 'c': False, 'd': False})) + # self.assertEqual(False, s.determine_eligibility({'a': True, 'b': True, 'c': None, 'd': False})) + # self.assertEqual((True, set()), s.determine_eligibility_soft({'a': True, 'b': True, 'c': True, 'd': False})) + # self.assertEqual((False, set()), s.determine_eligibility_soft({'a': True, 'b': True, 'c': False, 'd': False})) + # self.assertEqual((None, set('c')), s.determine_eligibility_soft({'a': True, 'b': True, 'c': None, + # 'd': False})) + # self.assertEqual((None, {'c', 'd'}), s.determine_eligibility_soft({'a': True, 'b': True, 'c': None, + # 'd': None})) diff --git a/tests/models/gr/__init__.py b/tests/models/gr/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/models/gr/__init__.py diff --git a/tests/models/gr/test_authentication.py b/tests/models/gr/test_authentication.py new file mode 100644 index 0000000..e906d8c --- /dev/null +++ b/tests/models/gr/test_authentication.py @@ -0,0 +1,313 @@ +import binascii +import json +import os +from datetime import datetime, timezone +from random import randint +from uuid import uuid4 + +import pytest + +SSO_ISSUER = "" + + +class TestGRUser: + + def test_init(self, gr_user): + from generalresearch.models.gr.authentication import GRUser + + assert isinstance(gr_user, GRUser) + assert not gr_user.is_superuser + + assert gr_user.teams is None + assert gr_user.businesses is None + assert gr_user.products is None + + @pytest.mark.skip(reason="TODO") + def test_businesses(self): + pass + + def test_teams(self, gr_user, membership, gr_db, gr_redis_config): + from generalresearch.models.gr.team import Team + + assert gr_user.teams is None + + gr_user.prefetch_teams(pg_config=gr_db, redis_config=gr_redis_config) + + assert isinstance(gr_user.teams, list) + assert len(gr_user.teams) == 1 + assert isinstance(gr_user.teams[0], Team) + + def test_prefetch_team_duplicates( + self, + gr_user_token, + gr_user, + membership, + product_factory, + membership_factory, + team, + thl_web_rr, + gr_redis_config, + gr_db, + ): + product_factory(team=team) + membership_factory(team=team, gr_user=gr_user) + + gr_user.prefetch_teams( + pg_config=gr_db, + redis_config=gr_redis_config, + ) + + assert len(gr_user.teams) == 1 + + def test_products( + self, + gr_user, + product_factory, + team, + membership, + gr_db, + thl_web_rr, + gr_redis_config, + ): + from generalresearch.models.thl.product import Product + + assert gr_user.products is None + + # Create a new Team membership, and then create a Product that + # is part of that team + membership.prefetch_team(pg_config=gr_db, redis_config=gr_redis_config) + p: Product = product_factory(team=team) + assert p.id_int + assert team.uuid == membership.team.uuid + assert p.team_id == team.uuid + assert p.team_uuid == membership.team.uuid + assert gr_user.id == membership.user_id + + gr_user.prefetch_products( + pg_config=gr_db, + thl_pg_config=thl_web_rr, + redis_config=gr_redis_config, + ) + assert isinstance(gr_user.products, list) + assert len(gr_user.products) == 1 + assert isinstance(gr_user.products[0], Product) + + +class TestGRUserMethods: + + def test_cache_key(self, gr_user, gr_redis): + assert isinstance(gr_user.cache_key, str) + assert ":" in gr_user.cache_key + assert str(gr_user.id) in gr_user.cache_key + + def test_to_redis( + self, + gr_user, + gr_redis, + team, + business, + product_factory, + membership_factory, + ): + product_factory(team=team, business=business) + membership_factory(team=team, gr_user=gr_user) + + res = gr_user.to_redis() + assert isinstance(res, str) + + from generalresearch.models.gr.authentication import GRUser + + instance = GRUser.from_redis(res) + assert isinstance(instance, GRUser) + + def test_set_cache( + self, + gr_user, + gr_user_token, + gr_redis, + gr_db, + thl_web_rr, + gr_redis_config, + ): + assert gr_redis.get(name=gr_user.cache_key) is None + assert gr_redis.get(name=f"{gr_user.cache_key}:team_uuids") is None + assert gr_redis.get(name=f"{gr_user.cache_key}:business_uuids") is None + assert gr_redis.get(name=f"{gr_user.cache_key}:product_uuids") is None + + gr_user.set_cache( + pg_config=gr_db, thl_web_rr=thl_web_rr, redis_config=gr_redis_config + ) + + assert gr_redis.get(name=gr_user.cache_key) is not None + assert gr_redis.get(name=f"{gr_user.cache_key}:team_uuids") is not None + assert gr_redis.get(name=f"{gr_user.cache_key}:business_uuids") is not None + assert gr_redis.get(name=f"{gr_user.cache_key}:product_uuids") is not None + + def test_set_cache_gr_user( + self, + gr_user, + gr_user_token, + gr_redis, + gr_redis_config, + gr_db, + thl_web_rr, + product_factory, + team, + membership_factory, + thl_redis_config, + ): + from generalresearch.models.gr.authentication import GRUser + + p1 = product_factory(team=team) + membership_factory(team=team, gr_user=gr_user) + + gr_user.set_cache( + pg_config=gr_db, thl_web_rr=thl_web_rr, redis_config=gr_redis_config + ) + + res: str = gr_redis.get(name=gr_user.cache_key) + gru2 = GRUser.from_redis(res) + + assert gr_user.model_dump_json( + exclude={"businesses", "teams", "products"} + ) == gru2.model_dump_json(exclude={"businesses", "teams", "products"}) + + gru2.prefetch_products( + pg_config=gr_db, + thl_pg_config=thl_web_rr, + redis_config=thl_redis_config, + ) + assert gru2.product_uuids == [p1.uuid] + + def test_set_cache_team_uuids( + self, + gr_user, + membership, + gr_user_token, + gr_redis, + gr_db, + thl_web_rr, + product_factory, + team, + gr_redis_config, + ): + product_factory(team=team) + + gr_user.set_cache( + pg_config=gr_db, thl_web_rr=thl_web_rr, redis_config=gr_redis_config + ) + res = json.loads(gr_redis.get(name=f"{gr_user.cache_key}:team_uuids")) + assert len(res) == 1 + assert gr_user.team_uuids == res + + @pytest.mark.skip + def test_set_cache_business_uuids( + self, + gr_user, + membership, + gr_user_token, + gr_redis, + gr_db, + thl_web_rr, + product_factory, + business, + team, + gr_redis_config, + ): + product_factory(team=team, business=business) + + gr_user.set_cache( + pg_config=gr_db, thl_web_rr=thl_web_rr, redis_config=gr_redis_config + ) + res = json.loads(gr_redis.get(name=f"{gr_user.cache_key}:business_uuids")) + assert len(res) == 1 + assert gr_user.business_uuids == res + + def test_set_cache_product_uuids( + self, + gr_user, + membership, + gr_user_token, + gr_redis, + gr_db, + thl_web_rr, + product_factory, + team, + gr_redis_config, + ): + product_factory(team=team) + + gr_user.set_cache( + pg_config=gr_db, thl_web_rr=thl_web_rr, redis_config=gr_redis_config + ) + res = json.loads(gr_redis.get(name=f"{gr_user.cache_key}:product_uuids")) + assert len(res) == 1 + assert gr_user.product_uuids == res + + +class TestGRToken: + + @pytest.fixture + def gr_token(self, gr_user): + from generalresearch.models.gr.authentication import GRToken + + now = datetime.now(tz=timezone.utc) + token = binascii.hexlify(os.urandom(20)).decode() + + gr_token = GRToken(key=token, created=now, user_id=gr_user.id) + + return gr_token + + def test_init(self, gr_token): + from generalresearch.models.gr.authentication import GRToken + + assert isinstance(gr_token, GRToken) + assert gr_token.created + + def test_user(self, gr_token, gr_db, gr_redis_config): + from generalresearch.models.gr.authentication import GRUser + + assert gr_token.user is None + + gr_token.prefetch_user(pg_config=gr_db, redis_config=gr_redis_config) + + assert isinstance(gr_token.user, GRUser) + + def test_auth_header(self, gr_token): + assert isinstance(gr_token.auth_header, dict) + + +class TestClaims: + + def test_init(self): + from generalresearch.models.gr.authentication import Claims + + d = { + "iss": SSO_ISSUER, + "sub": f"{uuid4().hex}{uuid4().hex}", + "aud": uuid4().hex, + "exp": randint(a=1_500_000_000, b=2_000_000_000), + "iat": randint(a=1_500_000_000, b=2_000_000_000), + "auth_time": randint(a=1_500_000_000, b=2_000_000_000), + "acr": "goauthentik.io/providers/oauth2/default", + "amr": ["pwd", "mfa"], + "sid": f"{uuid4().hex}{uuid4().hex}", + "email": "max@g-r-l.com", + "email_verified": True, + "name": "Max Nanis", + "given_name": "Max Nanis", + "preferred_username": "nanis", + "nickname": "nanis", + "groups": [ + "authentik Admins", + "Developers", + "Systems Admin", + "Customer Support", + "admin", + ], + "azp": uuid4().hex, + "uid": uuid4().hex, + } + instance = Claims.model_validate(d) + + assert isinstance(instance, Claims) diff --git a/tests/models/gr/test_business.py b/tests/models/gr/test_business.py new file mode 100644 index 0000000..9a7718d --- /dev/null +++ b/tests/models/gr/test_business.py @@ -0,0 +1,1432 @@ +import os +from datetime import datetime, timedelta, timezone +from decimal import Decimal +from typing import Optional +from uuid import uuid4 + +import pandas as pd +import pytest + +# noinspection PyUnresolvedReferences +from distributed.utils_test import ( + gen_cluster, + client_no_amm, + loop, + loop_in_thread, + cleanup, + cluster_fixture, + client, +) +from pytest import approx + +from generalresearch.currency import USDCent +from generalresearch.models.thl.finance import ( + ProductBalances, + BusinessBalances, +) + +# from test_utils.incite.conftest import mnt_filepath +from test_utils.managers.conftest import ( + business_bank_account_manager, + lm, + thl_lm, +) + + +class TestBusinessBankAccount: + + def test_init(self, business, business_bank_account_manager): + from generalresearch.models.gr.business import ( + BusinessBankAccount, + TransferMethod, + ) + + instance = business_bank_account_manager.create( + business_id=business.id, + uuid=uuid4().hex, + transfer_method=TransferMethod.ACH, + ) + assert isinstance(instance, BusinessBankAccount) + + def test_business(self, business_bank_account, business, gr_db, gr_redis_config): + from generalresearch.models.gr.business import Business + + assert business_bank_account.business is None + + business_bank_account.prefetch_business( + pg_config=gr_db, redis_config=gr_redis_config + ) + assert isinstance(business_bank_account.business, Business) + assert business_bank_account.business.uuid == business.uuid + + +class TestBusinessAddress: + + def test_init(self, business_address): + from generalresearch.models.gr.business import BusinessAddress + + assert isinstance(business_address, BusinessAddress) + + +class TestBusinessContact: + + def test_init(self): + from generalresearch.models.gr.business import BusinessContact + + bc = BusinessContact(name="abc", email="test@abc.com") + assert isinstance(bc, BusinessContact) + + +class TestBusiness: + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "30d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return None + + def test_init(self, business): + from generalresearch.models.gr.business import Business + + assert isinstance(business, Business) + assert isinstance(business.id, int) + assert isinstance(business.uuid, str) + + def test_str_and_repr( + self, + business, + product_factory, + thl_web_rr, + lm, + thl_lm, + business_payout_event_manager, + bp_payout_factory, + start, + user_factory, + session_with_tx_factory, + pop_ledger_merge, + client_no_amm, + ledger_collection, + mnt_filepath, + create_main_accounts, + ): + create_main_accounts() + p1 = product_factory(business=business) + u1 = user_factory(product=p1) + p2 = product_factory(business=business) + thl_lm.get_account_or_create_bp_wallet(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p2) + + res1 = repr(business) + + assert business.uuid in res1 + assert "<Business: " in res1 + + res2 = str(business) + + assert business.uuid in res2 + assert "Name:" in res2 + assert "Not Loaded" in res2 + + business.prefetch_products(thl_pg_config=thl_web_rr) + business.prefetch_bp_accounts(lm=lm, thl_pg_config=thl_web_rr) + res3 = str(business) + assert "Products: 2" in res3 + assert "Ledger Accounts: 2" in res3 + + # -- need some tx to make these interesting + business_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("2.50"), + started=start + timedelta(days=5), + ) + bp_payout_factory( + product=p1, + amount=USDCent(50), + created=start + timedelta(days=4), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + res4 = str(business) + assert "Payouts: 1" in res4 + assert "Available Balance: 141" in res4 + + def test_addresses(self, business, business_address, gr_db): + from generalresearch.models.gr.business import BusinessAddress + + assert business.addresses is None + + business.prefetch_addresses(pg_config=gr_db) + assert isinstance(business.addresses, list) + assert len(business.addresses) == 1 + assert isinstance(business.addresses[0], BusinessAddress) + + def test_teams(self, business, team, team_manager, gr_db): + assert business.teams is None + + business.prefetch_teams(pg_config=gr_db) + assert isinstance(business.teams, list) + assert len(business.teams) == 0 + + team_manager.add_business(team=team, business=business) + assert len(business.teams) == 0 + business.prefetch_teams(pg_config=gr_db) + assert len(business.teams) == 1 + + def test_products(self, business, product_factory, thl_web_rr): + from generalresearch.models.thl.product import Product + + p1 = product_factory(business=business) + assert business.products is None + + business.prefetch_products(thl_pg_config=thl_web_rr) + assert isinstance(business.products, list) + assert len(business.products) == 1 + assert isinstance(business.products[0], Product) + + assert business.products[0].uuid == p1.uuid + + # Add two more, but list is still one until we prefetch + p2 = product_factory(business=business) + p3 = product_factory(business=business) + assert len(business.products) == 1 + + business.prefetch_products(thl_pg_config=thl_web_rr) + assert len(business.products) == 3 + + def test_bank_accounts(self, business, business_bank_account, gr_db): + assert business.products is None + + # It's an empty list after prefetch + business.prefetch_bank_accounts(pg_config=gr_db) + assert isinstance(business.bank_accounts, list) + assert len(business.bank_accounts) == 1 + + def test_balance( + self, + business, + mnt_filepath, + client_no_amm, + thl_web_rr, + lm, + pop_ledger_merge, + ): + assert business.balance is None + + with pytest.raises(expected_exception=AssertionError) as cm: + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + assert "Cannot build Business Balance" in str(cm.value) + assert business.balance is None + + # TODO: Add parquet building so that this doesn't fail and we can + # properly assign a business.balance + + def test_payouts_no_accounts( + self, + business, + product_factory, + thl_web_rr, + thl_lm, + business_payout_event_manager, + ): + assert business.payouts is None + + with pytest.raises(expected_exception=AssertionError) as cm: + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + assert "Must provide product_uuids" in str(cm.value) + + p = product_factory(business=business) + thl_lm.get_account_or_create_bp_wallet(product=p) + + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + assert isinstance(business.payouts, list) + assert len(business.payouts) == 0 + + def test_payouts( + self, + business, + product_factory, + bp_payout_factory, + thl_lm, + thl_web_rr, + business_payout_event_manager, + create_main_accounts, + ): + create_main_accounts() + p = product_factory(business=business) + thl_lm.get_account_or_create_bp_wallet(product=p) + business_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + bp_payout_factory( + product=p, amount=USDCent(123), skip_wallet_balance_check=True + ) + + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + assert len(business.payouts) == 1 + assert sum([p.amount for p in business.payouts]) == 123 + + # Add another! + bp_payout_factory( + product=p, + amount=USDCent(123), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + business_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + assert len(business.payouts) == 1 + assert len(business.payouts[0].bp_payouts) == 2 + assert sum([p.amount for p in business.payouts]) == 246 + + def test_payouts_totals( + self, + business, + product_factory, + bp_payout_factory, + thl_lm, + thl_web_rr, + business_payout_event_manager, + create_main_accounts, + ): + from generalresearch.models.thl.product import Product + + create_main_accounts() + + p1: Product = product_factory(business=business) + thl_lm.get_account_or_create_bp_wallet(product=p1) + business_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + bp_payout_factory( + product=p1, + amount=USDCent(1), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=p1, + amount=USDCent(25), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=p1, + amount=USDCent(50), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + + assert len(business.payouts) == 1 + assert len(business.payouts[0].bp_payouts) == 3 + assert business.payouts_total == USDCent(76) + assert business.payouts_total_str == "$0.76" + + def test_pop_financial( + self, + business, + thl_web_rr, + lm, + mnt_filepath, + client_no_amm, + pop_ledger_merge, + ): + assert business.pop_financial is None + business.prebuild_pop_financial( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + assert business.pop_financial == [] + + def test_bp_accounts(self, business, lm, thl_web_rr, product_factory, thl_lm): + assert business.bp_accounts is None + business.prefetch_bp_accounts(lm=lm, thl_pg_config=thl_web_rr) + assert business.bp_accounts == [] + + from generalresearch.models.thl.product import Product + + p1: Product = product_factory(business=business) + thl_lm.get_account_or_create_bp_wallet(product=p1) + + business.prefetch_bp_accounts(lm=lm, thl_pg_config=thl_web_rr) + assert len(business.bp_accounts) == 1 + + +class TestBusinessBalance: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "30d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return None + + @pytest.mark.skip + def test_product_ordering(self): + # Assert that the order of business.balance.product_balances is always + # consistent and in the same order based off product.created ASC + pass + + def test_single_product( + self, + business, + product_factory, + user_factory, + mnt_filepath, + bp_payout_factory, + thl_lm, + lm, + duration, + offset, + start, + thl_web_rr, + payout_event_manager, + session_with_tx_factory, + delete_ledger_db, + create_main_accounts, + client_no_amm, + ledger_collection, + pop_ledger_merge, + delete_df_collection, + ): + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + u2: User = user_factory(product=p1) + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + + session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal("1.25"), + started=start + timedelta(days=2), + ) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + assert isinstance(business.balance, BusinessBalances) + assert business.balance.payout == 190 + assert business.balance.adjustment == 0 + assert business.balance.net == 190 + assert business.balance.retainer == 47 + assert business.balance.available_balance == 143 + + assert len(business.balance.product_balances) == 1 + pb = business.balance.product_balances[0] + assert isinstance(pb, ProductBalances) + assert pb.balance == business.balance.balance + assert pb.available_balance == business.balance.available_balance + assert pb.adjustment_percent == 0.0 + + def test_multi_product( + self, + business, + product_factory, + user_factory, + mnt_filepath, + bp_payout_factory, + thl_lm, + lm, + duration, + offset, + start, + thl_web_rr, + payout_event_manager, + session_with_tx_factory, + delete_ledger_db, + create_main_accounts, + client_no_amm, + ledger_collection, + pop_ledger_merge, + delete_df_collection, + ): + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product_factory(business=business)) + u2: User = user_factory(product=product_factory(business=business)) + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + + session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal("1.25"), + started=start + timedelta(days=2), + ) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + assert isinstance(business.balance, BusinessBalances) + assert business.balance.payout == 190 + assert business.balance.balance == 190 + assert business.balance.adjustment == 0 + assert business.balance.net == 190 + assert business.balance.retainer == 46 + assert business.balance.available_balance == 144 + + assert len(business.balance.product_balances) == 2 + + pb1 = business.balance.product_balances[0] + pb2 = business.balance.product_balances[1] + assert isinstance(pb1, ProductBalances) + assert pb1.product_id == u1.product_id + assert isinstance(pb2, ProductBalances) + assert pb2.product_id == u2.product_id + + for pb in [pb1, pb2]: + assert pb.balance != business.balance.balance + assert pb.available_balance != business.balance.available_balance + assert pb.adjustment_percent == 0.0 + + assert pb1.product_id in [u1.product_id, u2.product_id] + assert pb1.payout == 71 + assert pb1.adjustment == 0 + assert pb1.expense == 0 + assert pb1.net == 71 + assert pb1.retainer == 17 + assert pb1.available_balance == 54 + + assert pb2.product_id in [u1.product_id, u2.product_id] + assert pb2.payout == 119 + assert pb2.adjustment == 0 + assert pb2.expense == 0 + assert pb2.net == 119 + assert pb2.retainer == 29 + assert pb2.available_balance == 90 + + def test_multi_product_multi_payout( + self, + business, + product_factory, + user_factory, + mnt_filepath, + bp_payout_factory, + thl_lm, + lm, + duration, + offset, + start, + thl_web_rr, + payout_event_manager, + session_with_tx_factory, + delete_ledger_db, + create_main_accounts, + client_no_amm, + ledger_collection, + pop_ledger_merge, + delete_df_collection, + ): + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product_factory(business=business)) + u2: User = user_factory(product=product_factory(business=business)) + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + + session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal("1.25"), + started=start + timedelta(days=2), + ) + + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + bp_payout_factory( + product=u1.product, + amount=USDCent(5), + created=start + timedelta(days=4), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=u2.product, + amount=USDCent(50), + created=start + timedelta(days=4), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + assert business.balance.payout == 190 + assert business.balance.net == 190 + + assert business.balance.balance == 135 + + def test_multi_product_multi_payout_adjustment( + self, + business, + product_factory, + user_factory, + mnt_filepath, + bp_payout_factory, + thl_lm, + lm, + duration, + offset, + start, + thl_web_rr, + payout_event_manager, + session_with_tx_factory, + delete_ledger_db, + create_main_accounts, + client_no_amm, + ledger_collection, + task_adj_collection, + pop_ledger_merge, + wall_manager, + session_manager, + adj_to_fail_with_tx_factory, + delete_df_collection, + ): + """ + - Product 1 $2.50 Complete + - Product 2 $2.50 Complete + - $2.50 Payout on Product 1 + - $0.50 Payout on Product 2 + - Product 3 $2.50 Complete + - Complete -> Failure $2.50 Adjustment on Product 1 + ==== + - Net: $7.50 * .95 = $7.125 + - $2.50 = $2.375 = $2.38 + - $2.50 = $2.375 = $2.38 + - $2.50 = $2.375 = $2.38 + ==== + - $7.14 + - Balance: $2 + """ + + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + delete_df_collection(coll=task_adj_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product_factory(business=business)) + u2: User = user_factory(product=product_factory(business=business)) + u3: User = user_factory(product=product_factory(business=business)) + + s1 = session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("2.50"), + started=start + timedelta(days=1), + ) + + session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal("2.50"), + started=start + timedelta(days=2), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + bp_payout_factory( + product=u1.product, + amount=USDCent(250), + created=start + timedelta(days=3), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=u2.product, + amount=USDCent(50), + created=start + timedelta(days=4), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + adj_to_fail_with_tx_factory(session=s1, created=start + timedelta(days=5)) + + session_with_tx_factory( + user=u3, + wall_req_cpi=Decimal("2.50"), + started=start + timedelta(days=6), + ) + + # Build and prepare the Business with the db transactions now in place + + # This isn't needed for Business Balance... but good to also check + # task_adj_collection.initial_load(client=None, sync=True) + # These are the only two that are needed for Business Balance + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + df = client_no_amm.compute(ledger_collection.ddf(), sync=True) + assert df.shape == (24, 24) + + df = client_no_amm.compute(pop_ledger_merge.ddf(), sync=True) + assert df.shape == (20, 28) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + assert business.balance.payout == 714 + assert business.balance.adjustment == -238 + + assert business.balance.product_balances[0].adjustment == -238 + assert business.balance.product_balances[1].adjustment == 0 + assert business.balance.product_balances[2].adjustment == 0 + + assert business.balance.expense == 0 + assert business.balance.net == 714 - 238 + assert business.balance.balance == business.balance.payout - (250 + 50 + 238) + + predicted_retainer = sum( + [ + pb.balance * 0.25 + for pb in business.balance.product_balances + if pb.balance > 0 + ] + ) + assert business.balance.retainer == approx(predicted_retainer, rel=0.01) + + def test_neg_balance_cache( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + payout_event_manager, + adj_to_fail_with_tx_factory, + thl_web_rr, + lm, + ): + """Test having a Business with two products.. one that lost money + and one that gained money. Ensure that the Business balance + reflects that to compensate for the Product in the negative. + """ + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + p2: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + u2: User = user_factory(product=p2) + thl_lm.get_account_or_create_bp_wallet(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p2) + + # Product 1: Complete, Payout, Recon.. + s1 = session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + bp_payout_factory( + product=u1.product, + amount=USDCent(71), + ext_ref_id=uuid4().hex, + created=start + timedelta(days=1, minutes=1), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + adj_to_fail_with_tx_factory( + session=s1, + created=start + timedelta(days=1, minutes=2), + ) + + # Product 2: Complete, Complete. + s2 = session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1, minutes=3), + ) + s3 = session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1, minutes=4), + ) + + # Finally, process everything: + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + # Check Product 1 + pb1 = business.balance.product_balances[0] + assert pb1.product_id == p1.uuid + assert pb1.payout == 71 + assert pb1.adjustment == -71 + assert pb1.net == 0 + assert pb1.balance == 71 - (71 * 2) + assert pb1.retainer == 0 + assert pb1.available_balance == 0 + + # Check Product 2 + pb2 = business.balance.product_balances[1] + assert pb2.product_id == p2.uuid + assert pb2.payout == 71 * 2 + assert pb2.adjustment == 0 + assert pb2.net == 71 * 2 + assert pb2.balance == (71 * 2) + assert pb2.retainer == pytest.approx((71 * 2) * 0.25, rel=1) + assert pb2.available_balance == 107 + + # Check Business + bb1 = business.balance + assert bb1.payout == (71 * 3) # Raw total of completes + assert bb1.adjustment == -71 # 1 Complete >> Failure + assert bb1.expense == 0 + assert bb1.net == (71 * 3) - 71 # How much the Business actually earned + assert ( + bb1.balance == (71 * 3) - 71 - 71 + ) # 3 completes, but 1 payout and 1 recon leaves only one complete + # worth of activity on the account + assert bb1.retainer == pytest.approx((71 * 2) * 0.25, rel=1) + assert bb1.available_balance_usd_str == "$0.36" + + # Confirm that the debt from the pb1 in the red is covered when + # calculating the Business balance by the profit of pb2 + assert pb2.available_balance + pb1.balance == bb1.available_balance + + def test_multi_product_multi_payout_adjustment_at_timestamp( + self, + business, + product_factory, + user_factory, + mnt_filepath, + bp_payout_factory, + thl_lm, + lm, + duration, + offset, + start, + thl_web_rr, + payout_event_manager, + session_with_tx_factory, + delete_ledger_db, + create_main_accounts, + client_no_amm, + ledger_collection, + task_adj_collection, + pop_ledger_merge, + wall_manager, + session_manager, + adj_to_fail_with_tx_factory, + delete_df_collection, + ): + """ + This test measures a complex Business situation, but then makes + various assertions based off the query which uses an at_timestamp. + + The goal here is a feature that allows us to look back and see + what the balance was of an account at any specific point in time. + + - Day 1: Product 1 $2.50 Complete + - Total Payout: $2.38 + - Smart Retainer: $0.59 + - Available Balance: $1.79 + - Day 2: Product 2 $2.50 Complete + - Total Payout: $4.76 + - Smart Retainer: $1.18 + - Available Balance: $3.58 + - Day 3: $2.50 Payout on Product 1 + - Total Payout: $4.76 + - Smart Retainer: $0.59 + - Available Balance: $1.67 + - Day 4: $0.50 Payout on Product 2 + - Total Payout: $4.76 + - Smart Retainer: $0.47 + - Available Balance: $1.29 + - Day 5: Product 3 $2.50 Complete + - Total Payout: $7.14 + - Smart Retainer: $1.06 + - Available Balance: $3.08 + - Day 6: Complete -> Failure $2.50 Adjustment on Product 1 + - Total Payout: $7.18 + - Smart Retainer: $1.06 + - Available Balance: $0.70 + """ + + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + delete_df_collection(coll=task_adj_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product_factory(business=business)) + u2: User = user_factory(product=product_factory(business=business)) + u3: User = user_factory(product=product_factory(business=business)) + + s1 = session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("2.50"), + started=start + timedelta(days=1), + ) + + session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal("2.50"), + started=start + timedelta(days=2), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + bp_payout_factory( + product=u1.product, + amount=USDCent(250), + created=start + timedelta(days=3), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=u2.product, + amount=USDCent(50), + created=start + timedelta(days=4), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + session_with_tx_factory( + user=u3, + wall_req_cpi=Decimal("2.50"), + started=start + timedelta(days=5), + ) + + adj_to_fail_with_tx_factory(session=s1, created=start + timedelta(days=6)) + + # Build and prepare the Business with the db transactions now in place + + # This isn't needed for Business Balance... but good to also check + # task_adj_collection.initial_load(client=None, sync=True) + # These are the only two that are needed for Business Balance + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + df = client_no_amm.compute(ledger_collection.ddf(), sync=True) + assert df.shape == (24, 24) + + df = client_no_amm.compute(pop_ledger_merge.ddf(), sync=True) + assert df.shape == (20, 28) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + at_timestamp=start + timedelta(days=1, hours=1), + ) + day1_bal = business.balance + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + at_timestamp=start + timedelta(days=2, hours=1), + ) + day2_bal = business.balance + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + at_timestamp=start + timedelta(days=3, hours=1), + ) + day3_bal = business.balance + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + at_timestamp=start + timedelta(days=4, hours=1), + ) + day4_bal = business.balance + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + at_timestamp=start + timedelta(days=5, hours=1), + ) + day5_bal = business.balance + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + at_timestamp=start + timedelta(days=6, hours=1), + ) + day6_bal = business.balance + + assert day1_bal.payout == 238 + assert day1_bal.retainer == 59 + assert day1_bal.available_balance == 179 + + assert day2_bal.payout == 476 + assert day2_bal.retainer == 118 + assert day2_bal.available_balance == 358 + + assert day3_bal.payout == 476 + assert day3_bal.retainer == 59 + assert day3_bal.available_balance == 167 + + assert day4_bal.payout == 476 + assert day4_bal.retainer == 47 + assert day4_bal.available_balance == 129 + + assert day5_bal.payout == 714 + assert day5_bal.retainer == 106 + assert day5_bal.available_balance == 308 + + assert day6_bal.payout == 714 + assert day6_bal.retainer == 106 + assert day6_bal.available_balance == 70 + + +class TestBusinessMethods: + + @pytest.fixture(scope="function") + def start(self, utc_90days_ago) -> "datetime": + s = utc_90days_ago.replace(microsecond=0) + return s + + @pytest.fixture(scope="function") + def offset(self) -> str: + return "15d" + + @pytest.fixture(scope="function") + def duration( + self, + ) -> Optional["timedelta"]: + return None + + def test_cache_key(self, business, gr_redis): + assert isinstance(business.cache_key, str) + assert ":" in business.cache_key + assert str(business.uuid) in business.cache_key + + def test_set_cache( + self, + business, + gr_redis, + gr_db, + thl_web_rr, + client_no_amm, + mnt_filepath, + lm, + thl_lm, + business_payout_event_manager, + product_factory, + membership_factory, + team, + session_with_tx_factory, + user_factory, + ledger_collection, + pop_ledger_merge, + utc_60days_ago, + delete_ledger_db, + create_main_accounts, + gr_redis_config, + mnt_gr_api_dir, + ): + assert gr_redis.get(name=business.cache_key) is None + + p1 = product_factory(team=team, business=business) + u1 = user_factory(product=p1) + + # Business needs tx & incite to build balance + delete_ledger_db() + create_main_accounts() + thl_lm.get_account_or_create_bp_wallet(product=p1) + session_with_tx_factory(user=u1, started=utc_60days_ago) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.set_cache( + pg_config=gr_db, + thl_web_rr=thl_web_rr, + redis_config=gr_redis_config, + client=client_no_amm, + ds=mnt_filepath, + lm=lm, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + pop_ledger=pop_ledger_merge, + mnt_gr_api=mnt_gr_api_dir, + ) + + assert gr_redis.hgetall(name=business.cache_key) is not None + from generalresearch.models.gr.business import Business + + # We're going to pull only a specific year, but make sure that + # it's being assigned to the field regardless + year = datetime.now(tz=timezone.utc).year + res = Business.from_redis( + uuid=business.uuid, + fields=[f"pop_financial:{year}"], + gr_redis_config=gr_redis_config, + ) + assert len(res.pop_financial) > 0 + + def test_set_cache_business( + self, + gr_user, + business, + gr_user_token, + gr_redis, + gr_db, + thl_web_rr, + product_factory, + team, + membership_factory, + client_no_amm, + mnt_filepath, + lm, + thl_lm, + business_payout_event_manager, + user_factory, + delete_ledger_db, + create_main_accounts, + session_with_tx_factory, + ledger_collection, + team_manager, + pop_ledger_merge, + gr_redis_config, + utc_60days_ago, + mnt_gr_api_dir, + ): + from generalresearch.models.gr.business import Business + + p1 = product_factory(team=team, business=business) + u1 = user_factory(product=p1) + team_manager.add_business(team=team, business=business) + + # Business needs tx & incite to build balance + delete_ledger_db() + create_main_accounts() + thl_lm.get_account_or_create_bp_wallet(product=p1) + session_with_tx_factory(user=u1, started=utc_60days_ago) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.set_cache( + pg_config=gr_db, + thl_web_rr=thl_web_rr, + redis_config=gr_redis_config, + client=client_no_amm, + ds=mnt_filepath, + lm=lm, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + pop_ledger=pop_ledger_merge, + mnt_gr_api=mnt_gr_api_dir, + ) + + # keys: List = Business.required_fields() + ["products", "bp_accounts"] + business2 = Business.from_redis( + uuid=business.uuid, + fields=[ + "id", + "tax_number", + "contact", + "addresses", + "teams", + "products", + "bank_accounts", + "balance", + "payouts_total_str", + "payouts_total", + "payouts", + "pop_financial", + "bp_accounts", + ], + gr_redis_config=gr_redis_config, + ) + + assert business.model_dump_json() == business2.model_dump_json() + assert p1.uuid in [p.uuid for p in business2.products] + assert len(business2.teams) == 1 + assert team.uuid in [t.uuid for t in business2.teams] + + assert business2.balance.payout == 48 + assert business2.balance.balance == 48 + assert business2.balance.net == 48 + assert business2.balance.retainer == 12 + assert business2.balance.available_balance == 36 + assert len(business2.balance.product_balances) == 1 + + assert len(business2.payouts) == 0 + + assert len(business2.bp_accounts) == 1 + assert len(business2.bp_accounts) == len(business2.product_uuids) + + assert len(business2.pop_financial) == 1 + assert business2.pop_financial[0].payout == business2.balance.payout + assert business2.pop_financial[0].net == business2.balance.net + + def test_prebuild_enriched_session_parquet( + self, + event_report_request, + enriched_session_merge, + client_no_amm, + wall_collection, + session_collection, + thl_web_rr, + session_report_request, + user_factory, + start, + session_factory, + product_factory, + delete_df_collection, + business, + mnt_filepath, + mnt_gr_api_dir, + ): + + delete_df_collection(coll=wall_collection) + delete_df_collection(coll=session_collection) + + p1 = product_factory(business=business) + p2 = product_factory(business=business) + + for p in [p1, p2]: + u = user_factory(product=p) + for i in range(50): + s = session_factory( + user=u, + wall_count=1, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(minutes=i, seconds=1), + ) + wall_collection.initial_load(client=None, sync=True) + session_collection.initial_load(client=None, sync=True) + + enriched_session_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + + business.prebuild_enriched_session_parquet( + thl_pg_config=thl_web_rr, + ds=mnt_filepath, + client=client_no_amm, + mnt_gr_api=mnt_gr_api_dir, + enriched_session=enriched_session_merge, + ) + + # Now try to read from path + df = pd.read_parquet( + os.path.join(mnt_gr_api_dir, "pop_session", f"{business.file_key}.parquet") + ) + assert isinstance(df, pd.DataFrame) + + def test_prebuild_enriched_wall_parquet( + self, + event_report_request, + enriched_session_merge, + enriched_wall_merge, + client_no_amm, + wall_collection, + session_collection, + thl_web_rr, + session_report_request, + user_factory, + start, + session_factory, + product_factory, + delete_df_collection, + business, + mnt_filepath, + mnt_gr_api_dir, + ): + + delete_df_collection(coll=wall_collection) + delete_df_collection(coll=session_collection) + + p1 = product_factory(business=business) + p2 = product_factory(business=business) + + for p in [p1, p2]: + u = user_factory(product=p) + for i in range(50): + s = session_factory( + user=u, + wall_count=1, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(minutes=i, seconds=1), + ) + wall_collection.initial_load(client=None, sync=True) + session_collection.initial_load(client=None, sync=True) + + enriched_wall_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + + business.prebuild_enriched_wall_parquet( + thl_pg_config=thl_web_rr, + ds=mnt_filepath, + client=client_no_amm, + mnt_gr_api=mnt_gr_api_dir, + enriched_wall=enriched_wall_merge, + ) + + # Now try to read from path + df = pd.read_parquet( + os.path.join(mnt_gr_api_dir, "pop_event", f"{business.file_key}.parquet") + ) + assert isinstance(df, pd.DataFrame) diff --git a/tests/models/gr/test_team.py b/tests/models/gr/test_team.py new file mode 100644 index 0000000..d728bbe --- /dev/null +++ b/tests/models/gr/test_team.py @@ -0,0 +1,296 @@ +import os +from datetime import timedelta +from decimal import Decimal + +import pandas as pd + + +class TestTeam: + + def test_init(self, team): + from generalresearch.models.gr.team import Team + + assert isinstance(team, Team) + assert isinstance(team.id, int) + assert isinstance(team.uuid, str) + + def test_memberships_none(self, team, gr_user_factory, gr_db): + assert team.memberships is None + + team.prefetch_memberships(pg_config=gr_db) + assert isinstance(team.memberships, list) + assert len(team.memberships) == 0 + + def test_memberships( + self, + team, + membership, + gr_user, + gr_user_factory, + membership_factory, + membership_manager, + gr_db, + ): + assert team.memberships is None + + team.prefetch_memberships(pg_config=gr_db) + assert isinstance(team.memberships, list) + assert len(team.memberships) == 1 + assert team.memberships[0].user_id == gr_user.id + + # Create another new Membership + membership_manager.create(team=team, gr_user=gr_user_factory()) + assert len(team.memberships) == 1 + team.prefetch_memberships(pg_config=gr_db) + assert len(team.memberships) == 2 + + def test_gr_users( + self, team, gr_user_factory, membership_manager, gr_db, gr_redis_config + ): + assert team.gr_users is None + + team.prefetch_gr_users(pg_config=gr_db, redis_config=gr_redis_config) + assert isinstance(team.gr_users, list) + assert len(team.gr_users) == 0 + + # Create a new Membership + membership_manager.create(team=team, gr_user=gr_user_factory()) + assert len(team.gr_users) == 0 + team.prefetch_gr_users(pg_config=gr_db, redis_config=gr_redis_config) + assert len(team.gr_users) == 1 + + # Create another Membership + membership_manager.create(team=team, gr_user=gr_user_factory()) + assert len(team.gr_users) == 1 + team.prefetch_gr_users(pg_config=gr_db, redis_config=gr_redis_config) + assert len(team.gr_users) == 2 + + def test_businesses(self, team, business, team_manager, gr_db, gr_redis_config): + from generalresearch.models.gr.business import Business + + assert team.businesses is None + + team.prefetch_businesses(pg_config=gr_db, redis_config=gr_redis_config) + assert isinstance(team.businesses, list) + assert len(team.businesses) == 0 + + team_manager.add_business(team=team, business=business) + assert len(team.businesses) == 0 + team.prefetch_businesses(pg_config=gr_db, redis_config=gr_redis_config) + assert len(team.businesses) == 1 + assert isinstance(team.businesses[0], Business) + assert team.businesses[0].uuid == business.uuid + + def test_products(self, team, product_factory, thl_web_rr): + from generalresearch.models.thl.product import Product + + assert team.products is None + + team.prefetch_products(thl_pg_config=thl_web_rr) + assert isinstance(team.products, list) + assert len(team.products) == 0 + + product_factory(team=team) + assert len(team.products) == 0 + team.prefetch_products(thl_pg_config=thl_web_rr) + assert len(team.products) == 1 + assert isinstance(team.products[0], Product) + + +class TestTeamMethods: + + def test_cache_key(self, team, gr_redis): + assert isinstance(team.cache_key, str) + assert ":" in team.cache_key + assert str(team.uuid) in team.cache_key + + def test_set_cache( + self, + team, + gr_redis, + gr_db, + thl_web_rr, + gr_redis_config, + client_no_amm, + mnt_filepath, + mnt_gr_api_dir, + enriched_wall_merge, + enriched_session_merge, + ): + assert gr_redis.get(name=team.cache_key) is None + + team.set_cache( + pg_config=gr_db, + thl_web_rr=thl_web_rr, + redis_config=gr_redis_config, + client=client_no_amm, + ds=mnt_filepath, + mnt_gr_api=mnt_gr_api_dir, + enriched_wall=enriched_wall_merge, + enriched_session=enriched_session_merge, + ) + + assert gr_redis.hgetall(name=team.cache_key) is not None + + def test_set_cache_team( + self, + gr_user, + gr_user_token, + gr_redis, + gr_db, + thl_web_rr, + product_factory, + team, + membership_factory, + gr_redis_config, + client_no_amm, + mnt_filepath, + mnt_gr_api_dir, + enriched_wall_merge, + enriched_session_merge, + ): + from generalresearch.models.gr.team import Team + + p1 = product_factory(team=team) + membership_factory(team=team, gr_user=gr_user) + + team.set_cache( + pg_config=gr_db, + thl_web_rr=thl_web_rr, + redis_config=gr_redis_config, + client=client_no_amm, + ds=mnt_filepath, + mnt_gr_api=mnt_gr_api_dir, + enriched_wall=enriched_wall_merge, + enriched_session=enriched_session_merge, + ) + + team2 = Team.from_redis( + uuid=team.uuid, + fields=["id", "memberships", "gr_users", "businesses", "products"], + gr_redis_config=gr_redis_config, + ) + + assert team.model_dump_json() == team2.model_dump_json() + assert p1.uuid in [p.uuid for p in team2.products] + assert len(team2.gr_users) == 1 + assert gr_user.id in [gru.id for gru in team2.gr_users] + + def test_prebuild_enriched_session_parquet( + self, + event_report_request, + enriched_session_merge, + client_no_amm, + wall_collection, + session_collection, + thl_web_rr, + session_report_request, + user_factory, + start, + session_factory, + product_factory, + delete_df_collection, + business, + mnt_filepath, + mnt_gr_api_dir, + team, + ): + + delete_df_collection(coll=wall_collection) + delete_df_collection(coll=session_collection) + + p1 = product_factory(team=team) + p2 = product_factory(team=team) + + for p in [p1, p2]: + u = user_factory(product=p) + for i in range(50): + s = session_factory( + user=u, + wall_count=1, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(minutes=i, seconds=1), + ) + wall_collection.initial_load(client=None, sync=True) + session_collection.initial_load(client=None, sync=True) + + enriched_session_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + + team.prebuild_enriched_session_parquet( + thl_pg_config=thl_web_rr, + ds=mnt_filepath, + client=client_no_amm, + mnt_gr_api=mnt_gr_api_dir, + enriched_session=enriched_session_merge, + ) + + # Now try to read from path + df = pd.read_parquet( + os.path.join(mnt_gr_api_dir, "pop_session", f"{team.file_key}.parquet") + ) + assert isinstance(df, pd.DataFrame) + + def test_prebuild_enriched_wall_parquet( + self, + event_report_request, + enriched_session_merge, + enriched_wall_merge, + client_no_amm, + wall_collection, + session_collection, + thl_web_rr, + session_report_request, + user_factory, + start, + session_factory, + product_factory, + delete_df_collection, + business, + mnt_filepath, + mnt_gr_api_dir, + team, + ): + + delete_df_collection(coll=wall_collection) + delete_df_collection(coll=session_collection) + + p1 = product_factory(team=team) + p2 = product_factory(team=team) + + for p in [p1, p2]: + u = user_factory(product=p) + for i in range(50): + s = session_factory( + user=u, + wall_count=1, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(minutes=i, seconds=1), + ) + wall_collection.initial_load(client=None, sync=True) + session_collection.initial_load(client=None, sync=True) + + enriched_wall_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + + team.prebuild_enriched_wall_parquet( + thl_pg_config=thl_web_rr, + ds=mnt_filepath, + client=client_no_amm, + mnt_gr_api=mnt_gr_api_dir, + enriched_wall=enriched_wall_merge, + ) + + # Now try to read from path + df = pd.read_parquet( + os.path.join(mnt_gr_api_dir, "pop_event", f"{team.file_key}.parquet") + ) + assert isinstance(df, pd.DataFrame) diff --git a/tests/models/innovate/__init__.py b/tests/models/innovate/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/models/innovate/__init__.py diff --git a/tests/models/innovate/test_question.py b/tests/models/innovate/test_question.py new file mode 100644 index 0000000..330f919 --- /dev/null +++ b/tests/models/innovate/test_question.py @@ -0,0 +1,85 @@ +from generalresearch.models import Source +from generalresearch.models.innovate.question import ( + InnovateQuestion, + InnovateQuestionType, + InnovateQuestionOption, +) +from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionSelectorTE, + UpkQuestion, + UpkQuestionSelectorMC, + UpkQuestionType, + UpkQuestionChoice, +) + + +class TestInnovateQuestion: + + def test_text_entry(self): + + q = InnovateQuestion( + question_id="3", + country_iso="us", + language_iso="eng", + question_key="ZIPCODES", + question_text="postal code", + question_type=InnovateQuestionType.TEXT_ENTRY, + tags=None, + options=None, + is_live=True, + category_id=None, + ) + assert Source.INNOVATE == q.source + assert "i:zipcodes" == q.external_id + assert "zipcodes" == q.internal_id + assert ("zipcodes", "us", "eng") == q._key + + upk = q.to_upk_question() + expected_upk = UpkQuestion( + ext_question_id="i:zipcodes", + type=UpkQuestionType.TEXT_ENTRY, + country_iso="us", + language_iso="eng", + text="postal code", + selector=UpkQuestionSelectorTE.SINGLE_LINE, + choices=None, + ) + assert expected_upk == upk + + def test_mc(self): + + text = "Have you purchased or received any of the following in past 18 months?" + q = InnovateQuestion( + question_key="dynamic_profiling-_1_14715", + country_iso="us", + language_iso="eng", + question_id="14715", + question_text=text, + question_type=InnovateQuestionType.MULTI_SELECT, + tags="Dynamic Profiling- 1", + options=[ + InnovateQuestionOption(id="1", text="aaa", order=0), + InnovateQuestionOption(id="2", text="bbb", order=1), + ], + is_live=True, + category_id=None, + ) + assert "i:dynamic_profiling-_1_14715" == q.external_id + assert "dynamic_profiling-_1_14715" == q.internal_id + assert ("dynamic_profiling-_1_14715", "us", "eng") == q._key + assert 2 == q.num_options + + upk = q.to_upk_question() + expected_upk = UpkQuestion( + ext_question_id="i:dynamic_profiling-_1_14715", + type=UpkQuestionType.MULTIPLE_CHOICE, + country_iso="us", + language_iso="eng", + text=text, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + choices=[ + UpkQuestionChoice(id="1", text="aaa", order=0), + UpkQuestionChoice(id="2", text="bbb", order=1), + ], + ) + assert expected_upk == upk diff --git a/tests/models/legacy/__init__.py b/tests/models/legacy/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/models/legacy/__init__.py diff --git a/tests/models/legacy/data.py b/tests/models/legacy/data.py new file mode 100644 index 0000000..20f3231 --- /dev/null +++ b/tests/models/legacy/data.py @@ -0,0 +1,265 @@ +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/45b7228a7/?bpuid=379fb74f-05b2-42dc-b283 +# -47e1c8678b04&duration=1200&format=json&country_iso=us +RESPONSE_45b7228a7 = ( + '{"info": {"success": true}, "offerwall": {"availability_count": 9, "buckets": [{"category": [{"adwords_id": ' + 'null, "adwords_label": null, "id": "c82cf98c578a43218334544ab376b00e", "label": "Social Research", "p": 1.0}], ' + '"description": "", "duration": {"max": 719, "min": 72, "q1": 144, "q2": 621, "q3": 650}, ' + '"id": "5503c471d95645dd947080704d3760b3", "name": "", "payout": {"max": 132, "min": 68, "q1": 113, "q2": 124, ' + '"q3": 128}, "quality_score": 1.0, "uri": ' + '"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=5503c471d95645dd947080704d3760b3&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=16fa868", "x": 0, "y": 0}, ' + '{"category": [{"adwords_id": null, "adwords_label": null, "id": "c82cf98c578a43218334544ab376b00e", ' + '"label": "Social Research", "p": 0.6666666666666666}, {"adwords_id": "5000", "adwords_label": "World ' + 'Localities", "id": "cd3f9374ba5d4e5692ee5691320ecc8b", "label": "World Localities", "p": 0.16666666666666666}, ' + '{"adwords_id": "14", "adwords_label": "People & Society", "id": "c8642a1b86d9460cbe8f7e8ae6e56ee4", ' + '"label": "People & Society", "p": 0.16666666666666666}], "description": "", "duration": {"max": 1180, ' + '"min": 144, "q1": 457, "q2": 621, "q3": 1103}, "id": "56f437aa5da443748872390a5cbf6103", "name": "", "payout": {' + '"max": 113, "min": 24, "q1": 40, "q2": 68, "q3": 68}, "quality_score": 0.17667012, ' + '"uri": "https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=56f437aa5da443748872390a5cbf6103&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=16fa868", "x": 1, "y": 0}, ' + '{"category": [{"adwords_id": null, "adwords_label": null, "id": "c82cf98c578a43218334544ab376b00e", ' + '"label": "Social Research", "p": 0.6666666666666666}, {"adwords_id": "5000", "adwords_label": "World ' + 'Localities", "id": "cd3f9374ba5d4e5692ee5691320ecc8b", "label": "World Localities", "p": 0.16666666666666666}, ' + '{"adwords_id": "14", "adwords_label": "People & Society", "id": "c8642a1b86d9460cbe8f7e8ae6e56ee4", ' + '"label": "People & Society", "p": 0.16666666666666666}], "description": "", "duration": {"max": 1180, ' + '"min": 144, "q1": 457, "q2": 1103, "q3": 1128}, "id": "a5b8403e4a4a4ed1a21ef9b3a721ab02", "name": "", ' + '"payout": {"max": 68, "min": 14, "q1": 24, "q2": 40, "q3": 68}, "quality_score": 0.01, ' + '"uri": "https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=a5b8403e4a4a4ed1a21ef9b3a721ab02&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=16fa868", "x": 2, "y": 0}], ' + '"id": "2fba2999baf0423cad0c49eceea4eb33", "payout_format": "${payout/100:.2f}"}}' +) + +RESPONSE_b145b803 = ( + '{"info":{"success":true},"offerwall":{"availability_count":10,"buckets":[{"category":[{"adwords_id":null,' + '"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research","p":1.0}],"contents":[{' + '"id":"b9d6fdb95ae2402dbb8e8673be382f04","id_code":"m:b9d6fdb95ae2402dbb8e8673be382f04","loi":954,"payout":166,' + '"source":"m"},{"id":"x94r9bg","id_code":"o:x94r9bg","loi":71,"payout":132,"source":"o"},{"id":"ejqjbv4",' + '"id_code":"o:ejqjbv4","loi":650,"payout":128,"source":"o"},{"id":"yxqdnb9","id_code":"o:yxqdnb9","loi":624,' + '"payout":113,"source":"o"},{"id":"vyjrv0v","id_code":"o:vyjrv0v","loi":719,"payout":124,"source":"o"}],' + '"currency":"USD","description":"","duration":{"max":954,"min":72,"q1":625,"q2":650,"q3":719},' + '"id":"c190231a7d494012a2f641a89a85e6a6","name":"","payout":{"max":166,"min":113,"q1":124,"q2":128,"q3":132},' + '"quality_score":1.0,"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120' + '/?i=c190231a7d494012a2f641a89a85e6a6&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=9dbbf8c","x":0,"y":0},' + '{"category":[{"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social ' + 'Research","p":0.5},{"adwords_id":"5000","adwords_label":"World Localities",' + '"id":"cd3f9374ba5d4e5692ee5691320ecc8b","label":"World Localities","p":0.25},{"adwords_id":"14",' + '"adwords_label":"People & Society","id":"c8642a1b86d9460cbe8f7e8ae6e56ee4","label":"People & Society",' + '"p":0.25}],"contents":[{"id":"ejqa3kw","id_code":"o:ejqa3kw","loi":143,"payout":68,"source":"o"},' + '{"id":"g6xkrbm","id_code":"o:g6xkrbm","loi":536,"payout":68,"source":"o"},{"id":"yr5od0g","id_code":"o:yr5od0g",' + '"loi":457,"payout":68,"source":"o"}],"currency":"USD","description":"","duration":{"max":537,"min":144,"q1":301,' + '"q2":457,"q3":497},"id":"ee12be565f744ef2b194703f3d32f8cd","name":"","payout":{"max":68,"min":68,"q1":68,' + '"q2":68,"q3":68},"quality_score":0.01,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=ee12be565f744ef2b194703f3d32f8cd&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=9dbbf8c","x":1,"y":0},' + '{"category":[{"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social ' + 'Research","p":1.0}],"contents":[{"id":"a019660a8bc0411dba19a3e5c5df5b6c",' + '"id_code":"m:a019660a8bc0411dba19a3e5c5df5b6c","loi":1180,"payout":24,"source":"m"},{"id":"ejqa3kw",' + '"id_code":"o:ejqa3kw","loi":143,"payout":68,"source":"o"},{"id":"8c54725047cc4e0590665a034d37e7f5",' + '"id_code":"m:8c54725047cc4e0590665a034d37e7f5","loi":1128,"payout":14,"source":"m"}],"currency":"USD",' + '"description":"","duration":{"max":1180,"min":144,"q1":636,"q2":1128,"q3":1154},' + '"id":"dc2551eb48d84b329fdcb5b2bd60ed71","name":"","payout":{"max":68,"min":14,"q1":19,"q2":24,"q3":46},' + '"quality_score":0.06662835156312356,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=dc2551eb48d84b329fdcb5b2bd60ed71&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=9dbbf8c","x":2,"y":0}],' + '"id":"7cb9eb4c1a5b41a38cfebd13c9c338cb"}}' +) + +# This is a blocked user. Otherwise, the format is identical to RESPONSE_b145b803 +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/1e5f0af8/?bpuid=00051c62-a872-4832-a008 +# -c37ec51d33d3&duration=1200&format=json&country_iso=us +RESPONSE_d48cce47 = ( + '{"info":{"success":true},"offerwall":{"availability_count":0,"buckets":[],' + '"id":"168680387a7f4c8c9cc8e7ab63f502ff","payout_format":"${payout/100:.2f}"}}' +) + +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/1e5f0af8/?bpuid=379fb74f-05b2-42dc-b283 +# -47e1c8678b04&duration=1200&format=json&country_iso=us +RESPONSE_1e5f0af8 = ( + '{"info":{"success":true},"offerwall":{"availability_count":9,"buckets":[{"category":[{"adwords_id":null,' + '"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research","p":1.0}],"contents":[{' + '"id":"x94r9bg","id_code":"o:x94r9bg","loi":71,"payout":132,"source":"o"},{"id":"yxqdnb9","id_code":"o:yxqdnb9",' + '"loi":604,"payout":113,"source":"o"},{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":473,"payout":128,"source":"o"},' + '{"id":"vyjrv0v","id_code":"o:vyjrv0v","loi":719,"payout":124,"source":"o"}],"currency":"USD","description":"",' + '"duration":{"max":719,"min":72,"q1":373,"q2":539,"q3":633},"id":"2a4a897a76464af2b85703b72a125da0",' + '"is_recontact":false,"name":"","payout":{"max":132,"min":113,"q1":121,"q2":126,"q3":129},"quality_score":1.0,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=2a4a897a76464af2b85703b72a125da0&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=82fe142","x":0,"y":0},' + '{"category":[{"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social ' + 'Research","p":0.5},{"adwords_id":"5000","adwords_label":"World Localities",' + '"id":"cd3f9374ba5d4e5692ee5691320ecc8b","label":"World Localities","p":0.25},{"adwords_id":"14",' + '"adwords_label":"People & Society","id":"c8642a1b86d9460cbe8f7e8ae6e56ee4","label":"People & Society",' + '"p":0.25}],"contents":[{"id":"775ed98f65604dac91b3a60814438829","id_code":"m:775ed98f65604dac91b3a60814438829",' + '"loi":1121,"payout":32,"source":"m"},{"id":"ejqa3kw","id_code":"o:ejqa3kw","loi":143,"payout":68,"source":"o"},' + '{"id":"yr5od0g","id_code":"o:yr5od0g","loi":457,"payout":68,"source":"o"}],"currency":"USD","description":"",' + '"duration":{"max":1121,"min":144,"q1":301,"q2":457,"q3":789},"id":"0aa83eb711c042e28bb9284e604398ac",' + '"is_recontact":false,"name":"","payout":{"max":68,"min":32,"q1":50,"q2":68,"q3":68},"quality_score":0.01,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=0aa83eb711c042e28bb9284e604398ac&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=82fe142","x":1,"y":0},' + '{"category":[{"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social ' + 'Research","p":1.0}],"contents":[{"id":"775ed98f65604dac91b3a60814438829",' + '"id_code":"m:775ed98f65604dac91b3a60814438829","loi":1121,"payout":32,"source":"m"},' + '{"id":"8c54725047cc4e0590665a034d37e7f5","id_code":"m:8c54725047cc4e0590665a034d37e7f5","loi":1128,"payout":14,' + '"source":"m"},{"id":"a019660a8bc0411dba19a3e5c5df5b6c","id_code":"m:a019660a8bc0411dba19a3e5c5df5b6c",' + '"loi":1180,"payout":24,"source":"m"}],"currency":"USD","description":"","duration":{"max":1180,"min":1121,' + '"q1":1125,"q2":1128,"q3":1154},"id":"d3a87b2bb6cf4428a55916bdf65e775e","is_recontact":false,"name":"",' + '"payout":{"max":32,"min":14,"q1":19,"q2":24,"q3":28},"quality_score":0.0825688889614345,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=d3a87b2bb6cf4428a55916bdf65e775e&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=82fe142","x":2,"y":0}],' + '"id":"391a54bbe84c4dbfa50b40841201a606"}}' +) + +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/5fl8bpv5/?bpuid=379fb74f-05b2-42dc-b283 +# -47e1c8678b04&duration=1200&format=json&country_iso=us +RESPONSE_5fl8bpv5 = ( + '{"info":{"success":true},"offerwall":{"availability_count":9,"buckets":[{' + '"id":"a1097b20f2ae472a9a9ad2987ba3bf95",' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=a1097b20f2ae472a9a9ad2987ba3bf95&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e7baf5e"}],' + '"id":"fddaa544d7ff428a8ccccd0667fdc249","payout_format":"${payout/100:.2f}"}}' +) + +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/37d1da64/?bpuid=379fb74f-05b2-42dc-b283 +# -47e1c8678b04&duration=1200&format=json&country_iso=us +RESPONSE_37d1da64 = ( + '{"info":{"success":true},"offerwall":{"availability_count":18,"buckets":[{"category":[{"adwords_id":null,' + '"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research","p":1.0}],"contents":[{' + '"id":"0qvwx4z","id_code":"o:0qvwx4z","loi":700,"payout":106,"source":"o"},{"id":"x94r9bg","id_code":"o:x94r9bg",' + '"loi":72,"payout":132,"source":"o"},{"id":"yxqdnb9","id_code":"o:yxqdnb9","loi":605,"payout":113,"source":"o"},' + '{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":474,"payout":128,"source":"o"}],"eligibility":"conditional",' + '"id":"2cfc47e8d8c4417cb1f499dbf7e9afb8","loi":700,"missing_questions":["7ca8b59f4c864f80a1a7c7287adfc637"],' + '"payout":106,"uri":null},{"category":[{"adwords_id":null,"adwords_label":null,' + '"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research","p":1.0}],"contents":[{"id":"yxqdnb9",' + '"id_code":"o:yxqdnb9","loi":605,"payout":113,"source":"o"},{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,' + '"payout":132,"source":"o"},{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":474,"payout":128,"source":"o"}],' + '"eligibility":"unconditional","id":"8964a87ebbe9433cae0ddce1b34a637a","loi":605,"payout":113,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=8964a87ebbe9433cae0ddce1b34a637a&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=ec8aa9a"},{"category":[{' + '"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research",' + '"p":1.0}],"contents":[{"id":"a019660a8bc0411dba19a3e5c5df5b6c","id_code":"m:a019660a8bc0411dba19a3e5c5df5b6c",' + '"loi":1180,"payout":24,"source":"m"},{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},' + '{"id":"yxqdnb9","id_code":"o:yxqdnb9","loi":605,"payout":113,"source":"o"},' + '{"id":"775ed98f65604dac91b3a60814438829","id_code":"m:775ed98f65604dac91b3a60814438829","loi":1121,"payout":32,' + '"source":"m"},{"id":"ejqa3kw","id_code":"o:ejqa3kw","loi":144,"payout":68,"source":"o"},{"id":"ejqjbv4",' + '"id_code":"o:ejqjbv4","loi":474,"payout":128,"source":"o"},{"id":"yr5od0g","id_code":"o:yr5od0g","loi":457,' + '"payout":68,"source":"o"},{"id":"vyjrv0v","id_code":"o:vyjrv0v","loi":719,"payout":124,"source":"o"}],' + '"eligibility":"unconditional","id":"ba98df6041344e818f02873534ae09a3","loi":1180,"payout":24,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=ba98df6041344e818f02873534ae09a3&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=ec8aa9a"},{"category":[{' + '"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research",' + '"p":1.0}],"contents":[{"id":"775ed98f65604dac91b3a60814438829","id_code":"m:775ed98f65604dac91b3a60814438829",' + '"loi":1121,"payout":32,"source":"m"},{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},' + '{"id":"yxqdnb9","id_code":"o:yxqdnb9","loi":605,"payout":113,"source":"o"},{"id":"ejqa3kw",' + '"id_code":"o:ejqa3kw","loi":144,"payout":68,"source":"o"},{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":474,' + '"payout":128,"source":"o"},{"id":"yr5od0g","id_code":"o:yr5od0g","loi":457,"payout":68,"source":"o"},' + '{"id":"vyjrv0v","id_code":"o:vyjrv0v","loi":719,"payout":124,"source":"o"}],"eligibility":"unconditional",' + '"id":"6456abe0f2584be8a30d9e0e93bef496","loi":1121,"payout":32,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=6456abe0f2584be8a30d9e0e93bef496&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=ec8aa9a"},{"category":[{' + '"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research",' + '"p":1.0}],"contents":[{"id":"ejqa3kw","id_code":"o:ejqa3kw","loi":144,"payout":68,"source":"o"},{"id":"x94r9bg",' + '"id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,' + '"payout":132,"source":"o"}],"eligibility":"unconditional","id":"e0a0fb27bf174040be7971795a967ce5","loi":144,' + '"payout":68,"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=e0a0fb27bf174040be7971795a967ce5&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=ec8aa9a"},{"category":[{' + '"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research",' + '"p":1.0}],"contents":[{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":474,"payout":128,"source":"o"},' + '{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},{"id":"x94r9bg","id_code":"o:x94r9bg",' + '"loi":72,"payout":132,"source":"o"}],"eligibility":"unconditional","id":"99e3f8cfd367411b822cf11c3b54b558",' + '"loi":474,"payout":128,"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709' + "/00ff1d9b71b94bf4b20d22cd56774120/?i=99e3f8cfd367411b822cf11c3b54b558&b=379fb74f-05b2-42dc-b283-47e1c8678b04" + '&66482fb=ec8aa9a"},{"category":[{"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e",' + '"label":"Social Research","p":1.0}],"contents":[{"id":"8c54725047cc4e0590665a034d37e7f5",' + '"id_code":"m:8c54725047cc4e0590665a034d37e7f5","loi":1128,"payout":14,"source":"m"},{"id":"x94r9bg",' + '"id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},{"id":"yxqdnb9","id_code":"o:yxqdnb9","loi":605,' + '"payout":113,"source":"o"},{"id":"775ed98f65604dac91b3a60814438829",' + '"id_code":"m:775ed98f65604dac91b3a60814438829","loi":1121,"payout":32,"source":"m"},{"id":"ejqa3kw",' + '"id_code":"o:ejqa3kw","loi":144,"payout":68,"source":"o"},{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":474,' + '"payout":128,"source":"o"},{"id":"yr5od0g","id_code":"o:yr5od0g","loi":457,"payout":68,"source":"o"},' + '{"id":"vyjrv0v","id_code":"o:vyjrv0v","loi":719,"payout":124,"source":"o"}],"eligibility":"unconditional",' + '"id":"db7bc77afbb6443e8f35e9f06764fd06","loi":1128,"payout":14,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=db7bc77afbb6443e8f35e9f06764fd06&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=ec8aa9a"},{"category":[{' + '"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research",' + '"p":0.5},{"adwords_id":"5000","adwords_label":"World Localities","id":"cd3f9374ba5d4e5692ee5691320ecc8b",' + '"label":"World Localities","p":0.25},{"adwords_id":"14","adwords_label":"People & Society",' + '"id":"c8642a1b86d9460cbe8f7e8ae6e56ee4","label":"People & Society","p":0.25}],"contents":[{"id":"yr5od0g",' + '"id_code":"o:yr5od0g","loi":457,"payout":68,"source":"o"},{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,' + '"payout":132,"source":"o"},{"id":"ejqa3kw","id_code":"o:ejqa3kw","loi":144,"payout":68,"source":"o"}],' + '"eligibility":"unconditional","id":"d5af3c782a354d6a8616e47531eb15e7","loi":457,"payout":68,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=d5af3c782a354d6a8616e47531eb15e7&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=ec8aa9a"},{"category":[{' + '"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research",' + '"p":1.0}],"contents":[{"id":"vyjrv0v","id_code":"o:vyjrv0v","loi":719,"payout":124,"source":"o"},' + '{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},{"id":"ejqjbv4","id_code":"o:ejqjbv4",' + '"loi":474,"payout":128,"source":"o"}],"eligibility":"unconditional","id":"ceafb146bb7042889b85630a44338a75",' + '"loi":719,"payout":124,"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709' + "/00ff1d9b71b94bf4b20d22cd56774120/?i=ceafb146bb7042889b85630a44338a75&b=379fb74f-05b2-42dc-b283-47e1c8678b04" + '&66482fb=ec8aa9a"}],"id":"37ed0858e0f64329812e2070fb658eb3","question_info":{' + '"7ca8b59f4c864f80a1a7c7287adfc637":{"choices":[{"choice_id":"0","choice_text":"Single, never married",' + '"order":0},{"choice_id":"1","choice_text":"Living with a Partner","order":1},{"choice_id":"2",' + '"choice_text":"Civil Union / Domestic Partnership","order":2},{"choice_id":"3","choice_text":"Married",' + '"order":3},{"choice_id":"4","choice_text":"Separated","order":4},{"choice_id":"5","choice_text":"Divorced",' + '"order":5},{"choice_id":"6","choice_text":"Widowed","order":6}],"country_iso":"us","language_iso":"eng",' + '"question_id":"7ca8b59f4c864f80a1a7c7287adfc637","question_text":"What is your relationship status?",' + '"question_type":"MC","selector":"SA"}}}}' +) + +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/5fa23085/?bpuid=379fb74f-05b2-42dc-b283 +# -47e1c8678b04&duration=1200&format=json&country_iso=us +RESPONSE_5fa23085 = ( + '{"info":{"success":true},"offerwall":{"availability_count":7,"buckets":[{"category":[{"adwords_id":null,' + '"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research","p":0.6666666666666666},' + '{"adwords_id":"5000","adwords_label":"World Localities","id":"cd3f9374ba5d4e5692ee5691320ecc8b","label":"World ' + 'Localities","p":0.16666666666666666},{"adwords_id":"14","adwords_label":"People & Society",' + '"id":"c8642a1b86d9460cbe8f7e8ae6e56ee4","label":"People & Society","p":0.16666666666666666}],"contents":[{' + '"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},{"id":"yxqdnb9","id_code":"o:yxqdnb9",' + '"loi":605,"payout":113,"source":"o"},{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":640,"payout":128,"source":"o"},' + '{"id":"yr5od0g","id_code":"o:yr5od0g","loi":457,"payout":68,"source":"o"},{"id":"vyjrv0v","id_code":"o:vyjrv0v",' + '"loi":719,"payout":124,"source":"o"}],"duration":{"max":719,"min":72,"q1":457,"q2":605,"q3":640},' + '"id":"30049c28363b4f689c84fbacf1cc57e2","payout":{"max":132,"min":68,"q1":113,"q2":124,"q3":128},' + '"source":"pollfish","uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120' + '/?i=30049c28363b4f689c84fbacf1cc57e2&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e3c1af1"}],' + '"id":"1c33756fd7ca49fa84ee48e42145e68c"}}' +) + +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/1705e4f8/?bpuid=379fb74f-05b2-42dc-b283 +# -47e1c8678b04&duration=1200&format=json&country_iso=us +RESPONSE_1705e4f8 = ( + '{"info":{"success":true},"offerwall":{"availability_count":7,"buckets":[{"currency":"USD","duration":719,' + '"id":"d06abd6ac75b453a93d0e85e4e391c00","min_payout":124,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=d06abd6ac75b453a93d0e85e4e391c00&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":705,"id":"bcf5ca85a4044e9abed163f72039c7d1","min_payout":123,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=bcf5ca85a4044e9abed163f72039c7d1&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":679,"id":"b85034bd59e04a64ac1be7686a4c906d","min_payout":121,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=b85034bd59e04a64ac1be7686a4c906d&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":1180,"id":"324cf03b9ba24ef19284683bf9b62afb","min_payout":24,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=324cf03b9ba24ef19284683bf9b62afb&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":1062,"id":"6bbbef3157c346009e677e556ecea7e7","min_payout":17,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=6bbbef3157c346009e677e556ecea7e7&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":1128,"id":"bc12ba86b5024cf5b8ade997415190f7","min_payout":14,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=bc12ba86b5024cf5b8ade997415190f7&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":1126,"id":"e574e715dcc744dbacf84a8426a0cd37","min_payout":10,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=e574e715dcc744dbacf84a8426a0cd37&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":1302,"id":"a6f562e3371347d19d281d40b3ca317d","min_payout":10,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=a6f562e3371347d19d281d40b3ca317d&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":953,"id":"462c8e1fbdad475792a360097c8e740f","min_payout":4,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=462c8e1fbdad475792a360097c8e740f&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"}],' + '"id":"7f5158dc25174ada89f54e2b26a61b20"}}' +) + +# Blocked user +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/0af0f7ec/?bpuid=00051c62-a872-4832-a008 +# -c37ec51d33d3&duration=1200&format=json&country_iso=us +RESPONSE_0af0f7ec = ( + '{"info":{"success":true},"offerwall":{"availability_count":0,"buckets":[],' + '"id":"34e71f4ccdec47b3b0991f5cfda60238","payout_format":"${payout/100:.2f}"}}' +) diff --git a/tests/models/legacy/test_offerwall_parse_response.py b/tests/models/legacy/test_offerwall_parse_response.py new file mode 100644 index 0000000..7fb5315 --- /dev/null +++ b/tests/models/legacy/test_offerwall_parse_response.py @@ -0,0 +1,186 @@ +import json + +from generalresearch.models import Source +from generalresearch.models.legacy.bucket import ( + TopNPlusBucket, + SurveyEligibilityCriterion, + DurationSummary, + PayoutSummary, + BucketTask, +) + + +class TestOfferwallTopNAndStarwall: + def test_45b7228a7(self): + from generalresearch.models.legacy.offerwall import ( + TopNOfferWall, + TopNOfferWallResponse, + StarwallOfferWallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_45b7228a7, + ) + + res = json.loads(RESPONSE_45b7228a7) + assert TopNOfferWallResponse.model_validate(res) + offerwall = TopNOfferWall.model_validate(res["offerwall"]) + assert offerwall + offerwall.censor() + # Format is identical to starwall + assert StarwallOfferWallResponse.model_validate(res) + + def test_b145b803(self): + from generalresearch.models.legacy.offerwall import ( + TopNPlusOfferWallResponse, + StarwallPlusOfferWallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_b145b803, + ) + + res = json.loads(RESPONSE_b145b803) + assert TopNPlusOfferWallResponse.model_validate(res) + assert StarwallPlusOfferWallResponse.model_validate(res) + + def test_d48cce47(self): + from generalresearch.models.legacy.offerwall import ( + TopNPlusBlockOfferWallResponse, + StarwallPlusBlockOfferWallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_b145b803, + RESPONSE_d48cce47, + ) + + res = json.loads(RESPONSE_d48cce47) # this is a blocked user's response + assert TopNPlusBlockOfferWallResponse.model_validate(res) + assert StarwallPlusBlockOfferWallResponse.model_validate(res) + # otherwise it is identical to the plus's response + res = json.loads(RESPONSE_b145b803) + assert TopNPlusBlockOfferWallResponse.model_validate(res) + assert StarwallPlusBlockOfferWallResponse.model_validate(res) + + def test_1e5f0af8(self): + from generalresearch.models.legacy.offerwall import ( + TopNPlusBlockRecontactOfferWallResponse, + StarwallPlusBlockRecontactOfferWallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_d48cce47, + RESPONSE_1e5f0af8, + ) + + res = json.loads(RESPONSE_1e5f0af8) + assert TopNPlusBlockRecontactOfferWallResponse.model_validate(res) + assert StarwallPlusBlockRecontactOfferWallResponse.model_validate(res) + + res = json.loads(RESPONSE_d48cce47) # this is a blocked user's response + assert TopNPlusBlockRecontactOfferWallResponse.model_validate(res) + assert StarwallPlusBlockRecontactOfferWallResponse.model_validate(res) + + def test_eligibility_criteria(self): + b = TopNPlusBucket( + id="c82cf98c578a43218334544ab376b00e", + contents=[ + BucketTask( + id="12345", + payout=10, + source=Source.TESTING, + id_code="t:12345", + loi=120, + ) + ], + duration=DurationSummary(max=1, min=1, q1=1, q2=1, q3=1), + quality_score=1, + payout=PayoutSummary(max=1, min=1, q1=1, q2=1, q3=1), + uri="https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i=2a4a897a76464af2b85703b72a125da0&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=82fe142", + eligibility_criteria=( + SurveyEligibilityCriterion( + question_id="71a367fb71b243dc89f0012e0ec91749", + question_text="what is something", + qualifying_answer=("1",), + qualifying_answer_label=("abc",), + property_code="t:123", + ), + SurveyEligibilityCriterion( + question_id="81a367fb71b243dc89f0012e0ec91749", + question_text="what is something 2", + qualifying_answer=("2",), + qualifying_answer_label=("ddd",), + property_code="t:124", + ), + ), + ) + assert b.eligibility_criteria[0].rank == 0 + assert b.eligibility_criteria[1].rank == 1 + print(b.model_dump_json()) + b.censor() + print(b.model_dump_json()) + + +class TestOfferwallSingle: + def test_5fl8bpv5(self): + from generalresearch.models.legacy.offerwall import ( + SingleEntryOfferWallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_5fl8bpv5, + ) + + res = json.loads(RESPONSE_5fl8bpv5) + assert SingleEntryOfferWallResponse.model_validate(res) + + +class TestOfferwallSoftPair: + def test_37d1da64(self): + from generalresearch.models.legacy.offerwall import ( + SoftPairOfferwallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_37d1da64, + ) + + res = json.loads(RESPONSE_37d1da64) + assert SoftPairOfferwallResponse.model_validate(res) + + +class TestMarketplace: + def test_5fa23085(self): + from generalresearch.models.legacy.offerwall import ( + MarketplaceOfferwallResponse, + ) + + from tests.models.legacy.data import ( + RESPONSE_5fa23085, + ) + + res = json.loads(RESPONSE_5fa23085) + assert MarketplaceOfferwallResponse.model_validate(res) + + +class TestTimebucks: + def test_1705e4f8(self): + from generalresearch.models.legacy.offerwall import ( + TimeBucksOfferwallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_1705e4f8, + ) + + res = json.loads(RESPONSE_1705e4f8) + assert TimeBucksOfferwallResponse.model_validate(res) + + def test_0af0f7ec(self): + from generalresearch.models.legacy.offerwall import ( + TimeBucksBlockOfferwallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_1705e4f8, + RESPONSE_0af0f7ec, + ) + + res = json.loads(RESPONSE_0af0f7ec) + assert TimeBucksBlockOfferwallResponse.model_validate(res) + + res = json.loads(RESPONSE_1705e4f8) + assert TimeBucksBlockOfferwallResponse.model_validate(res) diff --git a/tests/models/legacy/test_profiling_questions.py b/tests/models/legacy/test_profiling_questions.py new file mode 100644 index 0000000..1afaa6b --- /dev/null +++ b/tests/models/legacy/test_profiling_questions.py @@ -0,0 +1,81 @@ +class TestUpkQuestionResponse: + + def test_init(self): + from generalresearch.models.legacy.questions import UpkQuestionResponse + + s = ( + '{"status": "success", "count": 7, "questions": [{"selector": "SL", "validation": {"patterns": [{' + '"message": "Must input a value between 13 and 120", "pattern": "^(1[01][0-9]|120|1[3-9]|[2-9][' + '0-9])$"}]}, "country_iso": "us", "question_id": "c5a4ef644c374f8994ecb3226b84263e", "language_iso": ' + '"eng", "configuration": {"max_length": 3, "type": "TE"}, "question_text": "What is your age (in ' + 'years)?", "question_type": "TE", "task_score": 20.28987136651298, "task_count": 21131, "p": 1.0}, ' + '{"choices": [{"order": 0, "choice_id": "0", "choice_text": "Male"}, {"order": 1, "choice_id": "1", ' + '"choice_text": "Female"}, {"order": 2, "choice_id": "2", "choice_text": "Other"}], "selector": "SA", ' + '"country_iso": "us", "question_id": "5d6d9f3c03bb40bf9d0a24f306387d7c", "language_iso": "eng", ' + '"question_text": "What is your gender?", "question_type": "MC", "task_score": 16.598347505339095, ' + '"task_count": 4842, "p": 0.8180607558081178}, {"choices": [{"order": 0, "choice_id": "ara", ' + '"choice_text": "Arabic"}, {"order": 1, "choice_id": "zho", "choice_text": "Chinese - Mandarin"}, ' + '{"order": 2, "choice_id": "dut", "choice_text": "Dutch"}, {"order": 3, "choice_id": "eng", ' + '"choice_text": "English"}, {"order": 4, "choice_id": "fre", "choice_text": "French"}, {"order": 5, ' + '"choice_id": "ger", "choice_text": "German"}, {"order": 6, "choice_id": "hat", "choice_text": "Haitian ' + 'Creole"}, {"order": 7, "choice_id": "hin", "choice_text": "Hindi"}, {"order": 8, "choice_id": "ind", ' + '"choice_text": "Indonesian"}, {"order": 9, "choice_id": "ita", "choice_text": "Italian"}, {"order": 10, ' + '"choice_id": "jpn", "choice_text": "Japanese"}, {"order": 11, "choice_id": "kor", "choice_text": ' + '"Korean"}, {"order": 12, "choice_id": "may", "choice_text": "Malay"}, {"order": 13, "choice_id": "pol", ' + '"choice_text": "Polish"}, {"order": 14, "choice_id": "por", "choice_text": "Portuguese"}, {"order": 15, ' + '"choice_id": "pan", "choice_text": "Punjabi"}, {"order": 16, "choice_id": "rus", "choice_text": ' + '"Russian"}, {"order": 17, "choice_id": "spa", "choice_text": "Spanish"}, {"order": 18, "choice_id": ' + '"tgl", "choice_text": "Tagalog/Filipino"}, {"order": 19, "choice_id": "tur", "choice_text": "Turkish"}, ' + '{"order": 20, "choice_id": "vie", "choice_text": "Vietnamese"}, {"order": 21, "choice_id": "zul", ' + '"choice_text": "Zulu"}, {"order": 22, "choice_id": "xxx", "choice_text": "Other"}], "selector": "MA", ' + '"country_iso": "us", "question_id": "f15663d012244d5fa43f5784f7bd1898", "language_iso": "eng", ' + '"question_text": "Which language(s) do you speak fluently at home? (Select all that apply)", ' + '"question_type": "MC", "task_score": 15.835933296975051, "task_count": 147, "p": 0.780484657143325}, ' + '{"selector": "SL", "validation": {"patterns": [{"message": "Must enter a valid zip code: XXXXX", ' + '"pattern": "^[0-9]{5}$"}]}, "country_iso": "us", "question_id": "543de254e9ca4d9faded4377edab82a9", ' + '"language_iso": "eng", "configuration": {"max_length": 5, "min_length": 5, "type": "TE"}, ' + '"question_text": "What is your zip code?", "question_type": "TE", "task_score": 3.9114103408096685, ' + '"task_count": 4116, "p": 0.19277649769949645}, {"selector": "SL", "validation": {"patterns": [{' + '"message": "Must input digits only (range between 1 and 999999)", "pattern": "^[0-9]{1,6}$"}]}, ' + '"country_iso": "us", "question_id": "9ffacedc92584215912062a9d75338fa", "language_iso": "eng", ' + '"configuration": {"max_length": 6, "type": "TE"}, "question_text": "What is your current annual ' + 'household income before taxes (in USD)?", "question_type": "TE", "task_score": 2.4630414657197686, ' + '"task_count": 3267, "p": 0.12139266046727369}, {"choices": [{"order": 0, "choice_id": "0", ' + '"choice_text": "Employed full-time"}, {"order": 1, "choice_id": "1", "choice_text": "Employed ' + 'part-time"}, {"order": 2, "choice_id": "2", "choice_text": "Self-employed full-time"}, {"order": 3, ' + '"choice_id": "3", "choice_text": "Self-employed part-time"}, {"order": 4, "choice_id": "4", ' + '"choice_text": "Active military"}, {"order": 5, "choice_id": "5", "choice_text": "Inactive ' + 'military/Veteran"}, {"order": 6, "choice_id": "6", "choice_text": "Temporarily unemployed"}, ' + '{"order": 7, "choice_id": "7", "choice_text": "Full-time homemaker"}, {"order": 8, "choice_id": "8", ' + '"choice_text": "Retired"}, {"order": 9, "choice_id": "9", "choice_text": "Student"}, {"order": 10, ' + '"choice_id": "10", "choice_text": "Disabled"}], "selector": "SA", "country_iso": "us", "question_id": ' + '"b546d26651f040c9a6900ffb126e7d69", "language_iso": "eng", "question_text": "What is your current ' + 'employment status?", "question_type": "MC", "task_score": 1.6940674222375414, "task_count": 1134, ' + '"p": 0.0834932559027201}, {"choices": [{"order": 0, "choice_id": "0", "choice_text": "No"}, ' + '{"order": 1, "choice_id": "1", "choice_text": "Yes, Mexican"}, {"order": 2, "choice_id": "2", ' + '"choice_text": "Yes, Puerto Rican"}, {"order": 3, "choice_id": "3", "choice_text": "Yes, Cuban"}, ' + '{"order": 4, "choice_id": "4", "choice_text": "Yes, Salvadoran"}, {"order": 5, "choice_id": "5", ' + '"choice_text": "Yes, Dominican"}, {"order": 6, "choice_id": "6", "choice_text": "Yes, Guatemalan"}, ' + '{"order": 7, "choice_id": "7", "choice_text": "Yes, Colombian"}, {"order": 8, "choice_id": "8", ' + '"choice_text": "Yes, Honduran"}, {"order": 9, "choice_id": "9", "choice_text": "Yes, Ecuadorian"}, ' + '{"order": 10, "choice_id": "10", "choice_text": "Yes, Argentinian"}, {"order": 11, "choice_id": "11", ' + '"choice_text": "Yes, Peruvian"}, {"order": 12, "choice_id": "12", "choice_text": "Yes, Nicaraguan"}, ' + '{"order": 13, "choice_id": "13", "choice_text": "Yes, Spaniard"}, {"order": 14, "choice_id": "14", ' + '"choice_text": "Yes, Venezuelan"}, {"order": 15, "choice_id": "15", "choice_text": "Yes, Panamanian"}, ' + '{"order": 16, "choice_id": "16", "choice_text": "Yes, Other"}], "selector": "SA", "country_iso": "us", ' + '"question_id": "7d452b8069c24a1aacbccbf767910345", "language_iso": "eng", "question_text": "Are you of ' + 'Hispanic, Latino, or Spanish origin?", "question_type": "MC", "task_score": 1.6342156553005878, ' + '"task_count": 2516, "p": 0.08054342118687588}], "special_questions": [{"selector": "HIDDEN", ' + '"country_iso": "xx", "question_id": "1d1e2e8380ac474b87fb4e4c569b48df", "language_iso": "xxx", ' + '"question_type": "HIDDEN"}, {"selector": "HIDDEN", "country_iso": "xx", "question_id": ' + '"2fbedb2b9f7647b09ff5e52fa119cc5e", "language_iso": "xxx", "question_type": "HIDDEN"}, {"selector": ' + '"HIDDEN", "country_iso": "xx", "question_id": "4030c52371b04e80b64e058d9c5b82e9", "language_iso": ' + '"xxx", "question_type": "HIDDEN"}, {"selector": "HIDDEN", "country_iso": "xx", "question_id": ' + '"a91cb1dea814480dba12d9b7b48696dd", "language_iso": "xxx", "question_type": "HIDDEN"}, {"selector": ' + '"HIDDEN", "task_count": 40.0, "task_score": 0.45961204566189584, "country_iso": "us", "question_id": ' + '"59f39a785f154752b6435c260cbce3c6", "language_iso": "eng", "question_text": "Core-Based Statistical ' + 'Area (2020)", "question_type": "HIDDEN"}], "consent_questions": []}' + ) + instance = UpkQuestionResponse.model_validate_json(s) + + assert isinstance(instance, UpkQuestionResponse) diff --git a/tests/models/legacy/test_user_question_answer_in.py b/tests/models/legacy/test_user_question_answer_in.py new file mode 100644 index 0000000..253c46e --- /dev/null +++ b/tests/models/legacy/test_user_question_answer_in.py @@ -0,0 +1,304 @@ +import json +from decimal import Decimal +from uuid import uuid4 + +import pytest + + +class TestUserQuestionAnswers: + """This is for the GRS POST submission that may contain multiple + Question+Answer(s) combinations for a single GRS Survey. It is + responsible for making sure the same question isn't submitted + more than once per submission, and other "list validation" + checks that aren't possible on an individual level. + """ + + def test_json_init( + self, + product_manager, + user_manager, + session_manager, + wall_manager, + user_factory, + product, + session_factory, + utc_hour_ago, + ): + from generalresearch.models.thl.session import Session, Wall + from generalresearch.models.thl.user import User + from generalresearch.models import Source + from generalresearch.models.legacy.questions import ( + UserQuestionAnswers, + ) + + u: User = user_factory(product=product) + + s1: Session = session_factory( + user=u, + wall_count=1, + started=utc_hour_ago, + wall_req_cpi=Decimal("0.00"), + wall_source=Source.GRS, + ) + assert isinstance(s1, Session) + w1 = s1.wall_events[0] + assert isinstance(w1, Wall) + + instance = UserQuestionAnswers.model_validate_json( + json.dumps( + { + "product_id": product.uuid, + "product_user_id": u.product_user_id, + "session_id": w1.uuid, + "answers": [ + {"question_id": uuid4().hex, "answer": ["a", "b"]}, + {"question_id": uuid4().hex, "answer": ["a", "b"]}, + ], + } + ) + ) + assert isinstance(instance, UserQuestionAnswers) + + def test_simple_validation_errors( + self, product_manager, user_manager, session_manager, wall_manager + ): + from generalresearch.models.legacy.questions import ( + UserQuestionAnswers, + ) + + with pytest.raises(ValueError): + UserQuestionAnswers.model_validate( + { + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "session_id": uuid4().hex, + "answers": [{"question_id": uuid4().hex, "answer": ["a", "b"]}], + } + ) + + with pytest.raises(ValueError): + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "session_id": uuid4().hex, + "answers": [{"question_id": uuid4().hex, "answer": ["a", "b"]}], + } + ) + + # user is validated only if a session_id is passed + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "answers": [{"question_id": uuid4().hex, "answer": ["a", "b"]}], + } + ) + + with pytest.raises(ValueError): + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "session_id": uuid4().hex, + } + ) + + with pytest.raises(ValueError): + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "session_id": uuid4().hex, + "answers": [], + } + ) + + with pytest.raises(ValueError): + answers = [ + {"question_id": uuid4().hex, "answer": ["a"]} for i in range(101) + ] + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "session_id": uuid4().hex, + "answers": answers, + } + ) + + with pytest.raises(ValueError): + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "session_id": uuid4().hex, + "answers": "aaa", + } + ) + + def test_no_duplicate_questions(self): + # TODO: depending on if or how many of these types of errors actually + # occur, we could get fancy and just drop one of them. I don't + # think this is worth exploring yet unless we see if it's a problem. + from generalresearch.models.legacy.questions import ( + UserQuestionAnswers, + ) + + consistent_qid = uuid4().hex + with pytest.raises(ValueError) as cm: + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "session_id": uuid4().hex, + "answers": [ + {"question_id": consistent_qid, "answer": ["aaa"]}, + {"question_id": consistent_qid, "answer": ["bbb"]}, + ], + } + ) + + assert "Don't provide answers to duplicate questions" in str(cm.value) + + def test_allow_answer_failures_silent( + self, + product_manager, + user_manager, + session_manager, + wall_manager, + product, + user_factory, + utc_hour_ago, + session_factory, + ): + """ + There are many instances where suppliers may be submitting answers + manually, and they're just totally broken. We want to silently remove + that one QuestionAnswerIn without "loosing" any of the other + QuestionAnswerIn items that they provided. + """ + from generalresearch.models.thl.session import Session, Wall + from generalresearch.models.thl.user import User + from generalresearch.models.legacy.questions import ( + UserQuestionAnswers, + ) + + u: User = user_factory(product=product) + + s1: Session = session_factory(user=u, wall_count=1, started=utc_hour_ago) + assert isinstance(s1, Session) + w1 = s1.wall_events[0] + assert isinstance(w1, Wall) + + data = { + "product_id": product.uuid, + "product_user_id": u.product_user_id, + "session_id": w1.uuid, + "answers": [ + {"question_id": uuid4().hex, "answer": ["aaa"]}, + {"question_id": f"broken-{uuid4().hex[:6]}", "answer": ["bbb"]}, + ], + } + # load via .model_validate() + instance = UserQuestionAnswers.model_validate(data) + assert isinstance(instance, UserQuestionAnswers) + + # One of the QuestionAnswerIn items was invalid, so it was dropped + assert 1 == len(instance.answers) + + # Confirm that this also works via model_validate_json + json_data = json.dumps(data) + instance = UserQuestionAnswers.model_validate_json(json_data) + assert isinstance(instance, UserQuestionAnswers) + + # One of the QuestionAnswerIn items was invalid, so it was dropped + assert 1 == len(instance.answers) + + assert instance.user is None + instance.prefetch_user(um=user_manager) + assert isinstance(instance.user, User) + + +class TestUserQuestionAnswerIn: + """This is for the individual Question+Answer(s) that may come back from + a GRS POST. + """ + + def test_simple_validation_errors(self): + from generalresearch.models.legacy.questions import ( + UserQuestionAnswerIn, + ) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate( + {"question_id": f"test-{uuid4().hex[:6]}", "answer": ["123"]} + ) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate({"answer": ["123"]}) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": [123]} + ) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": [""]} + ) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": [" "]} + ) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": ["a" * 5_001]} + ) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": []} + ) + + def test_only_single_answers(self): + from generalresearch.models.legacy.questions import ( + UserQuestionAnswerIn, + ) + + for qid in { + "2fbedb2b9f7647b09ff5e52fa119cc5e", + "4030c52371b04e80b64e058d9c5b82e9", + "a91cb1dea814480dba12d9b7b48696dd", + "1d1e2e8380ac474b87fb4e4c569b48df", + }: + # This is the UserAgent question which only allows a single answer + with pytest.raises(ValueError) as cm: + UserQuestionAnswerIn.model_validate( + {"question_id": qid, "answer": ["a", "b"]} + ) + + assert "Too many answer values provided" in str(cm.value) + + def test_answer_item_limit(self): + from generalresearch.models.legacy.questions import ( + UserQuestionAnswerIn, + ) + + answer = [uuid4().hex[:6] for i in range(11)] + with pytest.raises(ValueError) as cm: + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": answer} + ) + assert "List should have at most 10 items after validation" in str(cm.value) + + def test_disallow_duplicate_answer_values(self): + from generalresearch.models.legacy.questions import ( + UserQuestionAnswerIn, + ) + + answer = ["aaa" for i in range(5)] + with pytest.raises(ValueError) as cm: + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": answer} + ) diff --git a/tests/models/morning/__init__.py b/tests/models/morning/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/models/morning/__init__.py diff --git a/tests/models/morning/test.py b/tests/models/morning/test.py new file mode 100644 index 0000000..cf4982d --- /dev/null +++ b/tests/models/morning/test.py @@ -0,0 +1,199 @@ +from datetime import datetime, timezone + +from generalresearch.models.morning.question import MorningQuestion + +bid = { + "buyer_account_id": "ab180f06-aa2b-4b8b-9b87-1031bfe8b16b", + "buyer_id": "5f3b4daa-6ff0-4826-a551-9d4572ea1c84", + "country_id": "us", + "end_date": "2024-07-19T09:01:13.520243Z", + "exclusions": [ + {"group_id": "66070689-5198-5782-b388-33daa74f3269", "lockout_period": 28} + ], + "id": "5324c2ac-eca8-4ed0-8b0e-042ba3aa2a85", + "language_ids": ["en"], + "name": "Ad-Hoc Survey", + "published_at": "2024-06-19T09:01:13.520243Z", + "quotas": [ + { + "cost_per_interview": 154, + "id": "b8ade883-a83d-4d8e-9ef7-953f4b692bd8", + "qualifications": [ + { + "id": "age", + "response_ids": [ + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "25", + "26", + "27", + "28", + "29", + "30", + "31", + "32", + "33", + "34", + ], + }, + {"id": "gender", "response_ids": ["1"]}, + {"id": "hispanic", "response_ids": ["1"]}, + ], + "statistics": { + "length_of_interview": 1353, + "median_length_of_interview": 1353, + "num_available": 3, + "num_completes": 7, + "num_failures": 0, + "num_in_progress": 4, + "num_over_quotas": 0, + "num_qualified": 27, + "num_quality_terminations": 14, + "num_timeouts": 1, + "qualified_conversion": 30, + }, + } + ], + "state": "active", + "statistics": { + "earnings_per_click": 26, + "estimated_length_of_interview": 1140, + "incidence_rate": 77, + "length_of_interview": 1198, + "median_length_of_interview": 1198, + "num_available": 70, + "num_completes": 360, + "num_entrants": 1467, + "num_failures": 0, + "num_in_progress": 48, + "num_over_quotas": 10, + "num_qualified": 1121, + "num_quality_terminations": 584, + "num_screenouts": 380, + "num_timeouts": 85, + "qualified_conversion": 34, + "system_conversion": 25, + }, + "supplier_exclusive": False, + "survey_type": "ad_hoc", + "timeout": 21600, + "topic_id": "general", +} + +bid = { + "_experimental_single_use_qualifications": [ + { + "id": "electric_car_test", + "name": "Electric Car Test", + "text": "What kind of vehicle do you drive?", + "language_ids": ["en"], + "responses": [{"id": "1", "text": "electric"}, {"id": "2", "text": "gas"}], + "type": "multiple_choice", + } + ], + "buyer_account_id": "0b6f207c-96e1-4dce-b032-566a815ad263", + "buyer_id": "9020f6f3-db41-470a-a5d7-c04fa2da9156", + "closed_at": "2022-01-01T00:00:00Z", + "country_id": "us", + "end_date": "2022-01-01T00:00:00Z", + "exclusions": [ + {"group_id": "0bbae805-5a80-42e3-8d5f-cb056a0f825d", "lockout_period": 7} + ], + "id": "000f09a3-bc25-4adc-a443-a9975800e7ac", + "language_ids": ["en", "es"], + "name": "My Example Survey", + "published_at": "2021-12-30T00:00:00Z", + "quotas": [ + { + "_experimental_single_use_qualifications": [ + {"id": "electric_car_test", "response_ids": ["1"]} + ], + "cost_per_interview": 100, + "id": "6a7d0190-e6ad-4a59-9945-7ba460517f2b", + "qualifications": [ + {"id": "gender", "response_ids": ["1"]}, + {"id": "age", "response_ids": ["18", "19", "20", "21"]}, + ], + "statistics": { + "length_of_interview": 600, + "median_length_of_interview": 600, + "num_available": 500, + "num_completes": 100, + "num_failures": 0, + "num_in_progress": 0, + "num_over_quotas": 0, + "num_qualified": 100, + "num_quality_terminations": 0, + "num_timeouts": 0, + "qualified_conversion": 100, + }, + } + ], + "state": "active", + "statistics": { + "earnings_per_click": 50, + "estimated_length_of_interview": 720, + "incidence_rate": 100, + "length_of_interview": 600, + "median_length_of_interview": 600, + "num_available": 500, + "num_completes": 100, + "num_entrants": 100, + "num_failures": 0, + "num_in_progress": 0, + "num_over_quotas": 0, + "num_qualified": 100, + "num_quality_terminations": 0, + "num_screenouts": 0, + "num_timeouts": 0, + "qualified_conversion": 100, + "system_conversion": 100, + }, + "supplier_exclusive": False, + "survey_type": "ad_hoc", + "timeout": 3600, + "topic_id": "general", +} + +# what gets run in MorningAPI._format_bid +bid["language_isos"] = ("eng",) +bid["country_iso"] = "us" +bid["end_date"] = datetime(2024, 7, 19, 9, 1, 13, 520243, tzinfo=timezone.utc) +bid["published_at"] = datetime(2024, 6, 19, 9, 1, 13, 520243, tzinfo=timezone.utc) +bid.update(bid["statistics"]) +bid["qualified_conversion"] /= 100 +bid["system_conversion"] /= 100 +for quota in bid["quotas"]: + quota.update(quota["statistics"]) + quota["qualified_conversion"] /= 100 + quota["cost_per_interview"] /= 100 +if "_experimental_single_use_qualifications" in bid: + bid["experimental_single_use_qualifications"] = [ + MorningQuestion.from_api(q, bid["country_iso"], "eng") + for q in bid["_experimental_single_use_qualifications"] + ] + + +class TestMorningBid: + + def test_model_validate(self): + from generalresearch.models.morning.survey import MorningBid + + s = MorningBid.model_validate(bid) + d = s.model_dump(mode="json") + d = s.to_mysql() + + def test_manager(self): + # todo: credentials n stuff + pass + # sql_helper = SqlHelper(host="localhost", user="root", password="", db="300large-morning") + # m = MorningSurveyManager(sql_helper=sql_helper) + # s = MorningBid.model_validate(bid) + # m.create(s) + # res = m.get_survey_library()[0] + # MorningBid.model_validate(res) diff --git a/tests/models/precision/__init__.py b/tests/models/precision/__init__.py new file mode 100644 index 0000000..8006fa3 --- /dev/null +++ b/tests/models/precision/__init__.py @@ -0,0 +1,115 @@ +survey_json = { + "cpi": "1.44", + "country_isos": "ca", + "language_isos": "eng", + "country_iso": "ca", + "language_iso": "eng", + "buyer_id": "7047", + "bid_loi": 1200, + "bid_ir": 0.45, + "source": "e", + "used_question_ids": ["age", "country_iso", "gender", "gender_1"], + "survey_id": "0000", + "group_id": "633473", + "status": "open", + "name": "beauty survey", + "survey_guid": "c7f375c5077d4c6c8209ff0b539d7183", + "category_id": "-1", + "global_conversion": None, + "desired_count": 96, + "achieved_count": 0, + "allowed_devices": "1,2,3", + "entry_link": "https://www.opinionetwork.com/survey/entry.aspx?mid=[%MID%]&project=633473&key=%%key%%", + "excluded_surveys": "470358,633286", + "quotas": [ + { + "name": "25-34,Male,Quebec", + "id": "2324110", + "guid": "23b5760d24994bc08de451b3e62e77c7", + "status": "open", + "desired_count": 48, + "achieved_count": 0, + "termination_count": 0, + "overquota_count": 0, + "condition_hashes": ["b41e1a3", "bc89ee8", "4124366", "9f32c61"], + }, + { + "name": "25-34,Female,Quebec", + "id": "2324111", + "guid": "0706f1a88d7e4f11ad847c03012e68d2", + "status": "open", + "desired_count": 48, + "achieved_count": 0, + "termination_count": 4, + "overquota_count": 0, + "condition_hashes": ["b41e1a3", "0cdc304", "500af2c", "9f32c61"], + }, + ], + "conditions": { + "b41e1a3": { + "logical_operator": "OR", + "value_type": 1, + "negate": False, + "question_id": "country_iso", + "values": ["ca"], + "criterion_hash": "b41e1a3", + "value_len": 1, + "sizeof": 2, + }, + "bc89ee8": { + "logical_operator": "OR", + "value_type": 1, + "negate": False, + "question_id": "gender", + "values": ["male"], + "criterion_hash": "bc89ee8", + "value_len": 1, + "sizeof": 4, + }, + "4124366": { + "logical_operator": "OR", + "value_type": 1, + "negate": False, + "question_id": "gender_1", + "values": ["male"], + "criterion_hash": "4124366", + "value_len": 1, + "sizeof": 4, + }, + "9f32c61": { + "logical_operator": "OR", + "value_type": 1, + "negate": False, + "question_id": "age", + "values": ["25", "26", "27", "28", "29", "30", "31", "32", "33", "34"], + "criterion_hash": "9f32c61", + "value_len": 10, + "sizeof": 20, + }, + "0cdc304": { + "logical_operator": "OR", + "value_type": 1, + "negate": False, + "question_id": "gender", + "values": ["female"], + "criterion_hash": "0cdc304", + "value_len": 1, + "sizeof": 6, + }, + "500af2c": { + "logical_operator": "OR", + "value_type": 1, + "negate": False, + "question_id": "gender_1", + "values": ["female"], + "criterion_hash": "500af2c", + "value_len": 1, + "sizeof": 6, + }, + }, + "expected_end_date": "2024-06-28T10:40:33.000000Z", + "created": None, + "updated": None, + "is_live": True, + "all_hashes": ["0cdc304", "b41e1a3", "9f32c61", "bc89ee8", "4124366", "500af2c"], +} diff --git a/tests/models/precision/test_survey.py b/tests/models/precision/test_survey.py new file mode 100644 index 0000000..ff2d6d1 --- /dev/null +++ b/tests/models/precision/test_survey.py @@ -0,0 +1,88 @@ +class TestPrecisionQuota: + + def test_quota_passes(self): + from generalresearch.models.precision.survey import PrecisionSurvey + from tests.models.precision import survey_json + + s = PrecisionSurvey.model_validate(survey_json) + q = s.quotas[0] + ce = {k: True for k in ["b41e1a3", "bc89ee8", "4124366", "9f32c61"]} + assert q.matches(ce) + + ce["b41e1a3"] = False + assert not q.matches(ce) + + ce.pop("b41e1a3") + assert not q.matches(ce) + assert not q.matches({}) + + def test_quota_passes_closed(self): + from generalresearch.models.precision import PrecisionStatus + from generalresearch.models.precision.survey import PrecisionSurvey + from tests.models.precision import survey_json + + s = PrecisionSurvey.model_validate(survey_json) + q = s.quotas[0] + q.status = PrecisionStatus.CLOSED + ce = {k: True for k in ["b41e1a3", "bc89ee8", "4124366", "9f32c61"]} + # We still match, but the quota is not open + assert q.matches(ce) + assert not q.is_open + + +class TestPrecisionSurvey: + + def test_passes(self): + from generalresearch.models.precision.survey import PrecisionSurvey + from tests.models.precision import survey_json + + s = PrecisionSurvey.model_validate(survey_json) + ce = {k: True for k in ["b41e1a3", "bc89ee8", "4124366", "9f32c61"]} + assert s.determine_eligibility(ce) + + def test_elig_closed_quota(self): + from generalresearch.models.precision import PrecisionStatus + from generalresearch.models.precision.survey import PrecisionSurvey + from tests.models.precision import survey_json + + s = PrecisionSurvey.model_validate(survey_json) + ce = {k: True for k in ["b41e1a3", "bc89ee8", "4124366", "9f32c61"]} + q = s.quotas[0] + q.status = PrecisionStatus.CLOSED + # We match a closed quota + assert not s.determine_eligibility(ce) + + s.quotas[0].status = PrecisionStatus.OPEN + s.quotas[1].status = PrecisionStatus.CLOSED + # Now me match an open quota and dont match the closed quota, so we should be eligible + assert s.determine_eligibility(ce) + + def test_passes_sp(self): + from generalresearch.models.precision import PrecisionStatus + from generalresearch.models.precision.survey import PrecisionSurvey + from tests.models.precision import survey_json + + s = PrecisionSurvey.model_validate(survey_json) + ce = {k: True for k in ["b41e1a3", "bc89ee8", "4124366", "9f32c61"]} + passes, hashes = s.determine_eligibility_soft(ce) + + # We don't know if we match the 2nd quota, but it is open so it doesn't matter + assert passes + assert (True, []) == s.quotas[0].matches_soft(ce) + assert (None, ["0cdc304", "500af2c"]) == s.quotas[1].matches_soft(ce) + + # Now don't know if we match either + ce.pop("9f32c61") # age + passes, hashes = s.determine_eligibility_soft(ce) + assert passes is None + assert {"500af2c", "9f32c61", "0cdc304"} == hashes + + ce["9f32c61"] = False + ce["0cdc304"] = False + # We know we don't match either + assert (False, set()) == s.determine_eligibility_soft(ce) + + # We pass 1st quota, 2nd is unknown but closed, so we don't know for sure we pass + ce = {k: True for k in ["b41e1a3", "bc89ee8", "4124366", "9f32c61"]} + s.quotas[1].status = PrecisionStatus.CLOSED + assert (None, {"0cdc304", "500af2c"}) == s.determine_eligibility_soft(ce) diff --git a/tests/models/precision/test_survey_manager.py b/tests/models/precision/test_survey_manager.py new file mode 100644 index 0000000..8532ab0 --- /dev/null +++ b/tests/models/precision/test_survey_manager.py @@ -0,0 +1,63 @@ +# from decimal import Decimal +# +# from datetime import timezone, datetime +# from pymysql import IntegrityError +# from generalresearch.models.precision.survey import PrecisionSurvey +# from tests.models.precision import survey_json + + +# def delete_survey(survey_id: str): +# db_name = sql_helper.db +# # TODO: what is the precision specific db name... +# +# sql_helper.execute_sql_query( +# query=""" +# DELETE FROM `300large-precision`.precision_survey +# WHERE survey_id = %s +# """, +# params=[survey_id], commit=True) +# sql_helper.execute_sql_query(""" +# DELETE FROM `300large-precision`.precision_survey_country WHERE survey_id = %s +# """, [survey_id], commit=True) +# sql_helper.execute_sql_query(""" +# DELETE FROM `300large-precision`.precision_survey_language WHERE survey_id = %s +# """, [survey_id], commit=True) +# +# +# class TestPrecisionSurvey: +# def test_survey_create(self): +# now = datetime.now(tz=timezone.utc) +# s = PrecisionSurvey.model_validate(survey_json) +# self.assertEqual(s.survey_id, '0000') +# delete_survey(s.survey_id) +# +# sm.create(s) +# +# surveys = sm.get_survey_library(updated_since=now) +# self.assertEqual(len(surveys), 1) +# self.assertEqual('0000', surveys[0].survey_id) +# self.assertTrue(s.is_unchanged(surveys[0])) +# +# with self.assertRaises(IntegrityError) as context: +# sm.create(s) +# +# def test_survey_update(self): +# # There's extra complexity here with the country/lang join tables +# now = datetime.now(tz=timezone.utc) +# s = PrecisionSurvey.model_validate(survey_json) +# self.assertEqual(s.survey_id, '0000') +# delete_survey(s.survey_id) +# sm.create(s) +# s.cpi = Decimal('0.50') +# # started out at only 'ca' and 'eng' +# s.country_isos = ['us'] +# s.country_iso = 'us' +# s.language_isos = ['eng', 'spa'] +# s.language_iso = 'eng' +# sm.update([s]) +# surveys = sm.get_survey_library(updated_since=now) +# self.assertEqual(len(surveys), 1) +# s2 = surveys[0] +# self.assertEqual('0000', s2.survey_id) +# self.assertEqual(Decimal('0.50'), s2.cpi) +# self.assertTrue(s.is_unchanged(s2)) diff --git a/tests/models/prodege/__init__.py b/tests/models/prodege/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/models/prodege/__init__.py diff --git a/tests/models/prodege/test_survey_participation.py b/tests/models/prodege/test_survey_participation.py new file mode 100644 index 0000000..b85cc91 --- /dev/null +++ b/tests/models/prodege/test_survey_participation.py @@ -0,0 +1,120 @@ +from datetime import timezone, datetime, timedelta + + +class TestProdegeParticipation: + + def test_exclude(self): + from generalresearch.models.prodege import ProdegePastParticipationType + from generalresearch.models.prodege.survey import ( + ProdegePastParticipation, + ProdegeUserPastParticipation, + ) + + now = datetime.now(tz=timezone.utc) + pp = ProdegePastParticipation.from_api( + { + "participation_project_ids": [152677146, 152803285], + "filter_type": "exclude", + "in_past_days": 7, + "participation_types": ["complete"], + } + ) + # User has no history, so is eligible + assert pp.is_eligible([]) + + # user abandoned. its a click, not complete, so he's eligible + upps = [ + ProdegeUserPastParticipation( + started=now - timedelta(hours=69), survey_id="152677146" + ) + ] + assert pp.is_eligible(upps) + + # user completes. ineligible + upps = [ + ProdegeUserPastParticipation( + started=now - timedelta(hours=69), + survey_id="152677146", + ext_status_code_1="1", + ) + ] + assert not pp.is_eligible(upps) + + # user completed. but too long ago + upps = [ + ProdegeUserPastParticipation( + started=now - timedelta(days=100), + survey_id="152677146", + ext_status_code_1="1", + ) + ] + assert pp.is_eligible(upps) + + # remove day filter, should be ineligble again + pp = ProdegePastParticipation.from_api( + { + "participation_project_ids": [152677146, 152803285], + "filter_type": "exclude", + "in_past_days": 0, + "participation_types": ["complete"], + } + ) + assert not pp.is_eligible(upps) + + # I almost forgot this.... a "complete" IS ALSO A "click"!!! + pp = ProdegePastParticipation.from_api( + { + "participation_project_ids": [152677146, 152803285], + "filter_type": "exclude", + "in_past_days": 0, + "participation_types": ["click"], + } + ) + upps = [ + ProdegeUserPastParticipation( + started=now - timedelta(hours=69), + survey_id="152677146", + ext_status_code_1="1", + ) + ] + assert { + ProdegePastParticipationType.COMPLETE, + ProdegePastParticipationType.CLICK, + } == upps[0].participation_types + assert not pp.is_eligible(upps) + + def test_include(self): + from generalresearch.models.prodege.survey import ( + ProdegePastParticipation, + ProdegeUserPastParticipation, + ) + + now = datetime.now(tz=timezone.utc) + pp = ProdegePastParticipation.from_api( + { + "participation_project_ids": [152677146, 152803285], + "filter_type": "include", + "in_past_days": 7, + "participation_types": ["complete"], + } + ) + # User has no history, so is IN-eligible + assert not pp.is_eligible([]) + + # user abandoned. its a click, not complete, so he's INeligible + upps = [ + ProdegeUserPastParticipation( + started=now - timedelta(hours=69), survey_id="152677146" + ) + ] + assert not pp.is_eligible(upps) + + # user completes, eligible + upps = [ + ProdegeUserPastParticipation( + started=now - timedelta(hours=69), + survey_id="152677146", + ext_status_code_1="1", + ) + ] + assert pp.is_eligible(upps) diff --git a/tests/models/spectrum/__init__.py b/tests/models/spectrum/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/models/spectrum/__init__.py diff --git a/tests/models/spectrum/test_question.py b/tests/models/spectrum/test_question.py new file mode 100644 index 0000000..ba118d7 --- /dev/null +++ b/tests/models/spectrum/test_question.py @@ -0,0 +1,216 @@ +from datetime import datetime, timezone + +from generalresearch.models import Source +from generalresearch.models.spectrum.question import ( + SpectrumQuestionOption, + SpectrumQuestion, + SpectrumQuestionType, + SpectrumQuestionClass, +) +from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + UpkQuestionSelectorMC, + UpkQuestionType, + UpkQuestionChoice, +) + + +class TestSpectrumQuestion: + + def test_parse_from_api_1(self): + + example_1 = { + "qualification_code": 213, + "text": "My household earns approximately $%%213%% per year", + "cat": None, + "desc": "Income", + "type": 5, + "class": 1, + "condition_codes": [], + "format": {"min": 0, "max": 999999, "regex": "/^([0-9]{1,6})$/i"}, + "crtd_on": 1502869927688, + "mod_on": 1706557247467, + } + q = SpectrumQuestion.from_api(example_1, "us", "eng") + + expected_q = SpectrumQuestion( + question_id="213", + country_iso="us", + language_iso="eng", + question_name="Income", + question_text="My household earns approximately $___ per year", + question_type=SpectrumQuestionType.TEXT_ENTRY, + tags=None, + options=None, + class_num=SpectrumQuestionClass.CORE, + created=datetime(2017, 8, 16, 7, 52, 7, 688000, tzinfo=timezone.utc), + is_live=True, + source=Source.SPECTRUM, + category_id=None, + ) + assert "My household earns approximately $___ per year" == q.question_text + assert "213" == q.question_id + assert expected_q == q + q.to_upk_question() + assert "s:213" == q.external_id + + def test_parse_from_api_2(self): + + example_2 = { + "qualification_code": 211, + "text": "I'm a %%211%%", + "cat": None, + "desc": "Gender", + "type": 1, + "class": 1, + "condition_codes": [ + {"id": "111", "text": "Male"}, + {"id": "112", "text": "Female"}, + ], + "format": {"min": None, "max": None, "regex": ""}, + "crtd_on": 1502869927688, + "mod_on": 1706557249817, + } + q = SpectrumQuestion.from_api(example_2, "us", "eng") + expected_q = SpectrumQuestion( + question_id="211", + country_iso="us", + language_iso="eng", + question_name="Gender", + question_text="I'm a", + question_type=SpectrumQuestionType.SINGLE_SELECT, + tags=None, + options=[ + SpectrumQuestionOption(id="111", text="Male", order=0), + SpectrumQuestionOption(id="112", text="Female", order=1), + ], + class_num=SpectrumQuestionClass.CORE, + created=datetime(2017, 8, 16, 7, 52, 7, 688000, tzinfo=timezone.utc), + is_live=True, + source=Source.SPECTRUM, + category_id=None, + ) + assert expected_q == q + q.to_upk_question() + + def test_parse_from_api_3(self): + + example_3 = { + "qualification_code": 220, + "text": "My child is a %%230%% %%221%% old %%220%%", + "cat": None, + "desc": "Child Dependent", + "type": 6, + "class": 4, + "condition_codes": [ + {"id": "111", "text": "Boy"}, + {"id": "112", "text": "Girl"}, + ], + "format": {"min": None, "max": None, "regex": ""}, + "crtd_on": 1502869927688, + "mod_on": 1706556781278, + } + q = SpectrumQuestion.from_api(example_3, "us", "eng") + # This fails because the text has variables from other questions in it + assert q is None + + def test_parse_from_api_4(self): + + example_4 = { + "qualification_code": 1039, + "text": "Do you suffer from any of the following ailments or medical conditions? (Select all that apply) " + " %%1039%%", + "cat": "Ailments, Illness", + "desc": "Standard Ailments", + "type": 3, + "class": 2, + "condition_codes": [ + {"id": "111", "text": "Allergies (Food, Nut, Skin)"}, + {"id": "999", "text": "None of the above"}, + {"id": "130", "text": "Other"}, + { + "id": "129", + "text": "Women's Health Conditions (Reproductive Issues)", + }, + ], + "format": {"min": None, "max": None, "regex": ""}, + "crtd_on": 1502869927688, + "mod_on": 1706557241693, + } + q = SpectrumQuestion.from_api(example_4, "us", "eng") + expected_q = SpectrumQuestion( + question_id="1039", + country_iso="us", + language_iso="eng", + question_name="Standard Ailments", + question_text="Do you suffer from any of the following ailments or medical conditions? (Select all that " + "apply)", + question_type=SpectrumQuestionType.MULTI_SELECT, + tags="Ailments, Illness", + options=[ + SpectrumQuestionOption( + id="111", text="Allergies (Food, Nut, Skin)", order=0 + ), + SpectrumQuestionOption( + id="129", + text="Women's Health Conditions (Reproductive Issues)", + order=1, + ), + SpectrumQuestionOption(id="130", text="Other", order=2), + SpectrumQuestionOption(id="999", text="None of the above", order=3), + ], + class_num=SpectrumQuestionClass.EXTENDED, + created=datetime(2017, 8, 16, 7, 52, 7, 688000, tzinfo=timezone.utc), + is_live=True, + source=Source.SPECTRUM, + category_id=None, + ) + assert expected_q == q + + # todo: we should have something that infers that if the choice text is "None of the above", + # then the choice is exclusive + u = UpkQuestion( + id=None, + ext_question_id="s:1039", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + country_iso="us", + language_iso="eng", + text="Do you suffer from any of the following ailments or medical conditions? (Select all " + "that apply)", + choices=[ + UpkQuestionChoice( + id="111", + text="Allergies (Food, Nut, Skin)", + order=0, + group=None, + exclusive=False, + importance=None, + ), + UpkQuestionChoice( + id="129", + text="Women's Health Conditions (Reproductive Issues)", + order=1, + group=None, + exclusive=False, + importance=None, + ), + UpkQuestionChoice( + id="130", + text="Other", + order=2, + group=None, + exclusive=False, + importance=None, + ), + UpkQuestionChoice( + id="999", + text="None of the above", + order=3, + group=None, + exclusive=False, + importance=None, + ), + ], + ) + assert u == q.to_upk_question() diff --git a/tests/models/spectrum/test_survey.py b/tests/models/spectrum/test_survey.py new file mode 100644 index 0000000..65dec60 --- /dev/null +++ b/tests/models/spectrum/test_survey.py @@ -0,0 +1,413 @@ +from datetime import timezone, datetime +from decimal import Decimal + + +class TestSpectrumCondition: + + def test_condition_create(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.spectrum.survey import ( + SpectrumCondition, + ) + from generalresearch.models.thl.survey.condition import ConditionValueType + + c = SpectrumCondition.from_api( + { + "qualification_code": 212, + "range_sets": [ + {"units": 311, "to": 28, "from": 25}, + {"units": 311, "to": 42, "from": 40}, + ], + } + ) + assert ( + SpectrumCondition( + question_id="212", + values=["25-28", "40-42"], + value_type=ConditionValueType.RANGE, + negate=False, + logical_operator=LogicalOperator.OR, + ) + == c + ) + + # These equal each other b/c age ranges get automatically converted + assert ( + SpectrumCondition( + question_id="212", + values=["25", "26", "27", "28", "40", "41", "42"], + value_type=ConditionValueType.LIST, + negate=False, + logical_operator=LogicalOperator.OR, + ) + == c + ) + + c = SpectrumCondition.from_api( + { + "condition_codes": ["111", "117", "112", "113", "118"], + "qualification_code": 1202, + } + ) + assert ( + SpectrumCondition( + question_id="1202", + values=["111", "112", "113", "117", "118"], + value_type=ConditionValueType.LIST, + negate=False, + logical_operator=LogicalOperator.OR, + ) + == c + ) + + +class TestSpectrumQuota: + + def test_quota_create(self): + from generalresearch.models.spectrum.survey import ( + SpectrumCondition, + SpectrumQuota, + ) + + d = { + "quota_id": "a846b545-4449-4d76-93a2-f8ebdf6e711e", + "quantities": {"currently_open": 57, "remaining": 57, "achieved": 0}, + "criteria": [{"qualification_code": 211, "condition_codes": ["111"]}], + "crtd_on": 1716227282077, + "mod_on": 1716227284146, + "last_complete_date": None, + } + criteria = [SpectrumCondition.from_api(q) for q in d["criteria"]] + d["condition_hashes"] = [x.criterion_hash for x in criteria] + q = SpectrumQuota.from_api(d) + assert SpectrumQuota(remaining_count=57, condition_hashes=["c23c0b9"]) == q + assert q.is_open + + def test_quota_passes(self): + from generalresearch.models.spectrum.survey import ( + SpectrumQuota, + ) + + q = SpectrumQuota(remaining_count=57, condition_hashes=["a"]) + assert q.passes({"a": True}) + assert not q.passes({"a": False}) + assert not q.passes({}) + + # We have to match all + q = SpectrumQuota(remaining_count=57, condition_hashes=["a", "b", "c"]) + assert not q.passes({"a": True, "b": False}) + assert q.passes({"a": True, "b": True, "c": True}) + + # Quota must be open, even if we match + q = SpectrumQuota(remaining_count=0, condition_hashes=["a"]) + assert not q.passes({"a": True}) + + def test_quota_passes_soft(self): + from generalresearch.models.spectrum.survey import ( + SpectrumQuota, + ) + + q = SpectrumQuota(remaining_count=57, condition_hashes=["a", "b", "c"]) + # Pass if we match all + assert (True, set()) == q.matches_soft({"a": True, "b": True, "c": True}) + # Fail if we don't match any + assert (False, set()) == q.matches_soft({"a": True, "b": False, "c": None}) + # Unknown if any are unknown AND we don't fail any + assert (None, {"c", "b"}) == q.matches_soft({"a": True, "b": None, "c": None}) + assert (None, {"a", "c", "b"}) == q.matches_soft( + {"a": None, "b": None, "c": None} + ) + assert (False, set()) == q.matches_soft({"a": None, "b": False, "c": None}) + + +class TestSpectrumSurvey: + def test_survey_create(self): + from generalresearch.models import ( + LogicalOperator, + Source, + TaskCalculationType, + ) + from generalresearch.models.spectrum import SpectrumStatus + from generalresearch.models.spectrum.survey import ( + SpectrumCondition, + SpectrumQuota, + SpectrumSurvey, + ) + from generalresearch.models.thl.survey.condition import ConditionValueType + + # Note: d is the raw response after calling SpectrumAPI.preprocess_survey() on it! + d = { + "survey_id": 29333264, + "survey_name": "Exciting New Survey #29333264", + "survey_status": 22, + "field_end_date": datetime(2024, 5, 23, 18, 18, 31, tzinfo=timezone.utc), + "category": "Exciting New", + "category_code": 232, + "crtd_on": datetime(2024, 5, 20, 17, 48, 13, tzinfo=timezone.utc), + "mod_on": datetime(2024, 5, 20, 18, 18, 31, tzinfo=timezone.utc), + "soft_launch": False, + "click_balancing": 0, + "price_type": 1, + "pii": False, + "buyer_message": "", + "buyer_id": 4726, + "incl_excl": 0, + "cpi": Decimal("1.20000"), + "last_complete_date": None, + "project_last_complete_date": None, + "survey_performance": { + "overall": {"ir": 40, "loi": 10}, + "last_block": {"ir": None, "loi": None}, + }, + "supplier_completes": { + "needed": 495, + "achieved": 0, + "remaining": 495, + "guaranteed_allocation": 0, + "guaranteed_allocation_remaining": 0, + }, + "pds": {"enabled": False, "buyer_name": None}, + "quotas": [ + { + "quota_id": "c2bc961e-4f26-4223-b409-ebe9165cfdf5", + "quantities": { + "currently_open": 491, + "remaining": 495, + "achieved": 0, + }, + "criteria": [ + { + "qualification_code": 212, + "range_sets": [{"units": 311, "to": 64, "from": 18}], + } + ], + "crtd_on": 1716227293496, + "mod_on": 1716229289847, + "last_complete_date": None, + } + ], + "qualifications": [ + { + "range_sets": [{"units": 311, "to": 64, "from": 18}], + "qualification_code": 212, + } + ], + "country_iso": "fr", + "language_iso": "fre", + "bid_ir": 0.4, + "bid_loi": 600, + "last_block_ir": None, + "last_block_loi": None, + "survey_exclusions": set(), + "exclusion_period": 0, + } + s = SpectrumSurvey.from_api(d) + expected_survey = SpectrumSurvey( + cpi=Decimal("1.20000"), + country_isos=["fr"], + language_isos=["fre"], + buyer_id="4726", + source=Source.SPECTRUM, + used_question_ids={"212"}, + survey_id="29333264", + survey_name="Exciting New Survey #29333264", + status=SpectrumStatus.LIVE, + field_end_date=datetime(2024, 5, 23, 18, 18, 31, tzinfo=timezone.utc), + category_code="232", + calculation_type=TaskCalculationType.COMPLETES, + requires_pii=False, + survey_exclusions=set(), + exclusion_period=0, + bid_ir=0.40, + bid_loi=600, + last_block_loi=None, + last_block_ir=None, + overall_loi=None, + overall_ir=None, + project_last_complete_date=None, + country_iso="fr", + language_iso="fre", + include_psids=None, + exclude_psids=None, + qualifications=["77f493d"], + quotas=[SpectrumQuota(remaining_count=491, condition_hashes=["77f493d"])], + conditions={ + "77f493d": SpectrumCondition( + logical_operator=LogicalOperator.OR, + value_type=ConditionValueType.RANGE, + negate=False, + question_id="212", + values=["18-64"], + ) + }, + created_api=datetime(2024, 5, 20, 17, 48, 13, tzinfo=timezone.utc), + modified_api=datetime(2024, 5, 20, 18, 18, 31, tzinfo=timezone.utc), + updated=None, + ) + assert expected_survey.model_dump_json() == s.model_dump_json() + + def test_survey_properties(self): + from generalresearch.models.spectrum.survey import ( + SpectrumSurvey, + ) + + d = { + "survey_id": 29333264, + "survey_name": "#29333264", + "survey_status": 22, + "field_end_date": datetime(2024, 5, 23, 18, 18, 31, tzinfo=timezone.utc), + "category": "Exciting New", + "category_code": 232, + "crtd_on": datetime(2024, 5, 20, 17, 48, 13, tzinfo=timezone.utc), + "mod_on": datetime(2024, 5, 20, 18, 18, 31, tzinfo=timezone.utc), + "soft_launch": False, + "click_balancing": 0, + "price_type": 1, + "pii": False, + "buyer_message": "", + "buyer_id": 4726, + "incl_excl": 0, + "cpi": Decimal("1.20000"), + "last_complete_date": None, + "project_last_complete_date": None, + "quotas": [ + { + "quota_id": "c2bc961e-4f26-4223-b409-ebe9165cfdf5", + "quantities": { + "currently_open": 491, + "remaining": 495, + "achieved": 0, + }, + "criteria": [ + { + "qualification_code": 214, + "range_sets": [{"units": 311, "to": 64, "from": 18}], + } + ], + } + ], + "qualifications": [ + { + "range_sets": [{"units": 311, "to": 64, "from": 18}], + "qualification_code": 212, + }, + {"condition_codes": ["111", "117", "112"], "qualification_code": 1202}, + ], + "country_iso": "fr", + "language_iso": "fre", + "overall_ir": 0.4, + "overall_loi": 600, + "last_block_ir": None, + "last_block_loi": None, + "survey_exclusions": set(), + "exclusion_period": 0, + } + s = SpectrumSurvey.from_api(d) + assert {"212", "1202", "214"} == s.used_question_ids + assert s.is_live + assert s.is_open + assert {"38cea5e", "83955ef", "77f493d"} == s.all_hashes + + def test_survey_eligibility(self): + from generalresearch.models.spectrum.survey import ( + SpectrumQuota, + SpectrumSurvey, + ) + + d = { + "survey_id": 29333264, + "survey_name": "#29333264", + "survey_status": 22, + "field_end_date": datetime(2024, 5, 23, 18, 18, 31, tzinfo=timezone.utc), + "category": "Exciting New", + "category_code": 232, + "crtd_on": datetime(2024, 5, 20, 17, 48, 13, tzinfo=timezone.utc), + "mod_on": datetime(2024, 5, 20, 18, 18, 31, tzinfo=timezone.utc), + "soft_launch": False, + "click_balancing": 0, + "price_type": 1, + "pii": False, + "buyer_message": "", + "buyer_id": 4726, + "incl_excl": 0, + "cpi": Decimal("1.20000"), + "last_complete_date": None, + "project_last_complete_date": None, + "quotas": [], + "qualifications": [], + "country_iso": "fr", + "language_iso": "fre", + "overall_ir": 0.4, + "overall_loi": 600, + "last_block_ir": None, + "last_block_loi": None, + "survey_exclusions": set(), + "exclusion_period": 0, + } + s = SpectrumSurvey.from_api(d) + s.qualifications = ["a", "b", "c"] + s.quotas = [ + SpectrumQuota(remaining_count=10, condition_hashes=["a", "b"]), + SpectrumQuota(remaining_count=0, condition_hashes=["d"]), + SpectrumQuota(remaining_count=10, condition_hashes=["e"]), + ] + + assert s.passes_qualifications({"a": True, "b": True, "c": True}) + assert not s.passes_qualifications({"a": True, "b": True, "c": False}) + + # we do NOT match a full quota, so we pass + assert s.passes_quotas({"a": True, "b": True, "d": False}) + # We dont pass any + assert not s.passes_quotas({}) + # we only pass a full quota + assert not s.passes_quotas({"d": True}) + # we only dont pass a full quota, but we haven't passed any open + assert not s.passes_quotas({"d": False}) + # we pass a quota, but also pass a full quota, so fail + assert not s.passes_quotas({"e": True, "d": True}) + # we pass a quota, but are unknown in a full quota, so fail + assert not s.passes_quotas({"e": True}) + + # # Soft Pair + assert (True, set()) == s.passes_qualifications_soft( + {"a": True, "b": True, "c": True} + ) + assert (False, set()) == s.passes_qualifications_soft( + {"a": True, "b": True, "c": False} + ) + assert (None, set("c")) == s.passes_qualifications_soft( + {"a": True, "b": True, "c": None} + ) + + # we do NOT match a full quota, so we pass + assert (True, set()) == s.passes_quotas_soft({"a": True, "b": True, "d": False}) + # We dont pass any + assert (None, {"a", "b", "d", "e"}) == s.passes_quotas_soft({}) + # we only pass a full quota + assert (False, set()) == s.passes_quotas_soft({"d": True}) + # we only dont pass a full quota, but we haven't passed any open + assert (None, {"a", "b", "e"}) == s.passes_quotas_soft({"d": False}) + # we pass a quota, but also pass a full quota, so fail + assert (False, set()) == s.passes_quotas_soft({"e": True, "d": True}) + # we pass a quota, but are unknown in a full quota, so fail + assert (None, {"d"}) == s.passes_quotas_soft({"e": True}) + + assert s.determine_eligibility({"a": True, "b": True, "c": True, "d": False}) + assert not s.determine_eligibility( + {"a": True, "b": True, "c": False, "d": False} + ) + assert not s.determine_eligibility( + {"a": True, "b": True, "c": None, "d": False} + ) + assert (True, set()) == s.determine_eligibility_soft( + {"a": True, "b": True, "c": True, "d": False} + ) + assert (False, set()) == s.determine_eligibility_soft( + {"a": True, "b": True, "c": False, "d": False} + ) + assert (None, set("c")) == s.determine_eligibility_soft( + {"a": True, "b": True, "c": None, "d": False} + ) + assert (None, {"c", "d"}) == s.determine_eligibility_soft( + {"a": True, "b": True, "c": None, "d": None} + ) diff --git a/tests/models/spectrum/test_survey_manager.py b/tests/models/spectrum/test_survey_manager.py new file mode 100644 index 0000000..582093c --- /dev/null +++ b/tests/models/spectrum/test_survey_manager.py @@ -0,0 +1,130 @@ +import copy +import logging +from datetime import timezone, datetime +from decimal import Decimal + +from pymysql import IntegrityError + + +logger = logging.getLogger() + +example_survey_api_response = { + "survey_id": 29333264, + "survey_name": "#29333264", + "survey_status": 22, + "field_end_date": datetime(2024, 5, 23, 18, 18, 31, tzinfo=timezone.utc), + "category": "Exciting New", + "category_code": 232, + "crtd_on": datetime(2024, 5, 20, 17, 48, 13, tzinfo=timezone.utc), + "mod_on": datetime(2024, 5, 20, 18, 18, 31, tzinfo=timezone.utc), + "soft_launch": False, + "click_balancing": 0, + "price_type": 1, + "pii": False, + "buyer_message": "", + "buyer_id": 4726, + "incl_excl": 0, + "cpi": Decimal("1.20"), + "last_complete_date": None, + "project_last_complete_date": None, + "quotas": [ + { + "quota_id": "c2bc961e-4f26-4223-b409-ebe9165cfdf5", + "quantities": {"currently_open": 491, "remaining": 495, "achieved": 0}, + "criteria": [ + { + "qualification_code": 214, + "range_sets": [{"units": 311, "to": 64, "from": 18}], + } + ], + } + ], + "qualifications": [ + { + "range_sets": [{"units": 311, "to": 64, "from": 18}], + "qualification_code": 212, + }, + {"condition_codes": ["111", "117", "112"], "qualification_code": 1202}, + ], + "country_iso": "fr", + "language_iso": "fre", + "bid_ir": 0.4, + "bid_loi": 600, + "overall_ir": None, + "overall_loi": None, + "last_block_ir": None, + "last_block_loi": None, + "survey_exclusions": set(), + "exclusion_period": 0, +} + + +class TestSpectrumSurvey: + + def test_survey_create(self, settings, spectrum_manager, spectrum_rw): + from generalresearch.models.spectrum.survey import SpectrumSurvey + + assert settings.debug, "CRITICAL: Do not run this on production." + + now = datetime.now(tz=timezone.utc) + spectrum_rw.execute_sql_query( + query=f""" + DELETE FROM `{spectrum_rw.db}`.spectrum_survey + WHERE survey_id = '29333264'""", + commit=True, + ) + + d = example_survey_api_response.copy() + s = SpectrumSurvey.from_api(d) + spectrum_manager.create(s) + + surveys = spectrum_manager.get_survey_library(updated_since=now) + assert len(surveys) == 1 + assert "29333264" == surveys[0].survey_id + assert s.is_unchanged(surveys[0]) + + try: + spectrum_manager.create(s) + except IntegrityError as e: + print(e.args) + + def test_survey_update(self, settings, spectrum_manager, spectrum_rw): + from generalresearch.models.spectrum.survey import SpectrumSurvey + + assert settings.debug, "CRITICAL: Do not run this on production." + + now = datetime.now(tz=timezone.utc) + spectrum_rw.execute_sql_query( + query=f""" + DELETE FROM `{spectrum_rw.db}`.spectrum_survey + WHERE survey_id = '29333264' + """, + commit=True, + ) + d = copy.deepcopy(example_survey_api_response) + s = SpectrumSurvey.from_api(d) + print(s) + + spectrum_manager.create(s) + s.cpi = Decimal("0.50") + spectrum_manager.update([s]) + surveys = spectrum_manager.get_survey_library(updated_since=now) + assert len(surveys) == 1 + assert "29333264" == surveys[0].survey_id + assert Decimal("0.50") == surveys[0].cpi + assert s.is_unchanged(surveys[0]) + + # --- Updating bid/overall/last block + assert 600 == s.bid_loi + assert s.overall_loi is None + assert s.last_block_loi is None + + # now the last block is set + s.bid_loi = None + s.overall_loi = 1000 + s.last_block_loi = 1000 + spectrum_manager.update([s]) + surveys = spectrum_manager.get_survey_library(updated_since=now) + assert 600 == surveys[0].bid_loi + assert 1000 == surveys[0].overall_loi + assert 1000 == surveys[0].last_block_loi diff --git a/tests/models/test_currency.py b/tests/models/test_currency.py new file mode 100644 index 0000000..40cff88 --- /dev/null +++ b/tests/models/test_currency.py @@ -0,0 +1,410 @@ +"""These were taken from the wxet project's first use of this idea. Not all +functionality is the same, but pasting here so the tests are in the +correct spot... +""" + +from decimal import Decimal +from random import randint + +import pytest + + +class TestUSDCentModel: + + def test_construct_int(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val = randint(0, 999_999) + instance = USDCent(int_val) + assert int_val == instance + + def test_construct_float(self): + from generalresearch.currency import USDCent + + with pytest.warns(expected_warning=Warning) as record: + float_val: float = 10.6789 + instance = USDCent(float_val) + + assert len(record) == 1 + assert "USDCent init with a float. Rounding behavior may be unexpected" in str( + record[0].message + ) + assert instance == USDCent(10) + assert instance == 10 + + def test_construct_decimal(self): + from generalresearch.currency import USDCent + + with pytest.warns(expected_warning=Warning) as record: + decimal_val: Decimal = Decimal("10.0") + instance = USDCent(decimal_val) + + assert len(record) == 1 + assert ( + "USDCent init with a Decimal. Rounding behavior may be unexpected" + in str(record[0].message) + ) + + assert instance == USDCent(10) + assert instance == 10 + + # Now with rounding + with pytest.warns(Warning) as record: + decimal_val: Decimal = Decimal("10.6789") + instance = USDCent(decimal_val) + + assert len(record) == 1 + assert ( + "USDCent init with a Decimal. Rounding behavior may be unexpected" + in str(record[0].message) + ) + + assert instance == USDCent(10) + assert instance == 10 + + def test_construct_negative(self): + from generalresearch.currency import USDCent + + with pytest.raises(expected_exception=ValueError) as cm: + USDCent(-1) + assert "USDCent not be less than zero" in str(cm.value) + + def test_operation_add(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val1 = randint(0, 999_999) + int_val2 = randint(0, 999_999) + + instance1 = USDCent(int_val1) + instance2 = USDCent(int_val2) + + assert int_val1 + int_val2 == instance1 + instance2 + + def test_operation_subtract(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val1 = randint(500_000, 999_999) + int_val2 = randint(0, 499_999) + + instance1 = USDCent(int_val1) + instance2 = USDCent(int_val2) + + assert int_val1 - int_val2 == instance1 - instance2 + + def test_operation_subtract_to_neg(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val = randint(0, 999_999) + instance = USDCent(int_val) + + with pytest.raises(expected_exception=ValueError) as cm: + instance - USDCent(1_000_000) + + assert "USDCent not be less than zero" in str(cm.value) + + def test_operation_multiply(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val1 = randint(0, 999_999) + int_val2 = randint(0, 999_999) + + instance1 = USDCent(int_val1) + instance2 = USDCent(int_val2) + + assert int_val1 * int_val2 == instance1 * instance2 + + def test_operation_div(self): + from generalresearch.currency import USDCent + + with pytest.raises(ValueError) as cm: + USDCent(10) / 2 + assert "Division not allowed for USDCent" in str(cm.value) + + def test_operation_result_type(self): + from generalresearch.currency import USDCent + + int_val = randint(1, 999_999) + instance = USDCent(int_val) + + res_add = instance + USDCent(1) + assert isinstance(res_add, USDCent) + + res_sub = instance - USDCent(1) + assert isinstance(res_sub, USDCent) + + res_multipy = instance * USDCent(2) + assert isinstance(res_multipy, USDCent) + + def test_operation_partner_add(self): + from generalresearch.currency import USDCent + + int_val = randint(1, 999_999) + instance = USDCent(int_val) + + with pytest.raises(expected_exception=AssertionError): + instance + 0.10 + + with pytest.raises(expected_exception=AssertionError): + instance + Decimal(".10") + + with pytest.raises(expected_exception=AssertionError): + instance + "9.9" + + with pytest.raises(expected_exception=AssertionError): + instance + True + + def test_abs(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val = abs(randint(0, 999_999)) + instance = abs(USDCent(int_val)) + + assert int_val == instance + + def test_str(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val = randint(0, 999_999) + instance = USDCent(int_val) + + assert str(int_val) == str(instance) + + def test_operation_result_type_unsupported(self): + """There is no correct answer here, but we at least want to make sure + that a USDCent is returned + """ + from generalresearch.currency import USDCent + + res = USDCent(10) // 1.2 + assert not isinstance(res, USDCent) + assert isinstance(res, float) + + res = USDCent(10) % 1 + assert not isinstance(res, USDCent) + assert isinstance(res, int) + + res = pow(USDCent(10), 2) + assert not isinstance(res, USDCent) + assert isinstance(res, int) + + res = pow(USDCent(10), USDCent(2)) + assert not isinstance(res, USDCent) + assert isinstance(res, int) + + res = float(USDCent(10)) + assert not isinstance(res, USDCent) + assert isinstance(res, float) + + +class TestUSDMillModel: + + def test_construct_int(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val = randint(0, 999_999) + instance = USDMill(int_val) + assert int_val == instance + + def test_construct_float(self): + from generalresearch.currency import USDMill + + with pytest.warns(expected_warning=Warning) as record: + float_val: float = 10.6789 + instance = USDMill(float_val) + + assert len(record) == 1 + assert "USDMill init with a float. Rounding behavior may be unexpected" in str( + record[0].message + ) + assert instance == USDMill(10) + assert instance == 10 + + def test_construct_decimal(self): + from generalresearch.currency import USDMill + + with pytest.warns(expected_warning=Warning) as record: + decimal_val: Decimal = Decimal("10.0") + instance = USDMill(decimal_val) + + assert len(record) == 1 + assert ( + "USDMill init with a Decimal. Rounding behavior may be unexpected" + in str(record[0].message) + ) + + assert instance == USDMill(10) + assert instance == 10 + + # Now with rounding + with pytest.warns(expected_warning=Warning) as record: + decimal_val: Decimal = Decimal("10.6789") + instance = USDMill(decimal_val) + + assert len(record) == 1 + assert ( + "USDMill init with a Decimal. Rounding behavior may be unexpected" + in str(record[0].message) + ) + + assert instance == USDMill(10) + assert instance == 10 + + def test_construct_negative(self): + from generalresearch.currency import USDMill + + with pytest.raises(expected_exception=ValueError) as cm: + USDMill(-1) + assert "USDMill not be less than zero" in str(cm.value) + + def test_operation_add(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val1 = randint(0, 999_999) + int_val2 = randint(0, 999_999) + + instance1 = USDMill(int_val1) + instance2 = USDMill(int_val2) + + assert int_val1 + int_val2 == instance1 + instance2 + + def test_operation_subtract(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val1 = randint(500_000, 999_999) + int_val2 = randint(0, 499_999) + + instance1 = USDMill(int_val1) + instance2 = USDMill(int_val2) + + assert int_val1 - int_val2 == instance1 - instance2 + + def test_operation_subtract_to_neg(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val = randint(0, 999_999) + instance = USDMill(int_val) + + with pytest.raises(expected_exception=ValueError) as cm: + instance - USDMill(1_000_000) + + assert "USDMill not be less than zero" in str(cm.value) + + def test_operation_multiply(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val1 = randint(0, 999_999) + int_val2 = randint(0, 999_999) + + instance1 = USDMill(int_val1) + instance2 = USDMill(int_val2) + + assert int_val1 * int_val2 == instance1 * instance2 + + def test_operation_div(self): + from generalresearch.currency import USDMill + + with pytest.raises(ValueError) as cm: + USDMill(10) / 2 + assert "Division not allowed for USDMill" in str(cm.value) + + def test_operation_result_type(self): + from generalresearch.currency import USDMill + + int_val = randint(1, 999_999) + instance = USDMill(int_val) + + res_add = instance + USDMill(1) + assert isinstance(res_add, USDMill) + + res_sub = instance - USDMill(1) + assert isinstance(res_sub, USDMill) + + res_multipy = instance * USDMill(2) + assert isinstance(res_multipy, USDMill) + + def test_operation_partner_add(self): + from generalresearch.currency import USDMill + + int_val = randint(1, 999_999) + instance = USDMill(int_val) + + with pytest.raises(expected_exception=AssertionError): + instance + 0.10 + + with pytest.raises(expected_exception=AssertionError): + instance + Decimal(".10") + + with pytest.raises(expected_exception=AssertionError): + instance + "9.9" + + with pytest.raises(expected_exception=AssertionError): + instance + True + + def test_abs(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val = abs(randint(0, 999_999)) + instance = abs(USDMill(int_val)) + + assert int_val == instance + + def test_str(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val = randint(0, 999_999) + instance = USDMill(int_val) + + assert str(int_val) == str(instance) + + def test_operation_result_type_unsupported(self): + """There is no correct answer here, but we at least want to make sure + that a USDMill is returned + """ + from generalresearch.currency import USDCent, USDMill + + res = USDMill(10) // 1.2 + assert not isinstance(res, USDMill) + assert isinstance(res, float) + + res = USDMill(10) % 1 + assert not isinstance(res, USDMill) + assert isinstance(res, int) + + res = pow(USDMill(10), 2) + assert not isinstance(res, USDMill) + assert isinstance(res, int) + + res = pow(USDMill(10), USDMill(2)) + assert not isinstance(res, USDCent) + assert isinstance(res, int) + + res = float(USDMill(10)) + assert not isinstance(res, USDMill) + assert isinstance(res, float) + + +class TestNegativeFormatting: + + def test_pos(self): + from generalresearch.currency import format_usd_cent + + assert "-$987.65" == format_usd_cent(-98765) + + def test_neg(self): + from generalresearch.currency import format_usd_cent + + assert "-$123.45" == format_usd_cent(-12345) diff --git a/tests/models/test_device.py b/tests/models/test_device.py new file mode 100644 index 0000000..480e0c0 --- /dev/null +++ b/tests/models/test_device.py @@ -0,0 +1,27 @@ +import pytest + +iphone_ua_string = ( + "Mozilla/5.0 (iPhone; CPU iPhone OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) " + "Version/5.1 Mobile/9B179 Safari/7534.48.3" +) +ipad_ua_string = ( + "Mozilla/5.0(iPad; U; CPU iPhone OS 3_2 like Mac OS X; en-us) AppleWebKit/531.21.10 (KHTML, " + "like Gecko) Version/4.0.4 Mobile/7B314 Safari/531.21.10" +) +windows_ie_ua_string = "Mozilla/5.0 (compatible; MSIE 9.0; Windows NT 6.1; Trident/5.0)" +chromebook_ua_string = ( + "Mozilla/5.0 (X11; CrOS i686 0.12.433) AppleWebKit/534.30 (KHTML, like Gecko) " + "Chrome/12.0.742.77 Safari/534.30" +) + + +class TestDeviceUA: + def test_device_ua(self): + from generalresearch.models import DeviceType + from generalresearch.models.device import parse_device_from_useragent + + assert parse_device_from_useragent(iphone_ua_string) == DeviceType.MOBILE + assert parse_device_from_useragent(ipad_ua_string) == DeviceType.TABLET + assert parse_device_from_useragent(windows_ie_ua_string) == DeviceType.DESKTOP + assert parse_device_from_useragent(chromebook_ua_string) == DeviceType.DESKTOP + assert parse_device_from_useragent("greg bot") == DeviceType.UNKNOWN diff --git a/tests/models/test_finance.py b/tests/models/test_finance.py new file mode 100644 index 0000000..888bf49 --- /dev/null +++ b/tests/models/test_finance.py @@ -0,0 +1,929 @@ +from datetime import timezone, timedelta +from itertools import product as iter_product +from random import randint +from uuid import uuid4 + +import pandas as pd +import pytest + +# noinspection PyUnresolvedReferences +from distributed.utils_test import ( + gen_cluster, + client_no_amm, + loop, + loop_in_thread, + cleanup, + cluster_fixture, + client, +) +from faker import Faker + +from generalresearch.incite.schemas.mergers.pop_ledger import ( + numerical_col_names, +) +from generalresearch.models.thl.finance import ( + POPFinancial, + ProductBalances, + BusinessBalances, +) +from test_utils.conftest import delete_df_collection +from test_utils.incite.collections.conftest import ledger_collection +from test_utils.incite.mergers.conftest import pop_ledger_merge +from test_utils.managers.ledger.conftest import ( + create_main_accounts, + session_with_tx_factory, +) + +fake = Faker() + + +class TestProductBalanceInitialize: + + def test_unknown_fields(self): + with pytest.raises(expected_exception=ValueError): + ProductBalances.model_validate( + { + "bp_payment.DEBIT": 1, + } + ) + + def test_payout(self): + val = randint(1, 1_000) + instance = ProductBalances.model_validate({"bp_payment.CREDIT": val}) + assert instance.payout == val + + def test_adjustment(self): + instance = ProductBalances.model_validate( + {"bp_adjustment.CREDIT": 90, "bp_adjustment.DEBIT": 147} + ) + + assert -57 == instance.adjustment + + def test_plug(self): + instance = ProductBalances.model_validate( + { + "bp_adjustment.CREDIT": 1000, + "bp_adjustment.DEBIT": 200, + "plug.DEBIT": 50, + } + ) + assert 750 == instance.adjustment + + instance = ProductBalances.model_validate( + { + "bp_payment.CREDIT": 789, + "bp_adjustment.CREDIT": 23, + "bp_adjustment.DEBIT": 101, + "plug.DEBIT": 17, + } + ) + assert 694 == instance.net + assert 694 == instance.balance + + def test_expense(self): + instance = ProductBalances.model_validate( + {"user_bonus.CREDIT": 0, "user_bonus.DEBIT": 999} + ) + + assert -999 == instance.expense + + def test_payment(self): + instance = ProductBalances.model_validate( + {"bp_payout.CREDIT": 1, "bp_payout.DEBIT": 100} + ) + + assert 99 == instance.payment + + def test_balance(self): + instance = ProductBalances.model_validate( + { + # Payouts from surveys: 1000 + "bp_payment.CREDIT": 1000, + # Adjustments: -200 + "bp_adjustment.CREDIT": 100, + "bp_adjustment.DEBIT": 300, + # Expense: -50 + "user_bonus.CREDIT": 0, + "user_bonus.DEBIT": 50, + # Prior supplier Payouts = 99 + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 100, + } + ) + + # Supplier payments aren't considered in the net + assert 750 == instance.net + + # Confirm any Supplier payments are taken out of their balance + assert 651 == instance.balance + + def test_retainer(self): + instance = ProductBalances.model_validate( + { + "bp_payment.CREDIT": 1000, + } + ) + + assert 1000 == instance.balance + assert 250 == instance.retainer + + instance = ProductBalances.model_validate( + { + "bp_payment.CREDIT": 1000, + # 1001 worth of adjustments, making it negative + "bp_adjustment.DEBIT": 1001, + } + ) + + assert -1 == instance.balance + assert 0 == instance.retainer + + def test_available_balance(self): + instance = ProductBalances.model_validate( + { + "bp_payment.CREDIT": 1000, + } + ) + + assert 750 == instance.available_balance + + instance = ProductBalances.model_validate( + { + # Payouts from surveys: $188.37 + "bp_payment.CREDIT": 18_837, + # Adjustments: -$7.53 + $.17 + "bp_adjustment.CREDIT": 17, + "bp_adjustment.DEBIT": 753, + # $.15 of those marketplace Failure >> Completes were never + # actually paid out, so plug those positive adjustments + "plug.DEBIT": 15, + # Expense: -$27.45 + "user_bonus.CREDIT": 0, + "user_bonus.DEBIT": 2_745, + # Prior supplier Payouts = $100 + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 10_001, + } + ) + + assert 18837 == instance.payout + assert -751 == instance.adjustment + assert 15341 == instance.net + + # Confirm any Supplier payments are taken out of their balance + assert 5341 == instance.balance + assert 1335 == instance.retainer + assert 4006 == instance.available_balance + + def test_json_schema(self): + instance = ProductBalances.model_validate( + { + # Payouts from surveys: 1000 + "bp_payment.CREDIT": 1000, + # Adjustments: -200 + "bp_adjustment.CREDIT": 100, + "bp_adjustment.DEBIT": 300, + # $.80 of those marketplace Failure >> Completes were never + # actually paid out, so plug those positive adjustments + "plug.DEBIT": 80, + # Expense: -50 + "user_bonus.CREDIT": 0, + "user_bonus.DEBIT": 50, + # Prior supplier Payouts = 99 + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 100, + } + ) + + assert isinstance(instance.model_json_schema(), dict) + openapi_fields = list(instance.model_json_schema()["properties"].keys()) + + # Ensure the SkipJsonSchema is working.. + assert "mp_payment_credit" not in openapi_fields + assert "mp_payment_debit" not in openapi_fields + assert "mp_adjustment_credit" not in openapi_fields + assert "mp_adjustment_debit" not in openapi_fields + assert "bp_payment_debit" not in openapi_fields + assert "plug_credit" not in openapi_fields + assert "plug_debit" not in openapi_fields + + # Confirm the @property computed fields show up in openapi. I don't + # know how to do that yet... so this is check to confirm they're + # known computed fields for now + computed_fields = list(instance.model_computed_fields.keys()) + assert "payout" in computed_fields + assert "adjustment" in computed_fields + assert "expense" in computed_fields + assert "payment" in computed_fields + assert "net" in computed_fields + assert "balance" in computed_fields + assert "retainer" in computed_fields + assert "available_balance" in computed_fields + + def test_repr(self): + instance = ProductBalances.model_validate( + { + # Payouts from surveys: 1000 + "bp_payment.CREDIT": 1000, + # Adjustments: -200 + "bp_adjustment.CREDIT": 100, + "bp_adjustment.DEBIT": 300, + # $.80 of those marketplace Failure >> Completes were never + # actually paid out, so plug those positive adjustments + "plug.DEBIT": 80, + # Expense: -50 + "user_bonus.CREDIT": 0, + "user_bonus.DEBIT": 50, + # Prior supplier Payouts = 99 + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 100, + } + ) + + assert "Total Adjustment: -$2.80" in str(instance) + + +class TestBusinessBalanceInitialize: + + def test_validate_product_ids(self): + instance1 = ProductBalances.model_validate( + {"bp_payment.CREDIT": 500, "bp_adjustment.DEBIT": 40} + ) + + instance2 = ProductBalances.model_validate( + {"bp_payment.CREDIT": 500, "bp_adjustment.DEBIT": 40} + ) + + with pytest.raises(expected_exception=ValueError) as cm: + BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert "'product_id' must be set for BusinessBalance children" in str(cm.value) + + # Confirm that once you add them, it successfully initializes + instance1.product_id = uuid4().hex + instance2.product_id = uuid4().hex + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert isinstance(instance, BusinessBalances) + + def test_payout(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.DEBIT": 40, + } + ) + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.DEBIT": 40, + } + ) + + # Confirm the base payouts are as expected. + assert instance1.payout == 500 + assert instance2.payout == 500 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.payout == 1_000 + + def test_adjustment(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + } + ) + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + } + ) + + # Confirm the base adjustment are as expected. + assert instance1.adjustment == -30 + assert instance2.adjustment == -30 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.adjustment == -60 + + def test_expense(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + } + ) + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + } + ) + + # Confirm the base adjustment are as expected. + assert instance1.expense == -4 + assert instance2.expense == -4 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.expense == -8 + + def test_net(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + } + ) + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + } + ) + + # Confirm the simple net + assert instance1.net == 466 + assert instance2.net == 466 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.net == 466 * 2 + + def test_payment(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 10_001, + } + ) + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 10_001, + } + ) + assert instance1.payment == 10_000 + assert instance2.payment == 10_000 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.payment == 20_000 + + def test_balance(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 500, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 10_001, + } + ) + assert instance1.balance == 39_506 + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 40_000, + "bp_adjustment.CREDIT": 2_000, + "bp_adjustment.DEBIT": 400, + "plug.DEBIT": 983, + "user_bonus.DEBIT": 392, + "user_bonus.CREDIT": 0, + "bp_payout.CREDIT": 0, + "bp_payout.DEBIT": 8_000, + } + ) + assert instance2.balance == 32_225 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.balance == 39_506 + 32_225 + + def test_retainer(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 500, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 10_001, + } + ) + assert instance1.balance == 39_506 + assert instance1.retainer == 9_876 + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 40_000, + "bp_adjustment.CREDIT": 2_000, + "bp_adjustment.DEBIT": 400, + "plug.DEBIT": 983, + "user_bonus.DEBIT": 392, + "user_bonus.CREDIT": 0, + "bp_payout.CREDIT": 0, + "bp_payout.DEBIT": 8_000, + } + ) + assert instance2.balance == 32_225 + assert instance2.retainer == 8_056 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.retainer == 9_876 + 8_056 + + def test_available_balance(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 500, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 10_001, + } + ) + assert instance1.balance == 39_506 + assert instance1.retainer == 9_876 + assert instance1.available_balance == 39_506 - 9_876 + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 40_000, + "bp_adjustment.CREDIT": 2_000, + "bp_adjustment.DEBIT": 400, + "plug.DEBIT": 983, + "user_bonus.DEBIT": 392, + "user_bonus.CREDIT": 0, + "bp_payout.CREDIT": 0, + "bp_payout.DEBIT": 8_000, + } + ) + assert instance2.balance == 32_225 + assert instance2.retainer == 8_056 + assert instance2.available_balance == 32_225 - 8_056 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.retainer == 9_876 + 8_056 + assert instance.available_balance == instance.balance - (9_876 + 8_056) + + def test_negative_net(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.DEBIT": 50_001, + "bp_payout.DEBIT": 4_999, + } + ) + assert 50_000 == instance1.payout + assert -50_001 == instance1.adjustment + assert 4_999 == instance1.payment + + assert -1 == instance1.net + assert -5_000 == instance1.balance + assert 0 == instance1.available_balance + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.DEBIT": 10_000, + "bp_payout.DEBIT": 10_000, + } + ) + assert 50_000 == instance2.payout + assert -10_000 == instance2.adjustment + assert 10_000 == instance2.payment + + assert 40_000 == instance2.net + assert 30_000 == instance2.balance + assert 22_500 == instance2.available_balance + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert 100_000 == instance.payout + assert -60_001 == instance.adjustment + assert 14_999 == instance.payment + + assert 39_999 == instance.net + assert 25_000 == instance.balance + + # Compare the retainers together. We can't just calculate the retainer + # on the Business.balance because it'll be "masked" by any Products + # that have a negative balance and actually reduce the Business's + # retainer as a whole. Therefore, we need to sum together each of the + # retainers from the child Products + assert 0 == instance1.retainer + assert 7_500 == instance2.retainer + assert 6_250 == instance.balance * 0.25 + assert 6_250 != instance.retainer + assert 7_500 == instance.retainer + assert 25_000 - 7_500 == instance.available_balance + + def test_str(self): + instance = BusinessBalances.model_validate( + { + "product_balances": [ + ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.DEBIT": 50_001, + "bp_payout.DEBIT": 4_999, + } + ), + ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.DEBIT": 10_000, + "bp_payout.DEBIT": 10_000, + } + ), + ] + } + ) + + assert "Products: 2" in str(instance) + assert "Total Adjustment: -$600.01" in str(instance) + assert "Available Balance: $175.00" in str(instance) + + def test_from_json(self): + s = '{"product_balances":[{"product_id":"7485124190274248bc14132755c8fc3b","bp_payment_credit":1184,"adjustment_credit":0,"adjustment_debit":0,"supplier_credit":0,"supplier_debit":0,"user_bonus_credit":0,"user_bonus_debit":0,"payout":1184,"adjustment":0,"expense":0,"net":1184,"payment":0,"balance":1184,"retainer":296,"available_balance":888,"adjustment_percent":0.0}],"payout":1184,"adjustment":0,"expense":0,"net":1184,"payment":0,"balance":1184,"retainer":296,"available_balance":888,"adjustment_percent":0.0}' + instance = BusinessBalances.model_validate_json(s) + + assert instance.payout == 1184 + assert instance.available_balance == 888 + assert instance.retainer == 296 + assert len(instance.product_balances) == 1 + assert instance.adjustment_percent == 0.0 + assert instance.expense == 0 + + p = instance.product_balances[0] + assert p.payout == 1184 + assert p.available_balance == 888 + assert p.retainer == 296 + + +@pytest.mark.parametrize( + argnames="offset, duration", + argvalues=list( + iter_product( + ["12h", "2D"], + [timedelta(days=2), timedelta(days=5)], + ) + ), +) +class TestProductFinanceData: + + def test_base( + self, + client_no_amm, + ledger_collection, + pop_ledger_merge, + mnt_filepath, + session_with_tx_factory, + product, + user_factory, + start, + duration, + delete_df_collection, + thl_lm, + create_main_accounts, + ): + from generalresearch.models.thl.user import User + + # -- Build & Setup + # assert ledger_collection.start is None + # assert ledger_collection.offset is None + u: User = user_factory(product=product, created=ledger_collection.start) + + for item in ledger_collection.items: + + for s_idx in range(3): + rand_item_time = fake.date_time_between( + start_date=item.start, + end_date=item.finish, + tzinfo=timezone.utc, + ) + session_with_tx_factory(started=rand_item_time, user=u) + + item.initial_load(overwrite=True) + + # Confirm any of the items are archived + assert ledger_collection.progress.has_archive.eq(True).all() + + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + # assert pop_ledger_merge.progress.has_archive.eq(True).all() + + item_finishes = [i.finish for i in ledger_collection.items] + item_finishes.sort(reverse=True) + last_item_finish = item_finishes[0] + + # -- + account = thl_lm.get_account_or_create_bp_wallet(product=u.product) + + ddf = pop_ledger_merge.ddf( + force_rr_latest=False, + include_partial=True, + columns=numerical_col_names + ["time_idx", "account_id"], + filters=[ + ("account_id", "==", account.uuid), + ("time_idx", ">=", start), + ("time_idx", "<", start + duration), + ], + ) + + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + + # -- + + df = df.groupby([pd.Grouper(key="time_idx", freq="D"), "account_id"]).sum() + res = POPFinancial.list_from_pandas(df, accounts=[account]) + + assert isinstance(res, list) + assert isinstance(res[0], POPFinancial) + + # On this, we can assert all products are the same, and that there are + # no overlapping time intervals + assert 1 == len(set(list([i.product_id for i in res]))) + assert len(res) == len(set(list([i.time for i in res]))) + + +@pytest.mark.parametrize( + argnames="offset, duration", + argvalues=list( + iter_product( + ["12h", "2D"], + [timedelta(days=2), timedelta(days=5)], + ) + ), +) +class TestPOPFinancialData: + + def test_base( + self, + client_no_amm, + ledger_collection, + pop_ledger_merge, + mnt_filepath, + user_factory, + product, + start, + duration, + create_main_accounts, + session_with_tx_factory, + session_manager, + thl_lm, + delete_df_collection, + delete_ledger_db, + ): + # -- Build & Setup + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + # assert ledger_collection.start is None + # assert ledger_collection.offset is None + + users = [] + for idx in range(5): + u = user_factory(product=product) + + for item in ledger_collection.items: + rand_item_time = fake.date_time_between( + start_date=item.start, + end_date=item.finish, + tzinfo=timezone.utc, + ) + + session_with_tx_factory(started=rand_item_time, user=u) + item.initial_load(overwrite=True) + + users.append(u) + + # Confirm any of the items are archived + assert ledger_collection.progress.has_archive.eq(True).all() + + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + # assert pop_ledger_merge.progress.has_archive.eq(True).all() + + item_finishes = [i.finish for i in ledger_collection.items] + item_finishes.sort(reverse=True) + last_item_finish = item_finishes[0] + + accounts = [] + for user in users: + account = thl_lm.get_account_or_create_bp_wallet(product=u.product) + accounts.append(account) + account_ids = [a.uuid for a in accounts] + + # -- + + ddf = pop_ledger_merge.ddf( + force_rr_latest=False, + include_partial=True, + columns=numerical_col_names + ["time_idx", "account_id"], + filters=[ + ("account_id", "in", account_ids), + ("time_idx", ">=", start), + ("time_idx", "<", last_item_finish), + ], + ) + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + df = df.groupby([pd.Grouper(key="time_idx", freq="D"), "account_id"]).sum() + res = POPFinancial.list_from_pandas(df, accounts=accounts) + + assert isinstance(res, list) + for i in res: + assert isinstance(i, POPFinancial) + + # This does not return the AccountID, it's the Product ID + assert i.product_id in [u.product_id for u in users] + + # 1 Product, multiple Users + assert len(users) == len(accounts) + + # We group on days, and duration is a parameter to parametrize + assert isinstance(duration, timedelta) + + # -- Teardown + delete_df_collection(ledger_collection) + + +@pytest.mark.parametrize( + argnames="offset, duration", + argvalues=list( + iter_product( + ["12h", "1D"], + [timedelta(days=2), timedelta(days=3)], + ) + ), +) +class TestBusinessBalanceData: + def test_from_pandas( + self, + client_no_amm, + ledger_collection, + pop_ledger_merge, + user_factory, + product, + create_main_accounts, + session_factory, + thl_lm, + session_manager, + start, + thl_web_rr, + duration, + delete_df_collection, + delete_ledger_db, + session_with_tx_factory, + offset, + rm_ledger_collection, + ): + from generalresearch.models.thl.user import User + from generalresearch.models.thl.ledger import LedgerAccount + + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + rm_ledger_collection() + + for idx in range(5): + u: User = user_factory(product=product, created=ledger_collection.start) + + for item in ledger_collection.items: + item_time = fake.date_time_between( + start_date=item.start, + end_date=item.finish, + tzinfo=timezone.utc, + ) + session_with_tx_factory(started=item_time, user=u) + item.initial_load(overwrite=True) + + # Confirm any of the items are archived + assert ledger_collection.progress.has_archive.eq(True).all() + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + # assert pop_ledger_merge.progress.has_archive.eq(True).all() + + account: LedgerAccount = thl_lm.get_account_or_create_bp_wallet(product=product) + + ddf = pop_ledger_merge.ddf( + force_rr_latest=False, + include_partial=True, + columns=numerical_col_names + ["account_id"], + filters=[("account_id", "in", [account.uuid])], + ) + ddf = ddf.groupby("account_id").sum() + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + assert isinstance(df, pd.DataFrame) + + instance = BusinessBalances.from_pandas( + input_data=df, accounts=[account], thl_pg_config=thl_web_rr + ) + balance: int = thl_lm.get_account_balance(account=account) + + assert instance.balance == balance + assert instance.net == balance + assert instance.payout == balance + + assert instance.payment == 0 + assert instance.adjustment == 0 + assert instance.adjustment_percent == 0.0 + + assert instance.expense == 0 + + # Cleanup + delete_ledger_db() + delete_df_collection(coll=ledger_collection) diff --git a/tests/models/thl/__init__.py b/tests/models/thl/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/models/thl/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/models/thl/question/__init__.py b/tests/models/thl/question/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/models/thl/question/__init__.py diff --git a/tests/models/thl/question/test_question_info.py b/tests/models/thl/question/test_question_info.py new file mode 100644 index 0000000..945ee7a --- /dev/null +++ b/tests/models/thl/question/test_question_info.py @@ -0,0 +1,146 @@ +from generalresearch.models.thl.profiling.upk_property import ( + UpkProperty, + ProfilingInfo, +) + + +class TestQuestionInfo: + + def test_init(self): + + s = ( + '[{"property_label": "hispanic", "cardinality": "*", "prop_type": "i", "country_iso": "us", ' + '"property_id": "05170ae296ab49178a075cab2a2073a6", "item_id": "7911ec1468b146ee870951f8ae9cbac1", ' + '"item_label": "panamanian", "gold_standard": 1, "options": [{"id": "c358c11e72c74fa2880358f1d4be85ab", ' + '"label": "not_hispanic"}, {"id": "b1d6c475770849bc8e0200054975dc9c", "label": "yes_hispanic"}, ' + '{"id": "bd1eb44495d84b029e107c188003c2bd", "label": "other_hispanic"}, ' + '{"id": "f290ad5e75bf4f4ea94dc847f57c1bd3", "label": "mexican"}, ' + '{"id": "49f50f2801bd415ea353063bfc02d252", "label": "puerto_rican"}, ' + '{"id": "dcbe005e522f4b10928773926601f8bf", "label": "cuban"}, ' + '{"id": "467ef8ddb7ac4edb88ba9ef817cbb7e9", "label": "salvadoran"}, ' + '{"id": "3c98e7250707403cba2f4dc7b877c963", "label": "dominican"}, ' + '{"id": "981ee77f6d6742609825ef54fea824a8", "label": "guatemalan"}, ' + '{"id": "81c8057b809245a7ae1b8a867ea6c91e", "label": "colombian"}, ' + '{"id": "513656d5f9e249fa955c3b527d483b93", "label": "honduran"}, ' + '{"id": "afc8cddd0c7b4581bea24ccd64db3446", "label": "ecuadorian"}, ' + '{"id": "61f34b36e80747a89d85e1eb17536f84", "label": "argentinian"}, ' + '{"id": "5330cfa681d44aa8ade3a6d0ea198e44", "label": "peruvian"}, ' + '{"id": "e7bceaffd76e486596205d8545019448", "label": "nicaraguan"}, ' + '{"id": "b7bbb2ebf8424714962e6c4f43275985", "label": "spanish"}, ' + '{"id": "8bf539785e7a487892a2f97e52b1932d", "label": "venezuelan"}, ' + '{"id": "7911ec1468b146ee870951f8ae9cbac1", "label": "panamanian"}], "category": [{"id": ' + '"4fd8381d5a1c4409ab007ca254ced084", "label": "Demographic", "path": "/Demographic", ' + '"adwords_vertical_id": null}]}, {"property_label": "ethnic_group", "cardinality": "*", "prop_type": ' + '"i", "country_iso": "us", "property_id": "15070958225d4132b7f6674fcfc979f6", "item_id": ' + '"64b7114cf08143949e3bcc3d00a5d8a0", "item_label": "other_ethnicity", "gold_standard": 1, "options": [{' + '"id": "a72e97f4055e4014a22bee4632cbf573", "label": "caucasians"}, ' + '{"id": "4760353bc0654e46a928ba697b102735", "label": "black_or_african_american"}, ' + '{"id": "20ff0a2969fa4656bbda5c3e0874e63b", "label": "asian"}, ' + '{"id": "107e0a79e6b94b74926c44e70faf3793", "label": "native_hawaiian_or_other_pacific_islander"}, ' + '{"id": "900fa12691d5458c8665bf468f1c98c1", "label": "native_americans"}, ' + '{"id": "64b7114cf08143949e3bcc3d00a5d8a0", "label": "other_ethnicity"}], "category": [{"id": ' + '"4fd8381d5a1c4409ab007ca254ced084", "label": "Demographic", "path": "/Demographic", ' + '"adwords_vertical_id": null}]}, {"property_label": "educational_attainment", "cardinality": "?", ' + '"prop_type": "i", "country_iso": "us", "property_id": "2637783d4b2b4075b93e2a156e16e1d8", "item_id": ' + '"934e7b81d6744a1baa31bbc51f0965d5", "item_label": "other_education", "gold_standard": 1, "options": [{' + '"id": "df35ef9e474b4bf9af520aa86630202d", "label": "3rd_grade_completion"}, ' + '{"id": "83763370a1064bd5ba76d1b68c4b8a23", "label": "8th_grade_completion"}, ' + '{"id": "f0c25a0670c340bc9250099dcce50957", "label": "not_high_school_graduate"}, ' + '{"id": "02ff74c872bd458983a83847e1a9f8fd", "label": "high_school_completion"}, ' + '{"id": "ba8beb807d56441f8fea9b490ed7561c", "label": "vocational_program_completion"}, ' + '{"id": "65373a5f348a410c923e079ddbb58e9b", "label": "some_college_completion"}, ' + '{"id": "2d15d96df85d4cc7b6f58911fdc8d5e2", "label": "associate_academic_degree_completion"}, ' + '{"id": "497b1fedec464151b063cd5367643ffa", "label": "bachelors_degree_completion"}, ' + '{"id": "295133068ac84424ae75e973dc9f2a78", "label": "some_graduate_completion"}, ' + '{"id": "e64f874faeff4062a5aa72ac483b4b9f", "label": "masters_degree_completion"}, ' + '{"id": "cbaec19a636d476385fb8e7842b044f5", "label": "doctorate_degree_completion"}, ' + '{"id": "934e7b81d6744a1baa31bbc51f0965d5", "label": "other_education"}], "category": [{"id": ' + '"4fd8381d5a1c4409ab007ca254ced084", "label": "Demographic", "path": "/Demographic", ' + '"adwords_vertical_id": null}]}, {"property_label": "household_spoken_language", "cardinality": "*", ' + '"prop_type": "i", "country_iso": "us", "property_id": "5a844571073d482a96853a0594859a51", "item_id": ' + '"62b39c1de141422896ad4ab3c4318209", "item_label": "dut", "gold_standard": 1, "options": [{"id": ' + '"f65cd57b79d14f0f8460761ce41ec173", "label": "ara"}, {"id": "6d49de1f8f394216821310abd29392d9", ' + '"label": "zho"}, {"id": "be6dc23c2bf34c3f81e96ddace22800d", "label": "eng"}, ' + '{"id": "ddc81f28752d47a3b1c1f3b8b01a9b07", "label": "fre"}, {"id": "2dbb67b29bd34e0eb630b1b8385542ca", ' + '"label": "ger"}, {"id": "a747f96952fc4b9d97edeeee5120091b", "label": "hat"}, ' + '{"id": "7144b04a3219433baac86273677551fa", "label": "hin"}, {"id": "e07ff3e82c7149eaab7ea2b39ee6a6dc", ' + '"label": "ita"}, {"id": "b681eff81975432ebfb9f5cc22dedaa3", "label": "jpn"}, ' + '{"id": "5cb20440a8f64c9ca62fb49c1e80cdef", "label": "kor"}, {"id": "171c4b77d4204bc6ac0c2b81e38a10ff", ' + '"label": "pan"}, {"id": "8c3ec18e6b6c4a55a00dd6052e8e84fb", "label": "pol"}, ' + '{"id": "3ce074d81d384dd5b96f1fb48f87bf01", "label": "por"}, {"id": "6138dc951990458fa88a666f6ddd907b", ' + '"label": "rus"}, {"id": "e66e5ecc07df4ebaa546e0b436f034bd", "label": "spa"}, ' + '{"id": "5a981b3d2f0d402a96dd2d0392ec2fcb", "label": "tgl"}, {"id": "b446251bd211403487806c4d0a904981", ' + '"label": "vie"}, {"id": "92fb3ee337374e2db875fb23f52eed46", "label": "xxx"}, ' + '{"id": "8b1f590f12f24cc1924d7bdcbe82081e", "label": "ind"}, {"id": "bf3f4be556a34ff4b836420149fd2037", ' + '"label": "tur"}, {"id": "87ca815c43ba4e7f98cbca98821aa508", "label": "zul"}, ' + '{"id": "0adbf915a7a64d67a87bb3ce5d39ca54", "label": "may"}, {"id": "62b39c1de141422896ad4ab3c4318209", ' + '"label": "dut"}], "category": [{"id": "4fd8381d5a1c4409ab007ca254ced084", "label": "Demographic", ' + '"path": "/Demographic", "adwords_vertical_id": null}]}, {"property_label": "gender", "cardinality": ' + '"?", "prop_type": "i", "country_iso": "us", "property_id": "73175402104741549f21de2071556cd7", ' + '"item_id": "093593e316344cd3a0ac73669fca8048", "item_label": "other_gender", "gold_standard": 1, ' + '"options": [{"id": "b9fc5ea07f3a4252a792fd4a49e7b52b", "label": "male"}, ' + '{"id": "9fdb8e5e18474a0b84a0262c21e17b56", "label": "female"}, ' + '{"id": "093593e316344cd3a0ac73669fca8048", "label": "other_gender"}], "category": [{"id": ' + '"4fd8381d5a1c4409ab007ca254ced084", "label": "Demographic", "path": "/Demographic", ' + '"adwords_vertical_id": null}]}, {"property_label": "age_in_years", "cardinality": "?", "prop_type": ' + '"n", "country_iso": "us", "property_id": "94f7379437874076b345d76642d4ce6d", "item_id": null, ' + '"item_label": null, "gold_standard": 1, "category": [{"id": "4fd8381d5a1c4409ab007ca254ced084", ' + '"label": "Demographic", "path": "/Demographic", "adwords_vertical_id": null}]}, {"property_label": ' + '"children_age_gender", "cardinality": "*", "prop_type": "i", "country_iso": "us", "property_id": ' + '"e926142fcea94b9cbbe13dc7891e1e7f", "item_id": "b7b8074e95334b008e8958ccb0a204f1", "item_label": ' + '"female_18", "gold_standard": 1, "options": [{"id": "16a6448ec24c48d4993d78ebee33f9b4", ' + '"label": "male_under_1"}, {"id": "809c04cb2e3b4a3bbd8077ab62cdc220", "label": "female_under_1"}, ' + '{"id": "295e05bb6a0843bc998890b24c99841e", "label": "no_children"}, ' + '{"id": "142cb948d98c4ae8b0ef2ef10978e023", "label": "male_0"}, ' + '{"id": "5a5c1b0e9abc48a98b3bc5f817d6e9d0", "label": "male_1"}, ' + '{"id": "286b1a9afb884bdfb676dbb855479d1e", "label": "male_2"}, ' + '{"id": "942ca3cda699453093df8cbabb890607", "label": "male_3"}, ' + '{"id": "995818d432f643ec8dd17e0809b24b56", "label": "male_4"}, ' + '{"id": "f38f8b57f25f4cdea0f270297a1e7a5c", "label": "male_5"}, ' + '{"id": "975df709e6d140d1a470db35023c432d", "label": "male_6"}, ' + '{"id": "f60bd89bbe0f4e92b90bccbc500467c2", "label": "male_7"}, ' + '{"id": "6714ceb3ed5042c0b605f00b06814207", "label": "male_8"}, ' + '{"id": "c03c2f8271d443cf9df380e84b4dea4c", "label": "male_9"}, ' + '{"id": "11690ee0f5a54cb794f7ddd010d74fa2", "label": "male_10"}, ' + '{"id": "17bef9a9d14b4197b2c5609fa94b0642", "label": "male_11"}, ' + '{"id": "e79c8338fe28454f89ccc78daf6f409a", "label": "male_12"}, ' + '{"id": "3a4f87acb3fa41f4ae08dfe2858238c1", "label": "male_13"}, ' + '{"id": "36ffb79d8b7840a7a8cb8d63bbc8df59", "label": "male_14"}, ' + '{"id": "1401a508f9664347aee927f6ec5b0a40", "label": "male_15"}, ' + '{"id": "6e0943c5ec4a4f75869eb195e3eafa50", "label": "male_16"}, ' + '{"id": "47d4b27b7b5242758a9fff13d3d324cf", "label": "male_17"}, ' + '{"id": "9ce886459dd44c9395eb77e1386ab181", "label": "female_0"}, ' + '{"id": "6499ccbf990d4be5b686aec1c7353fd8", "label": "female_1"}, ' + '{"id": "d85ceaa39f6d492abfc8da49acfd14f2", "label": "female_2"}, ' + '{"id": "18edb45c138e451d8cb428aefbb80f9c", "label": "female_3"}, ' + '{"id": "bac6f006ed9f4ccf85f48e91e99fdfd1", "label": "female_4"}, ' + '{"id": "5a6a1a8ad00c4ce8be52dcb267b034ff", "label": "female_5"}, ' + '{"id": "6bff0acbf6364c94ad89507bcd5f4f45", "label": "female_6"}, ' + '{"id": "d0d56a0a6b6f4516a366a2ce139b4411", "label": "female_7"}, ' + '{"id": "bda6028468044b659843e2bef4db2175", "label": "female_8"}, ' + '{"id": "dbb6d50325464032b456357b1a6e5e9c", "label": "female_9"}, ' + '{"id": "b87a93d7dc1348edac5e771684d63fb8", "label": "female_10"}, ' + '{"id": "11449d0d98f14e27ba47de40b18921d7", "label": "female_11"}, ' + '{"id": "16156501e97b4263962cbbb743840292", "label": "female_12"}, ' + '{"id": "04ee971c89a345cc8141a45bce96050c", "label": "female_13"}, ' + '{"id": "e818d310bfbc4faba4355e5d2ed49d4f", "label": "female_14"}, ' + '{"id": "440d25e078924ba0973163153c417ed6", "label": "female_15"}, ' + '{"id": "78ff804cc9b441c5a524bd91e3d1f8bf", "label": "female_16"}, ' + '{"id": "4b04d804d7d84786b2b1c22e4ed440f5", "label": "female_17"}, ' + '{"id": "28bc848cd3ff44c3893c76bfc9bc0c4e", "label": "male_18"}, ' + '{"id": "b7b8074e95334b008e8958ccb0a204f1", "label": "female_18"}], "category": [{"id": ' + '"e18ba6e9d51e482cbb19acf2e6f505ce", "label": "Parenting", "path": "/People & Society/Family & ' + 'Relationships/Family/Parenting", "adwords_vertical_id": "58"}]}, {"property_label": "home_postal_code", ' + '"cardinality": "?", "prop_type": "x", "country_iso": "us", "property_id": ' + '"f3b32ebe78014fbeb1ed6ff77d6338bf", "item_id": null, "item_label": null, "gold_standard": 1, ' + '"category": [{"id": "4fd8381d5a1c4409ab007ca254ced084", "label": "Demographic", "path": "/Demographic", ' + '"adwords_vertical_id": null}]}, {"property_label": "household_income", "cardinality": "?", "prop_type": ' + '"n", "country_iso": "us", "property_id": "ff5b1d4501d5478f98de8c90ef996ac1", "item_id": null, ' + '"item_label": null, "gold_standard": 1, "category": [{"id": "4fd8381d5a1c4409ab007ca254ced084", ' + '"label": "Demographic", "path": "/Demographic", "adwords_vertical_id": null}]}]' + ) + instance_list = ProfilingInfo.validate_json(s) + + assert isinstance(instance_list, list) + for i in instance_list: + assert isinstance(i, UpkProperty) diff --git a/tests/models/thl/question/test_user_info.py b/tests/models/thl/question/test_user_info.py new file mode 100644 index 0000000..0bbbc78 --- /dev/null +++ b/tests/models/thl/question/test_user_info.py @@ -0,0 +1,32 @@ +from generalresearch.models.thl.profiling.user_info import UserInfo + + +class TestUserInfo: + + def test_init(self): + + s = ( + '{"user_profile_knowledge": [], "marketplace_profile_knowledge": [{"source": "d", "question_id": ' + '"1", "answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "pr", ' + '"question_id": "3", "answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": ' + '"h", "question_id": "60", "answer": ["58"], "created": "2023-11-07T16:41:05.234096Z"}, ' + '{"source": "c", "question_id": "43", "answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, ' + '{"source": "s", "question_id": "211", "answer": ["111"], "created": ' + '"2023-11-07T16:41:05.234096Z"}, {"source": "s", "question_id": "1843", "answer": ["111"], ' + '"created": "2023-11-07T16:41:05.234096Z"}, {"source": "h", "question_id": "13959", "answer": [' + '"244155"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "c", "question_id": "33092", ' + '"answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "c", "question_id": "gender", ' + '"answer": ["10682"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "e", "question_id": ' + '"gender", "answer": ["male"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "f", ' + '"question_id": "gender", "answer": ["male"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": ' + '"i", "question_id": "gender", "answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, ' + '{"source": "c", "question_id": "137510", "answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, ' + '{"source": "m", "question_id": "gender", "answer": ["1"], "created": ' + '"2023-11-07T16:41:05.234096Z"}, {"source": "o", "question_id": "gender", "answer": ["male"], ' + '"created": "2023-11-07T16:41:05.234096Z"}, {"source": "c", "question_id": "gender_plus", "answer": [' + '"7657644"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "i", "question_id": ' + '"gender_plus", "answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "c", ' + '"question_id": "income_level", "answer": ["9071"], "created": "2023-11-07T16:41:05.234096Z"}]}' + ) + instance = UserInfo.model_validate_json(s) + assert isinstance(instance, UserInfo) diff --git a/tests/models/thl/test_adjustments.py b/tests/models/thl/test_adjustments.py new file mode 100644 index 0000000..15d01d0 --- /dev/null +++ b/tests/models/thl/test_adjustments.py @@ -0,0 +1,688 @@ +from datetime import datetime, timezone, timedelta +from decimal import Decimal + +import pytest + +from generalresearch.models import Source +from generalresearch.models.thl.session import ( + Wall, + Status, + StatusCode1, + WallAdjustedStatus, + SessionAdjustedStatus, +) + +started1 = datetime(2023, 1, 1, tzinfo=timezone.utc) +started2 = datetime(2023, 1, 1, 0, 10, 0, tzinfo=timezone.utc) +finished1 = started1 + timedelta(minutes=10) +finished2 = started2 + timedelta(minutes=10) + +adj_ts = datetime(2023, 2, 2, tzinfo=timezone.utc) +adj_ts2 = datetime(2023, 2, 3, tzinfo=timezone.utc) +adj_ts3 = datetime(2023, 2, 4, tzinfo=timezone.utc) + + +class TestProductAdjustments: + + @pytest.mark.parametrize("payout", [".6", "1", "1.8", "2", "500.0000"]) + def test_determine_bp_payment_no_rounding(self, product_factory, payout): + p1 = product_factory(commission_pct=Decimal("0.05")) + res = p1.determine_bp_payment(thl_net=Decimal(payout)) + assert isinstance(res, Decimal) + assert res == Decimal(payout) * Decimal("0.95") + + @pytest.mark.parametrize("payout", [".01", ".05", ".5"]) + def test_determine_bp_payment_rounding(self, product_factory, payout): + p1 = product_factory(commission_pct=Decimal("0.05")) + res = p1.determine_bp_payment(thl_net=Decimal(payout)) + assert isinstance(res, Decimal) + assert res != Decimal(payout) * Decimal("0.95") + + +class TestSessionAdjustments: + + def test_status_complete(self, session_factory, user): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpi=Decimal(1), + final_status=Status.COMPLETE, + started=started1, + ) + + # Confirm only the last Wall Event is a complete + assert not s1.wall_events[0].status == Status.COMPLETE + assert s1.wall_events[1].status == Status.COMPLETE + + # Confirm the Session is marked as finished and the simple brokerage + # payout calculation is correct. + status, status_code_1 = s1.determine_session_status() + assert status == Status.COMPLETE + assert status_code_1 == StatusCode1.COMPLETE + + +class TestAdjustments: + + def test_finish_with_status(self, session_factory, user, session_manager): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpi=Decimal(1), + final_status=Status.COMPLETE, + started=started1, + ) + + status, status_code_1 = s1.determine_session_status() + payout = user.product.determine_bp_payment(Decimal(1)) + session_manager.finish_with_status( + session=s1, + status=status, + status_code_1=status_code_1, + payout=payout, + finished=finished2, + ) + + assert Decimal("0.95") == payout + + def test_never_adjusted(self, session_factory, user, session_manager): + s1 = session_factory( + user=user, + wall_count=5, + wall_req_cpi=Decimal(1), + final_status=Status.COMPLETE, + started=started1, + ) + + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("0.95"), + finished=finished2, + ) + + # Confirm walls and Session are never adjusted in anyway + for w in s1.wall_events: + w: Wall + assert w.adjusted_status is None + assert w.adjusted_timestamp is None + assert w.adjusted_cpi is None + + assert s1.adjusted_status is None + assert s1.adjusted_payout is None + assert s1.adjusted_timestamp is None + + def test_adjustment_wall_values( + self, session_factory, user, session_manager, wall_manager + ): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=5, + wall_req_cpi=Decimal(1), + final_status=Status.COMPLETE, + started=started1, + ) + + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("0.95"), + finished=finished2, + ) + + # Change the last wall event to a Failure + w: Wall = s1.wall_events[-1] + wall_manager.adjust_status( + wall=w, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=adj_ts, + ) + + # Original Session and Wall status is still the same, but the Adjusted + # values have changed + assert s1.status == Status.COMPLETE + assert s1.adjusted_status is None + assert s1.adjusted_timestamp is None + assert s1.adjusted_payout is None + assert s1.adjusted_user_payout is None + + assert w.status == Status.COMPLETE + assert w.status_code_1 == StatusCode1.COMPLETE + assert w.adjusted_status == WallAdjustedStatus.ADJUSTED_TO_FAIL + assert w.adjusted_cpi == Decimal(0) + assert w.adjusted_timestamp == adj_ts + + # Because the Product doesn't have the Wallet mode enabled, the + # user_payout fields should always be None + assert not user.product.user_wallet_config.enabled + assert s1.adjusted_user_payout is None + + def test_adjustment_session_values( + self, wall_manager, session_manager, session_factory, user + ): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpi=Decimal(1), + wall_source=Source.DYNATA, + final_status=Status.COMPLETE, + started=started1, + ) + + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("0.95"), + finished=finished2, + ) + + # Change the last wall event to a Failure + wall_manager.adjust_status( + wall=s1.wall_events[-1], + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=adj_ts, + ) + + # Refresh the Session with the new Wall Adjustment considerations, + session_manager.adjust_status(session=s1) + assert s1.status == Status.COMPLETE # Original status should remain + assert s1.adjusted_status == SessionAdjustedStatus.ADJUSTED_TO_FAIL + assert s1.adjusted_payout == Decimal(0) + assert s1.adjusted_timestamp == adj_ts + + # Because the Product doesn't have the Wallet mode enabled, the + # user_payout fields should always be None + assert not user.product.user_wallet_config.enabled + assert s1.adjusted_user_payout is None + + def test_double_adjustment_session_values( + self, wall_manager, session_manager, session_factory, user + ): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpi=Decimal(1), + final_status=Status.COMPLETE, + started=started1, + ) + + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("0.95"), + finished=finished2, + ) + + # Change the last wall event to a Failure + w: Wall = s1.wall_events[-1] + wall_manager.adjust_status( + wall=w, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=adj_ts, + ) + + # Refresh the Session with the new Wall Adjustment considerations, + session_manager.adjust_status(session=s1) + + # Let's take that back again! Buyers love to do this. + # So now we're going to "un-reconcile" the last Wall Event which has + # already gone from a Complete >> Failure + wall_manager.adjust_status( + wall=w, adjusted_status=None, adjusted_timestamp=adj_ts2 + ) + assert w.adjusted_status is None + assert w.adjusted_cpi is None + assert w.adjusted_timestamp == adj_ts2 + + # Once the wall was unreconciled, "refresh" the Session again + assert s1.adjusted_status is not None + session_manager.adjust_status(session=s1) + assert s1.adjusted_status is None + assert s1.adjusted_payout is None + assert s1.adjusted_timestamp == adj_ts2 + assert s1.adjusted_user_payout is None + + def test_double_adjustment_sm_vs_db_values( + self, wall_manager, session_manager, session_factory, user + ): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpi=Decimal(1), + wall_source=Source.DYNATA, + final_status=Status.COMPLETE, + started=started1, + ) + + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("0.95"), + finished=finished2, + ) + + # Change the last wall event to a Failure + wall_manager.adjust_status( + wall=s1.wall_events[-1], + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=adj_ts, + ) + + # Refresh the Session with the new Wall Adjustment considerations, + session_manager.adjust_status(session=s1) + + # Let's take that back again! Buyers love to do this. + # So now we're going to "un-reconcile" the last Wall Event which has + # already gone from a Complete >> Failure + # Once the wall was unreconciled, "refresh" the Session again + wall_manager.adjust_status( + wall=s1.wall_events[-1], adjusted_status=None, adjusted_timestamp=adj_ts2 + ) + session_manager.adjust_status(session=s1) + + # Confirm that the sessions wall attributes are still aligned with + # what comes back directly from the database + db_wall_events = wall_manager.get_wall_events(session_id=s1.id) + for idx in range(len(s1.wall_events)): + w_sm: Wall = s1.wall_events[idx] + w_db: Wall = db_wall_events[idx] + + assert w_sm.uuid == w_db.uuid + assert w_sm.session_id == w_db.session_id + assert w_sm.status == w_db.status + assert w_sm.status_code_1 == w_db.status_code_1 + assert w_sm.status_code_2 == w_db.status_code_2 + + assert w_sm.elapsed == w_db.elapsed + + # Decimal("1.000000") vs Decimal(1) - based on mysql or postgres + assert pytest.approx(w_sm.cpi) == w_db.cpi + assert pytest.approx(w_sm.req_cpi) == w_db.req_cpi + + assert w_sm.model_dump_json( + exclude={"cpi", "req_cpi"} + ) == w_db.model_dump_json(exclude={"cpi", "req_cpi"}) + + def test_double_adjustment_double_completes( + self, wall_manager, session_manager, session_factory, user + ): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpi=Decimal(2), + wall_source=Source.DYNATA, + final_status=Status.COMPLETE, + started=started1, + ) + + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("0.95"), + finished=finished2, + ) + + # Change the last wall event to a Failure + wall_manager.adjust_status( + wall=s1.wall_events[-1], + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=adj_ts, + ) + + # Refresh the Session with the new Wall Adjustment considerations, + session_manager.adjust_status(session=s1) + + # Let's take that back again! Buyers love to do this. + # So now we're going to "un-reconcile" the last Wall Event which has + # already gone from a Complete >> Failure + # Once the wall was unreconciled, "refresh" the Session again + wall_manager.adjust_status( + wall=s1.wall_events[-1], adjusted_status=None, adjusted_timestamp=adj_ts2 + ) + session_manager.adjust_status(session=s1) + + # Reassign them - we already validated they're equal in previous + # tests so this is safe to do. + s1.wall_events = wall_manager.get_wall_events(session_id=s1.id) + + # The First Wall event was originally a Failure, now let's also set + # that as a complete, so now both Wall Events will b a + # complete (Fail >> Adj to Complete, Complete >> Adj to Fail >> Adj to Complete) + w1: Wall = s1.wall_events[0] + assert w1.status == Status.FAIL + assert w1.adjusted_status is None + assert w1.adjusted_cpi is None + assert w1.adjusted_timestamp is None + + wall_manager.adjust_status( + wall=w1, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_timestamp=adj_ts3, + ) + + assert w1.status == Status.FAIL # original status doesn't change + assert w1.adjusted_status == WallAdjustedStatus.ADJUSTED_TO_COMPLETE + assert w1.adjusted_cpi == w1.cpi + assert w1.adjusted_timestamp == adj_ts3 + + session_manager.adjust_status(s1) + assert SessionAdjustedStatus.PAYOUT_ADJUSTMENT == s1.adjusted_status + assert Decimal("3.80") == s1.adjusted_payout + assert s1.adjusted_user_payout is None + assert adj_ts3 == s1.adjusted_timestamp + + def test_complete_to_fail( + self, session_factory, user, session_manager, wall_manager, utc_hour_ago + ): + s1 = session_factory( + user=user, + wall_count=1, + wall_req_cpi=Decimal("1"), + final_status=Status.COMPLETE, + started=utc_hour_ago, + ) + + status, status_code_1 = s1.determine_session_status() + assert status == Status.COMPLETE + + thl_net = Decimal(sum(w.cpi for w in s1.wall_events if w.is_visible_complete())) + payout = user.product.determine_bp_payment(thl_net=thl_net) + + session_manager.finish_with_status( + session=s1, + status=status, + status_code_1=status_code_1, + finished=utc_hour_ago + timedelta(minutes=10), + payout=payout, + user_payout=None, + ) + + w1 = s1.wall_events[0] + w1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=adj_ts, + ) + assert w1.adjusted_status == WallAdjustedStatus.ADJUSTED_TO_FAIL + assert w1.adjusted_cpi == Decimal(0) + + new_status, new_payout, new_user_payout = s1.determine_new_status_and_payouts() + assert Status.FAIL == new_status + assert Decimal(0) == new_payout + + assert not user.product.user_wallet_config.enabled + assert new_user_payout is None + + s1.adjust_status() + assert SessionAdjustedStatus.ADJUSTED_TO_FAIL == s1.adjusted_status + assert Decimal(0) == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + assert s1.adjusted_user_payout is None + + # cpi adjustment + w1.update( + adjusted_status=WallAdjustedStatus.CPI_ADJUSTMENT, + adjusted_cpi=Decimal("0.69"), + adjusted_timestamp=adj_ts, + ) + assert WallAdjustedStatus.CPI_ADJUSTMENT == w1.adjusted_status + assert Decimal("0.69") == w1.adjusted_cpi + new_status, new_payout, new_user_payout = s1.determine_new_status_and_payouts() + assert Status.COMPLETE == new_status + assert Decimal("0.66") == new_payout + + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.33") == new_user_payout + assert new_user_payout is None + + s1.adjust_status() + assert SessionAdjustedStatus.PAYOUT_ADJUSTMENT == s1.adjusted_status + assert Decimal("0.66") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.33") == s1.adjusted_user_payout + assert s1.adjusted_user_payout is None + + # adjust cpi again + wall_manager.adjust_status( + wall=w1, + adjusted_status=WallAdjustedStatus.CPI_ADJUSTMENT, + adjusted_cpi=Decimal("0.50"), + adjusted_timestamp=adj_ts, + ) + assert WallAdjustedStatus.CPI_ADJUSTMENT == w1.adjusted_status + assert Decimal("0.50") == w1.adjusted_cpi + new_status, new_payout, new_user_payout = s1.determine_new_status_and_payouts() + assert Status.COMPLETE == new_status + assert Decimal("0.48") == new_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.24") == new_user_payout + assert new_user_payout is None + + s1.adjust_status() + assert SessionAdjustedStatus.PAYOUT_ADJUSTMENT == s1.adjusted_status + assert Decimal("0.48") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.24") == s1.adjusted_user_payout + assert s1.adjusted_user_payout is None + + def test_complete_to_fail_to_complete(self, user, session_factory, utc_hour_ago): + # Setup: Complete, then adjust it to fail + s1 = session_factory( + user=user, + wall_count=1, + wall_req_cpi=Decimal("1"), + final_status=Status.COMPLETE, + started=utc_hour_ago, + ) + w1 = s1.wall_events[0] + + status, status_code_1 = s1.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments() + s1.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": utc_hour_ago + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + w1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=adj_ts, + ) + s1.adjust_status() + + # Test: Adjust back to complete + w1.update( + adjusted_status=None, + adjusted_cpi=None, + adjusted_timestamp=adj_ts, + ) + assert w1.adjusted_status is None + assert w1.adjusted_cpi is None + assert adj_ts == w1.adjusted_timestamp + + new_status, new_payout, new_user_payout = s1.determine_new_status_and_payouts() + assert Status.COMPLETE == new_status + assert Decimal("0.95") == new_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.48") == new_user_payout + assert new_user_payout is None + + s1.adjust_status() + assert s1.adjusted_status is None + assert s1.adjusted_payout is None + assert s1.adjusted_user_payout is None + + def test_complete_to_fail_to_complete_adj( + self, user, session_factory, utc_hour_ago + ): + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpis=[Decimal(1), Decimal(2)], + final_status=Status.COMPLETE, + started=utc_hour_ago, + ) + + w1 = s1.wall_events[0] + w2 = s1.wall_events[1] + + status, status_code_1 = s1.determine_session_status() + thl_net = Decimal(sum(w.cpi for w in s1.wall_events if w.is_visible_complete())) + payout = user.product.determine_bp_payment(thl_net=thl_net) + s1.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": utc_hour_ago + timedelta(minutes=25), + "payout": payout, + "user_payout": None, + } + ) + + # Test. Adjust first fail to complete. Now we have 2 completes. + w1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w1.cpi, + adjusted_timestamp=adj_ts, + ) + s1.adjust_status() + assert SessionAdjustedStatus.PAYOUT_ADJUSTMENT == s1.adjusted_status + assert Decimal("2.85") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("1.42") == s1.adjusted_user_payout + assert s1.adjusted_user_payout is None + + # Now we have [Fail, Complete ($2)] -> [Complete ($1), Fail] + w2.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=adj_ts2, + ) + s1.adjust_status() + assert SessionAdjustedStatus.PAYOUT_ADJUSTMENT == s1.adjusted_status + assert Decimal("0.95") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.48") == s1.adjusted_user_payout + assert s1.adjusted_user_payout is None + + def test_complete_to_fail_to_complete_adj1( + self, user, session_factory, utc_hour_ago + ): + # Same as test_complete_to_fail_to_complete_adj but in opposite order + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpis=[Decimal(1), Decimal(2)], + final_status=Status.COMPLETE, + started=utc_hour_ago, + ) + + w1 = s1.wall_events[0] + w2 = s1.wall_events[1] + + status, status_code_1 = s1.determine_session_status() + thl_net = Decimal(sum(w.cpi for w in s1.wall_events if w.is_visible_complete())) + payout = user.product.determine_bp_payment(thl_net) + s1.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": utc_hour_ago + timedelta(minutes=25), + "payout": payout, + "user_payout": None, + } + ) + + # Test. Adjust complete to fail. Now we have 2 fails. + w2.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=adj_ts, + ) + s1.adjust_status() + assert SessionAdjustedStatus.ADJUSTED_TO_FAIL == s1.adjusted_status + assert Decimal(0) == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal(0) == s.adjusted_user_payout + assert s1.adjusted_user_payout is None + # Now we have [Fail, Complete ($2)] -> [Complete ($1), Fail] + w1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w1.cpi, + adjusted_timestamp=adj_ts2, + ) + s1.adjust_status() + assert SessionAdjustedStatus.PAYOUT_ADJUSTMENT == s1.adjusted_status + assert Decimal("0.95") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.48") == s.adjusted_user_payout + assert s1.adjusted_user_payout is None + + def test_fail_to_complete_to_fail(self, user, session_factory, utc_hour_ago): + # End with an abandon + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpis=[Decimal(1), Decimal(2)], + final_status=Status.ABANDON, + started=utc_hour_ago, + ) + + w1 = s1.wall_events[0] + w2 = s1.wall_events[1] + + # abandon adjust to complete + w2.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w2.cpi, + adjusted_timestamp=adj_ts, + ) + assert WallAdjustedStatus.ADJUSTED_TO_COMPLETE == w2.adjusted_status + s1.adjust_status() + assert SessionAdjustedStatus.ADJUSTED_TO_COMPLETE == s1.adjusted_status + assert Decimal("1.90") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.95") == s1.adjusted_user_payout + assert s1.adjusted_user_payout is None + + # back to fail + w2.update( + adjusted_status=None, + adjusted_cpi=None, + adjusted_timestamp=adj_ts, + ) + assert w2.adjusted_status is None + s1.adjust_status() + assert s1.adjusted_status is None + assert s1.adjusted_payout is None + assert s1.adjusted_user_payout is None + + # other is now complete + w1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w1.cpi, + adjusted_timestamp=adj_ts, + ) + assert WallAdjustedStatus.ADJUSTED_TO_COMPLETE == w1.adjusted_status + s1.adjust_status() + assert SessionAdjustedStatus.ADJUSTED_TO_COMPLETE == s1.adjusted_status + assert Decimal("0.95") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.48") == s1.adjusted_user_payout + assert s1.adjusted_user_payout is None diff --git a/tests/models/thl/test_bucket.py b/tests/models/thl/test_bucket.py new file mode 100644 index 0000000..0aa5843 --- /dev/null +++ b/tests/models/thl/test_bucket.py @@ -0,0 +1,201 @@ +from datetime import timedelta +from decimal import Decimal + +import pytest +from pydantic import ValidationError + + +class TestBucket: + + def test_raises_payout(self): + from generalresearch.models.legacy.bucket import Bucket + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(user_payout_min=123) + assert "Must pass a Decimal" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(user_payout_min=Decimal(1 / 3)) + assert "Must have 2 or fewer decimal places" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(user_payout_min=Decimal(10000)) + assert "should be less than 1000" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(user_payout_min=Decimal(1), user_payout_max=Decimal("0.01")) + assert "user_payout_min should be <= user_payout_max" in str(e.value) + + def test_raises_loi(self): + from generalresearch.models.legacy.bucket import Bucket + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(loi_min=123) + assert "Input should be a valid timedelta" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(loi_min=timedelta(seconds=9999)) + assert "should be less than 90 minutes" in str(e.value) + + with pytest.raises(ValidationError) as e: + Bucket(loi_min=timedelta(seconds=0)) + assert "should be greater than 0" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(loi_min=timedelta(seconds=10), loi_max=timedelta(seconds=9)) + assert "loi_min should be <= loi_max" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket( + loi_min=timedelta(seconds=10), + loi_max=timedelta(seconds=90), + loi_q1=timedelta(seconds=20), + ) + assert "loi_q1, q2, and q3 should all be set" in str(e.value) + with pytest.raises(expected_exception=ValidationError) as e: + Bucket( + loi_min=timedelta(seconds=10), + loi_max=timedelta(seconds=90), + loi_q1=timedelta(seconds=200), + loi_q2=timedelta(seconds=20), + loi_q3=timedelta(seconds=12), + ) + assert "loi_q1 should be <= loi_q2" in str(e.value) + + def test_parse_1(self): + from generalresearch.models.legacy.bucket import Bucket + + b1 = Bucket.parse_from_offerwall({"payout": {"min": 123}}) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=None, + loi_min=None, + loi_max=None, + ) + assert b_exp == b1 + + b2 = Bucket.parse_from_offerwall({"payout": {"min": 123, "max": 230}}) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=Decimal("2.30"), + loi_min=None, + loi_max=None, + ) + assert b_exp == b2 + + b3 = Bucket.parse_from_offerwall( + {"payout": {"min": 123, "max": 230}, "duration": {"min": 600, "max": 1800}} + ) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=Decimal("2.30"), + loi_min=timedelta(seconds=600), + loi_max=timedelta(seconds=1800), + ) + assert b_exp == b3 + + b4 = Bucket.parse_from_offerwall( + { + "payout": {"max": 80, "min": 28, "q1": 43, "q2": 43, "q3": 56}, + "duration": {"max": 1172, "min": 266, "q1": 746, "q2": 918, "q3": 1002}, + } + ) + b_exp = Bucket( + user_payout_min=Decimal("0.28"), + user_payout_max=Decimal("0.80"), + user_payout_q1=Decimal("0.43"), + user_payout_q2=Decimal("0.43"), + user_payout_q3=Decimal("0.56"), + loi_min=timedelta(seconds=266), + loi_max=timedelta(seconds=1172), + loi_q1=timedelta(seconds=746), + loi_q2=timedelta(seconds=918), + loi_q3=timedelta(seconds=1002), + ) + assert b_exp == b4 + + def test_parse_2(self): + from generalresearch.models.legacy.bucket import Bucket + + b1 = Bucket.parse_from_offerwall({"min_payout": 123}) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=None, + loi_min=None, + loi_max=None, + ) + assert b_exp == b1 + + b2 = Bucket.parse_from_offerwall({"min_payout": 123, "max_payout": 230}) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=Decimal("2.30"), + loi_min=None, + loi_max=None, + ) + assert b_exp == b2 + + b3 = Bucket.parse_from_offerwall( + { + "min_payout": 123, + "max_payout": 230, + "min_duration": 600, + "max_duration": 1800, + } + ) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=Decimal("2.30"), + loi_min=timedelta(seconds=600), + loi_max=timedelta(seconds=1800), + ) + assert b_exp, b3 + + b4 = Bucket.parse_from_offerwall( + { + "min_payout": 28, + "max_payout": 99, + "min_duration": 205, + "max_duration": 1113, + "q1_payout": 43, + "q2_payout": 43, + "q3_payout": 46, + "q1_duration": 561, + "q2_duration": 891, + "q3_duration": 918, + } + ) + b_exp = Bucket( + user_payout_min=Decimal("0.28"), + user_payout_max=Decimal("0.99"), + user_payout_q1=Decimal("0.43"), + user_payout_q2=Decimal("0.43"), + user_payout_q3=Decimal("0.46"), + loi_min=timedelta(seconds=205), + loi_max=timedelta(seconds=1113), + loi_q1=timedelta(seconds=561), + loi_q2=timedelta(seconds=891), + loi_q3=timedelta(seconds=918), + ) + assert b_exp == b4 + + def test_parse_3(self): + from generalresearch.models.legacy.bucket import Bucket + + b1 = Bucket.parse_from_offerwall({"payout": 123}) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=None, + loi_min=None, + loi_max=None, + ) + assert b_exp == b1 + + b2 = Bucket.parse_from_offerwall({"payout": 123, "duration": 1800}) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=None, + loi_min=None, + loi_max=timedelta(seconds=1800), + ) + assert b_exp == b2 diff --git a/tests/models/thl/test_buyer.py b/tests/models/thl/test_buyer.py new file mode 100644 index 0000000..eebb828 --- /dev/null +++ b/tests/models/thl/test_buyer.py @@ -0,0 +1,23 @@ +from generalresearch.models import Source +from generalresearch.models.thl.survey.buyer import BuyerCountryStat + + +def test_buyer_country_stat(): + bcs = BuyerCountryStat( + country_iso="us", + source=Source.TESTING, + code="123", + task_count=100, + conversion_alpha=40, + conversion_beta=190, + dropoff_alpha=20, + dropoff_beta=50, + long_fail_rate=1, + loi_excess_ratio=1, + user_report_coeff=1, + recon_likelihood=0.05, + ) + assert bcs.score + print(bcs.score) + print(bcs.conversion_p20) + print(bcs.dropoff_p60) diff --git a/tests/models/thl/test_contest/__init__.py b/tests/models/thl/test_contest/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/models/thl/test_contest/__init__.py diff --git a/tests/models/thl/test_contest/test_contest.py b/tests/models/thl/test_contest/test_contest.py new file mode 100644 index 0000000..d53eee5 --- /dev/null +++ b/tests/models/thl/test_contest/test_contest.py @@ -0,0 +1,23 @@ +import pytest +from generalresearch.models.thl.user import User + + +class TestContest: + """In many of the Contest related tests, we often want a consistent + Product throughout, and multiple different users that may be + involved in the Contest... so redefine the product fixture along with + some users in here that are scoped="class" so they stay around for + each of the test functions + """ + + @pytest.fixture(scope="function") + def user_1(self, user_factory, product) -> User: + return user_factory(product=product) + + @pytest.fixture(scope="function") + def user_2(self, user_factory, product) -> User: + return user_factory(product=product) + + @pytest.fixture(scope="function") + def user_3(self, user_factory, product) -> User: + return user_factory(product=product) diff --git a/tests/models/thl/test_contest/test_leaderboard_contest.py b/tests/models/thl/test_contest/test_leaderboard_contest.py new file mode 100644 index 0000000..98f3215 --- /dev/null +++ b/tests/models/thl/test_contest/test_leaderboard_contest.py @@ -0,0 +1,213 @@ +from datetime import timezone +from uuid import uuid4 + +import pytest + +from generalresearch.currency import USDCent +from generalresearch.managers.leaderboard.manager import LeaderboardManager +from generalresearch.models.thl.contest import ContestPrize +from generalresearch.models.thl.contest.definitions import ( + ContestType, + ContestPrizeKind, +) +from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContest, +) +from generalresearch.models.thl.contest.utils import ( + distribute_leaderboard_prizes, +) +from generalresearch.models.thl.leaderboard import LeaderboardRow +from tests.models.thl.test_contest.test_contest import TestContest + + +class TestLeaderboardContest(TestContest): + + @pytest.fixture + def leaderboard_contest( + self, product, thl_redis, user_manager + ) -> "LeaderboardContest": + board_key = f"leaderboard:{product.uuid}:us:weekly:2025-05-26:complete_count" + + c = LeaderboardContest( + uuid=uuid4().hex, + product_id=product.uuid, + contest_type=ContestType.LEADERBOARD, + leaderboard_key=board_key, + name="$15 1st place, $10 2nd, $5 3rd place US weekly", + prizes=[ + ContestPrize( + name="$15 Cash", + estimated_cash_value=USDCent(15_00), + cash_amount=USDCent(15_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=1, + ), + ContestPrize( + name="$10 Cash", + estimated_cash_value=USDCent(10_00), + cash_amount=USDCent(10_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=2, + ), + ContestPrize( + name="$5 Cash", + estimated_cash_value=USDCent(5_00), + cash_amount=USDCent(5_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=3, + ), + ], + ) + c._redis_client = thl_redis + c._user_manager = user_manager + return c + + def test_init(self, leaderboard_contest, thl_redis, user_1, user_2): + model = leaderboard_contest.leaderboard_model + assert leaderboard_contest.end_condition.ends_at is not None + + lbm = LeaderboardManager( + redis_client=thl_redis, + board_code=model.board_code, + country_iso=model.country_iso, + freq=model.freq, + product_id=leaderboard_contest.product_id, + within_time=model.period_start_local, + ) + + lbm.hit_complete_count(product_user_id=user_1.product_user_id) + lbm.hit_complete_count(product_user_id=user_2.product_user_id) + lbm.hit_complete_count(product_user_id=user_2.product_user_id) + + lb = leaderboard_contest.get_leaderboard() + print(lb) + + def test_win(self, leaderboard_contest, thl_redis, user_1, user_2, user_3): + model = leaderboard_contest.leaderboard_model + lbm = LeaderboardManager( + redis_client=thl_redis, + board_code=model.board_code, + country_iso=model.country_iso, + freq=model.freq, + product_id=leaderboard_contest.product_id, + within_time=model.period_start_local.astimezone(tz=timezone.utc), + ) + + lbm.hit_complete_count(product_user_id=user_1.product_user_id) + lbm.hit_complete_count(product_user_id=user_1.product_user_id) + + lbm.hit_complete_count(product_user_id=user_2.product_user_id) + + lbm.hit_complete_count(product_user_id=user_3.product_user_id) + + leaderboard_contest.end_contest() + assert len(leaderboard_contest.all_winners) == 3 + + # Prizes are $15, $10, $5. user 2 and 3 ties for 2nd place, so they split (10 + 5) + assert leaderboard_contest.all_winners[0].awarded_cash_amount == USDCent(15_00) + assert ( + leaderboard_contest.all_winners[0].user.product_user_id + == user_1.product_user_id + ) + assert leaderboard_contest.all_winners[0].prize == leaderboard_contest.prizes[0] + assert leaderboard_contest.all_winners[1].awarded_cash_amount == USDCent( + 15_00 / 2 + ) + assert leaderboard_contest.all_winners[2].awarded_cash_amount == USDCent( + 15_00 / 2 + ) + + +class TestLeaderboardContestPrizes: + + def test_distribute_prizes_1(self): + prizes = [USDCent(15_00)] + leaderboard_rows = [ + LeaderboardRow(bpuid="a", value=20, rank=1), + LeaderboardRow(bpuid="b", value=10, rank=2), + ] + result = distribute_leaderboard_prizes(prizes, leaderboard_rows) + + # a gets first prize, b gets nothing. + assert result == { + "a": USDCent(15_00), + } + + def test_distribute_prizes_2(self): + prizes = [USDCent(15_00), USDCent(10_00)] + leaderboard_rows = [ + LeaderboardRow(bpuid="a", value=20, rank=1), + LeaderboardRow(bpuid="b", value=10, rank=2), + ] + result = distribute_leaderboard_prizes(prizes, leaderboard_rows) + + # a gets first prize, b gets 2nd prize + assert result == { + "a": USDCent(15_00), + "b": USDCent(10_00), + } + + def test_distribute_prizes_3(self): + prizes = [USDCent(15_00), USDCent(10_00)] + leaderboard_rows = [ + LeaderboardRow(bpuid="a", value=20, rank=1), + ] + result = distribute_leaderboard_prizes(prizes, leaderboard_rows) + + # A gets first prize, no-one gets $10 + assert result == { + "a": USDCent(15_00), + } + + def test_distribute_prizes_4(self): + prizes = [USDCent(15_00)] + leaderboard_rows = [ + LeaderboardRow(bpuid="a", value=20, rank=1), + LeaderboardRow(bpuid="b", value=20, rank=1), + LeaderboardRow(bpuid="c", value=20, rank=1), + LeaderboardRow(bpuid="d", value=20, rank=1), + ] + result = distribute_leaderboard_prizes(prizes, leaderboard_rows) + + # 4-way tie for the $15 prize; it gets split + assert result == { + "a": USDCent(3_75), + "b": USDCent(3_75), + "c": USDCent(3_75), + "d": USDCent(3_75), + } + + def test_distribute_prizes_5(self): + prizes = [USDCent(15_00), USDCent(10_00)] + leaderboard_rows = [ + LeaderboardRow(bpuid="a", value=20, rank=1), + LeaderboardRow(bpuid="b", value=20, rank=1), + LeaderboardRow(bpuid="c", value=10, rank=3), + ] + result = distribute_leaderboard_prizes(prizes, leaderboard_rows) + + # 2-way tie for the $15 prize; the top two prizes get split. Rank 3 + # and below get nothing + assert result == { + "a": USDCent(12_50), + "b": USDCent(12_50), + } + + def test_distribute_prizes_6(self): + prizes = [USDCent(15_00), USDCent(10_00), USDCent(5_00)] + leaderboard_rows = [ + LeaderboardRow(bpuid="a", value=20, rank=1), + LeaderboardRow(bpuid="b", value=10, rank=2), + LeaderboardRow(bpuid="c", value=10, rank=2), + LeaderboardRow(bpuid="d", value=10, rank=2), + ] + result = distribute_leaderboard_prizes(prizes, leaderboard_rows) + + # A gets first prize, 3 way tie for 2nd rank: they split the 2nd and + # 3rd place prizes (10 + 5)/3 + assert result == { + "a": USDCent(15_00), + "b": USDCent(5_00), + "c": USDCent(5_00), + "d": USDCent(5_00), + } diff --git a/tests/models/thl/test_contest/test_raffle_contest.py b/tests/models/thl/test_contest/test_raffle_contest.py new file mode 100644 index 0000000..e1c0a15 --- /dev/null +++ b/tests/models/thl/test_contest/test_raffle_contest.py @@ -0,0 +1,300 @@ +from collections import Counter +from uuid import uuid4 + +import pytest +from pytest import approx + +from generalresearch.currency import USDCent +from generalresearch.models.thl.contest import ( + ContestPrize, + ContestEndCondition, +) +from generalresearch.models.thl.contest.contest_entry import ContestEntry +from generalresearch.models.thl.contest.definitions import ( + ContestEntryType, + ContestPrizeKind, + ContestType, + ContestStatus, + ContestEndReason, +) +from generalresearch.models.thl.contest.raffle import RaffleContest + +from tests.models.thl.test_contest.test_contest import TestContest + + +class TestRaffleContest(TestContest): + + @pytest.fixture(scope="function") + def raffle_contest(self, product) -> RaffleContest: + return RaffleContest( + product_id=product.uuid, + name=f"Raffle Contest {uuid4().hex}", + contest_type=ContestType.RAFFLE, + entry_type=ContestEntryType.CASH, + prizes=[ + ContestPrize( + name="iPod 64GB White", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ) + ], + end_condition=ContestEndCondition(target_entry_amount=100), + ) + + @pytest.fixture(scope="function") + def ended_raffle_contest(self, raffle_contest, utc_now) -> RaffleContest: + # Fake ending the contest + raffle_contest = raffle_contest.model_copy() + raffle_contest.update( + status=ContestStatus.COMPLETED, + ended_at=utc_now, + end_reason=ContestEndReason.ENDS_AT, + ) + return raffle_contest + + +class TestRaffleContestUserView(TestRaffleContest): + + def test_user_view(self, raffle_contest, user): + from generalresearch.models.thl.contest.raffle import RaffleUserView + + data = { + "current_amount": USDCent(1_00), + "product_user_id": user.product_user_id, + "user_amount": USDCent(1), + "user_amount_today": USDCent(1), + } + r = RaffleUserView.model_validate(raffle_contest.model_dump() | data) + res = r.model_dump(mode="json") + + assert res["product_user_id"] == user.product_user_id + assert res["user_amount_today"] == 1 + assert res["current_win_probability"] == approx(0.01, rel=0.000001) + assert res["projected_win_probability"] == approx(0.01, rel=0.000001) + + # Now change the amount + r.current_amount = USDCent(1_01) + res = r.model_dump(mode="json") + assert res["current_win_probability"] == approx(0.0099, rel=0.001) + assert res["projected_win_probability"] == approx(0.0099, rel=0.001) + + def test_win_pct(self, raffle_contest, user): + from generalresearch.models.thl.contest.raffle import RaffleUserView + + data = { + "current_amount": USDCent(10), + "product_user_id": user.product_user_id, + "user_amount": USDCent(1), + "user_amount_today": USDCent(1), + } + r = RaffleUserView.model_validate(raffle_contest.model_dump() | data) + r.prizes = [ + ContestPrize( + name="iPod 64GB White", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ContestPrize( + name="iPod 64GB White", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ] + # Raffle has 10 entries, user has 1 entry. + # There are 2 prizes. + assert r.current_win_probability == approx(expected=0.2, rel=0.01) + # He can only possibly win 1 prize + assert r.current_prize_count_probability[1] == approx(expected=0.2, rel=0.01) + # He has a 0 prob of winning 2 prizes + assert r.current_prize_count_probability[2] == 0 + # Contest end when there are 100 entries, so 1/100 * 2 prizes + assert r.projected_win_probability == approx(expected=0.02, rel=0.01) + + # Change to user having 2 entries (out of 10) + # Still with 2 prizes + r.user_amount = USDCent(2) + assert r.current_win_probability == approx(expected=0.3777, rel=0.01) + # 2/10 chance of winning 1st, 8/9 change of not winning 2nd, plus the + # same in the other order + p = (2 / 10) * (8 / 9) * 2 # 0.355555 + assert r.current_prize_count_probability[1] == approx(p, rel=0.01) + p = (2 / 10) * (1 / 9) # 0.02222 + assert r.current_prize_count_probability[2] == approx(p, rel=0.01) + + +class TestRaffleContestWinners(TestRaffleContest): + + def test_winners_1_prize(self, ended_raffle_contest, user_1, user_2, user_3): + ended_raffle_contest.entries = [ + ContestEntry( + user=user_1, + amount=USDCent(1), + entry_type=ContestEntryType.CASH, + ), + ContestEntry( + user=user_2, + amount=USDCent(2), + entry_type=ContestEntryType.CASH, + ), + ContestEntry( + user=user_3, + amount=USDCent(3), + entry_type=ContestEntryType.CASH, + ), + ] + + # There is 1 prize. If we select a winner 1000 times, we'd expect user 1 + # to win ~ 1/6th of the time, user 2 ~2/6th and 3 3/6th. + winners = ended_raffle_contest.select_winners() + assert len(winners) == 1 + + c = Counter( + [ + ended_raffle_contest.select_winners()[0].user.user_id + for _ in range(10000) + ] + ) + assert c[user_1.user_id] == approx( + 10000 * 1 / 6, rel=0.1 + ) # 10% relative tolerance + assert c[user_2.user_id] == approx(10000 * 2 / 6, rel=0.1) + assert c[user_3.user_id] == approx(10000 * 3 / 6, rel=0.1) + + def test_winners_2_prizes(self, ended_raffle_contest, user_1, user_2, user_3): + ended_raffle_contest.prizes.append( + ContestPrize( + name="iPod 64GB Black", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ) + ) + ended_raffle_contest.entries = [ + ContestEntry( + user=user_3, + amount=USDCent(1), + entry_type=ContestEntryType.CASH, + ), + ContestEntry( + user=user_1, + amount=USDCent(9999999), + entry_type=ContestEntryType.CASH, + ), + ContestEntry( + user=user_2, + amount=USDCent(1), + entry_type=ContestEntryType.CASH, + ), + ] + # In this scenario, user 1 should win both prizes + winners = ended_raffle_contest.select_winners() + assert len(winners) == 2 + # Two different prizes + assert len({w.prize.name for w in winners}) == 2 + # Same user + assert all(w.user.user_id == user_1.user_id for w in winners) + + def test_winners_2_prizes_1_entry(self, ended_raffle_contest, user_3): + ended_raffle_contest.prizes = [ + ContestPrize( + name="iPod 64GB White", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ContestPrize( + name="iPod 64GB Black", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ] + ended_raffle_contest.entries = [ + ContestEntry( + user=user_3, + amount=USDCent(1), + entry_type=ContestEntryType.CASH, + ), + ] + + # One prize goes unclaimed + winners = ended_raffle_contest.select_winners() + assert len(winners) == 1 + + def test_winners_2_prizes_1_entry_2_pennies(self, ended_raffle_contest, user_3): + ended_raffle_contest.prizes = [ + ContestPrize( + name="iPod 64GB White", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ContestPrize( + name="iPod 64GB Black", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ] + ended_raffle_contest.entries = [ + ContestEntry( + user=user_3, + amount=USDCent(2), + entry_type=ContestEntryType.CASH, + ), + ] + # User wins both prizes + winners = ended_raffle_contest.select_winners() + assert len(winners) == 2 + + def test_winners_3_prizes_3_entries( + self, ended_raffle_contest, product, user_1, user_2, user_3 + ): + ended_raffle_contest.prizes = [ + ContestPrize( + name="iPod 64GB White", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ContestPrize( + name="iPod 64GB Black", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ContestPrize( + name="iPod 64GB Red", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ] + ended_raffle_contest.entries = [ + ContestEntry( + user=user_1, + amount=USDCent(1), + entry_type=ContestEntryType.CASH, + ), + ContestEntry( + user=user_2, + amount=USDCent(2), + entry_type=ContestEntryType.CASH, + ), + ContestEntry( + user=user_3, + amount=USDCent(3), + entry_type=ContestEntryType.CASH, + ), + ] + + winners = ended_raffle_contest.select_winners() + assert len(winners) == 3 + + winners = [ended_raffle_contest.select_winners() for _ in range(10000)] + + # There's 3 winners, the 1st should follow the same percentages + c = Counter([w[0].user.user_id for w in winners]) + + assert c[user_1.user_id] == approx(10000 * 1 / 6, rel=0.1) + assert c[user_2.user_id] == approx(10000 * 2 / 6, rel=0.1) + assert c[user_3.user_id] == approx(10000 * 3 / 6, rel=0.1) + + # Assume the 1st user won + ended_raffle_contest.entries.pop(0) + winners = [ended_raffle_contest.select_winners() for _ in range(10000)] + c = Counter([w[0].user.user_id for w in winners]) + assert c[user_2.user_id] == approx(10000 * 2 / 5, rel=0.1) + assert c[user_3.user_id] == approx(10000 * 3 / 5, rel=0.1) diff --git a/tests/models/thl/test_ledger.py b/tests/models/thl/test_ledger.py new file mode 100644 index 0000000..d706357 --- /dev/null +++ b/tests/models/thl/test_ledger.py @@ -0,0 +1,130 @@ +from datetime import datetime, timezone +from decimal import Decimal +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +from generalresearch.models.thl.ledger import LedgerAccount, Direction, AccountType +from generalresearch.models.thl.ledger import LedgerTransaction, LedgerEntry + + +class TestLedgerTransaction: + + def test_create(self): + # Can create with nothing ... + t = LedgerTransaction() + assert [] == t.entries + assert {} == t.metadata + t = LedgerTransaction( + created=datetime.now(tz=timezone.utc), + metadata={"a": "b", "user": "1234"}, + ext_description="foo", + ) + + def test_ledger_entry(self): + with pytest.raises(expected_exception=ValidationError) as cm: + LedgerEntry( + direction=Direction.CREDIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=0, + ) + assert "Input should be greater than 0" in str(cm.value) + + with pytest.raises(ValidationError) as cm: + LedgerEntry( + direction=Direction.CREDIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=2**65, + ) + assert "Input should be less than 9223372036854775807" in str(cm.value) + + with pytest.raises(ValidationError) as cm: + LedgerEntry( + direction=Direction.CREDIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=Decimal("1"), + ) + assert "Input should be a valid integer" in str(cm.value) + + with pytest.raises(ValidationError) as cm: + LedgerEntry( + direction=Direction.CREDIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=1.2, + ) + assert "Input should be a valid integer" in str(cm.value) + + def test_entries(self): + entries = [ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=100, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid="5927621462814f9893be807db850a31b", + amount=100, + ), + ] + LedgerTransaction(entries=entries) + + def test_raises_entries(self): + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=100, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid="5927621462814f9893be807db850a31b", + amount=100, + ), + ] + with pytest.raises(ValidationError) as e: + LedgerTransaction(entries=entries) + assert "ledger entries must balance" in str(e.value) + + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=100, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid="5927621462814f9893be807db850a31b", + amount=101, + ), + ] + with pytest.raises(ValidationError) as cm: + LedgerTransaction(entries=entries) + assert "ledger entries must balance" in str(cm.value) + + +class TestLedgerAccount: + + def test_initialization(self): + u = uuid4().hex + name = f"test-{u[:8]}" + + with pytest.raises(ValidationError) as cm: + LedgerAccount( + display_name=name, + qualified_name="bad bunny", + normal_balance=Direction.DEBIT, + account_type=AccountType.BP_WALLET, + ) + assert "qualified name should start with" in str(cm.value) + + with pytest.raises(ValidationError) as cm: + LedgerAccount( + display_name=name, + qualified_name="fish sticks:bp_wallet", + normal_balance=Direction.DEBIT, + account_type=AccountType.BP_WALLET, + currency="fish sticks", + ) + assert "Invalid UUID" in str(cm.value) diff --git a/tests/models/thl/test_marketplace_condition.py b/tests/models/thl/test_marketplace_condition.py new file mode 100644 index 0000000..217616d --- /dev/null +++ b/tests/models/thl/test_marketplace_condition.py @@ -0,0 +1,382 @@ +import pytest +from pydantic import ValidationError + + +class TestMarketplaceCondition: + + def test_list_or(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"a2"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a2"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a3"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) is None + + def test_list_or_negate(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"a2"}} + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.LIST, + values=["a2"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.LIST, + values=["a1", "a3"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=True, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) is None + + def test_list_and(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"a1", "a2"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.AND, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a2"], + logical_operator=LogicalOperator.AND, + ) + assert c.evaluate_criterion(user_qas) + user_qas = {"q1": {"a1"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a2"], + logical_operator=LogicalOperator.AND, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a3"], + logical_operator=LogicalOperator.AND, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.AND, + ) + assert c.evaluate_criterion(user_qas) is None + + def test_list_and_negate(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"a1", "a2"}} + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.AND, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.LIST, + values=["a2"], + logical_operator=LogicalOperator.AND, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.LIST, + values=["a1", "a3"], + logical_operator=LogicalOperator.AND, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=True, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.AND, + ) + assert c.evaluate_criterion(user_qas) is None + + def test_ranges(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"2", "50"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4", "10-20"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["10-20"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4", "10-20"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) is None + # --- negate + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.RANGE, + values=["1-4", "10-20"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.RANGE, + values=["10-20"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + # --- AND + with pytest.raises(expected_exception=ValidationError): + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4", "10-20"], + logical_operator=LogicalOperator.AND, + ) + + def test_ranges_to_list(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"2", "50"}} + MarketplaceCondition._CONVERT_LIST_TO_RANGE = ["q1"] + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4", "10-12", "3-5"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + assert ConditionValueType.LIST == c.value_type + assert ["1", "10", "11", "12", "2", "3", "4", "5"] == c.values + + def test_ranges_infinity(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"2", "50"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4", "10-inf"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + user_qas = {"q1": {"5", "50"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4", "60-inf"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion(user_qas) + + # need to test negative infinity! + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["inf-40"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion({"q1": {"5", "50"}}) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["inf-40"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion({"q1": {"50"}}) + + def test_answered(self): + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"a2"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.ANSWERED, + values=[], + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=False, + value_type=ConditionValueType.ANSWERED, + values=[], + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.ANSWERED, + values=[], + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=True, + value_type=ConditionValueType.ANSWERED, + values=[], + ) + assert c.evaluate_criterion(user_qas) + + def test_invite(self): + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_groups = {"g1", "g2", "g3"} + c = MarketplaceCondition( + question_id=None, + negate=False, + value_type=ConditionValueType.RECONTACT, + values=["g1", "g4"], + ) + assert c.evaluate_criterion(user_qas=dict(), user_groups=user_groups) + c = MarketplaceCondition( + question_id=None, + negate=False, + value_type=ConditionValueType.RECONTACT, + values=["g4"], + ) + assert not c.evaluate_criterion(user_qas=dict(), user_groups=user_groups) + + c = MarketplaceCondition( + question_id=None, + negate=True, + value_type=ConditionValueType.RECONTACT, + values=["g1", "g4"], + ) + assert not c.evaluate_criterion(user_qas=dict(), user_groups=user_groups) + c = MarketplaceCondition( + question_id=None, + negate=True, + value_type=ConditionValueType.RECONTACT, + values=["g4"], + ) + assert c.evaluate_criterion(user_qas=dict(), user_groups=user_groups) diff --git a/tests/models/thl/test_payout.py b/tests/models/thl/test_payout.py new file mode 100644 index 0000000..3a51328 --- /dev/null +++ b/tests/models/thl/test_payout.py @@ -0,0 +1,10 @@ +class TestBusinessPayoutEvent: + + def test_validate(self): + from generalresearch.models.gr.business import Business + + instance = Business.model_validate_json( + json_data='{"id":123,"uuid":"947f6ba5250d442b9a66cde9ee33605a","name":"Example » Demo","kind":"c","tax_number":null,"contact":null,"addresses":[],"teams":[{"id":53,"uuid":"8e4197dcaefe4f1f831a02b212e6b44a","name":"Example » Demo","memberships":null,"gr_users":null,"businesses":null,"products":null}],"products":[{"id":"fc23e741b5004581b30e6478363525df","id_int":1234,"name":"Example","enabled":true,"payments_enabled":true,"created":"2025-04-14T13:25:37.279403Z","team_id":"9e4197dcaefe4f1f831a02b212e6b44a","business_id":"857f6ba6160d442b9a66cde9ee33605a","tags":[],"commission_pct":"0.050000","redirect_url":"https://pam-api-us.reppublika.com/v2/public/4970ef00-0ef7-11f0-9962-05cb6323c84c/grl/status","harmonizer_domain":"https://talk.generalresearch.com/","sources_config":{"user_defined":[{"name":"w","active":false,"banned_countries":[],"allow_mobile_ip":true,"supplier_id":null,"allow_pii_only_buyers":false,"allow_unhashed_buyers":false,"withhold_profiling":false,"pass_unconditional_eligible_unknowns":true,"address":null,"allow_vpn":null,"distribute_harmonizer_active":null}]},"session_config":{"max_session_len":600,"max_session_hard_retry":5,"min_payout":"0.14"},"payout_config":{"payout_format":null,"payout_transformation":null},"user_wallet_config":{"enabled":false,"amt":false,"supported_payout_types":["CASH_IN_MAIL","PAYPAL","TANGO"],"min_cashout":null},"user_create_config":{"min_hourly_create_limit":0,"max_hourly_create_limit":null},"offerwall_config":{},"profiling_config":{"enabled":true,"grs_enabled":true,"n_questions":null,"max_questions":10,"avg_question_count":5.0,"task_injection_freq_mult":1.0,"non_us_mult":2.0,"hidden_questions_expiration_hours":168},"user_health_config":{"banned_countries":[],"allow_ban_iphist":true},"yield_man_config":{},"balance":null,"payouts_total_str":null,"payouts_total":null,"payouts":null,"user_wallet":{"enabled":false,"amt":false,"supported_payout_types":["CASH_IN_MAIL","PAYPAL","TANGO"],"min_cashout":null}}],"bank_accounts":[],"balance":{"product_balances":[{"product_id":"fc14e741b5004581b30e6478363414df","last_event":null,"bp_payment_credit":780251,"adjustment_credit":4678,"adjustment_debit":26446,"supplier_credit":0,"supplier_debit":451513,"user_bonus_credit":0,"user_bonus_debit":0,"issued_payment":0,"payout":780251,"payout_usd_str":"$7,802.51","adjustment":-21768,"expense":0,"net":758483,"payment":451513,"payment_usd_str":"$4,515.13","balance":306970,"retainer":76742,"retainer_usd_str":"$767.42","available_balance":230228,"available_balance_usd_str":"$2,302.28","recoup":0,"recoup_usd_str":"$0.00","adjustment_percent":0.027898714644390074}],"payout":780251,"payout_usd_str":"$7,802.51","adjustment":-21768,"expense":0,"net":758483,"net_usd_str":"$7,584.83","payment":451513,"payment_usd_str":"$4,515.13","balance":306970,"balance_usd_str":"$3,069.70","retainer":76742,"retainer_usd_str":"$767.42","available_balance":230228,"available_balance_usd_str":"$2,302.28","adjustment_percent":0.027898714644390074,"recoup":0,"recoup_usd_str":"$0.00"},"payouts_total_str":"$4,515.13","payouts_total":451513,"payouts":[{"bp_payouts":[{"uuid":"40cf2c3c341e4f9d985be4bca43e6116","debit_account_uuid":"3a058056da85493f9b7cdfe375aad0e0","cashout_method_uuid":"602113e330cf43ae85c07d94b5100291","created":"2025-08-02T09:18:20.433329Z","amount":345735,"status":"COMPLETE","ext_ref_id":null,"payout_type":"ACH","request_data":{},"order_data":null,"product_id":"fc14e741b5004581b30e6478363414df","method":"ACH","amount_usd":345735,"amount_usd_str":"$3,457.35"}],"amount":345735,"amount_usd_str":"$3,457.35","created":"2025-08-02T09:18:20.433329Z","line_items":1,"ext_ref_id":null},{"bp_payouts":[{"uuid":"63ce1787087248978919015c8fcd5ab9","debit_account_uuid":"3a058056da85493f9b7cdfe375aad0e0","cashout_method_uuid":"602113e330cf43ae85c07d94b5100291","created":"2025-06-10T22:16:18.765668Z","amount":105778,"status":"COMPLETE","ext_ref_id":"11175997868","payout_type":"ACH","request_data":{},"order_data":null,"product_id":"fc14e741b5004581b30e6478363414df","method":"ACH","amount_usd":105778,"amount_usd_str":"$1,057.78"}],"amount":105778,"amount_usd_str":"$1,057.78","created":"2025-06-10T22:16:18.765668Z","line_items":1,"ext_ref_id":"11175997868"}]}' + ) + + assert isinstance(instance, Business) diff --git a/tests/models/thl/test_payout_format.py b/tests/models/thl/test_payout_format.py new file mode 100644 index 0000000..dc91f39 --- /dev/null +++ b/tests/models/thl/test_payout_format.py @@ -0,0 +1,46 @@ +import pytest +from pydantic import BaseModel + +from generalresearch.models.thl.payout_format import ( + PayoutFormatType, + PayoutFormatField, + format_payout_format, +) + + +class PayoutFormatTestClass(BaseModel): + payout_format: PayoutFormatType = PayoutFormatField + + +class TestPayoutFormat: + def test_payout_format_cls(self): + # valid + PayoutFormatTestClass(payout_format="{payout*10:,.0f} Points") + PayoutFormatTestClass(payout_format="{payout:.0f}") + PayoutFormatTestClass(payout_format="${payout/100:.2f}") + + # invalid + with pytest.raises(expected_exception=Exception) as e: + PayoutFormatTestClass(payout_format="{payout10:,.0f} Points") + + with pytest.raises(expected_exception=Exception) as e: + PayoutFormatTestClass(payout_format="payout:,.0f} Points") + + with pytest.raises(expected_exception=Exception): + PayoutFormatTestClass(payout_format="payout") + + with pytest.raises(expected_exception=Exception): + PayoutFormatTestClass(payout_format="{payout;import sys:.0f}") + + def test_payout_format(self): + assert "1,230 Points" == format_payout_format( + payout_format="{payout*10:,.0f} Points", payout_int=123 + ) + + assert "123" == format_payout_format( + payout_format="{payout:.0f}", payout_int=123 + ) + + assert "$1.23" == format_payout_format( + payout_format="${payout/100:.2f}", payout_int=123 + ) diff --git a/tests/models/thl/test_product.py b/tests/models/thl/test_product.py new file mode 100644 index 0000000..52f60c2 --- /dev/null +++ b/tests/models/thl/test_product.py @@ -0,0 +1,1130 @@ +import os +import shutil +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from typing import Optional +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +from generalresearch.currency import USDCent +from generalresearch.models import Source +from generalresearch.models.thl.product import ( + Product, + PayoutConfig, + PayoutTransformation, + ProfilingConfig, + SourcesConfig, + IntegrationMode, + SupplyConfig, + SourceConfig, + SupplyPolicy, +) + + +class TestProduct: + + def test_init(self): + # By default, just a Pydantic instance doesn't have an id_int + instance = Product.model_validate( + dict( + id="968a9acc79b74b6fb49542d82516d284", + name="test-968a9acc", + redirect_url="https://www.google.com/hey", + ) + ) + assert instance.id_int is None + + res = instance.model_dump_json() + # We're not excluding anything here, only in the "*Out" variants + assert "id_int" in res + + def test_init_db(self, product_manager): + # By default, just a Pydantic instance doesn't have an id_int + instance = product_manager.create_dummy() + assert isinstance(instance.id_int, int) + + res = instance.model_dump_json() + + # we json skip & exclude + res = instance.model_dump() + + def test_redirect_url(self): + p = Product.model_validate( + dict( + id="968a9acc79b74b6fb49542d82516d284", + created="2023-09-21T22:13:09.274672Z", + commission_pct=Decimal("0.05"), + enabled=True, + sources=[{"name": "d", "active": True}], + name="test-968a9acc", + max_session_len=600, + team_id="8b5e94afd8a246bf8556ad9986486baa", + redirect_url="https://www.google.com/hey", + ) + ) + + with pytest.raises(expected_exception=ValidationError): + p.redirect_url = "" + + with pytest.raises(expected_exception=ValidationError): + p.redirect_url = None + + with pytest.raises(expected_exception=ValidationError): + p.redirect_url = "http://www.example.com/test/?a=1&b=2" + + with pytest.raises(expected_exception=ValidationError): + p.redirect_url = "http://www.example.com/test/?a=1&b=2&tsid=" + + p.redirect_url = "https://www.example.com/test/?a=1&b=2" + c = p.generate_bp_redirect(tsid="c6ab6ba1e75b44e2bf5aab00fc68e3b7") + assert ( + c + == "https://www.example.com/test/?a=1&b=2&tsid=c6ab6ba1e75b44e2bf5aab00fc68e3b7" + ) + + def test_harmonizer_domain(self): + p = Product( + id="968a9acc79b74b6fb49542d82516d284", + created="2023-09-21T22:13:09.274672Z", + commission_pct=Decimal("0.05"), + enabled=True, + name="test-968a9acc", + team_id="8b5e94afd8a246bf8556ad9986486baa", + harmonizer_domain="profile.generalresearch.com", + redirect_url="https://www.google.com/hey", + ) + assert p.harmonizer_domain == "https://profile.generalresearch.com/" + p.harmonizer_domain = "https://profile.generalresearch.com/" + p.harmonizer_domain = "https://profile.generalresearch.com" + assert p.harmonizer_domain == "https://profile.generalresearch.com/" + with pytest.raises(expected_exception=Exception): + p.harmonizer_domain = "" + with pytest.raises(expected_exception=Exception): + p.harmonizer_domain = None + with pytest.raises(expected_exception=Exception): + # no https + p.harmonizer_domain = "http://profile.generalresearch.com" + with pytest.raises(expected_exception=Exception): + # "/a" at the end + p.harmonizer_domain = "https://profile.generalresearch.com/a" + + def test_payout_xform(self): + p = Product( + id="968a9acc79b74b6fb49542d82516d284", + created="2023-09-21T22:13:09.274672Z", + commission_pct=Decimal("0.05"), + enabled=True, + name="test-968a9acc", + team_id="8b5e94afd8a246bf8556ad9986486baa", + harmonizer_domain="profile.generalresearch.com", + redirect_url="https://www.google.com/hey", + ) + + p.payout_config.payout_transformation = PayoutTransformation.model_validate( + { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.5", "min_payout": "0.10"}, + } + ) + + assert ( + "payout_transformation_percent" == p.payout_config.payout_transformation.f + ) + assert 0.5 == p.payout_config.payout_transformation.kwargs.pct + assert ( + Decimal("0.10") == p.payout_config.payout_transformation.kwargs.min_payout + ) + assert p.payout_config.payout_transformation.kwargs.max_payout is None + + # This calls get_payout_transformation_func + # 50% of $1.00 + assert Decimal("0.50") == p.calculate_user_payment(Decimal(1)) + # with a min + assert Decimal("0.10") == p.calculate_user_payment(Decimal("0.15")) + + with pytest.raises(expected_exception=ValidationError) as cm: + p.payout_config.payout_transformation = PayoutTransformation.model_validate( + {"f": "payout_transformation_percent", "kwargs": {}} + ) + assert "1 validation error for PayoutTransformation\nkwargs.pct" in str( + cm.value + ) + + with pytest.raises(expected_exception=ValidationError) as cm: + p.payout_config.payout_transformation = PayoutTransformation.model_validate( + {"f": "payout_transformation_percent"} + ) + + assert "1 validation error for PayoutTransformation\nkwargs" in str(cm.value) + + with pytest.warns(expected_warning=Warning) as w: + p.payout_config.payout_transformation = PayoutTransformation.model_validate( + { + "f": "payout_transformation_percent", + "kwargs": {"pct": 1, "min_payout": "0.5"}, + } + ) + assert "Are you sure you want to pay respondents >95% of CPI?" in "".join( + [str(i.message) for i in w] + ) + + p.payout_config = PayoutConfig() + assert p.calculate_user_payment(Decimal("0.15")) is None + + def test_payout_xform_amt(self): + p = Product( + id="968a9acc79b74b6fb49542d82516d284", + created="2023-09-21T22:13:09.274672Z", + commission_pct=Decimal("0.05"), + enabled=True, + name="test-968a9acc", + team_id="8b5e94afd8a246bf8556ad9986486baa", + harmonizer_domain="profile.generalresearch.com", + redirect_url="https://www.google.com/hey", + ) + + p.payout_config.payout_transformation = PayoutTransformation.model_validate( + { + "f": "payout_transformation_amt", + } + ) + + assert "payout_transformation_amt" == p.payout_config.payout_transformation.f + + # This calls get_payout_transformation_func + # 95% of $1.00 + assert p.calculate_user_payment(Decimal(1)) == Decimal("0.95") + assert p.calculate_user_payment(Decimal("1.05")) == Decimal("1.00") + + assert p.calculate_user_payment( + Decimal("0.10"), user_wallet_balance=Decimal(0) + ) == Decimal("0.07") + assert p.calculate_user_payment( + Decimal("1.05"), user_wallet_balance=Decimal(0) + ) == Decimal("0.97") + assert p.calculate_user_payment( + Decimal(".05"), user_wallet_balance=Decimal(1) + ) == Decimal("0.02") + # final balance will be <0, so pay the full amount + assert p.calculate_user_payment( + Decimal(".50"), user_wallet_balance=Decimal(-1) + ) == p.calculate_user_payment(Decimal("0.50")) + # final balance will be >0, so do the 7c rounding + assert p.calculate_user_payment( + Decimal(".50"), user_wallet_balance=Decimal("-0.10") + ) == ( + p.calculate_user_payment(Decimal(".40"), user_wallet_balance=Decimal(0)) + - Decimal("-0.10") + ) + + def test_payout_xform_none(self): + p = Product( + id="968a9acc79b74b6fb49542d82516d284", + created="2023-09-21T22:13:09.274672Z", + commission_pct=Decimal("0.05"), + enabled=True, + name="test-968a9acc", + team_id="8b5e94afd8a246bf8556ad9986486baa", + harmonizer_domain="profile.generalresearch.com", + redirect_url="https://www.google.com/hey", + payout_config=PayoutConfig(payout_format=None, payout_transformation=None), + ) + assert p.format_payout_format(Decimal("1.00")) is None + + pt = PayoutTransformation.model_validate( + {"kwargs": {"pct": 0.5}, "f": "payout_transformation_percent"} + ) + p.payout_config = PayoutConfig( + payout_format="{payout*10:,.0f} Points", payout_transformation=pt + ) + assert p.format_payout_format(Decimal("1.00")) == "1,000 Points" + + def test_profiling(self): + p = Product( + id="968a9acc79b74b6fb49542d82516d284", + created="2023-09-21T22:13:09.274672Z", + commission_pct=Decimal("0.05"), + enabled=True, + name="test-968a9acc", + team_id="8b5e94afd8a246bf8556ad9986486baa", + harmonizer_domain="profile.generalresearch.com", + redirect_url="https://www.google.com/hey", + ) + assert p.profiling_config.enabled is True + + p.profiling_config = ProfilingConfig(max_questions=1) + assert p.profiling_config.max_questions == 1 + + def test_bp_account(self, product, thl_lm): + assert product.bp_account is None + + product.prefetch_bp_account(thl_lm=thl_lm) + + from generalresearch.models.thl.ledger import LedgerAccount + + assert isinstance(product.bp_account, LedgerAccount) + + +class TestGlobalProduct: + # We have one product ID that is special; we call it the Global + # Product ID and in prod the. This product stores a bunch of extra + # things in the SourcesConfig + + def test_init_and_props(self): + instance = Product( + name="Global Config", + redirect_url="https://www.example.com", + sources_config=SupplyConfig( + policies=[ + # This is the config for Dynata that any BP is allowed to use + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + integration_mode=IntegrationMode.PLATFORM, + ), + # Spectrum that is using OUR credentials, that anyone is allowed to use. + # Same as the dynata config above, just that the dynata supplier_id is + # inferred by the dynata-grpc; it's not required to be set. + SupplyPolicy( + address=["https://spectrum.internal:50051"], + active=True, + name=Source.SPECTRUM, + supplier_id="example-supplier-id", + # implicit Scope = GLOBAL + # default integration_mode=IntegrationMode.PLATFORM, + ), + # A spectrum config with a different supplier_id, but + # it is OUR supplier, and we are paid for the completes. Only a certain BP + # can use this config. + SupplyPolicy( + address=["https://spectrum.internal:50051"], + active=True, + name=Source.SPECTRUM, + supplier_id="example-supplier-id", + team_ids=["d42194c2dfe44d7c9bec98123bc4a6c0"], + # implicit Scope = TEAM + # default integration_mode=IntegrationMode.PLATFORM, + ), + # The supplier ID is associated with THEIR + # credentials, and we do not get paid for this activity. + SupplyPolicy( + address=["https://cint.internal:50051"], + active=True, + name=Source.CINT, + supplier_id="example-supplier-id", + product_ids=["db8918b3e87d4444b60241d0d3a54caa"], + integration_mode=IntegrationMode.PASS_THROUGH, + ), + # We could have another global cint integration available + # to anyone also, or we could have another like above + SupplyPolicy( + address=["https://cint.internal:50051"], + active=True, + name=Source.CINT, + supplier_id="example-supplier-id", + team_ids=["b163972a59584de881e5eab01ad10309"], + integration_mode=IntegrationMode.PASS_THROUGH, + ), + ] + ), + ) + + assert Product.model_validate_json(instance.model_dump_json()) == instance + + s = instance.sources_config + # Cint should NOT have a global config + assert set(s.global_scoped_policies_dict.keys()) == { + Source.DYNATA, + Source.SPECTRUM, + } + + # The spectrum global config is the one that isn't scoped to a + # specific supplier + assert ( + s.global_scoped_policies_dict[Source.SPECTRUM].supplier_id + == "grl-supplier-id" + ) + + assert set(s.team_scoped_policies_dict.keys()) == { + "b163972a59584de881e5eab01ad10309", + "d42194c2dfe44d7c9bec98123bc4a6c0", + } + # This team has one team-scoped config, and it's for spectrum + assert s.team_scoped_policies_dict[ + "d42194c2dfe44d7c9bec98123bc4a6c0" + ].keys() == {Source.SPECTRUM} + + # For a random product/team, it'll just have the globally-scoped config + random_product = uuid4().hex + random_team = uuid4().hex + res = instance.sources_config.get_policies_for( + product_id=random_product, team_id=random_team + ) + assert res == s.global_scoped_policies_dict + + # It'll have the global config plus cint, and it should use the PRODUCT + # scoped config, not the TEAM scoped! + res = instance.sources_config.get_policies_for( + product_id="db8918b3e87d4444b60241d0d3a54caa", + team_id="b163972a59584de881e5eab01ad10309", + ) + assert set(res.keys()) == { + Source.DYNATA, + Source.SPECTRUM, + Source.CINT, + } + assert res[Source.CINT].supplier_id == "example-supplier-id" + + def test_source_vs_supply_validate(self): + # sources_config can be a SupplyConfig or SourcesConfig. + # make sure they get model_validated correctly + gp = Product( + name="Global Config", + redirect_url="https://www.example.com", + sources_config=SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + integration_mode=IntegrationMode.PLATFORM, + ) + ] + ), + ) + bp = Product( + name="test product config", + redirect_url="https://www.example.com", + sources_config=SourcesConfig( + user_defined=[ + SourceConfig( + active=False, + name=Source.DYNATA, + ) + ] + ), + ) + assert Product.model_validate_json(gp.model_dump_json()) == gp + assert Product.model_validate_json(bp.model_dump_json()) == bp + + def test_validations(self): + with pytest.raises( + ValidationError, match="Can only have one GLOBAL policy per Source" + ): + SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + integration_mode=IntegrationMode.PLATFORM, + ), + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + integration_mode=IntegrationMode.PASS_THROUGH, + ), + ] + ) + with pytest.raises( + ValidationError, + match="Can only have one PRODUCT policy per Source per BP", + ): + SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + product_ids=["7e417dec1c8a406e8554099b46e518ca"], + integration_mode=IntegrationMode.PLATFORM, + ), + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + product_ids=["7e417dec1c8a406e8554099b46e518ca"], + integration_mode=IntegrationMode.PASS_THROUGH, + ), + ] + ) + with pytest.raises( + ValidationError, + match="Can only have one TEAM policy per Source per Team", + ): + SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + team_ids=["7e417dec1c8a406e8554099b46e518ca"], + integration_mode=IntegrationMode.PLATFORM, + ), + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + team_ids=["7e417dec1c8a406e8554099b46e518ca"], + integration_mode=IntegrationMode.PASS_THROUGH, + ), + ] + ) + + +class TestGlobalProductConfigFor: + def test_no_user_defined(self): + sc = SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + ) + ] + ) + product = Product( + name="Test Product Config", + redirect_url="https://www.example.com", + sources_config=SourcesConfig(), + ) + res = sc.get_config_for_product(product=product) + assert len(res.policies) == 1 + + def test_user_defined_merge(self): + sc = SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + banned_countries=["mx"], + active=True, + name=Source.DYNATA, + ), + SupplyPolicy( + address=["https://dynata.internal:50051"], + banned_countries=["ca"], + active=True, + name=Source.DYNATA, + team_ids=[uuid4().hex], + ), + ] + ) + product = Product( + name="Test Product Config", + redirect_url="https://www.example.com", + sources_config=SourcesConfig( + user_defined=[ + SourceConfig( + name=Source.DYNATA, + active=False, + banned_countries=["us"], + ) + ] + ), + ) + res = sc.get_config_for_product(product=product) + assert len(res.policies) == 1 + assert not res.policies[0].active + assert res.policies[0].banned_countries == ["mx", "us"] + + def test_no_eligible(self): + sc = SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + team_ids=["7e417dec1c8a406e8554099b46e518ca"], + integration_mode=IntegrationMode.PLATFORM, + ) + ] + ) + product = Product( + name="Test Product Config", + redirect_url="https://www.example.com", + sources_config=SourcesConfig(), + ) + res = sc.get_config_for_product(product=product) + assert len(res.policies) == 0 + + +class TestProductFinancials: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "30d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return None + + def test_balance( + self, + business, + product_factory, + user_factory, + mnt_filepath, + bp_payout_factory, + thl_lm, + lm, + duration, + offset, + thl_redis_config, + start, + thl_web_rr, + brokerage_product_payout_event_manager, + session_with_tx_factory, + delete_ledger_db, + create_main_accounts, + client_no_amm, + ledger_collection, + pop_ledger_merge, + delete_df_collection, + ): + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + from generalresearch.models.thl.finance import ProductBalances + from generalresearch.currency import USDCent + + p1: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=p1) + thl_lm.get_account_or_create_user_wallet(user=u1) + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 0 + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".50"), + started=start + timedelta(days=1), + ) + assert thl_lm.get_account_balance(account=bp_wallet) == 48 + assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 1 + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(days=2), + ) + assert thl_lm.get_account_balance(account=bp_wallet) == 143 + assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 2 + + with pytest.raises(expected_exception=AssertionError) as cm: + p1.prebuild_balance( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + ) + assert "Cannot build Product Balance" in str(cm.value) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + p1.prebuild_balance( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + ) + assert isinstance(p1.balance, ProductBalances) + assert p1.balance.payout == 143 + assert p1.balance.adjustment == 0 + assert p1.balance.expense == 0 + assert p1.balance.net == 143 + assert p1.balance.balance == 143 + assert p1.balance.retainer == 35 + assert p1.balance.available_balance == 108 + + p1.prebuild_payouts( + thl_lm=thl_lm, + bp_pem=brokerage_product_payout_event_manager, + ) + assert p1.payouts is not None + assert len(p1.payouts) == 0 + assert p1.payouts_total == 0 + assert p1.payouts_total_str == "$0.00" + + # -- Now pay them out... + + bp_payout_factory( + product=p1, + amount=USDCent(50), + created=start + timedelta(days=3), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 3 + + # RM the entire directories + shutil.rmtree(ledger_collection.archive_path) + os.makedirs(ledger_collection.archive_path, exist_ok=True) + shutil.rmtree(pop_ledger_merge.archive_path) + os.makedirs(pop_ledger_merge.archive_path, exist_ok=True) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + p1.prebuild_balance( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + ) + assert isinstance(p1.balance, ProductBalances) + assert p1.balance.payout == 143 + assert p1.balance.adjustment == 0 + assert p1.balance.expense == 0 + assert p1.balance.net == 143 + assert p1.balance.balance == 93 + assert p1.balance.retainer == 23 + assert p1.balance.available_balance == 70 + + p1.prebuild_payouts( + thl_lm=thl_lm, + bp_pem=brokerage_product_payout_event_manager, + ) + assert p1.payouts is not None + assert len(p1.payouts) == 1 + assert p1.payouts_total == 50 + assert p1.payouts_total_str == "$0.50" + + # -- Now pay ou another!. + + bp_payout_factory( + product=p1, + amount=USDCent(5), + created=start + timedelta(days=4), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 4 + + # RM the entire directories + shutil.rmtree(ledger_collection.archive_path) + os.makedirs(ledger_collection.archive_path, exist_ok=True) + shutil.rmtree(pop_ledger_merge.archive_path) + os.makedirs(pop_ledger_merge.archive_path, exist_ok=True) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + p1.prebuild_balance( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + ) + assert isinstance(p1.balance, ProductBalances) + assert p1.balance.payout == 143 + assert p1.balance.adjustment == 0 + assert p1.balance.expense == 0 + assert p1.balance.net == 143 + assert p1.balance.balance == 88 + assert p1.balance.retainer == 22 + assert p1.balance.available_balance == 66 + + p1.prebuild_payouts( + thl_lm=thl_lm, + bp_pem=brokerage_product_payout_event_manager, + ) + assert p1.payouts is not None + assert len(p1.payouts) == 2 + assert p1.payouts_total == 55 + assert p1.payouts_total_str == "$0.55" + + +class TestProductBalance: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "30d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return None + + def test_inconsistent( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + payout_event_manager, + ): + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + + # 1. Complete and Build Parquets 1st time + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + # 2. Payout and build Parquets 2nd time + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + bp_payout_factory( + product=product, + amount=USDCent(71), + ext_ref_id=uuid4().hex, + created=start + timedelta(days=1, minutes=1), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + with pytest.raises(expected_exception=AssertionError) as cm: + product.prebuild_balance( + thl_lm=thl_lm, ds=mnt_filepath, client=client_no_amm + ) + assert "Sql and Parquet Balance inconsistent" in str(cm) + + def test_not_inconsistent( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + payout_event_manager, + ): + # This is very similar to the test_complete_payout_pq_inconsistent + # test, however this time we're only going to assign the payout + # in real time, and not in the past. This means that even if we + # build the parquet files multiple times, they will include the + # payout. + + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + + # 1. Complete and Build Parquets 1st time + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + # 2. Payout and build Parquets 2nd time but this payout is "now" + # so it hasn't already been archived + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + bp_payout_factory( + product=product, + amount=USDCent(71), + ext_ref_id=uuid4().hex, + created=datetime.now(tz=timezone.utc), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + # We just want to call this to confirm it doesn't raise. + product.prebuild_balance(thl_lm=thl_lm, ds=mnt_filepath, client=client_no_amm) + + +class TestProductPOPFinancial: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "30d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return None + + def test_base( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + payout_event_manager, + ): + # This is very similar to the test_complete_payout_pq_inconsistent + # test, however this time we're only going to assign the payout + # in real time, and not in the past. This means that even if we + # build the parquet files multiple times, they will include the + # payout. + + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + + # 1. Complete and Build Parquets 1st time + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + # --- test --- + assert product.pop_financial is None + product.prebuild_pop_financial( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + from generalresearch.models.thl.finance import POPFinancial + + assert isinstance(product.pop_financial, list) + assert isinstance(product.pop_financial[0], POPFinancial) + pf1: POPFinancial = product.pop_financial[0] + assert isinstance(pf1.time, datetime) + assert pf1.payout == 71 + assert pf1.net == 71 + assert pf1.adjustment == 0 + for adj in pf1.adjustment_types: + assert adj.amount == 0 + + +class TestProductCache: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "30d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return None + + def test_basic( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + ): + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + # Confirm the default / null behavior + rc = thl_redis_config.create_redis_client() + res: Optional[str] = rc.get(product.cache_key) + assert res is None + with pytest.raises(expected_exception=AssertionError): + product.set_cache( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + bp_pem=brokerage_product_payout_event_manager, + redis_config=thl_redis_config, + ) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + # Now try again with everything in place + product.set_cache( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + bp_pem=brokerage_product_payout_event_manager, + redis_config=thl_redis_config, + ) + + # Fetch from cache and assert the instance loaded from redis + res: Optional[str] = rc.get(product.cache_key) + assert isinstance(res, str) + from generalresearch.models.thl.ledger import LedgerAccount + + assert isinstance(product.bp_account, LedgerAccount) + + p1: Product = Product.model_validate_json(res) + assert p1.balance.product_id == product.uuid + assert p1.balance.payout_usd_str == "$0.71" + assert p1.balance.retainer_usd_str == "$0.17" + assert p1.balance.available_balance_usd_str == "$0.54" + + def test_neg_balance_cache( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + payout_event_manager, + adj_to_fail_with_tx_factory, + ): + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + + # 1. Complete + s1 = session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + + # 2. Payout + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + bp_payout_factory( + product=product, + amount=USDCent(71), + ext_ref_id=uuid4().hex, + created=start + timedelta(days=1, minutes=1), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + # 3. Recon + adj_to_fail_with_tx_factory( + session=s1, + created=start + timedelta(days=1, minutes=1), + ) + + # Finally, process everything: + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + product.set_cache( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + bp_pem=brokerage_product_payout_event_manager, + redis_config=thl_redis_config, + ) + + # Fetch from cache and assert the instance loaded from redis + rc = thl_redis_config.create_redis_client() + res: Optional[str] = rc.get(product.cache_key) + assert isinstance(res, str) + + p1: Product = Product.model_validate_json(res) + assert p1.balance.product_id == product.uuid + assert p1.balance.payout_usd_str == "$0.71" + assert p1.balance.adjustment == -71 + assert p1.balance.expense == 0 + assert p1.balance.net == 0 + assert p1.balance.balance == -71 + assert p1.balance.retainer_usd_str == "$0.00" + assert p1.balance.available_balance_usd_str == "$0.00" diff --git a/tests/models/thl/test_product_userwalletconfig.py b/tests/models/thl/test_product_userwalletconfig.py new file mode 100644 index 0000000..4583c46 --- /dev/null +++ b/tests/models/thl/test_product_userwalletconfig.py @@ -0,0 +1,56 @@ +from itertools import groupby +from random import shuffle as rshuffle + +from generalresearch.models.thl.product import ( + UserWalletConfig, +) + +from generalresearch.models.thl.wallet import PayoutType + + +def all_equal(iterable): + g = groupby(iterable) + return next(g, True) and not next(g, False) + + +class TestProductUserWalletConfig: + + def test_init(self): + instance = UserWalletConfig() + + assert isinstance(instance, UserWalletConfig) + + # Check the defaults + assert not instance.enabled + assert not instance.amt + + assert isinstance(instance.supported_payout_types, set) + assert len(instance.supported_payout_types) == 3 + + assert instance.min_cashout is None + + def test_model_dump(self): + instance = UserWalletConfig() + + # If we use the defaults, the supported_payout_types are always + # in the same order because they're the same + assert isinstance(instance.model_dump_json(), str) + res = [] + for idx in range(100): + res.append(instance.model_dump_json()) + assert all_equal(res) + + def test_model_dump_payout_types(self): + res = [] + for idx in range(100): + + # Generate a random order of PayoutTypes each time + payout_types = [e for e in PayoutType] + rshuffle(payout_types) + instance = UserWalletConfig.model_validate( + {"supported_payout_types": payout_types} + ) + + res.append(instance.model_dump_json()) + + assert all_equal(res) diff --git a/tests/models/thl/test_soft_pair.py b/tests/models/thl/test_soft_pair.py new file mode 100644 index 0000000..bac0e8d --- /dev/null +++ b/tests/models/thl/test_soft_pair.py @@ -0,0 +1,24 @@ +from generalresearch.models import Source +from generalresearch.models.thl.soft_pair import SoftPairResult, SoftPairResultType + + +def test_model(): + from generalresearch.models.dynata.survey import ( + DynataCondition, + ConditionValueType, + ) + + c1 = DynataCondition( + question_id="1", value_type=ConditionValueType.LIST, values=["a", "b"] + ) + c2 = DynataCondition( + question_id="2", value_type=ConditionValueType.LIST, values=["c", "d"] + ) + sr = SoftPairResult( + pair_type=SoftPairResultType.CONDITIONAL, + source=Source.DYNATA, + survey_id="xxx", + conditions={c1, c2}, + ) + assert sr.grpc_string == "xxx:1;2" + assert sr.survey_sid == "d:xxx" diff --git a/tests/models/thl/test_upkquestion.py b/tests/models/thl/test_upkquestion.py new file mode 100644 index 0000000..e67427e --- /dev/null +++ b/tests/models/thl/test_upkquestion.py @@ -0,0 +1,414 @@ +import pytest +from pydantic import ValidationError + + +class TestUpkQuestion: + + def test_importance(self): + from generalresearch.models.thl.profiling.upk_question import ( + UPKImportance, + ) + + ui = UPKImportance(task_score=1, task_count=None) + ui = UPKImportance(task_score=0) + with pytest.raises(ValidationError) as e: + UPKImportance(task_score=-1) + assert "Input should be greater than or equal to 0" in str(e.value) + + def test_pattern(self): + from generalresearch.models.thl.profiling.upk_question import ( + PatternValidation, + ) + + s = PatternValidation(message="hi", pattern="x") + with pytest.raises(ValidationError) as e: + s.message = "sfd" + assert "Instance is frozen" in str(e.value) + + def test_mc(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionSelectorMC, + UpkQuestionType, + UpkQuestion, + UpkQuestionConfigurationMC, + ) + + q = UpkQuestion( + id="601377a0d4c74529afc6293a8e5c3b5e", + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + text="whats up", + choices=[ + UpkQuestionChoice(id="1", text="sky", order=1), + UpkQuestionChoice(id="2", text="moon", order=2), + ], + configuration=UpkQuestionConfigurationMC(max_select=2), + ) + assert q == UpkQuestion.model_validate(q.model_dump(mode="json")) + + q = UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.SINGLE_ANSWER, + text="yes or no", + choices=[ + UpkQuestionChoice(id="1", text="yes", order=1), + UpkQuestionChoice(id="2", text="no", order=2), + ], + configuration=UpkQuestionConfigurationMC(max_select=1), + ) + assert q == UpkQuestion.model_validate(q.model_dump(mode="json")) + + q = UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + text="yes or no", + choices=[ + UpkQuestionChoice(id="1", text="yes", order=1), + UpkQuestionChoice(id="2", text="no", order=2), + ], + ) + assert q == UpkQuestion.model_validate(q.model_dump(mode="json")) + + with pytest.raises(ValidationError) as e: + q = UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.SINGLE_ANSWER, + text="yes or no", + choices=[ + UpkQuestionChoice(id="1", text="yes", order=1), + UpkQuestionChoice(id="2", text="no", order=2), + ], + configuration=UpkQuestionConfigurationMC(max_select=2), + ) + assert "max_select must be 1 if the selector is SA" in str(e.value) + + with pytest.raises(ValidationError) as e: + q = UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + text="yes or no", + choices=[ + UpkQuestionChoice(id="1", text="yes", order=1), + UpkQuestionChoice(id="2", text="no", order=2), + ], + configuration=UpkQuestionConfigurationMC(max_select=4), + ) + assert "max_select must be >= len(choices)" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + q = UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + text="yes or no", + choices=[ + UpkQuestionChoice(id="1", text="yes", order=1), + UpkQuestionChoice(id="2", text="no", order=2), + ], + configuration=UpkQuestionConfigurationMC(max_length=2), + ) + assert "Extra inputs are not permitted" in str(e.value) + + def test_te(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionType, + UpkQuestion, + UpkQuestionSelectorTE, + UpkQuestionValidation, + PatternValidation, + UpkQuestionConfigurationTE, + ) + + q = UpkQuestion( + id="601377a0d4c74529afc6293a8e5c3b5e", + country_iso="us", + language_iso="eng", + type=UpkQuestionType.TEXT_ENTRY, + selector=UpkQuestionSelectorTE.MULTI_LINE, + text="whats up", + choices=[], + configuration=UpkQuestionConfigurationTE(max_length=2), + validation=UpkQuestionValidation( + patterns=[PatternValidation(pattern=".", message="x")] + ), + ) + assert q == UpkQuestion.model_validate(q.model_dump(mode="json")) + assert q.choices is None + + def test_deserialization(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + + q = UpkQuestion.model_validate( + { + "id": "601377a0d4c74529afc6293a8e5c3b5e", + "ext_question_id": "m:2342", + "country_iso": "us", + "language_iso": "eng", + "text": "whats up", + "choices": [ + {"id": "1", "text": "yes", "order": 1}, + {"id": "2", "text": "no", "order": 2}, + ], + "importance": None, + "type": "MC", + "selector": "SA", + "configuration": None, + } + ) + assert q == UpkQuestion.model_validate(q.model_dump(mode="json")) + + q = UpkQuestion.model_validate( + { + "id": "601377a0d4c74529afc6293a8e5c3b5e", + "ext_question_id": "m:2342", + "country_iso": "us", + "language_iso": "eng", + "text": "whats up", + "choices": [ + {"id": "1", "text": "yes", "order": 1}, + {"id": "2", "text": "no", "order": 2}, + ], + "importance": None, + "question_type": "MC", + "selector": "MA", + "configuration": {"max_select": 2}, + } + ) + assert q == UpkQuestion.model_validate(q.model_dump(mode="json")) + + def test_from_morning(self): + from generalresearch.models.morning.question import ( + MorningQuestion, + MorningQuestionType, + ) + + q = MorningQuestion( + **{ + "id": "gender", + "country_iso": "us", + "language_iso": "eng", + "name": "Gender", + "text": "What is your gender?", + "type": "s", + "options": [ + {"id": "1", "text": "yes", "order": 1}, + {"id": "2", "text": "no", "order": 2}, + ], + } + ) + q.to_upk_question() + q = MorningQuestion( + country_iso="us", + language_iso="eng", + type=MorningQuestionType.text_entry, + text="how old r u", + id="a", + name="age", + ) + q.to_upk_question() + + def test_order(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionSelectorMC, + UpkQuestionType, + UpkQuestion, + order_exclusive_options, + ) + + q = UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + text="yes, no, or NA?", + choices=[ + UpkQuestionChoice(id="1", text="NA", order=0), + UpkQuestionChoice(id="2", text="no", order=1), + UpkQuestionChoice(id="3", text="yes", order=2), + ], + ) + order_exclusive_options(q) + assert ( + UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + text="yes, no, or NA?", + choices=[ + UpkQuestionChoice(id="2", text="no", order=0), + UpkQuestionChoice(id="3", text="yes", order=1), + UpkQuestionChoice(id="1", text="NA", order=2, exclusive=True), + ], + ) + == q + ) + + +class TestUpkQuestionValidateAnswer: + def test_validate_answer_SA(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + + question = UpkQuestion.model_validate( + { + "choices": [ + {"order": 0, "choice_id": "0", "choice_text": "Male"}, + {"order": 1, "choice_id": "1", "choice_text": "Female"}, + {"order": 2, "choice_id": "2", "choice_text": "Other"}, + ], + "selector": "SA", + "country_iso": "us", + "question_id": "5d6d9f3c03bb40bf9d0a24f306387d7c", + "language_iso": "eng", + "question_text": "What is your gender?", + "question_type": "MC", + } + ) + answer = ("0",) + assert question.validate_question_answer(answer)[0] is True + answer = ("3",) + assert question.validate_question_answer(answer) == ( + False, + "Invalid Options Selected", + ) + answer = ("0", "0") + assert question.validate_question_answer(answer) == ( + False, + "Multiple of the same answer submitted", + ) + answer = ("0", "1") + assert question.validate_question_answer(answer) == ( + False, + "Single Answer MC question with >1 selected " "answers", + ) + + def test_validate_answer_MA(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + + question = UpkQuestion.model_validate( + { + "choices": [ + { + "order": 0, + "choice_id": "none", + "exclusive": True, + "choice_text": "None of the above", + }, + { + "order": 1, + "choice_id": "female_under_1", + "choice_text": "Female under age 1", + }, + { + "order": 2, + "choice_id": "male_under_1", + "choice_text": "Male under age 1", + }, + { + "order": 3, + "choice_id": "female_1", + "choice_text": "Female age 1", + }, + {"order": 4, "choice_id": "male_1", "choice_text": "Male age 1"}, + { + "order": 5, + "choice_id": "female_2", + "choice_text": "Female age 2", + }, + ], + # I removed a bunch of choices fyi + "selector": "MA", + "country_iso": "us", + "question_id": "3b65220db85f442ca16bb0f1c0e3a456", + "language_iso": "eng", + "question_text": "Please indicate the age and gender of your child or children:", + "question_type": "MC", + } + ) + answer = ("none",) + assert question.validate_question_answer(answer)[0] is True + answer = ("male_1",) + assert question.validate_question_answer(answer)[0] is True + answer = ("male_1", "female_1") + assert question.validate_question_answer(answer)[0] is True + answer = ("xxx",) + assert question.validate_question_answer(answer) == ( + False, + "Invalid Options Selected", + ) + answer = ("male_1", "male_1") + assert question.validate_question_answer(answer) == ( + False, + "Multiple of the same answer submitted", + ) + answer = ("male_1", "xxx") + assert question.validate_question_answer(answer) == ( + False, + "Invalid Options Selected", + ) + answer = ("male_1", "none") + assert question.validate_question_answer(answer) == ( + False, + "Invalid exclusive selection", + ) + + def test_validate_answer_TE(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + + question = UpkQuestion.model_validate( + { + "selector": "SL", + "validation": { + "patterns": [ + { + "message": "Must enter a valid zip code: XXXXX", + "pattern": "^[0-9]{5}$", + } + ] + }, + "country_iso": "us", + "question_id": "543de254e9ca4d9faded4377edab82a9", + "language_iso": "eng", + "configuration": {"max_length": 5, "min_length": 5}, + "question_text": "What is your zip code?", + "question_type": "TE", + } + ) + answer = ("33143",) + assert question.validate_question_answer(answer)[0] is True + answer = ("33143", "33143") + assert question.validate_question_answer(answer) == ( + False, + "Multiple of the same answer submitted", + ) + answer = ("33143", "12345") + assert question.validate_question_answer(answer) == ( + False, + "Only one answer allowed", + ) + answer = ("111",) + assert question.validate_question_answer(answer) == ( + False, + "Must enter a valid zip code: XXXXX", + ) diff --git a/tests/models/thl/test_user.py b/tests/models/thl/test_user.py new file mode 100644 index 0000000..4f10861 --- /dev/null +++ b/tests/models/thl/test_user.py @@ -0,0 +1,688 @@ +import json +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from random import randint, choice as rand_choice +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + + +class TestUserUserID: + + def test_valid(self): + from generalresearch.models.thl.user import User + + val = randint(1, 2**30) + user = User(user_id=val) + assert user.user_id == val + + def test_type(self): + from generalresearch.models.thl.user import User + + # It will cast str to int + assert User(user_id="1").user_id == 1 + + # It will cast float to int + assert User(user_id=1.0).user_id == 1 + + # It will cast Decimal to int + assert User(user_id=Decimal("1.0")).user_id == 1 + + # pydantic Validation error is a ValueError, let's check both.. + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=Decimal("1.00000001")) + assert "1 validation error for User" in str(cm.value) + assert "user_id" in str(cm.value) + assert "Input should be a valid integer," in str(cm.value) + + with pytest.raises(expected_exception=ValidationError) as cm: + User(user_id=Decimal("1.00000001")) + assert "1 validation error for User" in str(cm.value) + assert "user_id" in str(cm.value) + assert "Input should be a valid integer," in str(cm.value) + + def test_zero(self): + from generalresearch.models.thl.user import User + + with pytest.raises(expected_exception=ValidationError) as cm: + User(user_id=0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be greater than 0" in str(cm.value) + + def test_negative(self): + from generalresearch.models.thl.user import User + + with pytest.raises(expected_exception=ValidationError) as cm: + User(user_id=-1) + assert "1 validation error for User" in str(cm.value) + assert "Input should be greater than 0" in str(cm.value) + + def test_too_big(self): + from generalresearch.models.thl.user import User + + val = 2**31 + with pytest.raises(expected_exception=ValidationError) as cm: + User(user_id=val) + assert "1 validation error for User" in str(cm.value) + assert "Input should be less than 2147483648" in str(cm.value) + + def test_identifiable(self): + from generalresearch.models.thl.user import User + + val = randint(1, 2**30) + user = User(user_id=val) + assert user.is_identifiable + + +class TestUserProductID: + user_id = randint(1, 2**30) + + def test_valid(self): + from generalresearch.models.thl.user import User + + product_id = uuid4().hex + + user = User(user_id=self.user_id, product_id=product_id) + assert user.user_id == self.user_id + assert user.product_id == product_id + + def test_type(self): + from generalresearch.models.thl.user import User + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=0.0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=Decimal("0")) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + def test_empty(self): + from generalresearch.models.thl.user import User + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id="") + assert "1 validation error for User" in str(cm.value) + assert "String should have at least 32 characters" in str(cm.value) + + def test_invalid_len(self): + from generalresearch.models.thl.user import User + + # Valid uuid4s are 32 char long + product_id = uuid4().hex[:31] + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=product_id) + assert "1 validation error for User", str(cm.value) + assert "String should have at least 32 characters", str(cm.value) + + product_id = uuid4().hex * 2 + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, product_id=product_id) + assert "1 validation error for User" in str(cm.value) + assert "String should have at most 32 characters" in str(cm.value) + + product_id = uuid4().hex + product_id *= 2 + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=product_id) + assert "1 validation error for User" in str(cm.value) + assert "String should have at most 32 characters" in str(cm.value) + + def test_invalid_uuid(self): + from generalresearch.models.thl.user import User + + # Modify the UUID to break it + product_id = uuid4().hex[:31] + "x" + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=product_id) + assert "1 validation error for User" in str(cm.value) + assert "Invalid UUID" in str(cm.value) + + def test_invalid_hex_form(self): + from generalresearch.models.thl.user import User + + # Sure not in hex form, but it'll get caught for being the + # wrong length before anything else + product_id = str(uuid4()) # '1a93447e-c77b-4cfa-b58e-ed4777d57110' + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=product_id) + assert "1 validation error for User" in str(cm.value) + assert "String should have at most 32 characters" in str(cm.value) + + def test_identifiable(self): + """Can't create a User with only a product_id because it also + needs to the product_user_id""" + from generalresearch.models.thl.user import User + + product_id = uuid4().hex + with pytest.raises(expected_exception=ValueError) as cm: + User(product_id=product_id) + assert "1 validation error for User" in str(cm.value) + assert "Value error, User is not identifiable" in str(cm.value) + + +class TestUserProductUserID: + user_id = randint(1, 2**30) + + def randomword(self, length: int = 50): + # Raw so nothing is escaped to add additional backslashes + _bpuid_allowed = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!#$%&()*+,-.:;<=>?@[]^_{|}~" + return "".join(rand_choice(_bpuid_allowed) for i in range(length)) + + def test_valid(self): + from generalresearch.models.thl.user import User + + product_user_id = uuid4().hex[:12] + user = User(user_id=self.user_id, product_user_id=product_user_id) + + assert user.user_id == self.user_id + assert user.product_user_id == product_user_id + + def test_type(self): + from generalresearch.models.thl.user import User + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=0.0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, product_user_id=Decimal("0")) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + def test_empty(self): + from generalresearch.models.thl.user import User + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id="") + assert "1 validation error for User" in str(cm.value) + assert "String should have at least 3 characters" in str(cm.value) + + def test_invalid_len(self): + from generalresearch.models.thl.user import User + + product_user_id = self.randomword(251) + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "String should have at most 128 characters" in str(cm.value) + + product_user_id = self.randomword(2) + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "String should have at least 3 characters" in str(cm.value) + + def test_invalid_chars_space(self): + from generalresearch.models.thl.user import User + + product_user_id = f"{self.randomword(50)} {self.randomword(50)}" + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "String cannot contain spaces" in str(cm.value) + + def test_invalid_chars_slash(self): + from generalresearch.models.thl.user import User + + product_user_id = f"{self.randomword(50)}\{self.randomword(50)}" + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "String cannot contain backslash" in str(cm.value) + + product_user_id = f"{self.randomword(50)}/{self.randomword(50)}" + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "String cannot contain slash" in str(cm.value) + + def test_invalid_chars_backtick(self): + """Yes I could keep doing these specific character checks. However, + I wanted a test that made sure the regex was hit. I do not know + how we want to provide with the level of specific String checks + we do in here for specific error messages.""" + from generalresearch.models.thl.user import User + + product_user_id = f"{self.randomword(50)}`{self.randomword(50)}" + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "String is not valid regex" in str(cm.value) + + def test_unique_from_product_id(self): + # We removed this filter b/c these users already exist. the manager checks for this + # though and we can't create new users like this + pass + # product_id = uuid4().hex + # + # with pytest.raises(ValueError) as cm: + # User(product_id=product_id, product_user_id=product_id) + # assert "1 validation error for User", str(cm.exception)) + # assert "product_user_id must not equal the product_id", str(cm.exception)) + + def test_identifiable(self): + """Can't create a User with only a product_user_id because it also + needs to the product_id""" + from generalresearch.models.thl.user import User + + product_user_id = uuid4().hex + with pytest.raises(ValueError) as cm: + User(product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "Value error, User is not identifiable" in str(cm.value) + + +class TestUserUUID: + user_id = randint(1, 2**30) + + def test_valid(self): + from generalresearch.models.thl.user import User + + uuid_pk = uuid4().hex + + user = User(user_id=self.user_id, uuid=uuid_pk) + assert user.user_id == self.user_id + assert user.uuid == uuid_pk + + def test_type(self): + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=0.0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=Decimal("0")) + assert "1 validation error for User", str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + def test_empty(self): + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid="") + assert "1 validation error for User", str(cm.value) + assert "String should have at least 32 characters", str(cm.value) + + def test_invalid_len(self): + from generalresearch.models.thl.user import User + + # Valid uuid4s are 32 char long + uuid_pk = uuid4().hex[:31] + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=uuid_pk) + assert "1 validation error for User" in str(cm.value) + assert "String should have at least 32 characters" in str(cm.value) + + # Valid uuid4s are 32 char long + uuid_pk = uuid4().hex + uuid_pk *= 2 + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=uuid_pk) + assert "1 validation error for User" in str(cm.value) + assert "String should have at most 32 characters" in str(cm.value) + + def test_invalid_uuid(self): + from generalresearch.models.thl.user import User + + # Modify the UUID to break it + uuid_pk = uuid4().hex[:31] + "x" + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=uuid_pk) + assert "1 validation error for User" in str(cm.value) + assert "Invalid UUID" in str(cm.value) + + def test_invalid_hex_form(self): + from generalresearch.models.thl.user import User + + # Sure not in hex form, but it'll get caught for being the + # wrong length before anything else + uuid_pk = str(uuid4()) # '1a93447e-c77b-4cfa-b58e-ed4777d57110' + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=uuid_pk) + assert "1 validation error for User" in str(cm.value) + assert "String should have at most 32 characters" in str(cm.value) + + uuid_pk = str(uuid4())[:32] # '1a93447e-c77b-4cfa-b58e-ed4777d57110' + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=uuid_pk) + assert "1 validation error for User" in str(cm.value) + assert "Invalid UUID" in str(cm.value) + + def test_identifiable(self): + from generalresearch.models.thl.user import User + + user_uuid = uuid4().hex + user = User(uuid=user_uuid) + assert user.is_identifiable + + +class TestUserCreated: + user_id = randint(1, 2**30) + + def test_valid(self): + from generalresearch.models.thl.user import User + + user = User(user_id=self.user_id) + dt = datetime.now(tz=timezone.utc) + user.created = dt + + assert user.created == dt + + def test_tz_naive_throws_init(self): + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, created=datetime.now(tz=None)) + assert "1 validation error for User" in str(cm.value) + assert "Input should have timezone info" in str(cm.value) + + def test_tz_naive_throws_setter(self): + from generalresearch.models.thl.user import User + + user = User(user_id=self.user_id) + with pytest.raises(ValueError) as cm: + user.created = datetime.now(tz=None) + assert "1 validation error for User" in str(cm.value) + assert "Input should have timezone info" in str(cm.value) + + def test_tz_utc(self): + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User( + user_id=self.user_id, + created=datetime.now(tz=timezone(-timedelta(hours=8))), + ) + assert "1 validation error for User" in str(cm.value) + assert "Timezone is not UTC" in str(cm.value) + + def test_not_in_future(self): + from generalresearch.models.thl.user import User + + the_future = datetime.now(tz=timezone.utc) + timedelta(minutes=1) + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, created=the_future) + assert "1 validation error for User" in str(cm.value) + assert "Input is in the future" in str(cm.value) + + def test_after_anno_domini(self): + from generalresearch.models.thl.user import User + + before_ad = datetime( + year=2015, month=1, day=1, tzinfo=timezone.utc + ) + timedelta(minutes=1) + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, created=before_ad) + assert "1 validation error for User" in str(cm.value) + assert "Input is before Anno Domini" in str(cm.value) + + +class TestUserLastSeen: + user_id = randint(1, 2**30) + + def test_valid(self): + from generalresearch.models.thl.user import User + + user = User(user_id=self.user_id) + dt = datetime.now(tz=timezone.utc) + user.last_seen = dt + + assert user.last_seen == dt + + def test_tz_naive_throws_init(self): + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, last_seen=datetime.now(tz=None)) + assert "1 validation error for User" in str(cm.value) + assert "Input should have timezone info" in str(cm.value) + + def test_tz_naive_throws_setter(self): + from generalresearch.models.thl.user import User + + user = User(user_id=self.user_id) + with pytest.raises(ValueError) as cm: + user.last_seen = datetime.now(tz=None) + assert "1 validation error for User" in str(cm.value) + assert "Input should have timezone info" in str(cm.value) + + def test_tz_utc(self): + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User( + user_id=self.user_id, + last_seen=datetime.now(tz=timezone(-timedelta(hours=8))), + ) + assert "1 validation error for User" in str(cm.value) + assert "Timezone is not UTC" in str(cm.value) + + def test_not_in_future(self): + from generalresearch.models.thl.user import User + + the_future = datetime.now(tz=timezone.utc) + timedelta(minutes=1) + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, last_seen=the_future) + assert "1 validation error for User" in str(cm.value) + assert "Input is in the future" in str(cm.value) + + def test_after_anno_domini(self): + from generalresearch.models.thl.user import User + + before_ad = datetime( + year=2015, month=1, day=1, tzinfo=timezone.utc + ) + timedelta(minutes=1) + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, last_seen=before_ad) + assert "1 validation error for User" in str(cm.value) + assert "Input is before Anno Domini" in str(cm.value) + + +class TestUserBlocked: + user_id = randint(1, 2**30) + + def test_valid(self): + from generalresearch.models.thl.user import User + + user = User(user_id=self.user_id, blocked=True) + assert user.blocked + + def test_str_casting(self): + """We don't want any of these to work, and that's why + we set strict=True on the column""" + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, blocked="true") + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid boolean" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, blocked="True") + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid boolean" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, blocked="1") + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid boolean" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, blocked="yes") + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid boolean" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, blocked="no") + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid boolean" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, blocked=uuid4().hex) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid boolean" in str(cm.value) + + +class TestUserTiming: + user_id = randint(1, 2**30) + + def test_valid(self): + from generalresearch.models.thl.user import User + + created = datetime.now(tz=timezone.utc) - timedelta(minutes=60) + last_seen = datetime.now(tz=timezone.utc) - timedelta(minutes=59) + + user = User(user_id=self.user_id, created=created, last_seen=last_seen) + assert user.created == created + assert user.last_seen == last_seen + + def test_created_first(self): + from generalresearch.models.thl.user import User + + created = datetime.now(tz=timezone.utc) - timedelta(minutes=60) + last_seen = datetime.now(tz=timezone.utc) - timedelta(minutes=59) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, created=last_seen, last_seen=created) + assert "1 validation error for User" in str(cm.value) + assert "User created time invalid" in str(cm.value) + + +class TestUserModelVerification: + """Tests that may be dependent on more than 1 attribute""" + + def test_identifiable(self): + from generalresearch.models.thl.user import User + + product_id = uuid4().hex + product_user_id = uuid4().hex + user = User(product_id=product_id, product_user_id=product_user_id) + assert user.is_identifiable + + def test_valid_helper(self): + from generalresearch.models.thl.user import User + + user_bool = User.is_valid_ubp( + product_id=uuid4().hex, product_user_id=uuid4().hex + ) + assert user_bool + + user_bool = User.is_valid_ubp(product_id=uuid4().hex, product_user_id=" - - - ") + assert not user_bool + + +class TestUserSerialization: + + def test_basic_json(self): + from generalresearch.models.thl.user import User + + product_id = uuid4().hex + product_user_id = uuid4().hex + + user = User( + product_id=product_id, + product_user_id=product_user_id, + created=datetime.now(tz=timezone.utc), + blocked=False, + ) + + d = json.loads(user.to_json()) + assert d.get("product_id") == product_id + assert d.get("product_user_id") == product_user_id + assert not d.get("blocked") + + assert d.get("product") is None + assert d.get("created").endswith("Z") + + def test_basic_dict(self): + from generalresearch.models.thl.user import User + + product_id = uuid4().hex + product_user_id = uuid4().hex + + user = User( + product_id=product_id, + product_user_id=product_user_id, + created=datetime.now(tz=timezone.utc), + blocked=False, + ) + + d = user.to_dict() + assert d.get("product_id") == product_id + assert d.get("product_user_id") == product_user_id + assert not d.get("blocked") + + assert d.get("product") is None + assert d.get("created").tzinfo == timezone.utc + + def test_from_json(self): + from generalresearch.models.thl.user import User + + product_id = uuid4().hex + product_user_id = uuid4().hex + + user = User( + product_id=product_id, + product_user_id=product_user_id, + created=datetime.now(tz=timezone.utc), + blocked=False, + ) + + u = User.model_validate_json(user.to_json()) + assert u.product_id == product_id + assert u.product is None + assert u.created.tzinfo == timezone.utc + + +class TestUserMethods: + + def test_audit_log(self, user, audit_log_manager): + assert user.audit_log is None + user.prefetch_audit_log(audit_log_manager=audit_log_manager) + assert user.audit_log == [] + + audit_log_manager.create_dummy(user_id=user.user_id) + user.prefetch_audit_log(audit_log_manager=audit_log_manager) + assert len(user.audit_log) == 1 + + def test_transactions( + self, user_factory, thl_lm, session_with_tx_factory, product_user_wallet_yes + ): + u1 = user_factory(product=product_user_wallet_yes) + + assert u1.transactions is None + u1.prefetch_transactions(thl_lm=thl_lm) + assert u1.transactions == [] + + session_with_tx_factory(user=u1) + + u1.prefetch_transactions(thl_lm=thl_lm) + assert len(u1.transactions) == 1 + + @pytest.mark.skip(reason="TODO") + def test_location_history(self, user): + assert user.location_history is None diff --git a/tests/models/thl/test_user_iphistory.py b/tests/models/thl/test_user_iphistory.py new file mode 100644 index 0000000..46018e0 --- /dev/null +++ b/tests/models/thl/test_user_iphistory.py @@ -0,0 +1,45 @@ +from datetime import timezone, datetime, timedelta + +from generalresearch.models.thl.user_iphistory import ( + UserIPHistory, + UserIPRecord, +) + + +def test_collapse_ip_records(): + # This does not exist in a db, so we do not need fixtures/ real user ids, whatever + now = datetime.now(tz=timezone.utc) - timedelta(days=1) + # Gets stored most recent first. This is reversed, but the validator will order it + records = [ + UserIPRecord(ip="1.2.3.5", created=now + timedelta(minutes=1)), + UserIPRecord( + ip="1e5c:de49:165a:6aa0:4f89:1433:9af7:aaaa", + created=now + timedelta(minutes=2), + ), + UserIPRecord( + ip="1e5c:de49:165a:6aa0:4f89:1433:9af7:bbbb", + created=now + timedelta(minutes=3), + ), + UserIPRecord(ip="1.2.3.5", created=now + timedelta(minutes=4)), + UserIPRecord( + ip="1e5c:de49:165a:6aa0:4f89:1433:9af7:cccc", + created=now + timedelta(minutes=5), + ), + UserIPRecord( + ip="6666:de49:165a:6aa0:4f89:1433:9af7:aaaa", + created=now + timedelta(minutes=6), + ), + UserIPRecord(ip="1.2.3.6", created=now + timedelta(minutes=7)), + ] + iph = UserIPHistory(user_id=1, ips=records) + res = iph.collapse_ip_records() + + # We should be left with one of the 1.2.3.5 ipv4s, + # and only the 1e5c::cccc and the 6666 ipv6 addresses + assert len(res) == 4 + assert [x.ip for x in res] == [ + "1.2.3.6", + "6666:de49:165a:6aa0:4f89:1433:9af7:aaaa", + "1e5c:de49:165a:6aa0:4f89:1433:9af7:cccc", + "1.2.3.5", + ] diff --git a/tests/models/thl/test_user_metadata.py b/tests/models/thl/test_user_metadata.py new file mode 100644 index 0000000..3d851dc --- /dev/null +++ b/tests/models/thl/test_user_metadata.py @@ -0,0 +1,46 @@ +import pytest + +from generalresearch.models import MAX_INT32 +from generalresearch.models.thl.user_profile import UserMetadata + + +class TestUserMetadata: + + def test_default(self): + # You can initialize it with nothing + um = UserMetadata() + assert um.email_address is None + assert um.email_sha1 is None + + def test_user_id(self): + # This does NOT validate that the user_id exists. When we attempt a db operation, + # at that point it will fail b/c of the foreign key constraint. + UserMetadata(user_id=MAX_INT32 - 1) + + with pytest.raises(expected_exception=ValueError) as cm: + UserMetadata(user_id=MAX_INT32) + assert "Input should be less than 2147483648" in str(cm.value) + + def test_email(self): + um = UserMetadata(email_address="e58375d80f5f4a958138004aae44c7ca@example.com") + assert ( + um.email_sha256 + == "fd219d8b972b3d82e70dc83284027acc7b4a6de66c42261c1684e3f05b545bc0" + ) + assert um.email_sha1 == "a82578f02b0eed28addeb81317417cf239ede1c3" + assert um.email_md5 == "9073a7a3c21cfd6160d1899fb736cd1c" + + # You cannot set the hashes directly + with pytest.raises(expected_exception=AttributeError) as cm: + um.email_md5 = "x" * 32 + # assert "can't set attribute 'email_md5'" in str(cm.value) + assert "property 'email_md5' of 'UserMetadata' object has no setter" in str( + cm.value + ) + + # assert it hasn't changed anything + assert um.email_md5 == "9073a7a3c21cfd6160d1899fb736cd1c" + + # If you update the email, all the hashes change + um.email_address = "greg@example.com" + assert um.email_md5 != "9073a7a3c21cfd6160d1899fb736cd1c" diff --git a/tests/models/thl/test_user_streak.py b/tests/models/thl/test_user_streak.py new file mode 100644 index 0000000..72efd05 --- /dev/null +++ b/tests/models/thl/test_user_streak.py @@ -0,0 +1,96 @@ +from datetime import datetime, timedelta +from zoneinfo import ZoneInfo + +import pytest +from pydantic import ValidationError + +from generalresearch.models.thl.user_streak import ( + UserStreak, + StreakPeriod, + StreakFulfillment, + StreakState, +) + + +def test_user_streak_empty_fail(): + us = UserStreak( + period=StreakPeriod.DAY, + fulfillment=StreakFulfillment.COMPLETE, + country_iso="us", + user_id=1, + last_fulfilled_period_start=None, + current_streak=0, + longest_streak=0, + state=StreakState.BROKEN, + ) + assert us.time_remaining_in_period is None + + with pytest.raises( + ValidationError, match="StreakState.BROKEN but current_streak not 0" + ): + UserStreak( + period=StreakPeriod.DAY, + fulfillment=StreakFulfillment.COMPLETE, + country_iso="us", + user_id=1, + last_fulfilled_period_start=None, + current_streak=1, + longest_streak=0, + state=StreakState.BROKEN, + ) + + with pytest.raises( + ValidationError, match="Current streak can't be longer than longest streak" + ): + UserStreak( + period=StreakPeriod.DAY, + fulfillment=StreakFulfillment.COMPLETE, + country_iso="us", + user_id=1, + last_fulfilled_period_start=None, + current_streak=1, + longest_streak=0, + state=StreakState.ACTIVE, + ) + + +def test_user_streak_remaining(): + us = UserStreak( + period=StreakPeriod.DAY, + fulfillment=StreakFulfillment.COMPLETE, + country_iso="us", + user_id=1, + last_fulfilled_period_start=None, + current_streak=1, + longest_streak=1, + state=StreakState.AT_RISK, + ) + now = datetime.now(tz=ZoneInfo("America/New_York")) + end_of_today = now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta( + days=1 + ) + print(f"{now.isoformat()=}, {end_of_today.isoformat()=}") + expected = (end_of_today - now).total_seconds() + assert us.time_remaining_in_period.total_seconds() == pytest.approx(expected, abs=1) + + +def test_user_streak_remaining_month(): + us = UserStreak( + period=StreakPeriod.MONTH, + fulfillment=StreakFulfillment.COMPLETE, + country_iso="us", + user_id=1, + last_fulfilled_period_start=None, + current_streak=1, + longest_streak=1, + state=StreakState.AT_RISK, + ) + now = datetime.now(tz=ZoneInfo("America/New_York")) + end_of_month = ( + now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + timedelta(days=32) + ).replace(day=1) + print(f"{now.isoformat()=}, {end_of_month.isoformat()=}") + expected = (end_of_month - now).total_seconds() + assert us.time_remaining_in_period.total_seconds() == pytest.approx(expected, abs=1) + print(us.time_remaining_in_period) diff --git a/tests/models/thl/test_wall.py b/tests/models/thl/test_wall.py new file mode 100644 index 0000000..057aad2 --- /dev/null +++ b/tests/models/thl/test_wall.py @@ -0,0 +1,207 @@ +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +from generalresearch.models import Source +from generalresearch.models.thl.definitions import ( + Status, + StatusCode1, + WallStatusCode2, +) +from generalresearch.models.thl.session import Wall + + +class TestWall: + + def test_wall_json(self): + w = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + ext_status_code_1="1.0", + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + started=datetime(2023, 1, 1, 0, 0, 1, tzinfo=timezone.utc), + finished=datetime(2023, 1, 1, 0, 10, 1, tzinfo=timezone.utc), + ) + s = w.to_json() + w2 = Wall.from_json(s) + assert w == w2 + + def test_status_status_code_agreement(self): + # should not raise anything + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.MARKETPLACE_FAIL, + status_code_2=WallStatusCode2.COMPLETE_TOO_FAST, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + with pytest.raises(expected_exception=ValidationError) as e: + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.GRS_ABANDON, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + assert "If status is f, status_code_1 should be in" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as cm: + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.GRS_ABANDON, + status_code_2=WallStatusCode2.COMPLETE_TOO_FAST, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + assert "If status is f, status_code_1 should be in" in str(e.value) + + def test_status_code_1_2_agreement(self): + # should not raise anything + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.MARKETPLACE_FAIL, + status_code_2=WallStatusCode2.COMPLETE_TOO_FAST, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + status_code_2=None, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + status_code_2=None, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + + with pytest.raises(expected_exception=ValidationError) as e: + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + status_code_2=WallStatusCode2.COMPLETE_TOO_FAST, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + assert "If status_code_1 is 1, status_code_2 should be in" in str(e.value) + + def test_annotate_status_code(self): + w = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + ) + w.annotate_status_codes("1.0") + assert Status.COMPLETE == w.status + assert StatusCode1.COMPLETE == w.status_code_1 + assert w.status_code_2 is None + assert "1.0" == w.ext_status_code_1 + assert w.ext_status_code_2 is None + + def test_buyer_too_long(self): + buyer_id = uuid4().hex + w = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + buyer_id=buyer_id, + ) + assert buyer_id == w.buyer_id + + w = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + buyer_id=None, + ) + assert w.buyer_id is None + + w = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + buyer_id=buyer_id + "abc123", + ) + assert buyer_id == w.buyer_id + + @pytest.mark.skip(reason="TODO") + def test_more_stuff(self): + # todo: .update, test status logic + pass diff --git a/tests/models/thl/test_wall_session.py b/tests/models/thl/test_wall_session.py new file mode 100644 index 0000000..ab140e9 --- /dev/null +++ b/tests/models/thl/test_wall_session.py @@ -0,0 +1,326 @@ +from datetime import datetime, timezone, timedelta +from decimal import Decimal + +import pytest + +from generalresearch.models import Source +from generalresearch.models.thl.definitions import Status, StatusCode1 +from generalresearch.models.thl.session import Session, Wall +from generalresearch.models.thl.user import User + + +class TestWallSession: + + def test_session_with_no_wall_events(self): + started = datetime(2023, 1, 1, tzinfo=timezone.utc) + s = Session(user=User(user_id=1), started=started) + assert s.status is None + assert s.status_code_1 is None + + # todo: this needs to be set explicitly, not this way + # # If I have no wall events, it's a fail + # s.determine_session_status() + # assert s.status == Status.FAIL + # assert s.status_code_1 == StatusCode1.SESSION_START_FAIL + + def test_session_timeout_with_only_grs(self): + started = datetime(2023, 1, 1, tzinfo=timezone.utc) + s = Session(user=User(user_id=1), started=started) + w = Wall( + user_id=1, + source=Source.GRS, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + ) + s.append_wall_event(w) + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + assert Status.TIMEOUT == s.status + assert StatusCode1.GRS_ABANDON == s.status_code_1 + + def test_session_with_only_grs_fail(self): + # todo: this needs to be set explicitly, not this way + pass + # started = datetime(2023, 1, 1, tzinfo=timezone.utc) + # s = Session(user=User(user_id=1), started=started) + # w = Wall(user_id=1, source=Source.GRS, req_survey_id='xxx', + # req_cpi=Decimal(1), session_id=1) + # s.append_wall_event(w) + # w.finish(status=Status.FAIL, status_code_1=StatusCode1.PS_FAIL) + # s.determine_session_status() + # assert s.status == Status.FAIL + # assert s.status_code_1 == StatusCode1.GRS_FAIL + + def test_session_with_only_grs_complete(self): + started = datetime(year=2023, month=1, day=1, tzinfo=timezone.utc) + + # A Session is started + s = Session(user=User(user_id=1), started=started) + + # The User goes into a GRS survey, and completes it + # @gstupp - should a GRS be allowed with a req_cpi > 0? + w = Wall( + user_id=1, + source=Source.GRS, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + ) + s.append_wall_event(w) + w.finish(status=Status.COMPLETE, status_code_1=StatusCode1.COMPLETE) + + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + + assert s.status == Status.FAIL + + # @gstupp changed this behavior on 11/2023 (51471b6ae671f21212a8b1fad60b508181cbb8ca) + # I don't know which is preferred or the consequences of each. However, + # now it's a SESSION_CONTINUE_FAIL instead of a SESSION_START_FAIL so + # change this so the test passes + # self.assertEqual(s.status_code_1, StatusCode1.SESSION_START_FAIL) + assert s.status_code_1 == StatusCode1.SESSION_CONTINUE_FAIL + + @pytest.mark.skip(reason="TODO") + def test_session_with_only_non_grs_complete(self): + # todo: this needs to be set explicitly, not this way + pass + # # This fails... until payout stuff is done + # started = datetime(2023, 1, 1, tzinfo=timezone.utc) + # s = Session(user=User(user_id=1), started=started) + # w = Wall(source=Source.DYNATA, req_survey_id='xxx', req_cpi=Decimal('1.00001'), + # session_id=1, user_id=1) + # s.append_wall_event(w) + # w.finish(status=Status.COMPLETE, status_code_1=StatusCode1.COMPLETE) + # s.determine_session_status() + # assert s.status == Status.COMPLETE + # assert s.status_code_1 is None + + def test_session_with_only_non_grs_fail(self): + started = datetime(year=2023, month=1, day=1, tzinfo=timezone.utc) + + s = Session(user=User(user_id=1), started=started) + w = Wall( + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.00001"), + session_id=1, + user_id=1, + ) + + s.append_wall_event(w) + w.finish(status=Status.FAIL, status_code_1=StatusCode1.BUYER_FAIL) + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + + assert s.status == Status.FAIL + assert s.status_code_1 == StatusCode1.BUYER_FAIL + assert s.payout is None + + def test_session_with_only_non_grs_timeout(self): + started = datetime(year=2023, month=1, day=1, tzinfo=timezone.utc) + + s = Session(user=User(user_id=1), started=started) + w = Wall( + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.00001"), + session_id=1, + user_id=1, + ) + + s.append_wall_event(w) + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + + assert s.status == Status.TIMEOUT + assert s.status_code_1 == StatusCode1.BUYER_ABANDON + assert s.payout is None + + def test_session_with_grs_and_external(self): + started = datetime(year=2023, month=1, day=1, tzinfo=timezone.utc) + + s = Session(user=User(user_id=1), started=started) + w = Wall( + source=Source.GRS, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + user_id=1, + started=started, + ) + + s.append_wall_event(w) + w.finish( + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + finished=started + timedelta(minutes=10), + ) + + w = Wall( + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.00001"), + session_id=1, + user_id=1, + ) + s.append_wall_event(w) + w.finish( + status=Status.ABANDON, + finished=datetime.now(tz=timezone.utc) + timedelta(minutes=10), + status_code_1=StatusCode1.BUYER_ABANDON, + ) + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + + assert s.status == Status.ABANDON + assert s.status_code_1 == StatusCode1.BUYER_ABANDON + assert s.payout is None + + s = Session(user=User(user_id=1), started=started) + w = Wall( + source=Source.GRS, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + user_id=1, + ) + s.append_wall_event(w) + w.finish(status=Status.COMPLETE, status_code_1=StatusCode1.COMPLETE) + w = Wall( + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.00001"), + session_id=1, + user_id=1, + ) + s.append_wall_event(w) + w.finish(status=Status.FAIL, status_code_1=StatusCode1.PS_DUPLICATE) + + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + + assert s.status == Status.FAIL + assert s.status_code_1 == StatusCode1.PS_DUPLICATE + assert s.payout is None + + def test_session_marketplace_fail(self): + started = datetime(2023, 1, 1, tzinfo=timezone.utc) + + s = Session(user=User(user_id=1), started=started) + w = Wall( + source=Source.CINT, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + user_id=1, + started=started, + ) + s.append_wall_event(w) + w.finish( + status=Status.FAIL, + status_code_1=StatusCode1.MARKETPLACE_FAIL, + finished=started + timedelta(minutes=10), + ) + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + assert Status.FAIL == s.status + assert StatusCode1.SESSION_CONTINUE_QUALITY_FAIL == s.status_code_1 + + def test_session_unknown(self): + started = datetime(2023, 1, 1, tzinfo=timezone.utc) + + s = Session(user=User(user_id=1), started=started) + w = Wall( + source=Source.CINT, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + user_id=1, + started=started, + ) + s.append_wall_event(w) + w.finish( + status=Status.FAIL, + status_code_1=StatusCode1.UNKNOWN, + finished=started + timedelta(minutes=10), + ) + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + assert Status.FAIL == s.status + assert StatusCode1.BUYER_FAIL == s.status_code_1 + + +# class TestWallSessionPayout: +# product_id = uuid4().hex +# +# def test_session_payout_with_only_non_grs_complete(self): +# sql_helper = self.make_sql_helper() +# user = User(user_id=1, product_id=self.product_id) +# s = Session(user=user, started=datetime(2023, 1, 1, tzinfo=timezone.utc)) +# w = Wall(source=Source.DYNATA, req_survey_id='xxx', req_cpi=Decimal('1.00001')) +# s.append_wall_event(w) +# w.handle_callback(status=Status.COMPLETE) +# s.determine_session_status() +# s.determine_payout(sql_helper=sql_helper) +# assert s.status == Status.COMPLETE +# assert s.status_code_1 is None +# # we're assuming here the commission on this BP is 8.5% and doesn't get changed by someone! +# assert s.payout == Decimal('0.88') +# +# def test_session_payout(self): +# sql_helper = self.make_sql_helper() +# user = User(user_id=1, product_id=self.product_id) +# s = Session(user=user, started=datetime(2023, 1, 1, tzinfo=timezone.utc)) +# w = Wall(source=Source.GRS, req_survey_id='xxx', req_cpi=1) +# s.append_wall_event(w) +# w.handle_callback(status=Status.COMPLETE) +# w = Wall(source=Source.DYNATA, req_survey_id='xxx', req_cpi=Decimal('1.00001')) +# s.append_wall_event(w) +# w.handle_callback(status=Status.COMPLETE) +# s.determine_session_status() +# s.determine_payout(commission_pct=Decimal('0.05')) +# assert s.status == Status.COMPLETE +# assert s.status_code_1 is None +# assert s.payout == Decimal('0.93') + + +# def test_get_from_uuid_vendor_wall(self): +# sql_helper = self.make_sql_helper() +# sql_helper.get_or_create("auth_user", "id", {"id": 1}, { +# "id": 1, "password": "1", +# "last_login": None, "is_superuser": 0, +# "username": "a", "first_name": "a", +# "last_name": "a", "email": "a", +# "is_staff": 0, "is_active": 1, +# "date_joined": "2023-10-13 14:03:20.000000"}) +# sql_helper.get_or_create("vendor_wallsession", "id", {"id": 324}, {"id": 324}) +# sql_helper.create("vendor_wall", { +# "id": "7b3e380babc840b79abf0030d408bbd9", +# "status": "c", +# "started": "2023-10-10 00:51:13.415444", +# "finished": "2023-10-10 01:08:00.676947", +# "req_loi": 1200, +# "req_cpi": 0.63, +# "req_survey_id": "8070750", +# "survey_id": "8070750", +# "cpi": 0.63, +# "user_id": 1, +# "report_notes": None, +# "report_status": None, +# "status_code": "1", +# "req_survey_hashed_opp": None, +# "session_id": 324, +# "source": "i", +# "ubp_id": None +# }) +# Wall +# w = Wall.get_from_uuid_vendor_wall('7b3e380babc840b79abf0030d408bbd9', sql_helper=sql_helper, +# session_id=1) +# assert w.status == Status.COMPLETE +# assert w.source == Source.INNOVATE +# assert w.uuid == '7b3e380babc840b79abf0030d408bbd9' +# assert w.cpi == Decimal('0.63') +# assert w.survey_id == '8070750' +# assert w.user_id == 1 diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 0000000..d280de0 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto
\ No newline at end of file diff --git a/tests/sql_helper.py b/tests/sql_helper.py new file mode 100644 index 0000000..247f0cd --- /dev/null +++ b/tests/sql_helper.py @@ -0,0 +1,53 @@ +from uuid import uuid4 + +import pytest +from pydantic import MySQLDsn, MariaDBDsn, ValidationError + + +class TestSqlHelper: + def test_db_property(self): + from generalresearch.sql_helper import SqlHelper + + db_name = uuid4().hex[:8] + dsn = MySQLDsn(f"mysql://root@localhost/{db_name}") + instance = SqlHelper(dsn=dsn) + + assert instance.db == db_name + assert instance.db_name == db_name + assert instance.dbname == db_name + + def test_scheme(self): + from generalresearch.sql_helper import SqlHelper + + dsn = MySQLDsn(f"mysql://root@localhost/test") + instance = SqlHelper(dsn=dsn) + assert instance.is_mysql() + + # This needs psycopg2 installed, and don't need to make this a + # requirement of the package ... todo? + # dsn = PostgresDsn(f"postgres://root@localhost/test") + # instance = SqlHelper(dsn=dsn) + # self.assertTrue(instance.is_postgresql()) + + with pytest.raises(ValidationError): + SqlHelper(dsn=MariaDBDsn(f"maria://root@localhost/test")) + + def test_row_decode(self): + from generalresearch.sql_helper import decode_uuids + + valid_uuid4_1 = "bf432839fd0d4436ab1581af5eb98f26" + valid_uuid4_2 = "e1d8683b9c014e9d80eb120c2fc95288" + invalid_uuid4_2 = "2f3b9edf5a3da6198717b77604775ec1" + + row1 = { + "b": valid_uuid4_1, + "c": valid_uuid4_2, + } + + row2 = { + "a": valid_uuid4_1, + "b": invalid_uuid4_2, + } + + assert row1 == decode_uuids(row1) + assert row1 != decode_uuids(row2) diff --git a/tests/wall_status_codes/__init__.py b/tests/wall_status_codes/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/wall_status_codes/__init__.py diff --git a/tests/wall_status_codes/test_analyze.py b/tests/wall_status_codes/test_analyze.py new file mode 100644 index 0000000..da6efb3 --- /dev/null +++ b/tests/wall_status_codes/test_analyze.py @@ -0,0 +1,150 @@ +from generalresearch.models.thl.definitions import StatusCode1, Status +from generalresearch.wall_status_codes import innovate + + +class TestInnovate: + def test_complete(self): + status, status_code_1, status_code_2 = innovate.annotate_status_code("1", None) + assert Status.COMPLETE == status + assert StatusCode1.COMPLETE == status_code_1 + assert status_code_2 is None + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "1", "whatever" + ) + assert Status.COMPLETE == status + assert StatusCode1.COMPLETE == status_code_1 + assert status_code_2 is None + + def test_unknown(self): + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "69420", None + ) + assert Status.FAIL == status + assert StatusCode1.UNKNOWN == status_code_1 + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "69420", "Speeder" + ) + assert Status.FAIL == status + assert StatusCode1.UNKNOWN == status_code_1 + + def test_ps(self): + status, status_code_1, status_code_2 = innovate.annotate_status_code("5", None) + assert Status.FAIL == status + assert StatusCode1.PS_FAIL == status_code_1 + # The ext_status_code_2 should reclassify this as PS_FAIL + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "8", "DeviceType" + ) + assert Status.FAIL == status + assert StatusCode1.PS_FAIL == status_code_1 + # this should be reclassified from PS_FAIL to PS_OQ + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "5", "Group NA" + ) + assert Status.FAIL == status + assert StatusCode1.PS_OVERQUOTA == status_code_1 + + def test_dupe(self): + # innovate calls it a quality, should be dupe + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "8", "Duplicated to token Tq2SwRVX7PUWnFunGPAYWHk" + ) + assert Status.FAIL == status + assert StatusCode1.PS_DUPLICATE == status_code_1 + # stay as quality + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "8", "Selected threat potential score at joblevel not allow the survey" + ) + assert Status.FAIL == status + assert StatusCode1.PS_QUALITY == status_code_1 + + +# todo: fix me: This got broke because I opened the csv in libreoffice and it broke it +# status codes "1.0" -> 1 (facepalm) + +# class TestAllCsv:: +# def get_wall(self) -> pd.DataFrame: +# df = pd.read_csv(os.path.join(os.path.dirname(__file__), "wall_excerpt.csv.gz")) +# df['started'] = pd.to_datetime(df['started']) +# df['finished'] = pd.to_datetime(df['finished']) +# return df +# +# def test_dynata(self): +# df = self.get_wall() +# df = df[df.source == Source.DYNATA] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(dynata.annotate_status_code(row.status_code)), axis=1) +# self.assertEqual(1419, len(df[df.t_status == Status.COMPLETE])) +# assert len(df[df.t_status == Status.FAIL]) == 2109 +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) == 0 +# assert 1000 < len(df[df.t_status_code_1 == StatusCode1.BUYER_FAIL]) < 1100 +# assert 30 < len(df[df.t_status_code_1 == StatusCode1.PS_BLOCKED]) < 40 +# +# def test_fullcircle(self): +# df = self.get_wall() +# df = df[df.source == Source.FULL_CIRCLE] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(fullcircle.annotate_status_code(row.status_code)), axis=1) +# # assert len(df[df.t_status == Status.COMPLETE]) == 1419 +# # assert len(df[df.t_status == Status.FAIL]) == 2109 +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) == 0 +# +# def test_innovate(self): +# df = self.get_wall() +# df = df[df.source == Source.INNOVATE] +# df = df[~df.status.isin({'r', 'e'})] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(innovate.annotate_status_code( +# row.status_code, row.status_code_2 if pd.notnull(row.status_code_2) else None)), +# axis=1) +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) / len(df) < 0.05 +# # assert len(df[df.t_status == Status.COMPLETE]) == 1419 +# # assert len(df[df.t_status == Status.FAIL]) == 2109 +# # assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) == 0 +# +# def test_morning(self): +# df = self.get_wall() +# df = df[df.source == Source.MORNING_CONSULT] +# df = df[df.status_code.notnull()] +# # we have to do this for old values... +# df['status_code'] = df['status_code'].apply(short_code_to_status_codes_morning.get) +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(morning.annotate_status_code(row.status, row.status_code)), axis=1) +# dff = df[df.t_status != Status.COMPLETE] +# assert len(dff[dff.t_status_code_1 == StatusCode1.UNKNOWN]) / len(dff) < 0.05 +# +# def test_pollfish(self): +# df = self.get_wall() +# df = df[df.source == Source.POLLFISH] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(pollfish.annotate_status_code(row.status_code)), axis=1) +# +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) / len(df) < 0.05 +# +# def test_pollfish(self): +# df = self.get_wall() +# df = df[df.source == Source.PRECISION] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(precision.annotate_status_code(row.status_code)), axis=1) +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) / len(df) < 0.05 +# +# def test_sago(self): +# df = self.get_wall() +# df = df[df.source == Source.SAGO] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(sago.annotate_status_code(row.status_code)), axis=1) +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) / len(df) < 0.05 +# +# def test_spectrum(self): +# df = self.get_wall() +# df = df[df.source == Source.SPECTRUM] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(spectrum.annotate_status_code(row.status_code)), axis=1) +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) / len(df) < 0.05 diff --git a/tests/wxet/__init__.py b/tests/wxet/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/wxet/__init__.py diff --git a/tests/wxet/models/__init__.py b/tests/wxet/models/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/wxet/models/__init__.py diff --git a/tests/wxet/models/test_definitions.py b/tests/wxet/models/test_definitions.py new file mode 100644 index 0000000..543b9f1 --- /dev/null +++ b/tests/wxet/models/test_definitions.py @@ -0,0 +1,113 @@ +import pytest + + +class TestWXETStatusCode1: + + def test_is_pre_task_entry_fail_pre(self): + from generalresearch.wxet.models.definitions import ( + WXETStatusCode1, + ) + + assert WXETStatusCode1.UNKNOWN.is_pre_task_entry_fail + assert WXETStatusCode1.WXET_FAIL.is_pre_task_entry_fail + assert WXETStatusCode1.WXET_ABANDON.is_pre_task_entry_fail + + def test_is_pre_task_entry_fail_post(self): + from generalresearch.wxet.models.definitions import ( + WXETStatusCode1, + ) + + assert not WXETStatusCode1.BUYER_OVER_QUOTA.is_pre_task_entry_fail + assert not WXETStatusCode1.BUYER_DUPLICATE.is_pre_task_entry_fail + assert not WXETStatusCode1.BUYER_TASK_NOT_AVAILABLE.is_pre_task_entry_fail + + assert not WXETStatusCode1.BUYER_ABANDON.is_pre_task_entry_fail + assert not WXETStatusCode1.BUYER_FAIL.is_pre_task_entry_fail + assert not WXETStatusCode1.BUYER_QUALITY_FAIL.is_pre_task_entry_fail + assert not WXETStatusCode1.BUYER_POSTBACK_NOT_RECEIVED.is_pre_task_entry_fail + assert not WXETStatusCode1.COMPLETE.is_pre_task_entry_fail + + +class TestCheckWXETStatusConsistent: + + def test_completes(self): + + from generalresearch.wxet.models.definitions import ( + WXETStatus, + WXETStatusCode1, + check_wxet_status_consistent, + ) + + with pytest.raises(AssertionError) as cm: + check_wxet_status_consistent( + status=WXETStatus.COMPLETE, + status_code_1=WXETStatusCode1.UNKNOWN, + status_code_2=None, + ) + + assert ( + "Invalid StatusCode1 when Status=COMPLETE. Use WXETStatusCode1.COMPLETE" + == str(cm.value) + ) + + def test_abandon(self): + + from generalresearch.wxet.models.definitions import ( + WXETStatus, + WXETStatusCode1, + check_wxet_status_consistent, + ) + + with pytest.raises(AssertionError) as cm: + check_wxet_status_consistent( + status=WXETStatus.ABANDON, + status_code_1=WXETStatusCode1.COMPLETE, + status_code_2=None, + ) + assert ( + "Invalid StatusCode1 when Status=ABANDON. Use WXET_ABANDON or BUYER_ABANDON" + == str(cm.value) + ) + + def test_fail(self): + + from generalresearch.wxet.models.definitions import ( + WXETStatus, + WXETStatusCode1, + check_wxet_status_consistent, + ) + + for sc1 in [ + WXETStatusCode1.COMPLETE, + WXETStatusCode1.WXET_ABANDON, + WXETStatusCode1.WXET_ABANDON, + ]: + with pytest.raises(AssertionError) as cm: + check_wxet_status_consistent( + status=WXETStatus.FAIL, + status_code_1=sc1, + status_code_2=None, + ) + assert "Invalid StatusCode1 when Status=FAIL." == str(cm.value) + + def test_status_code_2(self): + """Any StatusCode2 should fail if the StatusCode1 isn't + StatusCode1.WXET_FAIL + """ + + from generalresearch.wxet.models.definitions import ( + WXETStatus, + WXETStatusCode1, + WXETStatusCode2, + check_wxet_status_consistent, + ) + + for sc2 in WXETStatusCode2: + with pytest.raises(AssertionError) as cm: + check_wxet_status_consistent( + status=WXETStatus.FAIL, + status_code_1=WXETStatusCode1.COMPLETE, + status_code_2=sc2, + ) + + assert "Invalid StatusCode1 when Status=FAIL." == str(cm.value) diff --git a/tests/wxet/models/test_finish_type.py b/tests/wxet/models/test_finish_type.py new file mode 100644 index 0000000..7bdeea7 --- /dev/null +++ b/tests/wxet/models/test_finish_type.py @@ -0,0 +1,136 @@ +import pytest + +from generalresearch.wxet.models.definitions import WXETStatus, WXETStatusCode1 +from generalresearch.wxet.models.finish_type import FinishType, is_a_finish + + +class TestFinishType: + + def test_init_entrance(self): + instance = FinishType.ENTRANCE + finish_statuses = instance.finish_statuses + assert isinstance(finish_statuses, set) + assert 5 == len(finish_statuses) + + def test_init_complete(self): + instance = FinishType.COMPLETE + finish_statuses = instance.finish_statuses + assert isinstance(finish_statuses, set) + assert 1 == len(finish_statuses) + + def test_init_fail_or_complete(self): + instance = FinishType.FAIL_OR_COMPLETE + finish_statuses = instance.finish_statuses + assert isinstance(finish_statuses, set) + assert 2 == len(finish_statuses) + + def test_init_fail(self): + instance = FinishType.FAIL + finish_statuses = instance.finish_statuses + assert isinstance(finish_statuses, set) + assert 1 == len(finish_statuses) + + +class TestFunctionIsAFinish: + + def test_init_ft_entrance(self): + assert is_a_finish( + status=None, + status_code_1=None, + finish_type=FinishType.ENTRANCE, + ) + + assert is_a_finish( + status=WXETStatus.ABANDON, + status_code_1=None, + finish_type=FinishType.ENTRANCE, + ) + + assert is_a_finish( + status=WXETStatus.ABANDON, + status_code_1=WXETStatusCode1.BUYER_ABANDON, + finish_type=FinishType.ENTRANCE, + ) + + # If it's a WXET Abandon, they ever entered the Task so don't + # consider it a Finish + assert not is_a_finish( + status=WXETStatus.ABANDON, + status_code_1=WXETStatusCode1.WXET_ABANDON, + finish_type=FinishType.ENTRANCE, + ) + + def test_init_ft_complete(self): + assert is_a_finish( + status=WXETStatus.COMPLETE, + status_code_1=None, + finish_type=FinishType.COMPLETE, + ) + + assert is_a_finish( + status=WXETStatus.COMPLETE, + status_code_1=WXETStatusCode1.COMPLETE, + finish_type=FinishType.COMPLETE, + ) + + def test_init_ft_fail_or_complete(self): + assert is_a_finish( + status=WXETStatus.FAIL, + status_code_1=None, + finish_type=FinishType.FAIL_OR_COMPLETE, + ) + + assert is_a_finish( + status=WXETStatus.FAIL, + status_code_1=WXETStatusCode1.BUYER_FAIL, + finish_type=FinishType.FAIL_OR_COMPLETE, + ) + + # If it's a WXET Fail, the Worker never made it into a WXET Task + # experience, so it should not be considered a Finish + assert not is_a_finish( + status=WXETStatus.FAIL, + status_code_1=WXETStatusCode1.WXET_FAIL, + finish_type=FinishType.FAIL_OR_COMPLETE, + ) + + assert is_a_finish( + status=WXETStatus.COMPLETE, + status_code_1=WXETStatusCode1.COMPLETE, + finish_type=FinishType.FAIL_OR_COMPLETE, + ) + + def test_init_ft_fail(self): + assert is_a_finish( + status=WXETStatus.FAIL, + status_code_1=None, + finish_type=FinishType.FAIL, + ) + + assert is_a_finish( + status=WXETStatus.FAIL, + status_code_1=WXETStatusCode1.BUYER_FAIL, + finish_type=FinishType.FAIL, + ) + + def test_invalid_status_code_1(self): + for ft in FinishType: + for s in WXETStatus: + with pytest.raises(expected_exception=AssertionError) as cm: + is_a_finish( + status=s, + status_code_1=WXETStatus.COMPLETE, + finish_type=ft, + ) + assert "Invalid status_code_1" == str(cm.value) + + def test_invalid_none_status(self): + for ft in FinishType: + for sc1 in WXETStatusCode1: + with pytest.raises(expected_exception=AssertionError) as cm: + is_a_finish( + status=None, + status_code_1=sc1, + finish_type=ft, + ) + assert "Cannot provide status_code_1 without a status" == str(cm.value) |
