From 91d040211a4ed6e4157896256a762d3854777b5e Mon Sep 17 00:00:00 2001 From: Max Nanis Date: Fri, 6 Mar 2026 16:49:46 -0500 Subject: Initial commit --- .gitignore | 10 + Jenkinsfile | 295 +++ LICENSE | 21 + README.md | 61 + generalresearch/__init__.py | 146 ++ generalresearch/cacheing.py | 47 + generalresearch/config.py | 109 ++ generalresearch/currency.py | 156 ++ generalresearch/decorators.py | 10 + generalresearch/grliq/__init__.py | 0 generalresearch/grliq/managers/__init__.py | 34 + generalresearch/grliq/managers/colormap.py | 263 +++ generalresearch/grliq/managers/event_plotter.py | 158 ++ generalresearch/grliq/managers/forensic_data.py | 794 ++++++++ generalresearch/grliq/managers/forensic_events.py | 290 +++ generalresearch/grliq/managers/forensic_results.py | 104 ++ generalresearch/grliq/managers/forensic_summary.py | 175 ++ generalresearch/grliq/models/__init__.py | 66 + generalresearch/grliq/models/custom_types.py | 6 + generalresearch/grliq/models/decider.py | 53 + generalresearch/grliq/models/events.py | 250 +++ generalresearch/grliq/models/forensic_data.py | 801 ++++++++ generalresearch/grliq/models/forensic_result.py | 288 +++ generalresearch/grliq/models/forensic_summary.py | 282 +++ generalresearch/grliq/models/useragents.py | 246 +++ generalresearch/grliq/utils.py | 36 + generalresearch/grpc.py | 46 + generalresearch/healing_ppe.py | 77 + generalresearch/incite/__init__.py | 0 generalresearch/incite/base.py | 980 ++++++++++ generalresearch/incite/collections/__init__.py | 752 ++++++++ .../incite/collections/thl_marketplaces.py | 37 + generalresearch/incite/collections/thl_web.py | 41 + generalresearch/incite/defaults.py | 196 ++ generalresearch/incite/mergers/__init__.py | 316 ++++ generalresearch/incite/mergers/account_blocks.py | 189 ++ .../incite/mergers/foundations/__init__.py | 167 ++ .../incite/mergers/foundations/enriched_session.py | 331 ++++ .../mergers/foundations/enriched_task_adjust.py | 211 +++ .../incite/mergers/foundations/enriched_wall.py | 336 ++++ .../incite/mergers/foundations/user_id_product.py | 49 + generalresearch/incite/mergers/nginx_core.py | 146 ++ generalresearch/incite/mergers/nginx_fsb.py | 151 ++ generalresearch/incite/mergers/nginx_grs.py | 141 ++ generalresearch/incite/mergers/pop_ledger.py | 131 ++ generalresearch/incite/mergers/ym_survey_wall.py | 149 ++ generalresearch/incite/mergers/ym_wall_summary.py | 195 ++ generalresearch/incite/schemas/__init__.py | 29 + generalresearch/incite/schemas/admin_responses.py | 186 ++ generalresearch/incite/schemas/mergers/__init__.py | 27 + .../incite/schemas/mergers/foundations/__init__.py | 0 .../mergers/foundations/enriched_session.py | 36 + .../mergers/foundations/enriched_task_adjust.py | 98 + .../schemas/mergers/foundations/enriched_wall.py | 144 ++ .../schemas/mergers/foundations/user_id_product.py | 27 + generalresearch/incite/schemas/mergers/nginx.py | 140 ++ .../incite/schemas/mergers/pop_ledger.py | 64 + .../incite/schemas/mergers/ym_survey_wall.py | 101 + .../incite/schemas/mergers/ym_wall_summary.py | 74 + generalresearch/incite/schemas/thl_marketplaces.py | 64 + generalresearch/incite/schemas/thl_web.py | 803 ++++++++ generalresearch/locales/__init__.py | 96 + generalresearch/locales/country_default_lang.json | 250 +++ generalresearch/locales/iso3166-1.json | 1675 +++++++++++++++++ generalresearch/locales/iso639-3.json | 1117 +++++++++++ generalresearch/locales/setup_json.py | 61 + generalresearch/locales/timezone.py | 77 + generalresearch/logging.py | 21 + generalresearch/managers/__init__.py | 16 + generalresearch/managers/base.py | 91 + generalresearch/managers/cint/__init__.py | 0 generalresearch/managers/cint/profiling.py | 62 + generalresearch/managers/cint/survey.py | 145 ++ generalresearch/managers/cint/user_pid.py | 7 + generalresearch/managers/criteria.py | 105 ++ generalresearch/managers/dynata/__init__.py | 0 generalresearch/managers/dynata/profiling.py | 63 + generalresearch/managers/dynata/survey.py | 155 ++ generalresearch/managers/dynata/user_pid.py | 7 + generalresearch/managers/events.py | 863 +++++++++ generalresearch/managers/gr/__init__.py | 0 generalresearch/managers/gr/authentication.py | 331 ++++ generalresearch/managers/gr/business.py | 529 ++++++ generalresearch/managers/gr/team.py | 312 ++++ generalresearch/managers/innovate/__init__.py | 0 generalresearch/managers/innovate/profiling.py | 62 + generalresearch/managers/innovate/survey.py | 179 ++ generalresearch/managers/innovate/user_pid.py | 7 + generalresearch/managers/leaderboard/__init__.py | 37 + generalresearch/managers/leaderboard/manager.py | 212 +++ generalresearch/managers/leaderboard/tasks.py | 59 + generalresearch/managers/lucid/__init__.py | 0 generalresearch/managers/lucid/profiling.py | 83 + generalresearch/managers/marketplace/__init__.py | 23 + generalresearch/managers/marketplace/user_pid.py | 96 + generalresearch/managers/morning/__init__.py | 0 generalresearch/managers/morning/profiling.py | 67 + generalresearch/managers/morning/survey.py | 262 +++ generalresearch/managers/morning/user_pid.py | 7 + generalresearch/managers/pollfish/__init__.py | 0 generalresearch/managers/pollfish/profiling.py | 62 + generalresearch/managers/pollfish/user_pid.py | 7 + generalresearch/managers/precision/__init__.py | 0 generalresearch/managers/precision/profiling.py | 62 + generalresearch/managers/precision/survey.py | 243 +++ generalresearch/managers/precision/user_pid.py | 7 + generalresearch/managers/prodege/__init__.py | 0 generalresearch/managers/prodege/profiling.py | 62 + generalresearch/managers/prodege/survey.py | 155 ++ generalresearch/managers/prodege/user_pid.py | 7 + generalresearch/managers/repdata/__init__.py | 0 generalresearch/managers/repdata/profiling.py | 62 + generalresearch/managers/repdata/survey.py | 185 ++ generalresearch/managers/repdata/user_pid.py | 7 + generalresearch/managers/sago/__init__.py | 0 generalresearch/managers/sago/profiling.py | 62 + generalresearch/managers/sago/survey.py | 186 ++ generalresearch/managers/sago/user_pid.py | 7 + generalresearch/managers/spectrum/__init__.py | 0 generalresearch/managers/spectrum/profiling.py | 62 + generalresearch/managers/spectrum/survey.py | 218 +++ generalresearch/managers/spectrum/user_pid.py | 7 + generalresearch/managers/survey.py | 27 + generalresearch/managers/thl/__init__.py | 0 generalresearch/managers/thl/buyer.py | 113 ++ generalresearch/managers/thl/cashout_method.py | 295 +++ generalresearch/managers/thl/category.py | 56 + generalresearch/managers/thl/contest_manager.py | 1080 +++++++++++ generalresearch/managers/thl/delete_request.py | 178 ++ generalresearch/managers/thl/ipinfo.py | 819 ++++++++ .../managers/thl/ledger_manager/__init__.py | 0 .../managers/thl/ledger_manager/conditions.py | 217 +++ .../managers/thl/ledger_manager/exceptions.py | 49 + .../managers/thl/ledger_manager/ledger.py | 1139 +++++++++++ .../managers/thl/ledger_manager/thl_ledger.py | 1968 ++++++++++++++++++++ generalresearch/managers/thl/maxmind/__init__.py | 162 ++ generalresearch/managers/thl/maxmind/basic.py | 134 ++ generalresearch/managers/thl/maxmind/insights.py | 52 + generalresearch/managers/thl/payout.py | 1256 +++++++++++++ generalresearch/managers/thl/product.py | 570 ++++++ generalresearch/managers/thl/profiling/__init__.py | 0 generalresearch/managers/thl/profiling/question.py | 157 ++ generalresearch/managers/thl/profiling/schema.py | 75 + generalresearch/managers/thl/profiling/uqa.py | 211 +++ generalresearch/managers/thl/profiling/user_upk.py | 343 ++++ generalresearch/managers/thl/session.py | 669 +++++++ generalresearch/managers/thl/survey.py | 791 ++++++++ generalresearch/managers/thl/survey_penalty.py | 112 ++ generalresearch/managers/thl/task_adjustment.py | 187 ++ generalresearch/managers/thl/user_compensate.py | 89 + .../managers/thl/user_manager/__init__.py | 108 ++ .../thl/user_manager/memcached_user_manager.py | 49 + .../thl/user_manager/mysql_user_manager.py | 287 +++ .../managers/thl/user_manager/rate_limit.py | 76 + .../thl/user_manager/redis_user_manager.py | 88 + .../managers/thl/user_manager/user_manager.py | 378 ++++ .../thl/user_manager/user_metadata_manager.py | 141 ++ generalresearch/managers/thl/user_streak.py | 150 ++ generalresearch/managers/thl/userhealth.py | 579 ++++++ generalresearch/managers/thl/wall.py | 675 +++++++ generalresearch/managers/thl/wallet/__init__.py | 147 ++ generalresearch/managers/thl/wallet/approve.py | 52 + generalresearch/managers/thl/wallet/tango.py | 162 ++ generalresearch/mariadb.py | 42 + generalresearch/models/__init__.py | 114 ++ generalresearch/models/admin/__init__.py | 59 + generalresearch/models/admin/request.py | 157 ++ generalresearch/models/cint/__init__.py | 7 + generalresearch/models/cint/question.py | 244 +++ generalresearch/models/cint/survey.py | 532 ++++++ generalresearch/models/cint/task_collection.py | 71 + generalresearch/models/custom_types.py | 282 +++ generalresearch/models/device.py | 15 + generalresearch/models/dynata/__init__.py | 7 + generalresearch/models/dynata/question.py | 269 +++ generalresearch/models/dynata/survey.py | 656 +++++++ generalresearch/models/dynata/task_collection.py | 86 + generalresearch/models/events.py | 299 +++ generalresearch/models/gr/__init__.py | 13 + generalresearch/models/gr/authentication.py | 375 ++++ generalresearch/models/gr/business.py | 743 ++++++++ generalresearch/models/gr/team.py | 346 ++++ generalresearch/models/innovate/__init__.py | 38 + generalresearch/models/innovate/question.py | 244 +++ generalresearch/models/innovate/survey.py | 491 +++++ generalresearch/models/innovate/task_collection.py | 97 + generalresearch/models/legacy/__init__.py | 0 generalresearch/models/legacy/api_status.py | 70 + generalresearch/models/legacy/bucket.py | 772 ++++++++ generalresearch/models/legacy/definitions.py | 11 + generalresearch/models/legacy/offerwall.py | 349 ++++ generalresearch/models/legacy/questions.py | 254 +++ generalresearch/models/lucid/__init__.py | 7 + generalresearch/models/lucid/question.py | 158 ++ generalresearch/models/lucid/survey.py | 105 ++ generalresearch/models/marketplace/__init__.py | 0 generalresearch/models/marketplace/summary.py | 150 ++ generalresearch/models/morning/__init__.py | 16 + generalresearch/models/morning/question.py | 207 ++ generalresearch/models/morning/survey.py | 556 ++++++ generalresearch/models/morning/task_collection.py | 140 ++ generalresearch/models/pollfish/__init__.py | 0 generalresearch/models/pollfish/question.py | 140 ++ generalresearch/models/precision/__init__.py | 16 + generalresearch/models/precision/definitions.py | 322 ++++ generalresearch/models/precision/question.py | 199 ++ generalresearch/models/precision/survey.py | 375 ++++ .../models/precision/task_collection.py | 82 + generalresearch/models/prodege/__init__.py | 37 + generalresearch/models/prodege/definitions.py | 187 ++ generalresearch/models/prodege/question.py | 243 +++ generalresearch/models/prodege/survey.py | 747 ++++++++ generalresearch/models/prodege/task_collection.py | 97 + generalresearch/models/repdata/__init__.py | 16 + generalresearch/models/repdata/question.py | 255 +++ generalresearch/models/repdata/survey.py | 565 ++++++ generalresearch/models/repdata/task_collection.py | 132 ++ generalresearch/models/sago/__init__.py | 13 + generalresearch/models/sago/question.py | 284 +++ generalresearch/models/sago/survey.py | 417 +++++ generalresearch/models/sago/task_collection.py | 81 + generalresearch/models/spectrum/__init__.py | 15 + generalresearch/models/spectrum/question.py | 371 ++++ generalresearch/models/spectrum/survey.py | 514 +++++ generalresearch/models/spectrum/task_collection.py | 110 ++ generalresearch/models/string_utils.py | 12 + generalresearch/models/thl/__init__.py | 34 + generalresearch/models/thl/category.py | 62 + generalresearch/models/thl/contest/__init__.py | 143 ++ generalresearch/models/thl/contest/contest.py | 223 +++ .../models/thl/contest/contest_entry.py | 109 ++ generalresearch/models/thl/contest/definitions.py | 101 + generalresearch/models/thl/contest/examples.py | 404 ++++ generalresearch/models/thl/contest/exceptions.py | 2 + generalresearch/models/thl/contest/io.py | 47 + generalresearch/models/thl/contest/leaderboard.py | 289 +++ generalresearch/models/thl/contest/milestone.py | 226 +++ generalresearch/models/thl/contest/raffle.py | 317 ++++ generalresearch/models/thl/contest/utils.py | 76 + generalresearch/models/thl/definitions.py | 343 ++++ generalresearch/models/thl/demographics.py | 180 ++ generalresearch/models/thl/finance.py | 881 +++++++++ generalresearch/models/thl/grliq.py | 10 + generalresearch/models/thl/ipinfo.py | 348 ++++ generalresearch/models/thl/leaderboard.py | 349 ++++ generalresearch/models/thl/ledger.py | 625 +++++++ generalresearch/models/thl/ledger_example.py | 62 + generalresearch/models/thl/locales.py | 32 + generalresearch/models/thl/maxmind/__init__.py | 0 generalresearch/models/thl/maxmind/definitions.py | 22 + generalresearch/models/thl/offerwall/__init__.py | 321 ++++ generalresearch/models/thl/offerwall/base.py | 685 +++++++ generalresearch/models/thl/offerwall/behavior.py | 45 + generalresearch/models/thl/offerwall/bucket.py | 20 + generalresearch/models/thl/offerwall/cache.py | 59 + generalresearch/models/thl/pagination.py | 22 + generalresearch/models/thl/payout.py | 353 ++++ generalresearch/models/thl/payout_format.py | 96 + generalresearch/models/thl/product.py | 1427 ++++++++++++++ generalresearch/models/thl/profiling/__init__.py | 0 .../models/thl/profiling/marketplace.py | 127 ++ .../models/thl/profiling/other_option.py | 56 + generalresearch/models/thl/profiling/question.py | 46 + .../models/thl/profiling/upk_property.py | 94 + .../models/thl/profiling/upk_question.py | 683 +++++++ .../models/thl/profiling/upk_question_answer.py | 116 ++ generalresearch/models/thl/profiling/user_info.py | 76 + .../models/thl/profiling/user_question_answer.py | 160 ++ generalresearch/models/thl/report_task.py | 50 + generalresearch/models/thl/session.py | 1344 +++++++++++++ generalresearch/models/thl/soft_pair.py | 67 + generalresearch/models/thl/stats.py | 43 + generalresearch/models/thl/supplier_tag.py | 16 + generalresearch/models/thl/survey/__init__.py | 225 +++ generalresearch/models/thl/survey/buyer.py | 218 +++ generalresearch/models/thl/survey/condition.py | 337 ++++ generalresearch/models/thl/survey/model.py | 321 ++++ generalresearch/models/thl/survey/penalty.py | 63 + .../models/thl/survey/task_collection.py | 60 + .../models/thl/synchronize_global_vars.py | 17 + generalresearch/models/thl/task_adjustment.py | 89 + generalresearch/models/thl/task_status.py | 292 +++ generalresearch/models/thl/user.py | 323 ++++ generalresearch/models/thl/user_iphistory.py | 245 +++ generalresearch/models/thl/user_profile.py | 122 ++ generalresearch/models/thl/user_quality_event.py | 90 + generalresearch/models/thl/user_streak.py | 152 ++ generalresearch/models/thl/userhealth.py | 77 + generalresearch/models/thl/wallet/__init__.py | 87 + .../models/thl/wallet/cashout_method.py | 443 +++++ generalresearch/models/thl/wallet/payout.py | 214 +++ generalresearch/models/thl/wallet/user_wallet.py | 42 + generalresearch/models/utils.py | 9 + generalresearch/pg_helper.py | 124 ++ generalresearch/priority_thread_pool.py | 67 + generalresearch/redis_helper.py | 33 + generalresearch/resources/__init__.py | 0 generalresearch/schemas/__init__.py | 0 generalresearch/schemas/survey_stats.py | 159 ++ generalresearch/sql_helper.py | 351 ++++ generalresearch/thl_django/README.md | 70 + generalresearch/thl_django/__init__.py | 1 + generalresearch/thl_django/accounting/__init__.py | 0 generalresearch/thl_django/accounting/models.py | 44 + generalresearch/thl_django/app/__init__.py | 0 generalresearch/thl_django/app/manage.py | 11 + generalresearch/thl_django/app/settings.py | 23 + generalresearch/thl_django/apps.py | 15 + generalresearch/thl_django/common/__init__.py | 0 generalresearch/thl_django/common/models.py | 745 ++++++++ generalresearch/thl_django/contest/__init__.py | 0 generalresearch/thl_django/contest/models.py | 115 ++ generalresearch/thl_django/event/__init__.py | 0 generalresearch/thl_django/event/models.py | 91 + generalresearch/thl_django/marketplace/__init__.py | 0 generalresearch/thl_django/marketplace/models.py | 757 ++++++++ .../thl_django/migrations/0001_initial.py | 1066 +++++++++++ ..._live_alter_surveycategory_strength_and_more.py | 35 + ...rveystat_surveystat_live_survey_idx_and_more.py | 45 + ...004_alter_surveystat_survey_is_live_and_more.py | 28 + ...ve_surveystat_marketplace_updated_439a2d_idx.py | 17 + ...ssion_thl_session_status_d578b7_idx_and_more.py | 35 + .../thl_django/migrations/0007_table_params.py | 68 + ...stion_explanation_fragment_template_and_more.py | 23 + generalresearch/thl_django/migrations/__init__.py | 0 .../thl_django/postgres-table-tuning.md | 62 + generalresearch/thl_django/postgres.md | 143 ++ generalresearch/thl_django/userhealth/__init__.py | 0 generalresearch/thl_django/userhealth/models.py | 127 ++ generalresearch/thl_django/userprofile/__init__.py | 0 generalresearch/thl_django/userprofile/models.py | 178 ++ generalresearch/utils/__init__.py | 0 generalresearch/utils/aggregation.py | 14 + generalresearch/utils/copying_cache.py | 21 + generalresearch/utils/enum.py | 61 + generalresearch/utils/grpc_logger.py | 78 + generalresearch/wall_status_codes/__init__.py | 105 ++ generalresearch/wall_status_codes/cint.py | 15 + generalresearch/wall_status_codes/dynata.py | 128 ++ generalresearch/wall_status_codes/fullcircle.py | 58 + generalresearch/wall_status_codes/innovate.py | 117 ++ generalresearch/wall_status_codes/lucid.py | 128 ++ generalresearch/wall_status_codes/morning.py | 122 ++ generalresearch/wall_status_codes/pollfish.py | 80 + generalresearch/wall_status_codes/precision.py | 96 + generalresearch/wall_status_codes/prodege.py | 56 + generalresearch/wall_status_codes/repdata.py | 89 + generalresearch/wall_status_codes/sago.py | 193 ++ generalresearch/wall_status_codes/spectrum.py | 165 ++ generalresearch/wall_status_codes/wxet.py | 87 + generalresearch/wxet/__init__.py | 0 generalresearch/wxet/models/__init__.py | 0 generalresearch/wxet/models/definitions.py | 320 ++++ generalresearch/wxet/models/finish_type.py | 94 + mypy.ini | 6 + pyproject.toml | 57 + requirements.txt | 108 ++ test_utils/__init__.py | 0 test_utils/conftest.py | 310 +++ test_utils/grliq/__init__.py | 0 test_utils/grliq/conftest.py | 28 + test_utils/grliq/managers/__init__.py | 0 test_utils/grliq/managers/conftest.py | 0 test_utils/grliq/models/__init__.py | 0 test_utils/grliq/models/conftest.py | 0 test_utils/incite/__init__.py | 0 test_utils/incite/collections/__init__.py | 0 test_utils/incite/collections/conftest.py | 205 ++ test_utils/incite/conftest.py | 201 ++ test_utils/incite/mergers/__init__.py | 0 test_utils/incite/mergers/conftest.py | 247 +++ test_utils/managers/__init__.py | 0 test_utils/managers/cashout_methods.py | 76 + test_utils/managers/conftest.py | 701 +++++++ test_utils/managers/contest/__init__.py | 0 test_utils/managers/contest/conftest.py | 295 +++ test_utils/managers/ledger/__init__.py | 0 test_utils/managers/ledger/conftest.py | 678 +++++++ test_utils/managers/upk/__init__.py | 0 test_utils/managers/upk/conftest.py | 161 ++ .../managers/upk/marketplace_category.csv.gz | Bin 0 -> 100990 bytes test_utils/managers/upk/marketplace_item.csv.gz | Bin 0 -> 3225 bytes .../managers/upk/marketplace_property.csv.gz | Bin 0 -> 3315 bytes .../marketplace_propertycategoryassociation.csv.gz | Bin 0 -> 2079 bytes .../upk/marketplace_propertycountry.csv.gz | Bin 0 -> 71359 bytes .../upk/marketplace_propertyitemrange.csv.gz | Bin 0 -> 65389 bytes ...rketplace_propertymarketplaceassociation.csv.gz | Bin 0 -> 4272 bytes .../managers/upk/marketplace_question.csv.gz | Bin 0 -> 283465 bytes test_utils/models/__init__.py | 0 test_utils/models/conftest.py | 608 ++++++ test_utils/spectrum/__init__.py | 0 test_utils/spectrum/conftest.py | 79 + test_utils/spectrum/surveys_json.py | 140 ++ tests/__init__.py | 0 tests/conftest.py | 19 + tests/grliq/__init__.py | 0 tests/grliq/managers/__init__.py | 0 tests/grliq/managers/test_forensic_data.py | 212 +++ tests/grliq/managers/test_forensic_results.py | 16 + tests/grliq/models/__init__.py | 0 tests/grliq/models/test_forensic_data.py | 49 + tests/grliq/test_utils.py | 17 + tests/incite/__init__.py | 137 ++ tests/incite/collections/__init__.py | 0 .../incite/collections/test_df_collection_base.py | 113 ++ .../collections/test_df_collection_item_base.py | 72 + .../collections/test_df_collection_item_thl_web.py | 994 ++++++++++ .../test_df_collection_thl_marketplaces.py | 75 + .../collections/test_df_collection_thl_web.py | 160 ++ .../test_df_collection_thl_web_ledger.py | 32 + tests/incite/mergers/__init__.py | 0 tests/incite/mergers/foundations/__init__.py | 0 .../mergers/foundations/test_enriched_session.py | 138 ++ .../foundations/test_enriched_task_adjust.py | 76 + .../mergers/foundations/test_enriched_wall.py | 236 +++ .../mergers/foundations/test_user_id_product.py | 73 + tests/incite/mergers/test_merge_collection.py | 102 + tests/incite/mergers/test_merge_collection_item.py | 66 + tests/incite/mergers/test_pop_ledger.py | 307 +++ tests/incite/mergers/test_ym_survey_merge.py | 125 ++ tests/incite/schemas/__init__.py | 0 tests/incite/schemas/test_admin_responses.py | 239 +++ tests/incite/schemas/test_thl_web.py | 70 + tests/incite/test_collection_base.py | 318 ++++ tests/incite/test_collection_base_item.py | 223 +++ tests/incite/test_grl_flow.py | 23 + tests/incite/test_interval_idx.py | 23 + tests/managers/__init__.py | 0 tests/managers/gr/__init__.py | 0 tests/managers/gr/test_authentication.py | 125 ++ tests/managers/gr/test_business.py | 150 ++ tests/managers/gr/test_team.py | 125 ++ tests/managers/leaderboard.py | 274 +++ tests/managers/test_events.py | 530 ++++++ tests/managers/test_lucid.py | 23 + tests/managers/test_userpid.py | 68 + tests/managers/thl/__init__.py | 0 tests/managers/thl/test_buyer.py | 25 + tests/managers/thl/test_cashout_method.py | 139 ++ tests/managers/thl/test_category.py | 100 + tests/managers/thl/test_contest/__init__.py | 0 .../managers/thl/test_contest/test_leaderboard.py | 138 ++ tests/managers/thl/test_contest/test_milestone.py | 296 +++ tests/managers/thl/test_contest/test_raffle.py | 474 +++++ tests/managers/thl/test_harmonized_uqa.py | 116 ++ tests/managers/thl/test_ipinfo.py | 117 ++ tests/managers/thl/test_ledger/__init__.py | 0 tests/managers/thl/test_ledger/test_lm_accounts.py | 268 +++ tests/managers/thl/test_ledger/test_lm_tx.py | 235 +++ .../managers/thl/test_ledger/test_lm_tx_entries.py | 26 + tests/managers/thl/test_ledger/test_lm_tx_locks.py | 371 ++++ .../thl/test_ledger/test_lm_tx_metadata.py | 34 + .../thl/test_ledger/test_thl_lm_accounts.py | 411 ++++ .../thl/test_ledger/test_thl_lm_bp_payout.py | 516 +++++ tests/managers/thl/test_ledger/test_thl_lm_tx.py | 1762 ++++++++++++++++++ .../test_ledger/test_thl_lm_tx__user_payouts.py | 505 +++++ tests/managers/thl/test_ledger/test_thl_pem.py | 251 +++ tests/managers/thl/test_ledger/test_user_txs.py | 288 +++ tests/managers/thl/test_ledger/test_wallet.py | 78 + tests/managers/thl/test_maxmind.py | 273 +++ tests/managers/thl/test_payout.py | 1269 +++++++++++++ tests/managers/thl/test_product.py | 362 ++++ tests/managers/thl/test_product_prod.py | 82 + tests/managers/thl/test_profiling/__init__.py | 0 tests/managers/thl/test_profiling/test_question.py | 49 + tests/managers/thl/test_profiling/test_schema.py | 44 + tests/managers/thl/test_profiling/test_uqa.py | 1 + tests/managers/thl/test_profiling/test_user_upk.py | 59 + tests/managers/thl/test_session_manager.py | 137 ++ tests/managers/thl/test_survey.py | 376 ++++ tests/managers/thl/test_survey_penalty.py | 101 + tests/managers/thl/test_task_adjustment.py | 346 ++++ tests/managers/thl/test_task_status.py | 696 +++++++ tests/managers/thl/test_user_manager/__init__.py | 0 tests/managers/thl/test_user_manager/test_base.py | 274 +++ tests/managers/thl/test_user_manager/test_mysql.py | 25 + tests/managers/thl/test_user_manager/test_redis.py | 80 + .../thl/test_user_manager/test_user_fetch.py | 48 + .../thl/test_user_manager/test_user_metadata.py | 88 + tests/managers/thl/test_user_streak.py | 225 +++ tests/managers/thl/test_userhealth.py | 367 ++++ tests/managers/thl/test_wall_manager.py | 283 +++ tests/models/__init__.py | 0 tests/models/admin/__init__.py | 0 tests/models/admin/test_report_request.py | 163 ++ tests/models/custom_types/__init__.py | 0 tests/models/custom_types/test_aware_datetime.py | 82 + tests/models/custom_types/test_dsn.py | 112 ++ tests/models/custom_types/test_therest.py | 42 + tests/models/custom_types/test_uuid_str.py | 51 + tests/models/dynata/__init__.py | 0 tests/models/dynata/test_eligbility.py | 324 ++++ tests/models/dynata/test_survey.py | 164 ++ tests/models/gr/__init__.py | 0 tests/models/gr/test_authentication.py | 313 ++++ tests/models/gr/test_business.py | 1432 ++++++++++++++ tests/models/gr/test_team.py | 296 +++ tests/models/innovate/__init__.py | 0 tests/models/innovate/test_question.py | 85 + tests/models/legacy/__init__.py | 0 tests/models/legacy/data.py | 265 +++ .../models/legacy/test_offerwall_parse_response.py | 186 ++ tests/models/legacy/test_profiling_questions.py | 81 + .../models/legacy/test_user_question_answer_in.py | 304 +++ tests/models/morning/__init__.py | 0 tests/models/morning/test.py | 199 ++ tests/models/precision/__init__.py | 115 ++ tests/models/precision/test_survey.py | 88 + tests/models/precision/test_survey_manager.py | 63 + tests/models/prodege/__init__.py | 0 tests/models/prodege/test_survey_participation.py | 120 ++ tests/models/spectrum/__init__.py | 0 tests/models/spectrum/test_question.py | 216 +++ tests/models/spectrum/test_survey.py | 413 ++++ tests/models/spectrum/test_survey_manager.py | 130 ++ tests/models/test_currency.py | 410 ++++ tests/models/test_device.py | 27 + tests/models/test_finance.py | 929 +++++++++ tests/models/thl/__init__.py | 1 + tests/models/thl/question/__init__.py | 0 tests/models/thl/question/test_question_info.py | 146 ++ tests/models/thl/question/test_user_info.py | 32 + tests/models/thl/test_adjustments.py | 688 +++++++ tests/models/thl/test_bucket.py | 201 ++ tests/models/thl/test_buyer.py | 23 + tests/models/thl/test_contest/__init__.py | 0 tests/models/thl/test_contest/test_contest.py | 23 + .../thl/test_contest/test_leaderboard_contest.py | 213 +++ .../models/thl/test_contest/test_raffle_contest.py | 300 +++ tests/models/thl/test_ledger.py | 130 ++ tests/models/thl/test_marketplace_condition.py | 382 ++++ tests/models/thl/test_payout.py | 10 + tests/models/thl/test_payout_format.py | 46 + tests/models/thl/test_product.py | 1130 +++++++++++ tests/models/thl/test_product_userwalletconfig.py | 56 + tests/models/thl/test_soft_pair.py | 24 + tests/models/thl/test_upkquestion.py | 414 ++++ tests/models/thl/test_user.py | 688 +++++++ tests/models/thl/test_user_iphistory.py | 45 + tests/models/thl/test_user_metadata.py | 46 + tests/models/thl/test_user_streak.py | 96 + tests/models/thl/test_wall.py | 207 ++ tests/models/thl/test_wall_session.py | 326 ++++ tests/pytest.ini | 2 + tests/sql_helper.py | 53 + tests/wall_status_codes/__init__.py | 0 tests/wall_status_codes/test_analyze.py | 150 ++ tests/wxet/__init__.py | 0 tests/wxet/models/__init__.py | 0 tests/wxet/models/test_definitions.py | 113 ++ tests/wxet/models/test_finish_type.py | 136 ++ 551 files changed, 99879 insertions(+) create mode 100644 .gitignore create mode 100644 Jenkinsfile create mode 100644 LICENSE create mode 100644 README.md create mode 100644 generalresearch/__init__.py create mode 100644 generalresearch/cacheing.py create mode 100644 generalresearch/config.py create mode 100644 generalresearch/currency.py create mode 100644 generalresearch/decorators.py create mode 100644 generalresearch/grliq/__init__.py create mode 100644 generalresearch/grliq/managers/__init__.py create mode 100644 generalresearch/grliq/managers/colormap.py create mode 100644 generalresearch/grliq/managers/event_plotter.py create mode 100644 generalresearch/grliq/managers/forensic_data.py create mode 100644 generalresearch/grliq/managers/forensic_events.py create mode 100644 generalresearch/grliq/managers/forensic_results.py create mode 100644 generalresearch/grliq/managers/forensic_summary.py create mode 100644 generalresearch/grliq/models/__init__.py create mode 100644 generalresearch/grliq/models/custom_types.py create mode 100644 generalresearch/grliq/models/decider.py create mode 100644 generalresearch/grliq/models/events.py create mode 100644 generalresearch/grliq/models/forensic_data.py create mode 100644 generalresearch/grliq/models/forensic_result.py create mode 100644 generalresearch/grliq/models/forensic_summary.py create mode 100644 generalresearch/grliq/models/useragents.py create mode 100644 generalresearch/grliq/utils.py create mode 100644 generalresearch/grpc.py create mode 100644 generalresearch/healing_ppe.py create mode 100644 generalresearch/incite/__init__.py create mode 100644 generalresearch/incite/base.py create mode 100644 generalresearch/incite/collections/__init__.py create mode 100644 generalresearch/incite/collections/thl_marketplaces.py create mode 100644 generalresearch/incite/collections/thl_web.py create mode 100644 generalresearch/incite/defaults.py create mode 100644 generalresearch/incite/mergers/__init__.py create mode 100644 generalresearch/incite/mergers/account_blocks.py create mode 100644 generalresearch/incite/mergers/foundations/__init__.py create mode 100644 generalresearch/incite/mergers/foundations/enriched_session.py create mode 100644 generalresearch/incite/mergers/foundations/enriched_task_adjust.py create mode 100644 generalresearch/incite/mergers/foundations/enriched_wall.py create mode 100644 generalresearch/incite/mergers/foundations/user_id_product.py create mode 100644 generalresearch/incite/mergers/nginx_core.py create mode 100644 generalresearch/incite/mergers/nginx_fsb.py create mode 100644 generalresearch/incite/mergers/nginx_grs.py create mode 100644 generalresearch/incite/mergers/pop_ledger.py create mode 100644 generalresearch/incite/mergers/ym_survey_wall.py create mode 100644 generalresearch/incite/mergers/ym_wall_summary.py create mode 100644 generalresearch/incite/schemas/__init__.py create mode 100644 generalresearch/incite/schemas/admin_responses.py create mode 100644 generalresearch/incite/schemas/mergers/__init__.py create mode 100644 generalresearch/incite/schemas/mergers/foundations/__init__.py create mode 100644 generalresearch/incite/schemas/mergers/foundations/enriched_session.py create mode 100644 generalresearch/incite/schemas/mergers/foundations/enriched_task_adjust.py create mode 100644 generalresearch/incite/schemas/mergers/foundations/enriched_wall.py create mode 100644 generalresearch/incite/schemas/mergers/foundations/user_id_product.py create mode 100644 generalresearch/incite/schemas/mergers/nginx.py create mode 100644 generalresearch/incite/schemas/mergers/pop_ledger.py create mode 100644 generalresearch/incite/schemas/mergers/ym_survey_wall.py create mode 100644 generalresearch/incite/schemas/mergers/ym_wall_summary.py create mode 100644 generalresearch/incite/schemas/thl_marketplaces.py create mode 100644 generalresearch/incite/schemas/thl_web.py create mode 100644 generalresearch/locales/__init__.py create mode 100644 generalresearch/locales/country_default_lang.json create mode 100644 generalresearch/locales/iso3166-1.json create mode 100644 generalresearch/locales/iso639-3.json create mode 100644 generalresearch/locales/setup_json.py create mode 100644 generalresearch/locales/timezone.py create mode 100644 generalresearch/logging.py create mode 100644 generalresearch/managers/__init__.py create mode 100644 generalresearch/managers/base.py create mode 100644 generalresearch/managers/cint/__init__.py create mode 100644 generalresearch/managers/cint/profiling.py create mode 100644 generalresearch/managers/cint/survey.py create mode 100644 generalresearch/managers/cint/user_pid.py create mode 100644 generalresearch/managers/criteria.py create mode 100644 generalresearch/managers/dynata/__init__.py create mode 100644 generalresearch/managers/dynata/profiling.py create mode 100644 generalresearch/managers/dynata/survey.py create mode 100644 generalresearch/managers/dynata/user_pid.py create mode 100644 generalresearch/managers/events.py create mode 100644 generalresearch/managers/gr/__init__.py create mode 100644 generalresearch/managers/gr/authentication.py create mode 100644 generalresearch/managers/gr/business.py create mode 100644 generalresearch/managers/gr/team.py create mode 100644 generalresearch/managers/innovate/__init__.py create mode 100644 generalresearch/managers/innovate/profiling.py create mode 100644 generalresearch/managers/innovate/survey.py create mode 100644 generalresearch/managers/innovate/user_pid.py create mode 100644 generalresearch/managers/leaderboard/__init__.py create mode 100644 generalresearch/managers/leaderboard/manager.py create mode 100644 generalresearch/managers/leaderboard/tasks.py create mode 100644 generalresearch/managers/lucid/__init__.py create mode 100644 generalresearch/managers/lucid/profiling.py create mode 100644 generalresearch/managers/marketplace/__init__.py create mode 100644 generalresearch/managers/marketplace/user_pid.py create mode 100644 generalresearch/managers/morning/__init__.py create mode 100644 generalresearch/managers/morning/profiling.py create mode 100644 generalresearch/managers/morning/survey.py create mode 100644 generalresearch/managers/morning/user_pid.py create mode 100644 generalresearch/managers/pollfish/__init__.py create mode 100644 generalresearch/managers/pollfish/profiling.py create mode 100644 generalresearch/managers/pollfish/user_pid.py create mode 100644 generalresearch/managers/precision/__init__.py create mode 100644 generalresearch/managers/precision/profiling.py create mode 100644 generalresearch/managers/precision/survey.py create mode 100644 generalresearch/managers/precision/user_pid.py create mode 100644 generalresearch/managers/prodege/__init__.py create mode 100644 generalresearch/managers/prodege/profiling.py create mode 100644 generalresearch/managers/prodege/survey.py create mode 100644 generalresearch/managers/prodege/user_pid.py create mode 100644 generalresearch/managers/repdata/__init__.py create mode 100644 generalresearch/managers/repdata/profiling.py create mode 100644 generalresearch/managers/repdata/survey.py create mode 100644 generalresearch/managers/repdata/user_pid.py create mode 100644 generalresearch/managers/sago/__init__.py create mode 100644 generalresearch/managers/sago/profiling.py create mode 100644 generalresearch/managers/sago/survey.py create mode 100644 generalresearch/managers/sago/user_pid.py create mode 100644 generalresearch/managers/spectrum/__init__.py create mode 100644 generalresearch/managers/spectrum/profiling.py create mode 100644 generalresearch/managers/spectrum/survey.py create mode 100644 generalresearch/managers/spectrum/user_pid.py create mode 100644 generalresearch/managers/survey.py create mode 100644 generalresearch/managers/thl/__init__.py create mode 100644 generalresearch/managers/thl/buyer.py create mode 100644 generalresearch/managers/thl/cashout_method.py create mode 100644 generalresearch/managers/thl/category.py create mode 100644 generalresearch/managers/thl/contest_manager.py create mode 100644 generalresearch/managers/thl/delete_request.py create mode 100644 generalresearch/managers/thl/ipinfo.py create mode 100644 generalresearch/managers/thl/ledger_manager/__init__.py create mode 100644 generalresearch/managers/thl/ledger_manager/conditions.py create mode 100644 generalresearch/managers/thl/ledger_manager/exceptions.py create mode 100644 generalresearch/managers/thl/ledger_manager/ledger.py create mode 100644 generalresearch/managers/thl/ledger_manager/thl_ledger.py create mode 100644 generalresearch/managers/thl/maxmind/__init__.py create mode 100644 generalresearch/managers/thl/maxmind/basic.py create mode 100644 generalresearch/managers/thl/maxmind/insights.py create mode 100644 generalresearch/managers/thl/payout.py create mode 100644 generalresearch/managers/thl/product.py create mode 100644 generalresearch/managers/thl/profiling/__init__.py create mode 100644 generalresearch/managers/thl/profiling/question.py create mode 100644 generalresearch/managers/thl/profiling/schema.py create mode 100644 generalresearch/managers/thl/profiling/uqa.py create mode 100644 generalresearch/managers/thl/profiling/user_upk.py create mode 100644 generalresearch/managers/thl/session.py create mode 100644 generalresearch/managers/thl/survey.py create mode 100644 generalresearch/managers/thl/survey_penalty.py create mode 100644 generalresearch/managers/thl/task_adjustment.py create mode 100644 generalresearch/managers/thl/user_compensate.py create mode 100644 generalresearch/managers/thl/user_manager/__init__.py create mode 100644 generalresearch/managers/thl/user_manager/memcached_user_manager.py create mode 100644 generalresearch/managers/thl/user_manager/mysql_user_manager.py create mode 100644 generalresearch/managers/thl/user_manager/rate_limit.py create mode 100644 generalresearch/managers/thl/user_manager/redis_user_manager.py create mode 100644 generalresearch/managers/thl/user_manager/user_manager.py create mode 100644 generalresearch/managers/thl/user_manager/user_metadata_manager.py create mode 100644 generalresearch/managers/thl/user_streak.py create mode 100644 generalresearch/managers/thl/userhealth.py create mode 100644 generalresearch/managers/thl/wall.py create mode 100644 generalresearch/managers/thl/wallet/__init__.py create mode 100644 generalresearch/managers/thl/wallet/approve.py create mode 100644 generalresearch/managers/thl/wallet/tango.py create mode 100644 generalresearch/mariadb.py create mode 100644 generalresearch/models/__init__.py create mode 100644 generalresearch/models/admin/__init__.py create mode 100644 generalresearch/models/admin/request.py create mode 100644 generalresearch/models/cint/__init__.py create mode 100644 generalresearch/models/cint/question.py create mode 100644 generalresearch/models/cint/survey.py create mode 100644 generalresearch/models/cint/task_collection.py create mode 100644 generalresearch/models/custom_types.py create mode 100644 generalresearch/models/device.py create mode 100644 generalresearch/models/dynata/__init__.py create mode 100644 generalresearch/models/dynata/question.py create mode 100644 generalresearch/models/dynata/survey.py create mode 100644 generalresearch/models/dynata/task_collection.py create mode 100644 generalresearch/models/events.py create mode 100644 generalresearch/models/gr/__init__.py create mode 100644 generalresearch/models/gr/authentication.py create mode 100644 generalresearch/models/gr/business.py create mode 100644 generalresearch/models/gr/team.py create mode 100644 generalresearch/models/innovate/__init__.py create mode 100644 generalresearch/models/innovate/question.py create mode 100644 generalresearch/models/innovate/survey.py create mode 100644 generalresearch/models/innovate/task_collection.py create mode 100644 generalresearch/models/legacy/__init__.py create mode 100644 generalresearch/models/legacy/api_status.py create mode 100644 generalresearch/models/legacy/bucket.py create mode 100644 generalresearch/models/legacy/definitions.py create mode 100644 generalresearch/models/legacy/offerwall.py create mode 100644 generalresearch/models/legacy/questions.py create mode 100644 generalresearch/models/lucid/__init__.py create mode 100644 generalresearch/models/lucid/question.py create mode 100644 generalresearch/models/lucid/survey.py create mode 100644 generalresearch/models/marketplace/__init__.py create mode 100644 generalresearch/models/marketplace/summary.py create mode 100644 generalresearch/models/morning/__init__.py create mode 100644 generalresearch/models/morning/question.py create mode 100644 generalresearch/models/morning/survey.py create mode 100644 generalresearch/models/morning/task_collection.py create mode 100644 generalresearch/models/pollfish/__init__.py create mode 100644 generalresearch/models/pollfish/question.py create mode 100644 generalresearch/models/precision/__init__.py create mode 100644 generalresearch/models/precision/definitions.py create mode 100644 generalresearch/models/precision/question.py create mode 100644 generalresearch/models/precision/survey.py create mode 100644 generalresearch/models/precision/task_collection.py create mode 100644 generalresearch/models/prodege/__init__.py create mode 100644 generalresearch/models/prodege/definitions.py create mode 100644 generalresearch/models/prodege/question.py create mode 100644 generalresearch/models/prodege/survey.py create mode 100644 generalresearch/models/prodege/task_collection.py create mode 100644 generalresearch/models/repdata/__init__.py create mode 100644 generalresearch/models/repdata/question.py create mode 100644 generalresearch/models/repdata/survey.py create mode 100644 generalresearch/models/repdata/task_collection.py create mode 100644 generalresearch/models/sago/__init__.py create mode 100644 generalresearch/models/sago/question.py create mode 100644 generalresearch/models/sago/survey.py create mode 100644 generalresearch/models/sago/task_collection.py create mode 100644 generalresearch/models/spectrum/__init__.py create mode 100644 generalresearch/models/spectrum/question.py create mode 100644 generalresearch/models/spectrum/survey.py create mode 100644 generalresearch/models/spectrum/task_collection.py create mode 100644 generalresearch/models/string_utils.py create mode 100644 generalresearch/models/thl/__init__.py create mode 100644 generalresearch/models/thl/category.py create mode 100644 generalresearch/models/thl/contest/__init__.py create mode 100644 generalresearch/models/thl/contest/contest.py create mode 100644 generalresearch/models/thl/contest/contest_entry.py create mode 100644 generalresearch/models/thl/contest/definitions.py create mode 100644 generalresearch/models/thl/contest/examples.py create mode 100644 generalresearch/models/thl/contest/exceptions.py create mode 100644 generalresearch/models/thl/contest/io.py create mode 100644 generalresearch/models/thl/contest/leaderboard.py create mode 100644 generalresearch/models/thl/contest/milestone.py create mode 100644 generalresearch/models/thl/contest/raffle.py create mode 100644 generalresearch/models/thl/contest/utils.py create mode 100644 generalresearch/models/thl/definitions.py create mode 100644 generalresearch/models/thl/demographics.py create mode 100644 generalresearch/models/thl/finance.py create mode 100644 generalresearch/models/thl/grliq.py create mode 100644 generalresearch/models/thl/ipinfo.py create mode 100644 generalresearch/models/thl/leaderboard.py create mode 100644 generalresearch/models/thl/ledger.py create mode 100644 generalresearch/models/thl/ledger_example.py create mode 100644 generalresearch/models/thl/locales.py create mode 100644 generalresearch/models/thl/maxmind/__init__.py create mode 100644 generalresearch/models/thl/maxmind/definitions.py create mode 100644 generalresearch/models/thl/offerwall/__init__.py create mode 100644 generalresearch/models/thl/offerwall/base.py create mode 100644 generalresearch/models/thl/offerwall/behavior.py create mode 100644 generalresearch/models/thl/offerwall/bucket.py create mode 100644 generalresearch/models/thl/offerwall/cache.py create mode 100644 generalresearch/models/thl/pagination.py create mode 100644 generalresearch/models/thl/payout.py create mode 100644 generalresearch/models/thl/payout_format.py create mode 100644 generalresearch/models/thl/product.py create mode 100644 generalresearch/models/thl/profiling/__init__.py create mode 100644 generalresearch/models/thl/profiling/marketplace.py create mode 100644 generalresearch/models/thl/profiling/other_option.py create mode 100644 generalresearch/models/thl/profiling/question.py create mode 100644 generalresearch/models/thl/profiling/upk_property.py create mode 100644 generalresearch/models/thl/profiling/upk_question.py create mode 100644 generalresearch/models/thl/profiling/upk_question_answer.py create mode 100644 generalresearch/models/thl/profiling/user_info.py create mode 100644 generalresearch/models/thl/profiling/user_question_answer.py create mode 100644 generalresearch/models/thl/report_task.py create mode 100644 generalresearch/models/thl/session.py create mode 100644 generalresearch/models/thl/soft_pair.py create mode 100644 generalresearch/models/thl/stats.py create mode 100644 generalresearch/models/thl/supplier_tag.py create mode 100644 generalresearch/models/thl/survey/__init__.py create mode 100644 generalresearch/models/thl/survey/buyer.py create mode 100644 generalresearch/models/thl/survey/condition.py create mode 100644 generalresearch/models/thl/survey/model.py create mode 100644 generalresearch/models/thl/survey/penalty.py create mode 100644 generalresearch/models/thl/survey/task_collection.py create mode 100644 generalresearch/models/thl/synchronize_global_vars.py create mode 100644 generalresearch/models/thl/task_adjustment.py create mode 100644 generalresearch/models/thl/task_status.py create mode 100644 generalresearch/models/thl/user.py create mode 100644 generalresearch/models/thl/user_iphistory.py create mode 100644 generalresearch/models/thl/user_profile.py create mode 100644 generalresearch/models/thl/user_quality_event.py create mode 100644 generalresearch/models/thl/user_streak.py create mode 100644 generalresearch/models/thl/userhealth.py create mode 100644 generalresearch/models/thl/wallet/__init__.py create mode 100644 generalresearch/models/thl/wallet/cashout_method.py create mode 100644 generalresearch/models/thl/wallet/payout.py create mode 100644 generalresearch/models/thl/wallet/user_wallet.py create mode 100644 generalresearch/models/utils.py create mode 100644 generalresearch/pg_helper.py create mode 100644 generalresearch/priority_thread_pool.py create mode 100644 generalresearch/redis_helper.py create mode 100644 generalresearch/resources/__init__.py create mode 100644 generalresearch/schemas/__init__.py create mode 100644 generalresearch/schemas/survey_stats.py create mode 100644 generalresearch/sql_helper.py create mode 100644 generalresearch/thl_django/README.md create mode 100644 generalresearch/thl_django/__init__.py create mode 100644 generalresearch/thl_django/accounting/__init__.py create mode 100644 generalresearch/thl_django/accounting/models.py create mode 100644 generalresearch/thl_django/app/__init__.py create mode 100644 generalresearch/thl_django/app/manage.py create mode 100644 generalresearch/thl_django/app/settings.py create mode 100644 generalresearch/thl_django/apps.py create mode 100644 generalresearch/thl_django/common/__init__.py create mode 100644 generalresearch/thl_django/common/models.py create mode 100644 generalresearch/thl_django/contest/__init__.py create mode 100644 generalresearch/thl_django/contest/models.py create mode 100644 generalresearch/thl_django/event/__init__.py create mode 100644 generalresearch/thl_django/event/models.py create mode 100644 generalresearch/thl_django/marketplace/__init__.py create mode 100644 generalresearch/thl_django/marketplace/models.py create mode 100644 generalresearch/thl_django/migrations/0001_initial.py create mode 100644 generalresearch/thl_django/migrations/0002_surveystat_is_live_alter_surveycategory_strength_and_more.py create mode 100644 generalresearch/thl_django/migrations/0003_remove_surveystat_surveystat_live_survey_idx_and_more.py create mode 100644 generalresearch/thl_django/migrations/0004_alter_surveystat_survey_is_live_and_more.py create mode 100644 generalresearch/thl_django/migrations/0005_remove_surveystat_marketplace_updated_439a2d_idx.py create mode 100644 generalresearch/thl_django/migrations/0006_remove_thlsession_thl_session_status_d578b7_idx_and_more.py create mode 100644 generalresearch/thl_django/migrations/0007_table_params.py create mode 100644 generalresearch/thl_django/migrations/0008_question_explanation_fragment_template_and_more.py create mode 100644 generalresearch/thl_django/migrations/__init__.py create mode 100644 generalresearch/thl_django/postgres-table-tuning.md create mode 100644 generalresearch/thl_django/postgres.md create mode 100644 generalresearch/thl_django/userhealth/__init__.py create mode 100644 generalresearch/thl_django/userhealth/models.py create mode 100644 generalresearch/thl_django/userprofile/__init__.py create mode 100644 generalresearch/thl_django/userprofile/models.py create mode 100644 generalresearch/utils/__init__.py create mode 100644 generalresearch/utils/aggregation.py create mode 100644 generalresearch/utils/copying_cache.py create mode 100644 generalresearch/utils/enum.py create mode 100644 generalresearch/utils/grpc_logger.py create mode 100644 generalresearch/wall_status_codes/__init__.py create mode 100644 generalresearch/wall_status_codes/cint.py create mode 100644 generalresearch/wall_status_codes/dynata.py create mode 100644 generalresearch/wall_status_codes/fullcircle.py create mode 100644 generalresearch/wall_status_codes/innovate.py create mode 100644 generalresearch/wall_status_codes/lucid.py create mode 100644 generalresearch/wall_status_codes/morning.py create mode 100644 generalresearch/wall_status_codes/pollfish.py create mode 100644 generalresearch/wall_status_codes/precision.py create mode 100644 generalresearch/wall_status_codes/prodege.py create mode 100644 generalresearch/wall_status_codes/repdata.py create mode 100644 generalresearch/wall_status_codes/sago.py create mode 100644 generalresearch/wall_status_codes/spectrum.py create mode 100644 generalresearch/wall_status_codes/wxet.py create mode 100644 generalresearch/wxet/__init__.py create mode 100644 generalresearch/wxet/models/__init__.py create mode 100644 generalresearch/wxet/models/definitions.py create mode 100644 generalresearch/wxet/models/finish_type.py create mode 100644 mypy.ini create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 test_utils/__init__.py create mode 100644 test_utils/conftest.py create mode 100644 test_utils/grliq/__init__.py create mode 100644 test_utils/grliq/conftest.py create mode 100644 test_utils/grliq/managers/__init__.py create mode 100644 test_utils/grliq/managers/conftest.py create mode 100644 test_utils/grliq/models/__init__.py create mode 100644 test_utils/grliq/models/conftest.py create mode 100644 test_utils/incite/__init__.py create mode 100644 test_utils/incite/collections/__init__.py create mode 100644 test_utils/incite/collections/conftest.py create mode 100644 test_utils/incite/conftest.py create mode 100644 test_utils/incite/mergers/__init__.py create mode 100644 test_utils/incite/mergers/conftest.py create mode 100644 test_utils/managers/__init__.py create mode 100644 test_utils/managers/cashout_methods.py create mode 100644 test_utils/managers/conftest.py create mode 100644 test_utils/managers/contest/__init__.py create mode 100644 test_utils/managers/contest/conftest.py create mode 100644 test_utils/managers/ledger/__init__.py create mode 100644 test_utils/managers/ledger/conftest.py create mode 100644 test_utils/managers/upk/__init__.py create mode 100644 test_utils/managers/upk/conftest.py create mode 100644 test_utils/managers/upk/marketplace_category.csv.gz create mode 100644 test_utils/managers/upk/marketplace_item.csv.gz create mode 100644 test_utils/managers/upk/marketplace_property.csv.gz create mode 100644 test_utils/managers/upk/marketplace_propertycategoryassociation.csv.gz create mode 100644 test_utils/managers/upk/marketplace_propertycountry.csv.gz create mode 100644 test_utils/managers/upk/marketplace_propertyitemrange.csv.gz create mode 100644 test_utils/managers/upk/marketplace_propertymarketplaceassociation.csv.gz create mode 100644 test_utils/managers/upk/marketplace_question.csv.gz create mode 100644 test_utils/models/__init__.py create mode 100644 test_utils/models/conftest.py create mode 100644 test_utils/spectrum/__init__.py create mode 100644 test_utils/spectrum/conftest.py create mode 100644 test_utils/spectrum/surveys_json.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/grliq/__init__.py create mode 100644 tests/grliq/managers/__init__.py create mode 100644 tests/grliq/managers/test_forensic_data.py create mode 100644 tests/grliq/managers/test_forensic_results.py create mode 100644 tests/grliq/models/__init__.py create mode 100644 tests/grliq/models/test_forensic_data.py create mode 100644 tests/grliq/test_utils.py create mode 100644 tests/incite/__init__.py create mode 100644 tests/incite/collections/__init__.py create mode 100644 tests/incite/collections/test_df_collection_base.py create mode 100644 tests/incite/collections/test_df_collection_item_base.py create mode 100644 tests/incite/collections/test_df_collection_item_thl_web.py create mode 100644 tests/incite/collections/test_df_collection_thl_marketplaces.py create mode 100644 tests/incite/collections/test_df_collection_thl_web.py create mode 100644 tests/incite/collections/test_df_collection_thl_web_ledger.py create mode 100644 tests/incite/mergers/__init__.py create mode 100644 tests/incite/mergers/foundations/__init__.py create mode 100644 tests/incite/mergers/foundations/test_enriched_session.py create mode 100644 tests/incite/mergers/foundations/test_enriched_task_adjust.py create mode 100644 tests/incite/mergers/foundations/test_enriched_wall.py create mode 100644 tests/incite/mergers/foundations/test_user_id_product.py create mode 100644 tests/incite/mergers/test_merge_collection.py create mode 100644 tests/incite/mergers/test_merge_collection_item.py create mode 100644 tests/incite/mergers/test_pop_ledger.py create mode 100644 tests/incite/mergers/test_ym_survey_merge.py create mode 100644 tests/incite/schemas/__init__.py create mode 100644 tests/incite/schemas/test_admin_responses.py create mode 100644 tests/incite/schemas/test_thl_web.py create mode 100644 tests/incite/test_collection_base.py create mode 100644 tests/incite/test_collection_base_item.py create mode 100644 tests/incite/test_grl_flow.py create mode 100644 tests/incite/test_interval_idx.py create mode 100644 tests/managers/__init__.py create mode 100644 tests/managers/gr/__init__.py create mode 100644 tests/managers/gr/test_authentication.py create mode 100644 tests/managers/gr/test_business.py create mode 100644 tests/managers/gr/test_team.py create mode 100644 tests/managers/leaderboard.py create mode 100644 tests/managers/test_events.py create mode 100644 tests/managers/test_lucid.py create mode 100644 tests/managers/test_userpid.py create mode 100644 tests/managers/thl/__init__.py create mode 100644 tests/managers/thl/test_buyer.py create mode 100644 tests/managers/thl/test_cashout_method.py create mode 100644 tests/managers/thl/test_category.py create mode 100644 tests/managers/thl/test_contest/__init__.py create mode 100644 tests/managers/thl/test_contest/test_leaderboard.py create mode 100644 tests/managers/thl/test_contest/test_milestone.py create mode 100644 tests/managers/thl/test_contest/test_raffle.py create mode 100644 tests/managers/thl/test_harmonized_uqa.py create mode 100644 tests/managers/thl/test_ipinfo.py create mode 100644 tests/managers/thl/test_ledger/__init__.py create mode 100644 tests/managers/thl/test_ledger/test_lm_accounts.py create mode 100644 tests/managers/thl/test_ledger/test_lm_tx.py create mode 100644 tests/managers/thl/test_ledger/test_lm_tx_entries.py create mode 100644 tests/managers/thl/test_ledger/test_lm_tx_locks.py create mode 100644 tests/managers/thl/test_ledger/test_lm_tx_metadata.py create mode 100644 tests/managers/thl/test_ledger/test_thl_lm_accounts.py create mode 100644 tests/managers/thl/test_ledger/test_thl_lm_bp_payout.py create mode 100644 tests/managers/thl/test_ledger/test_thl_lm_tx.py create mode 100644 tests/managers/thl/test_ledger/test_thl_lm_tx__user_payouts.py create mode 100644 tests/managers/thl/test_ledger/test_thl_pem.py create mode 100644 tests/managers/thl/test_ledger/test_user_txs.py create mode 100644 tests/managers/thl/test_ledger/test_wallet.py create mode 100644 tests/managers/thl/test_maxmind.py create mode 100644 tests/managers/thl/test_payout.py create mode 100644 tests/managers/thl/test_product.py create mode 100644 tests/managers/thl/test_product_prod.py create mode 100644 tests/managers/thl/test_profiling/__init__.py create mode 100644 tests/managers/thl/test_profiling/test_question.py create mode 100644 tests/managers/thl/test_profiling/test_schema.py create mode 100644 tests/managers/thl/test_profiling/test_uqa.py create mode 100644 tests/managers/thl/test_profiling/test_user_upk.py create mode 100644 tests/managers/thl/test_session_manager.py create mode 100644 tests/managers/thl/test_survey.py create mode 100644 tests/managers/thl/test_survey_penalty.py create mode 100644 tests/managers/thl/test_task_adjustment.py create mode 100644 tests/managers/thl/test_task_status.py create mode 100644 tests/managers/thl/test_user_manager/__init__.py create mode 100644 tests/managers/thl/test_user_manager/test_base.py create mode 100644 tests/managers/thl/test_user_manager/test_mysql.py create mode 100644 tests/managers/thl/test_user_manager/test_redis.py create mode 100644 tests/managers/thl/test_user_manager/test_user_fetch.py create mode 100644 tests/managers/thl/test_user_manager/test_user_metadata.py create mode 100644 tests/managers/thl/test_user_streak.py create mode 100644 tests/managers/thl/test_userhealth.py create mode 100644 tests/managers/thl/test_wall_manager.py create mode 100644 tests/models/__init__.py create mode 100644 tests/models/admin/__init__.py create mode 100644 tests/models/admin/test_report_request.py create mode 100644 tests/models/custom_types/__init__.py create mode 100644 tests/models/custom_types/test_aware_datetime.py create mode 100644 tests/models/custom_types/test_dsn.py create mode 100644 tests/models/custom_types/test_therest.py create mode 100644 tests/models/custom_types/test_uuid_str.py create mode 100644 tests/models/dynata/__init__.py create mode 100644 tests/models/dynata/test_eligbility.py create mode 100644 tests/models/dynata/test_survey.py create mode 100644 tests/models/gr/__init__.py create mode 100644 tests/models/gr/test_authentication.py create mode 100644 tests/models/gr/test_business.py create mode 100644 tests/models/gr/test_team.py create mode 100644 tests/models/innovate/__init__.py create mode 100644 tests/models/innovate/test_question.py create mode 100644 tests/models/legacy/__init__.py create mode 100644 tests/models/legacy/data.py create mode 100644 tests/models/legacy/test_offerwall_parse_response.py create mode 100644 tests/models/legacy/test_profiling_questions.py create mode 100644 tests/models/legacy/test_user_question_answer_in.py create mode 100644 tests/models/morning/__init__.py create mode 100644 tests/models/morning/test.py create mode 100644 tests/models/precision/__init__.py create mode 100644 tests/models/precision/test_survey.py create mode 100644 tests/models/precision/test_survey_manager.py create mode 100644 tests/models/prodege/__init__.py create mode 100644 tests/models/prodege/test_survey_participation.py create mode 100644 tests/models/spectrum/__init__.py create mode 100644 tests/models/spectrum/test_question.py create mode 100644 tests/models/spectrum/test_survey.py create mode 100644 tests/models/spectrum/test_survey_manager.py create mode 100644 tests/models/test_currency.py create mode 100644 tests/models/test_device.py create mode 100644 tests/models/test_finance.py create mode 100644 tests/models/thl/__init__.py create mode 100644 tests/models/thl/question/__init__.py create mode 100644 tests/models/thl/question/test_question_info.py create mode 100644 tests/models/thl/question/test_user_info.py create mode 100644 tests/models/thl/test_adjustments.py create mode 100644 tests/models/thl/test_bucket.py create mode 100644 tests/models/thl/test_buyer.py create mode 100644 tests/models/thl/test_contest/__init__.py create mode 100644 tests/models/thl/test_contest/test_contest.py create mode 100644 tests/models/thl/test_contest/test_leaderboard_contest.py create mode 100644 tests/models/thl/test_contest/test_raffle_contest.py create mode 100644 tests/models/thl/test_ledger.py create mode 100644 tests/models/thl/test_marketplace_condition.py create mode 100644 tests/models/thl/test_payout.py create mode 100644 tests/models/thl/test_payout_format.py create mode 100644 tests/models/thl/test_product.py create mode 100644 tests/models/thl/test_product_userwalletconfig.py create mode 100644 tests/models/thl/test_soft_pair.py create mode 100644 tests/models/thl/test_upkquestion.py create mode 100644 tests/models/thl/test_user.py create mode 100644 tests/models/thl/test_user_iphistory.py create mode 100644 tests/models/thl/test_user_metadata.py create mode 100644 tests/models/thl/test_user_streak.py create mode 100644 tests/models/thl/test_wall.py create mode 100644 tests/models/thl/test_wall_session.py create mode 100644 tests/pytest.ini create mode 100644 tests/sql_helper.py create mode 100644 tests/wall_status_codes/__init__.py create mode 100644 tests/wall_status_codes/test_analyze.py create mode 100644 tests/wxet/__init__.py create mode 100644 tests/wxet/models/__init__.py create mode 100644 tests/wxet/models/test_definitions.py create mode 100644 tests/wxet/models/test_finish_type.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..df17d04 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +.idea +.vscode +*.pytest* +*.pyc +__pycache__ +*.csv +generalresearch/resources/brokerage_trust_calculated.csv +tests/.env.test +.env.* +.DS_Store \ No newline at end of file diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 0000000..29d065f --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,295 @@ +pipeline { + agent any + + triggers { + cron('H */12 * * *') + pollSCM('H */6 * * *') + } + + options { + skipDefaultCheckout() + } + + environment { + VENV = "${env.WORKSPACE}/py-utils-venv" + SPECTRUM_CARER_VENV = "${env.WORKSPACE}/thl-spectrum-carer-venv" + GRLIQ_CARER_VENV = "${env.WORKSPACE}/grliq-carer-venv" + GR_CARER_VENV = "${env.WORKSPACE}/gr-carer-venv" + + INCITE_MOUNT_DIR = '/mnt/thl-incite' + TMP_DIR = "${env.WORKSPACE}/tmp" + } + + stages { + stage('python versions') { + matrix { + axes { + axis { + name 'PYTHON_VERSION' + values 'python3.13', 'python3.12', 'python3.11', 'python3.10' + } + } + + stages { + stage('Setup DB') { + steps { + script { + env.DB_NAME = 'unittest-thl-' + UUID.randomUUID().toString().replace('-', '').take(12) + env.THL_WEB_RW_DB = "postgres://${env.DB_USER}:${env.DB_PASSWORD}@${env.DB_POSTGRESQL_HOST}/${env.DB_NAME}" + env.THL_WEB_RR_DB = env.THL_WEB_RW_DB + env.THL_WEB_RO_DB = env.THL_WEB_RW_DB + echo "Using database: ${env.DB_NAME}" + + env.SPECTRUM_DB_NAME = 'unittest-thl-spectrum-' + UUID.randomUUID().toString().replace('-', '').take(12) + env.SPECTRUM_RW_DB = "mariadb://${env.DB_USER}:${env.DB_PASSWORD}@${env.DB_MARIA_HOST}/${env.SPECTRUM_DB_NAME}" + env.SPECTRUM_RR_DB = env.SPECTRUM_RW_DB + echo "Using database: ${env.SPECTRUM_DB_NAME}" + + env.GRLIQ_DB_NAME = 'unittest-grliq-' + UUID.randomUUID().toString().replace('-', '').take(12) + env.GRLIQ_DB = "postgres://${env.DB_USER}:${env.DB_PASSWORD}@${env.DB_POSTGRESQL_HOST}/${env.GRLIQ_DB_NAME}" + echo "Using database: ${env.GRLIQ_DB_NAME}" + + env.GR_DB_NAME = 'unittest-gr-' + UUID.randomUUID().toString().replace('-', '').take(12) + env.GR_DB = "postgres://${env.DB_USER}:${env.DB_PASSWORD}@${env.DB_POSTGRESQL_HOST}/${env.GR_DB_NAME}" + echo "Using database: ${env.GR_DB_NAME}" + } + + sh """ + PGPASSWORD=${env.DB_PASSWORD} psql -h ${env.DB_POSTGRESQL_HOST} -U ${env.DB_USER} -d postgres < + + Branches + Versions + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/generalresearch/__init__.py b/generalresearch/__init__.py new file mode 100644 index 0000000..5993613 --- /dev/null +++ b/generalresearch/__init__.py @@ -0,0 +1,146 @@ +import threading +import time +from functools import wraps + +from decorator import decorator +from wrapt import FunctionWrapper, ObjectProxy + + +def retry(exceptions, tries=4, delay=0.5, backoff=2, logger=None): + """ + https://www.calazan.com/retry-decorator-for-python-3/ + Retry calling the decorated function using an exponential backoff. + + Args: + exceptions: The exception to check. may be a tuple of + exceptions to check. + tries: Number of times to try (not retry) before giving up. + delay: Initial delay between retries in seconds. + backoff: Backoff multiplier (e.g. value of 2 will double the delay + each retry). + logger: Logger to use. If None, print. + """ + + def deco_retry(f): + + @wraps(f) + def f_retry(*args, **kwargs): + mtries, mdelay = tries, delay + while mtries > 1: + try: + return f(*args, **kwargs) + except exceptions as e: + msg = "{}, Retrying in {} seconds...".format(e, mdelay) + if logger: + logger.warning(msg) + else: + print(msg) + time.sleep(mdelay) + mtries -= 1 + mdelay *= backoff + return f(*args, **kwargs) + + return f_retry # true decorator + + return deco_retry + + +def synchronized(wrapped): + # https://wrapt.readthedocs.io/en/latest/examples.html#thread-synchronization + + # Determine if being passed an object which is a synchronization + # primitive. We can't check by type for Lock, RLock, Semaphore etc, + # as the means of creating them isn't the type. Therefore use the + # existence of acquire() and release() methods. This is more + # extensible anyway as it allows custom synchronization mechanisms. + + if hasattr(wrapped, "acquire") and hasattr(wrapped, "release"): + # We remember what the original lock is and then return a new + # decorator which accesses and locks it. When returning the new + # decorator we wrap it with an object proxy so we can override + # the context manager methods in case it is being used to wrap + # synchronized statements with a 'with' statement. + + lock = wrapped + + @decorator + def _synchronized(wrapped, instance, args, kwargs): + # Execute the wrapped function while the original supplied + # lock is held. + + with lock: + return wrapped(*args, **kwargs) + + class _PartialDecorator(ObjectProxy): + + def __enter__(self): + lock.acquire() + return lock + + def __exit__(self, *args): + lock.release() + + return _PartialDecorator(wrapped=_synchronized) + + # Following only apply when the lock is being created + # automatically based on the context of what was supplied. In + # this case we supply a final decorator, but need to use + # FunctionWrapper directly as we want to derive from it to add + # context manager methods in case it is being used to wrap + # synchronized statements with a 'with' statement. + + def _synchronized_lock(context): + # Attempt to retrieve the lock for the specific context. + + lock = vars(context).get("_synchronized_lock", None) + + if lock is None: + # There is no existing lock defined for the context we + # are dealing with so we need to create one. This needs + # to be done in a way to guarantee there is only one + # created, even if multiple threads try and create it at + # the same time. We can't always use the setdefault() + # method on the __dict__ for the context. This is the + # case where the context is a class, as __dict__ is + # actually a dictproxy. What we therefore do is use a + # meta lock on this wrapper itself, to control the + # creation and assignment of the lock attribute against + # the context. + + meta_lock = vars(synchronized).setdefault( + "_synchronized_meta_lock", threading.Lock() + ) + + with meta_lock: + # We need to check again for whether the lock we want + # exists in case two threads were trying to create it + # at the same time and were competing to create the + # meta lock. + + lock = vars(context).get("_synchronized_lock", None) + + if lock is None: + lock = threading.RLock() + setattr(context, "_synchronized_lock", lock) + + return lock + + def _synchronized_wrapper(wrapped, instance, args, kwargs): + # Execute the wrapped function while the lock for the + # desired context is held. If instance is None then the + # wrapped function is used as the context. + + with _synchronized_lock(instance or wrapped): + return wrapped(*args, **kwargs) + + class _FinalDecorator(FunctionWrapper): + + def __enter__(self): + self._self_lock = _synchronized_lock(self.__wrapped__) + self._self_lock.acquire() + return self._self_lock + + def __exit__(self, *args): + self._self_lock.release() + + return _FinalDecorator(wrapped=wrapped, wrapper=_synchronized_wrapper) diff --git a/generalresearch/cacheing.py b/generalresearch/cacheing.py new file mode 100644 index 0000000..de000cf --- /dev/null +++ b/generalresearch/cacheing.py @@ -0,0 +1,47 @@ +from generalresearch import retry + + +class RetryCache: + # Simple pylibmc.Client wrapper that implements a retry on each method + + def __init__(self, client, tries=4, delay=1, backoff=1.5): + import pylibmc + + self.client = client + self.f = retry(pylibmc.Error, tries=tries, delay=delay, backoff=backoff) + + def get(self, key): + @self.f + def _get(key): + return self.client.get(key) + + return _get(key) + + def set(self, key, value, timeout=0): + @self.f + def _set(key, value, timeout): + return self.client.set(key, value, time=timeout) + + return _set(key, value, timeout) + + def delete_multi(self, keys): + @self.f + def _delete_multi(keys): + return self.client.delete_multi(keys) + + return _delete_multi(keys) + + def delete(self, key): + @self.f + def _delete(key): + return self.client.delete(key) + + return _delete(key) + + +if __name__ == "__main__": + import pylibmc + + CACHE = RetryCache(pylibmc.Client(["127.0.0.1:11211"], binary=True)) + CACHE.set("foo", "bar") + print(CACHE.get("foo")) diff --git a/generalresearch/config.py b/generalresearch/config.py new file mode 100644 index 0000000..e44124b --- /dev/null +++ b/generalresearch/config.py @@ -0,0 +1,109 @@ +from datetime import datetime, timezone +from typing import Optional + +from pydantic import RedisDsn, Field, MariaDBDsn, DirectoryPath, PostgresDsn +from pydantic_settings import BaseSettings + +from generalresearch.models.custom_types import DaskDsn, SentryDsn, MySQLOrMariaDsn + + +def is_debug() -> bool: + import os + + is_developer: bool = os.getenv("USER") in {"nanis", "gstupp"} + is_pytest1: bool = bool(os.getenv("PYTEST_TEST", False)) + is_pytest2: bool = bool(os.getenv("PYTEST_CURRENT_TEST", False)) + is_pytest3: bool = bool(os.getenv("PYTEST_VERSION", False)) + is_debugging1: bool = os.getenv("DEBUG", "").lower() in ("1", "true", "yes") + is_debugging2: bool = os.getenv("PYTHON_DEBUG", "").lower() in ("1", "true", "yes") + is_jenkins: bool = bool(os.getenv("JENKINS_HOME")) or bool(os.getenv("JENKINS_URL")) + is_vscode: bool = ( + os.getenv("DEBUGPY_RUNNING") == "true" or os.getenv("TERM_PROGRAM") == "vscode" + ) + + return ( + is_developer + or is_pytest1 + or is_pytest2 + or is_pytest3 + or is_debugging1 + or is_debugging2 + or is_jenkins + or is_vscode + ) + + +class GRLBaseSettings(BaseSettings): + debug: bool = Field(default=True) + + redis: Optional[RedisDsn] = Field(default="redis://127.0.0.1:6379") + redis_timeout: float = Field(default=0.10) + + thl_redis: Optional[RedisDsn] = Field(default="redis://127.0.0.1:6379") + + dask: Optional[DaskDsn] = Field(default="tcp://127.0.0.1:8786", description="") + + sentry: Optional[SentryDsn] = Field( + default=None, description="The sentry.io DSN for connecting to a project" + ) + + thl_mkpl_rw_db: Optional[MariaDBDsn] = Field(default="mariadb://root:@127.0.0.1/") + thl_mkpl_rr_db: Optional[MariaDBDsn] = Field(default="mariadb://root:@127.0.0.1/") + + # Primary DB, SELECT permissions + thl_web_ro_db: Optional[PostgresDsn] = Field( + default="postgres://postgres:password@localhost:5432/thl" + ) + # Primary DB, SELECT, INSERT, UPDATE permissions + thl_web_rw_db: Optional[PostgresDsn] = Field(default=None) + # Primary DB, SELECT, INSERT, UPDATE, DELETE permissions + thl_web_rwd_db: Optional[PostgresDsn] = Field(default=None) + # Slave/secondary/read-replica SELECT permission only + thl_web_rr_db: Optional[PostgresDsn] = Field(default=None) + + tmp_dir: DirectoryPath = Field(default="/tmp") + + spectrum_rw_db: Optional[MariaDBDsn] = Field(default=None) + spectrum_rr_db: Optional[MariaDBDsn] = Field(default=None) + + precision_rw_db: Optional[MariaDBDsn] = Field(default=None) + precision_rr_db: Optional[MariaDBDsn] = Field(default=None) + + # --- GR ---- + gr_db: Optional[PostgresDsn] = Field(default=None) + gr_redis: Optional[RedisDsn] = Field(default="redis://127.0.0.1:6379") + + # --- GRL IQ --- + grliq_db: Optional[PostgresDsn] = Field(default=None) + mnt_grliq_archive_dir: Optional[str] = Field( + default=None, + description="Where gr-api can pull GRL-IQ Forensic archive items like" + "the captured screenshots.", + ) + + mnt_gr_api_dir: Optional[str] = Field( + default=None, + description="Where gr-api can pull parquet files from.", + ) + + # --- TangoCard Configuration --- + tango_platform_name: Optional[str] = Field(default=None) + tango_platform_key: Optional[str] = Field(default=None) + tango_account_id: Optional[str] = Field(default=None) + tango_customer_id: Optional[str] = Field(default=None) + + # --- Keeping this here as we use these ids regardless of the AMT account + amt_bonus_cashout_method: Optional[str] = Field(default=None) + amt_assignment_cashout_method: Optional[str] = Field(default=None) + + # --- Maxmind Configuration --- + maxmind_account_id: Optional[str] = Field(default=None) + maxmind_license_key: Optional[str] = Field(default=None) + + +EXAMPLE_PRODUCT_ID = "1108d053e4fa47c5b0dbdcd03a7981e7" + +# AMT accounting was changed many times and txs before this date +# are either missing AMT bonuses, or not accounting for hit rewards. +JAMES_BILLINGS_BPID = "888dbc589987425fa846d6e2a8daed04" +JAMES_BILLINGS_TX_CUTOFF = datetime(2026, 1, 1, tzinfo=timezone.utc) diff --git a/generalresearch/currency.py b/generalresearch/currency.py new file mode 100644 index 0000000..2948d38 --- /dev/null +++ b/generalresearch/currency.py @@ -0,0 +1,156 @@ +import warnings +from decimal import Decimal +from enum import Enum +from typing import Any + +from pydantic import GetCoreSchemaHandler, NonNegativeInt +from pydantic_core import CoreSchema, core_schema + +from generalresearch.utils.enum import ReprEnumMeta + + +class LedgerCurrency(str, Enum, metaclass=ReprEnumMeta): + USD = "USD" + USDCent = "USDCent" + USDMill = "USDMill" + TEST = "test" + + +def format_usd_cent(usd_cent: int) -> str: + """USDCent can't be negative. However, we want some helper properties + so show the value even if it's negative + """ + v = USDCent(abs(usd_cent)).to_usd_str() + return f"-{v}" if usd_cent < 0 else v + + +class USDCent(int): + def __new__(cls, value, *args, **kwargs): + + if isinstance(value, float): + warnings.warn( + "USDCent init with a float. Rounding behavior may " "be unexpected" + ) + + if isinstance(value, Decimal): + warnings.warn( + "USDCent init with a Decimal. Rounding behavior may " "be unexpected" + ) + + if value < 0: + raise ValueError("USDCent not be less than zero") + + return super(cls, cls).__new__(cls, value) + + def __add__(self, other): + assert isinstance(other, USDCent) + res = super(USDCent, self).__add__(other) + return self.__class__(res) + + def __sub__(self, other): + assert isinstance(other, USDCent) + res = super(USDCent, self).__sub__(other) + return self.__class__(res) + + def __mul__(self, other): + assert isinstance(other, USDCent) + res = super(USDCent, self).__mul__(other) + return self.__class__(res) + + def __abs__(self): + res = super(USDCent, self).__abs__() + return self.__class__(res) + + def __truediv__(self, other): + raise ValueError("Division not allowed for USDCent") + + def __str__(self): + return "%d" % int(self) + + def __repr__(self): + return "USDCent(%d)" % int(self) + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """ + https://docs.pydantic.dev/latest/concepts/types/#customizing-validation-with-__get_pydantic_core_schema__ + """ + return core_schema.no_info_after_validator_function( + cls, handler(NonNegativeInt) + ) + + def to_usd(self) -> Decimal: + return Decimal(int(self) / 100).quantize(Decimal(".01")) + + def to_usd_str(self) -> str: + return "${:,.2f}".format(float(self.to_usd())) + + +class USDMill(int): + """ + This represents 1/1000 of a US dollar, or 1/10th of a USD cent. + """ + + def __new__(cls, value, *args, **kwargs): + + if isinstance(value, float): + warnings.warn( + "USDMill init with a float. Rounding behavior " "may be unexpected" + ) + + if isinstance(value, Decimal): + warnings.warn( + "USDMill init with a Decimal. Rounding behavior " "may be unexpected" + ) + + if value < 0: + raise ValueError("USDMill not be less than zero") + + return super(cls, cls).__new__(cls, value) + + def __add__(self, other): + assert isinstance(other, USDMill) + res = super(USDMill, self).__add__(other) + return self.__class__(res) + + def __sub__(self, other): + assert isinstance(other, USDMill) + res = super(USDMill, self).__sub__(other) + return self.__class__(res) + + def __mul__(self, other): + assert isinstance(other, USDMill) + res = super(USDMill, self).__mul__(other) + return self.__class__(res) + + def __abs__(self): + res = super(USDMill, self).__abs__() + return self.__class__(res) + + def __truediv__(self, other): + raise ValueError("Division not allowed for USDMill") + + def __str__(self): + return "%d" % int(self) + + def __repr__(self): + return "USDMill(%d)" % int(self) + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + """ + https://docs.pydantic.dev/latest/concepts/types/#customizing-validation-with-__get_pydantic_core_schema__ + """ + return core_schema.no_info_after_validator_function( + cls, handler(NonNegativeInt) + ) + + def to_usd(self) -> Decimal: + return Decimal(int(self) / 1_000).quantize(Decimal(".001")) + + def to_usd_str(self) -> str: + return "${:,.3f}".format(float(self.to_usd())) diff --git a/generalresearch/decorators.py b/generalresearch/decorators.py new file mode 100644 index 0000000..5490c8e --- /dev/null +++ b/generalresearch/decorators.py @@ -0,0 +1,10 @@ +import logging + +# --- Logging --- + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s:%(name)s:%(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +LOG = logging.getLogger("generalresearch") diff --git a/generalresearch/grliq/__init__.py b/generalresearch/grliq/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/grliq/managers/__init__.py b/generalresearch/grliq/managers/__init__.py new file mode 100644 index 0000000..849b6c2 --- /dev/null +++ b/generalresearch/grliq/managers/__init__.py @@ -0,0 +1,34 @@ +from generalresearch.grliq.models.forensic_data import GrlIqData +from generalresearch.grliq.models.forensic_result import ( + GrlIqCheckerResults, + GrlIqForensicCategoryResult, +) + +DUMMY_GRLIQ_DATA = [ + { + "data": GrlIqData.model_validate_json( + """{"mid": "3722ed29314940fabd37b42d808dcf5a", "uuid": "b11441da5a854dfbb8401d4c32e56db5", "phase": "offerwall-enter", "events": null, "vendor": "Google Inc.", "app_name": "Netscape", "calendar": "gregory", "language": "en-US", "platform": "Linux x86_64", "timezone": "America/Mexico_City", "client_ip": "131.196.250.250", "timestamp": "2025-02-27T16:05:34-06:00", "webrtc_ip": "131.196.250.250", "created_at": "2025-02-27T22:05:35.370589Z", "language_2": "en-US", "language_3": null, "platform_2": "Linux x86_64", "platform_3": null, "prefetched": true, "product_id": "d0606a0b5d034a8d81b1e3579d1f76fd", "webgl_flag": true, "webgl_hash": "da27e1b9b660057a3f5e185d3f5deabe", "canvas_hash": "14ed764326ec454d976c322261d99f16", "color_gamut": "3", "country_iso": "mx", "inner_width": 612, "outer_width": 1813, "product_sub": "20030107", "audio_codecs": "1,1,1,1,1,3,1,3,1,3,3,1,1,3,3,3,3,1,3,3,3,2,1,1", "cookie_check": "", "graphics_api": "WebKit WebGL", "inner_height": 1174, "mouse_events": null, "ontouchstart": false, "outer_height": 1261, "plugins_hash": "4c05fa2f766a444d4f253ead792c8b0e|2", "screen_width": 2560, "video_codecs": "1,3,3,3,3,3,3,3,3,3,1,1,1,1,1,1,3,1,1,1,3,3,1", "webgl_hash_2": "fc73fd5db75e2c36222fe34251be3971", "webrtc_error": false, "window_opera": false, "battery_level": 0.9, "canvas_hash_2": "bd11ebbf5c26fd20e0217820b4159752", "dynamic_range": false, "error_message": "Cannot read", "forced_colors": false, "math_result_1": "1.9275814160560204e-50", "math_result_2": "1.6182817135715877", "screen_height": 1440, "webgl_check_1": true, "webgl_context": "webgl2", "window_chrome": true, "connection_rtt": 150, "history_length": 16, "user_agent_str": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", "web_sql_exists": false, "calender_locale": "en-US", "connection_type": "", "inverted_colors": true, "navigator_brave": false, "product_user_id": "d1d55df1-959e-4740-b77c-fa1f4fc457ae", "request_headers": {"host": "test", "accept": "*/*", "connection": "keep-alive", "user-agent": "python-httpx/0.27.0", "content-length": "3646", "accept-encoding": "gzip, deflate", "x-forwarded-for": "131.196.250.250"}, "timezone_offset": 360, "webrtc_local_ip": "50486637-6b64-4812-b10a-0a75337c31bd.local", "battery_charging": true, "client_ip_detail": {"continent_code": "EU", "continent_name": "Europe", "country_name": "France", "is_in_european_union": true, "ip": "131.196.250.250", "isp": null, "latitude": null, "city_name": null, "longitude": null, "time_zone": null, "user_type": null, "country_iso": "mx", "postal_code": null, "is_anonymous": null, "accuracy_radius": null, "static_ip_score": null, "subdivision_1_iso": null, "subdivision_2_iso": null, "subdivision_1_name": null, "subdivision_2_name": null, "registered_country_iso": null}, "max_touch_points": 0, "numbering_system": "latn", "path_fingerprint": 3252, "prefers_contrast": "0", "rendering_engine": "WebKit", "timezone_success": "pass", "user_agent_hints": {"model": null, "brands": [{"brand": "Google Chrome", "version": "131"}, {"brand": "Chromium", "version": "131"}, {"brand": "Not_A Brand", "version": "24"}], "mobile": false, "bitness": "64", "platform": "Linux", "brands_full": [{"brand": "Google Chrome", "version": "131.0.6778.204"}, {"brand": "Chromium", "version": "131.0.6778.204"}, {"brand": "Not_A Brand", "version": "24.0.0.0"}], "architecture": "x86", "platform_version": "6.2.0"}, "user_agent_str_2": null, "webgl_extensions": "EXT_clip_control|EXT_color_buffer_float|EXT_color_buffer_half_float|EXT_conservative_depth|EXT_depth_clamp|EXT_disjoint_timer_query_webgl2|EXT_float_blend|EXT_polygon_offset_clamp|EXT_render_snorm|EXT_texture_compression_bptc|EXT_texture_compression_rgtc|EXT_texture_filter_anisotropic|EXT_texture_mirror_clamp_to_edge|EXT_texture_norm16|KHR_parallel_shader_compile|NV_shader_noperspective_interpolation|OES_draw_buffers_indexed|OES_sample_variables|OES_shader_multisample_interpolation|OES_texture_float_linear|OVR_multiview2|WEBGL_blend_func_extended|WEBGL_clip_cull_distance|WEBGL_compressed_texture_astc|WEBGL_compressed_texture_etc|WEBGL_compressed_texture_etc1|WEBGL_compressed_texture_s3tc|WEBGL_compressed_texture_s3tc_srgb|WEBGL_debug_renderer_info|WEBGL_debug_shaders|WEBGL_lose_context|WEBGL_multi_draw|WEBGL_polygon_mode|WEBGL_provoking_vertex|WEBGL_stencil_texturing", "webrtc_ip_detail": {"continent_code": "EU", "continent_name": "Europe", "country_name": "France", "is_in_european_union": true, "ip": "131.196.250.250", "isp": null, "latitude": null, "city_name": null, "longitude": null, "time_zone": null, "user_type": null, "country_iso": "mx", "postal_code": null, "is_anonymous": null, "accuracy_radius": null, "static_ip_score": null, "subdivision_1_iso": null, "subdivision_2_iso": null, "subdivision_1_name": null, "subdivision_2_name": null, "registered_country_iso": null}, "chrome_extensions": "", "execution_time_ms": 371.0999999642372, "graphics_renderer": "WebGL 2.0 (OpenGL ES 3.0 Chromium)", "keyboard_detected": true, "mime_types_length": 2, "request_fs_exists": true, "audio_context_flag": "pass", "audio_context_hash": "9307303774dec3248c18a939392090da", "canvas_fingerprint": 258, "canvas_pixel_check": false, "device_pixel_ratio": 1.0, "indexedDbData_blob": true, "navigator_keys_len": 79, "no_edge_pdf_plugin": false, "screen_avail_width": 2560, "webdriver_detected": false, "window_orientation": 0, "connection_downlink": 10.0, "navigator_webdriver": false, "non_native_function": false, "screen_avail_height": 1400, "supported_fonts_str": "72|768|262144|1073741824|0|0|540672|73728|7340032|1342177280|117446656|256|16|0|543|4290797636|1677723648|4168998400|0|1048576|262144|268500994|1342177280|262144|125829376|37888000|0|435363842|0|2147483648|109543424|1880099872|268435471", "text_2d_fingerprint": "bfcce91c9e71d11af7b14dbee4c75f83", "webrtc_is_supported": "pass", "canvas_support_level": "full", "do_not_track_enabled": "1", "hardware_concurrency": 12, "keyboard_layout_size": 48, "prefers_color_scheme": false, "webgl_max_anisotropy": 16, "battery_charging_time": 0.0, "browser_by_properties": "c", "eval_to_string_length": 33, "performance_loop_time": 0.09999996423721313, "session_storage_check": "pass", "unmasked_vendor_webgl": "Google Inc. (Intel)", "hardware_concurrency_2": 12, "hardware_concurrency_3": null, "localStorage_available": true, "memory_jsHeapSizeLimit": 4294705152, "mozilla_web_app_exists": false, "navigator_deviceMemory": 8.0, "navigator_java_enabled": false, "prefers_reduced_motion": false, "storage_estimate_quota": 1178717110272, "webdriver_detected_msg": "", "window_active_x_object": false, "window_external_exists": true, "color_depth_pixel_depth": "24-24", "indexedDbData_available": true, "navigator_cookieEnabled": true, "unmasked_renderer_webgl": "ANGLE (Intel, Mesa Intel(R) Graphics (RPL-P), OpenGL 4.6)", "battery_discharging_time": 0.0, "connection_effectiveType": "4g", "non_native_function_flag": "", "speech_synthesis_voice_1": "Google Bahasa Indonesia", "window_client_information": true, "audio_compressor_reduction": 20.538288116455078, "navigator_mediaDevices_len": 3, "audio_intensity_fingerprint": 124.04347527516074, "speech_synthesis_voice_hash": "8010ee3313813de521e48e63bd5a6f13", "microsoft_credentials_exists": false, "window_installTrigger_exists": false, "speech_synthesis_voices_count": 19, "webgl_shading_language_version": "WebGL GLSL ES 3.00 (OpenGL ES GLSL ES 3.0 Chromium)", "error_message_stack_access_count": 0, "speech_synthesis_avail_voices_count": 19, "error_message_stack_access_count_worker": 0}""" + ), + "result_data": GrlIqCheckerResults.model_validate_json( + """{"uuid": "b11441da5a854dfbb8401d4c32e56db5", "check_codecs": {"score": 0}, "check_timezone": {"score": 0}, "check_timestamp": {"score": 0}, "check_user_type": {"score": 0}, "check_ip_changes": {"score": 0}, "check_ip_country": {"score": 0}, "check_environment": {"score": 0}, "check_ip_timezone": {"score": 0}, "check_isp_changes": {"score": 0}, "check_useragent_js": {"score": 0}, "check_required_fonts": {"score": 0}, "check_user_anonymous": {"score": 0}, "check_webrtc_success": {"score": 0}, "check_seen_timestamps": {"msg": "duplicate timestamp", "score": 100}, "check_country_timezone": {"score": 0}, "check_prohibited_fonts": {"score": 0}, "check_timezone_changes": {"score": 0}, "check_execution_time_ms": {"msg": "duplicate execution_time_ms", "score": 100}, "check_fingerprint_reuse": {"score": 0}, "check_fingerprint_cycling": {"score": 0}, "check_ip_webrtc_ip_detail": {"score": 0}, "check_environment_critical": {"score": 0}, "check_useragent_other_enums": {"score": 0}, "check_useragent_ip_properties": {"score": 0}, "check_useragent_data_properties": {"score": 0}, "check_useragent_device_family_brand": {"score": 0}}""" + ), + "category_result": GrlIqForensicCategoryResult.model_validate_json( + """{"uuid": "b11441da5a854dfbb8401d4c32e56db5", "is_bot": 0, "is_tampered": 100, "is_velocity": 0, "is_anonymous": 0, "suspicious_ip": 0, "is_oscillating": 0, "is_teleporting": 0, "is_inconsistent": 0, "platform_ip_inconsistent": 0}""" + ), + "fraud_score": 100, + "is_attempt_allowed": False, + }, + { + "data": GrlIqData.model_validate_json( + """{"mid": "35f6f5c30bc74ea7ac4aca7b40a02352", "uuid": "d54509f2f310499f8ab74839b10b2a41", "phase": "offerwall-enter", "events": null, "vendor": "Google Inc.", "app_name": "Netscape", "calendar": "gregory", "language": "en-US", "platform": "Linux x86_64", "timezone": "America/Los_Angeles", "client_ip": "104.9.125.144", "timestamp": "2025-02-28T11:34:39-08:00", "webrtc_ip": "172.56.209.195", "created_at": "2025-02-28T19:34:39.681872Z", "language_2": "en-US", "language_3": null, "platform_2": "Linux x86_64", "platform_3": null, "prefetched": true, "product_id": "d0606a0b5d034a8d81b1e3579d1f76fd", "webgl_flag": true, "webgl_hash": "da27e1b9b660057a3f5e185d3f5deabe", "canvas_hash": "e6e4d17da26050ce85ad00d3c6ea999e", "color_gamut": "3", "country_iso": "us", "inner_width": 841, "outer_width": 1680, "product_sub": "20030107", "audio_codecs": "1,1,1,1,1,3,1,3,1,3,3,1,1,3,3,3,3,1,3,3,3,2,1,1", "cookie_check": "", "graphics_api": "WebKit WebGL", "inner_height": 891, "mouse_events": null, "ontouchstart": false, "outer_height": 978, "plugins_hash": "4c05fa2f766a444d4f253ead792c8b0e|2", "screen_width": 1680, "video_codecs": "1,3,3,3,3,3,3,3,3,3,1,1,1,1,1,1,3,1,1,1,3,3,1", "webgl_hash_2": "fc73fd5db75e2c36222fe34251be3971", "webrtc_error": false, "window_opera": false, "battery_level": 0.41, "canvas_hash_2": "e0559d49b1864985cafc0d1c3a6b053c", "dynamic_range": false, "error_message": "Cannot read", "forced_colors": false, "math_result_1": "1.9275814160560204e-50", "math_result_2": "1.6182817135715877", "screen_height": 1050, "webgl_check_1": true, "webgl_context": "webgl2", "window_chrome": true, "connection_rtt": 100, "history_length": 11, "user_agent_str": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", "web_sql_exists": false, "calender_locale": "en-US", "connection_type": "", "inverted_colors": true, "navigator_brave": false, "product_user_id": "test-unit", "request_headers": {"dnt": "1", "host": "127.0.0.1:8081", "accept": "application/json, lk/null q=0.1", "origin": "http://127.0.0.1:8080", "referer": "http://127.0.0.1:8080/", "sec-ch-ua": "\\"Google Chrome\\";v=\\"131\\", \\"Chromium\\";v=\\"131\\", \\"Not_A Brand\\";v=\\"24\\"", "connection": "keep-alive", "user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", "content-type": "application/json", "content-length": "3313", "sec-fetch-dest": "empty", "sec-fetch-mode": "cors", "sec-fetch-site": "same-site", "accept-encoding": "gzip, deflate, br, zstd", "accept-language": "en-US,en;q=0.9", "sec-ch-ua-mobile": "?0", "sec-ch-ua-platform": "\\"Linux\\""}, "timezone_offset": 480, "webrtc_local_ip": "10.253.217.45,[2607:fb91:20c5:c6af:cda0:10b4:830a:a85e]", "battery_charging": false, "client_ip_detail": {"continent_code": "EU", "continent_name": "Europe", "country_name": "France", "is_in_european_union": true, "ip": "104.9.125.144", "isp": "AT&T Internet", "latitude": 37.3897, "city_name": "Mountain View", "longitude": -122.083, "time_zone": "America/Los_Angeles", "user_type": "residential", "country_iso": "us", "postal_code": "94041", "is_anonymous": false, "accuracy_radius": 5, "static_ip_score": 40.3, "subdivision_1_iso": "CA", "subdivision_2_iso": null, "subdivision_1_name": "California", "subdivision_2_name": null, "registered_country_iso": "us"}, "max_touch_points": 0, "numbering_system": "latn", "path_fingerprint": 3252, "prefers_contrast": "0", "rendering_engine": "WebKit", "timezone_success": "pass", "user_agent_hints": {"model": null, "brands": [{"brand": "Google Chrome", "version": "131"}, {"brand": "Chromium", "version": "131"}, {"brand": "Not_A Brand", "version": "24"}], "mobile": false, "bitness": "64", "platform": "Linux", "brands_full": [{"brand": "Google Chrome", "version": "131.0.6778.204"}, {"brand": "Chromium", "version": "131.0.6778.204"}, {"brand": "Not_A Brand", "version": "24.0.0.0"}], "architecture": "x86", "platform_version": "6.2.0"}, "user_agent_str_2": null, "webgl_extensions": "EXT_clip_control|EXT_color_buffer_float|EXT_color_buffer_half_float|EXT_conservative_depth|EXT_depth_clamp|EXT_disjoint_timer_query_webgl2|EXT_float_blend|EXT_polygon_offset_clamp|EXT_render_snorm|EXT_texture_compression_bptc|EXT_texture_compression_rgtc|EXT_texture_filter_anisotropic|EXT_texture_mirror_clamp_to_edge|EXT_texture_norm16|KHR_parallel_shader_compile|NV_shader_noperspective_interpolation|OES_draw_buffers_indexed|OES_sample_variables|OES_shader_multisample_interpolation|OES_texture_float_linear|OVR_multiview2|WEBGL_blend_func_extended|WEBGL_clip_cull_distance|WEBGL_compressed_texture_astc|WEBGL_compressed_texture_etc|WEBGL_compressed_texture_etc1|WEBGL_compressed_texture_s3tc|WEBGL_compressed_texture_s3tc_srgb|WEBGL_debug_renderer_info|WEBGL_debug_shaders|WEBGL_lose_context|WEBGL_multi_draw|WEBGL_polygon_mode|WEBGL_provoking_vertex|WEBGL_stencil_texturing", "webrtc_ip_detail": {"continent_code": "EU", "continent_name": "Europe", "country_name": "France", "is_in_european_union": true, "ip": "172.56.209.195", "isp": null, "latitude": null, "city_name": null, "longitude": null, "time_zone": null, "user_type": null, "country_iso": "us", "postal_code": null, "is_anonymous": null, "accuracy_radius": null, "static_ip_score": null, "subdivision_1_iso": null, "subdivision_2_iso": null, "subdivision_1_name": null, "subdivision_2_name": null, "registered_country_iso": null}, "chrome_extensions": "", "execution_time_ms": 924.5, "graphics_renderer": "WebGL 2.0 (OpenGL ES 3.0 Chromium)", "keyboard_detected": true, "mime_types_length": 2, "request_fs_exists": true, "audio_context_flag": "pass", "audio_context_hash": "9307303774dec3248c18a939392090da", "canvas_fingerprint": 258, "canvas_pixel_check": false, "device_pixel_ratio": 1.0, "indexedDbData_blob": true, "navigator_keys_len": 79, "no_edge_pdf_plugin": false, "screen_avail_width": 1680, "webdriver_detected": false, "window_orientation": 0, "connection_downlink": 10.0, "navigator_webdriver": false, "non_native_function": false, "screen_avail_height": 1010, "supported_fonts_str": "72|17152|327680|1073741824|0|0|540736|73728|7340032|1342177280|117446657|256|16|0|262687|4290797636|1677723648|4168998400|0|1048576|262144|268500994|1342177280|262144|125829376|37888000|0|435363842|0|2147483648|109543680|1880099888|301989903", "text_2d_fingerprint": "bfcce91c9e71d11af7b14dbee4c75f83", "webrtc_is_supported": "pass", "canvas_support_level": "full", "do_not_track_enabled": "1", "hardware_concurrency": 12, "keyboard_layout_size": 48, "prefers_color_scheme": false, "webgl_max_anisotropy": 16, "battery_charging_time": 0.0, "browser_by_properties": "c", "eval_to_string_length": 33, "performance_loop_time": 0.09999999962747097, "session_storage_check": "pass", "unmasked_vendor_webgl": "Google Inc. (Intel)", "hardware_concurrency_2": 12, "hardware_concurrency_3": null, "localStorage_available": true, "memory_jsHeapSizeLimit": 4294705152, "mozilla_web_app_exists": false, "navigator_deviceMemory": 8.0, "navigator_java_enabled": false, "prefers_reduced_motion": false, "storage_estimate_quota": 1178717110272, "webdriver_detected_msg": "", "window_active_x_object": false, "window_external_exists": true, "color_depth_pixel_depth": "24-24", "indexedDbData_available": true, "navigator_cookieEnabled": true, "unmasked_renderer_webgl": "ANGLE (Intel, Mesa Intel(R) Graphics (RPL-P), OpenGL 4.6)", "battery_discharging_time": 4844.0, "connection_effectiveType": "4g", "non_native_function_flag": "", "speech_synthesis_voice_1": "Google Bahasa Indonesia", "window_client_information": true, "audio_compressor_reduction": 20.538288116455078, "navigator_mediaDevices_len": 8, "audio_intensity_fingerprint": 124.04347527516074, "speech_synthesis_voice_hash": "8010ee3313813de521e48e63bd5a6f13", "microsoft_credentials_exists": false, "window_installTrigger_exists": false, "speech_synthesis_voices_count": 19, "webgl_shading_language_version": "WebGL GLSL ES 3.00 (OpenGL ES GLSL ES 3.0 Chromium)", "error_message_stack_access_count": 2, "speech_synthesis_avail_voices_count": 19, "error_message_stack_access_count_worker": 2}""" + ), + "result_data": GrlIqCheckerResults.model_validate_json( + """{"uuid": "d54509f2f310499f8ab74839b10b2a41", "check_codecs": {"score": 0}, "check_timezone": {"score": 0}, "check_timestamp": {"score": 0}, "check_user_type": {"score": 0}, "check_ip_changes": {"score": 0}, "check_ip_country": {"score": 0}, "check_environment": {"msg": "error_message_stack_access_count: 2", "score": 100}, "check_ip_timezone": {"score": 0}, "check_isp_changes": {"score": 0}, "check_useragent_js": {"score": 0}, "check_required_fonts": {"score": 0}, "check_user_anonymous": {"score": 0}, "check_webrtc_success": {"score": 0}, "check_seen_timestamps": {"score": 0}, "check_country_timezone": {"score": 0}, "check_prohibited_fonts": {"score": 0}, "check_timezone_changes": {"score": 0}, "check_execution_time_ms": {"score": 0}, "check_fingerprint_reuse": {"score": 0}, "check_fingerprint_cycling": {"score": 0}, "check_ip_webrtc_ip_detail": {"score": 0}, "check_environment_critical": {"score": 0}, "check_useragent_other_enums": {"score": 0}, "check_useragent_ip_properties": {"score": 0}, "check_useragent_data_properties": {"score": 0}, "check_useragent_device_family_brand": {"score": 0}}""" + ), + "category_result": GrlIqForensicCategoryResult.model_validate_json( + """{"uuid": "d54509f2f310499f8ab74839b10b2a41", "is_bot": 0, "is_tampered": 0, "is_velocity": 0, "is_anonymous": 0, "suspicious_ip": 0, "is_oscillating": 0, "is_teleporting": 0, "is_inconsistent": 10, "platform_ip_inconsistent": 0}""" + ), + "fraud_score": 10, + "is_attempt_allowed": True, + }, +] diff --git a/generalresearch/grliq/managers/colormap.py b/generalresearch/grliq/managers/colormap.py new file mode 100644 index 0000000..dc55c76 --- /dev/null +++ b/generalresearch/grliq/managers/colormap.py @@ -0,0 +1,263 @@ +# To avoid a matplotlib dependency +import numpy as np + +turbo_colormap_data = np.array( + [ + [0.18995, 0.07176, 0.23217], + [0.19483, 0.08339, 0.26149], + [0.19956, 0.09498, 0.29024], + [0.20415, 0.10652, 0.31844], + [0.20860, 0.11802, 0.34607], + [0.21291, 0.12947, 0.37314], + [0.21708, 0.14087, 0.39964], + [0.22111, 0.15223, 0.42558], + [0.22500, 0.16354, 0.45096], + [0.22875, 0.17481, 0.47578], + [0.23236, 0.18603, 0.50004], + [0.23582, 0.19720, 0.52373], + [0.23915, 0.20833, 0.54686], + [0.24234, 0.21941, 0.56942], + [0.24539, 0.23044, 0.59142], + [0.24830, 0.24143, 0.61286], + [0.25107, 0.25237, 0.63374], + [0.25369, 0.26327, 0.65406], + [0.25618, 0.27412, 0.67381], + [0.25853, 0.28492, 0.69300], + [0.26074, 0.29568, 0.71162], + [0.26280, 0.30639, 0.72968], + [0.26473, 0.31706, 0.74718], + [0.26652, 0.32768, 0.76412], + [0.26816, 0.33825, 0.78050], + [0.26967, 0.34878, 0.79631], + [0.27103, 0.35926, 0.81156], + [0.27226, 0.36970, 0.82624], + [0.27334, 0.38008, 0.84037], + [0.27429, 0.39043, 0.85393], + [0.27509, 0.40072, 0.86692], + [0.27576, 0.41097, 0.87936], + [0.27628, 0.42118, 0.89123], + [0.27667, 0.43134, 0.90254], + [0.27691, 0.44145, 0.91328], + [0.27701, 0.45152, 0.92347], + [0.27698, 0.46153, 0.93309], + [0.27680, 0.47151, 0.94214], + [0.27648, 0.48144, 0.95064], + [0.27603, 0.49132, 0.95857], + [0.27543, 0.50115, 0.96594], + [0.27469, 0.51094, 0.97275], + [0.27381, 0.52069, 0.97899], + [0.27273, 0.53040, 0.98461], + [0.27106, 0.54015, 0.98930], + [0.26878, 0.54995, 0.99303], + [0.26592, 0.55979, 0.99583], + [0.26252, 0.56967, 0.99773], + [0.25862, 0.57958, 0.99876], + [0.25425, 0.58950, 0.99896], + [0.24946, 0.59943, 0.99835], + [0.24427, 0.60937, 0.99697], + [0.23874, 0.61931, 0.99485], + [0.23288, 0.62923, 0.99202], + [0.22676, 0.63913, 0.98851], + [0.22039, 0.64901, 0.98436], + [0.21382, 0.65886, 0.97959], + [0.20708, 0.66866, 0.97423], + [0.20021, 0.67842, 0.96833], + [0.19326, 0.68812, 0.96190], + [0.18625, 0.69775, 0.95498], + [0.17923, 0.70732, 0.94761], + [0.17223, 0.71680, 0.93981], + [0.16529, 0.72620, 0.93161], + [0.15844, 0.73551, 0.92305], + [0.15173, 0.74472, 0.91416], + [0.14519, 0.75381, 0.90496], + [0.13886, 0.76279, 0.89550], + [0.13278, 0.77165, 0.88580], + [0.12698, 0.78037, 0.87590], + [0.12151, 0.78896, 0.86581], + [0.11639, 0.79740, 0.85559], + [0.11167, 0.80569, 0.84525], + [0.10738, 0.81381, 0.83484], + [0.10357, 0.82177, 0.82437], + [0.10026, 0.82955, 0.81389], + [0.09750, 0.83714, 0.80342], + [0.09532, 0.84455, 0.79299], + [0.09377, 0.85175, 0.78264], + [0.09287, 0.85875, 0.77240], + [0.09267, 0.86554, 0.76230], + [0.09320, 0.87211, 0.75237], + [0.09451, 0.87844, 0.74265], + [0.09662, 0.88454, 0.73316], + [0.09958, 0.89040, 0.72393], + [0.10342, 0.89600, 0.71500], + [0.10815, 0.90142, 0.70599], + [0.11374, 0.90673, 0.69651], + [0.12014, 0.91193, 0.68660], + [0.12733, 0.91701, 0.67627], + [0.13526, 0.92197, 0.66556], + [0.14391, 0.92680, 0.65448], + [0.15323, 0.93151, 0.64308], + [0.16319, 0.93609, 0.63137], + [0.17377, 0.94053, 0.61938], + [0.18491, 0.94484, 0.60713], + [0.19659, 0.94901, 0.59466], + [0.20877, 0.95304, 0.58199], + [0.22142, 0.95692, 0.56914], + [0.23449, 0.96065, 0.55614], + [0.24797, 0.96423, 0.54303], + [0.26180, 0.96765, 0.52981], + [0.27597, 0.97092, 0.51653], + [0.29042, 0.97403, 0.50321], + [0.30513, 0.97697, 0.48987], + [0.32006, 0.97974, 0.47654], + [0.33517, 0.98234, 0.46325], + [0.35043, 0.98477, 0.45002], + [0.36581, 0.98702, 0.43688], + [0.38127, 0.98909, 0.42386], + [0.39678, 0.99098, 0.41098], + [0.41229, 0.99268, 0.39826], + [0.42778, 0.99419, 0.38575], + [0.44321, 0.99551, 0.37345], + [0.45854, 0.99663, 0.36140], + [0.47375, 0.99755, 0.34963], + [0.48879, 0.99828, 0.33816], + [0.50362, 0.99879, 0.32701], + [0.51822, 0.99910, 0.31622], + [0.53255, 0.99919, 0.30581], + [0.54658, 0.99907, 0.29581], + [0.56026, 0.99873, 0.28623], + [0.57357, 0.99817, 0.27712], + [0.58646, 0.99739, 0.26849], + [0.59891, 0.99638, 0.26038], + [0.61088, 0.99514, 0.25280], + [0.62233, 0.99366, 0.24579], + [0.63323, 0.99195, 0.23937], + [0.64362, 0.98999, 0.23356], + [0.65394, 0.98775, 0.22835], + [0.66428, 0.98524, 0.22370], + [0.67462, 0.98246, 0.21960], + [0.68494, 0.97941, 0.21602], + [0.69525, 0.97610, 0.21294], + [0.70553, 0.97255, 0.21032], + [0.71577, 0.96875, 0.20815], + [0.72596, 0.96470, 0.20640], + [0.73610, 0.96043, 0.20504], + [0.74617, 0.95593, 0.20406], + [0.75617, 0.95121, 0.20343], + [0.76608, 0.94627, 0.20311], + [0.77591, 0.94113, 0.20310], + [0.78563, 0.93579, 0.20336], + [0.79524, 0.93025, 0.20386], + [0.80473, 0.92452, 0.20459], + [0.81410, 0.91861, 0.20552], + [0.82333, 0.91253, 0.20663], + [0.83241, 0.90627, 0.20788], + [0.84133, 0.89986, 0.20926], + [0.85010, 0.89328, 0.21074], + [0.85868, 0.88655, 0.21230], + [0.86709, 0.87968, 0.21391], + [0.87530, 0.87267, 0.21555], + [0.88331, 0.86553, 0.21719], + [0.89112, 0.85826, 0.21880], + [0.89870, 0.85087, 0.22038], + [0.90605, 0.84337, 0.22188], + [0.91317, 0.83576, 0.22328], + [0.92004, 0.82806, 0.22456], + [0.92666, 0.82025, 0.22570], + [0.93301, 0.81236, 0.22667], + [0.93909, 0.80439, 0.22744], + [0.94489, 0.79634, 0.22800], + [0.95039, 0.78823, 0.22831], + [0.95560, 0.78005, 0.22836], + [0.96049, 0.77181, 0.22811], + [0.96507, 0.76352, 0.22754], + [0.96931, 0.75519, 0.22663], + [0.97323, 0.74682, 0.22536], + [0.97679, 0.73842, 0.22369], + [0.98000, 0.73000, 0.22161], + [0.98289, 0.72140, 0.21918], + [0.98549, 0.71250, 0.21650], + [0.98781, 0.70330, 0.21358], + [0.98986, 0.69382, 0.21043], + [0.99163, 0.68408, 0.20706], + [0.99314, 0.67408, 0.20348], + [0.99438, 0.66386, 0.19971], + [0.99535, 0.65341, 0.19577], + [0.99607, 0.64277, 0.19165], + [0.99654, 0.63193, 0.18738], + [0.99675, 0.62093, 0.18297], + [0.99672, 0.60977, 0.17842], + [0.99644, 0.59846, 0.17376], + [0.99593, 0.58703, 0.16899], + [0.99517, 0.57549, 0.16412], + [0.99419, 0.56386, 0.15918], + [0.99297, 0.55214, 0.15417], + [0.99153, 0.54036, 0.14910], + [0.98987, 0.52854, 0.14398], + [0.98799, 0.51667, 0.13883], + [0.98590, 0.50479, 0.13367], + [0.98360, 0.49291, 0.12849], + [0.98108, 0.48104, 0.12332], + [0.97837, 0.46920, 0.11817], + [0.97545, 0.45740, 0.11305], + [0.97234, 0.44565, 0.10797], + [0.96904, 0.43399, 0.10294], + [0.96555, 0.42241, 0.09798], + [0.96187, 0.41093, 0.09310], + [0.95801, 0.39958, 0.08831], + [0.95398, 0.38836, 0.08362], + [0.94977, 0.37729, 0.07905], + [0.94538, 0.36638, 0.07461], + [0.94084, 0.35566, 0.07031], + [0.93612, 0.34513, 0.06616], + [0.93125, 0.33482, 0.06218], + [0.92623, 0.32473, 0.05837], + [0.92105, 0.31489, 0.05475], + [0.91572, 0.30530, 0.05134], + [0.91024, 0.29599, 0.04814], + [0.90463, 0.28696, 0.04516], + [0.89888, 0.27824, 0.04243], + [0.89298, 0.26981, 0.03993], + [0.88691, 0.26152, 0.03753], + [0.88066, 0.25334, 0.03521], + [0.87422, 0.24526, 0.03297], + [0.86760, 0.23730, 0.03082], + [0.86079, 0.22945, 0.02875], + [0.85380, 0.22170, 0.02677], + [0.84662, 0.21407, 0.02487], + [0.83926, 0.20654, 0.02305], + [0.83172, 0.19912, 0.02131], + [0.82399, 0.19182, 0.01966], + [0.81608, 0.18462, 0.01809], + [0.80799, 0.17753, 0.01660], + [0.79971, 0.17055, 0.01520], + [0.79125, 0.16368, 0.01387], + [0.78260, 0.15693, 0.01264], + [0.77377, 0.15028, 0.01148], + [0.76476, 0.14374, 0.01041], + [0.75556, 0.13731, 0.00942], + [0.74617, 0.13098, 0.00851], + [0.73661, 0.12477, 0.00769], + [0.72686, 0.11867, 0.00695], + [0.71692, 0.11268, 0.00629], + [0.70680, 0.10680, 0.00571], + [0.69650, 0.10102, 0.00522], + [0.68602, 0.09536, 0.00481], + [0.67535, 0.08980, 0.00449], + [0.66449, 0.08436, 0.00424], + [0.65345, 0.07902, 0.00408], + [0.64223, 0.07380, 0.00401], + [0.63082, 0.06868, 0.00401], + [0.61923, 0.06367, 0.00410], + [0.60746, 0.05878, 0.00427], + [0.59550, 0.05399, 0.00453], + [0.58336, 0.04931, 0.00486], + [0.57103, 0.04474, 0.00529], + [0.55852, 0.04028, 0.00579], + [0.54583, 0.03593, 0.00638], + [0.53295, 0.03169, 0.00705], + [0.51989, 0.02756, 0.00780], + [0.50664, 0.02354, 0.00863], + [0.49321, 0.01963, 0.00955], + [0.47960, 0.01583, 0.01055], + ] +) diff --git a/generalresearch/grliq/managers/event_plotter.py b/generalresearch/grliq/managers/event_plotter.py new file mode 100644 index 0000000..b879c5c --- /dev/null +++ b/generalresearch/grliq/managers/event_plotter.py @@ -0,0 +1,158 @@ +import html +from typing import List +import webbrowser +import numpy as np +from more_itertools import windowed +from scipy.spatial.distance import euclidean + +from generalresearch.grliq.managers.colormap import turbo_colormap_data +from generalresearch.grliq.models.events import MouseEvent, KeyboardEvent + + +def make_events_svg( + mouse_events: List[MouseEvent], keyboard_events: List[KeyboardEvent] +) -> str: + if len(mouse_events) + len(keyboard_events) == 0: + return f'\n' + "\n" + + t = np.array([pm.timeStamp for pm in mouse_events]) + t_diff = t.max() - t.min() + for x in mouse_events: + if x.type in {"pointerdown", "pointerup"} and x.pointerType == "touch": + x.type = "pointermove" + move_events = [x for x in mouse_events if x.type == "pointermove"] + clicks = [x for x in mouse_events if x.type == "click"] + click_type = ( + "touch" if any(x.pointerType == "touch" for x in mouse_events) else "mouse" + ) + + svg_elements = [] + for ee in windowed(move_events, 2): + e1 = ee[0] + e2 = ee[1] + ts_idx = (e2.timeStamp - t.min()) / t_diff + r, g, b = turbo_colormap_data[round(ts_idx * 255)] + color = f"rgb({int(r*255)},{int(g*255)},{int(b*255)})" + svg_elements.append( + f'' + ) + for c in clicks: + cx = c.pageX + cy = c.pageY + if cx is not None and cy is not None: + ts_idx = (c.timeStamp - t.min()) / t_diff + r, g, b = turbo_colormap_data[round(ts_idx * 255)] + color = f"rgb({int(r*255)},{int(g*255)},{int(b*255)})" + if c._elementBounds is not None: + b = c._elementBounds + svg_elements.append( + f'' + ) + if click_type == "mouse": + svg_elements.append( + f'' + ) + else: + # Inner solid red circle + svg_elements.append( + f'' + ) + # Middle semi-transparent larger circle + svg_elements.append( + f'' + ) + # Outer faint larger circle with more transparency + svg_elements.append( + f'' + ) + + groups = group_input_events_by_xy( + mouse_events=mouse_events, keyboard_events=keyboard_events + ) + + for group in groups: + cx, cy = group[0] + text = "".join(group[1]) + text = text.replace("DELETECONTENTBACKWARD", "BACKSPACE") + text = text.replace(">", ">\n") + if len(text) > 5: + font_size = 10 + else: + font_size = 20 + svg_elements.append(svg_multiline_text(text, cx + 5, cy - 5, font_size)) + + svg = ( + f'' + + "\n".join(svg_elements) + + "\n" + ) + return svg + + +def view_plot(svg: str): + fp = "/tmp/test.svg" + with open(fp, "w") as f: + f.write(svg) + webbrowser.open("file://" + fp) + + +def svg_multiline_text( + text: str, x: float, y: float, font_size: int = 20, line_spacing: float = 1.2 +) -> str: + lines = html.escape(text).split("\n") + tspan_elements = [f'{lines[0]}'] + [ + f'{line}' + for line in lines[1:] + ] + return ( + f'' + f"{''.join(tspan_elements)}" + ) + + +def group_input_events_by_xy( + mouse_events: List[MouseEvent], keyboard_events: List[KeyboardEvent] +): + """ + Each keypress is its own event. For plotting, we want to group together + all keypresses that were made when the mouse was at the same position, + and then concat them. Otherwise, if we just plot each letter at the position + where the mouse was, then if the mouse doesn't move, all letter will be on top + of each other. + """ + groups = [] + last_pos = None + last_time = None + current_chars = [] + for e in keyboard_events: + # Get the most recent position of the mouse in the time before the key was pressed + mouse_events_before = [x for x in mouse_events if x.timeStamp < e.timeStamp] + if mouse_events_before: + mouse_event = mouse_events_before[-1] + cx = mouse_event.pageX + cy = mouse_event.pageY + else: + cx, cy = 0, 0 + char = e.text + if not char: + continue + if last_time is None: + last_time = e.timeStamp + if last_pos is None: + last_pos = (cx, cy) + current_chars.append(char) + else: + if abs(last_time - e.timeStamp) < 2000: + char = char + "\n" + if euclidean((cx, cy), last_pos) < 20: + current_chars.append(char) + else: + groups.append((last_pos, current_chars)) + current_chars = [char] + last_pos = (cx, cy) + if current_chars: + groups.append((last_pos, current_chars)) + return groups diff --git a/generalresearch/grliq/managers/forensic_data.py b/generalresearch/grliq/managers/forensic_data.py new file mode 100644 index 0000000..0f58d54 --- /dev/null +++ b/generalresearch/grliq/managers/forensic_data.py @@ -0,0 +1,794 @@ +from datetime import datetime, timezone +from typing import Optional, List, Collection, Dict, Tuple, Any +from uuid import uuid4 + +from psycopg import sql +from pydantic import PositiveInt, NonNegativeInt + +from generalresearch.grliq.managers import DUMMY_GRLIQ_DATA +from generalresearch.grliq.models.events import PointerMove, TimingData +from generalresearch.grliq.models.forensic_data import GrlIqData +from generalresearch.grliq.models.forensic_result import ( + GrlIqForensicCategoryResult, + GrlIqCheckerResults, + Phase, +) +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.thl.user import User +from generalresearch.pg_helper import PostgresConfig + + +class GrlIqDataManager: + + def __init__(self, postgres_config: PostgresConfig): + self.postgres_config = postgres_config + + def create_dummy( + self, + is_attempt_allowed: True, + product_id: Optional[str] = None, + product_user_id: Optional[str] = None, + uuid: Optional[str] = None, + mid: Optional[str] = None, + created_at: Optional[datetime] = None, + ) -> GrlIqData: + """ + Creates a dummy record in the db with a GrlIqData (data), GrlIqCheckerResults (result_data), + and GrlIqForensicCategoryResult (category_results) + :param is_attempt_allowed: Whether the attempt is allowed. + :param product_id: product_id of user + :param product_user_id: product_user_id of user + :param uuid: uuid for the grliq data record + :param mid: the thl_session:uuid / mid for the attempt. + :return: + """ + import copy + + res: GrlIqData = copy.deepcopy(DUMMY_GRLIQ_DATA[int(is_attempt_allowed)]) + + product_id = product_id or uuid4().hex + product_user_id = product_user_id or uuid4().hex + uuid = uuid or uuid4().hex + mid = mid or uuid4().hex + created_at = created_at or datetime.now(tz=timezone.utc) + + res["data"].product_id = product_id + res["data"].product_user_id = product_user_id + res["data"].uuid = uuid + res["data"].mid = mid + res["data"].created_at = created_at + res["result_data"].uuid = uuid + res["category_result"].uuid = uuid + + return self.create( + iq_data=res["data"], + result_data=res["result_data"], + category_result=res["category_result"], + fraud_score=res["category_result"].fraud_score, + is_attempt_allowed=res["category_result"].is_attempt_allowed(), + ) + + def create( + self, + iq_data: GrlIqData, + result_data: Optional[GrlIqCheckerResults] = None, + category_result: Optional[GrlIqForensicCategoryResult] = None, + fraud_score: Optional[int] = None, + is_attempt_allowed: Optional[bool] = None, + ) -> GrlIqData: + + data = iq_data.model_dump_sql(exclude={"events", "mouse_events", "timing_data"}) + + data["result_data"] = None + if result_data: + data["result_data"] = result_data.model_dump_json( + exclude_none=True, exclude={"is_complete"} + ) + iq_data.results = result_data + + data["category_result"] = None + if category_result: + data["category_result"] = category_result.model_dump_json() + iq_data.category_result = category_result + + data["fingerprint"] = iq_data.fingerprint + data["fraud_score"] = fraud_score + data["is_attempt_allowed"] = is_attempt_allowed + + query = sql.SQL( + """ + INSERT INTO grliq_forensicdata + (uuid, session_uuid, created_at, product_id, product_user_id, + country_iso, client_ip, ua_browser_family, ua_browser_version, + ua_os_family, ua_os_version, ua_device_family, ua_device_brand, + ua_device_model, ua_hash, data, phase, + fingerprint, fraud_score, is_attempt_allowed, + result_data, category_result) + VALUES + (%(uuid)s, %(session_uuid)s, %(created_at)s, %(product_id)s, %(product_user_id)s, + %(country_iso)s, %(client_ip)s, %(ua_browser_family)s, %(ua_browser_version)s, + %(ua_os_family)s, %(ua_os_version)s, %(ua_device_family)s, %(ua_device_brand)s, + %(ua_device_model)s, %(ua_hash)s, %(data)s, %(phase)s, + %(fingerprint)s, %(fraud_score)s, %(is_attempt_allowed)s, + %(result_data)s, %(category_result)s) + RETURNING id + """ + ) + + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + pk = c.fetchone()["id"] + conn.commit() + + iq_data.id = pk + + return iq_data + + def set_results( + self, + uuid: UUIDStr, + result_data: GrlIqCheckerResults, + category_result: GrlIqForensicCategoryResult, + fingerprint: Optional[str] = None, + fraud_score: Optional[int] = None, + is_attempt_allowed: Optional[bool] = None, + ) -> None: + data = {"uuid": uuid} + data["result_data"] = result_data.model_dump_json(exclude_none=True) + data["category_result"] = category_result.model_dump_json() + data["fingerprint"] = fingerprint + data["fraud_score"] = fraud_score + data["is_attempt_allowed"] = is_attempt_allowed + + query = sql.SQL( + """ + UPDATE grliq_forensicdata + SET result_data = %(result_data)s, + category_result = %(category_result)s, + fingerprint = %(fingerprint)s, + fraud_score = %(fraud_score)s, + is_attempt_allowed = %(is_attempt_allowed)s + WHERE uuid = %(uuid)s + """ + ) + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + if c.rowcount != 1: + raise ValueError( + f"Expected 1 row to be updated, but {c.rowcount} rows were affected." + ) + conn.commit() + + return None + + def update_fingerprint(self, iq_data: GrlIqData) -> None: + # We should only run this if we modified the fingerprint algorithm + if "fingerprint" in iq_data.__dict__: + # make sure it's not cached + del iq_data.__dict__["fingerprint"] + data = {"uuid": iq_data.uuid, "fingerprint": iq_data.fingerprint} + query = sql.SQL( + """ + UPDATE grliq_forensicdata + SET fingerprint = %(fingerprint)s + WHERE uuid = %(uuid)s + """ + ) + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + if c.rowcount != 1: + raise ValueError( + f"Expected 1 row to be updated, but {c.rowcount} rows were affected." + ) + conn.commit() + + def update_data(self, iq_data: GrlIqData) -> None: + # We should only run this if we structured new fields and want to + # back-populate them in the db + data = {"id": iq_data.id, "data": iq_data.model_dump_sql()["data"]} + query = sql.SQL( + """ + UPDATE grliq_forensicdata + SET data = %(data)s + WHERE id = %(id)s + """ + ) + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, data) + if c.rowcount != 1: + raise ValueError( + f"Expected 1 row to be updated, but {c.rowcount} rows were affected." + ) + conn.commit() + + def get_data_if_exists( + self, forensic_uuid: UUIDStr, load_events: bool = False + ) -> Optional[GrlIqData]: + try: + return self.get_data(forensic_uuid=forensic_uuid, load_events=load_events) + except AssertionError: + return None + + def get_data( + self, + forensic_id: Optional[PositiveInt] = None, + forensic_uuid: Optional[UUIDStr] = None, + load_events: bool = False, + ) -> GrlIqData: + from generalresearch.grliq.managers.forensic_events import ( + GrlIqEventManager, + ) + + assert any([forensic_id, forensic_uuid]), "Must provide a Forensic ID or UUID" + + if load_events: + # Gets the forensicevents where the 1) session_uuid matches the + # forensic items' session, 2) event_start is closest to the + # created_at for this forensic item, and within 1 minute. + + query = sql.SQL( + """ + SELECT d.id, d.data, e.events, e.mouse_events, t.timing_data + FROM grliq_forensicdata d + -- Closest event_start within 1 minute + LEFT JOIN LATERAL ( + SELECT events, mouse_events, timing_data + FROM grliq_forensicevents e + WHERE e.session_uuid = d.session_uuid + AND ABS(EXTRACT(EPOCH FROM (e.event_start - d.created_at))) <= 60 + ORDER BY ABS(EXTRACT(EPOCH FROM (e.event_start - d.created_at))) ASC + LIMIT 1 + ) e ON true + -- Most recent timing_data by id + LEFT JOIN LATERAL ( + SELECT timing_data + FROM grliq_forensicevents e2 + WHERE e2.session_uuid = d.session_uuid + ORDER BY e2.id DESC + LIMIT 1 + ) t ON true + """ + ) + + else: + query = sql.SQL( + """ + SELECT d.id, d.data + FROM grliq_forensicdata d + """ + ) + + if forensic_id is not None: + column_name = "id" + param_value = forensic_id + else: + column_name = "uuid" + param_value = forensic_uuid + + where_clause = sql.SQL(" WHERE {} = %s").format( + sql.Identifier("d", column_name) + ) + limit_clause = sql.SQL(" LIMIT 1") + q1 = sql.Composed([query, where_clause, limit_clause]) + + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=q1, params=(param_value,)) + x = c.fetchone() + + assert x is not None, f"GrlIqDataManager.get_data({forensic_uuid=}) not found" + + self.temporary_add_missing_fields(x["data"]) + x["data"]["id"] = x["id"] + d = GrlIqData.model_validate(x["data"]) + + if load_events: + d.events = x["events"] if x["events"] is not None else [] + d.pointer_move_events = ( + [PointerMove.from_dict(e) for e in x["mouse_events"]] + if x["mouse_events"] is not None + else [] + ) + d.timing_data = ( + TimingData.model_validate(x["timing_data"]) + if x["timing_data"] is not None + else None + ) + d.mouse_events = ( + GrlIqEventManager.process_mouse_events( + events=d.events or [], + pointer_moves=d.pointer_move_events or [], + ) + if d.events is not None + else [] + ) + d.keyboard_events = ( + GrlIqEventManager.process_keyboard_events(events=d.events) + if d.events is not None + else [] + ) + + return d + + def filter_timing_data( + self, + created_between: Optional[Tuple[datetime, datetime]] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + ) -> List[Dict]: + limit_str = f"LIMIT {limit}" if limit is not None else "" + offset_str = f"OFFSET {offset}" if offset is not None else "" + params = { + "created_after": created_between[0], + "created_before": created_between[1], + } + query = f""" + SELECT + d.id, d.session_uuid, d.client_ip, d.country_iso, + d.created_at, d.product_id, d.product_user_id, + d.fraud_score, e.timing_data, d.phase + FROM grliq_forensicdata d + JOIN LATERAL ( + SELECT timing_data + FROM grliq_forensicevents e + WHERE e.session_uuid = d.session_uuid AND e.timing_data IS NOT NULL + ORDER BY e.id DESC + LIMIT 1 + ) e ON TRUE + WHERE d.created_at BETWEEN %(created_after)s AND %(created_before)s + {limit_str} {offset_str}; + """ + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, params) + res: List[Dict] = c.fetchall() + for x in res: + x["timing_data"] = TimingData.model_validate(x["timing_data"]) + + return res + + def get_unique_user_count_by_fingerprint( + self, + product_id: str, + fingerprint: str, + product_user_id_not: str, + ) -> NonNegativeInt: + # This is used for filtering for other forensic posts with a certain + # fingerprint, in this product_id, but NOT for this user. + query = sql.SQL( + """ + SELECT COUNT(DISTINCT product_user_id) as user_count + FROM grliq_forensicdata d + WHERE product_id = %(product_id)s + AND fingerprint = %(fingerprint)s + AND product_user_id != %(product_user_id)s + AND created_at > NOW() - INTERVAL '30 DAYS' + """ + ) + params = { + "product_id": product_id, + "fingerprint": fingerprint, + "product_user_id": product_user_id_not, + } + # print(query) + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, params) + user_count = c.fetchone()["user_count"] + return int(user_count) + + def filter_data( + self, + session_uuid: Optional[str] = None, + fingerprint: Optional[str] = None, + fingerprints: Optional[Collection[str]] = None, + product_id: Optional[str] = None, + product_ids: Optional[Collection[str]] = None, + uuids: Optional[Collection[str]] = None, + created_after: Optional[datetime] = None, + created_before: Optional[datetime] = None, + created_between: Optional[Tuple[datetime, datetime]] = None, + user: Optional[User] = None, + users: Optional[Collection[User]] = None, + phase: Optional[Phase] = None, + order_by: str = "created_at DESC", + limit: Optional[int] = None, + offset: Optional[int] = None, + ) -> List[GrlIqData]: + + res = self.filter( + select_str="d.id, d.data", + session_uuid=session_uuid, + fingerprint=fingerprint, + fingerprints=fingerprints, + product_id=product_id, + product_ids=product_ids, + created_after=created_after, + created_before=created_before, + created_between=created_between, + uuids=uuids, + user=user, + users=users, + phase=phase, + order_by=order_by, + limit=limit, + offset=offset, + ) + return [x["data"] for x in res] + + def filter_results( + self, + session_uuid: Optional[str] = None, + uuid: Optional[str] = None, + product_ids: Optional[Collection[str]] = None, + product_id: Optional[str] = None, + created_after: Optional[datetime] = None, + created_before: Optional[datetime] = None, + created_between: Optional[Tuple[datetime, datetime]] = None, + user: Optional[User] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + order_by: str = "created_at DESC", + ) -> List[GrlIqCheckerResults]: + select_str = ( + "id, session_uuid, product_id, product_user_id, created_at, result_data" + ) + res = self.filter( + select_str=select_str, + session_uuid=session_uuid, + uuids=[uuid] if uuid else None, + product_ids=product_ids, + product_id=product_id, + created_after=created_after, + created_before=created_before, + created_between=created_between, + user=user, + limit=limit, + offset=offset, + order_by=order_by, + ) + for x in res: + x["result_data"] = ( + GrlIqCheckerResults.model_validate(x["result_data"]) + if x["result_data"] + else None + ) + return [x["result_data"] for x in res] + + def filter_category_results( + self, + session_uuid: Optional[str] = None, + uuid: Optional[str] = None, + product_id: Optional[str] = None, + product_ids: Optional[Collection[str]] = None, + created_after: Optional[datetime] = None, + created_before: Optional[datetime] = None, + created_between: Optional[Tuple[datetime, datetime]] = None, + user: Optional[User] = None, + order_by: str = "created_at DESC", + limit: Optional[int] = None, + offset: Optional[int] = None, + ) -> List[GrlIqForensicCategoryResult]: + select_str = ( + "id, session_uuid, product_id, product_user_id, created_at, category_result" + ) + res = self.filter( + select_str=select_str, + session_uuid=session_uuid, + uuids=[uuid] if uuid else None, + product_id=product_id, + product_ids=product_ids, + created_after=created_after, + created_before=created_before, + created_between=created_between, + user=user, + order_by=order_by, + limit=limit, + offset=offset, + ) + + return [x["category_result"] for x in res] + + @staticmethod + def make_filter_str( + session_uuid: Optional[str] = None, + fingerprint: Optional[str] = None, + fingerprints: Optional[Collection[str]] = None, + uuids: Optional[Collection[str]] = None, + product_id: Optional[str] = None, + product_ids: Optional[Collection[str]] = None, + created_after: Optional[datetime] = None, + created_before: Optional[datetime] = None, + created_between: Optional[Tuple[datetime, datetime]] = None, + user: Optional[User] = None, + users: Optional[Collection[User]] = None, + phase: Optional[Phase] = None, + ) -> Tuple[str, Dict[str, Any]]: + + filters = [] + params: Dict[str, Any] = {} + + if session_uuid: + params["session_uuid"] = session_uuid + filters.append("d.session_uuid = %(session_uuid)s") + + if fingerprint: + params["fingerprint"] = fingerprint + filters.append("d.fingerprint = %(fingerprint)s") + + if fingerprints: + params["fingerprints"] = list(set(fingerprints)) + filters.append("d.fingerprint = ANY(%(fingerprints)s)") + + if product_ids and len(product_ids) == 1: + product_id = list(product_ids)[0] + product_ids = None + + if product_ids: + assert ( + users is None and user is None and product_id is None + ), "user, users, product_id, and product_ids are mutually exclusive" + params["product_ids"] = list(set(product_ids)) + filters.append("d.product_id = ANY(%(product_ids)s::UUID[])") + + if product_id: + assert ( + users is None and user is None and product_ids is None + ), "user, users, product_id, and product_ids are mutually exclusive" + params["product_id"] = product_id + filters.append("d.product_id = %(product_id)s") + + if uuids: + params["uuids"] = uuids + filters.append("d.uuid = ANY(%(uuids)s)") + + if created_after: + params["created_after"] = created_after + filters.append( + "d.created_at >= %(created_after)s::timestamp with time zone" + ) + + if created_before: + params["created_before"] = created_before + filters.append( + "d.created_at < %(created_before)s::timestamp with time zone" + ) + + if created_between: + assert ( + created_after is None + ), "Cannot pass both created_after and created_between" + assert ( + created_before is None + ), "Cannot pass both created_before and created_between" + params["created_after"] = created_between[0] + params["created_before"] = created_between[1] + filters.append( + "d.created_at BETWEEN %(created_after)s::timestamptz AND %(created_before)s::timestamptz" + ) + + if user: + assert ( + product_ids is None and users is None + ), "user, users, and product_ids are mutually exclusive" + params["product_id"] = user.product_id + params["product_user_id"] = user.product_user_id + filters.append( + "(d.product_id = %(product_id)s AND d.product_user_id = %(product_user_id)s)" + ) + + if users: + assert ( + product_ids is None and user is None + ), "user, users, and product_ids are mutually exclusive" + user_args = ", ".join( + [f"(%(bp_{i})s, %(bpuid_{i})s)" for i in range(len(users))] + ) + filters.append(f"(d.product_id, d.product_user_id) IN ({user_args})") + for i, user in enumerate(users): + params[f"bp_{i}"] = user.product_id + params[f"bpuid_{i}"] = user.product_user_id + + if phase: + params["phase"] = phase.value + filters.append("d.phase = %(phase)s") + + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + return filter_str, params + + def filter_count( + self, + session_uuid: Optional[str] = None, + fingerprint: Optional[str] = None, + fingerprints: Optional[Collection[str]] = None, + uuids: Optional[Collection[str]] = None, + product_id: Optional[str] = None, + product_ids: Optional[Collection[str]] = None, + created_after: Optional[datetime] = None, + created_before: Optional[datetime] = None, + created_between: Optional[Tuple[datetime, datetime]] = None, + user: Optional[User] = None, + users: Optional[Collection[User]] = None, + phase: Optional[Phase] = None, + ) -> NonNegativeInt: + filter_str, params = self.make_filter_str( + session_uuid=session_uuid, + fingerprint=fingerprint, + fingerprints=fingerprints, + uuids=uuids, + product_id=product_id, + product_ids=product_ids, + created_after=created_after, + created_before=created_before, + created_between=created_between, + user=user, + users=users, + phase=phase, + ) + + only_product_id = ( + product_id is not None + and session_uuid is None + and fingerprint is None + and fingerprints is None + and uuids is None + and product_ids is None + and created_after is None + and created_before is None + and created_between is None + and user is None + and users is None + and phase is None + ) + + if only_product_id: + try: + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=""" + SELECT count AS c + FROM grliq_forensicdata_product_counts + WHERE product_id = %s + LIMIT 1 + """, + params=(product_id,), + ) + res = c.fetchone() + if res and res["c"] >= 0: + return int(res["c"]) + + except (Exception,) as e: + pass + + query = f""" + SELECT COUNT(1) AS c + FROM grliq_forensicdata d + {filter_str} + """ + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=query, params=params) + res = c.fetchone() + return int(res["c"]) + + def filter( + self, + select_str: str, + session_uuid: Optional[str] = None, + fingerprint: Optional[str] = None, + fingerprints: Optional[Collection[str]] = None, + uuids: Optional[Collection[str]] = None, + product_id: Optional[str] = None, + product_ids: Optional[Collection[str]] = None, + created_after: Optional[datetime] = None, + created_before: Optional[datetime] = None, + created_between: Optional[Tuple[datetime, datetime]] = None, + user: Optional[User] = None, + users: Optional[Collection[User]] = None, + phase: Optional[Phase] = None, + order_by: str = "created_at DESC", + limit: Optional[int] = None, + offset: Optional[int] = None, + ) -> List[Dict]: + """ + Accepts lots of optional filters. + """ + if not limit: + limit = 5000 + + if not offset: + offset = 0 + + if product_ids: + # It doesn't use the (product_id, created_at) index with multiple product_ids + assert ( + offset == 0 + ), "Cannot paginate using product_ids, use product_id instead" + + filter_str, params = self.make_filter_str( + session_uuid=session_uuid, + fingerprint=fingerprint, + fingerprints=fingerprints, + uuids=uuids, + product_id=product_id, + product_ids=product_ids, + created_after=created_after, + created_before=created_before, + created_between=created_between, + user=user, + users=users, + phase=phase, + ) + + query = f""" + SELECT {select_str} + FROM grliq_forensicdata d + {filter_str} + ORDER BY {order_by} + LIMIT {limit} + OFFSET {offset} + """ + # print(query) + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=query, params=params) + res: List = c.fetchall() + + for x in res: + + if "data" in x: + self.temporary_add_missing_fields(x["data"]) + x["data"]["id"] = x["id"] + x["data"] = GrlIqData.model_validate(x["data"]) if x["data"] else None + + if "result_data" in x: + if x["result_data"]: + x["result_data"].pop("is_complete", None) + x["result_data"] = ( + GrlIqCheckerResults.model_validate(x["result_data"]) + if x["result_data"] + else None + ) + + if "category_result" in x: + x["category_result"] = ( + GrlIqForensicCategoryResult.model_validate(x["category_result"]) + if x["category_result"] + else None + ) + + return res + + @staticmethod + def temporary_add_missing_fields(d: Dict): + # The following fields were added recently, and so we must give them + # a value or old db rows won't be parseable. Once logs are backfilled + # then this can be removed + field_default = { + "audio_codecs": None, + "video_codecs": None, + "color_gamut": "2", + "prefers_contrast": "0", + "prefers_reduced_motion": False, + "dynamic_range": False, + "inverted_colors": False, + "forced_colors": False, + "prefers_color_scheme": False, + } + for k, v in field_default.items(): + if k not in d: + d[k] = v + + # We made a mistake once and saved the grliq data object with the events fields set. + # Make sure they are not set here. We load them from the events table, not here! + d.pop("events", None) + d.pop("pointer_move_events", None) + d.pop("mouse_events", None) + d.pop("keyboard_events", None) + d.pop("timing_data", None) diff --git a/generalresearch/grliq/managers/forensic_events.py b/generalresearch/grliq/managers/forensic_events.py new file mode 100644 index 0000000..2633db7 --- /dev/null +++ b/generalresearch/grliq/managers/forensic_events.py @@ -0,0 +1,290 @@ +import json +from datetime import datetime +from typing import Optional, List, Collection, Dict +from uuid import uuid4 + +from psycopg import sql + +from generalresearch.grliq.models.events import ( + TimingData, + PointerMove, + MouseEvent, + KeyboardEvent, + Bounds, +) +from generalresearch.models.custom_types import UUIDStr +from generalresearch.pg_helper import PostgresConfig + + +class GrlIqEventManager: + + def __init__(self, postgres_config: PostgresConfig): + self.postgres_config = postgres_config + + def update_or_create_timing( + self, + session_uuid: UUIDStr, + timing_data: TimingData, + ): + data = { + "session_uuid": session_uuid, + "timing_data": ( + timing_data.model_dump_json() if timing_data is not None else None + ), + "uuid": uuid4().hex, + } + + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute("SELECT pg_advisory_xact_lock(hashtext(%s))", (session_uuid,)) + # Try to update first + update_query = sql.SQL( + """ + UPDATE grliq_forensicevents + SET timing_data = %(timing_data)s + WHERE session_uuid = %(session_uuid)s + AND timing_data IS NULL + RETURNING id + """ + ) + c.execute(update_query, data) + result = c.fetchone() + + if result: + pk = result["id"] + conn.commit() + return pk + + # No matching row to update. Do an insert + insert_query = sql.SQL( + """ + INSERT INTO grliq_forensicevents + (uuid, session_uuid, timing_data) + VALUES + (%(uuid)s, %(session_uuid)s, %(timing_data)s) + RETURNING id + """ + ) + c.execute(insert_query, data) + pk = c.fetchone()["id"] + conn.commit() + + return pk + + def update_or_create_events( + self, + session_uuid: UUIDStr, + event_start: datetime, + event_end: datetime, + events: Optional[List[Dict]] = None, + mouse_events: Optional[List[Dict]] = None, + ): + data = { + "uuid": uuid4().hex, + "session_uuid": session_uuid, + "events": json.dumps(events) if events is not None else None, + "mouse_events": ( + json.dumps(mouse_events) if mouse_events is not None else None + ), + "event_start": event_start, + "event_end": event_end, + } + + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute("SELECT pg_advisory_xact_lock(hashtext(%s))", (session_uuid,)) + # Try to update first + update_query = sql.SQL( + """ + UPDATE grliq_forensicevents + SET events = %(events)s, + mouse_events = %(mouse_events)s, + event_start = %(event_start)s, + event_end = %(event_end)s + WHERE session_uuid = %(session_uuid)s + AND events IS NULL + RETURNING id + """ + ) + c.execute(update_query, data) + result = c.fetchone() + + if result: + pk = result["id"] + conn.commit() + return pk + + # No matching row to update. Do an insert + insert_query = sql.SQL( + """ + INSERT INTO grliq_forensicevents + (uuid, session_uuid, events, mouse_events, + event_start, event_end) + VALUES + (%(uuid)s, %(session_uuid)s, %(events)s, %(mouse_events)s, + %(event_start)s, %(event_end)s) + RETURNING id + """ + ) + c.execute(insert_query, data) + pk = c.fetchone()["id"] + conn.commit() + + return pk + + def filter( + self, + select_str: Optional[str] = None, + session_uuid: Optional[str] = None, + session_uuids: Optional[Collection[str]] = None, + uuids: Optional[Collection[str]] = None, + started_since: Optional[datetime] = None, + limit: Optional[int] = None, + order_by: str = "event_start DESC", + ) -> List[Dict]: + """ """ + if not limit: + limit = 100 + if not select_str: + select_str = "*" + filters = [] + params = {} + if session_uuid: + params["session_uuid"] = session_uuid + filters.append("session_uuid = %(session_uuid)s") + if session_uuids: + params["session_uuids"] = session_uuids + filters.append("session_uuid = ANY(%(session_uuids)s)") + if uuids: + params["uuids"] = uuids + filters.append("uuid = ANY(%(uuids)s)") + if started_since: + params["started_since"] = started_since + filters.append("event_start >= %(started_since)s") + + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + query = f""" + SELECT {select_str} + FROM grliq_forensicevents + {filter_str} + ORDER BY {order_by} LIMIT {limit} + """ + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=query, params=params) + res = c.fetchall() + + for x in res: + if x.get("mouse_events"): + x["mouse_events"] = [ + PointerMove.from_dict(e) for e in x["mouse_events"] + ] + if x.get("timing_data"): + x["timing_data"] = TimingData.model_validate(x["timing_data"]) + + events = x.get("events", []) or [] + pointer_moves = x.get("mouse_events", []) or [] + x["mouse_events"] = self.process_mouse_events( + events=events, pointer_moves=pointer_moves + ) + x["keyboard_events"] = self.process_keyboard_events(events=events) + return res + + def filter_distinct_timing( + self, + session_uuids: Collection[str], + ) -> List[Dict]: + params = {"session_uuids": list(session_uuids)} + query = sql.SQL( + """ + SELECT DISTINCT ON (fe.session_uuid) + timing_data, + fe.session_uuid, + country_iso, + data ->> 'client_ip_detail' as client_ip_detail + FROM grliq_forensicevents fe + JOIN grliq_forensicdata d on fe.session_uuid = d.session_uuid + WHERE fe.session_uuid = ANY(%(session_uuids)s) + AND timing_data IS NOT NULL + ORDER BY session_uuid, fe.id DESC; + """ + ) + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, params) + res = c.fetchall() + + for x in res: + x["timing_data"] = TimingData.model_validate(x["timing_data"]) + x["client_ip_detail"] = ( + json.loads(x["client_ip_detail"]) if x["client_ip_detail"] else None + ) + + return res + + @staticmethod + def process_mouse_events(pointer_moves: List[PointerMove], events: List[Dict]): + """ + In the db column 'mouse_events' we put all 'pointermove' events. Pull those + out, and then any 'pointerdown' and 'pointerup' events from the 'events' column, + and merge them all together into a list of MouseEvent objects + """ + mouse_events = [ + # these contain only pointermove events + MouseEvent( + type=x.type, + pageX=x.pageX, + pageY=x.pageY, + pointerType=x.pointerType, + _elementId=x._elementId, + _elementTagName=x._elementTagName, + _elementBounds=x._elementBounds, + timeStamp=x.timeStamp, + ) + for x in pointer_moves + ] + mouse_events.extend( + [ + MouseEvent( + type=x["type"], + pageX=x["pageX"], + pageY=x["pageY"], + pointerType=x.get("pointerType"), + _elementId=x.get("_elementId"), + _elementTagName=x.get("_elementTagName"), + _elementBounds=( + Bounds(**x["_elementBounds"]) + if x.get("_elementBounds") + else None + ), + timeStamp=x["timeStamp"], + ) + for x in events + if x.get("type") in {"pointerdown", "pointerup", "click"} + ] + ) + mouse_events = sorted(mouse_events, key=lambda x: x.timeStamp) + return mouse_events + + @staticmethod + def process_keyboard_events(events: List[Dict]): + res = [ + KeyboardEvent( + type=x["type"], + inputType=x.get("inputType"), + key=x.get("key"), + data=x.get("data"), + _elementId=x.get("_elementId"), + _elementTagName=x.get("_elementTagName"), + timeStamp=x["timeStamp"], + _elementBounds=( + Bounds(**x["_elementBounds"]) if x.get("_elementBounds") else None + ), + ) + for x in events + if x.get("type") in {"keydown", "input"} + ] + # There's a lot of events that have nothing! on them.... ? + res = [x for x in res if x.data or x.inputType or x.key] + return res diff --git a/generalresearch/grliq/managers/forensic_results.py b/generalresearch/grliq/managers/forensic_results.py new file mode 100644 index 0000000..dd1b039 --- /dev/null +++ b/generalresearch/grliq/managers/forensic_results.py @@ -0,0 +1,104 @@ +from datetime import datetime +from typing import Optional, List, Collection, Dict, Tuple + +from generalresearch.grliq.models.forensic_result import ( + GrlIqForensicCategoryResult, + Phase, +) +from generalresearch.grliq.models.useragents import GrlUserAgent +from generalresearch.models.thl.user import User +from generalresearch.pg_helper import PostgresConfig + + +class GrlIqCategoryResultsReader: + def __init__(self, postgres_config: PostgresConfig): + self.postgres_config = postgres_config + + def filter_category_results( + self, + session_uuid: Optional[str] = None, + fingerprint: Optional[str] = None, + phase: Optional[Phase] = None, + uuids: Optional[Collection[str]] = None, + product_ids: Optional[Collection[str]] = None, + created_since: Optional[datetime] = None, + created_between: Optional[Tuple[datetime, datetime]] = None, + user: Optional[User] = None, + limit: Optional[int] = None, + ) -> List[Dict]: + """ + For retrieving GrlIqForensicCategoryResult objects from db. + :return: List of Dict. Keys are below in the 'select_str'. + """ + select_str = ( + "id, uuid, session_uuid, product_id, product_user_id, created_at," + " country_iso, client_ip, phase, data," + " data->>'user_agent_str' AS user_agent_str," + " category_result, is_attempt_allowed, fraud_score" + ) + if not limit: + limit = 5000 + + filters = [] + params = {} + if session_uuid: + params["session_uuid"] = session_uuid + filters.append("d.session_uuid = %(session_uuid)s") + if fingerprint: + params["fingerprint"] = fingerprint + filters.append("d.fingerprint = %(fingerprint)s") + if phase: + params["phase"] = phase.value + filters.append("d.phase = %(phase)s") + if product_ids: + params["product_ids"] = product_ids + filters.append("d.product_id = ANY(%(product_ids)s::UUID[])") + if uuids: + params["uuids"] = uuids + filters.append("d.uuid = ANY(%(uuids)s)") + if created_since: + params["created_since"] = created_since + filters.append( + "d.created_at >= %(created_since)s::timestamp with time zone" + ) + if created_between: + assert ( + created_since is None + ), "Cannot pass both created_until and created_between" + params["created_since"] = created_between[0] + params["created_until"] = created_between[1] + filters.append( + "d.created_at BETWEEN %(created_since)s::timestamptz AND %(created_until)s::timestamptz" + ) + if user: + assert product_ids is None, "Cannot pass both product_ids and user" + params["product_id"] = user.product_id + params["product_user_id"] = user.product_user_id + filters.append( + "(d.product_id = %(product_id)s AND d.product_user_id = %(product_user_id)s)" + ) + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + query = f""" + SELECT {select_str} + FROM grliq_forensicdata d + {filter_str} + ORDER BY created_at DESC LIMIT {limit} + """ + with self.postgres_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, params) + res = c.fetchall() + + for x in res: + x["client_ip"] = str(x["client_ip"]) + x["category_result"] = ( + GrlIqForensicCategoryResult.model_validate(x["category_result"]) + if x["category_result"] + else None + ) + if x.get("user_agent_str"): + x["user_agent"] = GrlUserAgent.from_ua_str(x["user_agent_str"]) + x.pop("user_agent_str", None) + + return res diff --git a/generalresearch/grliq/managers/forensic_summary.py b/generalresearch/grliq/managers/forensic_summary.py new file mode 100644 index 0000000..c44daf2 --- /dev/null +++ b/generalresearch/grliq/managers/forensic_summary.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import statistics +from collections import defaultdict +from datetime import datetime, timezone, timedelta +from typing import List, Dict + +import numpy as np + +from generalresearch.grliq.managers.forensic_data import GrlIqDataManager +from generalresearch.grliq.managers.forensic_events import ( + GrlIqEventManager, +) +from generalresearch.grliq.models.forensic_result import ( + GrlIqForensicCategoryResult, + GrlIqCheckerResults, +) +from generalresearch.grliq.models.forensic_summary import ( + GrlIqForensicCategorySummary, + GrlIqCheckerResultsSummary, + UserForensicSummary, + CountryRTTDistribution, + TimingDataCountrySummary, +) +from generalresearch.models.thl.user import User +from generalresearch.redis_helper import RedisConfig + + +def calculate_category_summary( + res: List[GrlIqForensicCategoryResult], +) -> GrlIqForensicCategorySummary: + totals = defaultdict(int) + is_complete_count = 0 + is_attempt_allowed_count = 0 + n = len(res) + fields = GrlIqForensicCategoryResult.model_score_fields() + fraud_score = 0 + + for r in res: + fraud_score += r.fraud_score + if r.is_complete: + is_complete_count += 1 + if r.is_attempt_allowed(): + is_attempt_allowed_count += 1 + for field in fields: + totals[field] += getattr(r, field) + + return GrlIqForensicCategorySummary( + attempt_count=n, + is_attempt_allowed_count=is_attempt_allowed_count, + is_complete_rate=is_complete_count / n if n else 0.0, + fraud_score_avg=fraud_score / n if n else None, + **{f"{field}_avg": totals[field] / n if n else 0.0 for field in totals}, + ) + + +def calculate_checker_summary( + res: List[GrlIqCheckerResults], +) -> GrlIqCheckerResultsSummary: + totals = defaultdict(list) + none_totals = defaultdict(int) + n = len(res) + fields = [f for f in GrlIqCheckerResults.model_fields if f.startswith("check_")] + + for r in res: + for field in fields: + value = getattr(r, field) + if value is None: + none_totals[field] += 1 + else: + totals[field].append(value.score) + field_avg = {f"{k}_avg": statistics.mean(v) if v else 0 for k, v in totals.items()} + field_pct_none = { + f"{k}_pct_none": v / n if n else 0 for k, v in none_totals.items() + } + field_avg.update( + { + k.replace("_pct_none", "_avg"): None + for k, v in field_pct_none.items() + if v == 1 + } + ) + + return GrlIqCheckerResultsSummary(**field_avg, **field_pct_none) + + +def calculate_timing_summary( + redis_config: RedisConfig, timing_res +) -> Dict[str, TimingDataCountrySummary]: + country_median_rtts = defaultdict(list) + for x in timing_res: + s = x["timing_data"].summarize + if s: + country_median_rtts[x["country_iso"]].append( + float(np.exp(s.median_log_rtt)) + ) + country_isos = list(country_median_rtts.keys()) + + rc = redis_config.create_redis_client() + country_distributions = dict( + zip( + country_isos, + rc.hmget("grl-iq:country_rtt_distributions", *country_isos), + ) + ) + country_distributions = { + k: CountryRTTDistribution.model_validate_json(v) + for k, v in country_distributions.items() + } + + out = dict() + for country_iso, median_rtts in country_median_rtts.items(): + country_stats = country_distributions[country_iso] + z_scores = [ + (np.log(x) - country_stats.rtt_log_mean) / country_stats.rtt_log_std + for x in median_rtts + ] + out[country_iso] = TimingDataCountrySummary( + country_iso=country_iso, + rtt_min=min(median_rtts), + rtt_max=max(median_rtts), + rtt_mean=statistics.mean(median_rtts), + rtt_median=statistics.median(median_rtts), + rtt_q25=float(np.quantile(median_rtts, 0.25)), + rtt_q75=float(np.quantile(median_rtts, 0.75)), + expected_rtt_range=country_distributions[country_iso].expected_rtt_range, + mean_z_score=statistics.mean(z_scores), + ) + return out + + +def run_user_forensic_summary( + iq_dm: GrlIqDataManager, + iq_em: GrlIqEventManager, + redis_config: RedisConfig, + user: User, +) -> UserForensicSummary: + now = datetime.now(tz=timezone.utc) + created_between = (now - timedelta(days=90), now) + select_str = "id, session_uuid, product_id, product_user_id, created_at, result_data, category_result" + res = iq_dm.filter( + select_str=select_str, + user=user, + created_between=created_between, + limit=500, + order_by="created_at DESC", + ) + period_start = min([x["created_at"] for x in res]) if res else None + period_end = max([x["created_at"] for x in res]) if res else None + + category_result_summary = ( + calculate_category_summary([x["category_result"] for x in res]) if res else None + ) + checker_result_summary = ( + calculate_checker_summary([x["result_data"] for x in res]) if res else None + ) + + session_uuids = {x["session_uuid"] for x in res} + timing_res: List[Dict] = iq_em.filter_distinct_timing(session_uuids=session_uuids) + + country_timing_data_summary = ( + calculate_timing_summary(redis_config=redis_config, timing_res=timing_res) + if timing_res + else dict() + ) + + s = UserForensicSummary( + period_start=period_start, + period_end=period_end, + category_result_summary=category_result_summary, + checker_result_summary=checker_result_summary, + country_timing_data_summary=country_timing_data_summary, + ip_timing_data_summary={}, + ) + return s diff --git a/generalresearch/grliq/models/__init__.py b/generalresearch/grliq/models/__init__.py new file mode 100644 index 0000000..957e64c --- /dev/null +++ b/generalresearch/grliq/models/__init__.py @@ -0,0 +1,66 @@ +import json +from enum import Enum +from typing import List + + +class RiskWeighting(str, Enum): + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + CRITICAL = "critical" + + +VIDEO_CODEC_NAMES = [ + 'video/3gpp; codecs="mp4v.20.8, samr"', + 'video/mp4; codecs="avc1.42E01E"', + 'video/mp4; codecs="avc1.58A01E"', + 'video/mp4; codecs="avc1.4D401E"', + 'video/mp4; codecs="avc1.64001E"', + 'video/mp4; codecs="avc1.42E01E, mp4a.40.2"', + 'video/mp4; codecs="avc1.58A01E, mp4a.40.2"', + 'video/mp4; codecs="avc1.4D401E, mp4a.40.2"', + 'video/mp4; codecs="avc1.64001E, mp4a.40.2"', + 'video/mp4; codecs="flac"', + 'video/mp4; codecs="H.264, mp3"', + 'video/mp4; codecs="H.264, aac"', + 'video/mp4; codecs="mp4v.20.8, mp4a.40.2"', + 'video/mp4; codecs="mp4v.20.240, mp4a.40.2"', + 'video/mpeg; codec="H.264"', + 'video/ogg; codecs="dirac, vorbis"', + 'video/ogg; codecs="opus"', + 'video/ogg; codecs="theora"', + 'video/ogg; codecs="theora, vorbis"', + 'video/ogg; codecs="theora, speex"', + 'video/webm; codecs="vp9, opus"', + 'video/webm; codecs="vp8, vorbis"', + 'video/x-matroska; codecs="theora, vorbis"', +] +AUDIO_CODEC_NAMES = [ + "audio/3gpp", + "audio/3gpp2", + "audio/AMR-NB", + "audio/AMR-WB", + "audio/GSM", + "audio/aac", + "audio/basic", + "audio/flac", + "audio/midi", + "audio/mpeg", + 'audio/mp4; codecs="mp4a.40.2"', + 'audio/mp4; codecs="ac-3"', + 'audio/mp4; codecs="ec-3"', + 'audio/mpeg; codecs="mp3"', + 'audio/ogg; codecs="flac"', + 'audio/ogg; codecs="vorbis"', + 'audio/ogg; codecs="opus"', + 'audio/ogg; codecs="speex"', + 'audio/wav; codecs="1"', + 'audio/webm; codecs="vorbis"', + 'audio/webm; codecs="opus"', + "audio/x-m4a", + "audio/x-aiff", + "audio/x-mpegurl", +] + +font_str = '[".Aqua Kana",".Helvetica LT MM",".Times LT MM","18thCentury","8514oem","AR BERKLEY","AR JULIAN","AR PL UKai CN","AR PL UMing CN","AR PL UMing HK","AR PL UMing TW","AR PL UMing TW MBE","Aakar","Abadi MT Condensed Extra Bold","Abadi MT Condensed Light","Abyssinica SIL","AcmeFont","Adobe Arabic","Agency FB","Aharoni","Aharoni Bold","Al Bayan","Al Bayan Bold","Al Bayan Plain","Al Nile","Al Tarikh","Aldhabi","Alfredo","Algerian","Alien Encounters","Almonte Snow","American Typewriter","American Typewriter Bold","American Typewriter Condensed","American Typewriter Light","Amethyst","Andale Mono","Andale Mono Version","Andalus","Angsana New","AngsanaUPC","Ani","AnjaliOldLipi","Aparajita","Apple Braille","Apple Braille Outline 6 Dot","Apple Braille Outline 8 Dot","Apple Braille Pinpoint 6 Dot","Apple Braille Pinpoint 8 Dot","Apple Chancery","Apple Color Emoji","Apple LiGothic Medium","Apple LiSung Light","Apple SD Gothic Neo","Apple SD Gothic Neo Regular","Apple SD GothicNeo ExtraBold","Apple Symbols","AppleGothic","AppleGothic Regular","AppleMyungjo","AppleMyungjo Regular","AquaKana","Arabic Transparent","Arabic Typesetting","Arial","Arial Baltic","Arial Black","Arial Bold","Arial Bold Italic","Arial CE","Arial CYR","Arial Greek","Arial Hebrew","Arial Hebrew Bold","Arial Italic","Arial Narrow","Arial Narrow Bold","Arial Narrow Bold Italic","Arial Narrow Italic","Arial Rounded Bold","Arial Rounded MT Bold","Arial TUR","Arial Unicode MS","ArialHB","Arimo","Asimov","Autumn","Avenir","Avenir Black","Avenir Book","Avenir Next","Avenir Next Bold","Avenir Next Condensed","Avenir Next Condensed Bold","Avenir Next Demi Bold","Avenir Next Heavy","Avenir Next Regular","Avenir Roman","Ayuthaya","BN Jinx","BN Machine","BOUTON International Symbols","Baby Kruffy","Baghdad","Bahnschrift","Balthazar","Bangla MN","Bangla MN Bold","Bangla Sangam MN","Bangla Sangam MN Bold","Baskerville","Baskerville Bold","Baskerville Bold Italic","Baskerville Old Face","Baskerville SemiBold","Baskerville SemiBold Italic","Bastion","Batang","BatangChe","Bauhaus 93","Beirut","Bell MT","Bell MT Bold","Bell MT Italic","Bellerose","Berlin Sans FB","Berlin Sans FB Demi","Bernard MT Condensed","BiauKai","Big Caslon","Big Caslon Medium","Birch Std","Bitstream Charter","Bitstream Vera Sans","Blackadder ITC","Blackoak Std","Bobcat","Bodoni 72","Bodoni MT","Bodoni MT Black","Bodoni MT Poster Compressed","Bodoni Ornaments","BolsterBold","Book Antiqua","Book Antiqua Bold","Bookman Old Style","Bookman Old Style Bold","Bookshelf Symbol 7","Borealis","Bradley Hand","Bradley Hand ITC","Braggadocio","Brandish","Britannic Bold","Broadway","Browallia New","BrowalliaUPC","Brush Script","Brush Script MT","Brush Script MT Italic","Brush Script Std","Brussels","Calibri","Calibri Bold","Calibri Light","Californian FB","Calisto MT","Calisto MT Bold","Calligraphic","Calvin","Cambria","Cambria Bold","Cambria Math","Candara","Candara Bold","Candles","Carrois Gothic SC","Castellar","Centaur","Century","Century Gothic","Century Gothic Bold","Century Schoolbook","Century Schoolbook Bold","Century Schoolbook L","Chalkboard","Chalkboard Bold","Chalkboard SE","Chalkboard SE Bold","ChalkboardBold","Chalkduster","Chandas","Chaparral Pro","Chaparral Pro Light","Charlemagne Std","Charter","Chilanka","Chiller","Chinyen","Clarendon","Cochin","Cochin Bold","Colbert","Colonna MT","Comic Sans MS","Comic Sans MS Bold","Commons","Consolas","Consolas Bold","Constantia","Constantia Bold","Coolsville","Cooper Black","Cooper Std Black","Copperplate","Copperplate Bold","Copperplate Gothic Bold","Copperplate Light","Corbel","Corbel Bold","Cordia New","CordiaUPC","Corporate","Corsiva","Corsiva Hebrew","Corsiva Hebrew Bold","Courier","Courier 10 Pitch","Courier Bold","Courier New","Courier New Baltic","Courier New Bold","Courier New CE","Courier New Italic","Courier Oblique","Cracked Johnnie","Creepygirl","Curlz MT","Cursor","Cutive Mono","DFKai-SB","DIN Alternate","DIN Condensed","Damascus","Damascus Bold","Dancing Script","DaunPenh","David","Dayton","DecoType Naskh","Deja Vu","DejaVu LGC Sans","DejaVu Sans","DejaVu Sans Mono","DejaVu Serif","Deneane","Desdemona","Detente","Devanagari MT","Devanagari MT Bold","Devanagari Sangam MN","Didot","Didot Bold","Digifit","DilleniaUPC","Dingbats","Distant Galaxy","Diwan Kufi","Diwan Kufi Regular","Diwan Thuluth","Diwan Thuluth Regular","DokChampa","Dominican","Dotum","DotumChe","Droid Sans","Droid Sans Fallback","Droid Sans Mono","Dyuthi","Ebrima","Edwardian Script ITC","Elephant","Emmett","Engravers MT","Engravers MT Bold","Enliven","Eras Bold ITC","Estrangelo Edessa","Ethnocentric","EucrosiaUPC","Euphemia","Euphemia UCAS","Euphemia UCAS Bold","Eurostile","Eurostile Bold","Expressway Rg","FangSong","Farah","Farisi","Felix Titling","Fingerpop","Fixedsys","Flubber","Footlight MT Light","Forte","FrankRuehl","Frankfurter Venetian TT","Franklin Gothic Book","Franklin Gothic Book Italic","Franklin Gothic Medium","Franklin Gothic Medium Cond","Franklin Gothic Medium Italic","FreeMono","FreeSans","FreeSerif","FreesiaUPC","Freestyle Script","French Script MT","Futura","Futura Condensed ExtraBold","Futura Medium","GB18030 Bitmap","Gabriola","Gadugi","Garamond","Garamond Bold","Gargi","Garuda","Gautami","Gazzarelli","Geeza Pro","Geeza Pro Bold","Geneva","GenevaCY","Gentium","Gentium Basic","Gentium Book Basic","GentiumAlt","Georgia","Georgia Bold","Geotype TT","Giddyup Std","Gigi","Gill","Gill Sans","Gill Sans Bold","Gill Sans MT","Gill Sans MT Bold","Gill Sans MT Condensed","Gill Sans MT Ext Condensed Bold","Gill Sans MT Italic","Gill Sans Ultra Bold","Gill Sans Ultra Bold Condensed","Gisha","Glockenspiel","Gloucester MT Extra Condensed","Good Times","Goudy","Goudy Old Style","Goudy Old Style Bold","Goudy Stout","Greek Diner Inline TT","Gubbi","Gujarati MT","Gujarati MT Bold","Gujarati Sangam MN","Gujarati Sangam MN Bold","Gulim","GulimChe","GungSeo Regular","Gungseouche","Gungsuh","GungsuhChe","Gurmukhi","Gurmukhi MN","Gurmukhi MN Bold","Gurmukhi MT","Gurmukhi Sangam MN","Gurmukhi Sangam MN Bold","Haettenschweiler","Hand Me Down S (BRK)","Hansen","Harlow Solid Italic","Harrington","Harvest","HarvestItal","Haxton Logos TT","HeadLineA Regular","HeadlineA","Heavy Heap","Hei","Hei Regular","Heiti SC","Heiti SC Light","Heiti SC Medium","Heiti TC","Heiti TC Light","Heiti TC Medium","Helvetica","Helvetica Bold","Helvetica CY Bold","Helvetica CY Plain","Helvetica LT Std","Helvetica Light","Helvetica Neue","Helvetica Neue Bold","Helvetica Neue Medium","Helvetica Oblique","HelveticaCY","HelveticaNeueLT Com 107 XBlkCn","Herculanum","High Tower Text","Highboot","Hiragino Kaku Gothic Pro W3","Hiragino Kaku Gothic Pro W6","Hiragino Kaku Gothic ProN W3","Hiragino Kaku Gothic ProN W6","Hiragino Kaku Gothic Std W8","Hiragino Kaku Gothic StdN W8","Hiragino Maru Gothic Pro W4","Hiragino Maru Gothic ProN W4","Hiragino Mincho Pro W3","Hiragino Mincho Pro W6","Hiragino Mincho ProN W3","Hiragino Mincho ProN W6","Hiragino Sans GB W3","Hiragino Sans GB W6","Hiragino Sans W0","Hiragino Sans W1","Hiragino Sans W2","Hiragino Sans W3","Hiragino Sans W4","Hiragino Sans W5","Hiragino Sans W6","Hiragino Sans W7","Hiragino Sans W8","Hiragino Sans W9","Hobo Std","Hoefler Text","Hoefler Text Black","Hoefler Text Ornaments","Hollywood Hills","Hombre","Huxley Titling","ITC Stone Serif","ITF Devanagari","ITF Devanagari Marathi","ITF Devanagari Medium","Impact","Imprint MT Shadow","InaiMathi","Induction","Informal Roman","Ink Free","IrisUPC","Iskoola Pota","Italianate","Jamrul","JasmineUPC","Javanese Text","Jokerman","Juice ITC","KacstArt","KacstBook","KacstDecorative","KacstDigital","KacstFarsi","KacstLetter","KacstNaskh","KacstOffice","KacstOne","KacstPen","KacstPoster","KacstQurn","KacstScreen","KacstTitle","KacstTitleL","Kai","Kai Regular","KaiTi","Kailasa","Kailasa Regular","Kaiti SC","Kaiti SC Black","Kalapi","Kalimati","Kalinga","Kannada MN","Kannada MN Bold","Kannada Sangam MN","Kannada Sangam MN Bold","Kartika","Karumbi","Kedage","Kefa","Kefa Bold","Keraleeyam","Keyboard","Khmer MN","Khmer MN Bold","Khmer OS","Khmer OS System","Khmer Sangam MN","Khmer UI","Kinnari","Kino MT","KodchiangUPC","Kohinoor Bangla","Kohinoor Devanagari","Kohinoor Telugu","Kokila","Kokonor","Kokonor Regular","Kozuka Gothic Pr6N B","Kristen ITC","Krungthep","KufiStandardGK","KufiStandardGK Regular","Kunstler Script","Laksaman","Lao MN","Lao Sangam MN","Lao UI","LastResort","Latha","Leelawadee","Letter Gothic Std","LetterOMatic!","Levenim MT","LiHei Pro","LiSong Pro","Liberation Mono","Liberation Sans","Liberation Sans Narrow","Liberation Serif","Likhan","LilyUPC","Limousine","Lithos Pro Regular","LittleLordFontleroy","Lohit Assamese","Lohit Bengali","Lohit Devanagari","Lohit Gujarati","Lohit Gurmukhi","Lohit Hindi","Lohit Kannada","Lohit Malayalam","Lohit Odia","Lohit Punjabi","Lohit Tamil","Lohit Tamil Classical","Lohit Telugu","Loma","Lucida Blackletter","Lucida Bright","Lucida Bright Demibold","Lucida Bright Demibold Italic","Lucida Bright Italic","Lucida Calligraphy","Lucida Calligraphy Italic","Lucida Console","Lucida Fax","Lucida Fax Demibold","Lucida Fax Regular","Lucida Grande","Lucida Grande Bold","Lucida Handwriting","Lucida Handwriting Italic","Lucida Sans","Lucida Sans Demibold Italic","Lucida Sans Typewriter","Lucida Sans Typewriter Bold","Lucida Sans Unicode","Luminari","Luxi Mono","MS Gothic","MS Mincho","MS Outlook","MS PGothic","MS PMincho","MS Reference Sans Serif","MS Reference Specialty","MS Sans Serif","MS Serif","MS UI Gothic","MT Extra","MV Boli","Mael","Magneto","Maiandra GD","Malayalam MN","Malayalam MN Bold","Malayalam Sangam MN","Malayalam Sangam MN Bold","Malgun Gothic","Mallige","Mangal","Manorly","Marion","Marion Bold","Marker Felt","Marker Felt Thin","Marlett","Martina","Matura MT Script Capitals","Meera","Meiryo","Meiryo Bold","Meiryo UI","MelodBold","Menlo","Menlo Bold","Mesquite Std","Microsoft","Microsoft Himalaya","Microsoft JhengHei","Microsoft JhengHei UI","Microsoft New Tai Lue","Microsoft PhagsPa","Microsoft Sans Serif","Microsoft Tai Le","Microsoft Tai Le Bold","Microsoft Uighur","Microsoft YaHei","Microsoft YaHei UI","Microsoft Yi Baiti","Minerva","MingLiU","MingLiU-ExtB","MingLiU_HKSCS","Minion Pro","Miriam","Mishafi","Mishafi Gold","Mistral","Modern","Modern No. 20","Monaco","Mongolian Baiti","Monospace","Monotype Corsiva","Monotype Sorts","MoolBoran","Moonbeam","MotoyaLMaru","Mshtakan","Mshtakan Bold","Mukti Narrow","Muna","Myanmar MN","Myanmar MN Bold","Myanmar Sangam MN","Myanmar Text","Mycalc","Myriad Arabic","Myriad Hebrew","Myriad Pro","NISC18030","NSimSun","Nadeem","Nadeem Regular","Nakula","Nanum Barun Gothic","Nanum Gothic","Nanum Myeongjo","NanumBarunGothic","NanumGothic","NanumGothic Bold","NanumGothicCoding","NanumMyeongjo","NanumMyeongjo Bold","Narkisim","Nasalization","Navilu","Neon Lights","New Peninim MT","New Peninim MT Bold","News Gothic MT","News Gothic MT Bold","Niagara Engraved","Niagara Solid","Nimbus Mono L","Nimbus Roman No9 L","Nimbus Sans L","Nimbus Sans L Condensed","Nina","Nirmala UI","Nirmala.ttf","Norasi","Noteworthy","Noteworthy Bold","Noto Color Emoji","Noto Emoji","Noto Mono","Noto Naskh Arabic","Noto Nastaliq Urdu","Noto Sans","Noto Sans Armenian","Noto Sans Bengali","Noto Sans CJK","Noto Sans Canadian Aboriginal","Noto Sans Cherokee","Noto Sans Devanagari","Noto Sans Ethiopic","Noto Sans Georgian","Noto Sans Gujarati","Noto Sans Gurmukhi","Noto Sans Hebrew","Noto Sans JP","Noto Sans KR","Noto Sans Kannada","Noto Sans Khmer","Noto Sans Lao","Noto Sans Malayalam","Noto Sans Myanmar","Noto Sans Oriya","Noto Sans SC","Noto Sans Sinhala","Noto Sans Symbols","Noto Sans TC","Noto Sans Tamil","Noto Sans Telugu","Noto Sans Thai","Noto Sans Yi","Noto Serif","Notram","November","Nueva Std","Nueva Std Cond","Nyala","OCR A Extended","OCR A Std","Old English Text MT","OldeEnglish","Onyx","OpenSymbol","OpineHeavy","Optima","Optima Bold","Optima Regular","Orator Std","Oriya MN","Oriya MN Bold","Oriya Sangam MN","Oriya Sangam MN Bold","Osaka","Osaka-Mono","OsakaMono","PCMyungjo Regular","PCmyoungjo","PMingLiU","PMingLiU-ExtB","PR Celtic Narrow","PT Mono","PT Sans","PT Sans Bold","PT Sans Caption Bold","PT Sans Narrow Bold","PT Serif","Padauk","Padauk Book","Padmaa","Pagul","Palace Script MT","Palatino","Palatino Bold","Palatino Linotype","Palatino Linotype Bold","Papyrus","Papyrus Condensed","Parchment","Parry Hotter","PenultimateLight","Perpetua","Perpetua Bold","Perpetua Titling MT","Perpetua Titling MT Bold","Phetsarath OT","Phosphate","Phosphate Inline","Phosphate Solid","PhrasticMedium","PilGi Regular","Pilgiche","PingFang HK","PingFang SC","PingFang TC","Pirate","Plantagenet Cherokee","Playbill","Poor Richard","Poplar Std","Pothana2000","Prestige Elite Std","Pristina","Purisa","QuiverItal","Raanana","Raanana Bold","Raavi","Rachana","Rage Italic","RaghuMalayalam","Ravie","Rekha","Roboto","Rockwell","Rockwell Bold","Rockwell Condensed","Rockwell Extra Bold","Rockwell Italic","Rod","Roland","Rondalo","Rosewood Std Regular","RowdyHeavy","Russel Write TT","SF Movie Poster","STFangsong","STHeiti","STIXGeneral","STIXGeneral-Bold","STIXGeneral-Regular","STIXIntegralsD","STIXIntegralsD-Bold","STIXIntegralsSm","STIXIntegralsSm-Bold","STIXIntegralsUp","STIXIntegralsUp-Bold","STIXIntegralsUp-Regular","STIXIntegralsUpD","STIXIntegralsUpD-Bold","STIXIntegralsUpD-Regular","STIXIntegralsUpSm","STIXIntegralsUpSm-Bold","STIXNonUnicode","STIXNonUnicode-Bold","STIXSizeFiveSym","STIXSizeFiveSym-Regular","STIXSizeFourSym","STIXSizeFourSym-Bold","STIXSizeOneSym","STIXSizeOneSym-Bold","STIXSizeThreeSym","STIXSizeThreeSym-Bold","STIXSizeTwoSym","STIXSizeTwoSym-Bold","STIXVariants","STIXVariants-Bold","STKaiti","STSong","STXihei","SWGamekeys MT","Saab","Sahadeva","Sakkal Majalla","Salina","Samanata","Samyak Devanagari","Samyak Gujarati","Samyak Malayalam","Samyak Tamil","Sana","Sana Regular","Sans","Sarai","Sathu","Savoye LET Plain:1.0","Sawasdee","Script","Script MT Bold","Segoe MDL2 Assets","Segoe Print","Segoe Pseudo","Segoe Script","Segoe UI","Segoe UI Emoji","Segoe UI Historic","Segoe UI Semilight","Segoe UI Symbol","Serif","Shonar Bangla","Showcard Gothic","Shree Devanagari 714","Shruti","SignPainter-HouseScript","Silom","SimHei","SimSun","SimSun-ExtB","Simplified Arabic","Simplified Arabic Fixed","Sinhala MN","Sinhala MN Bold","Sinhala Sangam MN","Sinhala Sangam MN Bold","Sitka","Skia","Skia Regular","Skinny","Small Fonts","Snap ITC","Snell Roundhand","Snowdrift","Songti SC","Songti SC Black","Songti TC","Source Code Pro","Splash","Standard Symbols L","Stencil","Stencil Std","Stephen","Sukhumvit Set","Suruma","Sylfaen","Symbol","Symbole","System","System Font","TAMu_Kadambri","TAMu_Kalyani","TAMu_Maduram","TSCu_Comic","TSCu_Paranar","TSCu_Times","Tahoma","Tahoma Negreta","TakaoExGothic","TakaoExMincho","TakaoGothic","TakaoMincho","TakaoPGothic","TakaoPMincho","Tamil MN","Tamil MN Bold","Tamil Sangam MN","Tamil Sangam MN Bold","Tarzan","Tekton Pro","Tekton Pro Cond","Tekton Pro Ext","Telugu MN","Telugu MN Bold","Telugu Sangam MN","Telugu Sangam MN Bold","Tempus Sans ITC","Terminal","Terminator Two","Thonburi","Thonburi Bold","Tibetan Machine Uni","Times","Times Bold","Times New Roman","Times New Roman Baltic","Times New Roman Bold","Times New Roman Italic","Times Roman","Tlwg Mono","Tlwg Typewriter","Tlwg Typist","Tlwg Typo","TlwgMono","TlwgTypewriter","Toledo","Traditional Arabic","Trajan Pro","Trattatello","Trebuchet MS","Trebuchet MS Bold","Tunga","Tw Cen MT","Tw Cen MT Bold","Tw Cen MT Italic","URW Bookman L","URW Chancery L","URW Gothic L","URW Palladio L","Ubuntu","Ubuntu Condensed","Ubuntu Mono","Ukai","Ume Gothic","Ume Mincho","Ume P Gothic","Ume P Mincho","Ume UI Gothic","Uming","Umpush","UnBatang","UnDinaru","UnDotum","UnGraphic","UnGungseo","UnPilgi","Untitled1","Urdu Typesetting","Uroob","Utkal","Utopia","Utsaah","Valken","Vani","Vemana2000","Verdana","Verdana Bold","Vijaya","Viner Hand ITC","Vivaldi","Vivian","Vladimir Script","Vrinda","Waree","Waseem","Waverly","Webdings","WenQuanYi Bitmap Song","WenQuanYi Micro Hei","WenQuanYi Micro Hei Mono","WenQuanYi Zen Hei","Whimsy TT","Wide Latin","Wingdings","Wingdings 2","Wingdings 3","Woodcut","X-Files","Year supply of fairy cakes","Yu Gothic","Yu Mincho","Yuppy SC","Yuppy SC Regular","Yuppy TC","Yuppy TC Regular","Zapf Dingbats","Zapfino","Zawgyi-One","gargi","lklug","mry_KacstQurn","ori1Uni"]' +SUPPORTED_FONTS: List[str] = json.loads(font_str) diff --git a/generalresearch/grliq/models/custom_types.py b/generalresearch/grliq/models/custom_types.py new file mode 100644 index 0000000..1b2c9de --- /dev/null +++ b/generalresearch/grliq/models/custom_types.py @@ -0,0 +1,6 @@ +from typing_extensions import Annotated +import annotated_types + +GrlIqScore = Annotated[int, annotated_types.Ge(0), annotated_types.Le(100)] +GrlIqAvgScore = Annotated[float, annotated_types.Ge(0), annotated_types.Le(100)] +GrlIqRate = Annotated[float, annotated_types.Ge(0), annotated_types.Le(1)] diff --git a/generalresearch/grliq/models/decider.py b/generalresearch/grliq/models/decider.py new file mode 100644 index 0000000..94802cc --- /dev/null +++ b/generalresearch/grliq/models/decider.py @@ -0,0 +1,53 @@ +from datetime import datetime, timezone +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, ConfigDict, Field + +from generalresearch.models.custom_types import AwareDatetimeISO + + +class Decider(str, Enum): + # This decision was made in the thl-core: pre-offerwall-entry view + PRE_ENTRY = "pre_entry" + # This decision made by grl-iq (synchronously) + GRL_IQ = "grl_iq" + # This decision made by ym-user-predict (asynchronously) + YM_USER = "ym_user" + + +class AttemptDecision(str, Enum): + # This attempt should be allowed to continue + PASS = "pass" + # This attempt is deemed fraudulent + FAIL = "fail" + + +class GrlIqAttemptResult(BaseModel): + """ + This model is used via Redis to communicate between GRL-IQ/YM and thl-core + to set or update a real-time (or close to real-time) decision about if + a session should be allowed to proceed. + """ + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + timestamp: AwareDatetimeISO = Field( + description="When this decision was made", + default_factory=lambda: datetime.now(tz=timezone.utc), + ) + decider: Decider = Field(description="Where this decision was made") + decision: AttemptDecision = Field( + description="Whether an attempt should be allowed to continue, based on the evidence" + "available to the decider at this point in time" + ) + fraud_score: Optional[int] = Field( + ge=0, + le=100, + description="Higher equals more likely to be fraudulent", + default=None, + ) + fingerprint: Optional[str] = Field( + default=None, + description="Fingerprint that should be unique to this particular device", + ) diff --git a/generalresearch/grliq/models/events.py b/generalresearch/grliq/models/events.py new file mode 100644 index 0000000..2c8fa64 --- /dev/null +++ b/generalresearch/grliq/models/events.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +from collections import namedtuple +from dataclasses import dataclass, fields +from functools import cached_property +from typing import List, Optional + +import numpy as np +from pydantic import ( + BaseModel, + ConfigDict, + Field, + NonNegativeInt, + NonNegativeFloat, + PositiveFloat, +) +from typing_extensions import Self + +from generalresearch.models.custom_types import AwareDatetimeISO, IPvAnyAddressStr + + +class Bounds(namedtuple("BoundsBase", ["left", "top", "width", "height"])): + __slots__ = () + + @property + def right(self): + return self.left + self.width + + @property + def bottom(self): + return self.top + self.height + + +@dataclass(kw_only=True) +class Event: + type: str + # in microseconds, since page load (?) + timeStamp: float + # optional ID of the event target (e.g.: where the mouse is hovering) + _elementId: Optional[str] = None + # optional tag name of the event target + _elementTagName: Optional[str] = None + # extracted coordinates for the element being interacted with + _elementBounds: Optional[Bounds] = None + + @classmethod + def from_dict(cls, data: dict) -> Self: + data = {k: v for k, v in data.items() if k in cls.__dataclass_fields__} + bounds = data.get("_elementBounds") + if bounds is not None and not isinstance(bounds, Bounds): + data = {**data, "_elementBounds": Bounds(**bounds)} + return cls(**data) + + +@dataclass +class PointerMove(Event): + # should always be 'pointermove' + type: str + # (mouse, touch, pen) + pointerType: str + # coordinate relative to the screen + screenX: float + screenY: float + # coordinate relative to the document (unaffected by scrolling) + pageX: float + pageY: float + # pageX/Y divided by the document Width/Height. This is calculated in JS and sent, which + # it must be b/c we don't know the document width/height at each time otherwise. + normalizedX: float + normalizedY: float + + +pointermove_keys = {f.name for f in fields(PointerMove)} + + +@dataclass +class MouseEvent(Event): + """ + More general than PointerMove. To be used for handling touch events/mobile + also, which don't generate pointermove events. + """ + + # should be {'pointerdown', 'pointerup', 'pointermove', 'click'} + type: str + # Type of input (mouse, touch, pen) + pointerType: str + # coordinate relative to the document (unaffected by scrolling) + pageX: float + pageY: float + + +@dataclass +class KeyboardEvent(Event): + """ """ + + # should be {'keydown', 'input'} + type: str + # "insertText", "insertCompositionText", "deleteCompositionText", + # "insertFromComposition", "deleteContentBackward" + inputType: Optional[str] + # e.g., 'Enter', 'a', 'Backspace' + key: Optional[str] = None + # This is the actual text, if applicable + data: Optional[str] = None + + @property + def key_text(self): + # if we get the input and keydown for a single char press, we don't need both + return ( + f"<{self.key.upper()}>" + if self.key + and self.key.lower() not in {"unidentified", ""} + and len(self.key) > 1 + else None + ) + + @property + def input_type_text(self): + return f"<{self.inputType.upper()}>" if self.inputType else None + + @property + def text(self): + return self.data or self.key_text or self.input_type_text or "" + + +class TimingDataSummary(BaseModel): + """ + Summarizes the pings from a single TimingData + (measurements from a single websocket connection / session, for one user + on one IP) + """ + + count: NonNegativeInt = Field(description="After filtering out outliers") + outlier_count: NonNegativeInt = Field() + outlier_frac: NonNegativeFloat = Field(ge=0, le=1) + + median_log_rtt: PositiveFloat = Field() + mean_log_rtt: PositiveFloat = Field() + std_log_rtt: PositiveFloat = Field() + + median_rtt: PositiveFloat = Field() + mean_rtt: PositiveFloat = Field() + std_rtt: PositiveFloat = Field() + + +class TimingData(BaseModel): + """ + Stores collected RTTs from websocket pings. + todo: can also store bandwidth info collected from router + """ + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + client_rtts: List[float] = Field() + server_rtts: List[float] = Field() + + # Have to be optional for backwards-compatibility, but should always be set. + started_at: Optional[AwareDatetimeISO] = Field(default=None) + ended_at: Optional[AwareDatetimeISO] = Field(default=None) + client_ip: Optional[IPvAnyAddressStr] = Field( + description="This comes from the websocket request's headers", + examples=["72.39.217.116"], + default=None, + ) + server_hostname: Optional[str] = Field( + description="The hostname of the server that handled this request", + examples=["grliq-web-0"], + default=None, + ) + + @property + def server_location(self) -> str: + # TODO: when we have more locations ... + return ( + "fremont_ca" + if self.server_hostname in {"grliq-web-0", "grliq-web-1"} + else "fremont_ca" + ) + + @property + def has_data(self): + return len(self.client_rtts) > 0 and len(self.server_rtts) > 0 + + def filter_rtts(self, rtts): + # Skip the first 5 pings, unless we have <10 pings, then get the last + # 5 instead. + # The first couple pings are usually outliers as they are running + # when a lot of initial JS is also running. + if len(self.client_rtts) >= 10: + rtts = rtts[5:] + else: + rtts = rtts[-5:] + return rtts + + @cached_property + def client_rtts_filtered(self): + return self.filter_rtts(self.client_rtts) + + @property + def client_rtt_mean(self): + rtts = self.client_rtts_filtered + return sum(rtts) / len(rtts) + + @cached_property + def server_rtts_filtered(self): + return self.filter_rtts(self.server_rtts) + + @property + def server_rtt_mean(self): + rtts = self.server_rtts_filtered + return sum(rtts) / len(rtts) + + @property + def filtered_rtts(self): + return self.server_rtts_filtered + self.client_rtts_filtered + + @property + def cleaned_rtts(self): + # Trim outliers + rtts = np.array(self.filtered_rtts) + rtts = rtts[(rtts > 3) & (rtts < 1000)] + if rtts.size > 0: + p5, p95 = np.percentile(rtts, [5, 95]) + rtts = rtts[(rtts >= p5) & (rtts <= p95)] + return rtts + + @property + def summarize(self) -> Optional[TimingDataSummary]: + if len(self.filtered_rtts) < 5: + return None + + orig_len = len(self.filtered_rtts) + rtts = np.array(self.cleaned_rtts) + if len(rtts) < 5: + # We started with 5 or more observations, but removed enough so that + # we have < 5 now. This is probably a signal of something + return None + + log_rtts = np.log(rtts) + + return TimingDataSummary( + count=len(rtts), + outlier_count=orig_len - len(rtts), + outlier_frac=(orig_len - len(rtts)) / orig_len, + median_rtt=float(np.median(rtts)), + mean_rtt=float(np.mean(rtts)), + std_rtt=float(np.std(rtts)), + mean_log_rtt=float(np.mean(log_rtts)), + median_log_rtt=float(np.median(log_rtts)), + std_log_rtt=float(np.std(log_rtts)), + ) diff --git a/generalresearch/grliq/models/forensic_data.py b/generalresearch/grliq/models/forensic_data.py new file mode 100644 index 0000000..c182f43 --- /dev/null +++ b/generalresearch/grliq/models/forensic_data.py @@ -0,0 +1,801 @@ +import hashlib +import re +from collections import Counter +from datetime import datetime, timezone, timedelta +from enum import Enum +from functools import cached_property +from typing import Literal, Optional, Dict, List, Set, Any +from uuid import uuid4 + +import pycountry +from faker import Faker +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + StringConstraints, + AfterValidator, + AwareDatetime, + NonNegativeInt, +) +from pydantic.json_schema import SkipJsonSchema +from pydantic_extra_types.timezone_name import TimeZoneName +from typing_extensions import Self, Annotated + +from generalresearch.grliq.models import ( + AUDIO_CODEC_NAMES, + VIDEO_CODEC_NAMES, + SUPPORTED_FONTS, +) +from generalresearch.grliq.models.events import ( + PointerMove, + TimingData, + MouseEvent, + KeyboardEvent, +) +from generalresearch.grliq.models.forensic_result import ( + GrlIqForensicCategoryResult, + GrlIqCheckerResults, +) +from generalresearch.grliq.models.forensic_result import Phase +from generalresearch.grliq.models.useragents import ( + GrlUserAgent, + OSFamily, + UserAgentHints, +) +from generalresearch.models.custom_types import ( + UUIDStr, + IPvAnyAddressStr, + BigAutoInteger, + AwareDatetimeISO, + CountryISOLike, +) +from generalresearch.models.thl.ipinfo import GeoIPInformation +from generalresearch.models.thl.session import Session + +fake = Faker() + + +class Platform(str, Enum): + MAC_INTEL = "MacIntel" + ARM = "ARM" + IPAD = "iPad" + IPHONE = "iPhone" + WIN32 = "Win32" + WIN64 = "Win64" + LINUX_X86_64 = "Linux x86_64" + LINUX_ARMV81 = "Linux armv81" + LINUX_ARMV8l = "Linux armv8l" + LINUX_ARMV7l = "Linux armv7l" + LINUX_AARCH64 = "Linux aarch64" + OTHER = "Other" + + +class PassFailError(str, Enum): + PASS = "pass" + FAIL = "fail" + ERROR = "error" + + @classmethod + def from_int_1_0_1(cls, v: str | int) -> Self: + lookup = {1: cls.PASS, 0: cls.FAIL, -1: cls.ERROR} + try: + return cls(v) + except ValueError: + return lookup.get(int(v), cls.ERROR) + + @classmethod + def from_int_2_1_0(cls, v: str | int) -> Self: + try: + return cls(v) + except ValueError: + return {2: cls.PASS, 1: cls.FAIL, 0: cls.ERROR, -1: cls.ERROR}[int(v)] + + +class SupportLevel(str, Enum): + # Used for checking if certain features are available in the browser + FULL = "full" + PARTIAL = "partial" + NONE = "none" + + @classmethod + def from_int(cls, v: str | int) -> Self: + try: + return cls(v) + except ValueError: + return {2: cls.FULL, 1: cls.PARTIAL, 0: cls.NONE}[int(v)] + + +def check_valid_hex(v: str) -> str: + if not all(char in "0123456789abcdefABCDEF" for char in v): + raise ValueError("The hash128 must only contain hexadecimal characters.") + return v + + +Hash128 = Annotated[ + str, + StringConstraints(min_length=32, max_length=32), + AfterValidator(check_valid_hex), +] + + +class GrlIqData(BaseModel): + """Forensics JS POSTs ~200 pipe-separated fields to the fake SSO view. We + parse that (along with a couple other fields) into this object. + """ + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + # --- Attributes on the db table directly --- + + id: Optional[BigAutoInteger] = Field(default=None, exclude=True) + uuid: UUIDStr = Field( + default_factory=lambda: uuid4().hex, + description="A unique identifier for this data object", + examples=[uuid4().hex], + ) + + mid: Optional[UUIDStr] = Field( + description="The mid the of the User's attempt (thl-session) that " + "is associated with this data", + examples=[uuid4().hex], + ) + phase: Optional[Phase] = Field( + description="The phase of a thl-session in which this data was collected", + default=Phase.OFFERWALL_ENTER, + ) + product_id: Optional[UUIDStr] = Field( + default=None, + description="The Brokerage Product ID (BPID)", + examples=[uuid4().hex], + ) + product_user_id: Optional[str] = Field( + default=None, + description="The Brokerage Product User ID (BPUID).", + examples=["test-user-2dbeaaf4"], + ) + + country_iso: CountryISOLike = Field( + examples=["us"], + description="This is the country that the offerwall was requested " + "for. Looked up in the thl_session table, via the mid.", + ) + + client_ip: IPvAnyAddressStr = Field( + description="This comes from the actual web request's headers", + examples=["72.39.217.116"], + ) + client_ip_detail: Optional[GeoIPInformation] = Field(default=None) + + created_at: AwareDatetimeISO = Field( + description="When we actually received this data. The timestamp field " + "below comes from the post body and could be manipulated " + "by a baddie." + ) + + request_headers: Dict = Field( + description="The full request headers from the actual HTTP call that was made." + ) + + # data: Dict = Field() + # result_data: Dict = Field() + # fraud_score: int = Field() + # is_attempt_allowed: Optional[bool] = Field(default=None) + results: Optional[GrlIqCheckerResults] = Field(default=None) + category_result: Optional[GrlIqForensicCategoryResult] = Field( + default=None, description="Saved in the database as a jsonb" + ) + + # --- Attributes in the json dict data field --- + + # Note: request_headers should contain origin (the webpage that made the request) + # and referer (URL of the previous page). + + # ---- Below here are parsed from the post body ---- + + timezone_success: PassFailError = Field(description="Should always be 1") + timezone: TimeZoneName = Field(examples=["America/Mexico_City"]) + timezone_offset: int = Field( + description="timezone offset from utc in minutes", examples=[360] + ) + calendar: str = Field(examples=["gregory"]) + numbering_system: str = Field(examples=["latn"]) + + timestamp: AwareDatetime = Field( + examples=["Fri Jan 10 2025 16:52:17 GMT-0600 (Central Standard Time)"] + ) + + user_agent_str: str = Field( + examples=[ + "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" + ] + ) + user_agent_str_2: Optional[str] = Field( + description="This will only be set if different than user_agent_str" + ) + user_agent_hints: Optional[UserAgentHints] = Field( + description="Comes from the User-Agent Client Hints API", default=None + ) + + platform: Platform = Field(description="navigator.platform") + platform_2: Optional[Platform] = Field(description="navigator.platform") + platform_3: Optional[Platform] = Field() + language: str = Field(examples=["en-US"]) + language_2: str = Field(examples=["en-US"]) + language_3: Optional[str] = Field() + calender_locale: str = Field(examples=["en-US"]) + + screen_width: NonNegativeInt = Field() + screen_height: NonNegativeInt = Field() + screen_avail_width: NonNegativeInt = Field() + screen_avail_height: NonNegativeInt = Field() + inner_width: NonNegativeInt = Field() + inner_height: NonNegativeInt = Field() + outer_width: NonNegativeInt = Field() + outer_height: NonNegativeInt = Field() + device_pixel_ratio: float = Field() + color_depth_pixel_depth: str = Field() + + app_name: Literal["Netscape"] = Field( + description="Navigator.appName. Always 'Netscape'" + ) + product_sub: Optional[Literal["20030107", "20100101"]] = Field( + description="Navigator.productSub" + ) + vendor: Optional[Literal["Apple Computer, Inc.", "Google Inc.", "NAVER Corp."]] = ( + Field(description="Navigator.vendor.") + ) + + history_length: int = Field(description="window.history.length. Current tab only") + + webrtc_is_supported: PassFailError = Field() + webrtc_error: bool = Field() + webrtc_local_ip: str = Field() + webrtc_ip: Optional[IPvAnyAddressStr] = Field(examples=[fake.ipv4_public()]) + webrtc_ip_detail: Optional[GeoIPInformation] = Field(default=None) + + hardware_concurrency: Optional[int] = Field( + description="Sometimes this is an empty str" + ) + hardware_concurrency_2: Optional[int] = Field() + hardware_concurrency_3: Optional[int] = Field() + + # Browser/session properties + navigator_java_enabled: bool = Field() + do_not_track_enabled: str = Field(description="unspecified or '' ? ") + mime_types_length: int = Field() + # todo: some report actual values, some are (always) faked by the os/browser + # as anti-fingerprint 10737418240 (10gb) typical on firefox windows, + # 2147483648 (20gb) chrome, iPhones typically only 8 different values + # android seems to show real, user-specific values. + storage_estimate_quota: int = Field() + navigator_cookieEnabled: bool = Field() + + # Browser properties + rendering_engine: str = Field() + graphics_api: str = Field() + graphics_renderer: str = Field() + + eval_to_string_length: int = Field( + description="eval.toString().length. Different in different browsers" + ) + navigator_keys_len: int = Field( + description="Object.keys(Object.getPrototypeOf(navigator)).length" + ) + mozilla_web_app_exists: bool = Field( + description="checking for presence of https://developer.mozilla.org/en-US/docs/Web/Progressive_web_apps" + ) + microsoft_credentials_exists: bool = Field( + description="checking for presence of microsoft credential manager" + ) + window_external_exists: bool = Field( + description="checking if window.external exists" + ) + window_client_information: bool = Field( + description="checking if window.clientInformation exists" + ) + window_opera: bool = Field(description="checking if window.opera exists") + window_chrome: bool = Field(description="checking if window.chrome exists") + navigator_brave: bool = Field(description="checking if navigator.brave exists") + window_active_x_object: bool = Field( + description="checking if ActiveXObject in window" + ) + no_edge_pdf_plugin: bool = Field( + description="checking if we don't have a microsoft edge pdf plugin" + ) + request_fs_exists: bool = Field( + description="checking if window.requestFileSystem or webkitRequestFileSystem exists." + ) + indexedDbData_available: bool = Field( + description="indexedDbData available and functional." + ) + localStorage_available: bool = Field( + description="localStorage available and functional." + ) + web_sql_exists: bool = Field(description="whether window.openDatabase exists") + webgl_flag: bool = Field() + webgl_check_1: bool = Field() + window_installTrigger_exists: bool = Field() + error_message: str = Field( + description="Its checking what the browser's error message is" + ) + math_result_1: str = Field() + math_result_2: str = Field() + # todo: validate per device + speech_synthesis_voices_count: int = Field(description="221 for iphones") + speech_synthesis_voice_1: str = Field() + browser_by_properties: str = Field( + description="looking for the presence of certain objects, to try to " + "detect the browser. 'c' means chrome. Look in code for " + "the list (function ee(n))" + ) + indexedDbData_blob: bool = Field( + description="whether blobs are supported in IndexedDB T/F" + ) + plugins_hash: str = Field( + description="pipe-sep hash|count of installed plugins. typically either " + "'de355917bf33e0789539450797b843f9|5' (windows, iphone, mac) or '|0' (typically android). " + ) + chrome_extensions: str = Field(description="comma sep str of chrome extensions") + audio_codecs: Optional[str] = Field( + examples=["1,1,1,1,1,3,1,3,1,3,3,1,1,3,3,3,3,1,3,3,3,2,1,1"], + description="canPlayType: {'3': probably, '2': maybe, '1': no, 0: error}", + min_length=47, + max_length=47, + ) + video_codecs: Optional[str] = Field( + examples=["1,3,3,3,3,3,3,3,3,3,1,1,1,1,1,1,3,1,1,1,3,3,1"], + description="canPlayType: {'3': probably, '2': maybe, '1': no, 0: error}", + min_length=45, + max_length=45, + ) + + # Browser Functionality Check + session_storage_check: PassFailError = Field( + description="If this doesn't work, something is probably manipulated" + ) + cookie_check: str = Field(description="idk") + audio_context_flag: PassFailError = Field( + description="checking some audio buffer thing idk" + ) + canvas_pixel_check: bool = Field() + + # Browser Automation Checks + navigator_webdriver: bool = Field() + webdriver_detected: bool = Field() + webdriver_detected_msg: str = Field() + non_native_function: bool = Field(description='checking for "[native code]"') + non_native_function_flag: str = Field( + description="comma-sep str", examples=["6,9,16"] + ) + error_message_stack_access_count: int = Field() + error_message_stack_access_count_worker: int = Field( + description="I think should equal error_message_stack_access_count?" + ) + + # Device Properties + ontouchstart: bool = Field() + # todo: confirm this is 5 for an iphone + max_touch_points: int = Field() + navigator_deviceMemory: Optional[float] = Field() + memory_jsHeapSizeLimit: int = Field() + navigator_mediaDevices_len: int = Field() + unmasked_vendor_webgl: str = Field() + unmasked_renderer_webgl: str = Field() + keyboard_detected: bool = Field() + keyboard_layout_size: Optional[int] = Field(description="mobile safari None?") + window_orientation: int = Field(description="0 or 1. idk which is which") + + # Session properties + # todo: we should check this across POSTs for a user, b/c it should be + # *different* each (to make sure they aren't reusing posts) + execution_time_ms: float = Field() + performance_loop_time: float = Field() + connection_rtt: Optional[int] = Field() + connection_downlink: Optional[float] = Field() + connection_type: str = Field() + connection_effectiveType: str = Field() + + # fingerprint stuff + canvas_support_level: SupportLevel = Field() + canvas_hash: Optional[Hash128] = Field( + description="dfiq's canvas image fingerprint" + ) + canvas_hash_2: Optional[Hash128] = Field( + description="simpler canvas image fingerprint stolen from amiunique.org", + default=None, + ) + webgl_hash: Optional[Hash128] = Field( + description="DFIQ's version of webgl hash. It has stuff included in the hash: anisotropy, supported " + "extensions, etc." + ) + webgl_context: Optional[ + Literal[ + "webgl2", + "webgl", + "experimental-webgl2", + "experimental-webgl", + "webkit-3d", + "moz-webgl", + ] + ] = Field(default=None) + webgl_max_anisotropy: Optional[int] = Field(default=None, examples=[16]) + webgl_shading_language_version: Optional[str] = Field( + default=None, + examples=["WebGL GLSL ES 3.00 (OpenGL ES GLSL ES 3.0 Chromium)"], + ) + webgl_hash_2: Optional[Hash128] = Field( + description="hash128 of the canvas image, without additional stuff concatenated to it", + default=None, + ) + webgl_extensions: Optional[str] = Field( + description="pipe-separated list of webgl extensions", + default=None, + examples=[ + "EXT_clip_control;EXT_color_buffer_float;EXT_color_buffer_half_float" + ], + ) + + audio_context_hash: Optional[Hash128] = Field() + audio_intensity_fingerprint: Optional[float] = Field() + audio_compressor_reduction: Optional[float] = Field() + speech_synthesis_voice_hash: Optional[Hash128] = Field() + speech_synthesis_avail_voices_count: int = Field() + path_fingerprint: int = Field( + description="maybe is consistent? sum of pixel value of some path." + ) + text_2d_fingerprint: Optional[Hash128] = Field() + canvas_fingerprint: int = Field() + + # User Preferences + color_gamut: Optional[Literal["1", "2", "3", "0"]] = Field( + description="{'1': 'rec2020', '2':'p3', '3':'srgb', '0': none} # p3 typically used in macbooks n stuff" + ) + prefers_contrast: Optional[Literal["0", "1", "2", "3", "4", "5", "9"]] = Field( + description="{'no-preference': 0, 'high': 1, 'more': 2, 'low': 3, 'less': 4, 'forced': 5, None: 9}" + ) + prefers_reduced_motion: bool = Field(description="reduce (1) vs no-preference (0)") + dynamic_range: bool = Field(description="high (1) vs standard (0)") + inverted_colors: bool = Field(description="{'inverted': 1, 'none': 0}") + forced_colors: bool = Field(description="{'active': 1, 'none': 0}") + prefers_color_scheme: bool = Field(description="{'dark': 1, '?': 0}") + + # Battery Info + battery_charging: Optional[bool] = Field(default=None) + battery_charging_time: Optional[float] = Field(default=None) + battery_discharging_time: Optional[float] = Field(default=None) + battery_level: Optional[float] = Field(default=None, ge=0, le=1) + + supported_fonts_str: Optional[str] = Field( + default=None, + description="Bit-packed string for font support. Each element is 32 bits, with each bit representing T/F for " + "font support.", + examples=[ + "72|768|262144|1073741824|0|0|540672|73728|7340032|1342177280|117446656|256|16|0|543|4290797636" + "|1677723648|4168998400|0|1048576|262144|268500994|1342177280|262144|125829376|37888000|0|435363842|0" + "|2147483648|109543424|1880099872|268435471" + ], + ) + + # Time it took for the client to download the logo.jpg + logo_download_ms: Optional[float] = Field(default=None, gt=0) + + # --- Not from post body ---- + + prefetched: SkipJsonSchema[bool] = Field( + default=False, description="Has prefetch been run?" + ) + + # Can optionally be loaded from the grliq_forensicevents table + events: Optional[List[Dict]] = Field(default=None) + pointer_move_events: Optional[List[PointerMove]] = Field(default=None) + mouse_events: Optional[List[MouseEvent]] = Field(default=None) + keyboard_events: Optional[List[KeyboardEvent]] = Field(default=None) + + timing_data: Optional[TimingData] = Field(default=None) + + @property + def session_uuid(self) -> Optional[UUIDStr]: + return self.mid + + @cached_property + def useragent(self) -> GrlUserAgent: + return GrlUserAgent.from_ua_str(self.user_agent_str) + + @cached_property + def fingerprint_keys(self) -> List[str]: + fp_cols = [ + "country_iso", + "canvas_hash", + "canvas_hash_2", + # This is removed in favor of webgl_hash_2, which is just the image hash + # "webgl_hash", + "webgl_hash_2", + "audio_context_hash", + "audio_intensity_rounded", + "audio_compressor_reduction", + "speech_synthesis_voice_hash", + "speech_synthesis_avail_voices_count", + "path_fingerprint", + "text_2d_fingerprint", + "canvas_fingerprint", + "device_pixel_ratio", + "storage_estimate_quota", + "audio_codecs", + "video_codecs", + "color_gamut", + "prefers_contrast", + "prefers_reduced_motion", + "dynamic_range", + "inverted_colors", + "forced_colors", + "prefers_color_scheme", + ] + if self.useragent.os.family in {OSFamily.IOS, OSFamily.MAC_OSX}: + fp_cols += ["screen_width", "screen_height"] + return fp_cols + + @cached_property + def fingerprint(self) -> str: + s = "|".join(map(str, [getattr(self, k) for k in self.fingerprint_keys])) + return hashlib.md5(s.encode()).hexdigest() + + @cached_property + def audio_codecs_named(self) -> Dict: + return dict( + zip( + AUDIO_CODEC_NAMES, + [True if x == "3" else False for x in self.audio_codecs.split(",")], + ) + ) + + @cached_property + def video_codecs_named(self) -> Dict: + return dict( + zip( + VIDEO_CODEC_NAMES, + [True if x == "3" else False for x in self.video_codecs.split(",")], + ) + ) + + @cached_property + def supported_fonts_binary(self) -> str: + return "".join( + [ + format(int(packed_int), "032b") + for packed_int in self.supported_fonts_str.split("|") + ] + )[-len(SUPPORTED_FONTS) :] + + @cached_property + def supported_fonts(self) -> Set[str]: + return { + f for x, f in zip(self.supported_fonts_binary, SUPPORTED_FONTS) if x == "1" + } + + @cached_property + def audio_intensity_rounded(self) -> Optional[float]: + # The audio intensity fingerprint seems to be purposely manipulated + # to add randomness, but the level of randomness if very low, past + # the 6th decimal point. + return ( + round(self.audio_intensity_fingerprint, 6) + if self.audio_intensity_fingerprint + else None + ) + + @cached_property + def event_type_count(self) -> Counter: + return Counter([x["type"] for x in self.events]) + + # @field_validator( + # "hardware_concurrency", + # mode="before", + # ) + # @classmethod + # def str_to_int(cls, value: str) -> int: + # return int(value) + + @field_validator( + "hardware_concurrency", + "hardware_concurrency_2", + "hardware_concurrency_3", + "language_3", + "connection_rtt", + "keyboard_layout_size", + mode="before", + ) + @classmethod + def str_to_int_or_null(cls, value: str) -> Optional[int]: + return int(value) if value not in {None, ""} else None + + @field_validator( + "connection_downlink", + "audio_intensity_fingerprint", + "audio_compressor_reduction", + "navigator_deviceMemory", + mode="before", + ) + @classmethod + def str_to_float_or_null(cls, value: str) -> Optional[int]: + return float(value) if value not in {None, ""} else None + + @field_validator( + "vendor", + "product_sub", + "speech_synthesis_voice_hash", + "webgl_hash", + "canvas_hash", + "audio_context_hash", + "text_2d_fingerprint", + "audio_codecs", + "video_codecs", + mode="before", + ) + @classmethod + def str_or_null(cls, value: str) -> Optional[str]: + return value or None + + @field_validator( + "webrtc_error", + "no_edge_pdf_plugin", + "indexedDbData_blob", + "keyboard_detected", + "non_native_function", + "navigator_brave", + "window_chrome", + "navigator_webdriver", + "webgl_check_1", + "window_installTrigger_exists", + "navigator_cookieEnabled", + "window_active_x_object", + "window_opera", + "window_external_exists", + "microsoft_credentials_exists", + "mozilla_web_app_exists", + "webdriver_detected", + "canvas_pixel_check", + "prefers_reduced_motion", + "dynamic_range", + "inverted_colors", + "forced_colors", + "prefers_color_scheme", + mode="before", + ) + @classmethod + def str_to_bool(cls, value: str) -> bool: + return bool(int(value)) + + @field_validator( + "request_fs_exists", + "indexedDbData_available", + "localStorage_available", + "webgl_flag", + mode="before", + ) + @classmethod + def str_to_bool_2_1(cls, value: str | bool) -> bool: + # 2 is True, 1 is False !!!! (why?) + if isinstance(value, str): + return bool(int(value) - 1) + else: + return value + + @field_validator("timezone_success", "webrtc_is_supported", mode="before") + @classmethod + def pass_fail_error(cls, value: int) -> PassFailError: + if not isinstance(value, PassFailError): + return PassFailError.from_int_1_0_1(value) + + @field_validator("session_storage_check", "audio_context_flag", mode="before") + @classmethod + def pass_fail_error_210(cls, value: int) -> PassFailError: + if not isinstance(value, PassFailError): + return PassFailError.from_int_2_1_0(value) + + @field_validator("canvas_support_level", mode="before") + @classmethod + def support_level(cls, value: int) -> SupportLevel: + if not isinstance(value, SupportLevel): + return SupportLevel.from_int(value) + + @field_validator("country_iso") + @classmethod + def validate_country_iso(cls, value: str) -> str: + if not pycountry.countries.get(alpha_2=value.lower()): + raise ValueError(f"{value} is not a valid ISO 3166-1 alpha-2 country code.") + return value.lower() + + @field_validator("platform", "platform_2", "platform_3", mode="before") + @classmethod + def platform_enum_or_other(cls, value: Optional[str]) -> Optional[Platform]: + if value is None or value == "": + return None + try: + return Platform(value) + except ValueError: + return Platform.OTHER + + @field_validator("webrtc_ip", mode="before") + @classmethod + def preprocess_ip(cls, ip: str) -> Optional[str]: + # Strip square brackets if present + return re.sub(r"^\[|\]$", "", ip) if ip else None + + # Doesn't work. "vi" is a valid value (en-US) + # @field_validator("language") + # @classmethod + # def validate_language_country(cls, value: str) -> str: + # language_code, country_code = value.split("-") + # try: + # pycountry.languages.get(alpha_2=language_code) + # except KeyError: + # raise ValueError(f"Invalid language code '{language_code}'.") + # try: + # pycountry.countries.get(alpha_2=country_code.upper()) + # except KeyError: + # raise ValueError(f"Invalid country code '{country_code.upper()}'.") + # return value + + @field_validator("timestamp", mode="before") + @classmethod + def parse_timestamp(cls, value: str) -> str | datetime: + if isinstance(value, datetime): + return value + if "GMT" not in value: + return value + try: + value = value.split("(")[0].strip() + return datetime.strptime(value, "%a %b %d %Y %H:%M:%S GMT%z") + except ValueError: + raise ValueError(f"Invalid date format: {value}") + + def validate_with_session(self, session: Session) -> None: + # product_id and product_user_id are parsed from the post body. make sure + # they match the session whose mid was specified + assert self.product_id == session.user.product_id, "product_id mismatch" + assert ( + self.product_user_id == session.user.product_user_id + ), "product_user_id mismatch" + + # validate the Session's mid is "recent" + assert (datetime.now(tz=timezone.utc) - session.started) < timedelta( + minutes=90 + ), "expired session" + + return None + + def model_dump_sql(self, **kwargs) -> Dict[str, Any]: + d = dict() + d["uuid"] = self.uuid + d["session_uuid"] = self.mid + d["created_at"] = self.created_at + d.update(self.useragent.ua_string_values) + d["data"] = self.model_dump_json(**kwargs) + keys = [ + "client_ip", + "country_iso", + "product_id", + "product_user_id", + "phase", + ] + + for k in keys: + d[k] = getattr(self, k) + + return d + + @classmethod + def from_db(cls, d: Dict) -> Self: + res = GrlIqData.model_validate(d["data"]) + + if d.get("category_result"): + res.category_result = GrlIqForensicCategoryResult.model_validate( + d["category_result"] + ) + + if d.get("result_data"): + res.results = GrlIqCheckerResults.model_validate(d["result_data"]) + + return res diff --git a/generalresearch/grliq/models/forensic_result.py b/generalresearch/grliq/models/forensic_result.py new file mode 100644 index 0000000..7f906c2 --- /dev/null +++ b/generalresearch/grliq/models/forensic_result.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +from enum import Enum +from typing import Optional, List, Set +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field, computed_field + +from generalresearch.grliq.models.custom_types import GrlIqScore +from generalresearch.grliq.models.decider import ( + Decider, + AttemptDecision, + GrlIqAttemptResult, +) +from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO + + +class Phase(str, Enum): + # The 'phase' of a THL-Session experience. grliq may be collected in + # multiple places multiple times within one session + + # Within a custom offerwall. Very optional, as most BPs won't be running our code + OFFERWALL = "offerwall" + # When a user clicks on a bucket. Each session should go through this + OFFERWALL_ENTER = "offerwall-enter" + # Running in GRS. Not every session will have this. + PROFILING = "profiling" + # We could run grl-iq again when a user continues a session + SESSION_CONTINUE = "session-continue" + + +class GrlIqForensicCategoryResult(BaseModel): + """ + This is for reporting external to GRL. + + There is a balance between exposing enough to answer "why did this user get blocked?" without + giving away technical knowledge that could be used to bypass. + """ + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + uuid: Optional[UUIDStr] = Field( + description="The uuid for the GrlIqData model these results are based on", + default=None, + examples=[uuid4().hex], + ) + + updated_at: Optional[AwareDatetimeISO] = Field(default=None) + is_complete: bool = Field( + description="This is based on whether or not the GrlIqCheckerResults" + "object that this data was based on was complete at that time.", + default=False, + ) + + # ----- Behavioral ----- + is_bot: GrlIqScore = Field( + description="User is behaving in a bot-like manner, for e.g. clicking " + "buttons without moving the mouse", + default=0, + ) + is_velocity: GrlIqScore = Field( + description="User is making HTTP Requests faster than a typical user", + default=0, + ) + is_oscillating: GrlIqScore = Field( + description="User is changing IPs suspiciously (eg: may indicate " + "non-sticky SOCKS connections.", + default=0, + ) + is_teleporting: GrlIqScore = Field( + description="User is moving Countries, Geographic Regions, and/or " + "locations faster than is humanly possible.", + default=0, + ) + # ..... + + # ----- Technical ----- + is_inconsistent: GrlIqScore = Field( + description="The User's platform (browser/device/OS) is inconsistent.", + default=0, + ) + is_tampered: GrlIqScore = Field( + description="The User attempted to interfere with or modify the " + "GRL-IQ security platform.", + default=0, + ) + + # ----- GeoIP ----- + + # Should this be a bool??? would it ever not be 0 or 100 ? answer: I guess if we check + # the IP via multiple sources, it could not be. + is_anonymous: GrlIqScore = Field( + description="The User's IP is flagged as anonymous", + default=0, + ) + suspicious_ip: GrlIqScore = Field( + description="The User's IP properties are suspicious", + default=0, + ) + platform_ip_inconsistent: GrlIqScore = Field( + description="The User's platform (browser/device/OS) is inconsistent " + "with the User's IP", + default=0, + ) + + @staticmethod + def model_score_fields() -> List[str]: + return [ + "is_bot", + "is_velocity", + "is_oscillating", + "is_teleporting", + "is_inconsistent", + "is_tampered", + "is_anonymous", + "suspicious_ip", + "platform_ip_inconsistent", + ] + + @property + def fraud_score(self) -> GrlIqScore: + return max([getattr(self, k) for k in self.model_score_fields()]) + + def is_attempt_allowed(self) -> bool: + # this could take the buyer's security tolerances + threshold = 50 + return all(getattr(self, k) < threshold for k in self.model_score_fields()) + + def make_decision(self) -> GrlIqAttemptResult: + decision = ( + AttemptDecision.PASS if self.is_attempt_allowed() else AttemptDecision.FAIL + ) + return GrlIqAttemptResult( + decider=Decider.GRL_IQ, decision=decision, fraud_score=self.fraud_score + ) + + +class GrlIqCheckerResult(BaseModel): + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + score: GrlIqScore = Field(default=0) + msg: Optional[str] = Field(default=None) + + @property + def passes(self) -> bool: + return self.score < 50 + + +class GrlIqObservations(BaseModel): + + fingerprint_count: int = Field( + default=0, description="Count of unique fingerprints (past 30 days)" + ) + shared_fingerprint_count: int = Field( + default=0, + description="Count of users sharing the same fingerprints (past 30 days, same product_id)", + ) + cellular_ip_count: int = Field( + default=0, description="Count of unique cellular IPs used (past 30 days)" + ) + non_cellular_ip_count: int = Field( + default=0, description="Count of unique cellular IPs used (past 30 days)" + ) + isp_count: int = Field(default=0, description="Count of unique ISPs (past 30 days)") + timezone_count: int = Field( + default=0, description="Count of unique timezones (by IP) (past 30 days)" + ) + + paste_event_count: Optional[int] = Field( + default=None, description="Count of paste events (user pasted in text)" + ) + visibilitychange_event_count: Optional[int] = Field( + default=None, + description="Count of visibilitychange events (entire page isn't visible)", + ) + blur_event_count: Optional[int] = Field( + default=None, description="Count of blur events (page lost focus)" + ) + devicemotion_event_count: Optional[int] = Field( + default=None, + description="Count of devicemotion events (device gyroscope motion)", + ) + click_event_count: Optional[int] = Field( + default=None, + description="Count of click events (any pointer type)", + ) + # all clicks are marked as pointerType = 'mouse', but other pointermove events have a pointerType + # of 'touch' or 'mouse' + pointermove_pointer_types: Optional[Set[str]] = Field( + default=None, description="pointer types" + ) + + +class GrlIqCheckerResults(BaseModel): + """ + Holds results for each individual checker. + Used to calculate the category-based results and a final attempt-allowed decision + """ + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + uuid: Optional[UUIDStr] = Field( + description="The uuid for the GrlIqData model these results are based on", + default=None, + examples=[uuid4().hex], + ) + + updated_at: Optional[AwareDatetimeISO] = Field(default=None) + + observations: Optional[GrlIqObservations] = Field(default=None) + + # browser_props + check_environment: GrlIqCheckerResult = Field() + check_environment_critical: GrlIqCheckerResult = Field() + check_codecs: GrlIqCheckerResult = Field(default_factory=GrlIqCheckerResult) + + # fingerprints + check_fingerprint_cycling: GrlIqCheckerResult = Field( + default_factory=GrlIqCheckerResult + ) + check_fingerprint_reuse: GrlIqCheckerResult = Field( + default_factory=GrlIqCheckerResult + ) + + # System Fonts + check_required_fonts: GrlIqCheckerResult = Field(default_factory=GrlIqCheckerResult) + check_prohibited_fonts: GrlIqCheckerResult = Field( + default_factory=GrlIqCheckerResult + ) + + # IP Info + check_ip_country: GrlIqCheckerResult = Field() + check_user_type: GrlIqCheckerResult = Field() + check_ip_timezone: GrlIqCheckerResult = Field() + check_user_anonymous: GrlIqCheckerResult = Field() + check_ip_changes: GrlIqCheckerResult = Field(default_factory=GrlIqCheckerResult) + check_isp_changes: GrlIqCheckerResult = Field(default_factory=GrlIqCheckerResult) + check_timezone_changes: GrlIqCheckerResult = Field( + default_factory=GrlIqCheckerResult + ) + + # tampered + check_timestamp: GrlIqCheckerResult = Field(default_factory=GrlIqCheckerResult) + check_seen_timestamps: GrlIqCheckerResult = Field( + default_factory=GrlIqCheckerResult + ) + check_execution_time_ms: GrlIqCheckerResult = Field( + default_factory=GrlIqCheckerResult + ) + + # timezone + check_timezone: GrlIqCheckerResult = Field() + check_country_timezone: GrlIqCheckerResult = Field() + + # useragents + check_useragent_other_enums: GrlIqCheckerResult = Field() + check_useragent_ip_properties: GrlIqCheckerResult = Field() + check_useragent_js: GrlIqCheckerResult = Field() + check_useragent_data_properties: GrlIqCheckerResult = Field() + check_useragent_device_family_brand: GrlIqCheckerResult = Field() + + # webrtc + check_webrtc_success: GrlIqCheckerResult = Field() + check_ip_webrtc_ip_detail: GrlIqCheckerResult = Field() + + # websocket (events) + check_page_load_events: Optional[GrlIqCheckerResult] = Field(default=None) + check_grliq_events: Optional[GrlIqCheckerResult] = Field(default=None) + check_pasting: Optional[GrlIqCheckerResult] = Field(default=None) + check_pointer_movements: Optional[GrlIqCheckerResult] = Field(default=None) + check_device_motion: Optional[GrlIqCheckerResult] = Field(default=None) + check_pointer_type: Optional[GrlIqCheckerResult] = Field(default=None) + check_for_bad_events: Optional[GrlIqCheckerResult] = Field(default=None) + + # websocket (ping) + check_average_rtt: Optional[GrlIqCheckerResult] = Field(default=None) + + # todo: we might also have a "fingerprint" in here ??? + + @property + def checker_fields(self): + fields = list(self.model_fields.keys()) + return [f for f in fields if f.startswith("check_")] + + @computed_field + @property + def is_complete(self) -> bool: + return all(getattr(self, f) is not None for f in self.checker_fields) diff --git a/generalresearch/grliq/models/forensic_summary.py b/generalresearch/grliq/models/forensic_summary.py new file mode 100644 index 0000000..2d38435 --- /dev/null +++ b/generalresearch/grliq/models/forensic_summary.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +import random +from typing import ( + List, + Literal, + Optional, + Tuple, + Dict, + get_type_hints, + get_origin, + Union, + get_args, +) + +import numpy as np +from pydantic import ( + BaseModel, + ConfigDict, + Field, + NonNegativeInt, + create_model, + computed_field, +) +from scipy.stats import lognorm + +from generalresearch.grliq.models.custom_types import GrlIqAvgScore, GrlIqRate +from generalresearch.grliq.models.forensic_result import ( + GrlIqCheckerResults, + GrlIqCheckerResult, +) +from generalresearch.models.custom_types import IPvAnyAddressStr, AwareDatetimeISO +from generalresearch.models.thl.locales import CountryISO +from generalresearch.models.thl.maxmind.definitions import UserType + +example_rtt_percentiles = ( + [133.332] + + list( + map(float, lognorm.ppf(np.linspace(0.01, 0.99, 99), s=0.1, scale=175).round(3)) + ) + + [890.006] +) + + +class UserForensicSummary(BaseModel): + """ + 'Top-level' forensic summary for a user + """ + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + period_start: Optional[AwareDatetimeISO] = Field( + default=None, + description="Timestamp of the earliest attempt included in this summary (UTC)", + ) + period_end: Optional[AwareDatetimeISO] = Field( + default=None, + description="Timestamp of the latest attempt included in this summary (UTC)", + ) + + # These must be nullable in case a user has 0 attempts! + category_result_summary: Optional[GrlIqForensicCategorySummary] = Field( + default=None + ) + checker_result_summary: Optional[GrlIqCheckerResultsSummary] = Field(default=None) + + country_timing_data_summary: Dict[CountryISO, TimingDataCountrySummary] = Field( + default_factory=dict + ) + ip_timing_data_summary: Dict[IPvAnyAddressStr, IPTimingDataSummary] = Field( + default_factory=dict + ) + + +class GrlIqForensicCategorySummary(BaseModel): + """ + GrlIqForensicCategoryResult Summary across multiple attempts by a single user. + """ + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + attempt_count: NonNegativeInt = Field( + description="Number of attempts included in this summary", examples=[42] + ) + is_attempt_allowed_count: NonNegativeInt = Field( + description="The count of attempts that were allowed.", examples=[40] + ) + + is_complete_rate: GrlIqRate = Field( + description="Proportion of attempts where is_complete=True", + examples=[random.random()], + ) + + is_bot_avg: GrlIqAvgScore = Field(examples=[random.randint(0, 100)]) + is_velocity_avg: GrlIqAvgScore = Field(examples=[random.randint(0, 100)]) + is_oscillating_avg: GrlIqAvgScore = Field(examples=[random.randint(0, 100)]) + is_teleporting_avg: GrlIqAvgScore = Field(examples=[random.randint(0, 100)]) + is_inconsistent_avg: GrlIqAvgScore = Field(examples=[random.randint(0, 100)]) + is_tampered_avg: GrlIqAvgScore = Field(examples=[random.randint(0, 100)]) + is_anonymous_avg: GrlIqAvgScore = Field(examples=[random.randint(0, 100)]) + suspicious_ip_avg: GrlIqAvgScore = Field(examples=[random.randint(0, 100)]) + platform_ip_inconsistent_avg: GrlIqAvgScore = Field( + examples=[random.randint(0, 100)] + ) + fraud_score_avg: Optional[GrlIqAvgScore] = Field( + default=None, examples=[random.randint(0, 100)] + ) + + +def is_optional(t): + return get_origin(t) is Union and type(None) in get_args(t) + + +def unwrap_optional(t): + if is_optional(t): + return next(arg for arg in get_args(t) if arg is not type(None)) + return t + + +def generate_GrlIqCheckerResultsSummary(): + """ + Dynamically generate GrlIqCheckerResultsSummary model from the + GrlIqCheckerResults model. Each check_* field is added + as a float with the name check_*_avg. + """ + fields = {} + for field_name, hint in get_type_hints(GrlIqCheckerResults).items(): + is_opt = is_optional(hint) + base_type = unwrap_optional(hint) + if base_type == GrlIqCheckerResult: + if is_opt: + fields[f"{field_name}_avg"] = ( + Optional[GrlIqAvgScore], + Field(default=None, examples=[random.randint(0, 100)]), + ) + fields[f"{field_name}_pct_none"] = ( + GrlIqRate, + Field(examples=[random.random()], default=0), + ) + else: + fields[f"{field_name}_avg"] = ( + GrlIqAvgScore, + Field(examples=[random.randint(0, 100)]), + ) + + summary_model = create_model( + "GrlIqCheckerResultsSummary", + __doc__="GrlIqCheckerResults Summary across multiple attempts by a single user.", + __config__=ConfigDict(extra="forbid", validate_assignment=True), + **fields, + ) + return summary_model + + +GrlIqCheckerResultsSummary = generate_GrlIqCheckerResultsSummary() + + +class TimingDataCountrySummary(BaseModel): + """ + Summary of timing data results for a single user across all of their observed ips, + within 1 country (to one server_location). + """ + + country_iso: CountryISO = Field(examples=["us"]) + server_location: Literal["fremont_ca"] = Field(default="fremont_ca") + + rtt_min: float = Field(gt=0, examples=[133.332]) + rtt_q25: float = Field(gt=0, examples=[144.928]) + rtt_median: float = Field(gt=0, examples=[167.743]) + rtt_mean: float = Field(gt=0, examples=[179.302]) + rtt_q75: float = Field(gt=0, examples=[220.232]) + rtt_max: float = Field(gt=0, examples=[890.006]) + + expected_rtt_range: Tuple[float, float] = Field( + description="The expected rtt range for this IP (based on country_iso/user_type) to server_location", + examples=[(45.193, 120.841)], + ) + mean_z_score: float = Field( + examples=[1.22238], + description="Mean of all z-scores. A z-score is calculated" + "for each of the user's sessions.", + ) + + +class IPTimingDataSummary(BaseModel): + """ + Summary of timing data results for a single user on a single ip (across all of + this user's sessions on this IP) + """ + + client_ip: IPvAnyAddressStr = Field(examples=["123.123.123.123"]) + country_iso: CountryISO = Field(examples=["us"]) + server_location: Literal["fremont_ca"] = Field(default="fremont_ca") + user_type: Optional[UserType] = Field(default=None, examples=[UserType.RESIDENTIAL]) + expected_rtt_range: Tuple[float, float] = Field( + description="The expected rtt range for this IP (based on country_iso/user_type) to server_location", + examples=[(45.193, 120.841)], + ) + observed_rtt_mean: float = Field(gt=0, examples=[382.983]) + mean_z_score: float = Field(examples=[2.411]) + + +class CountryRTTDistribution(BaseModel): + """The distribution of observed RTTs (optionally filtered by `is_fraud`) + from `country_iso` to `server_location`. + + This would be returned from its own endpoint (Get Expected RTT by Country), + where you could optionally pass url params 'is_fraud', 'user_type'. + """ + + server_location: Literal["fremont_ca"] = Field(default="fremont_ca") + country_iso: CountryISO = Field( + description="Country client_ip is located in", examples=["fr"] + ) + # For users marked as fraud or not + is_fraud: Optional[bool] = Field( + default=None, + description="If timing data from sessions determined to be fraud are included", + ) + + # we could split by this optionally + user_type: Optional[UserType] = Field( + default=None, + description="user_type of the client_ip as determined by MaxMind", + examples=[UserType.RESIDENTIAL], + ) + + rtt_min: float = Field(gt=0, examples=[133.332]) + rtt_median: float = Field(gt=0, examples=[167.743]) + rtt_mean: float = Field(gt=0, examples=[179.302]) + rtt_max: float = Field(gt=0, examples=[890.006]) + rtt_std: float = Field(gt=0, examples=[46.831]) + rtt_percentiles: List[float] = Field( + min_length=101, max_length=101, examples=[example_rtt_percentiles] + ) + + rtt_log_median: float = Field(gt=0, examples=[5.122]) + rtt_log_mean: float = Field(gt=0, examples=[5.167]) + rtt_log_std: float = Field(gt=0, examples=[0.191]) + + @computed_field( + examples=[(119.844, 256.873)], + description="The 95% confidence interval calculated in log-space", + ) + @property + def expected_rtt_range(self) -> Tuple[float, float]: + # This is the log_mean +- 2 log_std, then converted back to non-log space. + # This is not just the mean + 2x std b/c we calculate the expected + # range in log-space (due to high skewness) + edge = self.rtt_log_std * 2 + return float(np.exp(self.rtt_log_mean - edge)), float( + np.exp(self.rtt_log_mean + edge) + ) + + def boxplot(self): + """ + Render a boxplot from the RTT percentiles. + """ + try: + # annoying pycharm error + import matplotlib.pyplot as plt + except ImportError as e: + raise e + + p = self.rtt_percentiles + data = { + "whislo": p[5], + "q1": p[25], + "med": p[50], + "q3": p[75], + "whishi": p[95], + "fliers": [p[0]] + ([p[100]] if p[100] > p[95] else []), + } + + fig, ax = plt.subplots(figsize=(4, 1.5)) + ax.bxp([data], showfliers=True, vert=False) + ax.set_title(f"RTT Boxplot for {self.country_iso}") + ax.set_xlabel("RTT (ms)") + ax.set_yticks([]) + + plt.tight_layout() + plt.show() diff --git a/generalresearch/grliq/models/useragents.py b/generalresearch/grliq/models/useragents.py new file mode 100644 index 0000000..6cea2dc --- /dev/null +++ b/generalresearch/grliq/models/useragents.py @@ -0,0 +1,246 @@ +import hashlib +from enum import Enum +from typing import Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field, field_validator +from typing_extensions import Self +from user_agents import parse as ua_parse +from user_agents.parsers import UserAgent + + +class BrowserFamily(str, Enum): + CHROME_MOBILE = "Chrome Mobile" + CHROME = "Chrome" + CHROME_MOBILE_WEBVIEW = "Chrome Mobile WebView" + MOBILE_SAFARI_UI_WKWEBVIEW = "Mobile Safari UI/WKWebView" + MOBILE_SAFARI = "Mobile Safari" + EDGE = "Edge" + FIREFOX = "Firefox" + SAMSUNG_INTERNET = "Samsung Internet" + SAFARI = "Safari" + CHROME_MOBILE_IOS = "Chrome Mobile iOS" + OPERA = "Opera" + EDGE_MOBILE = "Edge Mobile" + MIUI_BROWSER = "MiuiBrowser" + OPERA_MOBILE = "Opera Mobile" + FIREFOX_MOBILE = "Firefox Mobile" + GOOGLE = "Google" + AMAZON_SILK = "Amazon Silk" + FIREFOX_IOS = "Firefox iOS" + YANDEX_BROWSER = "Yandex Browser" + OTHER = "Other" + + +class OSFamily(str, Enum): + ANDROID = "Android" + WINDOWS = "Windows" + IOS = "iOS" + MAC_OSX = "Mac OS X" + LINUX = "Linux" + CHROME_OS = "Chrome OS" + UBUNTU = "Ubuntu" + OTHER = "Other" + + +class DeviceBrand(str, Enum): + GENERIC_ANDROID = "Generic_Android" + NONE = "None" + APPLE = "Apple" + SAMSUNG = "Samsung" + OPPO = "Oppo" + MOTOROLA = "Motorola" + GENERIC_ANDROID_TABLET = "Generic_Android_Tablet" + VIVO = "vivo" + HUAWEI = "Huawei" + XIAOMI = "XiaoMi" + INFINIX = "Infinix" + GENERIC = "Generic" + NOKIA = "Nokia" + GOOGLE = "Google" + TECNO = "Tecno" + AMAZON = "Amazon" + ONE_PLUS = "OnePlus" + LENOVO = "Lenovo" + OTHER = "Other" + + +class DeviceModelFamily(str, Enum): + NONE = "None" + OTHER = "Other" + K = "K" + IPHONE = "iPhone" + MAC = "Mac" + IPAD = "iPad" + + +firefox_families = { + BrowserFamily.FIREFOX, + BrowserFamily.FIREFOX_MOBILE, + BrowserFamily.FIREFOX_IOS, +} + +safari_families = { + BrowserFamily.SAFARI, + BrowserFamily.MOBILE_SAFARI, + BrowserFamily.MOBILE_SAFARI_UI_WKWEBVIEW, +} + +chrome_families = { + BrowserFamily.CHROME, + BrowserFamily.CHROME_MOBILE, + BrowserFamily.CHROME_MOBILE_IOS, + BrowserFamily.CHROME_MOBILE_WEBVIEW, +} + +mobile_families = { + BrowserFamily.CHROME_MOBILE, + BrowserFamily.CHROME_MOBILE_IOS, + BrowserFamily.CHROME_MOBILE_WEBVIEW, + BrowserFamily.MOBILE_SAFARI, + BrowserFamily.MOBILE_SAFARI_UI_WKWEBVIEW, + BrowserFamily.FIREFOX_MOBILE, + BrowserFamily.FIREFOX_IOS, +} + + +class OSInfo(BaseModel): + family: OSFamily = Field() + version_string: Optional[str] = Field() + + @field_validator("family", mode="before") + @classmethod + def enum_or_other(cls, value: str) -> OSFamily: + try: + return OSFamily(value) + except ValueError: + return OSFamily.OTHER + + +class BrowserInfo(BaseModel): + family: BrowserFamily = Field() + version_string: Optional[str] = Field() + + @field_validator("family", mode="before") + @classmethod + def enum_or_other(cls, value: str) -> BrowserFamily: + try: + return BrowserFamily(value) + except ValueError: + return BrowserFamily.OTHER + + +class DeviceInfo(BaseModel): + family: DeviceModelFamily = Field() + brand: DeviceBrand = Field() + model: DeviceModelFamily = Field() + + @field_validator("family", "model", mode="before") + @classmethod + def enum_or_other(cls, value: str) -> DeviceModelFamily: + try: + return ( + DeviceModelFamily(value) + if value is not None + else DeviceModelFamily.NONE + ) + except ValueError: + return DeviceModelFamily.OTHER + + @field_validator("brand", mode="before") + @classmethod + def enum_or_other2(cls, value: str) -> DeviceBrand: + try: + return DeviceBrand(value) if value is not None else DeviceBrand.NONE + except ValueError: + return DeviceBrand.OTHER + + +class GrlUserAgent(BaseModel): + """ + The UserAgent library parses useragents, but does not enumerate possible + values for anything. We go a step further here and have Enums for + things like OS families, browsers, etc. + """ + + model_config = ConfigDict( + extra="forbid", validate_assignment=True, arbitrary_types_allowed=True + ) + + ua_string: str = Field(description="The actual useragent string") + ua_parsed: UserAgent = Field( + description="UserAgent object. Internally is a bunch of namedtuples with strings." + ) + + os: OSInfo = Field() + browser: BrowserInfo = Field() + device: DeviceInfo = Field() + + # These 4 are determined by the UserAgent library, not be me. + is_mobile: bool = Field() + is_tablet: bool = Field() + is_pc: bool = Field() + is_bot: bool = Field() + + @property + def ua_string_values(self) -> Dict[str, str]: + # Returns the raw parsed string values for each of these. To be used + # for db filtering, identifying trends, etc. + d = dict() + d["ua_browser_family"] = self.ua_parsed.browser.family + d["ua_browser_version"] = self.ua_parsed.browser.version_string + d["ua_os_family"] = self.ua_parsed.os.family + d["ua_os_version"] = self.ua_parsed.os.version_string + d["ua_device_family"] = self.ua_parsed.device.family + d["ua_device_brand"] = self.ua_parsed.device.brand + d["ua_device_model"] = self.ua_parsed.device.model + s = "|".join([str(x[1]) for x in sorted(d.items(), key=lambda x: x[0])]) + d["ua_hash"] = hashlib.md5(s.encode("utf-8")).hexdigest() + return d + + @classmethod + def from_ua_str(cls, user_agent: str) -> Self: + ua = ua_parse(user_agent) + return cls.model_validate( + { + "ua_string": user_agent, + "ua_parsed": ua, + "os": { + "family": ua.os.family, + "version_string": ua.os.version_string, + }, + "browser": { + "family": ua.browser.family, + "version_string": ua.browser.version_string, + }, + "device": { + "family": ua.device.family, + "brand": ua.device.brand, + "model": ua.device.model, + }, + "is_mobile": ua.is_mobile, + "is_tablet": ua.is_tablet, + "is_pc": ua.is_pc, + "is_bot": ua.is_bot, + } + ) + + +class UserAgentHints(BaseModel): + """The forensic post also includes output from the useragent hints API + + https://developer.mozilla.org/en-US/docs/Web/API/User-Agent_Client_Hints_API + https://developer.chrome.com/docs/privacy-security/user-agent-client-hints + """ + + model_config = ConfigDict( + extra="forbid", validate_assignment=True, populate_by_name=True + ) + + brands: Optional[List[Dict]] = Field(validation_alias="b", default=None) + brands_full: Optional[List[Dict]] = Field(validation_alias="fv", default=None) + mobile: bool = Field(validation_alias="m", default=False) + model: Optional[str] = Field(validation_alias="md", default=None) + platform: Optional[str] = Field(validation_alias="o", default=None) + platform_version: Optional[str] = Field(validation_alias="ov", default=None) + architecture: Optional[str] = Field(validation_alias="a", default=None) + bitness: Optional[str] = Field(validation_alias="bt", default=None) diff --git a/generalresearch/grliq/utils.py b/generalresearch/grliq/utils.py new file mode 100644 index 0000000..6e563f9 --- /dev/null +++ b/generalresearch/grliq/utils.py @@ -0,0 +1,36 @@ +import os +from datetime import datetime, timezone +from typing import Optional, Union +from uuid import UUID + +# from generalresearch.config import +from generalresearch.models.custom_types import UUIDStr +from pathlib import Path + + +def get_screenshot_fp( + created_at: datetime, + forensic_uuid: Union[UUIDStr, UUID], + grliq_archive_dir: Path = "/tmp", + grliq_ss_dir_name: str = "canvas2html", + create_dir_if_not_exists: bool = True, +) -> Optional[Path]: + assert created_at.tzinfo == timezone.utc + + if isinstance(forensic_uuid, UUID): + forensic_uuid = forensic_uuid.hex + + directory_path = os.path.join( + grliq_archive_dir, + grliq_ss_dir_name, + created_at.strftime("%Y"), + created_at.strftime("%m"), + created_at.strftime("%d"), + ) + + if create_dir_if_not_exists: + os.makedirs(directory_path, exist_ok=True) + + fp = Path(os.path.join(directory_path, f"{forensic_uuid}.png")) + + return fp diff --git a/generalresearch/grpc.py b/generalresearch/grpc.py new file mode 100644 index 0000000..bf3f0e2 --- /dev/null +++ b/generalresearch/grpc.py @@ -0,0 +1,46 @@ +from datetime import datetime, timedelta +from typing import Optional + +from google.protobuf.duration_pb2 import Duration +from google.protobuf.timestamp_pb2 import Timestamp + + +def timestamp_from_datetime(dt: datetime) -> Timestamp: + ts = Timestamp() + ts.FromDatetime(dt) + return ts + + +def timestamp_from_datetime_nullable(dt: Optional[datetime]) -> Timestamp: + ts = Timestamp() + if dt: + ts.FromDatetime(dt) + return ts + + +def timestamp_to_datetime(ts: Timestamp) -> datetime: + return datetime.utcfromtimestamp(ts.seconds + ts.nanos / 1e9) + + +def timestamp_to_datetime_nullable(ts: Timestamp) -> Optional[datetime]: + # grpc has no None. If a google.protobuf.Timestamp field is not set, it gets interpreted as timestamp 0 + default = datetime.utcfromtimestamp(0) + d = datetime.utcfromtimestamp(ts.seconds + ts.nanos / 1e9) + return None if d == default else d + + +def timestamp_to_json_nullable(ts: Timestamp) -> Optional[str]: + # 1) grpc converts a null timestamp to '1970-01-01T00:00:00Z'. Not what we want... + # 2) grpc uses different formatting for the microseconds depending on if it's divisible by 0, 3, or 6 digits. + # I don't understand why anyone would want to do this ... + # This forces 6 digit microsecond (even if it is .000000), uses a Z for utc (grpc Timestamp does not + # support timezones and so is always UTC) and handles None properly. + dt = timestamp_to_datetime_nullable(ts) + dt = dt.isoformat(timespec="microseconds") + "Z" if dt else None + return dt + + +def duration_from_timedelta(td: timedelta) -> Duration: + d = Duration() + d.FromTimedelta(td) + return d diff --git a/generalresearch/healing_ppe.py b/generalresearch/healing_ppe.py new file mode 100644 index 0000000..2a3aecf --- /dev/null +++ b/generalresearch/healing_ppe.py @@ -0,0 +1,77 @@ +import logging +import os +import signal +import time +from collections import defaultdict +from concurrent import futures +from concurrent.futures.process import BrokenProcessPool + +logger = logging.getLogger() + +signal_int_name = defaultdict( + lambda: "Unknown", {x.value: x.name for x in signal.Signals} +) + + +class HealingProcessPoolExecutor: + def __init__( + self, max_workers=None, name=None, slack_token=None, slack_channel=None + ): + if not name: + try: + name = f"Process Pool: {__file__}" + except NameError: + name = "Process Pool" + self._name = name + self._max_workers = max_workers + self._pool = futures.ProcessPoolExecutor(max_workers=max_workers) + self._pool.submit(do_nothing) + # noinspection PyUnresolvedReferences + self._processes = self._pool._processes + + def get_qsize(self): + # noinspection PyUnresolvedReferences + return len(self._pool._pending_work_items) + + def submit(self, *args, **kwargs): + try: + return self._pool.submit(*args, **kwargs) + except BrokenProcessPool: + ps = list(self._processes.values()) + exit_codes = [signal_int_name[abs(p.exitcode)] for p in ps] + + msg = f"{self._name} is broken. Restarting executor." + msg += "\n" + f"exit codes: {exit_codes}" + logger.warning(msg) + + self._pool.shutdown(wait=True) + self._pool = futures.ProcessPoolExecutor(max_workers=self._max_workers) + # Submitting "do_nothing" here is probably not useful anymore. + + # This call happens on the pool's submit so if it is still broken, it will + # now raise an exception + return self._pool.submit(*args, **kwargs) + + +def do_nothing(): + # We submit this to process pools on init in order to have the needed processes fork + # before we load up a lot of stuff in the parent process. + test_logger = logging.getLogger("test") + test_logger.setLevel(logging.INFO) + test_logger.info("doing nothing") + time.sleep(2) + test_logger.info("did nothing") + + +def test(): + pool = HealingProcessPoolExecutor(2, name="test") + pool.submit(do_nothing) + time.sleep(0.5) + + # Kill a process in the pool + pid = list(pool._processes.keys())[0] + os.kill(pid, signal.SIGKILL) + time.sleep(0.5) + + # re-schedule a job + pool.submit(do_nothing) diff --git a/generalresearch/incite/__init__.py b/generalresearch/incite/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/incite/base.py b/generalresearch/incite/base.py new file mode 100644 index 0000000..bd346f9 --- /dev/null +++ b/generalresearch/incite/base.py @@ -0,0 +1,980 @@ +from __future__ import annotations + +import glob +import logging +import os +import re +import shutil +import subprocess +import warnings +from concurrent.futures import Future +from datetime import datetime, timezone, timedelta +from os import access, R_OK, listdir +from os.path import join as pjoin, isdir +from pathlib import Path +from sys import platform +from typing import ( + Optional, + Tuple, + List, + Sequence, + Any, + Union, + Callable, + TYPE_CHECKING, +) +from uuid import uuid4 + +import dask +import dask.dataframe as dd +import pandas as pd +import pyarrow.parquet as pq +from distributed import Client +from pandera import DataFrameSchema +from pydantic import ( + BaseModel, + ConfigDict, + DirectoryPath, + PrivateAttr, + Field, + model_validator, + FilePath, + field_validator, + ValidationInfo, +) +from pydantic.json_schema import SkipJsonSchema +from sentry_sdk import capture_exception +from typing_extensions import Self + +from generalresearch.config import is_debug +from generalresearch.incite.schemas import ( + ARCHIVE_AFTER, + empty_dataframe_from_schema, +) +from generalresearch.models.custom_types import AwareDatetimeISO + +if TYPE_CHECKING: + from generalresearch.incite.mergers import MergeType, MergeCollection + from generalresearch.incite.collections.thl_marketplaces import ( + DFCollectionType, + ) + from generalresearch.incite.collections import DFCollection + + Collection = Union[DFCollection, MergeCollection] + +logging.basicConfig() +LOG = logging.getLogger() + +# Item = Union["DFCollectionItem", "MergeCollectionItem"] +Item = Any +Items = Sequence[Item] +DT_STR = "%Y-%m-%d %H:%M:%S" + + +class NFSMount(BaseModel): + address: str = Field(default="127.0.0.1") + point: str = Field(default="grl-data-example") + + +class GRLDatasets(BaseModel): + """ + The "idea" of this class is to manage the Mount point, or source of + where Sambda or NFS data may be coming from.. I don't think it needs + to manage individual folders directly, but it should be aware of if + a drive is mounted, if it has correct permissions. + + Each field maps to a single network mount.. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + data_src: Optional[Path] = Field(default=None) + incite: Optional[NFSMount] = Field(default=None) + + @model_validator(mode="after") + def check_data_src_and_et_path(self) -> Self: + from generalresearch.incite.mergers import MergeType + from generalresearch.incite.collections.thl_marketplaces import ( + DFCollectionType, + ) + + # Create the base folders and confirm we have read access + self.data_src.mkdir(parents=True, exist_ok=True) + assert access( + path=self.data_src, mode=R_OK + ), f"can't access data_src: {self.data_src}" + + for enum_type in [MergeType, DFCollectionType]: + for et in enum_type: + et: MergeType | DFCollectionType + + if not is_debug() and "test" in et.value: + continue + + p = self.archive_path(enum_type=et) + + if is_debug(): + # Try to make any of them + p.mkdir(parents=True, exist_ok=True) + + assert access(path=p, mode=R_OK), f"Cannot read {p}" + return self + + def archive_path(self, enum_type: Union["MergeType", "DFCollectionType"]) -> Path: + """ + TODO: Extend this so that it takes any type of Enum and that + inputs in the correct parent dir for the respective Enum + type.. + """ + + from generalresearch.incite.mergers import MergeType + + folder = "mergers" if isinstance(enum_type, MergeType) else "raw/df-collections" + return Path( + pjoin(self.data_src, self.incite.point, folder, str(enum_type.value)) + ) + + def has_data(self, enum_type: Union["MergeType", "DFCollectionType"]) -> bool: + path_dir = self.archive_path(enum_type=enum_type) + if isdir(path_dir): + return bool(listdir(path_dir)) + else: + return False + + +class CollectionBase(BaseModel): + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_assignment=True, + # This is needed to auto assign a dask client + validate_default=True, + extra="forbid", + ) + + archive_path: DirectoryPath = Field(default="/tmp/") + df: SkipJsonSchema[pd.DataFrame] = Field( + default_factory=lambda: pd.DataFrame(), exclude=True + ) + + # I want to intentionally keep these as native python types, and not + # pandas specific types. This could also be called "duration" or + # "ItemSize". It is the length of time a CollectionItem stores. + offset: str = Field(default="72h", max_length=5) + + start: AwareDatetimeISO = Field( + default=datetime(year=2018, month=1, day=1, tzinfo=timezone.utc), + description="This is the starting point in which data will be retrieved" + "in chunks from.", + frozen=True, + ) + + finished: Optional[AwareDatetimeISO] = Field( + default=None, + description="Finished is only set if we don't want a rolling window", + ) + + _client: Optional[Client] = PrivateAttr(default=None) + + # --- Validators --- + @model_validator(mode="before") + @classmethod + def check_model_before(cls, data: Any) -> Any: + + assert isinstance(data, dict), "check_model_before.isinstance(data, dict)" + + # We must be able to read from the archive_path + ap = data.get("archive_path", None) + assert isinstance(ap, Path), "check_model_before.isinstance(ap, Path)" + + if not ap.is_dir(): + raise ValueError(f"Path does not point to a directory") + + if not access(path=ap, mode=R_OK): + raise ValueError(f"Cannot read archive_path") + + df: Optional[pd.DataFrame] = data.get("df", None) + if df is not None: + if not df.empty or len(df.columns) != 0: + raise ValueError("Do not provide a pd.DataFrame") + + return data + + @model_validator(mode="after") + def check_model_after(self) -> Self: + if self.offset is None or self.start is None: + return self + + offset_total_sec = pd.Timedelta(self.offset).total_seconds() + start_total_sec = (datetime.now(tz=timezone.utc) - self.start).total_seconds() + + if offset_total_sec > start_total_sec: + raise ValueError("Offset must be equal to, or smaller the start timestamp") + + return self + + @field_validator("start") + def check_start( + cls, start: Optional[datetime], info: ValidationInfo + ) -> Optional[datetime]: + if start and start.microsecond != 0: + raise ValueError("Collection.start must not have microseconds") + return start + + @field_validator("offset") + def check_offset(cls, v: Optional[str], info: ValidationInfo): + # pd.offsets.__all__ + if v is None: + # In MergeCollections, offset can be None + return v + try: + pd.Timedelta(v) + except (Exception,) as e: + capture_exception(error=e) + raise ValueError( + "Invalid offset alias provided. Please review: " + "https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases" + ) + + total_seconds: float = pd.Timedelta(v).total_seconds() + + if total_seconds < timedelta(minutes=1).total_seconds(): + raise ValueError("Must be equal to, or longer than 1 min") + + if total_seconds > timedelta(days=365).total_seconds() * 100: + raise ValueError("Must be equal to, or less than 100 years") + + return v + + # --- Properties --- + def _interval_range(self, end: datetime) -> pd.IntervalIndex: + assert end, "an end value must be provided" + + _start = self.interval_start + + if end.tzinfo is None: + # A Naive end was passed in. We probably did this on purpose. + _start = _start.replace(tzinfo=None) + + assert _start.tzinfo == end.tzinfo, "Timezones must match" + + if self.offset: + iv_r: pd.IntervalIndex = pd.interval_range( + start=_start, end=end, freq=self.offset, closed="left" + ) + res = iv_r.to_list() + + # If there is a defined start (there always should be), + # but the end isn't in the IntervalIndex range because + # the offset is longer than the end - start + if self.start is not None and end not in iv_r[-1]: + right = iv_r[-1].right + pd.Timedelta(self.offset) + iv = pd.Interval(left=iv_r[-1].right, right=right) + res.append(iv) + + else: + iv_r: pd.IntervalIndex = pd.interval_range( + start=_start, end=end, periods=1, closed="left" + ) + res = iv_r.to_list() + + return pd.IntervalIndex.from_tuples( + data=[(iv.left, iv.right) for iv in res], closed="left" + ) + + @property + def interval_start(self) -> Optional[datetime]: + # In DFCollections, start must be set, so the interval_start = start. In merged + # this may be overridden with different behavior. + return self.start + + @property + def interval_range(self) -> List[Tuple]: + """closed='left', so 0 <= x < 5""" + end = self.finished or datetime.now(tz=timezone.utc).replace(microsecond=0) + iv_r = self._interval_range(end) + return [(iv.left.to_pydatetime(), iv.right.to_pydatetime()) for iv in iv_r] + + @property + def progress(self) -> pd.DataFrame: + records = [i.to_dict() for i in self.items] + end = self.finished if self.finished else datetime.now(tz=timezone.utc) + return pd.DataFrame.from_records(records, index=self._interval_range(end)) + + @property + def items(self) -> pd.DataFrame: + raise NotImplementedError("Must override") + + @property + def _schema(self) -> DataFrameSchema: + raise NotImplementedError("Must override") + + # --- Methods --- + def fetch_force_rr_latest(self, sources) -> list: + raise NotImplementedError("Must override") + + def fetch_all_paths( + self, + items: Optional[Items] = None, + force_rr_latest=False, + include_partial=False, + ) -> List[FilePath]: + LOG.info( + f"CollectionBase.fetch_all(items={len(items or [])}, " + f"{force_rr_latest=}, {include_partial=})" + ) + + items = items or self.items + + # (1) All the originally available archives + sources: List[FilePath] = [ + i.path for i in items if i.has_archive(include_empty=False) + ] + + # (2) All the "ephemerally" available rr fetched tmp archives + # --- + # Sometimes we may want to force in the latest CollectionItem if it's + # important for some reason. However, many times we probably don't + # need the absolutely most recent data... and it's not worth the slow + # rr operation to do so + if force_rr_latest: + sources = self.fetch_force_rr_latest(sources) + if include_partial is False: + LOG.warning( + "If force_rr_latest, then by definition partial is included. " + "Set include_partial to True to remove this warning." + ) + include_partial = True + # (3) + if include_partial: + rr_items = [i for i in items if not i.has_archive()] + for rr_item in rr_items: + pp = rr_item.partial_path + if rr_item.has_partial_archive() and pp not in sources: + sources.append(pp) + + return sources + + def ddf( + self, + items: Optional[Items] = None, + force_rr_latest=False, + columns=None, + filters=None, + categories=None, + include_partial=False, + graph: Optional[Callable] = None, + ) -> Optional[dd.DataFrame]: + """ + + Args: + items (list): These are any of the Collection Item that we want to + pull rows from. If it is empty, it includes all of the .items + in the DFCollection + + force_rr_latest (bool): Sometimes we may want to force in the latest + CollectionItem if it's important for some reason. However, many + times we probably don't need the absolutely most recent data + and it's not worth the slow rr operation to do so + + columns: Often times it isn't required to return all of the columns + in the DFCollection. This allows us to limit which are returned. + + filters: Apply these filters to an Item basis to limit the total + rows that are returned. It uses the tuple syntax + + categories (list): Define any columns that may be categorical as it + allows parquet to optimize how the data is read. + + include_partial (bool): when .fetch_all_paths() is called to get + all the CollectionItems, this boolean is used to determine + if the last items + + Returns: + dd.DataFrame: this is a Dask Dataframe + + """ + + if isinstance(items, list) and len(items): + sources: List[FilePath] = [ + i.path for i in items if i.has_archive(include_empty=False) + ] + + sources.extend( + [ + i.partial_path + for i in items + if i.has_partial_archive() + and not i.has_archive(include_empty=False) + ] + ) + + else: + sources: List[FilePath] = self.fetch_all_paths( + items=None, + force_rr_latest=force_rr_latest, + include_partial=include_partial, + ) + + if len(sources) == 0: + return None + + ddfs = [] + for s in sources: + _ddf = dd.read_parquet( + path=s, + columns=columns, + filters=filters, + categories=categories, + calculate_divisions=False, + engine="pyarrow", + ) + + if graph: + _ddf = graph(_ddf) + + ddfs.append(_ddf) + + if len(ddfs) == 0: + raise AssertionError("Must provide parquet sources") + + # Look into interleave_partitions, default False + ddf = dd.concat(ddfs) + return ddf + + # --- Methods: Cleanup --- + def schedule_cleanup( + self, client=None, sync=True, client_resources=None + ) -> Union[pd.DataFrame, Future]: + LOG.info(f"cleanup(archive_path={self.archive_path})") + + fs = [] + for item in self.items: + fs.append(dask.delayed(item.cleanup_partials)()) + fs.append(dask.delayed(item.clear_corrupt_archive)()) + fs.append(dask.delayed(self.clear_tmp_archives)()) + res = client.compute( + collections=fs, + sync=sync, + priority=2, + client_resources=client_resources, + ) + return res + + def cleanup(self) -> None: + # Same as schedule_cleanup but runs locally + self.cleanup_partials() + self.clear_tmp_archives() + self.clear_corrupt_archives() + # self.check_empty() # what did this do?? + + return None + + def cleanup_partials(self) -> None: + """If an item is "closed", remove any partial files that may be around...""" + for item in self.items: + item.cleanup_partials() + + return None + + def clear_tmp_archives(self) -> None: + regex = re.compile(r"\.parquet\.[0-9a-f]{32}", re.I) + + for fn in os.listdir(self.archive_path): + if regex.search(fn): + LOG.info(f"Removing {fn}") + CollectionItemBase.delete_archive( + Path(os.path.join(self.archive_path, fn)) + ) + + return None + + def clear_corrupt_archives(self) -> None: + for item in self.items: + item.clear_corrupt_archive() + + return None + + def rebuild_symlinks(self) -> None: + """ + When copying "things" between filesystems, and using Sambda mmfsylinks, + we can't ensure links are properly shared. + """ + + for item in reversed(self.items): + item: CollectionItemBase + reg_path = item.path.as_posix() + partial_path = item.partial_path.as_posix() + empty_path = item.empty_path.as_posix() + + # --- Partial Path --- + if os.path.exists(partial_path) and os.path.isfile(partial_path): + os.remove(partial_path) + + # Don't "continue" onto the next CollectionItem. Later on, + # we may need to create a symlink for the most recent partial + pass + + # --- Empty Path --- + if os.path.exists(empty_path): + # A symlink isn't used for empty path CollectionItems + continue + + # --- Regular Path --- + if os.path.exists(reg_path): + + if os.path.isfile(reg_path): + # These should never be a file, clean up + os.remove(reg_path) + continue + + if os.path.isdir(reg_path) and not os.path.islink(reg_path): + # All is good and how it should be! + continue + + if os.path.islink(reg_path): + # A symlink already exists for this CollectionItem. However, + # don't "continue" on because we will want to ensure + # it's at the most recent version. + pass + + highest_version: Path = item.search_highest_numbered_path() + + if highest_version is None: + # No version of the file at all was found in the directory, + # so don't try to make a symlink for it + continue + + # Make sure these are in the same dir. b/c the symlink has to be + # relative, not an absolute path + assert ( + item.path.parent == highest_version.parent + ), "Can't have numbered_path in a different directory" + + try: + pq.ParquetDataset(highest_version).read().to_pandas() + except (Exception,): + # If the most recent version isn't valid, we don't want to + # create a symlink to it. + # TODO: We could try to be smart and iterate down the most recent + # available partials until we find one that isn't broken. However, + # this isn't a huge priority because it should fix itself upon + # the next sync cmd is run (every 1 min) + continue + + # if os.path.exists(item.path.as_posix()) and not os.path.islink(path.as_posix()): + # This will fail when going from the old way to using symlinks, if self.path already exists + # and is a directory. + # raise ValueError( + # f"first time we run this, make sure the path doesn't exist: {path.as_posix()}") + + # After running for a while, it appears that symlinks have a + # tendency to break for some reason. While it's unclear why, there + # shouldn't be any harm in always removing the file before the + # `ln` command is run. -- Max 2024-07-26 + try: + os.remove(item.path.as_posix()) + except FileNotFoundError as e: + pass + + if platform == "darwin": + subprocess.call(["ln", "-sfn", highest_version, item.path.as_posix()]) + else: + subprocess.call(["ln", "-sfnT", highest_version, item.path.as_posix()]) + + return None + + # -- Methods: Source timing + def get_item(self, interval: pd.Interval) -> Item: + return next(x for x in self.items if x.interval == interval) + + def get_item_start(self, start: pd.Timestamp) -> Items: + return next(x for x in self.items if x.interval.left == start) + + def get_items(self, since: datetime) -> Items: + res = [] + first_match = True + + for idx, item in enumerate(self.items): + item: "DFCollectionItem" + + # TODO: This appears to be a bug. It should be using the + # IntervalRange overlaps approach - Max 2024-06-07 + if item.start >= since: + + # We want to retrieve the item that falls before the + # first item, so we aren't missing any partial time ranges + if first_match and idx != 0: + res.append(self.items[idx - 1]) + + res.append(item) + first_match = False + + res: List[Item] = [i for i in res if not i.is_empty()] + if len([1 for i in res if i.should_archive() and not i.has_archive()]): + warnings.warn( + message="DFCollectionItem has missing archives", + category=ResourceWarning, + ) + + return res + + def get_items_from_year(self, year: int) -> Items: + ts = datetime(year=year, month=1, day=1) + return self.get_items(since=ts) + + def get_items_last90(self) -> Items: + ts = datetime.now(tz=timezone.utc) - timedelta(days=90) + return self.get_items(since=ts) + + def get_items_last365(self) -> Items: + ts = datetime.now(tz=timezone.utc) - timedelta(days=365) + return self.get_items(since=ts) + + +class CollectionItemBase(BaseModel): + # I want to intentionally keep these as native python types, and not + # pandas specific types. + start: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc).replace(microsecond=0) + ) + + # --- Private attrs --- + _collection: Collection = PrivateAttr() + + @property + def name(self) -> str: + coll = self._collection + if hasattr(coll, "data_type"): + name = coll.data_type.value + else: + name = coll.merge_type.value + return name + + def __str__(self): + coll = self._collection + offset = coll.offset or "–" + return f"{self.name}({self.interval.left.strftime('%x %X')} @ {offset})" + + # --- Validators --- + @model_validator(mode="after") + def check_start(self): + """We don't want to support CollectionItems that start on a + fractional second. + """ + assert ( + self.start.microsecond == 0 + ), "CollectionItem.start must not have microsecond precision" + return self + + # --- Properties --- + @property + def finish(self) -> datetime: + return ( + pd.Timestamp(self.start) + pd.Timedelta(self._collection.offset) + ).to_pydatetime() + + @property + def interval(self) -> pd.Interval: + return pd.Interval( + left=pd.Timestamp(self.start), + right=pd.Timestamp(self.finish), + closed="left", + ) + + # --- Properties: paths + filenames --- + @property + def filename(self) -> str: + raise NotImplementedError("Do not use CollectionItemBase directly.") + + @property + def partial_filename(self) -> str: + # This is an archive for a CollectionItem that is not yet closed. It is temporary and should + # never get backed up. + return f"{self.filename}.partial" + + @property + def empty_filename(self) -> str: + # If this file or directory exists on disk, it means this CollectionItem + # truly has no data. We want to use this to distinguish this from a + # broken file or a failed query. + return f"{self.filename}.empty" + + @property + def path(self) -> FilePath: + return FilePath(os.path.join(self._collection.archive_path, self.filename)) + + @property + def partial_path(self) -> FilePath: + return FilePath( + os.path.join(self._collection.archive_path, self.partial_filename) + ) + + @property + def empty_path(self) -> FilePath: + return FilePath( + os.path.join(self._collection.archive_path, self.empty_filename) + ) + + # --- Methods --- + + @staticmethod + def path_exists(generic_path: FilePath) -> bool: + return os.path.exists(generic_path) + + @staticmethod + def next_numbered_path(path: Path) -> Path: + # We assume the item.path is pointing to the current version. To get the "next", we increment it. + target = os.path.realpath(path) + if path == target: + # This is not yet a symlink, start with .00000 + return Path(f"{path}.{0:>05}") + + # We assume the target ends with ".####". If not, we'll append .00000 + try: + left, right = target.rsplit(".", 1) + right_int = int(right) + except ValueError: + return Path(f"{path}.{0:>05}") + + right_int += 1 + return Path(f"{path}.{right_int:>05}") + + def search_highest_numbered_path(self) -> Optional[Path]: + """This is used for when things are broken, and we want to rebuild + our symlinks. We can't trust or use any exist symlinks... so given + a path or a partial path... find the highest available "versioned" + build there is + """ + coll: CollectionBase = self._collection + + # TODO: is_partial support??? + + # regex = re.compile(r'\.parquet\.[0-9a-f]{32}', re.I) + builds = [] + for fn in os.listdir(coll.archive_path): + if fn.startswith(self.filename): + + # Don't include the "broken link" or mmfsymlink text file + if fn != self.filename and fn != self.partial_filename: + builds.append(fn) + + if len(builds) == 0: + return None + + # --- this doesn't work with if the incrementing file is a partial --- + # nums = sorted([b.rsplit(".", 1)[1] for b in builds], reverse=True) + # return Path(f"{self.path}.{nums[0]}") + + files: List[str] = sorted( + builds, key=lambda b: b.rsplit(".", 1)[1], reverse=True + ) + return Path(os.path.join(coll.archive_path, files[0])) + + def tmp_filename(self) -> str: + # Not a @property b/c I don't want to accidentally have this get mixed + # up as always returning the same tmp filename + return f"{self.filename}.{uuid4().hex}" + + def tmp_path(self) -> FilePath: + return FilePath( + os.path.join(self._collection.archive_path, self.tmp_filename()) + ) + + # --- --- --- --- + # If it has a partial, it isn't always going to be a partial. However, + # if it has an empty, will we ever try to recheck? + + def is_empty(self) -> bool: + return self.path_exists(self.empty_path) + + def has_empty(self) -> bool: + return self.is_empty() + + def has_partial_archive(self) -> bool: + return self.path_exists(self.partial_path) + + # --- --- --- --- + + def has_archive(self, include_empty=False) -> bool: + if include_empty: + return self.path_exists(generic_path=self.path) or self.path_exists( + generic_path=self.empty_path + ) + else: + return self.path_exists(generic_path=self.path) + + @staticmethod + def delete_archive(generic_path: Path) -> None: + # If a partial directory or file exists, delete it. + if os.path.exists(generic_path): + + if os.path.isfile(generic_path): + os.remove(generic_path) + + if os.path.isdir(generic_path): + # TODO: this is broken on Mac... + # os.path.islink(path.as_posix()): + shutil.rmtree(generic_path) + else: + LOG.warning(f"tried removing non-existent file: {generic_path}") + pass + + def should_archive(self) -> bool: + # Determine if enough time has passed to move out of a partial file into an + # archive. + archive_after: timedelta = self._collection._schema.metadata[ARCHIVE_AFTER] + + if archive_after is None: + return False + + if datetime.now(tz=timezone.utc) > self.finish + archive_after: + return True + return False + + def set_empty(self): + assert ( + self.should_archive() + ), "Can not set_empty on an item that is not archive-able" + assert not self.is_empty(), "set_empty is already set; why are you doing this?" + self.empty_path.touch() + assert self.is_empty(), "set_empty(): something is wrong" + + def valid_archive( + self, + generic_path: Optional[FilePath] = None, + sample: Optional[int] = None, + ) -> bool: + """ + Attempts to confirm if the parquet file or directory that is + written to Disk for a DFCollectionItem is not corrupted or otherwise + in a state that would prevent its use. + """ + path: str = generic_path.as_posix() if generic_path else self.path.as_posix() + try: + if os.path.isfile(path): + parquet = pq.ParquetFile(path) + elif os.path.isdir(path): + # This will not fail on a empty directory. However, it will + # return the .read().to_pandas() as an empty pd.DataFrame + # without any rows or columns + parquet = pq.ParquetDataset(path) + else: + # TODO: are there even other types; eg: are symlinks .isfile=True? + raise ValueError("Unknown path type.") + + df = parquet.read().to_pandas() + except (Exception,): + LOG.warning(f"Invalid archive {path=}") + df = None + + # Check if it's None or a totally empty pd.DataFrame before we waste + # any time on trying to hit pandera + if df is None or sum(df.shape) == 0: + return False + + return self.validate_df(df=df, sample=sample) is not None + + def validate_df( + self, df: pd.DataFrame, sample: Optional[int] = None + ) -> Optional[pd.DataFrame]: + if sample is not None: + sample = min(len(df), sample) + try: + schema: DataFrameSchema = self._collection._schema + return schema.validate(check_obj=df, lazy=True, sample=sample) + except Exception as e: + LOG.exception(e) + capture_exception(error=e) + return None + + # def validate_ddf(self, ddf: dd.DataFrame) -> Optional[pd.DataFrame]: + # """ WARNING: this accepts a dd.DataFrame, but returns a pd.DataFrame + # """ + # # TODO: This is absolutely a way to do this with pyArrow Schemas.. However, + # # we'd first need to figure out how to go from Pandera to a pyArrow Schema. + # try: + # df = self._collection._client.compute( + # collections=ddf, + # sync=True, + # priority=1, + # resources=self._collection._client_resources + # ) + # except (Exception,) as e: + # capture_exception(error=e) + # return None + # + # return _validate_df(df=df, schema=self._collection._schema) + + # --- ORM / Data handlers--- + def from_archive( + self, + include_empty: bool = True, + generic_path: Optional[FilePath] = None, + ) -> Optional[dd.DataFrame]: + + if include_empty and self.path_exists(generic_path=self.empty_path): + # Return an empty dd.DataFrame with the correct columns + return dd.from_pandas(empty_dataframe_from_schema(self._collection._schema)) + + if not self.path_exists(generic_path=generic_path or self.path): + return None + + return dd.read_parquet( + path=generic_path or self.path, + calculate_divisions=False, + engine="pyarrow", + ) + + def to_archive(self, ddf: dd.DataFrame, is_partial: bool = False) -> bool: + raise NotImplementedError("Must override") + + # --- ORM / Data handlers--- + def _to_dict(self, *args, **kwargs) -> dict: + return dict( + should_archive=self.should_archive(), + has_archive=self.has_archive(), + filename=self.filename, + path=self.path, + start=self.start, + finish=self.finish, + ) + + def delete_partial(self): + # If a Collection Item is archived, we want to delete the partial file. + assert self.should_archive(), "please wait until item is archived" + if not self.path_exists(self.partial_path): + LOG.info(f"no partial to delete: {self.partial_path}") + return + if not self.partial_path.is_symlink(): + LOG.warning(f"expected symlink: {self.partial_path}") + return + target = self.partial_path.parent / self.partial_path.readlink() + os.remove(self.partial_path) + shutil.rmtree(target) + + def cleanup_partials(self): + if self.path_exists(self.partial_path): + if self.should_archive() and self.has_archive(include_empty=True): + self.delete_dangling_partials() + self.delete_partial() + else: + self.delete_dangling_partials(keep_latest=2) + + def delete_dangling_partials(self, keep_latest=None, target_path=None) -> List[str]: + # Specifically looking for numbered partials that are NOT associated + # with a symlink. It does not matter if the item is archiveable or not. + if target_path is None: + target_path = self.partial_path + fps = glob.glob(target_path.as_posix() + ".*") + fps = {x for x in fps if x.split(".")[-1].isnumeric()} + # Note: if the dir itself is sym-linked, this is going to be wrong. + # Use the relative paths existing_link = os.path.realpath(target_path) + if target_path.exists() and target_path.is_symlink(): + existing_link = target_path.parent / target_path.readlink() + fps.discard(existing_link.as_posix()) + fps = sorted(fps) + if keep_latest is not None: + fps = fps[:-keep_latest] + for fp in fps: + self.delete_archive(fp) + return fps diff --git a/generalresearch/incite/collections/__init__.py b/generalresearch/incite/collections/__init__.py new file mode 100644 index 0000000..051c5a1 --- /dev/null +++ b/generalresearch/incite/collections/__init__.py @@ -0,0 +1,752 @@ +import logging +import os +import subprocess +import time +from datetime import datetime +from enum import Enum +from sys import platform +from typing import Optional, List + +import dask +import dask.dataframe as dd +import pandas as pd +import pyarrow.parquet as pq +from dask.distributed import Future +from distributed import Client, as_completed +from more_itertools import chunked +from pandera import DataFrameSchema +from psycopg import Cursor +from pydantic import Field, FilePath, field_validator, ValidationInfo +from sentry_sdk import capture_exception + +from generalresearch.incite.base import CollectionBase, CollectionItemBase +from generalresearch.incite.schemas import ( + ORDER_KEY, + ARCHIVE_AFTER, + PARTITION_ON, + empty_dataframe_from_schema, +) +from generalresearch.incite.schemas.thl_marketplaces import ( + InnovateSurveyHistorySchema, + MorningSurveyTimeseriesSchema, + SagoSurveyHistorySchema, + SpectrumSurveyTimeseriesSchema, +) +from generalresearch.incite.schemas.thl_web import ( + TxSchema, + TxMetaSchema, + THLUserSchema, + THLTaskAdjustmentSchema, + THLWallSchema, + THLSessionSchema, + THLIPInfoSchema, + TransactionMetadataColumns, + UserHealthIPHistorySchema, + UserHealthAuditLogSchema, + UserHealthIPHistoryWSSchema, + LedgerSchema, +) +from generalresearch.pg_helper import PostgresConfig +from generalresearch.sql_helper import SqlHelper + +LOG = logging.getLogger("incite") + +DT_STR = "%Y-%m-%d %H:%M:%S" + + +class DFCollectionType(str, Enum): + TEST = "test" + + USER = "thl_user" + SESSION = "thl_session" + WALL = "thl_wall" + TASK_ADJUSTMENT = "thl_taskadjustment" + IP_INFO = "thl_ipinformation" + + AUDIT_LOG = "userhealth_auditlog" + IP_HISTORY = "userhealth_iphistory" + IP_HISTORY_WS = "userhealth_iphistory_ws" + + LEDGER = "ledger" + + INNOVATE_SURVEY_HISTORY = "innovate_surveyhistory" + MORNING_SURVEY_TIMESERIES = "morning_surveytimeseries" + SAGO_SURVEY_HISTORY = "sago_surveyhistory" + SPECTRUM_SURVEY_TIMESERIES = "spectrum_surveytimeseries" + + +DFCollectionTypeSchemas = { + DFCollectionType.USER: THLUserSchema, + DFCollectionType.WALL: THLWallSchema, + DFCollectionType.SESSION: THLSessionSchema, + DFCollectionType.IP_INFO: THLIPInfoSchema, + DFCollectionType.TASK_ADJUSTMENT: THLTaskAdjustmentSchema, + DFCollectionType.IP_HISTORY: UserHealthIPHistorySchema, + DFCollectionType.IP_HISTORY_WS: UserHealthIPHistoryWSSchema, + DFCollectionType.AUDIT_LOG: UserHealthAuditLogSchema, + DFCollectionType.LEDGER: LedgerSchema, + DFCollectionType.INNOVATE_SURVEY_HISTORY: InnovateSurveyHistorySchema, + DFCollectionType.MORNING_SURVEY_TIMESERIES: MorningSurveyTimeseriesSchema, + DFCollectionType.SAGO_SURVEY_HISTORY: SagoSurveyHistorySchema, + DFCollectionType.SPECTRUM_SURVEY_TIMESERIES: SpectrumSurveyTimeseriesSchema, +} + + +class DFCollectionItem(CollectionItemBase): + + # --- Properties --- + @property + def filename(self) -> str: + return ( + f"{self._collection.data_type.name.lower()}-{self._collection.offset}" + f"-{self.start.strftime('%Y-%m-%d-%H-%M-%S')}.parquet" + ) + + # --- Methods --- + + def has_mysql(self) -> bool: + if self._collection.sql_helper is None: + return False + + connected = True + try: + self._collection.sql_helper.execute_sql_query("""SELECT 1;""") + except: + connected = False + + return connected + + def has_postgres(self) -> bool: + if self._collection.pg_config is None: + return False + + connected = True + try: + self._collection.pg_config.execute_sql_query("""SELECT 1;""") + except: + connected = False + + return connected + + def has_db(self) -> bool: + return self.has_mysql() or self.has_postgres() + + def update_partial_archive(self) -> bool: + if not self.valid_archive(self.partial_path, sample=1000): + LOG.error(f"invalid partial archive: {self.partial_path}") + return self.create_partial_archive() + df = pq.ParquetDataset(self.partial_path).read().to_pandas() + + order_key = self._collection._schema.metadata[ORDER_KEY] + archive_after = self._collection._schema.metadata[ARCHIVE_AFTER] + + partial_max = df[order_key].max().to_pydatetime() + + since = partial_max - archive_after + since = max([since, self.start]) # don't allow to query before the item's start + df = df[df[order_key] < since].copy() + + _df = self.from_mysql(since=since) + + if _df is not None: + df = pd.concat([df, _df]) + self.to_archive(ddf=dd.from_pandas(df, npartitions=1), is_partial=True) + else: + # The update to the partial returned no rows, but the partial + # still exists, so we'll continue with whatever was calling this. + # We don't need to re-write the partial or really do anything. + pass + return True + + def create_partial_archive(self) -> bool: + _df = self.from_mysql() + if _df is None: + # Returned no rows, but the period is not closed, so we + # don't want to mark as empty. Do nothing. + return False + return self.to_archive(ddf=dd.from_pandas(_df, npartitions=1), is_partial=True) + + # --- ORM / Data handlers--- + def to_dict(self, *args, **kwargs) -> dict: + return self._to_dict() + + def from_mysql(self, since: Optional[datetime] = None) -> Optional[pd.DataFrame]: + if self._collection.data_type == DFCollectionType.LEDGER: + assert since is None, "Shouldn't pass since for Ledger item" + assert self._collection.pg_config is not None + return self.from_postgres_ledger() + else: + if self._collection.sql_helper: + return self.from_mysql_standard(since=since) + else: + return self.from_postgres_standard(since=since) + + def from_mysql_standard( + self, since: Optional[datetime] = None + ) -> Optional[pd.DataFrame]: + + assert ( + self._collection.data_type != DFCollectionType.LEDGER + ), "Can't call from_mysql_standard for Ledger DFCollectionItem" + + start, finish = self.start, self.finish + LOG.debug( + f"{self._collection.data_type.value}.from_mysql(" + f"start={start.strftime(DT_STR)}, " + f"finish={finish.strftime(DT_STR)})" + ) + coll = self._collection + schema = coll._schema + sql_helper = coll.sql_helper + + start = since or start + order_key = schema.metadata[ORDER_KEY] + cols = list(schema.columns.keys()) + [schema.index.name] + cols_str = ",".join(map(sql_helper._quote, cols)) + db_name = sql_helper.db + + try: + res = sql_helper.execute_sql_query( + query=f""" + SELECT {cols_str} + FROM `{db_name}`.`{coll.data_type.value}` + WHERE `{order_key}` >= %s AND `{order_key}` < %s; + """, + params=[start, finish], + ) + except (Exception,) as e: + capture_exception(error=e) + LOG.error(f"_from_mysql Exception: {e}") + return None + + if not res: + LOG.warning(f"_from_mysql query returned nothing") + # Return an empty df.DataFrame with the correct columns + return empty_dataframe_from_schema(coll._schema) + + df = pd.DataFrame.from_records(res).set_index(coll._schema.index.name) + df = self.validate_df(df=df) + + if df is None: + LOG.warning(f"_from_mysql query results failed validation") + # Schema validation can fail... + return None + + return df + + def from_postgres_standard( + self, since: Optional[datetime] = None + ) -> Optional[pd.DataFrame]: + assert ( + self._collection.data_type != DFCollectionType.LEDGER + ), "Can't call from_postgres_standard for Ledger DFCollectionItem" + + start, finish = self.start, self.finish + LOG.debug( + f"{self._collection.data_type.value}.from_postgres(" + f"start={start.strftime(DT_STR)}, " + f"finish={finish.strftime(DT_STR)})" + ) + coll = self._collection + schema = coll._schema + pg_config = coll.pg_config + + start = since or start + order_key = schema.metadata[ORDER_KEY] + cols = list(schema.columns.keys()) + [schema.index.name] + cols_str = ", ".join(cols) + + try: + res = pg_config.execute_sql_query( + query=f""" + SELECT {cols_str} + FROM {coll.data_type.value} + WHERE {order_key} >= %s AND {order_key} < %s; + """, + params=[start, finish], + ) + except (Exception,) as e: + capture_exception(error=e) + LOG.error(f"_from_postgres Exception: {e}") + return None + + if not res: + LOG.warning(f"_from_postgres query returned nothing") + # Return an empty df.DataFrame with the correct columns + return empty_dataframe_from_schema(coll._schema) + + df = pd.DataFrame.from_records(res).set_index(coll._schema.index.name) + df = self.validate_df(df=df) + + if df is None: + LOG.warning(f"_from_postgres query results failed validation") + # Schema validation can fail... + return None + + return df + + def from_postgres_ledger(self) -> Optional[pd.DataFrame]: + assert ( + self._collection.data_type == DFCollectionType.LEDGER + ), "Can only call from_postgres_ledger on Ledger DFCollectionItem" + + start, finish = self.start, self.finish + LOG.info( + f"{self._collection.data_type.value}.from_postgres_ledger(" + f"start={start.strftime(DT_STR)}, " + f"finish={finish.strftime(DT_STR)})" + ) + + coll = self._collection + pg_config: PostgresConfig = coll.pg_config + + limit = 20000 + offset = 0 + res = [] + while True: + logging.info( + f"{self._collection.data_type.value}.from_postgres_ledger({limit=}, {offset=})" + ) + chunk = pg_config.execute_sql_query( + query=f""" + SELECT lt.id AS tx_id, lt.created, lt.ext_description, lt.tag, + le.id AS entry_id, le.direction, le.amount, le.account_id, + la.display_name, la.qualified_name, la.account_type, + la.normal_balance, la.reference_type, la.reference_uuid, + la.currency + FROM ledger_transaction AS lt + LEFT JOIN ledger_entry AS le + ON lt.id = le.transaction_id + LEFT JOIN ledger_account AS la + ON la.uuid = le.account_id + WHERE lt.created >= %s AND lt.created < %s + AND le.id IS NOT NULL + ORDER BY lt.created + LIMIT {limit} OFFSET {offset}; + """, + params=[start, finish], + ) + res.extend(chunk) + if not chunk: + break + offset += limit + + if len(res) == 0: + return None + + # Note (AND le.id IS NOT NULL): It is possible we have transactions with + # no ledger entries. This is because the transaction creation failed + # for some reason. The ledger is not unbalanced, it is just an orphan + # transaction. Just skip those here. + + tx_df = TxSchema.validate( + check_obj=pd.DataFrame.from_records(res).set_index("entry_id"), + lazy=True, + ) + + tx_ids = list(tx_df["tx_id"].unique()) + metadata_res = [] + # "MySQL server has gone away" if this is too big + conn = pg_config.make_connection() + c: Cursor = conn.cursor() + for chunk in chunked(tx_ids, n=5_000): + c.execute( + query=f""" + SELECT ltm.transaction_id AS tx_id, + ltm.id AS tx_metadata_id, + ltm.key, ltm.value + FROM ledger_transactionmetadata AS ltm + WHERE ltm.transaction_id = ANY(%s); + """, + params=[chunk], + ) + metadata_res += c.fetchall() + + conn.close() + + tx_meta = ( + pd.DataFrame( + TxMetaSchema.validate( + check_obj=pd.DataFrame.from_records(metadata_res).set_index( + ["tx_id", "tx_metadata_id"] + ), + lazy=True, + ).pivot(columns="key", values="value"), + # This makes sure we expand to have all the possible columns + columns=[e.value for e in TransactionMetadataColumns], + ) + .groupby("tx_id") + .first() + ) + + df = tx_df.merge(tx_meta, how="left", left_on="tx_id", right_index=True) + df = self.validate_df(df=df) + + if df is None: + # Schema validation can fail... + return None + + return df + + def to_archive( + self, + ddf: dd.DataFrame, + is_partial: bool = False, + overwrite: bool = False, + ) -> bool: + """ + :returns: bool (saved_successful) + """ + assert isinstance(ddf, dd.DataFrame), "must pass dask df" + + client: Optional[Client] = self._collection._client + # client = None + + if client: + row_len = client.compute(collections=ddf.shape[0], sync=True) + else: + row_len = len(ddf.index) + is_empty = row_len == 0 + + if is_partial: + return self.to_archive_numbered_partial(ddf=ddf) + else: + return self._to_archive( + ddf=ddf, + is_empty=is_empty, + overwrite=overwrite, + ) + + def _to_archive( + self, + ddf: dd.DataFrame, + is_empty: bool, + overwrite: bool = False, + ) -> bool: + """ + For archiving an item. Will write an empty file if ddf is empty. This + is NOT for writing partials. + + :returns: bool (saved_successful) + """ + + if ddf is None: + return False + + should_archive = self.should_archive() + if not should_archive: + LOG.warning(f"Cannot create archive for such new data: {self.path}") + return False + + if overwrite is False: + has_archive = self.has_archive(include_empty=True) + if has_archive: + LOG.warning(f"archive already exists: {self.path}") + return False + + if is_empty: + # Create an .empty only if the Item is "archiveable" (which we checked above) + self.set_empty() + return True + + # Incase the file saving is interrupted, or otherwise fails + # save it to a tmp file first, then rename once we can confirm + # that it successfully loads + tmp_path = self.tmp_path() + try: + schema = self._collection._schema + partition = schema.metadata.get(PARTITION_ON, None) + + ddf.to_parquet( + path=tmp_path, + partition_on=partition, + engine="pyarrow", + overwrite=True, + write_metadata_file=True, + compression="brotli", + ) + + except (Exception,) as e: + LOG.exception(e) + self.delete_archive(tmp_path) + return False + + # It was saved, but the file seems to be corrupt + if not self.valid_archive(tmp_path): + LOG.error(f"not valid archive: {tmp_path}") + self.delete_archive(tmp_path) + # File did not save correctly so return it as saved=False + return False + + # To debug, just set this key to auto expire in 5 seconds + # RC.set(name=f"_to_archive:{self.path.as_posix()}", value=1, ex=15) + # with RC.lock(f"_to_archive:{self.path.as_posix()}:lock", timeout=15): + + if os.path.isfile(tmp_path): + # If the file was saved okay, seems okay, rename it + os.replace(tmp_path, self.path) + os.remove(tmp_path) + + if os.path.isdir(tmp_path): + if os.path.exists(self.path.as_posix()): + if overwrite: + subprocess.call(["rm", "-r", self.path.as_posix()]) + time.sleep(1) + else: + LOG.error(f"already exists: {self.path.as_posix()}") + return False + + if platform == "darwin": + subprocess.call(["mv", tmp_path.as_posix(), self.path.as_posix()]) + else: + # -T will (should) cause the mv to fail if path wasn't successfully deleted + subprocess.call(["mv", "-T", tmp_path.as_posix(), self.path.as_posix()]) + return True + + def to_archive_numbered_partial(self, ddf: dd.DataFrame) -> bool: + """ + For partial files/dirs only. Writes the .partial file with a number + at the end (.partial.####) and then creates a symlink + from .partial -> .partial.#### + + :returns: bool (saved_successful) + """ + if ddf is None: + return False + collection = self._collection + schema = collection._schema + client: Optional[Client] = collection._client + + next_numbered_path = self.next_numbered_path(self.partial_path) + partial_path = self.partial_path + finish = self.finish + + # Make sure these are in the same dir. b/c the symlink has to be + # relative, not an absolute path + assert ( + partial_path.parent == next_numbered_path.parent + ), "Can't have numbered_path in a different directory" + target = ( + next_numbered_path.name + ) # this is the symlink's target. it is a relative path (only the name) + + should_archive = self.should_archive() + assert should_archive is False, "Don't write partial if the item is archiveable" + + if client: + row_len = client.compute(collections=ddf.shape[0], sync=True) + else: + row_len = len(ddf.index) + + if row_len == 0: + LOG.warning("Skipping, don't partial save an empty dd.DataFrame") + return False + + try: + partition = schema.metadata.get(PARTITION_ON, None) + ddf.to_parquet( + path=next_numbered_path, + partition_on=partition, + engine="pyarrow", + overwrite=True, + write_metadata_file=True, + compression="brotli", + ) + except (Exception,) as e: + LOG.exception(e) + self.delete_archive(next_numbered_path) + return False + + if platform == "darwin": + subprocess.call(["ln", "-sfn", target, partial_path]) + else: + subprocess.call(["ln", "-sfnT", target, partial_path]) + + return True + + def initial_load(self, overwrite: bool = False) -> bool: + + if overwrite is False: + assert not self.has_archive(include_empty=True), "already archived" + + assert self.should_archive(), "not ready to archive!" + + df: pd.DataFrame = self.from_mysql() + + if df is None: + self.set_empty() + return False + + ddf = dd.from_pandas(df, npartitions=1) + return self.to_archive(ddf=ddf, is_partial=False, overwrite=overwrite) + + def clear_corrupt_archive(self): + if self.has_archive(include_empty=False): + if not self.valid_archive(self.path): + LOG.warning(f"invalid archive, deleting: {self.path}") + self.delete_archive(self.path) + + +class DFCollection(CollectionBase): + data_type: Optional[DFCollectionType] = Field(default=None) + + # --- Private --- + pg_config: Optional[PostgresConfig] = Field(default=None) + sql_helper: Optional[SqlHelper] = Field(default=None) + + def __repr__(self): + res = self.signature() + "\n" + if len(self.items) > 6: + items = self.items[:3] + ["..."] + self.items[-3:] + else: + items = self.items + + for i in items: + res += f" – {repr(i) if isinstance(i, DFCollectionItem) else i}\n" + + return res + + def signature(self): + arr = [ + 1 if i.has_archive(include_empty=True) else 0 + for i in self.items + if i.should_archive() + ] + repr_str = ( + f"items={len(self.items)}; start={self.start} @ {self.offset}; {int(sum(arr) / len(arr) * 100)}% " + f"archived" + ) + res = f"{self.__repr_name__()}({repr_str})" + return res + + @field_validator("data_type") + def check_data_type(cls, data_type, info: ValidationInfo): + if data_type is None: + raise ValueError("Must explicitly provide a data_type") + + if data_type not in DFCollectionTypeSchemas: + raise ValueError("Must provide a supported data_type") + + return data_type + + # --- Properties --- + @property + def items(self) -> List[DFCollectionItem]: + items = [] + for iv in self.interval_range: + cm = DFCollectionItem(start=iv[0]) + cm._collection = self + items.append(cm) + return items + + @property + def _schema(self) -> DataFrameSchema: + return DFCollectionTypeSchemas[self.data_type] + + # --- Methods --- + + def initial_load( + self, + client: Optional[Client] = None, + sync=True, + since: Optional[datetime] = None, + client_resources=None, + timeout: Optional[float] = None, + ) -> List[Future]: + # This can be used to just build all local archive files + # We typically want to go backwards first, so we can most quickly + # populate the last 90 days for example + + client = client or self._client + + LOG.info(f"{self.data_type.value}.initial_load({since=}, {sync=})") + + items = self.items + if since: + items = self.get_items(since=since) + + if client is None: + for item in reversed(items): + if item.has_archive(include_empty=True): + continue + if not item.should_archive(): + continue + item.initial_load() + return [] + + fs = [] + for item in items: + if item.has_archive(include_empty=True): + continue + if not item.should_archive(): + continue + f = dask.delayed(item.initial_load)() + fs.append(f) + + if sync: + fs = client.compute(fs, sync=False, priority=2, resources=client_resources) + ac = as_completed(fs, timeout=timeout) + return fs + + else: + return client.compute(fs, sync=True, priority=2, resources=client_resources) + + def fetch_force_rr_latest(self, sources) -> List[FilePath]: + LOG.info( + f"{self.data_type.value}.fetch_force_rr_latest(sources={len(sources)})" + ) + + # We only want 'partial-able' items (those that can not yet be archived). + rr_items = [ + i for i in self.items if not i.should_archive() and not i.is_empty() + ] + if rr_items: + # If the ARCHIVE_AFTER time is > the collection offset (which it is always currently), + # then there typically wouldn't be more than 1 un-archivable item. + _start = rr_items[0].start + _end = rr_items[-1].finish + rr_duration = (_end - _start).total_seconds() + + # TODO: Do we want to be smarter about any rr selects max durations? + # allowing 2x the length of the offset. If we have more than this not archived, + # we want to run the archive first, not fetch from rr + archive_after = self._schema.metadata[ARCHIVE_AFTER] + allowed_rr_duration = ( + (pd.Timedelta(self.offset) * 2) + archive_after + ).total_seconds() + if rr_duration > allowed_rr_duration: + raise ValueError( + f"rr select duration exceeds {pd.Timedelta(allowed_rr_duration)}" + ) + + for rr_item in rr_items: + if ( + rr_item.has_partial_archive() + and self.data_type != DFCollectionType.LEDGER + ): + saved = rr_item.update_partial_archive() + else: + saved = rr_item.create_partial_archive() + if saved: + sources.append(rr_item.partial_path) + + return sources + + def force_rr_latest( + self, client: Client, client_resources=None, sync: bool = True + ) -> List[Future]: + # For forcing update of any partials asynchronously if desired + LOG.info(f"{self.data_type.value}.force_rr_latest({client=})") + rr_items = [ + i for i in self.items if not i.should_archive() and not i.is_empty() + ] + fs = [] + for rr_item in rr_items: + if ( + rr_item.has_partial_archive() + and self.data_type != DFCollectionType.LEDGER + ): + fs.append(dask.delayed(rr_item.update_partial_archive)()) + else: + fs.append(dask.delayed(rr_item.create_partial_archive)()) + return client.compute(fs, sync=sync, priority=2, resources=client_resources) diff --git a/generalresearch/incite/collections/thl_marketplaces.py b/generalresearch/incite/collections/thl_marketplaces.py new file mode 100644 index 0000000..fe2b01f --- /dev/null +++ b/generalresearch/incite/collections/thl_marketplaces.py @@ -0,0 +1,37 @@ +from typing import Literal + +from generalresearch.incite.collections import DFCollection, DFCollectionType +from generalresearch.incite.schemas.thl_marketplaces import ( + InnovateSurveyHistorySchema, + MorningSurveyTimeseriesSchema, + SagoSurveyHistorySchema, + SpectrumSurveyTimeseriesSchema, +) + + +class InnovateSurveyHistoryCollection(DFCollection): + data_type: Literal[DFCollectionType.INNOVATE_SURVEY_HISTORY] = ( + DFCollectionType.INNOVATE_SURVEY_HISTORY + ) + _schema = InnovateSurveyHistorySchema + + +class MorningSurveyTimeseriesCollection(DFCollection): + data_type: Literal[DFCollectionType.MORNING_SURVEY_TIMESERIES] = ( + DFCollectionType.MORNING_SURVEY_TIMESERIES + ) + _schema = MorningSurveyTimeseriesSchema + + +class SagoSurveyHistoryCollection(DFCollection): + data_type: Literal[DFCollectionType.SAGO_SURVEY_HISTORY] = ( + DFCollectionType.SAGO_SURVEY_HISTORY + ) + _schema = SagoSurveyHistorySchema + + +class SpectrumSurveyTimeseriesCollection(DFCollection): + data_type: Literal[DFCollectionType.SPECTRUM_SURVEY_TIMESERIES] = ( + DFCollectionType.SPECTRUM_SURVEY_TIMESERIES + ) + _schema = SpectrumSurveyTimeseriesSchema diff --git a/generalresearch/incite/collections/thl_web.py b/generalresearch/incite/collections/thl_web.py new file mode 100644 index 0000000..951406c --- /dev/null +++ b/generalresearch/incite/collections/thl_web.py @@ -0,0 +1,41 @@ +from typing import Literal + +from generalresearch.incite.collections import DFCollection, DFCollectionType + + +class UserDFCollection(DFCollection): + data_type: Literal[DFCollectionType.USER] = DFCollectionType.USER + + +class WallDFCollection(DFCollection): + data_type: Literal[DFCollectionType.WALL] = DFCollectionType.WALL + + +class SessionDFCollection(DFCollection): + data_type: Literal[DFCollectionType.SESSION] = DFCollectionType.SESSION + + +class IPInfoDFCollection(DFCollection): + data_type: Literal[DFCollectionType.IP_INFO] = DFCollectionType.IP_INFO + + +class IPHistoryDFCollection(DFCollection): + data_type: Literal[DFCollectionType.IP_HISTORY] = DFCollectionType.IP_HISTORY + + +class IPHistoryWSDFCollection(DFCollection): + data_type: Literal[DFCollectionType.IP_HISTORY_WS] = DFCollectionType.IP_HISTORY_WS + + +class TaskAdjustmentDFCollection(DFCollection): + data_type: Literal[DFCollectionType.TASK_ADJUSTMENT] = ( + DFCollectionType.TASK_ADJUSTMENT + ) + + +class AuditLogDFCollection(DFCollection): + data_type: Literal[DFCollectionType.AUDIT_LOG] = DFCollectionType.AUDIT_LOG + + +class LedgerDFCollection(DFCollection): + data_type: Literal[DFCollectionType.LEDGER] = DFCollectionType.LEDGER diff --git a/generalresearch/incite/defaults.py b/generalresearch/incite/defaults.py new file mode 100644 index 0000000..e200ddc --- /dev/null +++ b/generalresearch/incite/defaults.py @@ -0,0 +1,196 @@ +from datetime import datetime, timezone + +from generalresearch.incite.base import GRLDatasets +from generalresearch.incite.collections import DFCollectionType +from generalresearch.incite.collections.thl_marketplaces import ( + InnovateSurveyHistoryCollection, + MorningSurveyTimeseriesCollection, + SagoSurveyHistoryCollection, + SpectrumSurveyTimeseriesCollection, +) +from generalresearch.incite.collections.thl_web import ( + SessionDFCollection, + WallDFCollection, + UserDFCollection, + TaskAdjustmentDFCollection, + LedgerDFCollection, +) +from generalresearch.incite.mergers import MergeType +from generalresearch.incite.mergers.foundations.enriched_session import ( + EnrichedSessionMerge, +) +from generalresearch.incite.mergers.foundations.enriched_task_adjust import ( + EnrichedTaskAdjustMerge, +) +from generalresearch.incite.mergers.foundations.enriched_wall import ( + EnrichedWallMerge, +) +from generalresearch.incite.mergers.foundations.user_id_product import ( + UserIdProductMerge, +) +from generalresearch.incite.mergers.pop_ledger import PopLedgerMerge +from generalresearch.incite.mergers.ym_survey_wall import YMSurveyWallMerge +from generalresearch.pg_helper import PostgresConfig +from generalresearch.sql_helper import SqlHelper + + +# --- THL Web --- # + + +def session_df_collection( + ds: "GRLDatasets", pg_config: PostgresConfig +) -> SessionDFCollection: + return SessionDFCollection( + offset="37h", + pg_config=pg_config, + start=datetime(year=2022, month=5, day=3, hour=12, tzinfo=timezone.utc), + archive_path=ds.archive_path(enum_type=DFCollectionType.SESSION), + ) + + +def wall_df_collection( + ds: "GRLDatasets", pg_config: PostgresConfig +) -> WallDFCollection: + return WallDFCollection( + offset="49h", + pg_config=pg_config, + start=datetime(year=2022, month=5, day=3, hour=12, tzinfo=timezone.utc), + archive_path=ds.archive_path(enum_type=DFCollectionType.WALL), + ) + + +def user_df_collection( + ds: "GRLDatasets", pg_config: PostgresConfig +) -> UserDFCollection: + return UserDFCollection( + offset="73h", + pg_config=pg_config, + start=datetime(year=2016, month=7, day=13, hour=1, tzinfo=timezone.utc), + archive_path=ds.archive_path(enum_type=DFCollectionType.USER), + ) + + +def task_df_collection( + ds: "GRLDatasets", pg_config: PostgresConfig +) -> TaskAdjustmentDFCollection: + return TaskAdjustmentDFCollection( + offset="48h", + pg_config=pg_config, + start=datetime(year=2022, month=7, day=16, hour=0, tzinfo=timezone.utc), + archive_path=ds.archive_path(enum_type=DFCollectionType.TASK_ADJUSTMENT), + ) + + +def ledger_df_collection( + ds: "GRLDatasets", pg_config: PostgresConfig +) -> LedgerDFCollection: + return LedgerDFCollection( + offset="12d", + pg_config=pg_config, + # thl_web:ledger_transaction - 1st record is 2018-03-14 20:22:17.408232 + start=datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc), + archive_path=ds.archive_path(enum_type=DFCollectionType.LEDGER), + ) + + +# --- Marketplace Specifics --- # +def innovate_survey_history_collection( + ds: "GRLDatasets", sql_helper: SqlHelper +) -> InnovateSurveyHistoryCollection: + return InnovateSurveyHistoryCollection( + offset="12h", + sql_helper=sql_helper, + start=datetime(year=2024, month=3, day=1, hour=0, tzinfo=timezone.utc), + archive_path=ds.archive_path( + enum_type=DFCollectionType.INNOVATE_SURVEY_HISTORY + ), + ) + + +def morning_survey_ts_collection( + ds: "GRLDatasets", sql_helper: SqlHelper +) -> MorningSurveyTimeseriesCollection: + return MorningSurveyTimeseriesCollection( + offset="12h", + sql_helper=sql_helper, + start=datetime(year=2024, month=3, day=1, hour=0, tzinfo=timezone.utc), + archive_path=ds.archive_path( + enum_type=DFCollectionType.MORNING_SURVEY_TIMESERIES + ), + ) + + +def sago_survey_history_collection( + ds: "GRLDatasets", sql_helper: SqlHelper +) -> SagoSurveyHistoryCollection: + return SagoSurveyHistoryCollection( + offset="12h", + sql_helper=sql_helper, + start=datetime(year=2024, month=3, day=1, hour=0, tzinfo=timezone.utc), + archive_path=ds.archive_path(enum_type=DFCollectionType.SAGO_SURVEY_HISTORY), + ) + + +def spectrum_survey_ts_collection( + ds: "GRLDatasets", sql_helper: SqlHelper +) -> SpectrumSurveyTimeseriesCollection: + return SpectrumSurveyTimeseriesCollection( + offset="12h", + sql_helper=sql_helper, + start=datetime(year=2024, month=3, day=1, hour=0, tzinfo=timezone.utc), + archive_path=ds.archive_path( + enum_type=DFCollectionType.SPECTRUM_SURVEY_TIMESERIES + ), + ) + + +# --- Mergers: Foundations --- # +def user_id_product(ds: "GRLDatasets") -> UserIdProductMerge: + return UserIdProductMerge( + start=datetime(year=2010, month=1, day=1, tzinfo=timezone.utc), + offset=None, + archive_path=ds.archive_path(enum_type=MergeType.USER_ID_PRODUCT), + ) + + +def enriched_session(ds: "GRLDatasets") -> EnrichedSessionMerge: + return EnrichedSessionMerge( + start=datetime(year=2023, month=5, day=1, tzinfo=timezone.utc), + offset="14d", + archive_path=ds.archive_path(enum_type=MergeType.ENRICHED_SESSION), + ) + + +def enriched_wall(ds: "GRLDatasets") -> EnrichedWallMerge: + return EnrichedWallMerge( + # start=datetime(year=2022, month=5, day=1, tzinfo=timezone.utc), + start=datetime(year=2023, month=7, day=23, tzinfo=timezone.utc), + offset="14d", + archive_path=ds.archive_path(enum_type=MergeType.ENRICHED_WALL), + ) + + +def enriched_task_adjust(ds: "GRLDatasets") -> EnrichedTaskAdjustMerge: + return EnrichedTaskAdjustMerge( + start=datetime(year=2010, month=1, day=1, tzinfo=timezone.utc), + offset=None, + archive_path=ds.archive_path(enum_type=MergeType.ENRICHED_TASK_ADJUST), + ) + + +# --- Mergers: Others --- # +def pop_ledger(ds: "GRLDatasets") -> PopLedgerMerge: + return PopLedgerMerge( + # thl_web:ledger_transaction - 1st record is 2018-03-14 20:22:17.408232 + start=datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc), + offset="30d", + archive_path=ds.archive_path(enum_type=MergeType.POP_LEDGER), + ) + + +def ym_survey_wall(ds: "GRLDatasets") -> YMSurveyWallMerge: + return YMSurveyWallMerge( + start=None, + offset="10D", + archive_path=ds.archive_path(enum_type=MergeType.YM_SURVEY_WALL), + ) diff --git a/generalresearch/incite/mergers/__init__.py b/generalresearch/incite/mergers/__init__.py new file mode 100644 index 0000000..22ac603 --- /dev/null +++ b/generalresearch/incite/mergers/__init__.py @@ -0,0 +1,316 @@ +import logging +import os.path +import subprocess +from datetime import datetime, timezone +from enum import Enum +from sys import platform +from typing import Optional, List, Type + +import dask.dataframe as dd +import pandas as pd +from dask.distributed import Client +from pandera import DataFrameSchema +from pydantic import Field, field_validator, ValidationInfo, model_validator +from typing_extensions import Self + +from generalresearch.incite.base import CollectionBase, CollectionItemBase +from generalresearch.incite.schemas import PARTITION_ON +from generalresearch.incite.schemas.mergers.foundations.enriched_session import ( + EnrichedSessionSchema, +) +from generalresearch.incite.schemas.mergers.foundations.enriched_task_adjust import ( + EnrichedTaskAdjustSchema, +) +from generalresearch.incite.schemas.mergers.foundations.enriched_wall import ( + EnrichedWallSchema, +) +from generalresearch.incite.schemas.mergers.foundations.user_id_product import ( + UserIdProductSchema, +) +from generalresearch.incite.schemas.mergers.nginx import ( + NGINXGRSSchema, + NGINXCoreSchema, + NGINXFSBSchema, +) +from generalresearch.incite.schemas.mergers.pop_ledger import ( + PopLedgerSchema, +) +from generalresearch.incite.schemas.mergers.ym_survey_wall import ( + YMSurveyWallSchema, +) +from generalresearch.incite.schemas.mergers.ym_wall_summary import ( + YMWallSummarySchema, +) +from generalresearch.models.custom_types import AwareDatetimeISO + +LOG = logging.getLogger("incite") + + +class MergeType(str, Enum): + TEST = "test" + YM_SURVEY_WALL = "ym_survey_wall" + YM_WALL_SUMMARY = "ym_wall_summary" + + NGINX_GRS = "nginx_grs" + NGINX_FSB = "nginx_fsb" + NGINX_CORE = "nginx_core" + + POP_LEDGER = "pop_ledger" + + # --- Foundations --- + USER_ID_PRODUCT = "user_id_product" + ENRICHED_WALL = "enriched_wall" + ENRICHED_SESSION = "enriched_session" + ENRICHED_TASK_ADJUST = "enriched_task_adjust" + + +MergeTypeSchemas = { + MergeType.YM_SURVEY_WALL: YMSurveyWallSchema, + MergeType.YM_WALL_SUMMARY: YMWallSummarySchema, + MergeType.NGINX_GRS: NGINXGRSSchema, + MergeType.NGINX_FSB: NGINXFSBSchema, + MergeType.NGINX_CORE: NGINXCoreSchema, + MergeType.POP_LEDGER: PopLedgerSchema, + # --- Foundations --- + MergeType.USER_ID_PRODUCT: UserIdProductSchema, + MergeType.ENRICHED_WALL: EnrichedWallSchema, + MergeType.ENRICHED_SESSION: EnrichedSessionSchema, + MergeType.ENRICHED_TASK_ADJUST: EnrichedTaskAdjustSchema, +} + + +class MergeCollectionItem(CollectionItemBase): + + # --- Properties --- + + @property + def finish(self) -> datetime: + # A MergeCollection can have offset = None + if self._collection.offset: + return ( + pd.Timestamp(self.start) + pd.Timedelta(self._collection.offset) + ).to_pydatetime() + else: + return datetime.now(tz=timezone.utc).replace(microsecond=0) + + @property + def filename(self) -> str: + grouped_key = self._collection.grouped_key + offset = self._collection.offset + start = self.start.strftime("%Y-%m-%d-%H-%M-%S") + f = [self._collection.merge_type.name.lower()] + if offset: + f.append(offset) + if grouped_key: + f.append(grouped_key) + if self._collection.start is not None: + # This is a collection that is "looking back" 'offset' time (1 item). + f.append(start) + s = "-".join(f) + s += ".parquet" + return s + + # --- ORM / Data handlers--- + def to_dict(self, *args, **kwargs) -> dict: + res = self._to_dict() + res["group_by"] = self._collection.group_by + return res + + def to_archive( + self, + client: Client, + ddf: dd.DataFrame, + is_partial: bool = False, + client_resources=None, + ) -> bool: + assert is_partial is False, "use to_archive_symlink" + return self._to_archive(client=client, ddf=ddf, client_resources=None) + + def _to_archive( + self, client: Client, ddf: dd.DataFrame, client_resources=None + ) -> bool: + """ + For archiving an item. Will write an empty file if ddf is empty. + This is NOT for writing partials. + + :returns: bool (saved_successful) + """ + if ddf is None: + return False + + row_len = client.compute(collections=ddf.shape[0], sync=True) + assert row_len > 0, "empty ddf" + + tmp_path = self.tmp_path() + schema = self._collection._schema + partition = schema.metadata.get(PARTITION_ON, None) + f = ddf.to_parquet( + compute=False, + path=tmp_path, + partition_on=partition, + engine="pyarrow", + overwrite=True, + write_metadata_file=True, + compression="brotli", + ) + client.compute(f, sync=True, priority=2, resources=client_resources) + assert not os.path.exists( + self.path.as_posix() + ), f"already exits!: {self.path.as_posix()}" + + if platform == "darwin": + subprocess.call(["mv", tmp_path.as_posix(), self.path.as_posix()]) + else: + # -T will (should) cause the mv to fail if `path` wasn't successfully deleted + subprocess.call(["mv", "-T", tmp_path.as_posix(), self.path.as_posix()]) + return True + + def to_archive_symlink( + self, + client: Client, + ddf: dd.DataFrame, + is_partial: bool = False, + client_resources=None, + validate_after=True, + ) -> bool: + """ + This differs from to_archive(): + 1) to_parquet is run in this process. If the df is already + computed, there is no point in sending it to another worker + to write. + + 2) symlink to next_numbered_path is created whether or not + is_partial (to_archive only does this on partials) + + 3) we do not validate the written file. seems not useful to do + this, as the file will probably get overwritten on the next + loop anyway + """ + path = self.partial_path if is_partial else self.path + next_numbered_path = self.next_numbered_path(path) + collection = self._collection + LOG.warning(f"{collection.merge_type.value}.to_archive_symlink()") + + if not isinstance(ddf, dd.DataFrame): + raise ValueError("must pass a dask df") + + # We should validate before or after!!! + # _validate_df(self.compute(ddf), coll._schema) + target = ( + next_numbered_path.name + ) # this is the symlink's target. it is a relative path (only the name) + + schema = self._collection._schema + partition = schema.metadata.get(PARTITION_ON, None) + f = ddf.to_parquet( + compute=False, + path=next_numbered_path.as_posix(), + partition_on=partition, + engine="pyarrow", + overwrite=True, + write_metadata_file=True, + compression="brotli", + ) + client.compute(f, sync=True, priority=2, resources=client_resources) + + if os.path.exists(path.as_posix()) and not os.path.islink(path.as_posix()): + # This will fail when going from the old way to using symlinks, + # if self.path already exists and is a directory. + raise ValueError( + f"first time we run this, make sure the path doesnt exist: {path.as_posix()}" + ) + + if platform == "darwin": + subprocess.call(["ln", "-sfn", target, path.as_posix()]) + else: + subprocess.call(["ln", "-sfnT", target, path.as_posix()]) + + if validate_after: + if not self.valid_archive(self.path): + LOG.error( + f"{collection.merge_type.value} failed validation: {self.path}" + ) + self.delete_archive(self.path) + return False + return True + + # todo: unclear what the common interface should be here ... ? + def fetch(self, *args, **kwargs) -> pd.DataFrame | dd.DataFrame: + raise NotImplementedError("implement in subclass") + + def build(self, *args, **kwargs) -> pd.DataFrame | dd.DataFrame: + raise NotImplementedError("implement in subclass") + + +class MergeCollection(CollectionBase): + """Mergers take instances of DFCollections, and/or other Mergers""" + + # In a merge, we can set offset = None which indicates that there is only 1 + # period/item where the range is 'start' until now. + offset: Optional[str] = Field(default="72h") + # In a merge, we can set start = None which indicates that there is only 1 + # period/item where the range is (now - offset) until now. + start: Optional[AwareDatetimeISO] = Field( + default=None, + description="This is the starting point in which data will" + " be retrieved in chunks from.", + frozen=True, + ) + + merge_type: Optional[MergeType] = Field(default=None) + group_by: Optional[str] = Field(default=None) + grouped_key: Optional[str] = Field(default=None) + collection_item_class: Type[MergeCollectionItem] = MergeCollectionItem + + @model_validator(mode="after") + def check_start_and_offset_nullable(self) -> Self: + if self.offset is None and self.start is None: + raise AssertionError("cannot set both start and offset to None") + return self + + @field_validator("merge_type") + def check_merge_type(cls, merge_type, info: ValidationInfo): + if merge_type is None: + raise ValueError("Must explicitly provide a merge_type") + + if merge_type not in MergeTypeSchemas: + raise ValueError("Must provide a supported merge_type") + + return merge_type + + # --- Properties --- + @property + def interval_start(self) -> Optional[datetime]: + # if self.start is None and self.offset is set, the inferred start is (now - offset) + if self.start is None: + return datetime.now(tz=timezone.utc).replace(microsecond=0) - pd.Timedelta( + self.offset + ) + return self.start + + @property + def items(self) -> List[MergeCollectionItem]: + items = [] + for iv in self.interval_range: + cm = self.collection_item_class(start=iv[0]) + cm._collection = self + items.append(cm) + return items + + @property + def _schema(self) -> DataFrameSchema: + return MergeTypeSchemas[self.merge_type] + + def signature(self) -> str: + arr = [ + 1 if i.has_archive(include_empty=True) else 0 + for i in self.items + if i.should_archive() + ] + repr_str = ( + f"path={self.archive_path.as_posix()}; " + f"items={len(self.items)}; start={self.start} @ {self.offset}; {int(sum(arr) / len(arr) * 100)}% " + f"archived" + ) + res = f"{self.__repr_name__()}({repr_str})" + return res diff --git a/generalresearch/incite/mergers/account_blocks.py b/generalresearch/incite/mergers/account_blocks.py new file mode 100644 index 0000000..01bbc0a --- /dev/null +++ b/generalresearch/incite/mergers/account_blocks.py @@ -0,0 +1,189 @@ +# import json +# import logging +# import os +# from typing import Optional +# +# import grpc +# import pandas as pd +# from generalresearch.locales import Localelator +# from generalresearch.sql_helper import SqlHelper +# from pandas import DataFrame +# +# from incite.data.build import BuildObject +# from incite.protos import generalresearch_pb2, generalresearch_pb2_grpc +# +# web_sql_helper = SqlHelper(**WEB_CONFIG_DICT) +# locale_helper = Localelator() +# +# logging.basicConfig() +# logger = logging.getLogger() +# logger.setLevel(LOG_LEVEL) +# +# +# class BuildAccountBlocksActive(BuildObject): +# +# def build(self) -> None: +# df = self.get_account_blocks_active_df() +# self.result = df.reset_index() +# +# def export(self, dry_run=True, god_only: Optional[bool] = False, rebuild: Optional[bool] = False) -> None: +# file_path = os.path.join(EXPORT_DIR, "_account_blocks_active.feather") +# +# if dry_run: +# logger.info(f"[dryrun] saving: {file_path}: {' x '.join([str(i) for i in self.result.shape])}") +# else: +# logger.info(f"saving: {file_path}: {' x '.join([str(i) for i in self.result.shape])}") +# self.result.to_feather(file_path) +# +# def get_account_blocks_active_df(self) -> DataFrame: +# logger.info(f"BuildAccountBlocksGeo.get_account_blocks_active_df") +# +# res = web_sql_helper.execute_sql_query(f"""SELECT bp.id FROM userprofile_brokerageproduct bp""") +# product_ids = [x["id"] for x in res] +# +# # 1. Global config +# with grpc.insecure_channel(GRPC_SERVER) as channel: +# stub = generalresearch_pb2_grpc.GeneralResearchStub(channel) +# msg = generalresearch_pb2.GetBPConfigRequest(product_id=GLOBAL_CONFIG) +# res = list(stub.GetBPConfig(msg)) +# +# # Global config always needs routers +# routers = {x.key: json.loads(x.value) for x in res}["routers"] +# sources = set([r["name"] for r in routers]) +# from incite.data.utils import MARKETPLACE_KEYS +# opts = set(MARKETPLACE_KEYS.keys()) +# assert opts.issuperset(sources), "Default router definitions not available option" +# +# df = pd.DataFrame( +# index=pd.MultiIndex.from_product([product_ids, ["global_config", "product_config"]], +# names=["product_id", "reason"]), +# columns=sources).fillna(0) +# +# for router in routers: +# if not router["active"]: +# df.loc[(slice(None), "global_config"), [router["name"]]] = 1 +# +# # 2. bpid specific router definitions +# with grpc.insecure_channel(GRPC_SERVER) as channel: +# stub = generalresearch_pb2_grpc.GeneralResearchStub(channel) +# +# for product_id in product_ids: +# msg = generalresearch_pb2.GetBPConfigRequest(product_id=product_id) +# res = list(stub.GetBPConfig(msg)) +# +# routers = {x.key: json.loads(x.value) for x in res}["routers"] +# +# for router in routers: +# if not router["active"]: +# df.loc[(product_id, "product_config"), [router["name"]]] = 1 +# +# return df + + +# +# import json +# import logging +# import os +# from typing import Optional +# +# import grpc +# import pandas as pd +# from generalresearch.locales import Localelator +# from generalresearch.sql_helper import SqlHelper +# from google.protobuf.json_format import MessageToDict +# from pandas import DataFrame +# +# from incite.data.build import BuildObject +# from incite.protos import thl_pb2, thl_pb2_grpc, generalresearch_pb2, generalresearch_pb2_grpc +# +# web_sql_helper = SqlHelper(**WEB_CONFIG_DICT) +# locale_helper = Localelator() +# +# logging.basicConfig() +# logger = logging.getLogger() +# logger.setLevel(LOG_LEVEL) +# +# +# class BuildAccountBlocksGeo(BuildObject): +# +# def build(self) -> None: +# df = self.get_account_blocks_geo_df() +# self.result = df.reset_index() +# +# def export(self, dry_run=True, god_only: Optional[bool] = False, rebuild: Optional[bool] = False) -> None: +# file_path = os.path.join(DS["exports"], "_account_blocks_geo.feather") +# +# if dry_run: +# logger.info(f"[dryrun] saving: {file_path}: {' x '.join([str(i) for i in self.result.shape])}") +# else: +# logger.info(f"saving: {file_path}: {' x '.join([str(i) for i in self.result.shape])}") +# self.result.to_feather(file_path) +# +# def get_account_blocks_geo_df(self) -> DataFrame: +# logger.info(f"BuildAccountBlocksGeo.get_account_blocks_geo_df") +# +# res = web_sql_helper.execute_sql_query(f"""SELECT bp.id FROM userprofile_brokerageproduct bp""") +# product_ids = [x["id"] for x in res] +# # (TODO) Stupid GB vs UK here... can cleanup on admin portal... not sure what/where to do it +# geos = list(locale_helper.get_all_countries()) + ["uk"] +# +# # 1. Global config +# with grpc.insecure_channel(GRPC_SERVER) as channel: +# stub = generalresearch_pb2_grpc.GeneralResearchStub(channel) +# msg = generalresearch_pb2.GetBPConfigRequest(product_id=GLOBAL_CONFIG) +# res = list(stub.GetBPConfig(msg)) +# +# # Global config always needs routers +# routers = {x.key: json.loads(x.value) for x in res}["routers"] +# sources = set([r["name"] for r in routers]) +# from incite.data.utils import MARKETPLACE_KEYS +# opts = set(MARKETPLACE_KEYS.keys()) +# assert opts.issuperset(sources), "Default router definitions not available option" +# +# df = pd.DataFrame( +# index=pd.MultiIndex.from_product([product_ids, sources, ["global_config", "product_config", +# "eligibility"]], +# names=["product_id", "source", "reason"]), +# columns=geos).fillna(0) +# +# for router in routers: +# df.loc[(slice(None), router["name"], "global_config"), router.get("banned_countries", [])] = 1 +# +# # 2. bpid specific router definitions +# with grpc.insecure_channel(GRPC_SERVER) as channel: +# stub = generalresearch_pb2_grpc.GeneralResearchStub(channel) +# +# for product_id in product_ids: +# msg = generalresearch_pb2.GetBPConfigRequest(product_id=product_id) +# res = list(stub.GetBPConfig(msg)) +# +# routers = {x.key: json.loads(x.value) for x in res}["routers"] +# +# for router in routers: +# df.loc[(product_id, router["name"], "product_config"), router.get("banned_countries", [])] = 1 +# +# # 3. bpid specific eligibility values +# with grpc.insecure_channel(GRPC_SERVER) as channel: +# thl_stub = thl_pb2_grpc.THLStub(channel) +# +# for product_id in product_ids: +# req = thl_pb2.GetPlatformStatsRequest(bpid=product_id) +# res = thl_stub.GetPlatformStats(req) +# res = MessageToDict(res, including_default_value_fields=True, +# preserving_proto_field_name=True, +# use_integers_for_enums=True) +# stats = res.get("stats", None) +# +# if not stats: +# continue +# +# res = [dict(name=x[0], value=x[1]) for x in zip(stats.keys(), stats.values())] +# +# for stat in res: +# statn = stat["name"].split("MARKETPLACE_ELIGIBILITY.") +# if len(statn) == 2: +# source, geo = statn[1].split(".") +# if stat.get("value") == 0: +# df.loc[(product_id, source, "eligibility"), [geo]] = 1 +# +# return df diff --git a/generalresearch/incite/mergers/foundations/__init__.py b/generalresearch/incite/mergers/foundations/__init__.py new file mode 100644 index 0000000..d3db74c --- /dev/null +++ b/generalresearch/incite/mergers/foundations/__init__.py @@ -0,0 +1,167 @@ +import logging +from typing import Collection, List, Dict + +import pandas as pd +from more_itertools import chunked + +from generalresearch.pg_helper import PostgresConfig + +LOG = logging.getLogger("incite") + + +def annotate_product_id( + df: pd.DataFrame, pg_config: PostgresConfig, chunksize=500 +) -> pd.DataFrame: + """ + Dask map_partitions is being called on a dask df. However, the function + it applies to each partition is being passed a chunk of the dask + df AS a pandas df. + + expects column 'user_id', adds column 'product_id' + """ + LOG.warning(f"annotate_product_id.chunk: {df.shape}") + assert "user_id" in df.columns, "must have a user_id column to join on" + + user_ids = df["user_id"].dropna() + user_ids = set(user_ids) + assert len(user_ids) >= 1, "must have user_ids" + LOG.warning(f"annotate_product_id.len(user_ids): {len(user_ids)}") + + res: List[Dict] = [] + with pg_config.make_connection() as conn: + for chunk in chunked(user_ids, chunksize): + try: + with conn.cursor() as c: + c.execute( + query=""" + SELECT id as user_id, product_id + FROM thl_user + WHERE id = ANY(%s)""", + params=[list(chunk)], + ) + res.extend(c.fetchall()) + + except Exception: + LOG.exception(f"annotate_product_id: {chunk}") + raise + + dfu = pd.DataFrame(res, columns=["user_id", "product_id"]) + + return df.merge(dfu, on="user_id", how="left") + + +def lookup_product_and_team_id( + user_ids: Collection[int], + pg_config: PostgresConfig, +) -> List[Dict]: + user_ids = set(user_ids) + LOG.info(f"lookup_product_and_team_id: {len(user_ids)}") + LOG.info({type(x) for x in user_ids}) + assert all(type(x) is int for x in user_ids), "must pass all integers" + assert len(user_ids) >= 1, "must have user_ids" + assert len(user_ids) <= 1000, "you should chunk this bro" + + res: List[Dict] = [] + with pg_config.make_connection() as conn: + try: + with conn.cursor() as c: + c.execute( + query=""" + SELECT u.id AS user_id, + u.product_id, + bp.team_id + FROM thl_user u + INNER JOIN userprofile_brokerageproduct AS bp + ON bp.id = u.product_id + WHERE u.id = ANY(%s); + """, + params=[list(user_ids)], + ) + res.extend(c.fetchall()) + + except Exception as e: + LOG.exception(f"lookup_product_and_team_id: {e}") + raise + + return res + + +def annotate_product_and_team_id( + df: pd.DataFrame, pg_config: PostgresConfig, chunksize=500 +) -> pd.DataFrame: + """ + Dask map_partitions is being called on a dask df. However, the function + it applies to each partition is being passed a chunk of the dask + df AS a pandas df. + + expects column 'user_id', adds column 'product_id' and team_id + + """ + LOG.info(f"annotate_product_and_team_id.chunk: {df.shape}") + assert "user_id" in df.columns, "must have a user_id column to join on" + + user_ids = df["user_id"].dropna() + user_ids = set(user_ids) + assert len(user_ids) >= 1, "must have user_ids" + LOG.warning(f"annotate_product_and_team_id.len(user_ids): {len(user_ids)}") + + res: List[Dict] = [] + with pg_config.make_connection() as conn: + for chunk in chunked(user_ids, chunksize): + try: + with conn.cursor() as c: + c.execute( + query=f""" + SELECT u.id AS user_id, u.product_id, + bp.team_id + FROM thl_user u + INNER JOIN userprofile_brokerageproduct AS bp + ON bp.id = u.product_id + WHERE u.id = ANY(%s); + """, + params=[list(chunk)], + ) + res.extend(c.fetchall()) + + except Exception: + LOG.exception(f"annotate_product_and_team_id: {chunk}") + raise + + dfu = pd.DataFrame(res, columns=["user_id", "product_id", "team_id"]) + + return df.merge(dfu, on="user_id", how="left") + + +def annotate_product_user( + df: pd.DataFrame, pg_config: PostgresConfig, chunksize=500 +) -> pd.DataFrame: + LOG.info(f"annotate_product_user.chunk: {df.shape}") + assert "user_id" in df.columns, "must have a user_id column to join on" + + user_ids = df["user_id"].dropna() + user_ids = set(user_ids) + assert len(user_ids) >= 1, "must have user_ids" + LOG.warning(f"annotate_product_user.len(user_ids): {len(user_ids)}") + + res: List[Dict] = [] + with pg_config.make_connection() as conn: + for chunk in chunked(user_ids, chunksize): + try: + with conn.cursor() as c: + c.execute( + query=""" + SELECT u.id AS user_id, u.product_user_id + FROM thl_user u + WHERE u.id = ANY(%s); + """, + params=[list(chunk)], + ) + res.extend(c.fetchall()) + + except Exception: + LOG.exception(f"annotate_product_user: {chunk}") + raise + + dfu = pd.DataFrame(res, columns=["user_id", "product_user_id"]) + + return df.merge(dfu, on="user_id", how="left") diff --git a/generalresearch/incite/mergers/foundations/enriched_session.py b/generalresearch/incite/mergers/foundations/enriched_session.py new file mode 100644 index 0000000..7fdcb50 --- /dev/null +++ b/generalresearch/incite/mergers/foundations/enriched_session.py @@ -0,0 +1,331 @@ +import logging +from datetime import timedelta +from typing import Literal, Optional, List, TYPE_CHECKING + +import dask.dataframe as dd +import pandas as pd +from dask.distributed import as_completed +from distributed import Client +from more_itertools import chunked, flatten + +from generalresearch.incite.collections.thl_web import ( + SessionDFCollection, + WallDFCollection, +) +from generalresearch.incite.mergers import ( + MergeCollection, + MergeCollectionItem, + MergeType, +) +from generalresearch.incite.mergers.foundations import ( + lookup_product_and_team_id, +) +from generalresearch.incite.schemas import empty_dataframe_from_schema +from generalresearch.incite.schemas.admin_responses import ( + AdminPOPSessionSchema, +) +from generalresearch.incite.schemas.mergers.foundations.enriched_session import ( + EnrichedSessionSchema, +) +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.thl.user import User +from generalresearch.pg_helper import PostgresConfig + +if TYPE_CHECKING: + from generalresearch.models.admin.request import ReportRequest + +LOG = logging.getLogger("incite") + + +class EnrichedSessionMergeItem(MergeCollectionItem): + + def build( + self, + session_coll: SessionDFCollection, + wall_coll: WallDFCollection, + pg_config: PostgresConfig, + client: Optional[Client] = None, + client_resources=None, + ) -> None: + + ir: pd.Interval = self.interval + start, end = ir.left.to_pydatetime(), ir.right.to_pydatetime() + + LOG.warning(f"EnrichedSessionMergeItem.build({ir})") + + # Skip which already exist + if self.has_archive(include_empty=True): + return None + + # --- Session --- + LOG.warning(f"EnrichedSessionMergeItem: get session_collection") + session_items = [w for w in session_coll.items if w.interval.overlaps(ir)] + if len(session_items) == 0: + LOG.warning(f"EnrichedSessionMergeItem: no session items. set_empty.") + if self.should_archive(): + self.set_empty() + return None + if not ( + session_items[-1].has_partial_archive() or session_items[-1].has_archive() + ): + LOG.warning(f"EnrichedSessionMergeItem: session isn't updated!") + return None + + sddf = session_coll.ddf( + items=session_items, + include_partial=True, + force_rr_latest=False, + filters=[("started", ">=", start), ("started", "<", end)], + ) + + # --- Walls --- + LOG.warning(f"EnrichedSessionMergeItem: merge wall_collection") + wall_items = [ + w + for w in wall_coll.items + if w.interval.overlaps( + pd.Interval( + ir.left - timedelta(hours=2), + ir.right + timedelta(hours=2), + closed="both", + ) + ) + ] + + if len(wall_items) == 0: + LOG.error(f"EnrichedSessionMergeItem: no wall items") + return None + + wddf = wall_coll.ddf( + items=wall_items, + include_partial=True, + force_rr_latest=False, + columns=["session_id"], + filters=[ + ("started", ">=", start - timedelta(hours=2)), + ("started", "<", end + timedelta(hours=2)), + ], + ) + + if wddf is None: + return None + + attempt_cnt_ddf = ( + wddf.groupby("session_id").size().rename("attempt_count").to_frame() + ) + ddf = sddf.join(attempt_cnt_ddf, how="left", npartitions=12) + ddf["attempt_count"] = ddf["attempt_count"].fillna(0) + # ddf = ddf.repartition(npartitions=4) + ddf = ddf.reset_index() + + # Unclear if this is needed. We are client.computing ddf literally + # in the next line, so I think it is not. + # client.persist(ddf) + + df: pd.DataFrame = client.compute(ddf, sync=True) + + user_ids = set( + map(int, df["user_id"].unique()) + ) # must int, otherwise it's a np.int sigh + + # Submit at most N tasks at a time. Will be useful when we have 32 workers again. + futures = set() + for chunk in chunked(user_ids, 500): + ac = as_completed(futures) + while ac.count() >= 4: + next(ac) # Wait for tasks to finish before submitting a new one + futures.add(client.submit(lookup_product_and_team_id, chunk, pg_config)) + + try: + results = client.gather(list(futures)) + except Exception as e: + client.cancel(futures, asynchronous=False, force=True) + raise e + + dfp = pd.DataFrame( + list(flatten(results)), columns=["user_id", "product_id", "team_id"] + ).astype({"user_id": int, "product_id": str, "team_id": str}) + df = df.merge(dfp, on="user_id", how="left") + + df = df.set_index("id") + + df = df[df["started"].between(start, end)] + + is_missing = df[["product_id"]].isna().sum().sum() > 0 + session_is_partial = any([w.should_archive() is False for w in session_items]) + session_is_missing = any( + [ + w.should_archive() is True and w.has_archive() is False + for w in session_items + ] + ) + wall_is_missing = any( + [ + w.should_archive() is True and w.has_archive() is False + for w in wall_items + ] + ) + is_partial = ( + is_missing or session_is_partial or session_is_missing or wall_is_missing + ) + + LOG.warning(f"missing user_ids: {df[df.product_id.isnull()].user_id.unique()}") + # Lots of downstream issues with this... + df = df[df.product_id.notna()] + df.product_id = df.product_id.astype(str) + + if is_partial: + ddf = dd.from_pandas(df, npartitions=6) + self.to_archive_symlink( + client, + ddf=ddf, + is_partial=True, + validate_after=False, + client_resources=client_resources, + ) + else: + df = self.validate_df(df=df) + ddf = dd.from_pandas(df, npartitions=6) + self.to_archive( + client, + ddf=ddf, + is_partial=False, + client_resources=client_resources, + ) + + +class EnrichedSessionMerge(MergeCollection): + merge_type: Literal[MergeType.ENRICHED_SESSION] = MergeType.ENRICHED_SESSION + _schema = EnrichedSessionSchema + collection_item_class: Literal[EnrichedSessionMergeItem] = EnrichedSessionMergeItem + + def build( + self, + client: Client, + session_coll: SessionDFCollection, + wall_coll: WallDFCollection, + pg_config: PostgresConfig, + ) -> None: + LOG.info( + f"EnrichedSessionMerge.build(session_coll={session_coll.signature()}, " + f"wall_coll={wall_coll.signature()}, " + f"pg_config={pg_config})" + ) + + assert isinstance(session_coll, SessionDFCollection) + assert isinstance(wall_coll, WallDFCollection) + assert isinstance(pg_config, PostgresConfig) + + for item in reversed(self.items): + if item.has_archive(include_empty=True): + continue + LOG.info(item) + item.build( + client=client, + session_coll=session_coll, + wall_coll=wall_coll, + pg_config=pg_config, + ) + + def to_admin_response( + self, + rr: "ReportRequest", + client: Client, + product_ids: Optional[List[UUIDStr]] = None, + user: Optional[User] = None, + ) -> pd.DataFrame: + """ + We don't have the concept of a Team yet so product_ids will be a list + """ + + filters = [] + + if user: + assert ( + len(product_ids) <= 1 + ), "Can't search more than 1 Product ID for a specific User" + assert ( + user.product_id in product_ids + ), "The provided user must be associated with the Product ID" + filters.append( + ("user_id", "==", user.user_id), + ) + + if product_ids: + assert ( + len(product_ids) >= 1 + ), "Don't provide an empty list. Pass None if SELECT ALL is desired" + filters.append( + ("product_id", "in", product_ids), + ) + + # es_items = [w for w in self.items if w.interval.overlaps(rr.pd_interval)] + ddf = self.ddf( + # items=es_items, + force_rr_latest=False, + include_partial=True, + columns=[ + "product_id", + "user_id", + "started", + "finished", + "attempt_count", + "status", + "status_code_1", + "status_code_2", + "country_iso", + "device_type", + "payout", + ], + filters=filters, + ) + + if ddf is None: + return empty_dataframe_from_schema(schema=EnrichedSessionSchema) + + ddf["elapsed"] = (ddf["finished"] - ddf["started"]).dt.total_seconds() + ddf["status"] = ddf.status.fillna("e") + ddf["status_code_1"] = ddf.status_code_1.fillna(0) + ddf["status_code_2"] = ddf.status_code_2.fillna(0) + ddf["complete"] = ddf.status.eq("c") + + dfa = client.compute( + collections=ddf, + sync=True, + priority=1, + ) + # dfa = dfa[dfa["started"].between(rr.start, rr.finish)] + + assert rr.interval == "5min" + group_arr = [pd.Grouper(key="started", freq=rr.interval), rr.index1] + df = dfa.groupby(group_arr).aggregate( + elapsed_avg=("elapsed", "mean"), + elapsed_total=("elapsed", "sum"), + payout_total=("payout", "sum"), + attempts_avg=("attempt_count", "mean"), + attempts_total=("attempt_count", "sum"), + entrances=("complete", "size"), + completes=("complete", "sum"), + users=("user_id", "nunique"), + ) + + # Completes only + df_completes = ( + dfa[(dfa.status == "c") & (dfa.payout > 0)] + .groupby(group_arr) + .aggregate( + payout_avg=("payout", "mean"), + elapsed_avg=("elapsed", "mean"), + elapsed_total=("elapsed", "sum"), + ) + ) + + df["payout_avg"] = df_completes.payout_avg + df["conversion"] = df.completes / df.entrances # system conversion + df["epc"] = df.payout_total / df.entrances # earnings per click + df["eph"] = df.payout_total / (df.elapsed_total / 3_600) # earnings per hour + # df["eph"] = df.payout_total / (df_completes.elapsed_total / 3_600) # earnings per hour + df["cpc"] = df_completes.payout_avg * df.conversion + + df.index = df.index.set_names(names=["index0", "index1"]) + return AdminPOPSessionSchema.validate(df).fillna(0) diff --git a/generalresearch/incite/mergers/foundations/enriched_task_adjust.py b/generalresearch/incite/mergers/foundations/enriched_task_adjust.py new file mode 100644 index 0000000..1749f9a --- /dev/null +++ b/generalresearch/incite/mergers/foundations/enriched_task_adjust.py @@ -0,0 +1,211 @@ +import logging +from typing import Literal + +import dask.dataframe as dd +import pandas as pd +from distributed import Client +from sentry_sdk import capture_exception + +from generalresearch.incite.collections.thl_web import ( + TaskAdjustmentDFCollection, +) +from generalresearch.incite.mergers import ( + MergeCollection, + MergeType, + MergeCollectionItem, +) +from generalresearch.incite.mergers.foundations import ( + annotate_product_and_team_id, +) +from generalresearch.incite.mergers.foundations.enriched_wall import ( + EnrichedWallMerge, +) +from generalresearch.incite.schemas.mergers.foundations.enriched_task_adjust import ( + EnrichedTaskAdjustSchema, +) +from generalresearch.pg_helper import PostgresConfig + +LOG = logging.getLogger("incite") + + +class EnrichedTaskAdjustMergeItem(MergeCollectionItem): + """Because a single wall event can have multiple "alerted" times, + we're basing the time event for the TaskAdjustDetailMergeCollection + off the wall.started timestamp. + """ + + def build( + self, + task_adj_coll: TaskAdjustmentDFCollection, + enriched_wall: EnrichedWallMerge, + pg_config: PostgresConfig, + client: Client, + client_resources=None, + ) -> None: + """ + TaskAdjustments are always partial because they could be revoked + at any moment + """ + + ir: pd.Interval = self.interval + start, end = ir.left.to_pydatetime(), ir.right.to_pydatetime() + + LOG.warning(f"EnrichedReconMergeItem.build({ir})") + + # --- Task Adjustments --- + LOG.warning(f"EnrichedReconMergeItem: get session_collection") + task_adj_coll_items = [ + w for w in task_adj_coll.items if w.interval.overlaps(ir) + ] + + if len(task_adj_coll_items) == 0: + raise Exception("TaskAdjColl item collection failed") + + ddf: dd.DataFrame = task_adj_coll.ddf( + items=task_adj_coll_items, + include_partial=True, + force_rr_latest=False, + columns=[ + "adjusted_status", + "amount", + "user_id", + "wall_uuid", + "source", + "survey_id", + "alerted", + "started", + ], + filters=[ + ("adjusted_status", "in", ("af", "ac")), + ("started", ">=", start), + ("started", "<", end), + ], + ) + # Naked compute... don't log + # LOG.info(f"TaskAdjustmentDetailMergeCollectionItem.rows: {len(ddf.index)}") + + # --- Join on the wall table --- # + + ew_items = [ew for ew in enriched_wall.items if ew.interval.overlaps(ir)] + + if len(ew_items) == 0: + raise Exception( + "EnrichedWall item collection failed for EnrichedTaskAdjColl" + ) + + wall_uuids = set( + client.compute(collections=ddf.wall_uuid.dropna().values, sync=True) + ) + wall_ddf = enriched_wall.ddf( + # I try to take the adjustments within this IntervalRange and + # figure out the respective range of when the surveys they're + # for were started. This should help limit how many enriched + # wall mergers needed to be loaded up + items=ew_items, + include_partial=True, + force_rr_latest=False, + columns=[ + "buyer_id", + "country_iso", + "device_type", + ], + filters=[("uuid", "in", wall_uuids)], + ) + + assert str(ddf.wall_uuid.dtype) == "string" + assert str(wall_ddf.index.dtype) == "string" + ddf = ddf.merge( + wall_ddf, + left_on="wall_uuid", + right_on="uuid", + how="left", + ) + + df = ( + ddf.sort_values("alerted") + .compute() + .groupby(["wall_uuid", "user_id", "source", "survey_id"]) + .agg( + amount=("amount", "sum"), + adjusted_status=("adjusted_status", "first"), + adjusted_status_last=("adjusted_status", "last"), + alerted=("alerted", "first"), + alerted_last=("alerted", "last"), + started=("started", "last"), + buyer_id=("buyer_id", "last"), + # Shouldn't matter (but some variation as of Sep 2024 -Max) + country_iso=("country_iso", "last"), + # Shouldn't matter (but some variation as of Sep 2024 -Max) + device_type=("device_type", "last"), + # Shouldn't matter (but some variation as of Sep 2024 -Max) + adjustments=("amount", "count"), + ) + .reset_index() + ) + df.index.rename("uuid") + + # --- Add the product_id + product_user_id --- + df = annotate_product_and_team_id(df=df, pg_config=pg_config) + + ddf = dd.from_pandas(df, npartitions=5) + self.to_archive_symlink( + client=client, + ddf=ddf, + is_partial=True, + validate_after=False, + client_resources=client_resources, + ) + + +class EnrichedTaskAdjustMerge(MergeCollection): + merge_type: Literal[MergeType.ENRICHED_TASK_ADJUST] = MergeType.ENRICHED_TASK_ADJUST + _schema = EnrichedTaskAdjustSchema + collection_item_class: Literal[EnrichedTaskAdjustMergeItem] = ( + EnrichedTaskAdjustMergeItem + ) + + def build( + self, + client: Client, + task_adjust_coll: TaskAdjustmentDFCollection, + enriched_wall: EnrichedWallMerge, + pg_config: PostgresConfig, + ) -> None: + """The Enriched TaskAdjustMerge is treated differently than most Merge + Collections because it requires some special consideration: + + - Due to Duplicate Removal issues - where the same task is Adjusted + multiple times, and due to the way Dask works.. we cannot break + this out into Items. The Task that is Tasked multiple times may + not be in the same Item so the aggregation would fail. + + - The thl_taskadjustment db table, and the task_adj DF Collection + are updated sequentially based on the + + """ + + LOG.info( + f"EnrichedTaskAdjustMerge.build(task_adj_coll={task_adjust_coll.signature()}, " + f"pg_config={pg_config})" + ) + + assert isinstance(task_adjust_coll, TaskAdjustmentDFCollection) + assert isinstance(enriched_wall, EnrichedWallMerge) + assert isinstance(pg_config, PostgresConfig) + + assert ( + len(self.items) == 1 + ), "EnrichedTaskAdjustMerge should only have 1 CollectionItem" + item: EnrichedTaskAdjustMergeItem = self.items[0] + + # item.build(client, user_coll=user_coll, client_resources=client_resources) + try: + item.build( + client=client, + task_adj_coll=task_adjust_coll, + enriched_wall=enriched_wall, + pg_config=pg_config, + ) + except (Exception,) as e: + capture_exception(error=e) + pass diff --git a/generalresearch/incite/mergers/foundations/enriched_wall.py b/generalresearch/incite/mergers/foundations/enriched_wall.py new file mode 100644 index 0000000..241a239 --- /dev/null +++ b/generalresearch/incite/mergers/foundations/enriched_wall.py @@ -0,0 +1,336 @@ +import logging +from datetime import timedelta +from typing import Literal, Optional, TYPE_CHECKING, List + +import dask.dataframe as dd +import pandas as pd +from distributed import Client + +from generalresearch.incite.collections.thl_web import ( + SessionDFCollection, + WallDFCollection, +) +from generalresearch.incite.mergers import ( + MergeCollection, + MergeCollectionItem, + MergeType, +) +from generalresearch.incite.mergers.foundations import annotate_product_id +from generalresearch.incite.schemas.admin_responses import ( + AdminPOPWallSchema, +) +from generalresearch.incite.schemas.mergers.foundations.enriched_wall import ( + EnrichedWallSchema, +) +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.thl.user import User +from generalresearch.pg_helper import PostgresConfig + +if TYPE_CHECKING: + from generalresearch.models.admin.request import ReportRequest + +LOG = logging.getLogger("incite") + + +class EnrichedWallMergeItem(MergeCollectionItem): + + def build( + self, + wall_coll: WallDFCollection, + session_coll: SessionDFCollection, + pg_config: PostgresConfig, + client: Optional[Client] = None, + client_resources=None, + ) -> None: + + ir: pd.Interval = self.interval + start, end = ir.left.to_pydatetime(), ir.right.to_pydatetime() + + LOG.warning(f"EnrichedWallMergeItem.build({ir})") + + # Skip which already exist + if self.has_archive(include_empty=True): + return None + + # --- Wall --- + LOG.warning(f"EnrichedWallMergeItem: get wall_collection") + wall_items = [w for w in wall_coll.items if w.interval.overlaps(ir)] + if len(wall_items) == 0: + LOG.warning(f"EnrichedWallMergeItem: no wall items. set_empty.") + if self.should_archive(): + self.set_empty() + return None + + wdf = wall_coll.ddf( + items=wall_items, + include_partial=True, + force_rr_latest=False, + columns=[ + "source", + "buyer_id", + "survey_id", + "session_id", + "started", + "finished", + "status", + "status_code_1", + "status_code_2", + "cpi", + "report_value", + "ext_status_code_1", + "ext_status_code_2", + "ext_status_code_3", + ], + filters=[("started", ">=", start), ("started", "<", end)], + ) + + if wdf is None: + return None + + wdf = wdf.repartition(npartitions=1) + wdf = wdf.reset_index(drop=False) + + # --- Sessions --- + LOG.warning(f"EnrichedWallMergeItem: merge session_collection") + session_items = [ + s + for s in session_coll.items + if s.interval.overlaps( + pd.Interval( + ir.left - timedelta(hours=6), + ir.right + timedelta(hours=6), + closed="both", + ) + ) + ] + + if len(session_items) == 0: + LOG.error(f"EnrichedWallMergeItem: no session items. breaking early.") + return None + + sdf = session_coll.ddf( + items=session_items, + include_partial=True, + force_rr_latest=False, + columns=["user_id", "country_iso", "device_type", "payout"], + filters=[ + ("started", ">=", start - timedelta(hours=6)), + ("started", "<", end + timedelta(hours=6)), + ], + ) + sdf = sdf.repartition(npartitions=1) + + mddf = dd.merge( + wdf, + sdf, + left_on="session_id", + right_index=True, + how="left", + npartitions=1, + ) + client.persist(mddf) + + # --- Add product_id for the user --- # + expected_df = mddf.copy() + expected_df["product_id"] = pd.Series(dtype="str") + res: dd.DataFrame = mddf.map_partitions( + annotate_product_id, pg_config, meta=expected_df + ) + + # --- cleanup --- + df: pd.DataFrame = client.compute(collections=res, sync=True) + df = df[df["started"].between(start, end)] + df = df.set_index("uuid") + + # is_missing = df[['product_id', 'session_id']].isna().sum().sum() > 0 + is_missing = False + df = df.dropna(subset=["product_id", "session_id"], how="any") + + wall_is_partial = any([w.should_archive() is False for w in wall_items]) + is_partial = is_missing or wall_is_partial + + # Lots of downstream issues with this... + df = df[df.product_id.notna()] + df.product_id = df.product_id.astype(str) + + if is_partial: + ddf = dd.from_pandas(df, npartitions=5) + self.to_archive_symlink( + client, + ddf=ddf, + is_partial=True, + validate_after=False, + client_resources=client_resources, + ) + else: + df = self.validate_df(df=df) + ddf = dd.from_pandas(df, npartitions=5) + self.to_archive( + client, + ddf=ddf, + is_partial=False, + client_resources=client_resources, + ) + + +class EnrichedWallMerge(MergeCollection): + merge_type: Literal[MergeType.ENRICHED_WALL] = MergeType.ENRICHED_WALL + _schema = EnrichedWallSchema + collection_item_class: Literal[EnrichedWallMergeItem] = EnrichedWallMergeItem + + def build( + self, + client: Client, + wall_coll: WallDFCollection, + session_coll: SessionDFCollection, + pg_config: PostgresConfig, + ) -> None: + + LOG.info( + f"EnrichedWallMerge.build(wall_coll={wall_coll.signature()}, " + f"session_coll={session_coll.signature()}, " + f"pg_config={pg_config})" + ) + + assert isinstance(wall_coll, WallDFCollection) + assert isinstance(session_coll, SessionDFCollection) + assert isinstance(pg_config, PostgresConfig) + + for item in reversed(self.items): + if item.has_archive(include_empty=True): + continue + + LOG.info(item) + item.build( + client=client, + wall_coll=wall_coll, + session_coll=session_coll, + pg_config=pg_config, + ) + + # This does not work. deadlocks. I need to submit them gradually or something + # fs = [] + # for item in self.items: + # if item.has_archive(include_empty=True): + # continue + # if not item.should_archive(): + # continue + # f = dask.delayed(item.build)(wall_coll=wall_coll, session_coll=session_coll, + # user_id_product=user_id_product) + # fs.append(f) + # + # # self = enriched_wall + # # item = self.items[0] + # # fs = [dask.delayed(item.build)(wall_coll=wall_coll, session_coll=session_coll, + # # user_id_product=user_id_product)] + # res = client.compute(collections=fs, sync=True, priority=1) + # return res + + def to_admin_response( + self, + rr: "ReportRequest", + client: Client, + product_ids: Optional[List[UUIDStr]] = None, + user: Optional[User] = None, + ) -> pd.DataFrame: + """We don't have the concept of a Team yet so product_ids will be a list""" + + filters = [] + + if user: + assert ( + len(product_ids) <= 1 + ), "Can't search more than 1 Product ID for a specific User" + assert ( + user.product_id in product_ids + ), "The provided user must be associated with the Product ID" + filters.append( + ("user_id", "==", user.user_id), + ) + + if product_ids: + assert ( + len(product_ids) >= 1 + ), "Don't provide an empty list. Pass None if SELECT ALL is desired" + filters.append( + ("product_id", "in", product_ids), + ) + + ddf = self.ddf( + force_rr_latest=False, + include_partial=True, + columns=[ + "product_id", + "user_id", + "source", + "buyer_id", + "survey_id", + "session_id", + "started", + "finished", + "status", + "status_code_1", + "status_code_2", + "country_iso", + "device_type", + "payout", + ], + filters=filters, + ) + + if ddf is None: + from generalresearch.incite.schemas import ( + empty_dataframe_from_schema, + ) + + return empty_dataframe_from_schema(schema=EnrichedWallSchema) + + ddf["elapsed"] = (ddf["finished"] - ddf["started"]).dt.total_seconds() + + ddf["status"] = ddf.status.fillna("e") + ddf["status_code_1"] = ddf.status_code_1.fillna(0) + ddf["status_code_2"] = ddf.status_code_2.fillna(0) + ddf["complete"] = ddf.status.eq("c") + + dfa = client.compute( + collections=ddf, + sync=True, + priority=1, + ) + + # --- Add wall index per session -- + assert rr.interval == "5min" + group_arr = [pd.Grouper(key="started", freq=rr.interval), rr.index1] + + df = dfa.groupby(group_arr).aggregate( + elapsed_avg=("elapsed", "mean"), + elapsed_total=("elapsed", "sum"), + payout_total=("payout", "sum"), + entrances=("complete", "size"), + completes=("complete", "sum"), + users=("user_id", "nunique"), + buyers=("buyer_id", "nunique"), + surveys=("survey_id", "nunique"), + sessions=("session_id", "nunique"), + ) + + # Completes only + df_completes = ( + dfa[(dfa.status == "c") & (dfa.payout > 0)] + .groupby(group_arr) + .aggregate( + payout_avg=("payout", "mean"), + elapsed_avg=("elapsed", "mean"), + elapsed_total=("elapsed", "sum"), + ) + ) + + df["payout_avg"] = df_completes.payout_avg + df["conversion"] = df.completes / df.entrances # system conversion + df["epc"] = df.payout_total / df.entrances # earnings per click + df["eph"] = df.payout_total / (df.elapsed_total / 3_600) # earnings per hour + # df["eph"] = df.payout_total / (df_completes.elapsed_total / 3_600) # earnings per hour + df["cpc"] = df_completes.payout_avg * df.conversion + + df.index = df.index.set_names(names=["index0", "index1"]) + return AdminPOPWallSchema.validate(df).fillna(0) diff --git a/generalresearch/incite/mergers/foundations/user_id_product.py b/generalresearch/incite/mergers/foundations/user_id_product.py new file mode 100644 index 0000000..e467179 --- /dev/null +++ b/generalresearch/incite/mergers/foundations/user_id_product.py @@ -0,0 +1,49 @@ +import logging +from typing import Literal + +from distributed import Client + +from generalresearch.incite.collections.thl_web import UserDFCollection +from generalresearch.incite.mergers import ( + MergeCollectionItem, + MergeCollection, + MergeType, +) + +LOG = logging.getLogger("incite") + + +class UserIdProductMergeItem(MergeCollectionItem): + + def build( + self, client: Client, user_coll: UserDFCollection, client_resources=None + ) -> None: + LOG.warning(f"UserIdProductMergeItem.build({self.interval})") + + udf = user_coll.ddf( + include_partial=True, force_rr_latest=False, columns=["product_id"] + ) + udf = udf.repartition(npartitions=40) + udf = udf.categorize(columns=["product_id"]) + # This is the best way I think. Each worker can read, categorize, + # and write its own chunk, and data doesn't have to be sent back + # and forth. We can validate the df afterward! + self.to_archive_symlink(client, client_resources=client_resources, ddf=udf) + + +class UserIdProductMerge(MergeCollection): + merge_type: Literal[MergeType.USER_ID_PRODUCT] = MergeType.USER_ID_PRODUCT + collection_item_class: Literal[UserIdProductMergeItem] = UserIdProductMergeItem + offset: None = None + + def build( + self, client: Client, user_coll: UserDFCollection, client_resources=None + ) -> None: + LOG.info(f"UserIdProductMerge.build(user_coll={user_coll.signature()})") + + assert ( + len(self.items) == 1 + ), "UserIdProductMerge should only have 1 CollectionItem" + item: UserIdProductMergeItem = self.items[0] + + item.build(client, user_coll=user_coll, client_resources=client_resources) diff --git a/generalresearch/incite/mergers/nginx_core.py b/generalresearch/incite/mergers/nginx_core.py new file mode 100644 index 0000000..1f471a6 --- /dev/null +++ b/generalresearch/incite/mergers/nginx_core.py @@ -0,0 +1,146 @@ +import json +import logging +from datetime import datetime, timedelta +from typing import Literal, List +from urllib.parse import parse_qs, urlsplit + +import dask.bag as db +import pandas as pd +from sentry_sdk import capture_exception + +from generalresearch.incite.mergers import ( + MergeCollection, + MergeCollectionItem, + MergeType, +) +from generalresearch.incite.schemas.mergers.nginx import NGINXCoreSchema +from generalresearch.models.thl.definitions import ReservedQueryParameters + +LOG = logging.getLogger("incite") + +uuid4hex = r"[a-f0-9]{8}-?[a-f0-9]{4}-?4[a-f0-9]{3}-?[89ab][a-f0-9]{3}-?[a-f0-9]{12}" + + +class NginxCoreMergeItem(MergeCollectionItem): + + def build(self) -> None: + ir: pd.Interval = self.interval + is_partial = not self.should_archive() + coll: MergeCollection = self._collection + + start, end = ir.left.to_pydatetime(), ir.right.to_pydatetime() + __name__ = coll.merge_type.value + + reserved_kwargs = set([e.value for e in ReservedQueryParameters]) | set( + ["AC5AD0DDBC0C", "66482fb"] + ) + LOG.info(f"{__name__}: {self._collection._client} {reserved_kwargs=}") + + # --- READ --- + _start = start.replace(hour=0) + _end = end.replace(hour=0) + days: List[str] = [ + (start + timedelta(days=i)).strftime("%Y-%m-%d") + for i in range((end - start).days + 1) + ] + LOG.info(f"{__name__}: READ start") + lines = db.read_text( + urlpath=[f"/tmp/thl-core-logs/access.log-{day}-*.gz" for day in days], + compression="gzip", + include_path=False, + ) + + # --- PROCESS --- + LOG.info(f"{__name__}: PROCESS start") + + def process_core_entry(x: dict) -> dict: + request: str = x["request"].split(" ")[1] # GET full_url_path HTTP/1.1.1 + referer: str = x["referer"] + + request_split = urlsplit(request) + request_query_dict = parse_qs(request_split.query) + + # -- couldn't get to work well with .astype. I know the 0 or 0.0 isn't good but too frustrating for now + try: + upstream_status = int(x["upstream_status"]) + except (Exception,): + upstream_status = 0 + + try: + status = int(x["status"]) + except (Exception,): + status = 0 + + try: + request_time = float(x["request_time"]) + except (Exception,): + request_time = 0.0 + + try: + upstream_response_time = float(x["upstream_response_time"]) + except (Exception,): + upstream_response_time = 0.0 + + return { + "time": datetime.fromtimestamp(float(x["time"])), + "method": x.get("method", None), + "user_agent": x.get("user_agent", None), + "upstream_route": x.get("upstream_route", None), + "host": x.get("host", None), + "upstream_status": upstream_status, + "status": status, + "request_time": request_time, + "upstream_response_time": upstream_response_time, + "upstream_cache_hit": x.get("upstream_cache_hit") == "True", + # GRL custom + "request_path": request_split.path, + "referer": referer, + "session_id": request_query_dict.get("AC5AD0DDBC0C", [None])[0], + "request_id": request_query_dict.get("66482fb", [None])[0], + "nudge_id": request_query_dict.get("5e0e0323", [None])[0], + "request_custom_query_params": ",".join( + [ + qk + for qk in request_query_dict.keys() + if qk not in reserved_kwargs + ] + ), + } + + LOG.info(f"{__name__}: PROCESS - records maps") + records = lines.map(json.loads).map(process_core_entry) + LOG.info(f"{__name__}: PROCESS - .to_dataframe()") + ddf = records.to_dataframe() + + # --- for "partition_on" --- + ddf = ddf[ + ddf["time"].between(ir.left.to_datetime64(), ir.right.to_datetime64()) + ] + ddf = ddf.repartition(npartitions=1) + + # --- SAVE --- + LOG.info(f"{__name__}: SAVE start") + self.to_archive(ddf=ddf, is_partial=is_partial) + LOG.info(f"{__name__}: SAVE end") + + return None + + +class NginxCoreMerge(MergeCollection): + merge_type: Literal[MergeType.NGINX_CORE] = MergeType.NGINX_CORE + _schema = NGINXCoreSchema + collection_item_class = NginxCoreMergeItem + + def build(self) -> None: + LOG.info(f"NginxCoreMerge.build()") + + for item in reversed(self.items): + item: NginxCoreMergeItem + + try: + item.build() + except (Exception,) as e: + capture_exception(error=e) + pass + + return None diff --git a/generalresearch/incite/mergers/nginx_fsb.py b/generalresearch/incite/mergers/nginx_fsb.py new file mode 100644 index 0000000..1f1039b --- /dev/null +++ b/generalresearch/incite/mergers/nginx_fsb.py @@ -0,0 +1,151 @@ +import json +import logging + +from sentry_sdk import capture_exception +import re +from datetime import datetime, timedelta +from typing import Literal, List +from urllib.parse import parse_qs, urlsplit + +import dask.bag as db +import pandas as pd + +from generalresearch.incite.mergers import ( + MergeCollection, + MergeCollectionItem, + MergeType, +) +from generalresearch.incite.schemas.mergers.nginx import NGINXFSBSchema +from generalresearch.models.thl.definitions import ReservedQueryParameters + +LOG = logging.getLogger("incite") + +uuid4hex = r"[a-f0-9]{8}-?[a-f0-9]{4}-?4[a-f0-9]{3}-?[89ab][a-f0-9]{3}-?[a-f0-9]{12}" + + +class NginxFSBMergeItem(MergeCollectionItem): + + def build(self) -> None: + ir: pd.Interval = self.interval + is_partial = not self.should_archive() + coll: MergeCollection = self._collection + + start, end = ir.left.to_pydatetime(), ir.right.to_pydatetime() + __name__ = coll.merge_type.value + + reserved_kwargs = set([e.value for e in ReservedQueryParameters]) + LOG.info(f"{__name__}: {coll._client} {reserved_kwargs=}") + + # --- READ --- + _start = start.replace(hour=0) + _end = end.replace(hour=0) + days: List[str] = [ + (start + timedelta(days=i)).strftime("%Y-%m-%d") + for i in range((end - start).days + 1) + ] + LOG.info(f"{__name__}: READ start: {days=}") + lines = db.read_text( + urlpath=[f"/tmp/fsb-logs/access.log-{day}-*.gz" for day in days], + compression="gzip", + include_path=False, + ) + + # --- PROCESS --- + LOG.info(f"{__name__}: PROCESS start") + + def process_fsb_entry(x: dict) -> dict: + request: str = x["request"].split(" ")[1] # GET full_url_path HTTP/1.1.1 + url_split = urlsplit(request) + query_dict = parse_qs(url_split.query) + product_ids = re.findall(uuid4hex, request) + product_id = ( + product_ids[0] if len(product_ids) else "-" + ) # Cannot (categorize) convert non-finite values + is_offerwall = "/offerwall/" in url_split.path + offerwall = "-" # Cannot (categorize) convert non-finite values + if is_offerwall: + offerwall = url_split.path.split("/offerwall/")[1][:-1] or "-" + is_report = "/report/" in url_split.path + + # -- couldn't get to work well with .astype. I know the 0 or 0.0 isn't good but too frustrating for now + try: + status = int(x["status"]) + except (Exception,): + status = 0 + + try: + upstream_status = int(x["upstream_status"]) + except (Exception,): + upstream_status = 0 + + try: + request_time = float(x["request_time"]) + except (Exception,): + request_time = 0.0 + + try: + upstream_response_time = float(x["upstream_response_time"]) + except (Exception,): + upstream_response_time = 0.0 + + return { + "time": datetime.fromtimestamp(float(x["time"])), + "method": x.get("method", None), + "user_agent": x.get("user_agent", None), + "upstream_route": x.get("upstream_route", None), + "host": x.get("host", None), + "status": status, + "upstream_status": upstream_status, + "request_time": request_time, + "upstream_response_time": upstream_response_time, + "upstream_cache_hit": x.get("upstream_cache_hit") == "True", + # GRL custom + "product_id": product_id, + "product_user_id": query_dict.get("bpuid", [None])[0], + "n_bins": query_dict.get("n_bins", [None])[0], + "is_offerwall": is_offerwall, + "offerwall": offerwall, + "is_report": is_report, + "custom_query_params": ",".join( + [qk for qk in query_dict.keys() if qk not in reserved_kwargs] + ), + } + + LOG.info(f"{__name__}: PROCESS - records maps") + records = lines.map(json.loads).map(process_fsb_entry) + LOG.info(f"{__name__}: PROCESS - .to_dataframe()") + ddf = records.to_dataframe() + + # -- for "partition_on" + LOG.info(f"{__name__}: PROCESS - cleanup") + ddf = ddf[ + ddf["time"].between(ir.left.to_datetime64(), ir.right.to_datetime64()) + ] + ddf = ddf.repartition(npartitions=1) + + # --- SAVE --- + LOG.info(f"{__name__}: SAVE start") + self.to_archive(ddf=ddf, is_partial=is_partial) + LOG.info(f"{__name__}: SAVE finish") + + return None + + +class NginxFSBMerge(MergeCollection): + merge_type: Literal[MergeType.NGINX_FSB] = MergeType.NGINX_FSB + _schema = NGINXFSBSchema + collection_item_class = NginxFSBMergeItem + + def build(self) -> None: + LOG.info(f"NginxFSBMerge.build()") + + for item in reversed(self.items): + item: NginxFSBMergeItem + + try: + item.build() + except (Exception,) as e: + capture_exception(error=e) + pass + + return None diff --git a/generalresearch/incite/mergers/nginx_grs.py b/generalresearch/incite/mergers/nginx_grs.py new file mode 100644 index 0000000..0242b0b --- /dev/null +++ b/generalresearch/incite/mergers/nginx_grs.py @@ -0,0 +1,141 @@ +import json +import logging +from datetime import datetime, timedelta +from typing import Literal, List +from urllib.parse import parse_qs, urlsplit + +import dask.bag as db +import dask.dataframe as dd +from sentry_sdk import capture_exception + +from generalresearch.incite.mergers import ( + MergeCollection, + MergeCollectionItem, + MergeType, +) +from generalresearch.incite.schemas.mergers.nginx import NGINXGRSSchema + +LOG = logging.getLogger("incite") + + +uuid4hex = r"[a-f0-9]{8}-?[a-f0-9]{4}-?4[a-f0-9]{3}-?[89ab][a-f0-9]{3}-?[a-f0-9]{12}" + + +class NginxGRSMergeItem(MergeCollectionItem): + + def build(self) -> None: + ir = self.interval + coll: MergeCollection = self._collection + is_partial = not self.should_archive() + + start, end = ir.left.to_pydatetime(), ir.right.to_pydatetime() + __name__ = self._collection.merge_type.value + + reserved_kwargs = set(["39057c8b", "c184efc0", "0bb50182"]) + LOG.info(f"{__name__}: {coll._client} {reserved_kwargs=}") + + # --- READ --- + _start = start.replace(hour=0) + _end = end.replace(hour=0) + days: List[str] = [ + (start + timedelta(days=i)).strftime("%Y-%m-%d") + for i in range((end - start).days + 1) + ] + LOG.info(f"{__name__}: READ start: {days}") + lines = db.read_text( + urlpath=[f"/tmp/grs-logs/access.log-{day}-*.gz" for day in days], + compression="gzip", + include_path=False, + ) + + # --- PROCESS --- + LOG.info(f"{MergeType.NGINX_GRS.value}: PROCESS start") + + def process_grs_entry(x: dict) -> dict: + request: str = x["request"].split(" ")[1] # GET full_url_path HTTP/1.1.1 + + referer_split = urlsplit(x["referer"]) + referer_query_dict = parse_qs(referer_split.query) + + # -- couldn't get to work well with .astype. I know the 0 or 0.0 isn't good but too frustrating for now + try: + upstream_status = int(x["upstream_status"]) + except (Exception,): + upstream_status = 0 + + try: + status = int(x["status"]) + except (Exception,): + status = 0 + + try: + request_time = float(x["request_time"]) + except (Exception,): + request_time = 0.00 + + try: + upstream_response_time = float(x["upstream_response_time"]) + except (Exception,): + upstream_response_time = 0.0 + return { + "time": datetime.fromtimestamp(float(x["time"])), + "method": x.get("method", None), + "user_agent": x.get("user_agent", None), + "upstream_route": x.get("upstream_route", None), + "host": x.get("host", None), + "status": status, + "upstream_status": upstream_status, + "request_time": request_time, + "upstream_response_time": upstream_response_time, + "upstream_cache_hit": x.get("upstream_cache_hit") == "True", + # GRL custom + "product_id": referer_query_dict.get("39057c8b", [None])[0], + "product_user_id": referer_query_dict.get("c184efc0", [None])[0], + "wall_uuid": referer_query_dict.get("0bb50182", [None])[0], + "custom_query_params": ",".join( + [ + qk + for qk in referer_query_dict.keys() + if qk not in reserved_kwargs + ] + ), + } + + LOG.info(f"{__name__}: PROCESS - records maps") + records = lines.map(json.loads).map(process_grs_entry) + LOG.info(f"{__name__}: PROCESS - .to_dataframe()") + ddf: dd.DataFrame = records.to_dataframe() + + # -- for "partition_on" + LOG.info(f"{__name__}: PROCESS - cleanup") + ddf = ddf[ + ddf["time"].between(ir.left.to_datetime64(), ir.right.to_datetime64()) + ] + ddf = ddf.repartition(npartitions=1) + + # --- SAVE --- + LOG.info(f"{__name__}: SAVE start") + self.to_archive(ddf=ddf, is_partial=is_partial) + LOG.info(f"{__name__}: SAVE finish") + + return None + + +class NginxGRSMerge(MergeCollection): + merge_type: Literal[MergeType.NGINX_GRS] = MergeType.NGINX_GRS + _schema = NGINXGRSSchema + collection_item_class = NginxGRSMergeItem + + def build(self) -> None: + LOG.info(f"NginxGRSMerge.build()") + + for item in reversed(self.items): + item: NginxGRSMergeItem + + try: + item.build() + except (Exception,) as e: + capture_exception(e) + pass + + return None diff --git a/generalresearch/incite/mergers/pop_ledger.py b/generalresearch/incite/mergers/pop_ledger.py new file mode 100644 index 0000000..4915abb --- /dev/null +++ b/generalresearch/incite/mergers/pop_ledger.py @@ -0,0 +1,131 @@ +import logging +from typing import Literal, Optional + +import dask.dataframe as dd +import pandas as pd +from distributed import Client +from more_itertools import flatten + +from generalresearch.incite.collections.thl_web import LedgerDFCollection +from generalresearch.incite.mergers import ( + MergeCollection, + MergeCollectionItem, + MergeType, +) +from generalresearch.incite.schemas.mergers.pop_ledger import PopLedgerSchema +from generalresearch.models.thl.ledger import Direction, TransactionType + +LOG = logging.getLogger("incite") + + +class PopLedgerMergeItem(MergeCollectionItem): + + def build( + self, + ledger_coll: LedgerDFCollection, + client: Optional[Client] = None, + client_resources=None, + ) -> None: + ir: pd.Interval = self.interval + + is_partial = not self.should_archive() + start, end = ir.left.to_pydatetime(), ir.right.to_pydatetime() + + ledger_items = [ + s + for s in ledger_coll.items + if s.interval.overlaps(pd.Interval(ir.left, ir.right, closed="left")) + ] + + ddf = ledger_coll.ddf( + items=ledger_items, + include_partial=True, + force_rr_latest=False, + ) + + if ddf is None: + return None + + ddf = ddf[ddf["created"].between(start, end)] + df: pd.DataFrame = client.compute(ddf, resources=client_resources, sync=True) + + if df.empty: + # self.set_empty() + return None + + df["direction_name"] = df["direction"].apply(lambda x: Direction(x).name) + + # The smallest "unit time interval" supported by this merge. It can be + # resampled to anything larger in the future, but not smaller. We use + # the dt.floor so the intervals do not overlap + df["time_idx"] = df["created"].dt.floor("1min") + + # For each time interval and Ledger Account (this is different from a + # product_id), we want the raw amounts and their respective direction + # for every type of transaction that is possible + x = ( + df.groupby(by=["time_idx", "account_id", "tx_type", "direction_name"]) + .amount.sum() + .reset_index() + ) + + # We want to keep the positive and negatives for each type. For example, + # for bp_adjustment, we want to know the amount increased and the + # amount decreased, not just the net. + x["tx_type.direction"] = x["tx_type"] + "." + x["direction_name"] + s = ( + x.pivot_table( + index=["time_idx", "account_id"], + columns="tx_type.direction", + values="amount", + aggfunc="sum", + ) + .fillna(0) + .reset_index() + ) + + columns = set( + flatten( + [[e.value + ".CREDIT", e.value + ".DEBIT"] for e in TransactionType] + ) + ) + s = s.reindex(columns=columns | set(s.columns)).fillna(0) + s = s.reset_index(drop=True) + s.index.name = "id" + # The "columns were named" tx_type.direction from the above pivot. This + # made it confusing when viewing in a console, so we rename it here, + # it doesn't provide any functional change + s.columns.name = None + + s = PopLedgerSchema.validate(s) + ddf = dd.from_pandas(s, npartitions=1) + + if is_partial: + self.to_archive_symlink( + client=client, + ddf=ddf, + is_partial=True, + validate_after=False, + client_resources=client_resources, + ) + else: + self.to_archive(client=client, ddf=ddf) + + +class PopLedgerMerge(MergeCollection): + merge_type: Literal[MergeType.POP_LEDGER] = MergeType.POP_LEDGER + _schema = PopLedgerSchema + collection_item_class: Literal[PopLedgerMergeItem] = PopLedgerMergeItem + + def build(self, client: Client, ledger_coll: LedgerDFCollection) -> None: + + LOG.info(f"PopLedgerMerge.build(wall_coll={ledger_coll.signature()}") + + assert isinstance(ledger_coll, LedgerDFCollection) + + for item in reversed(self.items): + if item.has_archive(include_empty=True): + continue + + LOG.debug(msg=item) + item.build(client=client, ledger_coll=ledger_coll) diff --git a/generalresearch/incite/mergers/ym_survey_wall.py b/generalresearch/incite/mergers/ym_survey_wall.py new file mode 100644 index 0000000..7f6c31c --- /dev/null +++ b/generalresearch/incite/mergers/ym_survey_wall.py @@ -0,0 +1,149 @@ +import logging +from datetime import timedelta +from typing import Optional, Literal + +import dask.dataframe as dd +import pandas as pd +from distributed import Client +from sentry_sdk import capture_exception + +from generalresearch.incite.collections.thl_web import WallDFCollection +from generalresearch.incite.mergers import ( + MergeCollection, + MergeType, + MergeCollectionItem, +) +from generalresearch.incite.mergers.foundations.enriched_session import ( + EnrichedSessionMerge, +) +from generalresearch.incite.schemas.mergers.ym_survey_wall import ( + YMSurveyWallSchema, +) +from generalresearch.models.custom_types import AwareDatetimeISO + +LOG = logging.getLogger("incite") + + +class YMSurveyWallMergeCollectionItem(MergeCollectionItem): + + def build( + self, + wall_coll: WallDFCollection, + enriched_session: EnrichedSessionMerge, + client: Optional[Client] = None, + client_resources=None, + ) -> None: + LOG.info(f"YMSurveyWallMerge.build({self.start=}, {self.finish=})") + ir: pd.Interval = self.interval + start, _ = self.start, self.finish + ddf = wall_coll.ddf( + items=wall_coll.get_items(start), + force_rr_latest=False, + include_partial=True, + columns=[ + "source", + "buyer_id", + "survey_id", + "session_id", + "started", + "finished", + "status", + "status_code_1", + "status_code_2", + "cpi", + "report_value", + "ext_status_code_1", + "ext_status_code_2", + "ext_status_code_3", + ], + filters=[("started", ">=", start)], + ) + ddf = ddf[ddf["started"] > start] + + LOG.warning(f"YMSurveyWallMerge: merge session_collection") + session_items = [ + s + for s in enriched_session.items + if s.interval.overlaps( + pd.Interval(ir.left - timedelta(hours=2), ir.right, closed="both") + ) + ] + sdf = enriched_session.ddf( + items=session_items, + include_partial=True, + force_rr_latest=False, + columns=[ + "user_id", + "country_iso", + "device_type", + "payout", + "product_id", + "team_id", + ], + filters=[ + ("started", ">=", start - timedelta(hours=2)), + ], + ) + ddf = dd.merge( + ddf, sdf, left_on="session_id", right_index=True, how="left", npartitions=5 + ) + + df = client.compute(ddf, resources=client_resources, sync=True) + df["elapsed"] = (df["finished"] - df["started"]).dt.total_seconds() + df["elapsed"] = df["elapsed"].round().astype("Int64") + df = df.drop(columns={"finished", "payout"}, errors="ignore") + df.dropna(subset="user_id", how="any", inplace=True) + df.dropna(subset="product_id", how="any", inplace=True) + df.sort_values(by="started", inplace=True) + + LOG.debug(f"YMSurveyWallMerge.build() validation") + + df = self.validate_df(df=df) + if df is not None: + ddf = dd.from_pandas(df, npartitions=4) + LOG.info(f"YMSurveyWallMerge.build() saving") + self.to_archive_symlink(client=client, ddf=ddf) + else: + LOG.warning("YMSurveyWallMerge failed validation") + + return None + + +class YMSurveyWallMerge(MergeCollection): + merge_type: Literal[MergeType.YM_SURVEY_WALL] = MergeType.YM_SURVEY_WALL + collection_item_class: Literal[YMSurveyWallMergeCollectionItem] = ( + YMSurveyWallMergeCollectionItem + ) + start: Optional[AwareDatetimeISO] = None + offset: str = "10D" + _schema = YMSurveyWallSchema + + def build( + self, + client: Client, + wall_coll: WallDFCollection, + enriched_session: EnrichedSessionMerge, + client_resources=None, + ) -> None: + + LOG.info( + f"YMSurveyWallMerge.build(wall_coll={wall_coll.signature()}, " + f"enriched_session={enriched_session.signature()})" + ) + assert ( + len(self.items) == 1 + ), "YMSurveyWallMerge can't have more than 1 CollectionItem." + item: YMSurveyWallMergeCollectionItem = self.items[0] + + try: + item.build( + client=client, + client_resources=client_resources, + wall_coll=wall_coll, + enriched_session=enriched_session, + ) + except (Exception,) as e: + capture_exception(error=e) + pass + + item.delete_dangling_partials(keep_latest=2, target_path=item.path) diff --git a/generalresearch/incite/mergers/ym_wall_summary.py b/generalresearch/incite/mergers/ym_wall_summary.py new file mode 100644 index 0000000..419994d --- /dev/null +++ b/generalresearch/incite/mergers/ym_wall_summary.py @@ -0,0 +1,195 @@ +from datetime import timedelta, datetime, time +from typing import Literal, List, Optional, Type + +import dask.dataframe as dd +import pandas as pd +from pydantic import Field, field_validator +from sentry_sdk import capture_exception + +from generalresearch.incite.collections.thl_web import ( + SessionDFCollection, + WallDFCollection, +) +from generalresearch.incite.mergers import ( + MergeCollection, + MergeType, + MergeCollectionItem, +) +from generalresearch.incite.mergers.foundations.user_id_product import ( + UserIdProductMerge, +) +from generalresearch.incite.schemas.mergers.ym_wall_summary import ( + YMWallSummarySchema, +) +from generalresearch.models.thl.definitions import Status, StatusCode1 + + +class YMWallSummaryMergeItem(MergeCollectionItem): + + def fetch( + self, + wall_collection: WallDFCollection, + session_collection: SessionDFCollection, + user_id_product: UserIdProductMerge, + ): + ir = self.interval + start, end = ir.left.to_pydatetime(), ir.right.to_pydatetime() + + wall_items = [w for w in wall_collection.items if w.interval.overlaps(ir)] + ddf = wall_collection.ddf( + items=wall_items, force_rr_latest=False, include_partial=True + ) + ddf = ddf[ddf["started"].between(start, end)] + + # Then we need the sessions for these wall events. They'll have started + # up to 90 min before the wall event. + session_ir = pd.Interval( + ir.left - timedelta(minutes=90), ir.right, closed="left" + ) + session_items = [ + s for s in session_collection.items if s.interval.overlaps(session_ir) + ] + sddf = session_collection.ddf( + items=session_items, + force_rr_latest=False, + include_partial=True, + columns=["user_id", "country_iso", "device_type"], + ) + df: pd.DataFrame = self.compute(ddf.join(sddf, on="session_id", how="left")) + user_ids = set(df.user_id.dropna().unique()) + + udf = self.compute( + user_id_product.ddf(filters=[("id", "in", user_ids)], include_partial=True) + ) + + x = udf.loc[udf.index.isin(user_ids)].copy() + x["product_id"] = x["product_id"].astype(str) + df = df.join(x, on="user_id", how="left") + + df["date"] = df["started"].dt.strftime("%Y-%m-%d") + df = YMWallSummaryMerge.build_groupbys(df) + + self._collection._schema.validate(df) + + is_partial = not self.should_archive() + ddf = dd.from_pandas(df, npartitions=4) + self.to_archive(ddf, is_partial=is_partial) + return df + + +class YMWallSummaryMerge(MergeCollection): + merge_type: Literal[MergeType.YM_WALL_SUMMARY] = MergeType.YM_WALL_SUMMARY + _schema = YMWallSummarySchema + collection_item_class: Type[YMWallSummaryMergeItem] = YMWallSummaryMergeItem + items: List[YMWallSummaryMergeItem] = Field(default_factory=list) + + @field_validator("offset") + def check_offset_ym_wall_summary(cls, v: Optional[str]): + # the offset MUST be on a whole day, no hourly + assert v.endswith("D"), "offset must be in days" + return v + + @field_validator("start") + def check_start_ym_wall_summary(cls, v: Optional[datetime]): + # the start MUST be start on midnight exactly + assert v.time() == time(0, 0, 0, 0), "start must no have a time component" + return v + + def build( + self, + wall_collection: WallDFCollection, + session_collection: SessionDFCollection, + user_id_product: UserIdProductMerge, + ) -> None: + + for item in reversed(self.items): + item: YMWallSummaryMergeItem + + # Skip which already exist + if item.has_archive(): + continue + + try: + # TODO: How should we do this generically? + # We're going to assume that we want to update the "latest" + # item every time build is run even if it isn't closed + # if item.should_archive(): + item.fetch(wall_collection, session_collection, user_id_product) + except (Exception,) as e: + capture_exception(e) + pass + + @staticmethod + def build_groupbys(df: pd.DataFrame) -> pd.DataFrame: + gb_cols = ["date", "product_id", "buyer_id", "country_iso", "source"] + status_cols = [ + "Status.ABANDON", + "Status.COMPLETE", + "Status.FAIL", + "Status.TIMEOUT", + "StatusCode1.BUYER_FAIL", + ] + df.loc[df.status.isnull(), "status"] = Status.TIMEOUT.value + gbs = [ + ["date", "source"], + ["date", "source", "country_iso"], + ["date", "source", "product_id"], + ["date", "source", "buyer_id"], + ["date", "source", "product_id", "country_iso"], + ["date", "source", "buyer_id", "country_iso"], + ] + + gdf = pd.DataFrame(columns=gb_cols + status_cols) + gdf = gdf.astype({k: "string" for k in gb_cols} | {k: int for k in status_cols}) + for gb in gbs: + s = df.groupby(gb)["status"].value_counts() + s = s.reset_index().pivot_table(index=gb, columns="status", values="count") + bf = ( + df[df["status_code_1"] == StatusCode1.BUYER_FAIL.value] + .groupby(gb) + .size() + .rename("StatusCode1.BUYER_FAIL") + ) + s = s.join(bf) + s = ( + s.rename( + columns={ + "a": "Status.ABANDON", + "c": "Status.COMPLETE", + "f": "Status.FAIL", + "t": "Status.TIMEOUT", + } + ) + .reset_index() + .rename_axis(None, axis=1) + ) + s = s.reindex(columns=list(set(status_cols) | set(s.columns))).fillna(0) + s = s.reindex(columns=list(set(gb_cols) | set(s.columns))) + s = s.astype({k: "string" for k in gb_cols} | {k: int for k in status_cols}) + gdf = pd.concat([gdf, s]) + return gdf + + def save(self) -> None: + # Once we build all the daily files, we can package them all up into 1 file + # df = pq.ParquetDataset(self.archive_path).read().to_pandas() + # df.to_parquet(str(self.archive_path) + ".all.parquet") + pass + + def get_counts(self, product_id): + # examples... + product_id = "" + df = dd.read_parquet( + str(self.archive_path) + ".all.parquet", + filters=[ + ("product_id", "=", product_id), + ("country_iso", "is", None), + ], + ).compute() + country_iso = "de" + df = dd.read_parquet( + str(self.archive_path) + ".all.parquet", + filters=[ + ("product_id", "=", product_id), + ("country_iso", "=", country_iso), + ], + ).compute() diff --git a/generalresearch/incite/schemas/__init__.py b/generalresearch/incite/schemas/__init__.py new file mode 100644 index 0000000..c0000d1 --- /dev/null +++ b/generalresearch/incite/schemas/__init__.py @@ -0,0 +1,29 @@ +from typing import List + +import pandas as pd +import pandera.pandas as pa + +ORDER_KEY = "order_key" +# How long after an DFCollectionItem's .finish can we archive it ? Should be +# 90 min for Wall / Session, typically if rows are never modified, we'll +# use 1 min. +ARCHIVE_AFTER = "archive_after" +PARTITION_ON = "partition_on" + + +def empty_dataframe_from_schema(schema: pa.DataFrameSchema) -> "pd.DataFrame": + index_names: List[str] = schema.index.names + columns = set(schema.dtypes.keys()) + + if len(index_names) > 1: + columns = columns | set(index_names) + + df = pd.DataFrame(columns=list(columns)).astype( + {col: str(dtype) for col, dtype in schema.dtypes.items()} + ) + + if len(index_names) > 1: + df.set_index(keys=index_names, inplace=True) + + df.index.set_names(names=index_names, inplace=True) + return df diff --git a/generalresearch/incite/schemas/admin_responses.py b/generalresearch/incite/schemas/admin_responses.py new file mode 100644 index 0000000..73c0aaa --- /dev/null +++ b/generalresearch/incite/schemas/admin_responses.py @@ -0,0 +1,186 @@ +from datetime import datetime + +from pandera import ( + DataFrameSchema, + Column, + Check, + Parser, + MultiIndex, + Index, + Timestamp, +) + +BIG_INT32 = 2_147_483_647 +SIX_HOUR_SECONDS = 6 * 60 * 6 +ROUNDING = 2 + +AdminPOPSchema = DataFrameSchema( + # Generic: used for Session or Wall + index=MultiIndex( + indexes=[ + # It seems to be impossible to create a list of optional names, + # and given that we allow the index1 to be different depending + # on the split_by, let's just use generic names for now. However, + # we know the first is also an iso string for now (19 chars) + Index( + name="index0", + dtype=Timestamp, + parsers=[Parser(lambda i: i.dt.tz_localize(None))], + checks=[ + Check.less_than( + max_value=datetime(year=datetime.now().year + 1, month=1, day=1) + ) + ], + ), + Index( + name="index1", + dtype=str, + checks=[Check.str_length(max_value=255)], + ), + ], + coerce=True, + ), + columns={ + "elapsed_avg": Column( + dtype=float, + parsers=[ + Parser(lambda s: s.fillna(value=0.00)), + Parser(lambda s: s.clip(lower=0, upper=SIX_HOUR_SECONDS)), + Parser(lambda s: s.round(decimals=ROUNDING)), + ], + checks=Check.between(min_value=0, max_value=SIX_HOUR_SECONDS), + ), + "elapsed_total": Column( + dtype=int, + parsers=[ + Parser(lambda s: s.fillna(value=0)), + ], + checks=Check.between(min_value=0, max_value=BIG_INT32), + ), + "payout_avg": Column( + dtype=float, + parsers=[ + Parser(lambda s: s.fillna(value=0.00)), + Parser(lambda s: s.round(decimals=ROUNDING)), + ], + checks=Check.between(min_value=0, max_value=100), + ), + "payout_total": Column( + dtype=float, + parsers=[ + Parser(lambda s: s.fillna(value=0.00)), + Parser(lambda s: s.round(decimals=ROUNDING)), + ], + checks=Check.between(min_value=0, max_value=BIG_INT32), + ), + "entrances": Column( + dtype=int, + parsers=[ + Parser(lambda s: s.fillna(value=0)), + ], + checks=Check.between(min_value=0, max_value=BIG_INT32), + ), + "completes": Column( + dtype=int, + parsers=[ + Parser(lambda s: s.fillna(value=0)), + ], + checks=Check.between(min_value=0, max_value=BIG_INT32), + ), + "users": Column( + dtype=int, + parsers=[ + Parser(lambda s: s.fillna(value=0)), + ], + checks=Check.between(min_value=0, max_value=BIG_INT32), + ), + "conversion": Column( + dtype=float, + parsers=[ + Parser(lambda s: s.fillna(value=0.00)), + Parser(lambda s: s.round(decimals=ROUNDING)), + ], + checks=Check.between(min_value=0.00, max_value=1.00), + ), + "epc": Column( + dtype=float, + parsers=[ + Parser(lambda s: s.fillna(value=0.00)), + Parser(lambda s: s.round(decimals=ROUNDING)), + ], + checks=Check.between(min_value=0, max_value=100), + ), + "eph": Column( + dtype=float, + parsers=[ + Parser(lambda s: s.fillna(value=0.00)), + Parser(lambda s: s.round(decimals=ROUNDING)), + ], + checks=Check.between(min_value=0, max_value=BIG_INT32), + ), + "cpc": Column( + dtype=float, + parsers=[ + Parser(lambda s: s.fillna(value=0.00)), + Parser(lambda s: s.round(decimals=ROUNDING)), + ], + checks=Check.between(min_value=0, max_value=250), + ), + }, + coerce=True, +) + +admin_pop_index = AdminPOPSchema.index +admin_pop_columns = AdminPOPSchema.columns.copy() + +AdminPOPWallSchema = DataFrameSchema( + index=admin_pop_index, + columns=admin_pop_columns + | { + "buyers": Column( + dtype=int, + parsers=[ + Parser(lambda s: s.fillna(value=0)), + ], + checks=Check.between(min_value=0, max_value=BIG_INT32), + ), + "surveys": Column( + dtype=int, + parsers=[ + Parser(lambda s: s.fillna(value=0)), + ], + checks=Check.between(min_value=0, max_value=BIG_INT32), + ), + "sessions": Column( + dtype=int, + parsers=[ + Parser(lambda s: s.fillna(value=0)), + ], + checks=Check.between(min_value=0, max_value=BIG_INT32), + ), + }, + coerce=True, +) + +AdminPOPSessionSchema = DataFrameSchema( + index=admin_pop_index, + columns=admin_pop_columns + | { + "attempts_avg": Column( + dtype=float, + parsers=[ + Parser(lambda s: s.fillna(value=0.00)), + Parser(lambda s: s.round(decimals=ROUNDING)), + ], + checks=Check.between(min_value=0, max_value=25), + ), + "attempts_total": Column( + dtype=int, + parsers=[ + Parser(lambda s: s.fillna(value=0)), + ], + checks=Check.between(min_value=0, max_value=BIG_INT32), + ), + }, + coerce=True, +) diff --git a/generalresearch/incite/schemas/mergers/__init__.py b/generalresearch/incite/schemas/mergers/__init__.py new file mode 100644 index 0000000..db25b96 --- /dev/null +++ b/generalresearch/incite/schemas/mergers/__init__.py @@ -0,0 +1,27 @@ +IP_REGEX_PATTERN = ( + r"^((([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][" + r"0-9]|25[0-5])$|^(([a-fA-F]|[a-fA-F][a-fA-F0-9\-]*[a-fA-F0-9])\.)*([A-Fa-f]|[A-Fa-f][" + r"A-Fa-f0-9\-]*[A-Fa-f0-9])$|^(?:(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){6})(?:(?:(?:(?:(?:[" + r"0-9a-fA-F]{1,4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[" + r"0-9]))\.){3}(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:::(?:(?:(?:[0-9a-fA-F]{1," + r"4})):){5})(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(" + r"?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(" + r"?:(?:(?:[0-9a-fA-F]{1,4})))?::(?:(?:(?:[0-9a-fA-F]{1,4})):){4})(?:(?:(?:(?:(?:[0-9a-fA-F]{1," + r"4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(" + r"?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"1}(?:(?:[0-9a-fA-F]{1,4})))?::(?:(?:(?:[0-9a-fA-F]{1,4})):){3})(?:(?:(?:(?:(?:[0-9a-fA-F]{1," + r"4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(" + r"?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"2}(?:(?:[0-9a-fA-F]{1,4})))?::(?:(?:(?:[0-9a-fA-F]{1,4})):){2})(?:(?:(?:(?:(?:[0-9a-fA-F]{1," + r"4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(" + r"?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"3}(?:(?:[0-9a-fA-F]{1,4})))?::(?:(?:[0-9a-fA-F]{1,4})):)(?:(?:(?:(?:(?:[0-9a-fA-F]{1," + r"4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(" + r"?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"4}(?:(?:[0-9a-fA-F]{1,4})))?::)(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):(?:(?:[0-9a-fA-F]{1," + r"4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(?:25[0-5]|(?:[1-9]|1[" + r"0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0,5}(?:(?:[0-9a-fA-F]{1," + r"4})))?::)(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"6}(?:(?:[0-9a-fA-F]{1,4})))?::)))))$" +) +BIGINT = 9223372036854775807 diff --git a/generalresearch/incite/schemas/mergers/foundations/__init__.py b/generalresearch/incite/schemas/mergers/foundations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/incite/schemas/mergers/foundations/enriched_session.py b/generalresearch/incite/schemas/mergers/foundations/enriched_session.py new file mode 100644 index 0000000..4badfac --- /dev/null +++ b/generalresearch/incite/schemas/mergers/foundations/enriched_session.py @@ -0,0 +1,36 @@ +from datetime import timedelta + +from pandera import DataFrameSchema, Column, Check + +from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY, PARTITION_ON +from generalresearch.incite.schemas.thl_web import THLSessionSchema + +thl_session_columns = THLSessionSchema.columns.copy() + +EnrichedSessionSchema = DataFrameSchema( + index=THLSessionSchema.index, + columns=thl_session_columns + | { + # --- From thl_user MySQL-RR + "product_id": Column( + dtype=str, + checks=Check.str_length(min_value=32, max_value=32), + nullable=False, + ), + # -- nullable until it can be back-filled + "team_id": Column( + dtype=str, + checks=Check.str_length(min_value=32, max_value=32), + nullable=True, + ), + # --- Calculated from WallCollection --- + "attempt_count": Column(dtype="Int64", nullable=False), + }, + checks=[], + coerce=True, + metadata={ + ORDER_KEY: "started", + ARCHIVE_AFTER: timedelta(minutes=90), + PARTITION_ON: None, + }, +) diff --git a/generalresearch/incite/schemas/mergers/foundations/enriched_task_adjust.py b/generalresearch/incite/schemas/mergers/foundations/enriched_task_adjust.py new file mode 100644 index 0000000..35b579b --- /dev/null +++ b/generalresearch/incite/schemas/mergers/foundations/enriched_task_adjust.py @@ -0,0 +1,98 @@ +import pandas as pd +from pandera import DataFrameSchema, Column, Check, Index +from typing import Set + +from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY +from generalresearch.incite.schemas.thl_web import THLTaskAdjustmentSchema +from generalresearch.locales import Localelator +from generalresearch.models import DeviceType, Source +from generalresearch.models.thl.definitions import ( + WallAdjustedStatus, +) + +thl_task_adj_columns = THLTaskAdjustmentSchema.columns.copy() + +COUNTRY_ISOS: Set[str] = Localelator().get_all_countries() +kosovo = "xk" +COUNTRY_ISOS.add(kosovo) +BIGINT = 9223372036854775807 + +EnrichedTaskAdjustSchema = DataFrameSchema( + index=Index(dtype=int, checks=Check.greater_than_or_equal_to(0)), + columns={ + "wall_uuid": Column( + dtype=str, + checks=[ + Check.str_length(min_value=32, max_value=32), + ], + ), + "user_id": Column( + dtype="Int32", + checks=Check.between(min_value=0, max_value=BIGINT), + nullable=False, + ), + "source": Column( + dtype=str, + checks=[ + Check.str_length(max_value=2), + Check.isin([e.value for e in Source]), + ], + nullable=False, + ), + "survey_id": Column( + dtype=str, checks=[Check.str_length(max_value=32)], nullable=False + ), + "amount": Column(dtype=float), + "adjusted_status": Column( + dtype=str, + checks=[ + Check.str_length(min_value=2, max_value=2), + Check.isin([e.value for e in WallAdjustedStatus]), + ], + ), + "adjusted_status_last": Column( + dtype=str, + checks=[ + Check.str_length(min_value=2, max_value=2), + Check.isin([e.value for e in WallAdjustedStatus]), + ], + ), + "alerted": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "alerted_last": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "started": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "buyer_id": Column( + dtype=str, checks=Check.str_length(max_value=32), nullable=True + ), + "country_iso": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=2), + Check.isin(COUNTRY_ISOS), # 2 letter, lowercase + ], + nullable=True, + ), + "device_type": Column( + dtype="Int32", + checks=Check.isin([e.value for e in DeviceType]), + nullable=True, + ), + "adjustments": Column( + dtype="Int32", + checks=Check.between(min_value=0, max_value=BIGINT), + nullable=False, + ), + "product_id": Column( + dtype=str, + checks=Check.str_length(min_value=32, max_value=32), + nullable=False, + ), + "team_id": Column( + dtype=str, + checks=Check.str_length(min_value=32, max_value=32), + nullable=True, + ), + }, + checks=[], + coerce=True, + metadata={ORDER_KEY: "alerted", ARCHIVE_AFTER: None}, +) diff --git a/generalresearch/incite/schemas/mergers/foundations/enriched_wall.py b/generalresearch/incite/schemas/mergers/foundations/enriched_wall.py new file mode 100644 index 0000000..1b78fde --- /dev/null +++ b/generalresearch/incite/schemas/mergers/foundations/enriched_wall.py @@ -0,0 +1,144 @@ +from datetime import timedelta + +import pandas as pd +from pandera import DataFrameSchema, Column, Check, Index + +from generalresearch.incite.schemas import PARTITION_ON, ARCHIVE_AFTER +from generalresearch.locales import Localelator +from generalresearch.models import DeviceType, Source +from generalresearch.models.thl.definitions import ( + Status, + StatusCode1, + ReportValue, + WallStatusCode2, +) + +IP_REGEX_PATTERN = ( + r"^((([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][" + r"0-9]|25[0-5])$|^(([a-fA-F]|[a-fA-F][a-fA-F0-9\-]*[a-fA-F0-9])\.)*([A-Fa-f]|[A-Fa-f][" + r"A-Fa-f0-9\-]*[A-Fa-f0-9])$|^(?:(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){6})(?:(?:(?:(?:(?:[" + r"0-9a-fA-F]{1,4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[" + r"0-9]))\.){3}(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:::(?:(?:(?:[0-9a-fA-F]{1," + r"4})):){5})(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(" + r"?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(" + r"?:(?:(?:[0-9a-fA-F]{1,4})))?::(?:(?:(?:[0-9a-fA-F]{1,4})):){4})(?:(?:(?:(?:(?:[0-9a-fA-F]{1," + r"4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(" + r"?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"1}(?:(?:[0-9a-fA-F]{1,4})))?::(?:(?:(?:[0-9a-fA-F]{1,4})):){3})(?:(?:(?:(?:(?:[0-9a-fA-F]{1," + r"4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(" + r"?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"2}(?:(?:[0-9a-fA-F]{1,4})))?::(?:(?:(?:[0-9a-fA-F]{1,4})):){2})(?:(?:(?:(?:(?:[0-9a-fA-F]{1," + r"4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(" + r"?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"3}(?:(?:[0-9a-fA-F]{1,4})))?::(?:(?:[0-9a-fA-F]{1,4})):)(?:(?:(?:(?:(?:[0-9a-fA-F]{1," + r"4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(" + r"?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"4}(?:(?:[0-9a-fA-F]{1,4})))?::)(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):(?:(?:[0-9a-fA-F]{1," + r"4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(?:25[0-5]|(?:[1-9]|1[" + r"0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0,5}(?:(?:[0-9a-fA-F]{1," + r"4})))?::)(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"6}(?:(?:[0-9a-fA-F]{1,4})))?::)))))$" +) +BIGINT = 9223372036854775807 + +COUNTRY_ISOS = Localelator().get_all_countries() +kosovo = "xk" +COUNTRY_ISOS.add(kosovo) + +EnrichedWallSchema = DataFrameSchema( + index=Index( + name="uuid", # this is the wall event's uuid + dtype=str, + checks=Check.str_length(min_value=32, max_value=32), + ), + columns={ + # --- Wall based --- + "source": Column( + dtype=str, + checks=[ + Check.str_length(max_value=2), + Check.isin([e.value for e in Source]), + ], + nullable=False, + ), + "buyer_id": Column( + dtype=str, checks=Check.str_length(max_value=32), nullable=True + ), + "survey_id": Column( + dtype=str, checks=[Check.str_length(max_value=32)], nullable=False + ), + "started": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=False), + "finished": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=True), + "status": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=1), + Check.isin([e.value for e in Status]), + ], + nullable=True, + ), + "status_code_1": Column( + dtype="Int32", + checks=[Check.isin([e.value for e in StatusCode1])], + nullable=True, + ), + "status_code_2": Column( + dtype="Int32", + checks=[Check.isin([e.value for e in WallStatusCode2])], + nullable=True, + ), + "cpi": Column( + dtype=float, + checks=Check.between(min_value=0, max_value=1_000), + nullable=False, + ), + "ext_status_code_1": Column( + dtype=str, checks=Check.str_length(max_value=32), nullable=True + ), + "ext_status_code_2": Column( + dtype=str, checks=Check.str_length(max_value=32), nullable=True + ), + "ext_status_code_3": Column( + dtype=str, checks=Check.str_length(max_value=32), nullable=True + ), + "report_value": Column( + dtype="Int64", + checks=Check.isin([e.value for e in ReportValue]), + nullable=True, + ), + # --- Session based --- + "session_id": Column(dtype=int, checks=Check.greater_than(0), nullable=False), + "country_iso": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=2), + Check.isin(COUNTRY_ISOS), # 2 letter, lowercase + ], + nullable=True, + ), + "device_type": Column( + dtype="Int32", + checks=Check.isin([e.value for e in DeviceType]), + nullable=True, + ), + "payout": Column( + dtype=float, + checks=Check.between(min_value=0, max_value=1_000), + nullable=True, + ), + # --- User based --- + "user_id": Column( + dtype="Int32", + checks=Check.between(min_value=0, max_value=BIGINT), + nullable=False, + ), + "product_id": Column( + dtype=str, + checks=Check.str_length(min_value=32, max_value=32), + nullable=False, + ), + }, + checks=[], + coerce=True, + metadata={PARTITION_ON: None, ARCHIVE_AFTER: timedelta(minutes=90)}, +) diff --git a/generalresearch/incite/schemas/mergers/foundations/user_id_product.py b/generalresearch/incite/schemas/mergers/foundations/user_id_product.py new file mode 100644 index 0000000..780a3f2 --- /dev/null +++ b/generalresearch/incite/schemas/mergers/foundations/user_id_product.py @@ -0,0 +1,27 @@ +from datetime import timedelta + +from pandera import Column, Check, Index, Category, DataFrameSchema + +from generalresearch.incite.schemas import ARCHIVE_AFTER + +BIGINT = 9223372036854775807 + +UserIdIndex = Index( + name="id", + dtype=int, + checks=Check.between(min_value=0, max_value=BIGINT), + unique=True, +) + +""" +Simply stores a mapping between user ID and product ID. product_id is a category +which is much smaller.""" +UserIdProductSchema = DataFrameSchema( + index=UserIdIndex, + columns={ + "product_id": Column(dtype=Category, nullable=False), + }, + checks=[], + coerce=False, + metadata={ARCHIVE_AFTER: timedelta(minutes=0)}, +) diff --git a/generalresearch/incite/schemas/mergers/nginx.py b/generalresearch/incite/schemas/mergers/nginx.py new file mode 100644 index 0000000..30e6fec --- /dev/null +++ b/generalresearch/incite/schemas/mergers/nginx.py @@ -0,0 +1,140 @@ +# MergeType.NGINX_GRS: NGINXGRSSchema, +# MergeType.NGINX_FSB: NGINXFSBSchema, +# MergeType.NGINX_CORE: NGINXCoreSchema, + +from datetime import timedelta + +import pandas as pd +from pandera import DataFrameSchema, Column, Check, Index + +from generalresearch.incite.schemas import PARTITION_ON, ARCHIVE_AFTER + +NGINXBaseSchema = DataFrameSchema( + columns={ + "time": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=False), + "method": Column( + dtype=str, checks=[Check.str_length(max_value=8)], nullable=True + ), + "user_agent": Column( + dtype=str, checks=[Check.str_length(max_value=3_000)], nullable=True + ), + "upstream_route": Column( + dtype=str, checks=[Check.str_length(max_value=255)], nullable=True + ), + "host": Column( + dtype=str, checks=[Check.str_length(max_value=255)], nullable=True + ), + "status": Column( + dtype="Int32", + checks=[Check.between(min_value=0, max_value=600)], + nullable=False, + ), + "upstream_status": Column( + dtype="Int32", + checks=[Check.between(min_value=0, max_value=600)], + nullable=False, + ), + "request_time": Column( + dtype=float, + checks=[Check.greater_than_or_equal_to(min_value=0)], + nullable=False, + ), + "upstream_response_time": Column( + dtype=float, + checks=[Check.greater_than_or_equal_to(min_value=0)], + nullable=False, + ), + "upstream_cache_hit": Column(dtype=bool, nullable=False), + } +) + +NGINXGRSSchema = DataFrameSchema( + index=Index(dtype=int, checks=Check.greater_than_or_equal_to(0)), + columns=NGINXBaseSchema.columns + | { + # --- GRL Custom + "product_id": Column( + dtype=str, checks=Check.str_length(min_value=1, max_value=32), nullable=True + ), + "product_user_id": Column( + dtype=str, + checks=Check.str_length(min_value=1, max_value=128), + nullable=True, + ), + "wall_uuid": Column( + dtype=str, + # It's modified by some people and so this breaks.. + # checks=[Check.str_length(min_value=32, max_value=32)], + nullable=True, + ), + "custom_query_params": Column( + dtype=str, checks=[Check.str_length(max_value=3_000)], nullable=True + ), + }, + checks=[], + coerce=True, + metadata={PARTITION_ON: ["product_id"], ARCHIVE_AFTER: timedelta(minutes=1)}, +) + +NGINXCoreSchema = DataFrameSchema( + index=Index(dtype=int, checks=Check.greater_than_or_equal_to(0)), + columns=NGINXBaseSchema.columns + | { + # --- GRL Custom + "request_path": Column( + dtype=str, + checks=Check.str_length(min_value=1, max_value=3_000), + nullable=False, + ), + "referer": Column( + dtype=str, + checks=Check.str_length(min_value=1, max_value=128), + nullable=True, + ), + "session_id": Column( + dtype=str, checks=Check.str_length(max_value=3_000), nullable=True + ), + "request_id": Column( + dtype=str, checks=Check.str_length(max_value=3_000), nullable=True + ), + "nudge_id": Column( + dtype=str, checks=Check.str_length(max_value=3_000), nullable=True + ), + "request_custom_query_params": Column( + dtype=str, checks=[Check.str_length(max_value=3_000)], nullable=True + ), + }, + checks=[], + coerce=True, + metadata={PARTITION_ON: None, ARCHIVE_AFTER: timedelta(minutes=1)}, +) + +NGINXFSBSchema = DataFrameSchema( + index=Index(dtype=int, checks=Check.greater_than_or_equal_to(0)), + columns=NGINXBaseSchema.columns + | { + # --- GRL Custom + "product_id": Column( + dtype=str, checks=Check.str_length(min_value=1, max_value=32), nullable=True + ), + "product_user_id": Column( + dtype=str, + checks=Check.str_length(min_value=1, max_value=128), + nullable=True, + ), + "n_bins": Column( + dtype="Int32", + checks=Check.greater_than_or_equal_to(min_value=0), + nullable=True, + ), + "is_offerwall": Column(dtype=bool, nullable=False), + "offerwall": Column(dtype=bool, nullable=False), + "is_report": Column(dtype=bool, nullable=False), + "custom_query_params": Column( + dtype=str, checks=[Check.str_length(max_value=3_000)], nullable=True + ), + }, + checks=[], + coerce=True, + metadata={PARTITION_ON: ["product_id"], ARCHIVE_AFTER: timedelta(minutes=1)}, +) diff --git a/generalresearch/incite/schemas/mergers/pop_ledger.py b/generalresearch/incite/schemas/mergers/pop_ledger.py new file mode 100644 index 0000000..25c7e68 --- /dev/null +++ b/generalresearch/incite/schemas/mergers/pop_ledger.py @@ -0,0 +1,64 @@ +from datetime import timedelta + +import pandas as pd +from more_itertools import flatten +from pandera import DataFrameSchema, Column, Check, Index + +from generalresearch.incite.schemas import ARCHIVE_AFTER, ORDER_KEY, PARTITION_ON +from generalresearch.incite.schemas.thl_web import TxSchema +from generalresearch.models.thl.ledger import TransactionType, Direction + +""" +- In reality, a multi-index would be appropriate here, but dask does not support this, so we're keeping it flat. + As such, the index in this schema is simply an autoindex and has no meaning. + +- The "virtual" index (conceptually) is (time, account_id), and the columns are all combinations of + '{TransactionType}.{Direction}'. + +- We want both credit + debit amounts so we know, for e.g., an account got $+10 of positive recons + and $-20 of negative recons. +""" + +# If an amount is "very" large, something is def wrong. Defining "very" somewhat arbitrarily here. +SUSPICIOUSLY_LARGE_NUMBER = (2**32 / 2) - 1 # 2147483647 + +NonNegativeAmount = Column( + dtype="Int32", + nullable=True, + checks=Check.between( + min_value=0, max_value=SUSPICIOUSLY_LARGE_NUMBER, include_min=True + ), +) + +numerical_col_names = list( + flatten( + [ + [ + e.value + "." + Direction.CREDIT.name, + e.value + "." + Direction.DEBIT.name, + ] + for e in TransactionType + ] + ) +) +numerical_cols = {k: NonNegativeAmount for k in numerical_col_names} + +PopLedgerSchema = DataFrameSchema( + index=Index(name="id", dtype=int, checks=Check.greater_than_or_equal_to(0)), + columns=numerical_cols + | { + "time_idx": Column( + dtype=pd.DatetimeTZDtype(tz="UTC"), + checks=Check(lambda x: (x.dt.second == 0) & (x.dt.microsecond == 0)), + nullable=False, + ), + "account_id": TxSchema.columns["account_id"], + }, + checks=[], + coerce=True, + metadata={ + ORDER_KEY: None, + ARCHIVE_AFTER: timedelta(minutes=90), + PARTITION_ON: None, + }, +) diff --git a/generalresearch/incite/schemas/mergers/ym_survey_wall.py b/generalresearch/incite/schemas/mergers/ym_survey_wall.py new file mode 100644 index 0000000..2b2d266 --- /dev/null +++ b/generalresearch/incite/schemas/mergers/ym_survey_wall.py @@ -0,0 +1,101 @@ +from datetime import timedelta + +from pandera import DataFrameSchema, Column, Check, Index + +from generalresearch.incite.schemas import ORDER_KEY, ARCHIVE_AFTER +from generalresearch.incite.schemas.thl_web import THLWallSchema, THLSessionSchema + +thl_wall_columns = THLWallSchema.columns.copy() + +thl_wall_columns = { + k: v + for k, v in thl_wall_columns.items() + if k + in { + "source", + "buyer_id", + "started", + "session_id", + "survey_id", + "cpi", + "status", + "status_code_1", + "status_code_2", + "ext_status_code_1", + "ext_status_code_2", + "ext_status_code_3", + "report_value", + } +} +thl_session_columns = THLSessionSchema.columns.copy() +thl_session_columns = { + k: v + for k, v in thl_session_columns.items() + if k in {"user_id", "country_iso", "device_type"} +} + +""" +This is used by YM-survey-predict and train. It is mostly THLWall with: + - Adjusted columns removed, (YM will get this info from the + TaskAdjustment collection) + + - Fields from the session joined in (user_id, country_iso, device_type, + session's uuid) + + - Product_id and blocked (from User). Blocked means blocked NOW (latest), + not when the session was attempted. +""" + +YMSurveyWallSchema = DataFrameSchema( + # index is the wall's uuid + index=Index( + name="uuid", dtype=str, checks=Check.str_length(min_value=32, max_value=32) + ), + columns=thl_wall_columns + | thl_session_columns + | { + "product_id": Column( + dtype=str, + checks=Check.str_length(min_value=32, max_value=32), + nullable=False, + ), + # -- nullable until it can be back-filled + "team_id": Column( + dtype=str, + checks=Check.str_length(min_value=32, max_value=32), + nullable=True, + ), + "elapsed": Column(dtype="Int64", nullable=True), + "in_progress": Column( + dtype=bool, + required=False, + description="This is time-sensitive, so will not be included in archived files. True if" + "the entrance started less than 90 min ago and has not yet returned.", + ), + "pass_ps": Column( + dtype=bool, + required=False, + description="Did this entrance pass the pre-screener and actually enter the client?" + "Note: we mark abandonments as True." + "Note: there is no 'in-progress' determination here. A user who 'just' entered" + "and hasn't come back yet is also marked as True", + ), + "quality_fail": Column( + dtype=bool, + required=False, + description="Did the user fail for quality reasons? We generally want to exclude these for" + "yield-management.", + ), + "abandon": Column( + dtype=bool, + required=False, + description="In-progress is not considered. A user who is in-progress and hasn't come back" + "is still marked as abandon.", + ), + }, + checks=[], + coerce=True, + strict=True, + unique=["session_id", "source", "survey_id"], + metadata={ORDER_KEY: "started", ARCHIVE_AFTER: timedelta(minutes=90)}, +) diff --git a/generalresearch/incite/schemas/mergers/ym_wall_summary.py b/generalresearch/incite/schemas/mergers/ym_wall_summary.py new file mode 100644 index 0000000..fb34dec --- /dev/null +++ b/generalresearch/incite/schemas/mergers/ym_wall_summary.py @@ -0,0 +1,74 @@ +from datetime import timedelta +from typing import Set + +from pandera import DataFrameSchema, Column, Check, Index + +from generalresearch.incite.schemas import ARCHIVE_AFTER +from generalresearch.locales import Localelator +from generalresearch.models import Source + +COUNTRY_ISOS: Set[str] = Localelator().get_all_countries() +kosovo = "xk" +COUNTRY_ISOS.add(kosovo) + +""" +A single file containing, over the past year, one row per: + date (YYYY-MM-DD), product_id (optional), buyer_id (optional), country_iso, source +with counts for this aggregation for the following: + Status.COMPLETE, Status.FAIL, ..., StatusNULL, StatusCode1.BUYER_FAIL, ... +For e.g: + 2024-01-01, 70bXXXXXXXXXXX, NULL, 'us', 'm', 100, 234, 123, +""" + +YMWallSummarySchema = DataFrameSchema( + # index is meaningless + index=Index(dtype=int), + columns={ + "date": Column( + dtype=str, + checks=Check.str_matches("20[0-9][0-9]-[0-9]{2}-[0-9]{2}"), + nullable=False, + ), + "product_id": Column( + dtype=str, + checks=Check.str_length(min_value=32, max_value=32), + nullable=True, + ), + "buyer_id": Column( + dtype=str, + checks=Check.str_length(min_value=1, max_value=32), + nullable=True, + ), + "country_iso": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=2), + Check.isin(COUNTRY_ISOS), # 2 letter, lowercase + ], + nullable=True, + ), + "source": Column( + dtype=str, + checks=[ + Check.str_length(max_value=2), + Check.isin([e.value for e in Source]), + ], + ), + "Status.COMPLETE": Column(dtype=int, checks=Check.greater_than_or_equal_to(0)), + "Status.FAIL": Column(dtype=int, checks=Check.greater_than_or_equal_to(0)), + "Status.ABANDON": Column(dtype=int, checks=Check.greater_than_or_equal_to(0)), + "Status.TIMEOUT": Column( + dtype=int, + checks=Check.greater_than_or_equal_to(0), + description="this includes those where the status is None", + ), + "StatusCode1.BUYER_FAIL": Column( + dtype=int, checks=Check.greater_than_or_equal_to(0) + ), + }, + checks=[], + coerce=True, + strict=True, + unique=["date", "product_id", "buyer_id", "country_iso", "source"], + metadata={ARCHIVE_AFTER: timedelta(minutes=90)}, +) diff --git a/generalresearch/incite/schemas/thl_marketplaces.py b/generalresearch/incite/schemas/thl_marketplaces.py new file mode 100644 index 0000000..286db6a --- /dev/null +++ b/generalresearch/incite/schemas/thl_marketplaces.py @@ -0,0 +1,64 @@ +import copy +from datetime import timedelta + +import pandas as pd +from pandera import Column, Check, Index, DataFrameSchema + +from generalresearch.incite.schemas import ORDER_KEY, ARCHIVE_AFTER + +BIGINT = 9223372036854775807 + +SurveyHistorySchemaMeta = { + "index": Index( + name="id", dtype=int, checks=Check.between(min_value=0, max_value=BIGINT) + ), + "columns": { + # "survey_id": # fill this in in implementations + "key": Column( + dtype="Int32", + checks=Check.between(min_value=0, max_value=5), + nullable=False, + ), + "value": Column(dtype="Int64", nullable=True), + "date": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=False), + }, + "checks": [], + "coerce": True, + "metadata": {ORDER_KEY: "date", ARCHIVE_AFTER: timedelta(minutes=1)}, +} + +InnovateSurveyHistorySchemaDict = copy.deepcopy(SurveyHistorySchemaMeta) +InnovateSurveyHistorySchemaDict["columns"]["survey_id"] = Column( + dtype=str, checks=Check.str_length(min_value=1, max_value=32), nullable=False +) +# global_conversion (5) is a float +InnovateSurveyHistorySchemaDict["columns"]["value"] = Column(dtype=float, nullable=True) +InnovateSurveyHistorySchema = DataFrameSchema(**InnovateSurveyHistorySchemaDict) + +MorningSurveyTimeseriesSchemaDict = copy.deepcopy(SurveyHistorySchemaMeta) +MorningSurveyTimeseriesSchemaDict["columns"]["bid_id"] = Column( + dtype=str, checks=Check.str_length(min_value=1, max_value=32), nullable=False +) +MorningSurveyTimeseriesSchema = DataFrameSchema(**MorningSurveyTimeseriesSchemaDict) + +SagoSurveyHistorySchemaDict = copy.deepcopy(SurveyHistorySchemaMeta) +SagoSurveyHistorySchemaDict["columns"]["survey_id"] = Column( + dtype=str, checks=Check.str_length(min_value=1, max_value=32), nullable=False +) +# They send us the client_conversion as a float +SagoSurveyHistorySchemaDict["columns"]["value"] = Column(dtype=float, nullable=True) +# We added 3 new keys: [3, 4, 5]. We don't need [0, 1, 2]. +SagoSurveyHistorySchemaDict["columns"]["key"] = Column( + dtype="Int32", checks=Check.between(min_value=0, max_value=5), nullable=False +) +SagoSurveyHistorySchema = DataFrameSchema(**SagoSurveyHistorySchemaDict) + +SpectrumSurveyTimeseriesSchemaDict = copy.deepcopy(SurveyHistorySchemaMeta) +SpectrumSurveyTimeseriesSchemaDict["columns"]["survey_id"] = Column( + dtype=str, checks=Check.str_length(min_value=1, max_value=32), nullable=False +) +# Keys 1 & 3 (ir) are floats +SpectrumSurveyTimeseriesSchemaDict["columns"]["value"] = Column( + dtype=float, nullable=True +) +SpectrumSurveyTimeseriesSchema = DataFrameSchema(**SpectrumSurveyTimeseriesSchemaDict) diff --git a/generalresearch/incite/schemas/thl_web.py b/generalresearch/incite/schemas/thl_web.py new file mode 100644 index 0000000..a644720 --- /dev/null +++ b/generalresearch/incite/schemas/thl_web.py @@ -0,0 +1,803 @@ +from datetime import timezone, datetime, timedelta + +import pandas as pd +from pandera import DataFrameSchema, Column, Check, Index, MultiIndex + +from generalresearch.incite.schemas import ORDER_KEY, ARCHIVE_AFTER +from generalresearch.locales import Localelator +from generalresearch.models import DeviceType, Source +from generalresearch.models.thl.definitions import ( + StatusCode1, + WallStatusCode2, + ReportValue, + WallAdjustedStatus, + Status, + SessionStatusCode2, + SessionAdjustedStatus, +) +from generalresearch.models.thl.ledger import TransactionMetadataColumns +from generalresearch.models.thl.maxmind.definitions import UserType + +IP_REGEX_PATTERN = ( + r"^((([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][" + r"0-9]|25[0-5])$|^(([a-fA-F]|[a-fA-F][a-fA-F0-9\-]*[a-fA-F0-9])\.)*([A-Fa-f]|[A-Fa-f][" + r"A-Fa-f0-9\-]*[A-Fa-f0-9])$|^(?:(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){6})(?:(?:(?:(?:(?:[" + r"0-9a-fA-F]{1,4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[" + r"0-9]))\.){3}(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:::(?:(?:(?:[0-9a-fA-F]{1," + r"4})):){5})(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(" + r"?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(" + r"?:(?:(?:[0-9a-fA-F]{1,4})))?::(?:(?:(?:[0-9a-fA-F]{1,4})):){4})(?:(?:(?:(?:(?:[0-9a-fA-F]{1," + r"4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(" + r"?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"1}(?:(?:[0-9a-fA-F]{1,4})))?::(?:(?:(?:[0-9a-fA-F]{1,4})):){3})(?:(?:(?:(?:(?:[0-9a-fA-F]{1," + r"4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(" + r"?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"2}(?:(?:[0-9a-fA-F]{1,4})))?::(?:(?:(?:[0-9a-fA-F]{1,4})):){2})(?:(?:(?:(?:(?:[0-9a-fA-F]{1," + r"4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(" + r"?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"3}(?:(?:[0-9a-fA-F]{1,4})))?::(?:(?:[0-9a-fA-F]{1,4})):)(?:(?:(?:(?:(?:[0-9a-fA-F]{1," + r"4})):(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(" + r"?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"4}(?:(?:[0-9a-fA-F]{1,4})))?::)(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):(?:(?:[0-9a-fA-F]{1," + r"4})))|(?:(?:(?:(?:(?:25[0-5]|(?:[1-9]|1[0-9]|2[0-4])?[0-9]))\.){3}(?:(?:25[0-5]|(?:[1-9]|1[" + r"0-9]|2[0-4])?[0-9])))))))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0,5}(?:(?:[0-9a-fA-F]{1," + r"4})))?::)(?:(?:[0-9a-fA-F]{1,4})))|(?:(?:(?:(?:(?:(?:[0-9a-fA-F]{1,4})):){0," + r"6}(?:(?:[0-9a-fA-F]{1,4})))?::)))))$" +) +BIGINT = 9223372036854775807 + +COUNTRY_ISOS = Localelator().get_all_countries() +kosovo = "xk" +COUNTRY_ISOS.add(kosovo) + +THLUserSchema = DataFrameSchema( + index=Index( + name="id", dtype=int, checks=Check.between(min_value=0, max_value=BIGINT) + ), + columns={ + "uuid": Column( + dtype=str, + checks=Check.str_length(min_value=32, max_value=32), + nullable=False, + ), + "product_id": Column( + dtype=str, + checks=Check.str_length(min_value=32, max_value=32), + nullable=False, + ), + "product_user_id": Column( + dtype=str, + checks=Check.str_length(min_value=3, max_value=128), + nullable=False, + ), + "created": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=True), + }, + checks=[], + coerce=True, + # This may be an issue with how we handle updates... and reading from + # last_seen as multiple of the same user could be in a dataframe, and we + # only want the latest record. + # unique=["product_id", "product_user_id"], + metadata={ + ORDER_KEY: "created", + ARCHIVE_AFTER: timedelta(minutes=1), + }, +) + +THLWallSchema = DataFrameSchema( + index=Index( + name="uuid", dtype=str, checks=Check.str_length(min_value=32, max_value=32) + ), + columns={ + "source": Column( + dtype=str, + checks=[ + Check.str_length(max_value=2), + Check.isin([e.value for e in Source]), + ], + ), + "buyer_id": Column( + dtype=str, checks=Check.str_length(max_value=32), nullable=True + ), + "req_survey_id": Column(dtype=str, checks=Check.str_length(max_value=32)), + "req_cpi": Column( + dtype=float, checks=Check.between(min_value=0, max_value=1_000) + ), + "started": Column( + dtype=pd.DatetimeTZDtype(tz="UTC"), + checks=[Check(lambda x: x < datetime.now(tz=timezone.utc))], + nullable=False, + ), + "session_id": Column( + dtype="Int32", checks=Check.between(min_value=0, max_value=BIGINT) + ), + "survey_id": Column( + dtype=str, checks=Check.str_length(max_value=32), nullable=True + ), + "cpi": Column( + dtype=float, + checks=Check.between(min_value=0, max_value=1_000), + nullable=True, + ), + "finished": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=True), + "status": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=1), + Check.isin([e.value for e in Status]), + ], + nullable=True, + ), + "status_code_1": Column( + dtype="Int64", + checks=Check.isin([e.value for e in StatusCode1]), + nullable=True, + ), + "status_code_2": Column( + dtype="Int64", + checks=Check.isin([e.value for e in WallStatusCode2]), + nullable=True, + ), + "ext_status_code_1": Column( + dtype=str, checks=Check.str_length(max_value=32), nullable=True + ), + "ext_status_code_2": Column( + dtype=str, checks=Check.str_length(max_value=32), nullable=True + ), + "ext_status_code_3": Column( + dtype=str, checks=Check.str_length(max_value=32), nullable=True + ), + "report_value": Column( + dtype="Int64", + checks=Check.isin([e.value for e in ReportValue]), + nullable=True, + ), + "report_notes": Column( + dtype=str, checks=Check.str_length(max_value=255), nullable=True + ), + "adjusted_status": Column( + dtype=str, + checks=[ + Check.str_length(min_value=2, max_value=2), + Check.isin([e.value for e in WallAdjustedStatus]), + ], + nullable=True, + ), + "adjusted_cpi": Column( + dtype=float, + checks=Check.between(min_value=0, max_value=1_000), + nullable=True, + ), + "adjusted_timestamp": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=True), + }, + checks=[ + # Lets require more than a few Sources + # Check(check_fn=lambda df: df.source.unique().size > 3, + # error="Issue with the distribution of Sources") + # Check(check_fn=lambda df: df['started'] <= df['finished'], + # element_wise=True, + # ignore_na=True, + # error='"Finished" must be greater than "started"'), + # If adjusted, ensure all adjusted_* fields are set + # If status !=e, sure finished is set + ], + coerce=True, + unique=["session_id", "source", "survey_id"], + metadata={ + ORDER_KEY: "started", + ARCHIVE_AFTER: timedelta(minutes=90), + }, +) + +THLSessionSchema = DataFrameSchema( + index=Index(name="id", dtype=int, checks=Check.greater_than(0)), + columns={ + "uuid": Column( + dtype=str, + checks=Check.str_length(min_value=32, max_value=32), + nullable=False, + unique=True, + ), + "user_id": Column( + dtype="Int32", + checks=Check.between(min_value=0, max_value=BIGINT), + nullable=False, + ), + "started": Column( + dtype=pd.DatetimeTZDtype(tz="UTC"), + checks=[Check(lambda x: x < datetime.now(tz=timezone.utc))], + nullable=True, + ), + "finished": Column( + dtype=pd.DatetimeTZDtype(tz="UTC"), + checks=[Check(lambda x: x < datetime.now(tz=timezone.utc))], + nullable=True, + ), + "loi_min": Column(dtype="Int64", nullable=True), + "loi_max": Column(dtype="Int64", nullable=True), + "user_payout_min": Column(dtype=float, nullable=True), + "user_payout_max": Column(dtype=float, nullable=True), + "country_iso": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=2), + Check.isin(COUNTRY_ISOS), # 2 letter, lowercase + ], + nullable=True, + ), + "device_type": Column( + dtype="Int64", + checks=Check.isin([e.value for e in DeviceType]), + nullable=True, + ), + "ip": Column( + dtype=str, + checks=[ + Check.str_length(min_value=7), + Check.str_matches(pattern=IP_REGEX_PATTERN), + ], + nullable=True, + ), + "status": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=1), + Check.isin([e.value for e in Status]), + ], + nullable=True, + ), + "status_code_1": Column( + dtype="Int64", + checks=Check.isin([e.value for e in StatusCode1]), + nullable=True, + ), + "status_code_2": Column( + dtype="Int64", + checks=Check.isin([e.value for e in SessionStatusCode2]), + nullable=True, + ), + "payout": Column( + dtype=float, + checks=Check.between(min_value=0, max_value=1_000), + nullable=True, + ), + "user_payout": Column( + dtype=float, + checks=Check.between(min_value=0, max_value=1_000), + nullable=True, + ), + "adjusted_status": Column( + dtype=str, + checks=[ + Check.str_length(min_value=2, max_value=2), + Check.isin([e.value for e in SessionAdjustedStatus]), + ], + nullable=True, + ), + "adjusted_payout": Column( + dtype=float, + checks=Check.between(min_value=0, max_value=1_000), + nullable=True, + ), + "adjusted_user_payout": Column( + dtype=float, + checks=Check.between(min_value=0, max_value=1_000), + nullable=True, + ), + "adjusted_timestamp": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=True), + "url_metadata": Column(dtype=str, nullable=True), + }, + checks=[ + # Check(lambda df: df['started'] <= df['finished'], + # element_wise=True, + # ignore_na=True, + # error='"Finished" should be greater than "started"'), + # Check(check_fn=lambda df: df.source.unique().size > 3, + # error="Issue with the distribution of Sources") + ], + coerce=True, + metadata={ORDER_KEY: "started", ARCHIVE_AFTER: timedelta(minutes=90)}, +) + +THLIPInfoSchema = DataFrameSchema( + index=Index( + name="ip", + dtype=str, + checks=[ + Check.str_length(min_value=7), + Check.str_matches(pattern=IP_REGEX_PATTERN), + ], + ), + columns={ + "geoname_id": Column( + dtype=str, + checks=[ + Check.str_length(min_value=5, max_value=8), + ], + ), + "country_iso": Column( + dtype=str, + checks=[ + Check.str_length(min_value=2, max_value=2), + ], + nullable=False, + ), + "registered_country_iso": Column( + dtype=str, + checks=[ + Check.str_length(min_value=2, max_value=2), + ], + nullable=True, + ), + "is_anonymous": Column(dtype=bool), + "is_anonymous_vpn": Column(dtype=bool), + "is_hosting_provider": Column(dtype=bool), + "is_public_proxy": Column(dtype=bool), + "is_tor_exit_node": Column(dtype=bool), + "is_residential_proxy": Column(dtype=bool), + "autonomous_system_number": Column( + dtype="Int64", + checks=[ + Check.greater_than(min_value=0), + ], + nullable=True, + ), + "autonomous_system_organization": Column( + dtype=str, + checks=[ + Check.str_length(min_value=2, max_value=255), + ], + nullable=True, + ), + "domain": Column( + dtype=str, + checks=[ + Check.str_length(min_value=3, max_value=255), + ], + nullable=True, + ), + "isp": Column( + dtype=str, + checks=[ + Check.str_length(min_value=2, max_value=255), + ], + nullable=True, + ), + # Don't know what this is.. + "mobile_country_code": Column(dtype=str, nullable=True), + # Don't know what this is.. + "mobile_network_code": Column(dtype=str, nullable=True), + "network": Column( + dtype=str, + checks=[ + Check.str_length(min_value=7, max_value=255), + ], + nullable=True, + ), + "organization": Column( + dtype=str, + checks=[ + Check.str_length(min_value=2, max_value=255), + ], + nullable=True, + ), + "static_ip_score": Column( + dtype=float, + checks=[ + Check.greater_than(min_value=0), + ], + nullable=True, + ), + "user_type": Column( + dtype=str, + checks=[ + Check.str_length(min_value=3, max_value=255), + Check.isin([e.value for e in UserType]), + ], + nullable=True, + ), + "postal_code": Column( + dtype=str, + checks=[ + Check.str_length(min_value=2, max_value=9), + ], + nullable=True, + ), + "latitude": Column(dtype=float, nullable=True), + "longitude": Column(dtype=float, nullable=True), + "accuracy_radius": Column( + dtype="Int64", + checks=[ + # Checked on 2024-02-24 Max + Check.between(min_value=0, max_value=1_000), + ], + nullable=True, + ), + "updated": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + }, + checks=[], + coerce=True, + metadata={ + ORDER_KEY: "updated", + ARCHIVE_AFTER: timedelta(minutes=1), + }, +) + +THLTaskAdjustmentSchema = DataFrameSchema( + index=Index( + name="uuid", dtype=str, checks=Check.str_length(min_value=32, max_value=32) + ), + columns={ + "adjusted_status": Column( + dtype=str, + checks=[ + Check.str_length(min_value=2, max_value=2), + Check.isin([e.value for e in WallAdjustedStatus]), + ], + ), + "ext_status_code": Column(dtype=str, checks=[], nullable=True), + "amount": Column(dtype=float), + "alerted": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "created": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "user_id": Column( + dtype="Int32", checks=Check.between(min_value=0, max_value=BIGINT) + ), + "wall_uuid": Column( + dtype=str, + checks=[ + Check.str_length(min_value=32, max_value=32), + ], + ), + "started": Column( + dtype=pd.DatetimeTZDtype(tz="UTC"), + checks=[Check(lambda x: x < datetime.now(tz=timezone.utc))], + ), + "source": Column( + dtype=str, + checks=[ + Check.str_length(max_value=2), + Check.isin([e.value for e in Source]), + ], + ), + "survey_id": Column( + dtype=str, + checks=Check.str_length(max_value=32), + ), + }, + checks=[ + # started < created + ], + coerce=True, + metadata={ + ORDER_KEY: "created", + ARCHIVE_AFTER: timedelta(minutes=1), + }, +) + +UserHealthAuditLogSchema = DataFrameSchema( + index=Index( + name="id", + dtype=int, + checks=[ + Check.between(min_value=1, max_value=BIGINT), + ], + ), + columns={ + "user_id": Column( + dtype="Int32", + checks=[ + Check.between(min_value=1, max_value=BIGINT), + ], + nullable=False, + ), + "created": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=False), + "level": Column( + dtype="Int32", + checks=[ + Check.between(min_value=0, max_value=32767), + ], + nullable=False, + ), + "event_type": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=64), + ], + nullable=False, + ), + "event_msg": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=256), + ], + nullable=True, + ), + "event_value": Column(dtype=float, nullable=True), + }, + checks=[], + coerce=True, + metadata={ + ORDER_KEY: "created", + ARCHIVE_AFTER: timedelta(minutes=1), + }, +) + +UserHealthIPHistorySchema = DataFrameSchema( + index=Index( + name="id", + dtype=int, + checks=[ + Check.between(min_value=1, max_value=BIGINT), + ], + ), + columns={ + "user_id": Column( + dtype="Int32", + checks=[ + Check.between(min_value=1, max_value=BIGINT), + ], + nullable=False, + ), + "ip": Column( + dtype=str, + checks=[ + Check.str_length(min_value=7), + Check.str_matches(pattern=IP_REGEX_PATTERN), + ], + nullable=False, + ), + "created": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=False), + "forwarded_ip1": Column( + dtype=str, + checks=[ + Check.str_length(min_value=7), + Check.str_matches(pattern=IP_REGEX_PATTERN), + ], + nullable=True, + ), + "forwarded_ip2": Column( + dtype=str, + checks=[ + Check.str_length(min_value=7), + Check.str_matches(pattern=IP_REGEX_PATTERN), + ], + nullable=True, + ), + "forwarded_ip3": Column( + dtype=str, + checks=[ + Check.str_length(min_value=7), + Check.str_matches(pattern=IP_REGEX_PATTERN), + ], + nullable=True, + ), + "forwarded_ip4": Column( + dtype=str, + checks=[ + Check.str_length(min_value=7), + Check.str_matches(pattern=IP_REGEX_PATTERN), + ], + nullable=True, + ), + "forwarded_ip5": Column( + dtype=str, + checks=[ + Check.str_length(min_value=7), + Check.str_matches(pattern=IP_REGEX_PATTERN), + ], + nullable=True, + ), + "forwarded_ip6": Column( + dtype=str, + checks=[ + Check.str_length(min_value=7), + Check.str_matches(pattern=IP_REGEX_PATTERN), + ], + nullable=True, + ), + }, + checks=[], + coerce=True, + metadata={ORDER_KEY: "created", ARCHIVE_AFTER: timedelta(minutes=1)}, +) + +UserHealthIPHistoryWSSchema = DataFrameSchema( + index=Index( + name="id", + dtype=int, + checks=[ + Check.between(min_value=1, max_value=BIGINT), + ], + ), + columns={ + "user_id": Column( + dtype="Int32", + checks=[ + Check.between(min_value=1, max_value=BIGINT), + ], + nullable=False, + ), + "ip": Column( + dtype=str, + checks=[ + Check.str_length(min_value=7), + Check.str_matches(pattern=IP_REGEX_PATTERN), + ], + nullable=False, + ), + "created": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=False), + "last_seen": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=False), + }, + checks=[], + coerce=True, + metadata={ORDER_KEY: "last_seen", ARCHIVE_AFTER: timedelta(minutes=1)}, +) + +TxSchema = DataFrameSchema( + index=Index( + name="entry_id", + dtype=int, + checks=[ + Check.between(min_value=1, max_value=BIGINT), + ], + ), + columns={ + # ----------------- + # ledger_transaction + # ----------------- + "tx_id": Column( + dtype=int, + checks=[ + Check.between(min_value=1, max_value=BIGINT), + ], + ), + "created": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=False), + "ext_description": Column( + dtype=str, + checks=[Check.str_length(min_value=1, max_value=255)], + nullable=True, + ), + "tag": Column( + dtype=str, + checks=[Check.str_length(min_value=1, max_value=255)], + nullable=True, + ), + # ----------------- + # ledger_entry + # ----------------- + "direction": Column( + dtype="Int32", + checks=[ + Check.isin([-1, 1]), + ], + nullable=False, + ), + "amount": Column( + dtype="Int32", + checks=[Check.between(min_value=1, max_value=BIGINT)], + nullable=False, + ), + "account_id": Column( + dtype=str, + checks=[ + Check.str_length(min_value=32, max_value=32), + ], + nullable=False, + ), + # ----------------- + # ledger_account + # ----------------- + "display_name": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=64), + ], + nullable=False, + ), + "qualified_name": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=255), + ], + nullable=False, + # I don't think this can be unique in Pandera bc of the MultiIndex makes + # it show up twice... + unique=False, + ), + "account_type": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=30), + ], + nullable=True, + ), + "normal_balance": Column( + dtype="Int32", + checks=[ + Check.isin([-1, 1]), + ], + nullable=False, + ), + "reference_type": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=30), + ], + nullable=True, + ), + "reference_uuid": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=32), + ], + nullable=True, + ), + "currency": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=32), + ], + nullable=False, + ), + }, + checks=[], + coerce=True, + metadata={ORDER_KEY: "created", ARCHIVE_AFTER: timedelta(minutes=1)}, +) + +TxMetaSchema = DataFrameSchema( + index=MultiIndex( + indexes=[ + Index( + name="tx_id", + dtype=int, + checks=[ + Check.between(min_value=1, max_value=BIGINT), + ], + ), + Index( + name="tx_metadata_id", + dtype=int, + checks=[ + Check.between(min_value=1, max_value=BIGINT), + ], + ), + ] + ), + columns={ + "key": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=30), + Check.isin([e.value for e in TransactionMetadataColumns]), + ], + nullable=False, + ), + "value": Column( + dtype=str, + checks=[Check.str_length(min_value=1, max_value=255)], + nullable=False, + ), + }, + checks=[], + coerce=True, + metadata={ARCHIVE_AFTER: timedelta(minutes=1)}, +) + +meta_obj = {} +for e in TransactionMetadataColumns: + meta_obj[e.value] = Column( + dtype=str, checks=[Check.str_length(min_value=1, max_value=255)], nullable=True + ) + +# The weird hybrid DF that actually gets saved out +LedgerSchema = DataFrameSchema( + index=TxSchema.index, + columns=TxSchema.columns | meta_obj, + checks=[], + coerce=True, + metadata={ + ARCHIVE_AFTER: timedelta(minutes=1), + ORDER_KEY: "created", + }, +) diff --git a/generalresearch/locales/__init__.py b/generalresearch/locales/__init__.py new file mode 100644 index 0000000..88b72e6 --- /dev/null +++ b/generalresearch/locales/__init__.py @@ -0,0 +1,96 @@ +""" +THL/GR is using: + +country codes: ISO 3166-1 alpha-2 (two-letter codes) +https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2 +https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes + +language codes: ISO 639-2/B (three-letter codes) +https://en.wikipedia.org/wiki/ISO_639-2 +https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes +""" + +import json +import pkgutil +from typing import Set + + +class Localelator: + """ + EVERYTHING IS LOWERCASE!!! (except this comment) + """ + + lang_alpha2_to_alpha3b = dict() + lang_alpha3_to_alpha3b = dict() + languages = set() + + def __init__(self): + d = json.loads(pkgutil.get_data(__name__, "iso639-3.json")) + self.lang_alpha2_to_alpha3b = {x["alpha_2"]: x["alpha_3b"] for x in d} + self.lang_alpha3_to_alpha3b = {x["alpha_3"]: x["alpha_3b"] for x in d} + self.languages = ( + set(self.lang_alpha2_to_alpha3b.keys()) + | set(self.lang_alpha2_to_alpha3b.values()) + | set(self.lang_alpha3_to_alpha3b.keys()) + ) + d = json.loads(pkgutil.get_data(__name__, "iso3166-1.json")) + self.country_alpha3_to_alpha2 = {x["alpha_3"]: x["alpha_2"] for x in d} + self.countries = set(self.country_alpha3_to_alpha2.keys()) | set( + self.country_alpha3_to_alpha2.values() + ) + + self.country_default_lang = json.loads( + pkgutil.get_data(__name__, "country_default_lang.json") + ) + + def get_all_languages(self) -> Set[str]: + # returns only the ISO 639-2/B (three-letter codes) + return set(self.lang_alpha2_to_alpha3b.values()) + + def get_all_countries(self) -> Set[str]: + # returns only the ISO 3166-1 alpha-2 (two-letter codes) + return set(self.country_alpha3_to_alpha2.values()) + + def get_language_iso(self, input_iso: str) -> str: + # input_iso is a 2 (ISO 639-1) or 3 (ISO 639-2/T) char language ISO + # output is a 3 char ISO 639-2/B + assert len(input_iso) in { + 2, + 3, + }, f"input_iso must be len 2 or 3, got: {input_iso}" + assert input_iso.lower() == input_iso, "input_iso must be lowercase" + assert ( + input_iso in self.languages + ), f"language input_iso: {input_iso} not recognized" + + return ( + self.lang_alpha2_to_alpha3b.get(input_iso) + or self.lang_alpha3_to_alpha3b.get(input_iso) + or input_iso + ) + + def get_country_iso(self, input_iso: str) -> str: + # input_iso is a 2 (ISO 3166-1 alpha-2) or 3 (ISO 3166-1 alpha-3) char country ISO + # output is a 2 char ISO 3166-1 alpha-2 + assert len(input_iso) in { + 2, + 3, + }, f"input_iso must be len 2 or 3, got: {input_iso}" + assert input_iso.lower() == input_iso, "input_iso must be lowercase" + assert ( + input_iso in self.countries + ), f"country input_iso: {input_iso} not recognized" + return self.country_alpha3_to_alpha2.get(input_iso) or input_iso + + def get_default_lang_from_country(self, input_iso): + country_iso = self.get_country_iso(input_iso) + return self.country_default_lang.get(country_iso) + + def run_tests(self): + assert self.get_language_iso("de") == "ger" + assert self.get_language_iso("deu") == "ger" + assert self.get_language_iso("ger") == "ger" + assert self.get_country_iso("deu") == "de" + assert self.get_country_iso("de") == "de" + assert self.get_default_lang_from_country("deu") == "ger" + assert self.get_default_lang_from_country("de") == "ger" diff --git a/generalresearch/locales/country_default_lang.json b/generalresearch/locales/country_default_lang.json new file mode 100644 index 0000000..fdb7738 --- /dev/null +++ b/generalresearch/locales/country_default_lang.json @@ -0,0 +1,250 @@ +{ + "ad": "cat", + "ae": "ara", + "af": "per", + "ag": "eng", + "ai": "eng", + "al": "alb", + "am": "arm", + "ao": "por", + "aq": "eng", + "ar": "spa", + "as": "eng", + "at": "ger", + "au": "eng", + "aw": "dut", + "ax": "swe", + "az": "aze", + "ba": "bos", + "bb": "eng", + "bd": "ben", + "be": "dut", + "bf": "fre", + "bg": "bul", + "bh": "ara", + "bi": "fre", + "bj": "fre", + "bl": "fre", + "bm": "eng", + "bn": "may", + "bo": "spa", + "bq": "dut", + "br": "por", + "bs": "eng", + "bt": "dzo", + "bv": "eng", + "bw": "eng", + "by": "bel", + "bz": "eng", + "ca": "eng", + "cc": "may", + "cd": "fre", + "cf": "fre", + "cg": "fre", + "ch": "ger", + "ci": "fre", + "ck": "eng", + "cl": "spa", + "cm": "eng", + "cn": "chi", + "co": "spa", + "cr": "spa", + "cu": "spa", + "cv": "por", + "cw": "dut", + "cx": "eng", + "cy": "gre", + "cz": "cze", + "de": "ger", + "dj": "fre", + "dk": "dan", + "dm": "eng", + "do": "spa", + "dz": "ara", + "ec": "spa", + "ee": "est", + "eg": "ara", + "eh": "ara", + "er": "aar", + "es": "spa", + "et": "amh", + "fi": "fin", + "fj": "eng", + "fk": "eng", + "fm": "eng", + "fo": "fao", + "fr": "fre", + "ga": "fre", + "gb": "eng", + "gd": "eng", + "ge": "geo", + "gf": "fre", + "gg": "eng", + "gh": "eng", + "gi": "eng", + "gl": "kal", + "gm": "eng", + "gn": "fre", + "gp": "fre", + "gq": "spa", + "gr": "gre", + "gs": "eng", + "gt": "spa", + "gu": "eng", + "gw": "por", + "gy": "eng", + "hk": "chi", + "hm": "eng", + "hn": "spa", + "hr": "hrv", + "ht": "hat", + "hu": "hun", + "id": "ind", + "ie": "eng", + "il": "heb", + "im": "eng", + "in": "eng", + "io": "eng", + "iq": "ara", + "ir": "per", + "is": "ice", + "it": "ita", + "je": "eng", + "jm": "eng", + "jo": "ara", + "jp": "jpn", + "ke": "eng", + "kg": "kir", + "kh": "khm", + "ki": "eng", + "km": "ara", + "kn": "eng", + "kp": "kor", + "kr": "kor", + "kw": "ara", + "ky": "eng", + "kz": "kaz", + "la": "lao", + "lb": "ara", + "lc": "eng", + "li": "ger", + "lk": "sin", + "lr": "eng", + "ls": "eng", + "lt": "lit", + "lu": "ltz", + "lv": "lav", + "ly": "ara", + "ma": "ara", + "mc": "fre", + "md": "rum", + "me": "srp", + "mf": "fre", + "mg": "fre", + "mh": "mah", + "mk": "mac", + "ml": "fre", + "mm": "bur", + "mn": "mon", + "mo": "chi", + "mp": "eng", + "mq": "fre", + "mr": "ara", + "ms": "eng", + "mt": "mlt", + "mu": "eng", + "mv": "div", + "mw": "nya", + "mx": "spa", + "my": "may", + "mz": "por", + "nc": "fre", + "ne": "fre", + "nf": "eng", + "ng": "eng", + "ni": "spa", + "nl": "dut", + "no": "nor", + "np": "nep", + "nr": "nau", + "nu": "eng", + "nz": "eng", + "om": "ara", + "pa": "spa", + "pe": "spa", + "pf": "fre", + "pg": "eng", + "ph": "tgl", + "pk": "urd", + "pl": "pol", + "pm": "fre", + "pn": "eng", + "pr": "eng", + "ps": "ara", + "pt": "por", + "pw": "eng", + "py": "spa", + "qa": "ara", + "re": "fre", + "ro": "rum", + "rs": "srp", + "ru": "rus", + "rw": "kin", + "sa": "ara", + "sb": "eng", + "sc": "eng", + "sd": "ara", + "ss": "eng", + "se": "swe", + "sg": "eng", + "sh": "eng", + "si": "slv", + "sj": "nor", + "sk": "slo", + "sl": "eng", + "sm": "ita", + "sn": "fre", + "so": "som", + "sr": "dut", + "st": "por", + "sv": "spa", + "sx": "dut", + "sy": "ara", + "sz": "eng", + "tc": "eng", + "td": "fre", + "tf": "fre", + "tg": "fre", + "th": "tha", + "tj": "tgk", + "tk": "eng", + "tl": "eng", + "tm": "tuk", + "tn": "ara", + "to": "ton", + "tr": "tur", + "tt": "eng", + "tv": "eng", + "tw": "chi", + "tz": "swa", + "ua": "ukr", + "ug": "eng", + "um": "eng", + "us": "eng", + "uy": "spa", + "uz": "uzb", + "va": "lat", + "vc": "eng", + "ve": "spa", + "vg": "eng", + "vi": "eng", + "vn": "vie", + "vu": "bis", + "wf": "eng", + "ws": "smo", + "ye": "ara", + "yt": "fre", + "za": "zul", + "zm": "eng", + "zw": "eng" +} \ No newline at end of file diff --git a/generalresearch/locales/iso3166-1.json b/generalresearch/locales/iso3166-1.json new file mode 100644 index 0000000..b140d93 --- /dev/null +++ b/generalresearch/locales/iso3166-1.json @@ -0,0 +1,1675 @@ +[ + { + "alpha_2": "aw", + "alpha_3": "abw", + "name": "Aruba", + "numeric": "533" + }, + { + "alpha_2": "af", + "alpha_3": "afg", + "name": "Afghanistan", + "numeric": "004", + "official_name": "Islamic Republic of Afghanistan" + }, + { + "alpha_2": "ao", + "alpha_3": "ago", + "name": "Angola", + "numeric": "024", + "official_name": "Republic of Angola" + }, + { + "alpha_2": "ai", + "alpha_3": "aia", + "name": "Anguilla", + "numeric": "660" + }, + { + "alpha_2": "ax", + "alpha_3": "ala", + "name": "\u00c5land Islands", + "numeric": "248" + }, + { + "alpha_2": "al", + "alpha_3": "alb", + "name": "Albania", + "numeric": "008", + "official_name": "Republic of Albania" + }, + { + "alpha_2": "ad", + "alpha_3": "and", + "name": "Andorra", + "numeric": "020", + "official_name": "Principality of Andorra" + }, + { + "alpha_2": "ae", + "alpha_3": "are", + "name": "United Arab Emirates", + "numeric": "784" + }, + { + "alpha_2": "ar", + "alpha_3": "arg", + "name": "Argentina", + "numeric": "032", + "official_name": "Argentine Republic" + }, + { + "alpha_2": "am", + "alpha_3": "arm", + "name": "Armenia", + "numeric": "051", + "official_name": "Republic of Armenia" + }, + { + "alpha_2": "as", + "alpha_3": "asm", + "name": "American Samoa", + "numeric": "016" + }, + { + "alpha_2": "aq", + "alpha_3": "ata", + "name": "Antarctica", + "numeric": "010" + }, + { + "alpha_2": "tf", + "alpha_3": "atf", + "name": "French Southern Territories", + "numeric": "260" + }, + { + "alpha_2": "ag", + "alpha_3": "atg", + "name": "Antigua and Barbuda", + "numeric": "028" + }, + { + "alpha_2": "au", + "alpha_3": "aus", + "name": "Australia", + "numeric": "036" + }, + { + "alpha_2": "at", + "alpha_3": "aut", + "name": "Austria", + "numeric": "040", + "official_name": "Republic of Austria" + }, + { + "alpha_2": "az", + "alpha_3": "aze", + "name": "Azerbaijan", + "numeric": "031", + "official_name": "Republic of Azerbaijan" + }, + { + "alpha_2": "bi", + "alpha_3": "bdi", + "name": "Burundi", + "numeric": "108", + "official_name": "Republic of Burundi" + }, + { + "alpha_2": "be", + "alpha_3": "bel", + "name": "Belgium", + "numeric": "056", + "official_name": "Kingdom of Belgium" + }, + { + "alpha_2": "bj", + "alpha_3": "ben", + "name": "Benin", + "numeric": "204", + "official_name": "Republic of Benin" + }, + { + "alpha_2": "bq", + "alpha_3": "bes", + "name": "Bonaire, Sint Eustatius and Saba", + "numeric": "535", + "official_name": "Bonaire, Sint Eustatius and Saba" + }, + { + "alpha_2": "bf", + "alpha_3": "bfa", + "name": "Burkina Faso", + "numeric": "854" + }, + { + "alpha_2": "bd", + "alpha_3": "bgd", + "name": "Bangladesh", + "numeric": "050", + "official_name": "People's Republic of Bangladesh" + }, + { + "alpha_2": "bg", + "alpha_3": "bgr", + "name": "Bulgaria", + "numeric": "100", + "official_name": "Republic of Bulgaria" + }, + { + "alpha_2": "bh", + "alpha_3": "bhr", + "name": "Bahrain", + "numeric": "048", + "official_name": "Kingdom of Bahrain" + }, + { + "alpha_2": "bs", + "alpha_3": "bhs", + "name": "Bahamas", + "numeric": "044", + "official_name": "Commonwealth of the Bahamas" + }, + { + "alpha_2": "ba", + "alpha_3": "bih", + "name": "Bosnia and Herzegovina", + "numeric": "070", + "official_name": "Republic of Bosnia and Herzegovina" + }, + { + "alpha_2": "bl", + "alpha_3": "blm", + "name": "Saint Barth\u00e9lemy", + "numeric": "652" + }, + { + "alpha_2": "by", + "alpha_3": "blr", + "name": "Belarus", + "numeric": "112", + "official_name": "Republic of Belarus" + }, + { + "alpha_2": "bz", + "alpha_3": "blz", + "name": "Belize", + "numeric": "084" + }, + { + "alpha_2": "bm", + "alpha_3": "bmu", + "name": "Bermuda", + "numeric": "060" + }, + { + "alpha_2": "bo", + "alpha_3": "bol", + "common_name": "Bolivia", + "name": "Bolivia, Plurinational State of", + "numeric": "068", + "official_name": "Plurinational State of Bolivia" + }, + { + "alpha_2": "br", + "alpha_3": "bra", + "name": "Brazil", + "numeric": "076", + "official_name": "Federative Republic of Brazil" + }, + { + "alpha_2": "bb", + "alpha_3": "brb", + "name": "Barbados", + "numeric": "052" + }, + { + "alpha_2": "bn", + "alpha_3": "brn", + "name": "Brunei Darussalam", + "numeric": "096" + }, + { + "alpha_2": "bt", + "alpha_3": "btn", + "name": "Bhutan", + "numeric": "064", + "official_name": "Kingdom of Bhutan" + }, + { + "alpha_2": "bv", + "alpha_3": "bvt", + "name": "Bouvet Island", + "numeric": "074" + }, + { + "alpha_2": "bw", + "alpha_3": "bwa", + "name": "Botswana", + "numeric": "072", + "official_name": "Republic of Botswana" + }, + { + "alpha_2": "cf", + "alpha_3": "caf", + "name": "Central African Republic", + "numeric": "140" + }, + { + "alpha_2": "ca", + "alpha_3": "can", + "name": "Canada", + "numeric": "124" + }, + { + "alpha_2": "cc", + "alpha_3": "cck", + "name": "Cocos (Keeling) Islands", + "numeric": "166" + }, + { + "alpha_2": "ch", + "alpha_3": "che", + "name": "Switzerland", + "numeric": "756", + "official_name": "Swiss Confederation" + }, + { + "alpha_2": "cl", + "alpha_3": "chl", + "name": "Chile", + "numeric": "152", + "official_name": "Republic of Chile" + }, + { + "alpha_2": "cn", + "alpha_3": "chn", + "name": "China", + "numeric": "156", + "official_name": "People's Republic of China" + }, + { + "alpha_2": "ci", + "alpha_3": "civ", + "name": "C\u00f4te d'Ivoire", + "numeric": "384", + "official_name": "Republic of C\u00f4te d'Ivoire" + }, + { + "alpha_2": "cm", + "alpha_3": "cmr", + "name": "Cameroon", + "numeric": "120", + "official_name": "Republic of Cameroon" + }, + { + "alpha_2": "cd", + "alpha_3": "cod", + "name": "Congo, The Democratic Republic of the", + "numeric": "180" + }, + { + "alpha_2": "cg", + "alpha_3": "cog", + "name": "Congo", + "numeric": "178", + "official_name": "Republic of the Congo" + }, + { + "alpha_2": "ck", + "alpha_3": "cok", + "name": "Cook Islands", + "numeric": "184" + }, + { + "alpha_2": "co", + "alpha_3": "col", + "name": "Colombia", + "numeric": "170", + "official_name": "Republic of Colombia" + }, + { + "alpha_2": "km", + "alpha_3": "com", + "name": "Comoros", + "numeric": "174", + "official_name": "Union of the Comoros" + }, + { + "alpha_2": "cv", + "alpha_3": "cpv", + "name": "Cabo Verde", + "numeric": "132", + "official_name": "Republic of Cabo Verde" + }, + { + "alpha_2": "cr", + "alpha_3": "cri", + "name": "Costa Rica", + "numeric": "188", + "official_name": "Republic of Costa Rica" + }, + { + "alpha_2": "cu", + "alpha_3": "cub", + "name": "Cuba", + "numeric": "192", + "official_name": "Republic of Cuba" + }, + { + "alpha_2": "cw", + "alpha_3": "cuw", + "name": "Cura\u00e7ao", + "numeric": "531", + "official_name": "Cura\u00e7ao" + }, + { + "alpha_2": "cx", + "alpha_3": "cxr", + "name": "Christmas Island", + "numeric": "162" + }, + { + "alpha_2": "ky", + "alpha_3": "cym", + "name": "Cayman Islands", + "numeric": "136" + }, + { + "alpha_2": "cy", + "alpha_3": "cyp", + "name": "Cyprus", + "numeric": "196", + "official_name": "Republic of Cyprus" + }, + { + "alpha_2": "cz", + "alpha_3": "cze", + "name": "Czechia", + "numeric": "203", + "official_name": "Czech Republic" + }, + { + "alpha_2": "de", + "alpha_3": "deu", + "name": "Germany", + "numeric": "276", + "official_name": "Federal Republic of Germany" + }, + { + "alpha_2": "dj", + "alpha_3": "dji", + "name": "Djibouti", + "numeric": "262", + "official_name": "Republic of Djibouti" + }, + { + "alpha_2": "dm", + "alpha_3": "dma", + "name": "Dominica", + "numeric": "212", + "official_name": "Commonwealth of Dominica" + }, + { + "alpha_2": "dk", + "alpha_3": "dnk", + "name": "Denmark", + "numeric": "208", + "official_name": "Kingdom of Denmark" + }, + { + "alpha_2": "do", + "alpha_3": "dom", + "name": "Dominican Republic", + "numeric": "214" + }, + { + "alpha_2": "dz", + "alpha_3": "dza", + "name": "Algeria", + "numeric": "012", + "official_name": "People's Democratic Republic of Algeria" + }, + { + "alpha_2": "ec", + "alpha_3": "ecu", + "name": "Ecuador", + "numeric": "218", + "official_name": "Republic of Ecuador" + }, + { + "alpha_2": "eg", + "alpha_3": "egy", + "name": "Egypt", + "numeric": "818", + "official_name": "Arab Republic of Egypt" + }, + { + "alpha_2": "er", + "alpha_3": "eri", + "name": "Eritrea", + "numeric": "232", + "official_name": "the State of Eritrea" + }, + { + "alpha_2": "eh", + "alpha_3": "esh", + "name": "Western Sahara", + "numeric": "732" + }, + { + "alpha_2": "es", + "alpha_3": "esp", + "name": "Spain", + "numeric": "724", + "official_name": "Kingdom of Spain" + }, + { + "alpha_2": "ee", + "alpha_3": "est", + "name": "Estonia", + "numeric": "233", + "official_name": "Republic of Estonia" + }, + { + "alpha_2": "et", + "alpha_3": "eth", + "name": "Ethiopia", + "numeric": "231", + "official_name": "Federal Democratic Republic of Ethiopia" + }, + { + "alpha_2": "fi", + "alpha_3": "fin", + "name": "Finland", + "numeric": "246", + "official_name": "Republic of Finland" + }, + { + "alpha_2": "fj", + "alpha_3": "fji", + "name": "Fiji", + "numeric": "242", + "official_name": "Republic of Fiji" + }, + { + "alpha_2": "fk", + "alpha_3": "flk", + "name": "Falkland Islands (Malvinas)", + "numeric": "238" + }, + { + "alpha_2": "fr", + "alpha_3": "fra", + "name": "France", + "numeric": "250", + "official_name": "French Republic" + }, + { + "alpha_2": "fo", + "alpha_3": "fro", + "name": "Faroe Islands", + "numeric": "234" + }, + { + "alpha_2": "fm", + "alpha_3": "fsm", + "name": "Micronesia, Federated States of", + "numeric": "583", + "official_name": "Federated States of Micronesia" + }, + { + "alpha_2": "ga", + "alpha_3": "gab", + "name": "Gabon", + "numeric": "266", + "official_name": "Gabonese Republic" + }, + { + "alpha_2": "gb", + "alpha_3": "gbr", + "name": "United Kingdom", + "numeric": "826", + "official_name": "United Kingdom of Great Britain and Northern Ireland" + }, + { + "alpha_2": "ge", + "alpha_3": "geo", + "name": "Georgia", + "numeric": "268" + }, + { + "alpha_2": "gg", + "alpha_3": "ggy", + "name": "Guernsey", + "numeric": "831" + }, + { + "alpha_2": "gh", + "alpha_3": "gha", + "name": "Ghana", + "numeric": "288", + "official_name": "Republic of Ghana" + }, + { + "alpha_2": "gi", + "alpha_3": "gib", + "name": "Gibraltar", + "numeric": "292" + }, + { + "alpha_2": "gn", + "alpha_3": "gin", + "name": "Guinea", + "numeric": "324", + "official_name": "Republic of Guinea" + }, + { + "alpha_2": "gp", + "alpha_3": "glp", + "name": "Guadeloupe", + "numeric": "312" + }, + { + "alpha_2": "gm", + "alpha_3": "gmb", + "name": "Gambia", + "numeric": "270", + "official_name": "Republic of the Gambia" + }, + { + "alpha_2": "gw", + "alpha_3": "gnb", + "name": "Guinea-Bissau", + "numeric": "624", + "official_name": "Republic of Guinea-Bissau" + }, + { + "alpha_2": "gq", + "alpha_3": "gnq", + "name": "Equatorial Guinea", + "numeric": "226", + "official_name": "Republic of Equatorial Guinea" + }, + { + "alpha_2": "gr", + "alpha_3": "grc", + "name": "Greece", + "numeric": "300", + "official_name": "Hellenic Republic" + }, + { + "alpha_2": "gd", + "alpha_3": "grd", + "name": "Grenada", + "numeric": "308" + }, + { + "alpha_2": "gl", + "alpha_3": "grl", + "name": "Greenland", + "numeric": "304" + }, + { + "alpha_2": "gt", + "alpha_3": "gtm", + "name": "Guatemala", + "numeric": "320", + "official_name": "Republic of Guatemala" + }, + { + "alpha_2": "gf", + "alpha_3": "guf", + "name": "French Guiana", + "numeric": "254" + }, + { + "alpha_2": "gu", + "alpha_3": "gum", + "name": "Guam", + "numeric": "316" + }, + { + "alpha_2": "gy", + "alpha_3": "guy", + "name": "Guyana", + "numeric": "328", + "official_name": "Republic of Guyana" + }, + { + "alpha_2": "hk", + "alpha_3": "hkg", + "name": "Hong Kong", + "numeric": "344", + "official_name": "Hong Kong Special Administrative Region of China" + }, + { + "alpha_2": "hm", + "alpha_3": "hmd", + "name": "Heard Island and McDonald Islands", + "numeric": "334" + }, + { + "alpha_2": "hn", + "alpha_3": "hnd", + "name": "Honduras", + "numeric": "340", + "official_name": "Republic of Honduras" + }, + { + "alpha_2": "hr", + "alpha_3": "hrv", + "name": "Croatia", + "numeric": "191", + "official_name": "Republic of Croatia" + }, + { + "alpha_2": "ht", + "alpha_3": "hti", + "name": "Haiti", + "numeric": "332", + "official_name": "Republic of Haiti" + }, + { + "alpha_2": "hu", + "alpha_3": "hun", + "name": "Hungary", + "numeric": "348", + "official_name": "Hungary" + }, + { + "alpha_2": "id", + "alpha_3": "idn", + "name": "Indonesia", + "numeric": "360", + "official_name": "Republic of Indonesia" + }, + { + "alpha_2": "im", + "alpha_3": "imn", + "name": "Isle of Man", + "numeric": "833" + }, + { + "alpha_2": "in", + "alpha_3": "ind", + "name": "India", + "numeric": "356", + "official_name": "Republic of India" + }, + { + "alpha_2": "io", + "alpha_3": "iot", + "name": "British Indian Ocean Territory", + "numeric": "086" + }, + { + "alpha_2": "ie", + "alpha_3": "irl", + "name": "Ireland", + "numeric": "372" + }, + { + "alpha_2": "ir", + "alpha_3": "irn", + "name": "Iran, Islamic Republic of", + "numeric": "364", + "official_name": "Islamic Republic of Iran" + }, + { + "alpha_2": "iq", + "alpha_3": "irq", + "name": "Iraq", + "numeric": "368", + "official_name": "Republic of Iraq" + }, + { + "alpha_2": "is", + "alpha_3": "isl", + "name": "Iceland", + "numeric": "352", + "official_name": "Republic of Iceland" + }, + { + "alpha_2": "il", + "alpha_3": "isr", + "name": "Israel", + "numeric": "376", + "official_name": "State of Israel" + }, + { + "alpha_2": "it", + "alpha_3": "ita", + "name": "Italy", + "numeric": "380", + "official_name": "Italian Republic" + }, + { + "alpha_2": "jm", + "alpha_3": "jam", + "name": "Jamaica", + "numeric": "388" + }, + { + "alpha_2": "je", + "alpha_3": "jey", + "name": "Jersey", + "numeric": "832" + }, + { + "alpha_2": "jo", + "alpha_3": "jor", + "name": "Jordan", + "numeric": "400", + "official_name": "Hashemite Kingdom of Jordan" + }, + { + "alpha_2": "jp", + "alpha_3": "jpn", + "name": "Japan", + "numeric": "392" + }, + { + "alpha_2": "kz", + "alpha_3": "kaz", + "name": "Kazakhstan", + "numeric": "398", + "official_name": "Republic of Kazakhstan" + }, + { + "alpha_2": "ke", + "alpha_3": "ken", + "name": "Kenya", + "numeric": "404", + "official_name": "Republic of Kenya" + }, + { + "alpha_2": "kg", + "alpha_3": "kgz", + "name": "Kyrgyzstan", + "numeric": "417", + "official_name": "Kyrgyz Republic" + }, + { + "alpha_2": "kh", + "alpha_3": "khm", + "name": "Cambodia", + "numeric": "116", + "official_name": "Kingdom of Cambodia" + }, + { + "alpha_2": "ki", + "alpha_3": "kir", + "name": "Kiribati", + "numeric": "296", + "official_name": "Republic of Kiribati" + }, + { + "alpha_2": "kn", + "alpha_3": "kna", + "name": "Saint Kitts and Nevis", + "numeric": "659" + }, + { + "alpha_2": "kr", + "alpha_3": "kor", + "name": "Korea, Republic of", + "numeric": "410" + }, + { + "alpha_2": "kw", + "alpha_3": "kwt", + "name": "Kuwait", + "numeric": "414", + "official_name": "State of Kuwait" + }, + { + "alpha_2": "la", + "alpha_3": "lao", + "name": "Lao People's Democratic Republic", + "numeric": "418" + }, + { + "alpha_2": "lb", + "alpha_3": "lbn", + "name": "Lebanon", + "numeric": "422", + "official_name": "Lebanese Republic" + }, + { + "alpha_2": "lr", + "alpha_3": "lbr", + "name": "Liberia", + "numeric": "430", + "official_name": "Republic of Liberia" + }, + { + "alpha_2": "ly", + "alpha_3": "lby", + "name": "Libya", + "numeric": "434", + "official_name": "Libya" + }, + { + "alpha_2": "lc", + "alpha_3": "lca", + "name": "Saint Lucia", + "numeric": "662" + }, + { + "alpha_2": "li", + "alpha_3": "lie", + "name": "Liechtenstein", + "numeric": "438", + "official_name": "Principality of Liechtenstein" + }, + { + "alpha_2": "lk", + "alpha_3": "lka", + "name": "Sri Lanka", + "numeric": "144", + "official_name": "Democratic Socialist Republic of Sri Lanka" + }, + { + "alpha_2": "ls", + "alpha_3": "lso", + "name": "Lesotho", + "numeric": "426", + "official_name": "Kingdom of Lesotho" + }, + { + "alpha_2": "lt", + "alpha_3": "ltu", + "name": "Lithuania", + "numeric": "440", + "official_name": "Republic of Lithuania" + }, + { + "alpha_2": "lu", + "alpha_3": "lux", + "name": "Luxembourg", + "numeric": "442", + "official_name": "Grand Duchy of Luxembourg" + }, + { + "alpha_2": "lv", + "alpha_3": "lva", + "name": "Latvia", + "numeric": "428", + "official_name": "Republic of Latvia" + }, + { + "alpha_2": "mo", + "alpha_3": "mac", + "name": "Macao", + "numeric": "446", + "official_name": "Macao Special Administrative Region of China" + }, + { + "alpha_2": "mf", + "alpha_3": "maf", + "name": "Saint Martin (French part)", + "numeric": "663" + }, + { + "alpha_2": "ma", + "alpha_3": "mar", + "name": "Morocco", + "numeric": "504", + "official_name": "Kingdom of Morocco" + }, + { + "alpha_2": "mc", + "alpha_3": "mco", + "name": "Monaco", + "numeric": "492", + "official_name": "Principality of Monaco" + }, + { + "alpha_2": "md", + "alpha_3": "mda", + "common_name": "Moldova", + "name": "Moldova, Republic of", + "numeric": "498", + "official_name": "Republic of Moldova" + }, + { + "alpha_2": "mg", + "alpha_3": "mdg", + "name": "Madagascar", + "numeric": "450", + "official_name": "Republic of Madagascar" + }, + { + "alpha_2": "mv", + "alpha_3": "mdv", + "name": "Maldives", + "numeric": "462", + "official_name": "Republic of Maldives" + }, + { + "alpha_2": "mx", + "alpha_3": "mex", + "name": "Mexico", + "numeric": "484", + "official_name": "United Mexican States" + }, + { + "alpha_2": "mh", + "alpha_3": "mhl", + "name": "Marshall Islands", + "numeric": "584", + "official_name": "Republic of the Marshall Islands" + }, + { + "alpha_2": "mk", + "alpha_3": "mkd", + "name": "North Macedonia", + "numeric": "807", + "official_name": "Republic of North Macedonia" + }, + { + "alpha_2": "ml", + "alpha_3": "mli", + "name": "Mali", + "numeric": "466", + "official_name": "Republic of Mali" + }, + { + "alpha_2": "mt", + "alpha_3": "mlt", + "name": "Malta", + "numeric": "470", + "official_name": "Republic of Malta" + }, + { + "alpha_2": "mm", + "alpha_3": "mmr", + "name": "Myanmar", + "numeric": "104", + "official_name": "Republic of Myanmar" + }, + { + "alpha_2": "me", + "alpha_3": "mne", + "name": "Montenegro", + "numeric": "499", + "official_name": "Montenegro" + }, + { + "alpha_2": "mn", + "alpha_3": "mng", + "name": "Mongolia", + "numeric": "496" + }, + { + "alpha_2": "mp", + "alpha_3": "mnp", + "name": "Northern Mariana Islands", + "numeric": "580", + "official_name": "Commonwealth of the Northern Mariana Islands" + }, + { + "alpha_2": "mz", + "alpha_3": "moz", + "name": "Mozambique", + "numeric": "508", + "official_name": "Republic of Mozambique" + }, + { + "alpha_2": "mr", + "alpha_3": "mrt", + "name": "Mauritania", + "numeric": "478", + "official_name": "Islamic Republic of Mauritania" + }, + { + "alpha_2": "ms", + "alpha_3": "msr", + "name": "Montserrat", + "numeric": "500" + }, + { + "alpha_2": "mq", + "alpha_3": "mtq", + "name": "Martinique", + "numeric": "474" + }, + { + "alpha_2": "mu", + "alpha_3": "mus", + "name": "Mauritius", + "numeric": "480", + "official_name": "Republic of Mauritius" + }, + { + "alpha_2": "mw", + "alpha_3": "mwi", + "name": "Malawi", + "numeric": "454", + "official_name": "Republic of Malawi" + }, + { + "alpha_2": "my", + "alpha_3": "mys", + "name": "Malaysia", + "numeric": "458" + }, + { + "alpha_2": "yt", + "alpha_3": "myt", + "name": "Mayotte", + "numeric": "175" + }, + { + "alpha_2": "na", + "alpha_3": "nam", + "name": "Namibia", + "numeric": "516", + "official_name": "Republic of Namibia" + }, + { + "alpha_2": "nc", + "alpha_3": "ncl", + "name": "New Caledonia", + "numeric": "540" + }, + { + "alpha_2": "ne", + "alpha_3": "ner", + "name": "Niger", + "numeric": "562", + "official_name": "Republic of the Niger" + }, + { + "alpha_2": "nf", + "alpha_3": "nfk", + "name": "Norfolk Island", + "numeric": "574" + }, + { + "alpha_2": "ng", + "alpha_3": "nga", + "name": "Nigeria", + "numeric": "566", + "official_name": "Federal Republic of Nigeria" + }, + { + "alpha_2": "ni", + "alpha_3": "nic", + "name": "Nicaragua", + "numeric": "558", + "official_name": "Republic of Nicaragua" + }, + { + "alpha_2": "nu", + "alpha_3": "niu", + "name": "Niue", + "numeric": "570", + "official_name": "Niue" + }, + { + "alpha_2": "nl", + "alpha_3": "nld", + "name": "Netherlands", + "numeric": "528", + "official_name": "Kingdom of the Netherlands" + }, + { + "alpha_2": "no", + "alpha_3": "nor", + "name": "Norway", + "numeric": "578", + "official_name": "Kingdom of Norway" + }, + { + "alpha_2": "np", + "alpha_3": "npl", + "name": "Nepal", + "numeric": "524", + "official_name": "Federal Democratic Republic of Nepal" + }, + { + "alpha_2": "nr", + "alpha_3": "nru", + "name": "Nauru", + "numeric": "520", + "official_name": "Republic of Nauru" + }, + { + "alpha_2": "nz", + "alpha_3": "nzl", + "name": "New Zealand", + "numeric": "554" + }, + { + "alpha_2": "om", + "alpha_3": "omn", + "name": "Oman", + "numeric": "512", + "official_name": "Sultanate of Oman" + }, + { + "alpha_2": "pk", + "alpha_3": "pak", + "name": "Pakistan", + "numeric": "586", + "official_name": "Islamic Republic of Pakistan" + }, + { + "alpha_2": "pa", + "alpha_3": "pan", + "name": "Panama", + "numeric": "591", + "official_name": "Republic of Panama" + }, + { + "alpha_2": "pn", + "alpha_3": "pcn", + "name": "Pitcairn", + "numeric": "612" + }, + { + "alpha_2": "pe", + "alpha_3": "per", + "name": "Peru", + "numeric": "604", + "official_name": "Republic of Peru" + }, + { + "alpha_2": "ph", + "alpha_3": "phl", + "name": "Philippines", + "numeric": "608", + "official_name": "Republic of the Philippines" + }, + { + "alpha_2": "pw", + "alpha_3": "plw", + "name": "Palau", + "numeric": "585", + "official_name": "Republic of Palau" + }, + { + "alpha_2": "pg", + "alpha_3": "png", + "name": "Papua New Guinea", + "numeric": "598", + "official_name": "Independent State of Papua New Guinea" + }, + { + "alpha_2": "pl", + "alpha_3": "pol", + "name": "Poland", + "numeric": "616", + "official_name": "Republic of Poland" + }, + { + "alpha_2": "pr", + "alpha_3": "pri", + "name": "Puerto Rico", + "numeric": "630" + }, + { + "alpha_2": "kp", + "alpha_3": "prk", + "name": "Korea, Democratic People's Republic of", + "numeric": "408", + "official_name": "Democratic People's Republic of Korea" + }, + { + "alpha_2": "pt", + "alpha_3": "prt", + "name": "Portugal", + "numeric": "620", + "official_name": "Portuguese Republic" + }, + { + "alpha_2": "py", + "alpha_3": "pry", + "name": "Paraguay", + "numeric": "600", + "official_name": "Republic of Paraguay" + }, + { + "alpha_2": "ps", + "alpha_3": "pse", + "name": "Palestine, State of", + "numeric": "275", + "official_name": "the State of Palestine" + }, + { + "alpha_2": "pf", + "alpha_3": "pyf", + "name": "French Polynesia", + "numeric": "258" + }, + { + "alpha_2": "qa", + "alpha_3": "qat", + "name": "Qatar", + "numeric": "634", + "official_name": "State of Qatar" + }, + { + "alpha_2": "re", + "alpha_3": "reu", + "name": "R\u00e9union", + "numeric": "638" + }, + { + "alpha_2": "ro", + "alpha_3": "rou", + "name": "Romania", + "numeric": "642" + }, + { + "alpha_2": "ru", + "alpha_3": "rus", + "name": "Russian Federation", + "numeric": "643" + }, + { + "alpha_2": "rw", + "alpha_3": "rwa", + "name": "Rwanda", + "numeric": "646", + "official_name": "Rwandese Republic" + }, + { + "alpha_2": "sa", + "alpha_3": "sau", + "name": "Saudi Arabia", + "numeric": "682", + "official_name": "Kingdom of Saudi Arabia" + }, + { + "alpha_2": "sd", + "alpha_3": "sdn", + "name": "Sudan", + "numeric": "729", + "official_name": "Republic of the Sudan" + }, + { + "alpha_2": "sn", + "alpha_3": "sen", + "name": "Senegal", + "numeric": "686", + "official_name": "Republic of Senegal" + }, + { + "alpha_2": "sg", + "alpha_3": "sgp", + "name": "Singapore", + "numeric": "702", + "official_name": "Republic of Singapore" + }, + { + "alpha_2": "gs", + "alpha_3": "sgs", + "name": "South Georgia and the South Sandwich Islands", + "numeric": "239" + }, + { + "alpha_2": "sh", + "alpha_3": "shn", + "name": "Saint Helena, Ascension and Tristan da Cunha", + "numeric": "654" + }, + { + "alpha_2": "sj", + "alpha_3": "sjm", + "name": "Svalbard and Jan Mayen", + "numeric": "744" + }, + { + "alpha_2": "sb", + "alpha_3": "slb", + "name": "Solomon Islands", + "numeric": "090" + }, + { + "alpha_2": "sl", + "alpha_3": "sle", + "name": "Sierra Leone", + "numeric": "694", + "official_name": "Republic of Sierra Leone" + }, + { + "alpha_2": "sv", + "alpha_3": "slv", + "name": "El Salvador", + "numeric": "222", + "official_name": "Republic of El Salvador" + }, + { + "alpha_2": "sm", + "alpha_3": "smr", + "name": "San Marino", + "numeric": "674", + "official_name": "Republic of San Marino" + }, + { + "alpha_2": "so", + "alpha_3": "som", + "name": "Somalia", + "numeric": "706", + "official_name": "Federal Republic of Somalia" + }, + { + "alpha_2": "pm", + "alpha_3": "spm", + "name": "Saint Pierre and Miquelon", + "numeric": "666" + }, + { + "alpha_2": "rs", + "alpha_3": "srb", + "name": "Serbia", + "numeric": "688", + "official_name": "Republic of Serbia" + }, + { + "alpha_2": "ss", + "alpha_3": "ssd", + "name": "South Sudan", + "numeric": "728", + "official_name": "Republic of South Sudan" + }, + { + "alpha_2": "st", + "alpha_3": "stp", + "name": "Sao Tome and Principe", + "numeric": "678", + "official_name": "Democratic Republic of Sao Tome and Principe" + }, + { + "alpha_2": "sr", + "alpha_3": "sur", + "name": "Suriname", + "numeric": "740", + "official_name": "Republic of Suriname" + }, + { + "alpha_2": "sk", + "alpha_3": "svk", + "name": "Slovakia", + "numeric": "703", + "official_name": "Slovak Republic" + }, + { + "alpha_2": "si", + "alpha_3": "svn", + "name": "Slovenia", + "numeric": "705", + "official_name": "Republic of Slovenia" + }, + { + "alpha_2": "se", + "alpha_3": "swe", + "name": "Sweden", + "numeric": "752", + "official_name": "Kingdom of Sweden" + }, + { + "alpha_2": "sz", + "alpha_3": "swz", + "name": "Eswatini", + "numeric": "748", + "official_name": "Kingdom of Eswatini" + }, + { + "alpha_2": "sx", + "alpha_3": "sxm", + "name": "Sint Maarten (Dutch part)", + "numeric": "534", + "official_name": "Sint Maarten (Dutch part)" + }, + { + "alpha_2": "sc", + "alpha_3": "syc", + "name": "Seychelles", + "numeric": "690", + "official_name": "Republic of Seychelles" + }, + { + "alpha_2": "sy", + "alpha_3": "syr", + "name": "Syrian Arab Republic", + "numeric": "760" + }, + { + "alpha_2": "tc", + "alpha_3": "tca", + "name": "Turks and Caicos Islands", + "numeric": "796" + }, + { + "alpha_2": "td", + "alpha_3": "tcd", + "name": "Chad", + "numeric": "148", + "official_name": "Republic of Chad" + }, + { + "alpha_2": "tg", + "alpha_3": "tgo", + "name": "Togo", + "numeric": "768", + "official_name": "Togolese Republic" + }, + { + "alpha_2": "th", + "alpha_3": "tha", + "name": "Thailand", + "numeric": "764", + "official_name": "Kingdom of Thailand" + }, + { + "alpha_2": "tj", + "alpha_3": "tjk", + "name": "Tajikistan", + "numeric": "762", + "official_name": "Republic of Tajikistan" + }, + { + "alpha_2": "tk", + "alpha_3": "tkl", + "name": "Tokelau", + "numeric": "772" + }, + { + "alpha_2": "tm", + "alpha_3": "tkm", + "name": "Turkmenistan", + "numeric": "795" + }, + { + "alpha_2": "tl", + "alpha_3": "tls", + "name": "Timor-Leste", + "numeric": "626", + "official_name": "Democratic Republic of Timor-Leste" + }, + { + "alpha_2": "to", + "alpha_3": "ton", + "name": "Tonga", + "numeric": "776", + "official_name": "Kingdom of Tonga" + }, + { + "alpha_2": "tt", + "alpha_3": "tto", + "name": "Trinidad and Tobago", + "numeric": "780", + "official_name": "Republic of Trinidad and Tobago" + }, + { + "alpha_2": "tn", + "alpha_3": "tun", + "name": "Tunisia", + "numeric": "788", + "official_name": "Republic of Tunisia" + }, + { + "alpha_2": "tr", + "alpha_3": "tur", + "name": "Turkey", + "numeric": "792", + "official_name": "Republic of Turkey" + }, + { + "alpha_2": "tv", + "alpha_3": "tuv", + "name": "Tuvalu", + "numeric": "798" + }, + { + "alpha_2": "tw", + "alpha_3": "twn", + "common_name": "Taiwan", + "name": "Taiwan, Province of China", + "numeric": "158", + "official_name": "Taiwan, Province of China" + }, + { + "alpha_2": "tz", + "alpha_3": "tza", + "common_name": "Tanzania", + "name": "Tanzania, United Republic of", + "numeric": "834", + "official_name": "United Republic of Tanzania" + }, + { + "alpha_2": "ug", + "alpha_3": "uga", + "name": "Uganda", + "numeric": "800", + "official_name": "Republic of Uganda" + }, + { + "alpha_2": "ua", + "alpha_3": "ukr", + "name": "Ukraine", + "numeric": "804" + }, + { + "alpha_2": "um", + "alpha_3": "umi", + "name": "United States Minor Outlying Islands", + "numeric": "581" + }, + { + "alpha_2": "uy", + "alpha_3": "ury", + "name": "Uruguay", + "numeric": "858", + "official_name": "Eastern Republic of Uruguay" + }, + { + "alpha_2": "us", + "alpha_3": "usa", + "name": "United States", + "numeric": "840", + "official_name": "United States of America" + }, + { + "alpha_2": "uz", + "alpha_3": "uzb", + "name": "Uzbekistan", + "numeric": "860", + "official_name": "Republic of Uzbekistan" + }, + { + "alpha_2": "va", + "alpha_3": "vat", + "name": "Holy See (Vatican City State)", + "numeric": "336" + }, + { + "alpha_2": "vc", + "alpha_3": "vct", + "name": "Saint Vincent and the Grenadines", + "numeric": "670" + }, + { + "alpha_2": "ve", + "alpha_3": "ven", + "common_name": "Venezuela", + "name": "Venezuela, Bolivarian Republic of", + "numeric": "862", + "official_name": "Bolivarian Republic of Venezuela" + }, + { + "alpha_2": "vg", + "alpha_3": "vgb", + "name": "Virgin Islands, British", + "numeric": "092", + "official_name": "British Virgin Islands" + }, + { + "alpha_2": "vi", + "alpha_3": "vir", + "name": "Virgin Islands, U.S.", + "numeric": "850", + "official_name": "Virgin Islands of the United States" + }, + { + "alpha_2": "vn", + "alpha_3": "vnm", + "common_name": "Vietnam", + "name": "Viet Nam", + "numeric": "704", + "official_name": "Socialist Republic of Viet Nam" + }, + { + "alpha_2": "vu", + "alpha_3": "vut", + "name": "Vanuatu", + "numeric": "548", + "official_name": "Republic of Vanuatu" + }, + { + "alpha_2": "wf", + "alpha_3": "wlf", + "name": "Wallis and Futuna", + "numeric": "876" + }, + { + "alpha_2": "ws", + "alpha_3": "wsm", + "name": "Samoa", + "numeric": "882", + "official_name": "Independent State of Samoa" + }, + { + "alpha_2": "ye", + "alpha_3": "yem", + "name": "Yemen", + "numeric": "887", + "official_name": "Republic of Yemen" + }, + { + "alpha_2": "za", + "alpha_3": "zaf", + "name": "South Africa", + "numeric": "710", + "official_name": "Republic of South Africa" + }, + { + "alpha_2": "zm", + "alpha_3": "zmb", + "name": "Zambia", + "numeric": "894", + "official_name": "Republic of Zambia" + }, + { + "alpha_2": "zw", + "alpha_3": "zwe", + "name": "Zimbabwe", + "numeric": "716", + "official_name": "Republic of Zimbabwe" + } +] \ No newline at end of file diff --git a/generalresearch/locales/iso639-3.json b/generalresearch/locales/iso639-3.json new file mode 100644 index 0000000..6693a2a --- /dev/null +++ b/generalresearch/locales/iso639-3.json @@ -0,0 +1,1117 @@ +[ + { + "alpha_2": "aa", + "alpha_3": "aar", + "name": "Afar", + "alpha_3b": "aar" + }, + { + "alpha_2": "ab", + "alpha_3": "abk", + "name": "Abkhazian", + "alpha_3b": "abk" + }, + { + "alpha_2": "af", + "alpha_3": "afr", + "name": "Afrikaans", + "alpha_3b": "afr" + }, + { + "alpha_2": "ak", + "alpha_3": "aka", + "name": "Akan", + "alpha_3b": "aka" + }, + { + "alpha_2": "am", + "alpha_3": "amh", + "name": "Amharic", + "alpha_3b": "amh" + }, + { + "alpha_2": "ar", + "alpha_3": "ara", + "name": "Arabic", + "alpha_3b": "ara" + }, + { + "alpha_2": "an", + "alpha_3": "arg", + "name": "Aragonese", + "alpha_3b": "arg" + }, + { + "alpha_2": "as", + "alpha_3": "asm", + "name": "Assamese", + "alpha_3b": "asm" + }, + { + "alpha_2": "av", + "alpha_3": "ava", + "name": "Avaric", + "alpha_3b": "ava" + }, + { + "alpha_2": "ae", + "alpha_3": "ave", + "name": "Avestan", + "alpha_3b": "ave" + }, + { + "alpha_2": "ay", + "alpha_3": "aym", + "name": "Aymara", + "alpha_3b": "aym" + }, + { + "alpha_2": "az", + "alpha_3": "aze", + "name": "Azerbaijani", + "alpha_3b": "aze" + }, + { + "alpha_2": "ba", + "alpha_3": "bak", + "name": "Bashkir", + "alpha_3b": "bak" + }, + { + "alpha_2": "bm", + "alpha_3": "bam", + "name": "Bambara", + "alpha_3b": "bam" + }, + { + "alpha_2": "be", + "alpha_3": "bel", + "name": "Belarusian", + "alpha_3b": "bel" + }, + { + "alpha_2": "bn", + "alpha_3": "ben", + "common_name": "Bangla", + "name": "Bengali", + "alpha_3b": "ben" + }, + { + "alpha_2": "bi", + "alpha_3": "bis", + "name": "Bislama", + "alpha_3b": "bis" + }, + { + "alpha_2": "bo", + "alpha_3": "bod", + "name": "Tibetan", + "alpha_3b": "tib" + }, + { + "alpha_2": "bs", + "alpha_3": "bos", + "name": "Bosnian", + "alpha_3b": "bos" + }, + { + "alpha_2": "br", + "alpha_3": "bre", + "name": "Breton", + "alpha_3b": "bre" + }, + { + "alpha_2": "bg", + "alpha_3": "bul", + "name": "Bulgarian", + "alpha_3b": "bul" + }, + { + "alpha_2": "ca", + "alpha_3": "cat", + "name": "Catalan", + "alpha_3b": "cat" + }, + { + "alpha_2": "cs", + "alpha_3": "ces", + "name": "Czech", + "alpha_3b": "cze" + }, + { + "alpha_2": "ch", + "alpha_3": "cha", + "name": "Chamorro", + "alpha_3b": "cha" + }, + { + "alpha_2": "ce", + "alpha_3": "che", + "name": "Chechen", + "alpha_3b": "che" + }, + { + "alpha_2": "cu", + "alpha_3": "chu", + "inverted_name": "Slavic, Church", + "name": "Church Slavic", + "alpha_3b": "chu" + }, + { + "alpha_2": "cv", + "alpha_3": "chv", + "name": "Chuvash", + "alpha_3b": "chv" + }, + { + "alpha_2": "kw", + "alpha_3": "cor", + "name": "Cornish", + "alpha_3b": "cor" + }, + { + "alpha_2": "co", + "alpha_3": "cos", + "name": "Corsican", + "alpha_3b": "cos" + }, + { + "alpha_2": "cr", + "alpha_3": "cre", + "name": "Cree", + "alpha_3b": "cre" + }, + { + "alpha_2": "cy", + "alpha_3": "cym", + "name": "Welsh", + "alpha_3b": "wel" + }, + { + "alpha_2": "da", + "alpha_3": "dan", + "name": "Danish", + "alpha_3b": "dan" + }, + { + "alpha_2": "de", + "alpha_3": "deu", + "name": "German", + "alpha_3b": "ger" + }, + { + "alpha_2": "dv", + "alpha_3": "div", + "name": "Dhivehi", + "alpha_3b": "div" + }, + { + "alpha_2": "dz", + "alpha_3": "dzo", + "name": "Dzongkha", + "alpha_3b": "dzo" + }, + { + "alpha_2": "el", + "alpha_3": "ell", + "inverted_name": "Greek, Modern (1453-)", + "name": "Modern Greek (1453-)", + "alpha_3b": "gre" + }, + { + "alpha_2": "en", + "alpha_3": "eng", + "name": "English", + "alpha_3b": "eng" + }, + { + "alpha_2": "eo", + "alpha_3": "epo", + "name": "Esperanto", + "alpha_3b": "epo" + }, + { + "alpha_2": "et", + "alpha_3": "est", + "name": "Estonian", + "alpha_3b": "est" + }, + { + "alpha_2": "eu", + "alpha_3": "eus", + "name": "Basque", + "alpha_3b": "baq" + }, + { + "alpha_2": "ee", + "alpha_3": "ewe", + "name": "Ewe", + "alpha_3b": "ewe" + }, + { + "alpha_2": "fo", + "alpha_3": "fao", + "name": "Faroese", + "alpha_3b": "fao" + }, + { + "alpha_2": "fa", + "alpha_3": "fas", + "name": "Persian", + "alpha_3b": "per" + }, + { + "alpha_2": "fj", + "alpha_3": "fij", + "name": "Fijian", + "alpha_3b": "fij" + }, + { + "alpha_2": "fi", + "alpha_3": "fin", + "name": "Finnish", + "alpha_3b": "fin" + }, + { + "alpha_2": "fr", + "alpha_3": "fra", + "name": "French", + "alpha_3b": "fre" + }, + { + "alpha_2": "fy", + "alpha_3": "fry", + "inverted_name": "Frisian, Western", + "name": "Western Frisian", + "alpha_3b": "fry" + }, + { + "alpha_2": "ff", + "alpha_3": "ful", + "name": "Fulah", + "alpha_3b": "ful" + }, + { + "alpha_2": "gd", + "alpha_3": "gla", + "inverted_name": "Gaelic, Scottish", + "name": "Scottish Gaelic", + "alpha_3b": "gla" + }, + { + "alpha_2": "ga", + "alpha_3": "gle", + "name": "Irish", + "alpha_3b": "gle" + }, + { + "alpha_2": "gl", + "alpha_3": "glg", + "name": "Galician", + "alpha_3b": "glg" + }, + { + "alpha_2": "gv", + "alpha_3": "glv", + "name": "Manx", + "alpha_3b": "glv" + }, + { + "alpha_2": "gn", + "alpha_3": "grn", + "name": "Guarani", + "alpha_3b": "grn" + }, + { + "alpha_2": "gu", + "alpha_3": "guj", + "name": "Gujarati", + "alpha_3b": "guj" + }, + { + "alpha_2": "ht", + "alpha_3": "hat", + "name": "Haitian", + "alpha_3b": "hat" + }, + { + "alpha_2": "ha", + "alpha_3": "hau", + "name": "Hausa", + "alpha_3b": "hau" + }, + { + "alpha_2": "sh", + "alpha_3": "hbs", + "name": "Serbo-Croatian", + "alpha_3b": "hbs" + }, + { + "alpha_2": "he", + "alpha_3": "heb", + "name": "Hebrew", + "alpha_3b": "heb" + }, + { + "alpha_2": "hz", + "alpha_3": "her", + "name": "Herero", + "alpha_3b": "her" + }, + { + "alpha_2": "hi", + "alpha_3": "hin", + "name": "Hindi", + "alpha_3b": "hin" + }, + { + "alpha_2": "ho", + "alpha_3": "hmo", + "name": "Hiri Motu", + "alpha_3b": "hmo" + }, + { + "alpha_2": "hr", + "alpha_3": "hrv", + "name": "Croatian", + "alpha_3b": "hrv" + }, + { + "alpha_2": "hu", + "alpha_3": "hun", + "name": "Hungarian", + "alpha_3b": "hun" + }, + { + "alpha_2": "hy", + "alpha_3": "hye", + "name": "Armenian", + "alpha_3b": "arm" + }, + { + "alpha_2": "ig", + "alpha_3": "ibo", + "name": "Igbo", + "alpha_3b": "ibo" + }, + { + "alpha_2": "io", + "alpha_3": "ido", + "name": "Ido", + "alpha_3b": "ido" + }, + { + "alpha_2": "ii", + "alpha_3": "iii", + "inverted_name": "Yi, Sichuan", + "name": "Sichuan Yi", + "alpha_3b": "iii" + }, + { + "alpha_2": "iu", + "alpha_3": "iku", + "name": "Inuktitut", + "alpha_3b": "iku" + }, + { + "alpha_2": "ie", + "alpha_3": "ile", + "name": "Interlingue", + "alpha_3b": "ile" + }, + { + "alpha_2": "ia", + "alpha_3": "ina", + "name": "Interlingua (International Auxiliary Language Association)", + "alpha_3b": "ina" + }, + { + "alpha_2": "id", + "alpha_3": "ind", + "name": "Indonesian", + "alpha_3b": "ind" + }, + { + "alpha_2": "ik", + "alpha_3": "ipk", + "name": "Inupiaq", + "alpha_3b": "ipk" + }, + { + "alpha_2": "is", + "alpha_3": "isl", + "name": "Icelandic", + "alpha_3b": "ice" + }, + { + "alpha_2": "it", + "alpha_3": "ita", + "name": "Italian", + "alpha_3b": "ita" + }, + { + "alpha_2": "jv", + "alpha_3": "jav", + "name": "Javanese", + "alpha_3b": "jav" + }, + { + "alpha_2": "ja", + "alpha_3": "jpn", + "name": "Japanese", + "alpha_3b": "jpn" + }, + { + "alpha_2": "kl", + "alpha_3": "kal", + "name": "Kalaallisut", + "alpha_3b": "kal" + }, + { + "alpha_2": "kn", + "alpha_3": "kan", + "name": "Kannada", + "alpha_3b": "kan" + }, + { + "alpha_2": "ks", + "alpha_3": "kas", + "name": "Kashmiri", + "alpha_3b": "kas" + }, + { + "alpha_2": "ka", + "alpha_3": "kat", + "name": "Georgian", + "alpha_3b": "geo" + }, + { + "alpha_2": "kr", + "alpha_3": "kau", + "name": "Kanuri", + "alpha_3b": "kau" + }, + { + "alpha_2": "kk", + "alpha_3": "kaz", + "name": "Kazakh", + "alpha_3b": "kaz" + }, + { + "alpha_2": "km", + "alpha_3": "khm", + "inverted_name": "Khmer, Central", + "name": "Central Khmer", + "alpha_3b": "khm" + }, + { + "alpha_2": "ki", + "alpha_3": "kik", + "name": "Kikuyu", + "alpha_3b": "kik" + }, + { + "alpha_2": "rw", + "alpha_3": "kin", + "name": "Kinyarwanda", + "alpha_3b": "kin" + }, + { + "alpha_2": "ky", + "alpha_3": "kir", + "name": "Kirghiz", + "alpha_3b": "kir" + }, + { + "alpha_2": "kv", + "alpha_3": "kom", + "name": "Komi", + "alpha_3b": "kom" + }, + { + "alpha_2": "kg", + "alpha_3": "kon", + "name": "Kongo", + "alpha_3b": "kon" + }, + { + "alpha_2": "ko", + "alpha_3": "kor", + "name": "Korean", + "alpha_3b": "kor" + }, + { + "alpha_2": "kj", + "alpha_3": "kua", + "name": "Kuanyama", + "alpha_3b": "kua" + }, + { + "alpha_2": "ku", + "alpha_3": "kur", + "name": "Kurdish", + "alpha_3b": "kur" + }, + { + "alpha_2": "lo", + "alpha_3": "lao", + "name": "Lao", + "alpha_3b": "lao" + }, + { + "alpha_2": "la", + "alpha_3": "lat", + "name": "Latin", + "alpha_3b": "lat" + }, + { + "alpha_2": "lv", + "alpha_3": "lav", + "name": "Latvian", + "alpha_3b": "lav" + }, + { + "alpha_2": "li", + "alpha_3": "lim", + "name": "Limburgan", + "alpha_3b": "lim" + }, + { + "alpha_2": "ln", + "alpha_3": "lin", + "name": "Lingala", + "alpha_3b": "lin" + }, + { + "alpha_2": "lt", + "alpha_3": "lit", + "name": "Lithuanian", + "alpha_3b": "lit" + }, + { + "alpha_2": "lb", + "alpha_3": "ltz", + "name": "Luxembourgish", + "alpha_3b": "ltz" + }, + { + "alpha_2": "lu", + "alpha_3": "lub", + "name": "Luba-Katanga", + "alpha_3b": "lub" + }, + { + "alpha_2": "lg", + "alpha_3": "lug", + "name": "Ganda", + "alpha_3b": "lug" + }, + { + "alpha_2": "mh", + "alpha_3": "mah", + "name": "Marshallese", + "alpha_3b": "mah" + }, + { + "alpha_2": "ml", + "alpha_3": "mal", + "name": "Malayalam", + "alpha_3b": "mal" + }, + { + "alpha_2": "mr", + "alpha_3": "mar", + "name": "Marathi", + "alpha_3b": "mar" + }, + { + "alpha_2": "mk", + "alpha_3": "mkd", + "name": "Macedonian", + "alpha_3b": "mac" + }, + { + "alpha_2": "mg", + "alpha_3": "mlg", + "name": "Malagasy", + "alpha_3b": "mlg" + }, + { + "alpha_2": "mt", + "alpha_3": "mlt", + "name": "Maltese", + "alpha_3b": "mlt" + }, + { + "alpha_2": "mn", + "alpha_3": "mon", + "name": "Mongolian", + "alpha_3b": "mon" + }, + { + "alpha_2": "mi", + "alpha_3": "mri", + "name": "Maori", + "alpha_3b": "mao" + }, + { + "alpha_2": "ms", + "alpha_3": "msa", + "name": "Malay (macrolanguage)", + "alpha_3b": "may" + }, + { + "alpha_2": "my", + "alpha_3": "mya", + "name": "Burmese", + "alpha_3b": "bur" + }, + { + "alpha_2": "na", + "alpha_3": "nau", + "name": "Nauru", + "alpha_3b": "nau" + }, + { + "alpha_2": "nv", + "alpha_3": "nav", + "name": "Navajo", + "alpha_3b": "nav" + }, + { + "alpha_2": "nr", + "alpha_3": "nbl", + "inverted_name": "Ndebele, South", + "name": "South Ndebele", + "alpha_3b": "nbl" + }, + { + "alpha_2": "nd", + "alpha_3": "nde", + "inverted_name": "Ndebele, North", + "name": "North Ndebele", + "alpha_3b": "nde" + }, + { + "alpha_2": "ng", + "alpha_3": "ndo", + "name": "Ndonga", + "alpha_3b": "ndo" + }, + { + "alpha_2": "ne", + "alpha_3": "nep", + "name": "Nepali (macrolanguage)", + "alpha_3b": "nep" + }, + { + "alpha_2": "nl", + "alpha_3": "nld", + "name": "Dutch", + "alpha_3b": "dut" + }, + { + "alpha_2": "nn", + "alpha_3": "nno", + "name": "Norwegian Nynorsk", + "alpha_3b": "nno" + }, + { + "alpha_2": "nb", + "alpha_3": "nob", + "name": "Norwegian Bokm\u00e5l", + "alpha_3b": "nob" + }, + { + "alpha_2": "no", + "alpha_3": "nor", + "name": "Norwegian", + "alpha_3b": "nor" + }, + { + "alpha_2": "ny", + "alpha_3": "nya", + "name": "Nyanja", + "alpha_3b": "nya" + }, + { + "alpha_2": "oc", + "alpha_3": "oci", + "name": "Occitan (post 1500)", + "alpha_3b": "oci" + }, + { + "alpha_2": "oj", + "alpha_3": "oji", + "name": "Ojibwa", + "alpha_3b": "oji" + }, + { + "alpha_2": "or", + "alpha_3": "ori", + "name": "Oriya (macrolanguage)", + "alpha_3b": "ori" + }, + { + "alpha_2": "om", + "alpha_3": "orm", + "name": "Oromo", + "alpha_3b": "orm" + }, + { + "alpha_2": "os", + "alpha_3": "oss", + "name": "Ossetian", + "alpha_3b": "oss" + }, + { + "alpha_2": "pa", + "alpha_3": "pan", + "name": "Panjabi", + "alpha_3b": "pan" + }, + { + "alpha_2": "pi", + "alpha_3": "pli", + "name": "Pali", + "alpha_3b": "pli" + }, + { + "alpha_2": "pl", + "alpha_3": "pol", + "name": "Polish", + "alpha_3b": "pol" + }, + { + "alpha_2": "pt", + "alpha_3": "por", + "name": "Portuguese", + "alpha_3b": "por" + }, + { + "alpha_2": "ps", + "alpha_3": "pus", + "name": "Pushto", + "alpha_3b": "pus" + }, + { + "alpha_2": "qu", + "alpha_3": "que", + "name": "Quechua", + "alpha_3b": "que" + }, + { + "alpha_2": "rm", + "alpha_3": "roh", + "name": "Romansh", + "alpha_3b": "roh" + }, + { + "alpha_2": "ro", + "alpha_3": "ron", + "name": "Romanian", + "alpha_3b": "rum" + }, + { + "alpha_2": "rn", + "alpha_3": "run", + "name": "Rundi", + "alpha_3b": "run" + }, + { + "alpha_2": "ru", + "alpha_3": "rus", + "name": "Russian", + "alpha_3b": "rus" + }, + { + "alpha_2": "sg", + "alpha_3": "sag", + "name": "Sango", + "alpha_3b": "sag" + }, + { + "alpha_2": "sa", + "alpha_3": "san", + "name": "Sanskrit", + "alpha_3b": "san" + }, + { + "alpha_2": "si", + "alpha_3": "sin", + "name": "Sinhala", + "alpha_3b": "sin" + }, + { + "alpha_2": "sk", + "alpha_3": "slk", + "name": "Slovak", + "alpha_3b": "slo" + }, + { + "alpha_2": "sl", + "alpha_3": "slv", + "name": "Slovenian", + "alpha_3b": "slv" + }, + { + "alpha_2": "se", + "alpha_3": "sme", + "inverted_name": "Sami, Northern", + "name": "Northern Sami", + "alpha_3b": "sme" + }, + { + "alpha_2": "sm", + "alpha_3": "smo", + "name": "Samoan", + "alpha_3b": "smo" + }, + { + "alpha_2": "sn", + "alpha_3": "sna", + "name": "Shona", + "alpha_3b": "sna" + }, + { + "alpha_2": "sd", + "alpha_3": "snd", + "name": "Sindhi", + "alpha_3b": "snd" + }, + { + "alpha_2": "so", + "alpha_3": "som", + "name": "Somali", + "alpha_3b": "som" + }, + { + "alpha_2": "st", + "alpha_3": "sot", + "inverted_name": "Sotho, Southern", + "name": "Southern Sotho", + "alpha_3b": "sot" + }, + { + "alpha_2": "es", + "alpha_3": "spa", + "name": "Spanish", + "alpha_3b": "spa" + }, + { + "alpha_2": "sq", + "alpha_3": "sqi", + "name": "Albanian", + "alpha_3b": "alb" + }, + { + "alpha_2": "sc", + "alpha_3": "srd", + "name": "Sardinian", + "alpha_3b": "srd" + }, + { + "alpha_2": "sr", + "alpha_3": "srp", + "name": "Serbian", + "alpha_3b": "srp" + }, + { + "alpha_2": "ss", + "alpha_3": "ssw", + "name": "Swati", + "alpha_3b": "ssw" + }, + { + "alpha_2": "su", + "alpha_3": "sun", + "name": "Sundanese", + "alpha_3b": "sun" + }, + { + "alpha_2": "sw", + "alpha_3": "swa", + "name": "Swahili (macrolanguage)", + "alpha_3b": "swa" + }, + { + "alpha_2": "sv", + "alpha_3": "swe", + "name": "Swedish", + "alpha_3b": "swe" + }, + { + "alpha_2": "ty", + "alpha_3": "tah", + "name": "Tahitian", + "alpha_3b": "tah" + }, + { + "alpha_2": "ta", + "alpha_3": "tam", + "name": "Tamil", + "alpha_3b": "tam" + }, + { + "alpha_2": "tt", + "alpha_3": "tat", + "name": "Tatar", + "alpha_3b": "tat" + }, + { + "alpha_2": "te", + "alpha_3": "tel", + "name": "Telugu", + "alpha_3b": "tel" + }, + { + "alpha_2": "tg", + "alpha_3": "tgk", + "name": "Tajik", + "alpha_3b": "tgk" + }, + { + "alpha_2": "tl", + "alpha_3": "tgl", + "name": "Tagalog", + "alpha_3b": "tgl" + }, + { + "alpha_2": "th", + "alpha_3": "tha", + "name": "Thai", + "alpha_3b": "tha" + }, + { + "alpha_2": "ti", + "alpha_3": "tir", + "name": "Tigrinya", + "alpha_3b": "tir" + }, + { + "alpha_2": "to", + "alpha_3": "ton", + "name": "Tonga (Tonga Islands)", + "alpha_3b": "ton" + }, + { + "alpha_2": "tn", + "alpha_3": "tsn", + "name": "Tswana", + "alpha_3b": "tsn" + }, + { + "alpha_2": "ts", + "alpha_3": "tso", + "name": "Tsonga", + "alpha_3b": "tso" + }, + { + "alpha_2": "tk", + "alpha_3": "tuk", + "name": "Turkmen", + "alpha_3b": "tuk" + }, + { + "alpha_2": "tr", + "alpha_3": "tur", + "name": "Turkish", + "alpha_3b": "tur" + }, + { + "alpha_2": "tw", + "alpha_3": "twi", + "name": "Twi", + "alpha_3b": "twi" + }, + { + "alpha_2": "ug", + "alpha_3": "uig", + "name": "Uighur", + "alpha_3b": "uig" + }, + { + "alpha_2": "uk", + "alpha_3": "ukr", + "name": "Ukrainian", + "alpha_3b": "ukr" + }, + { + "alpha_2": "ur", + "alpha_3": "urd", + "name": "Urdu", + "alpha_3b": "urd" + }, + { + "alpha_2": "uz", + "alpha_3": "uzb", + "name": "Uzbek", + "alpha_3b": "uzb" + }, + { + "alpha_2": "ve", + "alpha_3": "ven", + "name": "Venda", + "alpha_3b": "ven" + }, + { + "alpha_2": "vi", + "alpha_3": "vie", + "name": "Vietnamese", + "alpha_3b": "vie" + }, + { + "alpha_2": "vo", + "alpha_3": "vol", + "name": "Volap\u00fck", + "alpha_3b": "vol" + }, + { + "alpha_2": "wa", + "alpha_3": "wln", + "name": "Walloon", + "alpha_3b": "wln" + }, + { + "alpha_2": "wo", + "alpha_3": "wol", + "name": "Wolof", + "alpha_3b": "wol" + }, + { + "alpha_2": "xh", + "alpha_3": "xho", + "name": "Xhosa", + "alpha_3b": "xho" + }, + { + "alpha_2": "yi", + "alpha_3": "yid", + "name": "Yiddish", + "alpha_3b": "yid" + }, + { + "alpha_2": "yo", + "alpha_3": "yor", + "name": "Yoruba", + "alpha_3b": "yor" + }, + { + "alpha_2": "za", + "alpha_3": "zha", + "name": "Zhuang", + "alpha_3b": "zha" + }, + { + "alpha_2": "zh", + "alpha_3": "zho", + "name": "Chinese", + "alpha_3b": "chi" + }, + { + "alpha_2": "zu", + "alpha_3": "zul", + "name": "Zulu", + "alpha_3b": "zul" + } +] \ No newline at end of file diff --git a/generalresearch/locales/setup_json.py b/generalresearch/locales/setup_json.py new file mode 100644 index 0000000..356084a --- /dev/null +++ b/generalresearch/locales/setup_json.py @@ -0,0 +1,61 @@ +import json + + +def country_default_lang(): + """ + Some marketplaces have no language specified. Surveys are in the "default + language for that country", whatever that means. This helper is meant to + provide a reasonable guess as to what language it is. + + Derived from: http://download.geonames.org/export/dump/countryInfo.txt + """ + raise ValueError("no need to run this, I already ran it.") + import pandas as pd + from generalresearch.locales import Localelator + + l = Localelator() + + df = pd.read_csv( + "http://download.geonames.org/export/dump/countryInfo.txt", + sep="\t", + skiprows=49, + ) + df["default_lang"] = df.Languages.str.split(",").str[0].str.split("-").str[0] + df.default_lang = df.default_lang.fillna("en") + df.default_lang = df.default_lang.map( + lambda x: l.get_language_iso(x) if x in l.languages else "eng" + ) + df["#ISO"] = df["#ISO"].str.lower() + df["country_iso"] = df["#ISO"].map( + lambda x: l.get_country_iso(x) if x in l.countries else None + ) + df = df[df.country_iso.notnull()] + d = df.set_index("country_iso").default_lang.to_dict() + with open("country_default_lang.json", "w") as f: + json.dump(d, f, indent=2) + return d + + +def setup_json(): + # pycountry is 30mb, which makes using this package on AWS lambda problematic. + # These JSONs are stolen from pycountry and adapted. + + raise ValueError("no need to run this, I already ran it.") + + # languages + d = json.load(open("iso639-3.json")) + d["639-3"] = [x for x in d["639-3"] if "alpha_2" in x] + for x in d["639-3"]: + x["alpha_3b"] = x.pop("bibliographic", None) or x["alpha_3"] + del x["scope"] + del x["type"] + with open("iso639-3.json", "w") as f: + json.dump(d["639-3"], f, indent=2) + + # countries + d = json.load(open("iso3166-1.json"))["3166-1"] + for x in d: + x["alpha_2"] = x["alpha_2"].lower() + x["alpha_3"] = x["alpha_3"].lower() + with open("iso3166-1.json", "w") as f: + json.dump(d, f, indent=2) diff --git a/generalresearch/locales/timezone.py b/generalresearch/locales/timezone.py new file mode 100644 index 0000000..50d539d --- /dev/null +++ b/generalresearch/locales/timezone.py @@ -0,0 +1,77 @@ +from typing import Optional + +from pytz import country_timezones + + +def get_default_timezone(country_iso: str) -> Optional[str]: + # to list all: + # from pytz import country_names, country_timezones + # [country_timezones.get(country) for country in country_names] + + # country_iso can be upper or lower, doesn't matter + return country_timezones.get(country_iso, [None])[0] + + +# There is no official list for this .... +country_default_locale = { + "af": "fa-AF", + "al": "sq-AL", + "dz": "ar-DZ", + "ar": "es-AR", + "au": "en-AU", + "at": "de-AT", + "br": "pt-BR", + "ca": "en-CA", + "cn": "zh-CN", + "eg": "ar-EG", + "fr": "fr-FR", + "de": "de-DE", + "in": "hi-IN", + "jp": "ja-JP", + "ke": "sw-KE", + "mx": "es-MX", + "ru": "ru-RU", + "kr": "ko-KR", + "gb": "en-GB", + "us": "en-US", + "lt": "lt-LT", + "lu": "lb-LU", + "mg": "mg-MG", + "my": "ms-MY", + "mv": "dv-MV", + "ml": "fr-ML", + "mt": "mt-MT", + "mn": "mn-MN", + "ma": "ar-MA", + "np": "ne-NP", + "nl": "nl-NL", + "nz": "en-NZ", + "ng": "en-NG", + "no": "no-NO", + "pk": "ur-PK", + "pa": "es-PA", + "pe": "es-PE", + "ph": "tl-PH", + "pl": "pl-PL", + "pt": "pt-PT", + "qa": "ar-QA", + "ro": "ro-RO", + "sa": "ar-SA", + "sg": "en-SG", + "za": "en-ZA", + "es": "es-ES", + "lk": "si-LK", + "se": "sv-SE", + "ch": "de-CH", + "th": "th-TH", + "tr": "tr-TR", + "ua": "uk-UA", + "ae": "ar-AE", + "vn": "vi-VN", + "zw": "en-ZW", +} + + +def get_default_locale(country_iso: str) -> Optional[str]: + # todo: "https://cdn.simplelocalize.io/public/v1/locales" to fill in the rest? + return country_default_locale.get(country_iso, None) diff --git a/generalresearch/logging.py b/generalresearch/logging.py new file mode 100644 index 0000000..9b72e0b --- /dev/null +++ b/generalresearch/logging.py @@ -0,0 +1,21 @@ +import decimal +import json +from datetime import date + + +class ThlJsonEncoder(json.JSONEncoder): + """ + Converts: + Decimal to str + set to sorted list + datetime/date to isoformat + """ + + def default(self, o): + if isinstance(o, decimal.Decimal): + return str(o) + if isinstance(o, set): + return sorted(list(o)) + if isinstance(o, date): + return o.isoformat() + return super().default(o) diff --git a/generalresearch/managers/__init__.py b/generalresearch/managers/__init__.py new file mode 100644 index 0000000..8af988c --- /dev/null +++ b/generalresearch/managers/__init__.py @@ -0,0 +1,16 @@ +def parse_order_by(order_by_str: str): + """ + Converts django-rest-framework ordering str to mysql clause + :param order_by_str: e.g. 'created,-name' + :return: mysql clause e.g. ORDER BY created ASC, name DESC + """ + fields = order_by_str.split(",") + + order_clause = [] + for field in fields: + if field.startswith("-"): + order_clause.append(f"{field[1:]} DESC") + else: + order_clause.append(f"{field} ASC") + + return "ORDER BY " + ", ".join(order_clause) diff --git a/generalresearch/managers/base.py b/generalresearch/managers/base.py new file mode 100644 index 0000000..bb9ca75 --- /dev/null +++ b/generalresearch/managers/base.py @@ -0,0 +1,91 @@ +from enum import Enum +from typing import Collection, Optional + +from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig +from generalresearch.sql_helper import SqlHelper + + +class Permission(int, Enum): + READ = 1 + UPDATE = 2 + CREATE = 3 + DELETE = 4 + + +class Manager: + pass + + +class SqlManager(Manager): + def __init__( + self, + sql_helper: SqlHelper, + permissions: Optional[Collection[Permission]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.sql_helper = sql_helper + self.permissions = set(permissions) if permissions else set() + # This is susceptible to sql injection, so don't ever pass arbitrary input into it + # (https://stackoverflow.com/a/64412951/1991066) + self.db_name = self.sql_helper.db_name + + +class PostgresManager(Manager): + def __init__( + self, + pg_config: PostgresConfig, + permissions: Collection[Permission] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.pg_config = pg_config + self.permissions = set(permissions) if permissions else set() + + +class RedisManager(Manager): + CACHE_PREFIX = None + + def __init__( + self, + redis_config: RedisConfig, + cache_prefix: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.redis_config = redis_config + self.cache_prefix = cache_prefix or self.CACHE_PREFIX or "" + self.redis_client = self.redis_config.create_redis_client() + + +class SqlManagerWithRedis(SqlManager, RedisManager): + def __init__( + self, + sql_helper: SqlHelper, + redis_config: RedisConfig, + permissions: Collection[Permission] = None, + cache_prefix: Optional[str] = None, + ): + super().__init__( + sql_helper=sql_helper, + redis_config=redis_config, + permissions=permissions, + cache_prefix=cache_prefix, + ) + + +class PostgresManagerWithRedis(PostgresManager, RedisManager): + def __init__( + self, + pg_config: PostgresConfig, + redis_config: RedisConfig, + permissions: Collection[Permission] = None, + cache_prefix: Optional[str] = None, + ): + super().__init__( + pg_config=pg_config, + permissions=permissions, + redis_config=redis_config, + cache_prefix=cache_prefix, + ) diff --git a/generalresearch/managers/cint/__init__.py b/generalresearch/managers/cint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/cint/profiling.py b/generalresearch/managers/cint/profiling.py new file mode 100644 index 0000000..4a7fc69 --- /dev/null +++ b/generalresearch/managers/cint/profiling.py @@ -0,0 +1,62 @@ +import json +from typing import List, Collection, Optional, Tuple + +from generalresearch.models.cint.question import CintQuestion +from generalresearch.sql_helper import SqlHelper + + +def get_profiling_library( + sql_helper: SqlHelper, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + question_ids: Optional[Collection[str]] = None, + max_options: Optional[int] = None, + is_live: Optional[bool] = None, + pks: Optional[Collection[Tuple[str, str, str]]] = None, +) -> List[CintQuestion]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param question_ids: filters on question_id field, accepts multiple values + :param max_options: filters on max_options field + :param is_live: filters on is_live field + :param pks: The pk is (question_key, country_iso, language_iso). pks accepts a collection of + len(3) tuples. e.g. [('CORE_AUTOMOTIVE_0002', 'us', 'eng'), ('AGE', 'us', 'spa')] + :return: + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if question_ids: + params["question_ids"] = question_ids + filters.append("question_id IN %(question_ids)s") + if max_options is not None: + params["max_options"] = max_options + filters.append("COALESCE(num_options, 0) <= %(max_options)s") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + if pks: + params["pks"] = pks + filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( + f""" + SELECT * + FROM `{sql_helper.db}`.`cint_question` q + {filter_str} + """, + params, + ) + for x in res: + x["options"] = json.loads(x["options"]) if x["options"] else None + qs = [CintQuestion.from_db(x) for x in res] + return qs diff --git a/generalresearch/managers/cint/survey.py b/generalresearch/managers/cint/survey.py new file mode 100644 index 0000000..b7045c9 --- /dev/null +++ b/generalresearch/managers/cint/survey.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import logging +from datetime import timezone, datetime +from typing import List, Collection, Optional, Set + +import pymysql +from pymysql import IntegrityError + +from generalresearch.managers.criteria import CriteriaManager +from generalresearch.managers.survey import SurveyManager +from generalresearch.models.cint.survey import CintSurvey, CintCondition + +logger = logging.getLogger() + + +class CintCriteriaManager(CriteriaManager): + CONDITION_MODEL = CintCondition + TABLE_NAME = "cint_criterion" + + +class CintSurveyManager(SurveyManager): + SURVEY_FIELDS = (set(CintSurvey.model_fields.keys()) | {"is_live"}) - { + "country_isos", + "language_isos", + "source", + "conditions", + "gross_cpi", + "is_live_raw", + } + + def get_survey_library( + self, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + survey_ids: Optional[Collection[str]] = None, + is_live: Optional[bool] = None, + updated_since: Optional[datetime] = None, + exclude_fields: Optional[Set[str]] = None, + ) -> List[CintSurvey]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param is_live: filters on is_live field + :param updated_since: filters on "> last_updated" + :param exclude_fields: Optionally exclude fields from query. This only supports + nullable fields, as the CintSurvey model validation will fail otherwise. + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + + if survey_ids is not None: + params["survey_ids"] = survey_ids + filters.append("survey_id IN %(survey_ids)s") + + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + + if updated_since is not None: + params["updated"] = updated_since + filters.append("last_updated > %(updated)s") + assert filters, "Must set at least 1 filter" + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + fields = set(self.SURVEY_FIELDS) + if exclude_fields: + fields -= exclude_fields + fields_str = ", ".join([f"`{v}`" for v in fields]) + res = self.sql_helper.execute_sql_query( + query=f""" + SELECT {fields_str} + FROM `thl-cint`.`cint_survey` survey + {filter_str} + """, + params=params, + ) + surveys = [CintSurvey.from_mysql(x) for x in res] + return surveys + + def create(self, survey: CintSurvey) -> bool: + d = survey.to_mysql() + conn: pymysql.Connection = self.sql_helper.make_connection() + conn.autocommit(True) + c = conn.cursor() + create_fields = list(self.SURVEY_FIELDS) + + fields_str = ", ".join([f"`{x}`" for x in create_fields]) + values_str = ", ".join([f"%({x})s" for x in create_fields]) + survey_data = {k: v for k, v in d.items() if k in create_fields} + c.execute( + query=f""" + INSERT INTO `thl-cint`.`cint_survey` + ({fields_str}) VALUES ({values_str}) + """, + args=survey_data, + ) + return True + + def update(self, surveys: List[CintSurvey]) -> bool: + now = datetime.now(tz=timezone.utc) + for survey in surveys: + survey.last_updated = now + + survey_fields = list(self.SURVEY_FIELDS) + data = [survey.to_mysql() for survey in surveys] + survey_data = [[d[k] for k in survey_fields] for d in data] + self.sql_helper.bulk_update("cint_survey", survey_fields, survey_data) + return True + + def create_or_update(self, surveys: List[CintSurvey]): + surveys = {s.survey_id: s for s in surveys} + sns = set(surveys.keys()) + existing_sns = { + x["survey_id"] + for x in self.sql_helper.execute_sql_query( + query=""" + SELECT survey_id + FROM `thl-cint`.`cint_survey` + WHERE survey_id IN %s; + """, + params=[sns], + ) + } + create_sns = sns - existing_sns + for sn in create_sns: + survey = surveys[sn] + try: + self.create(survey) + except IntegrityError as e: + logger.info(e) + if e.args[0] == 1062: + existing_sns.add(sn) + else: + raise e + self.update([surveys[sn] for sn in existing_sns]) diff --git a/generalresearch/managers/cint/user_pid.py b/generalresearch/managers/cint/user_pid.py new file mode 100644 index 0000000..4f749a0 --- /dev/null +++ b/generalresearch/managers/cint/user_pid.py @@ -0,0 +1,7 @@ +from generalresearch.managers.marketplace.user_pid import UserPidManager +from generalresearch.models import Source + + +class CintUserPidManager(UserPidManager): + TABLE_NAME = "cint_userpid" + SOURCE = Source.CINT diff --git a/generalresearch/managers/criteria.py b/generalresearch/managers/criteria.py new file mode 100644 index 0000000..4cf3a3e --- /dev/null +++ b/generalresearch/managers/criteria.py @@ -0,0 +1,105 @@ +from abc import ABC +from datetime import datetime, timezone +from typing import Collection, Dict, Set + +from more_itertools import chunked + +from generalresearch.managers.base import SqlManager +from generalresearch.models.thl.survey import MarketplaceCondition + + +class CriteriaManager(SqlManager, ABC): + """ + Using the terms "criteria" & "condition" interchangeably! + """ + + DB_FIELDS = [ + "hash", + "question_id", + "logical_operator", + "values", + "value_type", + "negate", + ] + CONDITION_MODEL = None + TABLE_NAME = "" + + def create(self, criterion: MarketplaceCondition) -> bool: + """ + Create a single criterion + """ + ... + + def filter(self, hashes: Collection[str]) -> Dict[str, MarketplaceCondition]: + """ + Filter for criterion from the db + """ + res = self.sql_helper.execute_sql_query( + query=f""" + SELECT {self.mysql_fields} + FROM {self.mysql_db_table} + WHERE `hash` IN %s; + """, + params=[hashes], + ) + return {x["hash"]: self.CONDITION_MODEL.from_mysql(x) for x in res} + + def filter_exists(self, hashes: Set[str]) -> Set[str]: + """Returns hashes that exist in the db""" + res = self.sql_helper.execute_sql_query( + query=f""" + SELECT `hash` + FROM {self.mysql_db_table} + WHERE `hash` IN %s; + """, + params=[hashes], + ) + return {x["hash"] for x in res} + + def update(self, conditions: Collection[MarketplaceCondition]) -> None: + # Add any new hashes into the DB + this_hashes = set([condition.criterion_hash for condition in conditions]) + known_hashes = self.filter_exists(this_hashes) + new_hashes = this_hashes - known_hashes + + if new_hashes: + now = datetime.now(tz=timezone.utc) + values = [ + condition.to_mysql() + for condition in conditions + if condition.criterion_hash in new_hashes + ] + values = [ + v + | { + "created": now, + "last_used": now, + "hash": v["criterion_hash"], + } + for v in values + ] + values_str = ",".join( + [f"%({k})s" for k in self.DB_FIELDS + ["created", "last_used"]] + ) + conn = self.sql_helper.make_connection() + c = conn.cursor() + for chunk in chunked(values, 100): + c.executemany( + query=f""" + INSERT INTO {self.mysql_db_table} + ({self.mysql_fields}, `created`, `last_used`) + VALUES ({values_str}); + """, + args=chunk, + ) + conn.commit() + + return None + + @property + def mysql_fields(self) -> str: + return ", ".join([f"`{k}`" for k in self.DB_FIELDS]) + + @property + def mysql_db_table(self) -> str: + return f"`{self.sql_helper.db}`.`{self.TABLE_NAME}`" diff --git a/generalresearch/managers/dynata/__init__.py b/generalresearch/managers/dynata/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/dynata/profiling.py b/generalresearch/managers/dynata/profiling.py new file mode 100644 index 0000000..10d6c69 --- /dev/null +++ b/generalresearch/managers/dynata/profiling.py @@ -0,0 +1,63 @@ +import json +from typing import List, Collection, Optional, Tuple + +from generalresearch.models.dynata.question import DynataQuestion +from generalresearch.sql_helper import SqlHelper + + +def get_profiling_library( + sql_helper: SqlHelper, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + question_ids: Optional[Collection[str]] = None, + max_options: Optional[int] = None, + is_live: Optional[bool] = None, + pks: Optional[Collection[Tuple[str, str, str]]] = None, +) -> List[DynataQuestion]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param question_ids: filters on question_id field, accepts multiple values + :param max_options: filters on max_options field + :param is_live: filters on is_live field + :param pks: The pk is (question_id, country_iso, language_iso). pks accepts a collection of + len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] + :return: + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if question_ids: + params["question_ids"] = question_ids + filters.append("question_id IN %(question_ids)s") + if max_options is not None: + params["max_options"] = max_options + filters.append("COALESCE(num_options, 0) <= %(max_options)s") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + if pks: + params["pks"] = pks + filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( + f""" + SELECT * + FROM `thl-dynata`.`dynata_rexquestion` q + {filter_str} + """, + params, + ) + for x in res: + x["options"] = json.loads(x["options"]) if x["options"] else None + x["parent_dependencies"] = json.loads(x["parent_dependencies"]) + qs = [DynataQuestion.from_db(x) for x in res] + return qs diff --git a/generalresearch/managers/dynata/survey.py b/generalresearch/managers/dynata/survey.py new file mode 100644 index 0000000..a20345a --- /dev/null +++ b/generalresearch/managers/dynata/survey.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import logging +from datetime import timezone, datetime +from typing import List, Collection, Optional + +import pymysql +from pymysql import IntegrityError + +from generalresearch.managers.criteria import CriteriaManager +from generalresearch.managers.survey import SurveyManager +from generalresearch.models.dynata.survey import DynataSurvey, DynataCondition + +logger = logging.getLogger() + + +class DynataCriteriaManager(CriteriaManager): + CONDITION_MODEL = DynataCondition + TABLE_NAME = "dynata_criterion" + + +class DynataSurveyManager(SurveyManager): + SURVEY_FIELDS = [ + "survey_id", + "status", + "is_live", + "client_id", + "bid_loi", + "bid_ir", + "country_iso", + "language_iso", + "cpi", + "expected_count", + "project_id", + "group_id", + "calculation_type", + "days_in_field", + "order_number", + "requirements", + "allowed_devices", + "category_exclusions", + "project_exclusions", + "live_link", + "category_ids", + "filters", + "quotas", + "used_question_ids", + "created", + ] + + def get_survey_library( + self, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + survey_ids: Optional[Collection[str]] = None, + is_live: Optional[bool] = None, + updated_since: Optional[datetime] = None, + ) -> List[DynataSurvey]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param is_live: filters on is_live field + :param updated_since: filters on "> last_updated" + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if survey_ids is not None: + params["survey_ids"] = survey_ids + filters.append("survey_id IN %(survey_ids)s") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + if updated_since is not None: + params["updated_since"] = updated_since + filters.append("last_updated > %(updated_since)s") + assert filters, "Must set at least 1 filter" + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = self.sql_helper.execute_sql_query( + f""" + SELECT survey_id, status, is_live, client_id, bid_loi, bid_ir, cpi, expected_count, + project_id, group_id, calculation_type, days_in_field, country_iso, language_iso, + created, order_number, requirements, allowed_devices, category_exclusions, category_ids, + filters, project_exclusions, quotas, used_question_ids, live_link, last_updated + FROM `thl-dynata`.`dynata_survey` survey + {filter_str} + """, + params, + ) + surveys = [DynataSurvey.from_db(x) for x in res] + return surveys + + def create(self, survey: DynataSurvey) -> bool: + now = datetime.now(tz=timezone.utc) + d = survey.to_mysql() + conn: pymysql.Connection = self.sql_helper.make_connection() + conn.autocommit(True) + c = conn.cursor() + create_fields = ["id"] + self.SURVEY_FIELDS + ["last_updated"] + + fields_str = ", ".join([f"`{x}`" for x in create_fields]) + values_str = ", ".join([f"%({x})s" for x in create_fields]) + survey_data = {k: v for k, v in d.items() if k in create_fields} + survey_data.update({"last_updated": now, "id": None}) + c.execute( + f""" + INSERT INTO `thl-dynata`.`dynata_survey` + ({fields_str}) VALUES ({values_str}) + """, + survey_data, + ) + return True + + def update(self, surveys: List[DynataSurvey]) -> bool: + now = datetime.now(tz=timezone.utc) + update_fields = self.SURVEY_FIELDS + ["last_updated"] + + data = [survey.to_mysql() for survey in surveys] + survey_data = [[d[k] for k in self.SURVEY_FIELDS] + [now] for d in data] + self.sql_helper.bulk_update("dynata_survey", update_fields, survey_data) + return True + + def create_or_update(self, surveys: List[DynataSurvey]): + surveys = {s.survey_id: s for s in surveys} + sns = set(surveys.keys()) + existing_sns = { + x["survey_id"] + for x in self.sql_helper.execute_sql_query( + """ + SELECT survey_id + FROM `thl-dynata`.`dynata_survey` + WHERE survey_id IN %s""", + [sns], + ) + } + create_sns = sns - existing_sns + for sn in create_sns: + survey = surveys[sn] + try: + self.create(survey) + except IntegrityError as e: + logger.info(e) + if e.args[0] == 1062: + existing_sns.add(sn) + else: + raise e + self.update([surveys[sn] for sn in existing_sns]) diff --git a/generalresearch/managers/dynata/user_pid.py b/generalresearch/managers/dynata/user_pid.py new file mode 100644 index 0000000..aefed34 --- /dev/null +++ b/generalresearch/managers/dynata/user_pid.py @@ -0,0 +1,7 @@ +from generalresearch.managers.marketplace.user_pid import UserPidManager +from generalresearch.models import Source + + +class DynataUserPidManager(UserPidManager): + TABLE_NAME = "dynata_userpid" + SOURCE = Source.DYNATA diff --git a/generalresearch/managers/events.py b/generalresearch/managers/events.py new file mode 100644 index 0000000..3e87879 --- /dev/null +++ b/generalresearch/managers/events.py @@ -0,0 +1,863 @@ +import logging +import socket +import threading +import time +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from typing import Set, Optional, TYPE_CHECKING, Dict, List + +import math +from redis.client import PubSub, Redis + +from generalresearch.managers.base import RedisManager +from generalresearch.models import Source +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.events import ( + StatsMessage, + EventMessage, + EventEnvelope, + EventType, + TaskEnterPayload, + ServerToClientMessageAdapter, + ServerToClientMessage, + TaskFinishPayload, + SessionEnterPayload, + SessionFinishPayload, + AggregateBySource, + MaxGaugeBySource, + TaskStatsSnapshot, +) +from generalresearch.models.thl.definitions import Status +from generalresearch.models.thl.session import Wall, Session +from generalresearch.models.thl.user import User + +if TYPE_CHECKING: + from influxdb import InfluxDBClient +else: + InfluxDBClient = object + +# Sums all the values in a single hashmap +SUM_HASH_LUA_SCRIPT = """ +local vals = redis.call("HVALS", KEYS[1]) +local sum = 0 +for i = 1, #vals do + sum = sum + tonumber(vals[i]) +end +return sum +""" + +# Returns the max of all the values in a single hashmap +MAX_HASH_LUA_SCRIPT = """ +local vals = redis.call("HVALS", KEYS[1]) +local max = nil +for i = 1, #vals do + local v = tonumber(vals[i]) + if v then + if not max or v > max then + max = v + end + end +end +return max +""" + + +class UserStatsManager(RedisManager): + """ + We store a hashmap for each of last 1hr, last 24hr, for + global counts and per BP. The hashmap's key is the user, + and it expires N hours after it gets set. + To calculate the active user count, we simply get + the number of keys in the hashmap. + """ + + def handle_user(self, user: User): + self.mark_user_active(user=user) + self.handle_user_signup(user=user) + + def get_user_stats(self, product_id: UUIDStr): + r = self.redis_client + pipe = r.pipeline(transaction=False) + + keys = [ + f"active_users_last_1h:{product_id}", + f"active_users_last_24h:{product_id}", + f"signups_last_24h:{product_id}", + f"in_progress_users:{product_id}", + ] + keys_out = [ + "active_users_last_1h", + "active_users_last_24h", + "signups_last_24h", + "in_progress_users", + ] + + for k in keys: + pipe.hlen(k) + + return {k: v for k, v in zip(keys_out, pipe.execute())} + + def get_global_user_stats(self): + r = self.redis_client + pipe = r.pipeline(transaction=False) + + keys = [ + "active_users_last_1h", + "active_users_last_24h", + "signups_last_24h", + "in_progress_users", + ] + for k in keys: + pipe.hlen(k) + + return {k: v for k, v in zip(keys, pipe.execute())} + + def handle_user_signup(self, user: User) -> None: + # Use the user.created timestamp instead of "now". This allows + # us to test this function also, and we also can avoid + # having the caller do any sort of logic, this just always gets called. + # The key is the user's ID so this can be called multiple times + # without side effects. + if user.created is None: + return None + now = round(time.time()) + sec_24hr = round(timedelta(hours=24).total_seconds()) + + minute = int(user.created.timestamp() // 60) * 60 + expires_at = minute + sec_24hr + ttl = expires_at - now + if ttl <= 0: + return None + + pipe = self.redis_client.pipeline() + name = "signups_last_24h" + pipe.hset(name, user.uuid, now) + pipe.hexpire(name, ttl, user.uuid) + name = f"signups_last_24h:{user.product_id}" + pipe.hset(name, user.product_user_id, now) + pipe.hexpire(name, ttl, user.product_user_id) + pipe.execute() + return None + + def mark_user_active(self, user: User) -> None: + now = datetime.now(tz=timezone.utc).isoformat() + r = self.redis_client + + pipe = r.pipeline(transaction=False) + + name_1h_bpid = f"active_users_last_1h:{user.product_id}" + key = user.product_user_id # I could use user.uuid here also + # store last-seen timestamp (value (now) is informational, we could just store "1") + pipe.hset(name_1h_bpid, key, now) + # refresh TTL for this user only + pipe.hexpire(name_1h_bpid, timedelta(hours=1), key) + + name_1h_global = "active_users_last_1h" + key = user.uuid # must use user.uuid here b/c user.bpuid might not be unique + pipe.hset(name_1h_global, key, now) + pipe.hexpire(name_1h_global, timedelta(hours=1), key) + + name_24h_bpid = f"active_users_last_24h:{user.product_id}" + key = user.product_user_id + pipe.hset(name_24h_bpid, key, now) + pipe.hexpire(name_24h_bpid, timedelta(hours=24), key) + + name_24h_global = "active_users_last_24h" + key = user.uuid + pipe.hset(name_24h_global, key, now) + pipe.hexpire(name_24h_global, timedelta(hours=24), key) + + pipe.execute() + + def mark_user_inprogress(self, user: User) -> None: + # Call when a user enters a Session + # This call is idempotent; it can be called multiple times (for the + # same user) and won't falsely increase a counter; it will just + # reset the expiration for this user (times out after 60 min) + now = datetime.now(tz=timezone.utc).isoformat() + r = self.redis_client + pipe = r.pipeline(transaction=False) + + name = f"in_progress_users:{user.product_id}" + key = user.product_user_id # I could use user.uuid here also + # store last-seen timestamp (value (now) is informational, we could just store "1") + pipe.hset(name, key, now) + # Expire after 1 hr + pipe.hexpire(name, timedelta(hours=1), key) + + name = "in_progress_users" + key = user.uuid # must use user.uuid here b/c user.bpuid might not be unique + pipe.hset(name, key, now) + pipe.hexpire(name, timedelta(hours=1), key) + pipe.execute() + return None + + def unmark_user_inprogress(self, user: User): + # Call when a user exits a Session + # This call is idempotent; it can be called multiple times (for the same user) + # and won't falsely decrease a counter. + r = self.redis_client + pipe = r.pipeline(transaction=False) + + name = f"in_progress_users:{user.product_id}" + # Delete the key, whether it exists or not + pipe.hdel(name, user.product_user_id) + name = "in_progress_users" + pipe.hdel(name, user.uuid) + pipe.execute() + return None + + def clear_global_user_stats(self) -> None: + # For testing + r = self.redis_client + r.delete("active_users_last_1h") + r.delete("active_users_last_24h") + r.delete("signups_last_24h") + r.delete("in_progress_users") + return None + + +class TaskStatsManager(RedisManager): + task_stats = [ + "task_created_count_last_1h", + "task_created_count_last_24h", + "live_task_count", + "live_tasks_max_payout", + "TaskStatsManager:latest", + ] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.SUM_HASH_LUA = self.redis_client.register_script(SUM_HASH_LUA_SCRIPT) + self.MAX_HASH_LUA = self.redis_client.register_script(MAX_HASH_LUA_SCRIPT) + + def set_source_task_stats( + self, + source: Source, + live_task_count: int, + live_tasks_max_payout: Decimal, + created_count: int = 0, + ): + self._incr_task_created_count(source=source, created_count=created_count) + self._set_live_task_stats( + source=source, + live_task_count=live_task_count, + live_tasks_max_payout=live_tasks_max_payout, + ) + self.refresh_latest_task_stats() + + def refresh_latest_task_stats(self): + res = TaskStatsSnapshot.model_validate(self.get_task_stats_raw()) + self.redis_client.set( + "TaskStatsManager:latest", res.model_dump_json(), ex=timedelta(hours=24) + ) + + def get_latest_task_stats(self) -> Optional[TaskStatsSnapshot]: + res = self.redis_client.get("TaskStatsManager:latest") + if res is not None: + return TaskStatsSnapshot.model_validate_json(res) + return None + + def _incr_task_created_count(self, source: Source, created_count: int) -> None: + now = round(time.time()) + minute = int(now // 60) * 60 + key = str(minute) + sec_24hr = round(timedelta(hours=24).total_seconds()) + sec_1hr = round(timedelta(hours=1).total_seconds()) + + expires_at_24hr = minute + sec_24hr + ttl_24hr = expires_at_24hr - now + expires_at_1hr = minute + sec_1hr + ttl_1hr = expires_at_1hr - now + + pipe = self.redis_client.pipeline(transaction=False) + name_1h_all = "task_created_count_last_1h" + pipe.hincrby(name_1h_all, key, created_count) + pipe.hexpire(name_1h_all, ttl_1hr, key) + name_24h_all = "task_created_count_last_24h" + pipe.hincrby(name_24h_all, key, created_count) + pipe.hexpire(name_24h_all, ttl_24hr, key) + + name_1h_source = f"task_created_count_last_1h:{source.value}" + pipe.hincrby(name_1h_source, key, created_count) + pipe.hexpire(name_1h_source, ttl_1hr, key) + name_24h_source = f"task_created_count_last_24h:{source.value}" + pipe.hincrby(name_24h_source, key, created_count) + pipe.hexpire(name_24h_source, ttl_24hr, key) + + pipe.execute() + return None + + def _set_live_task_stats( + self, source: Source, live_task_count: int, live_tasks_max_payout: Decimal + ): + # Keep the live stats per source. The total is the sum across all sources + pipe = self.redis_client.pipeline(transaction=False) + + # keys are source, value is the live task count + name = "live_task_count" + pipe.hset(name, source.value, live_task_count) + pipe.hexpire(name, timedelta(hours=24), source.value) + + # keys are source, value is the max_payout + name = "live_tasks_max_payout" + pipe.hset(name, source.value, round(live_tasks_max_payout * 100)) + pipe.hexpire(name, timedelta(hours=24), source.value) + + pipe.execute() + + def get_active_sources(self) -> List[Source]: + return [Source(x) for x in self.redis_client.hkeys("live_task_count")] + + def get_task_stats_raw(self): + sources = self.get_active_sources() + + pipe = self.redis_client.pipeline(transaction=False) + pipe.hgetall("live_task_count") + pipe.hgetall("live_tasks_max_payout") + for source in sources: + self.SUM_HASH_LUA( + keys=[f"task_created_count_last_1h:{source.value}"], + client=pipe, + ) + self.SUM_HASH_LUA( + keys=[f"task_created_count_last_24h:{source.value}"], + client=pipe, + ) + pipe_res = pipe.execute() + live_task_count_raw = pipe_res.pop(0) + live_tasks_max_payout_raw = pipe_res.pop(0) + + live_task_count_by_source = { + Source(k): int(v) for k, v in live_task_count_raw.items() + } + live_task_count = AggregateBySource( + total=sum(live_task_count_by_source.values()), + by_source=live_task_count_by_source, + ) + + live_tasks_max_payout_by_source = { + Source(k): int(v) for k, v in live_tasks_max_payout_raw.items() + } + live_tasks_max_payout = MaxGaugeBySource( + value=max(live_tasks_max_payout_by_source.values(), default=None), + by_source=live_tasks_max_payout_by_source, + ) + + task_created_count_last_1h = dict() + task_created_count_last_24h = dict() + for source in sources: + task_created_count_last_1h[source] = pipe_res.pop(0) + task_created_count_last_24h[source] = pipe_res.pop(0) + task_created_count_last_1h = AggregateBySource( + total=sum(task_created_count_last_1h.values()), + by_source=task_created_count_last_1h, + ) + task_created_count_last_24h = AggregateBySource( + total=sum(task_created_count_last_24h.values()), + by_source=task_created_count_last_24h, + ) + + return { + "task_created_count_last_1h": task_created_count_last_1h, + "task_created_count_last_24h": task_created_count_last_24h, + "live_task_count": live_task_count, + "live_tasks_max_payout": live_tasks_max_payout, + } + + def clear_task_stats(self): + keys = self.task_stats.copy() + keys.extend([f"task_created_count_last_1h:{source.value}" for source in Source]) + keys.extend( + [f"task_created_count_last_24h:{source.value}" for source in Source] + ) + self.redis_client.delete(*keys) + + +class SessionStatsManager(RedisManager): + """ + Each hashmap name stores keys where each key is a unix epoch minute. + The key expires in now - bucket's time period seconds. The value + is a counter. To get the sum, we just sum all the values. Any key + older than 1 hr (in the 1 hr bucket) will expire. + """ + + # Must be ordered. Don't change this + global_keys = [ + "session_enters_last_1h", + "session_enters_last_24h", + "session_fails_last_1h", + "session_fails_last_24h", + "session_completes_last_1h", + "session_completes_last_24h", + "sum_payouts_last_1h", + "sum_payouts_last_24h", + "sum_user_payouts_last_1h", + "sum_user_payouts_last_24h", + # "session_fail_loi_sum_last_1h", + "session_fail_loi_sum_last_24h", + # "session_complete_loi_sum_last_1h", + "session_complete_loi_sum_last_24h", + ] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.SUM_HASH_LUA = self.redis_client.register_script(SUM_HASH_LUA_SCRIPT) + + def session_on_finish(self, session: Session, user: User): + if session.status == Status.COMPLETE: + self.session_on_complete(session=session, user=user) + else: + self.session_on_fail(session=session, user=user) + return None + + def session_on_fail(self, session: Session, user: User): + r = self.redis_client + + now = int(time.time()) + assert session.status != Status.COMPLETE + assert session.finished + + ts = int(session.finished.timestamp()) + bucket_ts = (ts // 60) * 60 # minute-aligned epoch + key = str(bucket_ts) + + pipe = r.pipeline(transaction=True) + + for name_postfix, window in [ + ("last_1h", timedelta(hours=1).total_seconds()), + ("last_24h", timedelta(hours=24).total_seconds()), + ]: + ttl = round(window - (now - bucket_ts)) + if ttl <= 0: + continue + + name = "session_fails_" + name_postfix + # Global tracker + pipe.hincrby(name, key, 1) + pipe.hexpire(name, ttl, key, nx=True) + # BP-specific tracker + pipe.hincrby(name + ":" + user.product_id, key, 1) + pipe.hexpire(name + ":" + user.product_id, ttl, key, nx=True) + + # We're not returning this, but keep the sums, so we can + # calculate the avg + name = "session_fail_loi_sum_" + name_postfix + value = round(session.elapsed.total_seconds()) + pipe.hincrby(name, key, value) + pipe.hexpire(name, ttl, key, nx=True) + pipe.hincrby(name + ":" + user.product_id, key, value) + pipe.hexpire(name + ":" + user.product_id, ttl, key, nx=True) + + pipe.execute() + + def session_on_complete(self, session: Session, user: User): + r = self.redis_client + + now = int(time.time()) + assert session.status == Status.COMPLETE + assert session.finished + assert session.payout is not None + + ts = int(session.finished.timestamp()) + bucket_ts = (ts // 60) * 60 # minute-aligned epoch + key = str(bucket_ts) + + pipe = r.pipeline(transaction=True) + + for name_postfix, window in [ + ("last_1h", timedelta(hours=1).total_seconds()), + ("last_24h", timedelta(hours=24).total_seconds()), + ]: + ttl = round(window - (now - bucket_ts)) + if ttl <= 0: + continue + + name = "session_completes_" + name_postfix + # Global tracker + pipe.hincrby(name, key, 1) + pipe.hexpire(name, ttl, key, nx=True) + # BP-specific tracker + pipe.hincrby(name + ":" + user.product_id, key, 1) + pipe.hexpire(name + ":" + user.product_id, ttl, key, nx=True) + + name = "sum_payouts_" + name_postfix + amount = round(session.payout * 100) + pipe.hincrby(name, key, amount) + pipe.hexpire(name, ttl, key, nx=True) + pipe.hincrby(name + ":" + user.product_id, key, amount) + pipe.hexpire(name + ":" + user.product_id, ttl, key, nx=True) + + if session.user_payout: + name = "sum_user_payouts_" + name_postfix + amount = round(session.user_payout * 100) + pipe.hincrby(name, key, amount) + pipe.hexpire(name, ttl, key, nx=True) + pipe.hincrby(name + ":" + user.product_id, key, amount) + pipe.hexpire(name + ":" + user.product_id, ttl, key, nx=True) + + # We're not returning this, but keep the sums, so we can calculate the avg + name = "session_complete_loi_sum_" + name_postfix + value = round(session.elapsed.total_seconds()) + pipe.hincrby(name, key, value) + pipe.hexpire(name, ttl, key, nx=True) + pipe.hincrby(name + ":" + user.product_id, key, value) + pipe.hexpire(name + ":" + user.product_id, ttl, key, nx=True) + + pipe.execute() + + def session_on_enter(self, session: Session, user: User): + r = self.redis_client + + now = int(time.time()) + assert session.status is None + + ts = int(session.started.timestamp()) + bucket_ts = (ts // 60) * 60 # minute-aligned epoch + key = str(bucket_ts) + + pipe = r.pipeline(transaction=True) + + for name_postfix, window in [ + ("last_1h", timedelta(hours=1).total_seconds()), + ("last_24h", timedelta(hours=24).total_seconds()), + ]: + ttl = round(window - (now - bucket_ts)) + if ttl <= 0: + continue + + name = "session_enters_" + name_postfix + # Global tracker + pipe.hincrby(name, key, 1) + pipe.hexpire(name, ttl, key, nx=True) + # BP-specific tracker + pipe.hincrby(name + ":" + user.product_id, key, 1) + pipe.hexpire(name + ":" + user.product_id, ttl, key, nx=True) + + pipe.execute() + + def get_session_stats(self, product_id: UUIDStr): + r = self.redis_client + key_map = {k: k + ":" + product_id for k in self.global_keys} + keys = list(key_map.values()) + + pipe = r.pipeline(transaction=False) + for k in keys: + self.SUM_HASH_LUA(keys=[k], client=pipe) + + res = {k: v for k, v in zip(list(key_map.keys()), pipe.execute())} + self.calculate_avg_stats(res) + return res + + def get_global_session_stats(self): + r = self.redis_client + + pipe = r.pipeline(transaction=False) + for k in self.global_keys: + self.SUM_HASH_LUA(keys=[k], client=pipe) + + res = {k: v for k, v in zip(self.global_keys, pipe.execute())} + self.calculate_avg_stats(res) + return res + + def calculate_avg_stats(self, res: Dict[str, Optional[float | int]]): + res["session_avg_payout_last_24h"] = None + res["session_avg_user_payout_last_24h"] = None + res["session_complete_avg_loi_last_24h"] = None + res["session_fail_avg_loi_last_24h"] = None + if res["session_completes_last_24h"]: + res["session_avg_payout_last_24h"] = math.ceil( + res["sum_payouts_last_24h"] / res["session_completes_last_24h"] + ) + res["session_avg_user_payout_last_24h"] = math.ceil( + res["sum_user_payouts_last_24h"] / res["session_completes_last_24h"] + ) + res["session_complete_avg_loi_last_24h"] = round( + res["session_complete_loi_sum_last_24h"] + / res["session_completes_last_24h"] + ) + if res["session_fails_last_24h"]: + res["session_fail_avg_loi_last_24h"] = round( + res["session_fail_loi_sum_last_24h"] / res["session_fails_last_24h"] + ) + res.pop("session_complete_loi_sum_last_24h") + res.pop("session_fail_loi_sum_last_24h") + return res + + def clear_global_session_stats(self): + # For testing + r = self.redis_client + for k in self.global_keys: + r.delete(k) + + +class StatsManager(UserStatsManager, SessionStatsManager, TaskStatsManager): + + def get_stats_message(self, product_id: UUIDStr) -> StatsMessage: + res = self.get_session_stats(product_id=product_id) + res.update(self.get_user_stats(product_id=product_id)) + ts = self.get_latest_task_stats() + if ts is not None: + res.update(ts.model_dump()) + return StatsMessage.model_validate({"data": res}) + + def clear_stats(self): + self.clear_task_stats() + self.clear_global_user_stats() + self.clear_global_session_stats() + + +class EventManager(StatsManager): + CACHE_PREFIX = "EventManager" + + def __init__(self, *args, influx_client: InfluxDBClient = None, **kwargs): + super().__init__(*args, **kwargs) + self.influx_client = influx_client + self.stats_worker_thread = None + # Don't bother starting this thread if the influx_client is not set + if self.influx_client is not None: + self.stats_worker_thread = threading.Thread( + target=self.stats_worker, daemon=True + ) + self.stats_worker_thread.start() + + def get_channel_name(self, product_id: UUIDStr): + return f"{self.cache_prefix}:event-channel:{product_id}" + + def get_replay_channel_name(self, product_id: UUIDStr): + return f"{self.cache_prefix}:event-channel-replay:{product_id}" + + def get_last_stats_key(self, product_id: UUIDStr): + return f"{self.cache_prefix}:last_stats:{product_id}" + + def get_active_subscribers(self) -> Set[UUIDStr]: + res = self.redis_client.pubsub_channels(f"{self.cache_prefix}:event-channel:*") + product_ids = {x.rsplit(":", 1)[-1] for x in res} + # product_ids.update( + # {"fc14e741b5004581b30e6478363414df", "888dbc589987425fa846d6e2a8daed04"} + # ) + return product_ids + + def stats_worker(self): + while True: + try: + self.stats_worker_task() + except Exception as e: + logging.exception(e) + finally: + time.sleep(60) + + def stats_worker_task(self): + """ + Only a single worker will be running. It'll be responsible + for periodic publication of summary/stats messages. + + active_subscribers are product_ids that have a pubsub subscription + """ + # Make sure only whoever grabs the lock first runs + now = time.monotonic() + lock_key = f"{self.cache_prefix}:event-channel-lock" + res = self.redis_client.set(lock_key, 1, ex=120, nx=True) + if not res: + logging.debug("failed to acquire stats_worker_task lock") + return None + logging.info("Acquired stats_worker_task lock") + + for product_id in self.get_active_subscribers(): + if time.monotonic() - now > 120: + logging.exception("stats_worker_task is taking too long") + break + channel = self.get_channel_name(product_id) + msg = self.get_stats_message(product_id=product_id) + self.redis_client.publish(channel, msg.model_dump_json()) + self.redis_client.set( + self.get_last_stats_key(product_id), + msg.model_dump_json(), + ex=timedelta(hours=24), + ) + if self.influx_client: + _, numsub = self.redis_client.pubsub_numsub(channel)[0] + point = self.make_influx_point(channel, numsub) + self.influx_client.write_points([point]) + + self.redis_client.delete(lock_key) + + def make_influx_point(self, channel: str, numsub: int): + return { + "measurement": "redis_pubsub_subscribers", + "tags": {"hostname": socket.gethostname(), "channel": channel}, + "fields": { + "subscribers": float(numsub), + }, + } + + def publish_event(self, msg: EventMessage, product_id: UUIDStr): + channel = self.get_channel_name(product_id) + replay = self.get_replay_channel_name(product_id) + print(f"publish: {self.get_channel_name(product_id)} {msg.kind=}") + + msg_json = msg.model_dump_json() + pipe = self.redis_client.pipeline() + pipe.publish(channel, msg_json) + + # replay buffer + pipe.lpush(replay, msg_json) + pipe.ltrim(replay, 0, 9) # keep last 10 + pipe.expire(replay, 86400) # 24h since last message + + # Last stats + + pipe.execute() + + def handle_task_enter(self, wall: Wall, session: Session, user: User): + self.handle_user(user=user) + + msg = EventMessage( + data=EventEnvelope( + event_type=EventType.TASK_ENTER, + timestamp=wall.started, + product_id=user.product_id, + product_user_id=user.product_user_id, + payload=TaskEnterPayload( + source=wall.source, + survey_id=wall.req_survey_id, + country_iso=session.country_iso, + ), + ) + ) + self.publish_event(msg, product_id=user.product_id) + return None + + def handle_task_finish(self, wall: Wall, session: Session, user: User): + self.mark_user_active(user=user) + + msg = EventMessage( + data=EventEnvelope( + event_type=EventType.TASK_FINISH, + timestamp=wall.finished, + product_id=user.product_id, + product_user_id=user.product_user_id, + payload=TaskFinishPayload( + source=wall.source, + survey_id=wall.req_survey_id, + country_iso=session.country_iso, + duration_sec=wall.elapsed.total_seconds(), + status=wall.status, + status_code_1=wall.status_code_1, + status_code_2=wall.status_code_2, + cpi=round(wall.cpi * 100), + ), + ) + ) + self.publish_event(msg, product_id=user.product_id) + + def handle_session_enter(self, session: Session, user: User): + self.handle_user(user=user) + self.mark_user_inprogress(user=user) + self.session_on_enter(session=session, user=user) + + msg = EventMessage( + data=EventEnvelope( + event_type=EventType.SESSION_ENTER, + timestamp=session.started, + product_id=user.product_id, + product_user_id=user.product_user_id, + payload=SessionEnterPayload( + country_iso=session.country_iso, + ), + ) + ) + self.publish_event(msg, product_id=user.product_id) + + def handle_session_finish(self, session: Session, user: User): + self.mark_user_active(user=user) + self.unmark_user_inprogress(user=user) + self.session_on_finish(session=session, user=user) + + msg = EventMessage( + data=EventEnvelope( + event_type=EventType.SESSION_FINISH, + timestamp=session.finished, + product_id=user.product_id, + product_user_id=user.product_user_id, + payload=SessionFinishPayload( + country_iso=session.country_iso, + duration_sec=session.elapsed.total_seconds(), + status=session.status, + status_code_1=session.status_code_1, + status_code_2=session.status_code_2, + user_payout=( + round(session.user_payout * 100) + if session.user_payout + else None + ), + ), + ) + ) + self.publish_event(msg, product_id=user.product_id) + + +class EventSubscriber(RedisManager): + """ + Initialize this class once per websocket connection. This subscribes that client + to a redis PubSub and handles any filtering and parsing of the messages. + """ + + CACHE_PREFIX = "EventManager" + + def __init__(self, *args, product_id: UUIDStr, **kwargs): + super().__init__(*args, **kwargs) + self.product_id = product_id + self.pubsub_client: Optional[Redis] = None + self.pubsub: Optional[PubSub] = None + self._subscribe() + + def _subscribe(self): + if self.pubsub is not None: + raise ValueError("Already subscribed") + r = self.redis_config.create_redis_client() + p = r.pubsub(ignore_subscribe_messages=True) + p.subscribe(self.get_channel_name()) + self.pubsub_client = r + self.pubsub = p + return None + + def get_channel_name(self): + return f"{self.cache_prefix}:event-channel:{self.product_id}" + + def get_replay_channel_name(self): + return f"{self.cache_prefix}:event-channel-replay:{self.product_id}" + + def get_last_stats_key(self): + return f"{self.cache_prefix}:last_stats:{self.product_id}" + + def get_last_stats_msg(self) -> Optional[StatsMessage]: + raw = self.redis_client.get(self.get_last_stats_key()) + if raw is not None: + return StatsMessage.model_validate_json(raw) + return None + + def get_replay_messages(self) -> list[ServerToClientMessage]: + key = self.get_replay_channel_name() + raw = self.redis_client.lrange(key, 0, -1) + # messages are newest -> oldest; reverse for playback + raw.reverse() + return [ServerToClientMessageAdapter.validate_json(x) for x in raw] + + def poll_message(self) -> Optional[ServerToClientMessage]: + res = self.pubsub.get_message(ignore_subscribe_messages=True) + if res is None: + return None + return ServerToClientMessageAdapter.validate_json(res["data"]) + + def get_next_message(self) -> ServerToClientMessage: + while True: + res = self.poll_message() + if res is None: + time.sleep(0.1) + continue + return res + + def clear_replay_messages(self): + # For testing + self.redis_client.delete(self.get_replay_channel_name()) diff --git a/generalresearch/managers/gr/__init__.py b/generalresearch/managers/gr/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/gr/authentication.py b/generalresearch/managers/gr/authentication.py new file mode 100644 index 0000000..f851cfa --- /dev/null +++ b/generalresearch/managers/gr/authentication.py @@ -0,0 +1,331 @@ +import binascii +import logging +import os +from datetime import datetime, timezone +from typing import Optional, List, TYPE_CHECKING, Dict, Union +from uuid import uuid4 + +from psycopg import sql +from pydantic import AnyHttpUrl +from pydantic import PositiveInt + +from generalresearch.managers.base import PostgresManagerWithRedis, PostgresManager +from generalresearch.models.custom_types import UUIDStr +from generalresearch.redis_helper import RedisConfig +from generalresearch.pg_helper import PostgresConfig + +LOG = logging.getLogger("gr") + +if TYPE_CHECKING: + from generalresearch.models.gr.authentication import GRUser, GRToken + + +class GRUserManager(PostgresManagerWithRedis): + + def create_dummy( + self, + sub: Optional[str] = None, + is_superuser: bool = False, + ) -> "GRUser": + sub = sub or f"{uuid4().hex}-{uuid4().hex}" + + return self.create( + sub=sub, + is_superuser=is_superuser, + ) + + def create( + self, + sub: str, + is_superuser: bool = False, + ) -> "GRUser": + from generalresearch.models.gr.authentication import GRUser + + now = datetime.now(tz=timezone.utc) + + instance = GRUser.model_validate( + { + "sub": sub, + "is_superuser": is_superuser, + "date_joined": now, + } + ) + data = instance.model_dump(mode="json") + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + query = sql.SQL( + """ + INSERT INTO gr_user + (sub, is_superuser, date_joined) + VALUES (%(sub)s, %(is_superuser)s, %(date_joined)s) + RETURNING id + """ + ) + c.execute(query=query, params=data) + gr_user_id: int = c.fetchone()["id"] + conn.commit() + + instance.id = gr_user_id + return instance + + def get_by_id(self, gr_user_id: int) -> Optional["GRUser"]: + from generalresearch.models.gr.authentication import GRUser + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=""" + SELECT u.* + FROM gr_user AS u + WHERE u.id = %s + LIMIT 1; + """, + params=(gr_user_id,), + ) + res = c.fetchone() + + if res is None: + raise ValueError("GRUser not found") + assert isinstance( + res, dict + ), "GRUserManager.get_by_id query returned invalid results" + + # We can return None if no MySQL results were found... but raise an + # error if returning failed for a different reason + gr_user = GRUser.from_postgresql(res) + assert isinstance(gr_user, GRUser), "GRUser not serialized correctly" + return gr_user + + def get_by_sub(self, sub: str, raises=True) -> Optional["GRUser"]: + from generalresearch.models.gr.authentication import GRUser + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=""" + SELECT u.* + FROM gr_user AS u + WHERE u.sub = %s + LIMIT 1; + """, + params=(sub,), + ) + res = c.fetchone() + + if raises and res is None: + raise ValueError("GRUser not found") + + if res is None: + return None + + assert isinstance( + res, dict + ), "GRUserManager.get_by_id query returned invalid results" + + # We can return None if no MySQL results were found... but raise an + # error if returning failed for a different reason + gr_user = GRUser.from_postgresql(res) + assert isinstance(gr_user, GRUser), "GRUser not serialized correctly" + return gr_user + + def get_by_sub_or_create(self, sub: str) -> "GRUser": + return self.get_by_sub(sub=sub, raises=False) or self.create(sub=sub) + + def get_all(self) -> List["GRUser"]: + from generalresearch.models.gr.authentication import GRUser + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=""" + SELECT u.* + FROM gr_user AS u + """ + ) + res = c.fetchall() + + return [GRUser.from_postgresql(i) for i in res] + + def get_by_team(self, team_id: PositiveInt) -> List["GRUser"]: + from generalresearch.models.gr.authentication import GRUser + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=""" + SELECT gru.* + FROM common_membership AS membership + INNER JOIN gr_user AS gru + ON gru.id = membership.user_id + WHERE membership.team_id = %s + """, + params=(team_id,), + ) + res = c.fetchall() + + for item in res: + for k, v in item.items(): + if isinstance(item[k], datetime): + item[k] = item[k].replace(tzinfo=timezone.utc) + + return [GRUser.model_validate(item) for item in res] + + def list_product_uuids( + self, user: "GRUser", thl_pg_config: PostgresConfig + ) -> Optional[List[UUIDStr]]: + if user.business_uuids is None: + LOG.warning("prefetch not run") + return None + + res = thl_pg_config.execute_sql_query( + query=f""" + SELECT bp.id + FROM userprofile_brokerageproduct AS bp + WHERE bp.business_id = ANY(%s) + """, + params=[user.business_uuids], + ) + return [item["uuid"] for item in res] + + +class GRTokenManager(PostgresManager): + + def get_by_key( + self, + api_key: str, + jwks: Optional[Dict] = None, + audience: Optional[str] = None, + issuer: Optional[Union[AnyHttpUrl, str]] = None, + gr_redis_config: Optional[RedisConfig] = None, + ) -> "GRToken": + """Return the GRToken for this API Token. + + :param api_key: an api value from http header + :param jwks: a jwts dict from sso provider + :param audience: an oidc client id + :param issuer: a jwks_uri for sso provider + :param gr_redis_config: redis + + :return GRToken instance (minified version, no relationships) + :raises NotFoundException + """ + from generalresearch.models.gr.authentication import GRToken, Claims + + # SSO Key + if GRToken.is_sso(api_key): + from jose import jwt + + payload = jwt.decode( + token=api_key, + key=jwks, + algorithms=["RS256"], + audience=audience, + issuer=issuer, + ) + claims = Claims.model_validate(payload) + + gr_um = GRUserManager( + pg_config=self.pg_config, redis_config=gr_redis_config + ) + gr_user = gr_um.get_by_sub_or_create(sub=claims.subject) + gr_user.claims = claims + + gr_token = GRToken.model_validate( + { + "key": api_key, + "user_id": gr_user.id, + "user": gr_user, + "created": datetime.now(tz=timezone.utc), + } + ) + + return gr_token + + # API Key + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + query = sql.SQL( + """ + SELECT grk.* + FROM gr_token AS grk + WHERE grk.key = %s + LIMIT 1 + """ + ) + c.execute(query=query, params=(api_key,)) + res = c.fetchall() + + if len(res) == 0: + raise Exception(f"No GRUser with token of '{api_key}'") + + if len(res) > 1: + raise Exception(f"Too many GRUsers found with token of '{api_key}'") + + item = res[0] + + return GRToken.model_validate(item) + + def create(self, user_id: PositiveInt) -> None: + # Taken directly from the DRF Token + # https://github.com/encode/django-rest-framework/blob/0f39e0124d358b0098261f070175fa8e0359b739/rest_framework/authtoken/models.py#L35-L37 + from generalresearch.models.gr.authentication import GRToken + + token = GRToken.model_validate( + { + "key": binascii.hexlify(os.urandom(20)).decode(), + "created": datetime.now(tz=timezone.utc), + "user_id": user_id, + } + ) + + data = token.model_dump() + data["user_id"] = token.user_id + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + INSERT INTO gr_token (key, user_id, created) + VALUES (%(key)s, %(user_id)s, %(created)s) + """ + ), + params=data, + ) + conn.commit() + + return None + + def get_by_user_id(self, user_id: PositiveInt) -> Optional["GRToken"]: + # django authtoken_token table has (user_id) UNIQUE constraint + # therefore, this will only return 0 or 1 GRTokens + from generalresearch.models.gr.authentication import GRToken + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + query = sql.SQL( + """ + SELECT grt.* + FROM gr_token AS grt + LEFT JOIN gr_user AS u + ON u.id = grt.user_id + WHERE u.id = %s + LIMIT 1; + """ + ) + + c.execute(query=query, params=(user_id,)) + + result = c.fetchall() + + if not result: + return None + + res = result[0] + + for k, v in res.items(): + if isinstance(res[k], datetime): + res[k] = res[k].replace(tzinfo=timezone.utc) + + return GRToken.model_validate(res) diff --git a/generalresearch/managers/gr/business.py b/generalresearch/managers/gr/business.py new file mode 100644 index 0000000..001a9e8 --- /dev/null +++ b/generalresearch/managers/gr/business.py @@ -0,0 +1,529 @@ +from typing import Optional, List, TYPE_CHECKING +from uuid import UUID, uuid4 + +from psycopg import sql +from pydantic import PositiveInt +from pydantic_extra_types.phone_numbers import PhoneNumber + +from generalresearch.managers.base import ( + PostgresManagerWithRedis, + PostgresManager, +) +from generalresearch.models.custom_types import UUIDStr + +if TYPE_CHECKING: + from generalresearch.models.gr.team import Team + from generalresearch.models.gr.business import ( + Business, + BusinessType, + BusinessAddress, + BusinessBankAccount, + TransferMethod, + ) + + +class BusinessBankAccountManager(PostgresManager): + + def create_dummy( + self, + business_id: PositiveInt, + uuid: Optional[UUID] = None, + transfer_method: Optional["TransferMethod"] = None, + account_number: Optional[str] = None, + routing_number: Optional[str] = None, + iban: Optional[str] = None, + swift: Optional[str] = None, + ): + from generalresearch.models.gr.business import TransferMethod + + uuid = uuid or uuid4().hex + transfer_method = transfer_method or TransferMethod.ACH + account_number = account_number or uuid4().hex[:6] + routing_number = routing_number or uuid4().hex[:6] + iban = iban or uuid4().hex[:6] + swift = swift or uuid4().hex[:6] + + return self.create( + business_id=business_id, + uuid=uuid, + transfer_method=transfer_method, + account_number=account_number, + routing_number=routing_number, + iban=iban, + swift=swift, + ) + + def create( + self, + business_id: PositiveInt, + uuid: UUIDStr, + transfer_method: "TransferMethod", + account_number: Optional[str] = None, + routing_number: Optional[str] = None, + iban: Optional[str] = None, + swift: Optional[str] = None, + ) -> "BusinessBankAccount": + from generalresearch.models.gr.business import BusinessBankAccount + + ba = BusinessBankAccount.model_validate( + { + "business_id": business_id, + "uuid": uuid, + "transfer_method": transfer_method, + "account_number": account_number, + "routing_number": routing_number, + "iban": iban, + "swift": swift, + } + ) + + data = ba.model_dump(mode="json") + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + INSERT INTO common_bankaccount + (uuid, transfer_method, account_number, + routing_number, iban, swift, business_id) + VALUES + (%(uuid)s, %(transfer_method)s, %(account_number)s, + %(routing_number)s, %(iban)s, %(swift)s, %(business_id)s) + RETURNING id + """ + ), + params=data, + ) + ba_id = c.fetchone()["id"] + conn.commit() + + ba.id = ba_id + return ba + + def get_by_business_id(self, business_id: UUIDStr) -> List["BusinessBankAccount"]: + from generalresearch.models.gr.business import BusinessBankAccount + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT ba.* + FROM common_bankaccount AS ba + WHERE ba.business_id = %s + """ + ), + params=(business_id,), + ) + res = c.fetchall() + + return [BusinessBankAccount.model_validate(item) for item in res] + + +class BusinessAddressManager(PostgresManager): + + def create_dummy( + self, + business_id: PositiveInt, + uuid: Optional[UUIDStr] = None, + line_1: Optional[str] = None, + line_2: Optional[str] = None, + city: Optional[str] = None, + state: Optional[str] = None, + postal_code: Optional[str] = None, + phone_number: Optional[PhoneNumber] = None, + country: Optional[str] = None, + ): + uuid = uuid or uuid4().hex + line_1 = line_1 or "abc" + line_2 = line_2 or "bczx" + city = city or "Downingtown" + state = state or "CA" + postal_code = postal_code or "94041" + phone_number = None + country = country or "US" + + return self.create( + business_id=business_id, + uuid=uuid, + line_1=line_1, + line_2=line_2, + city=city, + state=state, + postal_code=postal_code, + phone_number=phone_number, + country=country, + ) + + def create( + self, + business_id: PositiveInt, + uuid: UUIDStr, + line_1: Optional[str] = None, + line_2: Optional[str] = None, + city: Optional[str] = None, + state: Optional[str] = None, + postal_code: Optional[str] = None, + phone_number: Optional[PhoneNumber] = None, + country: Optional[str] = None, + ) -> "BusinessAddress": + from generalresearch.models.gr.business import BusinessAddress + + ba = BusinessAddress.model_validate( + { + "business_id": business_id, + "uuid": uuid, + "line_1": line_1, + "line_2": line_2, + "city": city, + "state": state, + "postal_code": postal_code, + "phone_number": phone_number, + "country": country, + } + ) + data = ba.model_dump() + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + INSERT INTO common_businessaddress + (uuid, line_1, line_2, city, country, state, + postal_code, phone_number, business_id) + VALUES + (%(uuid)s, %(line_1)s, %(line_2)s, %(city)s, %(country)s, %(state)s, + %(postal_code)s, %(phone_number)s, %(business_id)s) + RETURNING id + """ + ), + params=data, + ) + ba_id = c.fetchone()["id"] + conn.commit() + + ba.id = ba_id + return ba + + +class BusinessManager(PostgresManagerWithRedis): + """This can and often references many data sources so it's important + to stay organized. + + - The GR-* project maintains its own PostgresSQL + database with Business metadata, contact information, relationship + to Teams and authentication details + - The thl-web brokerage table is ultimately our sense of truth + for which businesses exist and live Products under that + business + - The gr-redis instance stores cached values that may be commonly + referenced by the gr-api services + + """ + + def get_or_create( + self, + uuid: UUIDStr, + name: Optional[str] = None, + team: Optional["Team"] = None, + kind: Optional["BusinessType"] = None, + tax_number: Optional[str] = None, + ) -> "Business": + """ + Warning: this ** does not ** update the name, team, kind, tax_number + values if they differ from what was passed in for the + respective uuid + """ + + business = self.get_by_uuid(business_uuid=uuid) + + if business: + return business + + assert name, "Must provide Business name if creating" + return self.create( + uuid=uuid, name=name, team=team, kind=kind, tax_number=tax_number + ) + + def create_dummy( + self, + uuid: Optional[UUIDStr] = None, + name: Optional[str] = None, + team: Optional["Team"] = None, + kind: Optional["BusinessType"] = None, + tax_number: Optional[str] = None, + ) -> "Business": + from random import randint + + uuid = uuid or uuid4().hex + name = name or "< Unknown >" + tax_number = tax_number or str(randint(1, 999_999_999)) + + return self.create( + uuid=uuid, name=name, team=team, kind=kind, tax_number=tax_number + ) + + def create( + self, + name: str, + kind: Optional["BusinessType"] = None, + uuid: Optional[UUIDStr] = None, + team: Optional["Team"] = None, + tax_number: Optional[str] = None, + ) -> "Business": + """ + Behavior: does this raise on duplicate? + """ + from generalresearch.models.gr.business import ( + Business, + BusinessType, + ) + + business = Business.model_validate( + { + "uuid": uuid or uuid4().hex, + "name": name, + "kind": kind or BusinessType.COMPANY, + "tax_number": tax_number, + } + ) + data = business.model_dump() + data["tax_number"] = business.tax_number + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + INSERT INTO common_business (uuid, kind, name, tax_number) + VALUES (%(uuid)s, %(kind)s, %(name)s, %(tax_number)s) + RETURNING id + """ + ), + params=data, + ) + business_id = c.fetchone()["id"] + conn.commit() + business.id = business_id + + if team: + from generalresearch.managers.gr.team import TeamManager + + tm = TeamManager(pg_config=self.pg_config, redis_config=self.redis_config) + tm.add_business(team=team, business=business) + + return business + + def get_all(self) -> List["Business"]: + """WARNING: This should be access by the /god/ page only, and only + used by GRUser.is_staff as it doesn't provide any authentication + on it's own. This is used because the .get_by_team_id() and + .get_by_user_id() use the table relationships, and it's often too + tedious to ensure every GRL admin is manually added to each and + every Team in order to manage or view details about it. + + :return: + """ + from generalresearch.models.gr.business import Business + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT b.id, b.uuid, b.kind, b.name, b.tax_number + FROM common_business AS b + """ + ) + ) + res = c.fetchall() + + response = [] + for i in res: + # i["contact"] = BusinessContact.model_validate(i) + # i["address"] = BusinessAddress.model_validate(i) + i["contact"] = None + i["address"] = None + + response.append(Business.model_validate(i)) + + return response + + def get_by_team( + self, + team_id: PositiveInt, + ) -> List["Business"]: + + # conn: psycopg.Connection = GR_POSTGRES_C.make_connection() + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT b.id, b.uuid, b.kind, b.name, b.tax_number + FROM common_business AS b + INNER JOIN common_team_businesses as tb + ON tb.business_id = b.id + WHERE tb.team_id = %s + """ + ), + params=(team_id,), + ) + + res = c.fetchall() + + response = [] + from generalresearch.models.gr.business import Business + + for i in res: + # i["contact"] = BusinessContact.model_validate(i) + # i["address"] = BusinessAddress.model_validate(i) + response.append(Business.model_validate(i)) + + return response + + def get_by_user_id( + self, + user_id: PositiveInt, + ) -> List["Business"]: + from generalresearch.models.gr.business import Business + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT b.id, b.uuid, b.kind, b.name, b.tax_number + FROM common_business AS b + INNER JOIN common_team_businesses AS tb + ON tb.business_id = b.id + INNER JOIN common_membership AS m + ON m.team_id = tb.team_id + WHERE m.user_id = %s + """ + ), + params=(user_id,), + ) + + res = c.fetchall() + + response = [] + for i in res: + # i["contact"] = BusinessContact.model_validate(i) + # i["address"] = BusinessAddress.model_validate(i) + response.append(Business.model_validate(i)) + + return response + + def get_ids_by_user_id(self, user_id: PositiveInt) -> List[PositiveInt]: + """ + :return: Every Business UUIDStr that this GRUser has permission to view + """ + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT b.id + FROM common_business AS b + INNER JOIN common_team_businesses AS tb + ON tb.business_id = b.id + INNER JOIN common_membership AS cm + ON tb.team_id = cm.team_id + WHERE cm.user_id = %s + """ + ), + params=(user_id,), + ) + + res = c.fetchall() + + return [i["id"] for i in res] + + def get_uuids_by_user_id(self, user_id: PositiveInt) -> List[UUIDStr]: + """ + :return: Every Business UUIDStr that this GRUser has permission to view + """ + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT b.uuid + FROM common_business AS b + INNER JOIN common_team_businesses AS tb + ON tb.business_id = b.id + INNER JOIN common_membership AS cm + ON tb.team_id = cm.team_id + WHERE cm.user_id = %s + """ + ), + params=(user_id,), + ) + + res = c.fetchall() + + return [i["uuid"] for i in res] + + def get_by_uuid( + self, + business_uuid: UUIDStr, + ) -> Optional["Business"]: + from generalresearch.models.gr.business import Business + + assert UUID(hex=business_uuid).hex == business_uuid + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT id, uuid, kind, name, tax_number + FROM common_business + WHERE uuid = %s + LIMIT 1; + """ + ), + params=(business_uuid,), + ) + + res = c.fetchall() + + if len(res) == 0: + return None + + assert len(res) == 1, "BusinessManager.get_by_uuid returned invalid results" + data = res[0] + # data["address"] = BusinessAddress.model_validate(data) + # data["contact"] = BusinessContact.model_validate(data) + return Business.model_validate(data) + + def get_by_id(self, business_id: PositiveInt) -> Optional["Business"]: + from generalresearch.models.gr.business import Business + + assert isinstance(business_id, int) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT id, uuid, kind, name, tax_number + FROM common_business + WHERE id = %s + LIMIT 1; + """ + ), + params=(business_id,), + ) + + res = c.fetchall() + + if len(res) == 0: + return None + + return Business.model_validate(res[0]) diff --git a/generalresearch/managers/gr/team.py b/generalresearch/managers/gr/team.py new file mode 100644 index 0000000..a57ef5f --- /dev/null +++ b/generalresearch/managers/gr/team.py @@ -0,0 +1,312 @@ +from datetime import datetime, timezone +from typing import Optional, List, TYPE_CHECKING +from uuid import uuid4 + +from psycopg import sql +from pydantic import PositiveInt + +from generalresearch.managers.base import ( + PostgresManager, + PostgresManagerWithRedis, +) +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.gr.team import Membership, MembershipPrivilege + +if TYPE_CHECKING: + from generalresearch.models.gr.team import ( + Membership, + Team, + MembershipPrivilege, + ) + from generalresearch.models.gr.authentication import GRUser + from generalresearch.models.gr.business import Business + + +class MembershipManager(PostgresManager): + """The Membership Manager controls the relationships between a + GR User and a Team. + + GRUsers do not have direct connections to Businesses or Products, + they're all connected through a Team and a GRUser's relationship to + a Team can have various levels of permissions and rights. + """ + + def create( + self, + team: "Team", + gr_user: "GRUser", + privilege: MembershipPrivilege = MembershipPrivilege.READ, + ) -> Membership: + membership = Membership( + uuid=uuid4().hex, + privilege=MembershipPrivilege.READ, + owner=False, + team_id=team.id, + user_id=gr_user.id, + created=datetime.now(tz=timezone.utc), + ) + + data = membership.model_dump(by_alias=True) + data["team_id"] = team.id + data["user_id"] = gr_user.id + + # 'user_id' = {int} 5774 + # 'team_id' = {int} 20736 + + assert gr_user.id, "GR User must be saved" + assert team.id, "Team must be saved" + existing = self.exists(gr_user_id=gr_user.id, team_id=team.id) + if existing: + return existing + + gr_user_memberships = self.get_by_gr_user_id(gr_user_id=gr_user.id) + if len(gr_user_memberships) > 5: + raise ValueError("Should this GR User really be in more than 5 Teams?") + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + INSERT INTO common_membership + (uuid, privilege, owner, team_id, user_id, created) + VALUES (%(uuid)s, %(privilege)s, %(owner)s, %(team_id)s, + %(user_id)s, %(created)s) + RETURNING id + """ + ), + params=data, + ) + membership_id: int = c.fetchone()["id"] + conn.commit() + + membership.id = membership_id + return membership + + def exists( + self, gr_user_id: PositiveInt, team_id: PositiveInt + ) -> Optional[Membership]: + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT id, uuid, privilege, owner, created, + user_id, team_id + FROM common_membership + WHERE team_id = %s AND user_id = %s + LIMIT 1 + """ + ), + params=(team_id, gr_user_id), + ) + res = c.fetchone() + + if not res: + return None + + return Membership.model_validate(res) + + def get_by_team_id(self, team_id: PositiveInt) -> List[Membership]: + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT id, uuid, privilege, owner, created, + user_id, team_id + FROM common_membership + WHERE team_id = %s + LIMIT 250 + """ + ), + params=(team_id,), + ) + res = c.fetchall() + + return [Membership.model_validate(i) for i in res] + + def get_by_gr_user_id(self, gr_user_id: PositiveInt) -> List[Membership]: + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT id, uuid, privilege, owner, created, + user_id, team_id + FROM common_membership + WHERE user_id = %s + LIMIT 250 + """ + ), + params=(gr_user_id,), + ) + res = c.fetchall() + + return [Membership.model_validate(i) for i in res] + + +class TeamManager(PostgresManagerWithRedis): + + def get_or_create( + self, uuid: Optional[UUIDStr] = None, name: Optional[str] = None + ) -> "Team": + + team = self.get_by_uuid(team_uuid=uuid) + + if team: + return team + + return self.create(uuid=uuid, name=name or "< Unknown >") + + def get_all(self) -> List["Team"]: + from generalresearch.models.gr.team import Team + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT t.id, t.uuid, t.name + FROM common_team AS t + """ + ) + ) + res = c.fetchall() + + return [Team.model_validate(i) for i in res] + + def create_dummy( + self, uuid: Optional[UUIDStr] = None, name: Optional[str] = None + ) -> "Team": + uuid = uuid or uuid4().hex + name = name or f"name-{uuid4().hex[:12]}" + + return self.create(uuid=uuid, name=name) + + def create( + self, + name: str, + uuid: Optional[UUIDStr] = None, + ) -> "Team": + from generalresearch.models.gr.team import Team + + team = Team.model_validate({"uuid": uuid or uuid4().hex, "name": name}) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + INSERT INTO common_team (uuid, name) + VALUES (%s, %s) + RETURNING id + """ + ), + params=[team.uuid, team.name], + ) + team_id = c.fetchone()["id"] + conn.commit() + team.id = team_id + + return team + + def add_user(self, team: "Team", gr_user: "GRUser") -> "Membership": + """Create a Membership between a GRUser and a Team""" + + team.prefetch_gr_users(pg_config=self.pg_config, redis_config=self.redis_config) + + assert gr_user not in team.gr_users, ( + "Can't create multiple Memberships for " "the same User to the same Team" + ) + mm = MembershipManager(pg_config=self.pg_config) + + return mm.create(team=team, gr_user=gr_user) + + def add_business(self, team: "Team", business: "Business") -> None: + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + INSERT INTO common_team_businesses + (team_id, business_id) + VALUES (%s, %s) + """ + ), + params=( + team.id, + business.id, + ), + ) + conn.commit() + + return None + + def get_by_uuid(self, team_uuid: UUIDStr) -> Optional["Team"]: + from generalresearch.models.gr.team import Team + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=""" + SELECT t.* + FROM common_team AS t + WHERE t.uuid = %s + LIMIT 1; + """, + params=(team_uuid,), + ) + + res = c.fetchone() + + if not isinstance(res, dict): + return None + + return Team.model_validate(res) + + def get_by_id(self, team_id: PositiveInt) -> Optional["Team"]: + from generalresearch.models.gr.team import Team + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT t.id, t.uuid, t.name + FROM common_team AS t + WHERE t.id = %s + LIMIT 1; + """ + ), + params=(team_id,), + ) + + res = c.fetchone() + + if not isinstance(res, dict): + return None + + return Team.model_validate(res) + + def get_by_user(self, gr_user: "GRUser") -> List["Team"]: + from generalresearch.models.gr.team import Team + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=sql.SQL( + """ + SELECT team.* + FROM common_team AS team + INNER JOIN common_membership AS mem + ON mem.team_id = team.id + WHERE mem.user_id = %s + """ + ), + params=(gr_user.id,), + ) + + res = c.fetchall() + + return [Team.model_validate(item) for item in res] diff --git a/generalresearch/managers/innovate/__init__.py b/generalresearch/managers/innovate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/innovate/profiling.py b/generalresearch/managers/innovate/profiling.py new file mode 100644 index 0000000..a3939a7 --- /dev/null +++ b/generalresearch/managers/innovate/profiling.py @@ -0,0 +1,62 @@ +import json +from typing import List, Collection, Optional, Tuple + +from generalresearch.models.innovate.question import InnovateQuestion +from generalresearch.sql_helper import SqlHelper + + +def get_profiling_library( + sql_helper: SqlHelper, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + question_keys: Optional[Collection[str]] = None, + max_options: Optional[int] = None, + is_live: Optional[bool] = None, + pks: Optional[Collection[Tuple[str, str, str]]] = None, +) -> List[InnovateQuestion]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param question_keys: filters on question_key field, accepts multiple values + :param max_options: filters on max_options field + :param is_live: filters on is_live field + :param pks: The pk is (question_key, country_iso, language_iso). pks accepts a collection of + len(3) tuples. e.g. [('CORE_AUTOMOTIVE_0002', 'us', 'eng'), ('AGE', 'us', 'spa')] + :return: + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if question_keys: + params["question_keys"] = question_keys + filters.append("question_key IN %(question_keys)s") + if max_options is not None: + params["max_options"] = max_options + filters.append("COALESCE(num_options, 0) <= %(max_options)s") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + if pks: + params["pks"] = pks + filters.append("(question_key, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( + f""" + SELECT * + FROM `{sql_helper.db}`.`innovate_question` q + {filter_str} + """, + params, + ) + for x in res: + x["options"] = json.loads(x["options"]) if x["options"] else None + qs = [InnovateQuestion.from_db(x) for x in res] + return qs diff --git a/generalresearch/managers/innovate/survey.py b/generalresearch/managers/innovate/survey.py new file mode 100644 index 0000000..0e19065 --- /dev/null +++ b/generalresearch/managers/innovate/survey.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import logging +from datetime import timezone, datetime +from typing import List, Collection, Optional, Set + +import pymysql +from pymysql import IntegrityError + +from generalresearch.managers.criteria import CriteriaManager +from generalresearch.managers.survey import SurveyManager +from generalresearch.models.innovate.survey import ( + InnovateSurvey, + InnovateCondition, +) + +logger = logging.getLogger() + + +class InnovateCriteriaManager(CriteriaManager): + CONDITION_MODEL = InnovateCondition + TABLE_NAME = "innovate_criterion" + + +class InnovateSurveyManager(SurveyManager): + SURVEY_FIELDS = [ + "survey_id", + "status", + "country_iso", + "language_iso", + "cpi", + "buyer_id", + "job_id", + "survey_name", + "desired_count", + "remaining_count", + "supplier_completes_achieved", + "global_completes", + "global_starts", + "global_median_loi", + "global_conversion", + "bid_loi", + "bid_ir", + "allowed_devices", + "entry_link", + "category", + "requires_pii", + "excluded_surveys", + "duplicate_check_level", + "exclude_pids", + "include_pids", + "is_revenue_sharing", + "group_type", + "off_hour_traffic", + "qualifications", + "quotas", + "used_question_ids", + "is_live", + "modified_api", + "created_api", + "expected_end_date", + ] + + def get_survey_library( + self, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + survey_ids: Optional[Collection[str]] = None, + is_live: Optional[bool] = None, + updated_since: Optional[datetime] = None, + exclude_fields: Optional[Set[str]] = None, + ) -> List[InnovateSurvey]: + """ + Accepts lots of optional filters. + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param is_live: filters on is_live field + :param updated_since: filters on "> updated" + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if survey_ids is not None: + params["survey_ids"] = survey_ids + filters.append("survey_id IN %(survey_ids)s") + if is_live is not None: + if is_live: + filters.append("is_live") + else: + filters.append("NOT is_live") + if updated_since is not None: + params["updated_since"] = updated_since + filters.append("updated > %(updated_since)s") + assert filters, "Must set at least 1 filter" + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + fields = set(self.SURVEY_FIELDS) | {"created", "updated"} + if exclude_fields: + fields -= exclude_fields + fields_str = ", ".join([f"`{v}`" for v in fields]) + res = self.sql_helper.execute_sql_query( + f""" + SELECT {fields_str} + FROM `{self.sql_helper.db}`.`innovate_survey` survey + {filter_str} + """, + params, + ) + surveys = [InnovateSurvey.from_db(x) for x in res] + return surveys + + def create(self, survey: InnovateSurvey) -> bool: + now = datetime.now(tz=timezone.utc) + d = survey.to_mysql() + conn: pymysql.Connection = self.sql_helper.make_connection() + conn.autocommit(True) + c = conn.cursor() + create_fields = self.SURVEY_FIELDS + ["created", "updated"] + + fields_str = ", ".join([f"`{x}`" for x in create_fields]) + values_str = ", ".join([f"%({x})s" for x in create_fields]) + survey_data = {k: v for k, v in d.items() if k in create_fields} + survey_data.update({"created": now, "updated": now}) + c.execute( + f""" + INSERT INTO `{self.sql_helper.db}`.`innovate_survey` + ({fields_str}) VALUES ({values_str}) + """, + survey_data, + ) + return True + + def update(self, surveys: List[InnovateSurvey]) -> bool: + now = datetime.now(tz=timezone.utc) + update_fields = self.SURVEY_FIELDS + ["updated"] + + data = [survey.to_mysql() for survey in surveys] + survey_data = [[d[k] for k in self.SURVEY_FIELDS] + [now] for d in data] + self.sql_helper.bulk_update( + table_name="innovate_survey", + field_names=update_fields, + values_to_insert=survey_data, + ) + + return True + + def create_or_update(self, surveys: List[InnovateSurvey]) -> None: + surveys = {s.survey_id: s for s in surveys} + sns = set(surveys.keys()) + existing_sns = { + x["survey_id"] + for x in self.sql_helper.execute_sql_query( + query=f""" + SELECT survey_id + FROM `{self.sql_helper.db}`.`innovate_survey` + WHERE survey_id IN %s; + """, + params=[sns], + ) + } + create_sns = sns - existing_sns + for sn in create_sns: + survey = surveys[sn] + try: + self.create(survey) + except IntegrityError as e: + logger.info(e) + if e.args[0] == 1062: + existing_sns.add(sn) + else: + raise e + self.update([surveys[sn] for sn in existing_sns]) + + return None diff --git a/generalresearch/managers/innovate/user_pid.py b/generalresearch/managers/innovate/user_pid.py new file mode 100644 index 0000000..100b0ca --- /dev/null +++ b/generalresearch/managers/innovate/user_pid.py @@ -0,0 +1,7 @@ +from generalresearch.managers.marketplace.user_pid import UserPidManager +from generalresearch.models import Source + + +class InnovateUserPidManager(UserPidManager): + TABLE_NAME = "innovate_userpid" + SOURCE = Source.INNOVATE diff --git a/generalresearch/managers/leaderboard/__init__.py b/generalresearch/managers/leaderboard/__init__.py new file mode 100644 index 0000000..8468cdc --- /dev/null +++ b/generalresearch/managers/leaderboard/__init__.py @@ -0,0 +1,37 @@ +from typing import Dict +from zoneinfo import ZoneInfo + +import pytz +from cachetools import cached, LRUCache + + +@cached(cache=LRUCache(maxsize=1)) +def country_timezone() -> Dict[str, ZoneInfo]: + """ + Most countries only have 1 tz. I am picking the most populous for the rest. + A timezone is unique for a country, as in America/New_York and America/Toronto + both have the same UTC offset, but one is for US only and one is for CA only, + Each timezone is unique per country. + """ + ct = dict(pytz.country_timezones) + ct = {k: v[0] for k, v in ct.items()} + ct["US"] = "America/New_York" + ct["AR"] = "America/Argentina/Buenos_Aires" + ct["AU"] = "Australia/Sydney" + ct["BR"] = "America/Sao_Paulo" + ct["CA"] = "America/Toronto" + ct["CL"] = "America/Santiago" + ct["CN"] = "Asia/Shanghai" + ct["DE"] = "Europe/Berlin" + ct["EC"] = "America/Guayaquil" + ct["ES"] = "Europe/Madrid" + ct["ID"] = "Asia/Jakarta" + ct["KZ"] = "Asia/Almaty" + ct["MX"] = "America/Mexico_City" + ct["MY"] = "Asia/Kuala_Lumpur" + ct["NZ"] = "Pacific/Auckland" + ct["PT"] = "Europe/Lisbon" + ct["RU"] = "Europe/Moscow" + ct["UA"] = "Europe/Kiev" + ct = {k.lower(): ZoneInfo(v) for k, v in ct.items()} + return ct diff --git a/generalresearch/managers/leaderboard/manager.py b/generalresearch/managers/leaderboard/manager.py new file mode 100644 index 0000000..86b3d80 --- /dev/null +++ b/generalresearch/managers/leaderboard/manager.py @@ -0,0 +1,212 @@ +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from functools import cached_property +from typing import Optional, TYPE_CHECKING, List + +import pandas as pd +from pandas import Period +from pydantic import NaiveDatetime, AwareDatetime +from redis import Redis + +from generalresearch.managers.leaderboard import country_timezone +from generalresearch.models.thl.leaderboard import ( + LeaderboardCode, + LeaderboardFrequency, + LeaderboardRow, + Leaderboard, +) + +if TYPE_CHECKING: + from generalresearch.models.thl.session import Session + + +class LeaderboardManager: + def __init__( + self, + redis_client: Redis, + board_code: LeaderboardCode, + freq: LeaderboardFrequency, + product_id: str, + country_iso: str, + within_time: Optional[NaiveDatetime | AwareDatetime] = None, + ): + """ + :param within_time: Any local datetime falling within the desired leaderboard period. + e.g. (if freq=daily) within_time = 2024-04-12 01:02:03 will get the '2024-04-12' board + """ + self.redis_client = redis_client + self.timezone = country_timezone()[country_iso] + + self.board_code = board_code + self.freq = freq + self.product_id = product_id + self.country_iso = country_iso + self.within_time_aware = None + if within_time is None: + self.within_time_aware = datetime.now(tz=timezone.utc).astimezone( + self.timezone + ) + elif within_time.tzinfo is not None: + self.within_time_aware = within_time.astimezone(self.timezone) + else: + self.within_time_aware = within_time.replace(tzinfo=self.timezone) + self.key = self.board_key() + + @cached_property + def period(self) -> Period: + local_ts = self.within_time_aware + assert local_ts.tzinfo != timezone.utc and local_ts.tzinfo is not None + t = pd.Timestamp(local_ts).tz_localize(tz=None) + freq_pd = { + LeaderboardFrequency.WEEKLY: "W-SUN", + LeaderboardFrequency.DAILY: "D", + LeaderboardFrequency.MONTHLY: "M", + }[self.freq] + return t.to_period(freq_pd) + + @cached_property + def expiration(self) -> int: + # When the redis key for this board should expire + return { + LeaderboardFrequency.DAILY: int(timedelta(days=90).total_seconds()), + LeaderboardFrequency.WEEKLY: int(timedelta(days=365).total_seconds()), + LeaderboardFrequency.MONTHLY: int(timedelta(days=365 * 2).total_seconds()), + }[self.freq] + + def board_key(self) -> str: + product_id = self.product_id + country_iso = self.country_iso + freq = self.freq + board_code = self.board_code + date_str = self.period.start_time.to_pydatetime().strftime("%Y-%m-%d") + return f"leaderboard:{product_id}:{country_iso}:{freq.value}:{date_str}:{board_code.value}" + + def get_row_count(self) -> int: + # How many rows (unique users) does this leaderboard have? + return self.redis_client.zcard(self.key) or 0 + + def get_leaderboard_rows( + self, + limit: Optional[int] = None, + ) -> List[LeaderboardRow]: + limit = limit if limit else 0 + res = self.redis_client.zrange( + self.key, start=0, end=limit - 1, withscores=True, desc=True + ) + # We re-rank using pandas min value for ties. Redis does not consider ties in ranking. + s = pd.DataFrame(res, columns=["bpuid", "value"]).sort_values( + by="value", ascending=False + ) + s["rank"] = s["value"].rank(method="min", ascending=False) + return [ + LeaderboardRow(bpuid=r.bpuid, value=r.value, rank=r.rank) + for r in s.itertuples() + ] + + def get_personal_leaderboard_rows( + self, bp_user_id: str, limit: Optional[int] = 5 + ) -> List[LeaderboardRow]: + # We can't just grab this user's rank and nearby rows b/c redis does + # not handle ties the same way we do (in redis, each value is a + # unique rank, we use lowest rank for all ties). So we have to just + # grab everything, then filter + limit = limit if limit is not None else 5 + rows = self.get_leaderboard_rows() + rows = sorted(rows, key=lambda x: x.value, reverse=True) + user_indices = [ + (i, row) for i, row in enumerate(rows) if row.bpuid == bp_user_id + ] + if not user_indices: + return rows[: limit * 2] + user_idx = user_indices[0][0] + user_row = user_indices[0][1] + if user_row.rank == max([row.rank for row in rows]): + user_idx = [i for i, row in enumerate(rows) if row.rank == user_row.rank][0] + start: int = max(user_idx - limit, 0) + end: int = min(user_idx + limit + 1, len(rows)) + + return rows[start:end] + + def get_leaderboard( + self, + limit: Optional[int] = None, + bp_user_id=None, + ) -> Leaderboard: + + if bp_user_id: + rows = self.get_personal_leaderboard_rows( + bp_user_id=bp_user_id, limit=limit + ) + else: + rows = self.get_leaderboard_rows( + limit=limit, + ) + total = self.get_row_count() + + tz = self.timezone + + return Leaderboard( + board_code=self.board_code, + country_iso=self.country_iso, + bpid=self.product_id, + freq=self.freq, + row_count=total, + rows=rows, + period_start_local=self.period.start_time.to_pydatetime().replace( + tzinfo=self.timezone + ), + period_end_local=self.period.end_time.replace(nanosecond=0) + .to_pydatetime() + .replace(tzinfo=self.timezone), + timezone_name=str(tz), + ) + + def hit_complete_count(self, product_user_id: str) -> None: + assert ( + self.board_code == LeaderboardCode.COMPLETE_COUNT + ), "wrong kind of leaderboard" + self.redis_client.zincrby(self.key, amount=1, value=product_user_id) + self.redis_client.expire(self.key, time=self.expiration) + + return None + + def hit_sum_payouts(self, product_user_id: str, user_payout: Decimal) -> None: + assert ( + self.board_code == LeaderboardCode.SUM_PAYOUTS + ), "wrong kind of leaderboard" + self.redis_client.zincrby( + self.key, amount=round(user_payout * 100), value=product_user_id + ) + self.redis_client.expire(self.key, time=self.expiration) + + return None + + def hit_largest_payout(self, product_user_id: str, user_payout: Decimal) -> None: + assert ( + self.board_code == LeaderboardCode.LARGEST_PAYOUT + ), "wrong kind of leaderboard" + # Only sets the value if the new value is greater than the existing + self.redis_client.zadd( + self.key, {product_user_id: round(user_payout * 100)}, gt=True + ) + self.redis_client.expire(self.key, time=self.expiration) + + return None + + def hit(self, session: "Session") -> None: + user = session.user + match self.board_code: + case LeaderboardCode.COMPLETE_COUNT: + return self.hit_complete_count(product_user_id=user.product_user_id) + case LeaderboardCode.SUM_PAYOUTS: + return self.hit_sum_payouts( + product_user_id=user.product_user_id, + user_payout=session.user_payout, + ) + case LeaderboardCode.LARGEST_PAYOUT: + return self.hit_largest_payout( + product_user_id=user.product_user_id, + user_payout=session.user_payout, + ) + + return None diff --git a/generalresearch/managers/leaderboard/tasks.py b/generalresearch/managers/leaderboard/tasks.py new file mode 100644 index 0000000..e25c4d8 --- /dev/null +++ b/generalresearch/managers/leaderboard/tasks.py @@ -0,0 +1,59 @@ +import logging + +from redis import Redis + +from generalresearch.managers.leaderboard.manager import LeaderboardManager +from generalresearch.models.thl.session import Session +from generalresearch.models.thl.leaderboard import ( + LeaderboardFrequency, + LeaderboardCode, +) + +logger = logging.getLogger() + + +def hit_leaderboards(redis_client: Redis, session: Session): + user = session.user + assert user.product is not None, "prefetch user.product first" + product_id = user.product_id + country_iso = session.country_iso + ts = session.started + + for freq in [ + LeaderboardFrequency.DAILY, + LeaderboardFrequency.WEEKLY, + LeaderboardFrequency.MONTHLY, + ]: + lbm = LeaderboardManager( + redis_client=redis_client, + board_code=LeaderboardCode.COMPLETE_COUNT, + freq=freq, + product_id=product_id, + country_iso=country_iso, + within_time=ts, + ) + logger.info(lbm.key) + lbm.hit(session) + + if user.product.payout_config.payout_transformation: + lbm = LeaderboardManager( + redis_client=redis_client, + board_code=LeaderboardCode.SUM_PAYOUTS, + freq=freq, + product_id=product_id, + country_iso=country_iso, + within_time=ts, + ) + logger.info(lbm.key) + lbm.hit(session) + + lbm = LeaderboardManager( + redis_client=redis_client, + board_code=LeaderboardCode.LARGEST_PAYOUT, + freq=freq, + product_id=product_id, + country_iso=country_iso, + within_time=ts, + ) + logger.info(lbm.key) + lbm.hit(session) diff --git a/generalresearch/managers/lucid/__init__.py b/generalresearch/managers/lucid/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/lucid/profiling.py b/generalresearch/managers/lucid/profiling.py new file mode 100644 index 0000000..5f084b1 --- /dev/null +++ b/generalresearch/managers/lucid/profiling.py @@ -0,0 +1,83 @@ +import json +from typing import List, Collection, Optional, Tuple +from generalresearch.decorators import LOG +from pydantic import ValidationError + +from generalresearch.models.lucid.question import LucidQuestion, LucidQuestionType +from generalresearch.sql_helper import SqlHelper + + +def get_profiling_library( + sql_helper: SqlHelper, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + question_ids: Optional[Collection[str]] = None, + pks: Optional[Collection[Tuple[str | int, str, str]]] = None, +) -> List[LucidQuestion]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param question_ids: filters on question_id field, accepts multiple values + :param pks: The pk is (question_id, country_iso, language_iso). pks accepts a collection of + len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] + :return: + """ + filters = ["`q`.question_type != 'o'"] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`q`.`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`q`.`language_iso` = %(language_iso)s") + if question_ids: + params["question_ids"] = question_ids + filters.append("question_id IN %(question_ids)s") + if pks: + # In this table, the question_id is an int + pks = [(int(x[0]), x[1], x[2]) for x in pks] + params["pks"] = pks + filters.append("(q.question_id, q.country_iso, q.language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + db_name = sql_helper.db_name + res = sql_helper.execute_sql_query( + query=f""" + SELECT q.question_id, q.question_type, q.question_text, + q.country_iso, q.language_iso, + JSON_ARRAYAGG( + JSON_OBJECT('id', qo.precode, 'text', qo.option_text) + ) AS options + FROM `{db_name}`.`lucid_question` q + LEFT JOIN `{db_name}`.lucid_questionoption qo + ON q.question_id = qo.question_id + AND q.country_iso = qo.country_iso + AND q.language_iso = qo.language_iso + {filter_str} + GROUP BY q.question_id, q.country_iso, q.language_iso + """, + params=params, + ) + for x in res: + x["question_id"] = str(x["question_id"]) + x["options"] = json.loads(x["options"]) if x["options"] else None + # the mysql JSON_ARRAYAGG returns this if there are no options + x["options"] = ( + x["options"] if x["options"] != [{"id": None, "text": None}] else [] + ) + for n, y in enumerate(x["options"]): + y["order"] = n + # Special hack... These don't have options, but they should + # (CBSA, MSA, DMA), + if x["question_id"] in {"116", "120", "121"}: + x["question_type"] = LucidQuestionType.TEXT_ENTRY + qs = [] + for x in res: + try: + qs.append(LucidQuestion.from_db(x)) + except ValidationError as e: + LOG.warning(f"{x['question_id']}: {e}") + # print(x) + return qs diff --git a/generalresearch/managers/marketplace/__init__.py b/generalresearch/managers/marketplace/__init__.py new file mode 100644 index 0000000..3349434 --- /dev/null +++ b/generalresearch/managers/marketplace/__init__.py @@ -0,0 +1,23 @@ +from generalresearch.managers.cint.user_pid import CintUserPidManager +from generalresearch.managers.dynata.user_pid import DynataUserPidManager +from generalresearch.managers.innovate.user_pid import InnovateUserPidManager +from generalresearch.managers.morning.user_pid import MorningUserPidManager +from generalresearch.managers.precision.user_pid import PrecisionUserPidManager +from generalresearch.managers.prodege.user_pid import ProdegeUserPidManager +from generalresearch.managers.repdata.user_pid import RepdataUserPidManager +from generalresearch.managers.sago.user_pid import SagoUserPidManager +from generalresearch.managers.spectrum.user_pid import SpectrumUserPidManager + +_managers = [ + CintUserPidManager, + DynataUserPidManager, + InnovateUserPidManager, + MorningUserPidManager, + PrecisionUserPidManager, + ProdegeUserPidManager, + RepdataUserPidManager, + SagoUserPidManager, + SpectrumUserPidManager, +] + +USER_PID_MANAGERS = {x.SOURCE: x for x in _managers} diff --git a/generalresearch/managers/marketplace/user_pid.py b/generalresearch/managers/marketplace/user_pid.py new file mode 100644 index 0000000..f731179 --- /dev/null +++ b/generalresearch/managers/marketplace/user_pid.py @@ -0,0 +1,96 @@ +from abc import ABC +from typing import Collection, Optional, List, Dict +from uuid import UUID + +from generalresearch.managers.base import SqlManager +from generalresearch.models import Source +from generalresearch.sql_helper import SqlHelper + + +class UserPidManager(SqlManager, ABC): + """ + For getting user pids across marketplaces + """ + + SOURCE: Source = None + TABLE_NAME = None + + def filter( + self, + user_ids: Optional[Collection[int]] = None, + pids: Optional[Collection[str]] = None, + ) -> List[Dict[str, str]]: + """ + Filter by user_id or user_pid + """ + assert (user_ids or pids) and not ( + user_ids and pids + ), "Must pass ONE of user_ids, pids" + + params = [] + if user_ids: + assert len(user_ids) <= 500, "limit 500 user_ids" + assert isinstance( + user_ids, (list, set) + ), "must pass a collection of user_ids" + filter_str = "user_id IN %s" + params.append(set(user_ids)) + else: + assert len(pids) <= 500, "limit 500 pids" + assert isinstance(pids, (list, set)), "must pass a collection of pids" + pids = {UUID(x).hex for x in pids} + filter_str = "pid IN %s" + params.append(pids) + query = f""" + SELECT user_id, pid + FROM {self.mysql_db_table} + WHERE {filter_str} + LIMIT 500;""" + res = self.sql_helper.execute_sql_query( + query=query, + params=params, + ) + for x in res: + x["pid"] = UUID(x["pid"]).hex + return sorted(res, key=lambda x: x["user_id"]) + + @property + def mysql_db_table(self): + assert self.TABLE_NAME, "must subclass and set TABLE_NAME" + return f"`{self.sql_helper.db}`.`{self.TABLE_NAME}`" + + +class UserPidMultiManager: + """ + For looking up marketplace user_pids by user_id across multiple marketplaces + """ + + def __init__(self, sql_helper: SqlHelper, managers: List[UserPidManager]): + self.sql_helper = sql_helper + self.managers = managers + + def filter(self, user_ids: Optional[Collection[int]] = None): + # You can only query across all marketplaces by user_id. + # If you are looking by user_pid, it is assumed + # you know which marketplace you are looking in. + assert len(user_ids) <= 100, "limit 100 user_ids" + assert isinstance(user_ids, (list, set)), "must pass a collection of user_ids" + + params = [set(user_ids)] * len(self.managers) + queries = [ + f""" + SELECT user_id, pid, '{m.SOURCE.value}' as source + FROM {m.mysql_db_table} + WHERE user_id IN %s + """ + for m in self.managers + ] + query = "\nUNION ".join(queries) + res = self.sql_helper.execute_sql_query(query=query, params=params) + for x in res: + x["pid"] = UUID(x["pid"]).hex + x["source"] = Source(x["source"]) + + # Note: the wxet user pid is just the thl_user.uuid. Whatever uses this + # should insert that in. + return sorted(res, key=lambda x: (x["user_id"], x["source"].value)) diff --git a/generalresearch/managers/morning/__init__.py b/generalresearch/managers/morning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/morning/profiling.py b/generalresearch/managers/morning/profiling.py new file mode 100644 index 0000000..c397748 --- /dev/null +++ b/generalresearch/managers/morning/profiling.py @@ -0,0 +1,67 @@ +import json +from typing import List, Collection, Optional, Tuple + +from generalresearch.models.morning.question import MorningQuestion +from generalresearch.sql_helper import SqlHelper + + +def get_profiling_library( + sql_helper: SqlHelper, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + source: Optional[str] = None, + question_ids: Optional[Collection[str]] = None, + max_options: Optional[int] = None, + is_live: Optional[bool] = None, + pks: Optional[Collection[Tuple[str, str, str]]] = None, +) -> List[MorningQuestion]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param source: can be 'api' or 'exp-single-use' + :param question_ids: filters on question_id field, accepts multiple values + :param max_options: filters on max_options field + :param is_live: filters on is_live field + :param pks: The pk is (question_id, country_iso, language_iso). pks accepts a collection of + len(3) tuples. e.g. [('employer_size', 'us', 'eng'), ('employer_size', 'us', 'spa')] + :return: + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if source: + params["source"] = source + filters.append("source = %(source)s") + if question_ids: + params["question_ids"] = question_ids + filters.append("question_id IN %(question_ids)s") + if max_options is not None: + params["max_options"] = max_options + filters.append("COALESCE(num_options, 0) <= %(max_options)s") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + if pks: + params["pks"] = pks + filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( + f""" + SELECT * + FROM `thl-morning`.`morning_question` q + {filter_str} + """, + params, + ) + for x in res: + x["options"] = json.loads(x["options"]) if x["options"] else None + qs = [MorningQuestion.from_db(x) for x in res] + return qs diff --git a/generalresearch/managers/morning/survey.py b/generalresearch/managers/morning/survey.py new file mode 100644 index 0000000..9b08a65 --- /dev/null +++ b/generalresearch/managers/morning/survey.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +import json +import logging +from datetime import timezone, datetime +from typing import List, Collection, Optional + +import pymysql +from pymysql import IntegrityError + +from generalresearch.managers.criteria import CriteriaManager +from generalresearch.managers.survey import SurveyManager +from generalresearch.models.morning.survey import MorningBid, MorningCondition + +logger = logging.getLogger() + + +class MorningCriteriaManager(CriteriaManager): + CONDITION_MODEL = MorningCondition + TABLE_NAME = "morning_criterion" + + +class MorningSurveyManager(SurveyManager): + STAT_FIELDS = [ + "obs_median_loi", + "qualified_conversion", + "num_available", + "num_completes", + "num_failures", + "num_in_progress", + "num_over_quotas", + "num_qualified", + "num_quality_terminations", + "num_timeouts", + ] + STAT_EXTENDED_FIELDS = ["system_conversion", "num_entrants", "num_screenouts"] + BID_FIELDS = ( + [ + "id", + "status", + "country_iso", + "language_isos", + "buyer_account_id", + "buyer_id", + "name", + "supplier_exclusive", + "survey_type", + "timeout", + "topic_id", + "bid_loi", + "exclusions", + "used_question_ids", + "expected_end", + "created_api", + "is_live", + ] + + STAT_FIELDS + + STAT_EXTENDED_FIELDS + ) + QUOTA_FIELDS = [ + "id", + "cpi", + "condition_hashes", + ] + STAT_FIELDS + BID_DB_SOURCE = "`thl-morning`.`morning_surveybid`" + QUOTA_DB_SOURCE = "`thl-morning`.`morning_surveyquota`" + + def get_survey_library( + self, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + survey_ids: Optional[Collection[str]] = None, + is_live: Optional[bool] = None, + updated_since: Optional[datetime] = None, + ) -> List[MorningBid]: + """ + Accepts lots of optional filters. + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param is_live: filters on is_live field + :param updated_since: filters on "> updated" + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if survey_ids is not None: + params["survey_ids"] = survey_ids + filters.append("bid.id IN %(survey_ids)s") + if is_live is not None: + if is_live: + filters.append("is_live") + else: + filters.append("NOT is_live") + if updated_since is not None: + params["updated_since"] = updated_since + filters.append("updated > %(updated_since)s") + assert filters, "Must set at least 1 filter" + fields_str = """ + bid.*, + JSON_ARRAYAGG( + JSON_OBJECT( + 'id', quota.id, + 'cpi', quota.cpi, + 'condition_hashes', quota.condition_hashes, + 'num_available', quota.num_available, + 'num_completes', quota.num_completes, + 'num_failures', quota.num_failures, + 'num_in_progress', quota.num_in_progress, + 'num_over_quotas', quota.num_over_quotas, + 'num_qualified', quota.num_qualified, + 'num_quality_terminations', quota.num_quality_terminations, + 'num_timeouts', quota.num_timeouts, + 'obs_median_loi', quota.obs_median_loi, + 'qualified_conversion', quota.qualified_conversion + ) + ) AS quotas + """ + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = self.sql_helper.execute_sql_query( + f""" + SELECT {fields_str} + FROM {self.BID_DB_SOURCE} AS bid + JOIN {self.QUOTA_DB_SOURCE} AS quota ON bid.id = quota.bid_id + {filter_str} + GROUP BY bid.id; + """, + params, + ) + for bid in res: + bid["quotas"] = json.loads(bid["quotas"]) + bids = [MorningBid.from_db(x) for x in res] + return bids + + def create(self, bid: MorningBid) -> bool: + now = datetime.now(tz=timezone.utc) + d = bid.to_mysql() + create_fields = self.BID_FIELDS + ["created", "updated"] + + fields_str = ", ".join([f"`{x}`" for x in create_fields]) + values_str = ", ".join([f"%({x})s" for x in create_fields]) + bid_data = {k: v for k, v in d.items() if k in create_fields} + bid_data.update({"updated": now, "created": now}) + conn: pymysql.Connection = self.sql_helper.make_connection() + conn.autocommit(True) + c = conn.cursor() + c.execute( + query=f""" + INSERT INTO {self.BID_DB_SOURCE} + ({fields_str}) + VALUES ({values_str}) + """, + args=bid_data, + ) + + quotas = d["quotas"] + create_fields = self.QUOTA_FIELDS + ["bid_id"] + fields_str = ", ".join([f"`{x}`" for x in create_fields]) + values_str = ", ".join([f"%({x})s" for x in create_fields]) + quota_data = [ + {k: v for k, v in quota.items() if k in create_fields} | {"bid_id": bid.id} + for quota in quotas + ] + c = conn.cursor() + c.executemany( + query=f""" + INSERT INTO {self.QUOTA_DB_SOURCE} + ({fields_str}) + VALUES ({values_str}) + """, + args=quota_data, + ) + + return True + + def update(self, surveys: List[MorningBid]) -> None: + now = datetime.now(tz=timezone.utc) + + for survey in surveys: + self.update_one(survey, now=now) + + def update_one(self, bid: MorningBid, now=None) -> bool: + if now is None: + now = datetime.now(tz=timezone.utc) + d = bid.to_mysql() + d["updated"] = now + + bid_data = {k: v for k, v in d.items() if k in self.BID_FIELDS + ["updated"]} + set_str = ", ".join( + [ + f"`{k}` = %({k})s" + for k, v in d.items() + if k in self.BID_FIELDS + ["updated"] and k != "id" + ] + ) + + conn: pymysql.Connection = self.sql_helper.make_connection() + conn.autocommit(False) + c = conn.cursor() + c.execute( + f""" + UPDATE {self.BID_DB_SOURCE} + SET {set_str} + WHERE `id`=%(id)s + LIMIT 1""", + bid_data, + ) + + quota_data = [ + {k: v for k, v in quota.items() if k in self.QUOTA_FIELDS} + for quota in d["quotas"] + ] + set_str = ", ".join( + [ + f"`{k}` = %({k})s" + for k, v in d["quotas"][0].items() + if k in self.QUOTA_FIELDS and k != "id" + ] + ) + for quota in quota_data: + c.execute( + f""" + UPDATE {self.QUOTA_DB_SOURCE} + SET {set_str} + WHERE `id`=%(id)s + LIMIT 1""", + quota, + ) + + conn.commit() + return bool(c.rowcount >= 1) + + def create_or_update(self, surveys: List[MorningBid]): + surveys = {s.id: s for s in surveys} + sns = set(surveys.keys()) + existing_sns = { + x["id"] + for x in self.sql_helper.execute_sql_query( + """ + SELECT id + FROM `thl-morning`.`morning_surveybid` + WHERE id IN %s""", + [sns], + ) + } + create_sns = sns - existing_sns + for sn in create_sns: + survey = surveys[sn] + try: + self.create(survey) + except IntegrityError as e: + logger.info(e) + if e.args[0] == 1062: + existing_sns.add(sn) + else: + raise e + self.update([surveys[sn] for sn in existing_sns]) diff --git a/generalresearch/managers/morning/user_pid.py b/generalresearch/managers/morning/user_pid.py new file mode 100644 index 0000000..78de3bd --- /dev/null +++ b/generalresearch/managers/morning/user_pid.py @@ -0,0 +1,7 @@ +from generalresearch.managers.marketplace.user_pid import UserPidManager +from generalresearch.models import Source + + +class MorningUserPidManager(UserPidManager): + TABLE_NAME = "morning_userpid" + SOURCE = Source.MORNING_CONSULT diff --git a/generalresearch/managers/pollfish/__init__.py b/generalresearch/managers/pollfish/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/pollfish/profiling.py b/generalresearch/managers/pollfish/profiling.py new file mode 100644 index 0000000..735e824 --- /dev/null +++ b/generalresearch/managers/pollfish/profiling.py @@ -0,0 +1,62 @@ +import json +from typing import List, Collection, Optional, Tuple + +from generalresearch.models.pollfish.question import PollfishQuestion +from generalresearch.sql_helper import SqlHelper + + +def get_profiling_library( + sql_helper: SqlHelper, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + question_ids: Optional[Collection[str]] = None, + max_options: Optional[int] = None, + is_live: Optional[bool] = None, + pks: Optional[Collection[Tuple[str, str, str]]] = None, +) -> List[PollfishQuestion]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param question_ids: filters on question_id field, accepts multiple values + :param max_options: filters on max_options field + :param is_live: filters on is_live field + :param pks: The pk is (question_id, country_iso, language_iso). pks accepts a collection of + len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] + :return: + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if question_ids: + params["question_ids"] = question_ids + filters.append("question_id IN %(question_ids)s") + if max_options is not None: + params["max_options"] = max_options + filters.append("COALESCE(num_options, 0) <= %(max_options)s") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + if pks: + params["pks"] = pks + filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( + f""" + SELECT * + FROM `thl-pollfish`.`pollfish_question` q + {filter_str} + """, + params, + ) + for x in res: + x["options"] = json.loads(x["options"]) if x["options"] else None + qs = [PollfishQuestion.from_db(x) for x in res] + return qs diff --git a/generalresearch/managers/pollfish/user_pid.py b/generalresearch/managers/pollfish/user_pid.py new file mode 100644 index 0000000..1068405 --- /dev/null +++ b/generalresearch/managers/pollfish/user_pid.py @@ -0,0 +1,7 @@ +from generalresearch.managers.marketplace.user_pid import UserPidManager +from generalresearch.models import Source + + +class PollfishUserPidManager(UserPidManager): + TABLE_NAME = "pollfish_userpid" + SOURCE = Source.POLLFISH diff --git a/generalresearch/managers/precision/__init__.py b/generalresearch/managers/precision/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/precision/profiling.py b/generalresearch/managers/precision/profiling.py new file mode 100644 index 0000000..5f687ce --- /dev/null +++ b/generalresearch/managers/precision/profiling.py @@ -0,0 +1,62 @@ +import json +from typing import List, Collection, Optional, Tuple + +from generalresearch.models.precision.question import PrecisionQuestion +from generalresearch.sql_helper import SqlHelper + + +def get_profiling_library( + sql_helper: SqlHelper, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + question_ids: Optional[Collection[str]] = None, + max_options: Optional[int] = None, + is_live: Optional[bool] = None, + pks: Optional[Collection[Tuple[str, str, str]]] = None, +) -> List[PrecisionQuestion]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param question_ids: filters on question_id field, accepts multiple values + :param max_options: filters on max_options field + :param is_live: filters on is_live field + :param pks: The pk is (question_id, country_iso, language_iso). pks accepts a collection of + len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] + :return: + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if question_ids: + params["question_ids"] = question_ids + filters.append("question_id IN %(question_ids)s") + if max_options is not None: + params["max_options"] = max_options + filters.append("COALESCE(num_options, 0) <= %(max_options)s") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + if pks: + params["pks"] = pks + filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( + f""" + SELECT * + FROM `thl-precision`.`precision_question` q + {filter_str} + """, + params, + ) + for x in res: + x["options"] = json.loads(x["options"]) if x["options"] else None + qs = [PrecisionQuestion.from_db(x) for x in res] + return qs diff --git a/generalresearch/managers/precision/survey.py b/generalresearch/managers/precision/survey.py new file mode 100644 index 0000000..fc4e037 --- /dev/null +++ b/generalresearch/managers/precision/survey.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +import logging +from datetime import timezone, datetime +from typing import List, Collection, Optional + +import pymysql +from pymysql import IntegrityError + +from generalresearch.managers.criteria import CriteriaManager +from generalresearch.managers.survey import SurveyManager +from generalresearch.models.precision.survey import ( + PrecisionSurvey, + PrecisionCondition, +) + +logger = logging.getLogger() + + +class PrecisionCriteriaManager(CriteriaManager): + CONDITION_MODEL = PrecisionCondition + TABLE_NAME = "precision_criterion" + + +class PrecisionSurveyManager(SurveyManager): + SURVEY_FIELDS = [ + # 'country_iso', 'language_iso', # these come from join table + "survey_id", + "is_live", + "status", + "cpi", + "group_id", + "name", + "survey_guid", + "buyer_id", + "category_id", + "bid_loi", + "bid_ir", + "global_conversion", + "desired_count", + "achieved_count", + "allowed_devices", + "entry_link", + "excluded_surveys", + "quotas", + "used_question_ids", + "expected_end_date", + ] + + def get_survey_library( + self, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + survey_ids: Optional[Collection[str]] = None, + is_live: Optional[bool] = None, + updated_since: Optional[datetime] = None, + ) -> List[PrecisionSurvey]: + """ + Accepts lots of optional filters. + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param is_live: filters on is_live field + :param updated_since: filters on "> last_updated" + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if survey_ids is not None: + params["survey_ids"] = survey_ids + filters.append("s.survey_id IN %(survey_ids)s") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + if updated_since is not None: + params["updated"] = updated_since + filters.append("updated > %(updated)s") + assert filters, "Must set at least 1 filter" + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = self.sql_helper.execute_sql_query( + f""" + SELECT *, + GROUP_CONCAT(DISTINCT country_iso SEPARATOR ',') as country_isos, + GROUP_CONCAT(DISTINCT language_iso SEPARATOR ',') as language_isos + FROM `thl-precision`.`precision_survey` s + LEFT JOIN `thl-precision`.`precision_survey_country` sc on s.survey_id=sc.survey_id AND sc.is_active + LEFT JOIN `thl-precision`.`precision_survey_language` sl on s.survey_id=sl.survey_id AND sl.is_active + {filter_str} + GROUP BY s.survey_id + """, + params, + ) + for x in res: + x["country_isos"] = x["country_isos"].split(",") + x["language_isos"] = x["language_isos"].split(",") + surveys = [PrecisionSurvey.from_db(x) for x in res] + return surveys + + def create(self, survey: PrecisionSurvey) -> bool: + now = datetime.now(tz=timezone.utc) + d = survey.to_mysql() + conn: pymysql.Connection = self.sql_helper.make_connection() + conn.autocommit(False) + c = conn.cursor() + create_fields = self.SURVEY_FIELDS + ["created", "updated"] + + fields_str = ", ".join([f"`{x}`" for x in create_fields]) + values_str = ", ".join([f"%({x})s" for x in create_fields]) + survey_data = {k: v for k, v in d.items() if k in create_fields} + survey_data.update({"created": now, "updated": now}) + c.execute( + f""" + INSERT INTO `thl-precision`.`precision_survey` + ({fields_str}) VALUES ({values_str}) + """, + survey_data, + ) + + country_data = [(survey.survey_id, c) for c in survey.country_isos] + c.executemany( + f""" + INSERT INTO `thl-precision`.`precision_survey_country` + (survey_id, country_iso, is_active) VALUES + (%s, %s, TRUE) + """, + country_data, + ) + lang_data = [(survey.survey_id, c) for c in survey.language_isos] + c.executemany( + f""" + INSERT INTO `thl-precision`.`precision_survey_language` + (survey_id, language_iso, is_active) VALUES + (%s, %s, TRUE) + """, + lang_data, + ) + conn.commit() + + return True + + def update(self, surveys: List[PrecisionSurvey]) -> bool: + for survey in surveys: + self.update_one(survey) + return True + + def update_one(self, survey: PrecisionSurvey) -> bool: + now = datetime.now(tz=timezone.utc) + d = survey.to_mysql() + d["updated"] = now + + conn: pymysql.Connection = self.sql_helper.make_connection() + conn.autocommit(False) + c = conn.cursor() + + # Update survey table + set_str = ", ".join( + [ + f"`{k}` = %({k})s" + for k, v in d.items() + if k not in {"survey_id", "created"} + ] + ) + c.execute( + f""" + UPDATE `thl-precision`.precision_survey + SET {set_str} + WHERE `survey_id`=%(survey_id)s + LIMIT 1""", + d, + ) + + # Turn off countries not in the current list, for this survey + c.execute( + """ + UPDATE `thl-precision`.`precision_survey_country` + SET is_active = FALSE + WHERE survey_id = %(survey_id)s AND country_iso NOT IN %(country_isos)s; + """, + {"survey_id": survey.survey_id, "country_isos": survey.country_isos}, + ) + country_data = [(survey.survey_id, c) for c in survey.country_isos] + # Turn ON countries in this survey's list of countries, insert row, if already exists, set active. + c.executemany( + f""" + INSERT INTO `thl-precision`.`precision_survey_country` + (survey_id, country_iso, is_active) VALUES + (%s, %s, TRUE) ON DUPLICATE KEY UPDATE is_active = TRUE; + """, + country_data, + ) + + # Same thing with languages + c.execute( + """ + UPDATE `thl-precision`.`precision_survey_language` + SET is_active = FALSE + WHERE survey_id = %(survey_id)s AND language_iso NOT IN %(language_isos)s; + """, + {"survey_id": survey.survey_id, "language_isos": survey.language_isos}, + ) + language_data = [(survey.survey_id, c) for c in survey.language_isos] + c.executemany( + f""" + INSERT INTO `thl-precision`.`precision_survey_language` + (survey_id, language_iso, is_active) VALUES + (%s, %s, TRUE) ON DUPLICATE KEY UPDATE is_active = TRUE; + """, + language_data, + ) + conn.commit() + + return True + + def create_or_update(self, surveys: List[PrecisionSurvey]): + surveys = {s.survey_id: s for s in surveys} + sns = set(surveys.keys()) + existing_sns = { + x["survey_id"] + for x in self.sql_helper.execute_sql_query( + """ + SELECT survey_id + FROM `thl-precision`.`precision_survey` + WHERE survey_id IN %s""", + [sns], + ) + } + create_sns = sns - existing_sns + for sn in create_sns: + survey = surveys[sn] + try: + self.create(survey) + except IntegrityError as e: + logger.info(e) + if e.args[0] == 1062: + existing_sns.add(sn) + else: + raise e + self.update([surveys[sn] for sn in existing_sns]) diff --git a/generalresearch/managers/precision/user_pid.py b/generalresearch/managers/precision/user_pid.py new file mode 100644 index 0000000..50e97e6 --- /dev/null +++ b/generalresearch/managers/precision/user_pid.py @@ -0,0 +1,7 @@ +from generalresearch.managers.marketplace.user_pid import UserPidManager +from generalresearch.models import Source + + +class PrecisionUserPidManager(UserPidManager): + TABLE_NAME = "precision_userpid" + SOURCE = Source.PRECISION diff --git a/generalresearch/managers/prodege/__init__.py b/generalresearch/managers/prodege/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/prodege/profiling.py b/generalresearch/managers/prodege/profiling.py new file mode 100644 index 0000000..a33364b --- /dev/null +++ b/generalresearch/managers/prodege/profiling.py @@ -0,0 +1,62 @@ +import json +from typing import List, Collection, Optional, Tuple + +from generalresearch.models.prodege.question import ProdegeQuestion +from generalresearch.sql_helper import SqlHelper + + +def get_profiling_library( + sql_helper: SqlHelper, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + question_ids: Optional[Collection[str]] = None, + max_options: Optional[int] = None, + is_live: Optional[bool] = None, + pks: Optional[Collection[Tuple[str, str, str]]] = None, +) -> List[ProdegeQuestion]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param question_ids: filters on question_id field, accepts multiple values + :param max_options: filters on max_options field + :param is_live: filters on is_live field + :param pks: The pk is (question_id, country_iso, language_iso). pks accepts a collection of + len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] + :return: + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if question_ids: + params["question_ids"] = question_ids + filters.append("question_id IN %(question_ids)s") + if max_options is not None: + params["max_options"] = max_options + filters.append("COALESCE(num_options, 0) <= %(max_options)s") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + if pks: + params["pks"] = pks + filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( + f""" + SELECT * + FROM `thl-prodege`.`prodege_question` q + {filter_str} + """, + params, + ) + for x in res: + x["options"] = json.loads(x["options"]) if x["options"] else None + qs = [ProdegeQuestion.from_db(x) for x in res] + return qs diff --git a/generalresearch/managers/prodege/survey.py b/generalresearch/managers/prodege/survey.py new file mode 100644 index 0000000..ec5665e --- /dev/null +++ b/generalresearch/managers/prodege/survey.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from datetime import timezone, datetime +from typing import List, Collection, Optional + +import pymysql + +from generalresearch.managers.criteria import CriteriaManager +from generalresearch.managers.survey import SurveyManager +from generalresearch.models.prodege.survey import ProdegeSurvey, ProdegeCondition + + +class ProdegeCriteriaManager(CriteriaManager): + CONDITION_MODEL = ProdegeCondition + TABLE_NAME = "prodege_criterion" + + +class ProdegeSurveyManager(SurveyManager): + SURVEY_FIELDS = [ + "survey_id", + "survey_name", + "status", + "country_iso", + "language_iso", + "cpi", + "desired_count", + "remaining_count", + "achieved_completes", + "bid_loi", + "bid_ir", + "actual_loi", + "actual_ir", + "conversion_rate", + "entrance_url", + "max_clicks_settings", + "past_participation", + "include_psids", + "exclude_psids", + "quotas", + "used_question_ids", + "is_live", + ] + + def get_survey_library( + self, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + survey_ids: Optional[Collection[str]] = None, + is_live: Optional[bool] = None, + updated_since: Optional[datetime] = None, + ) -> List[ProdegeSurvey]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param is_live: filters on is_live field + :param updated_since: filters on "> updated" + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if survey_ids is not None: + params["survey_ids"] = survey_ids + filters.append("survey_id IN %(survey_ids)s") + if is_live is not None: + if is_live: + filters.append("status = 'LIVE'") + else: + filters.append("status != 'LIVE'") + if updated_since is not None: + params["updated_since"] = updated_since + filters.append("updated > %(updated_since)s") + assert filters, "Must set at least 1 filter" + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = self.sql_helper.execute_sql_query( + f""" + SELECT * + FROM `thl-prodege`.`prodege_survey` survey + {filter_str} + """, + params, + ) + surveys = [ProdegeSurvey.from_db(x) for x in res] + return surveys + + def create(self, survey: ProdegeSurvey) -> bool: + now = datetime.now(tz=timezone.utc) + d = survey.to_mysql() + conn: pymysql.Connection = self.sql_helper.make_connection() + conn.autocommit(True) + c = conn.cursor() + create_fields = self.SURVEY_FIELDS + ["created", "updated"] + + fields_str = ", ".join([f"`{x}`" for x in create_fields]) + values_str = ", ".join([f"%({x})s" for x in create_fields]) + survey_data = {k: v for k, v in d.items() if k in create_fields} + survey_data.update({"created": now, "updated": now}) + c.execute( + f""" + INSERT INTO `thl-prodege`.`prodege_survey` + ({fields_str}) VALUES ({values_str}) + """, + survey_data, + ) + return True + + def update(self, surveys: List[ProdegeSurvey]) -> None: + now = datetime.now(tz=timezone.utc) + + # Do to stupidity with bid/actual loi/ir values (see ProdegeSurvey.to_mysql), we now + # can't do a bulk update b/c the fields may be different in different rows. Just do + # one at a time, there shouldn't be that many. + for survey in surveys: + self.update_one(survey, now=now) + + def update_one(self, survey: ProdegeSurvey, now=None) -> bool: + if now is None: + now = datetime.now(tz=timezone.utc) + d = survey.to_mysql() + # We have to have special logic for bid/actual loi/ir here. The api is + # stupid and only returns one set of them. If we just do the db + # update it'll overwrite the other with NULL. So, exclude them if + # they are null. + + for k in ["bid_loi", "bid_ir", "actual_loi", "actual_ir"]: + if d[k] is None: + d.pop(k) + d["updated"] = now + set_str = ", ".join( + [ + f"`{k}` = %({k})s" + for k, v in d.items() + if k not in {"survey_id", "created"} + ] + ) + + conn: pymysql.Connection = self.sql_helper.make_connection() + conn.autocommit(True) + c = conn.cursor() + c.execute( + f""" + UPDATE `thl-prodege`.prodege_survey + SET {set_str} + WHERE `survey_id`=%(survey_id)s + LIMIT 1""", + d, + ) + return c.rowcount == 1 diff --git a/generalresearch/managers/prodege/user_pid.py b/generalresearch/managers/prodege/user_pid.py new file mode 100644 index 0000000..7c92e28 --- /dev/null +++ b/generalresearch/managers/prodege/user_pid.py @@ -0,0 +1,7 @@ +from generalresearch.managers.marketplace.user_pid import UserPidManager +from generalresearch.models import Source + + +class ProdegeUserPidManager(UserPidManager): + TABLE_NAME = "prodege_userpid" + SOURCE = Source.PRODEGE diff --git a/generalresearch/managers/repdata/__init__.py b/generalresearch/managers/repdata/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/repdata/profiling.py b/generalresearch/managers/repdata/profiling.py new file mode 100644 index 0000000..c508764 --- /dev/null +++ b/generalresearch/managers/repdata/profiling.py @@ -0,0 +1,62 @@ +import json +from typing import List, Collection, Optional, Tuple + +from generalresearch.models.repdata.question import RepDataQuestion +from generalresearch.sql_helper import SqlHelper + + +def get_profiling_library( + sql_helper: SqlHelper, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + question_ids: Optional[Collection[str]] = None, + max_options: Optional[int] = None, + is_live: Optional[bool] = None, + pks: Optional[Collection[Tuple[str, str, str]]] = None, +) -> List[RepDataQuestion]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param question_ids: filters on question_id field, accepts multiple values + :param max_options: filters on max_options field + :param is_live: filters on is_live field + :param pks: The pk is (question_id, country_iso, language_iso). pks accepts a collection of + len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] + :return: + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if question_ids: + params["question_ids"] = question_ids + filters.append("lucid_id IN %(question_ids)s") + if max_options is not None: + params["max_options"] = max_options + filters.append("COALESCE(num_options, 0) <= %(max_options)s") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + if pks: + params["pks"] = pks + filters.append("(lucid_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( + query=f""" + SELECT * + FROM `thl-repdata`.`repdata_question` q + {filter_str} + """, + params=params, + ) + for x in res: + x["options"] = json.loads(x["options"]) if x["options"] else None + qs = [RepDataQuestion.from_db(x) for x in res] + return qs diff --git a/generalresearch/managers/repdata/survey.py b/generalresearch/managers/repdata/survey.py new file mode 100644 index 0000000..05465e9 --- /dev/null +++ b/generalresearch/managers/repdata/survey.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import json +from datetime import timezone, datetime +from typing import List, Collection, Optional + +import pymysql + +from generalresearch.managers.criteria import CriteriaManager +from generalresearch.managers.survey import SurveyManager +from generalresearch.models.repdata.survey import ( + RepDataSurvey, + RepDataSurveyHashed, + RepDataStreamHashed, + RepDataCondition, +) + + +class RepDataCriteriaManager(CriteriaManager): + CONDITION_MODEL = RepDataCondition + TABLE_NAME = "repdata_criterion" + + +class RepDataSurveyManager(SurveyManager): + SURVEY_FIELDS = [ + "survey_id", + "survey_uuid", + "survey_name", + "project_uuid", + "survey_status", + "country_iso", + "language_iso", + "estimated_loi", + "estimated_ir", + "collects_pii", + "allowed_devices", + ] + STREAM_FIELDS = [ + "stream_id", + "stream_uuid", + "stream_name", + "stream_status", + "calculation_type", + "qualification_hashes", + "hashed_quotas", + "expected_count", + "cpi", + "days_in_field", + "actual_ir", + "actual_loi", + "actual_conversion", + "actual_complete_count", + "actual_count", + "used_question_ids", + "survey_id", + "remaining_count", + ] + + def get_survey_library( + self, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + survey_ids: Optional[Collection[str]] = None, + is_live: Optional[bool] = None, + updated_since: Optional[datetime] = None, + ) -> List[RepDataSurveyHashed]: + """ + Accepts lots of optional filters. + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param is_live: filters on is_live field + :param updated_since: filters on "> last_updated" + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if survey_ids is not None: + params["survey_ids"] = survey_ids + filters.append("survey_id IN %(survey_ids)s") + if is_live is not None: + if is_live: + filters.append("survey_status = 'LIVE'") + else: + filters.append("survey_status != 'LIVE'") + if updated_since is not None: + params["updated_since"] = updated_since + filters.append("last_updated > %(updated_since)s") + assert filters, "Must set at least 1 filter" + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = self.sql_helper.execute_sql_query( + query=f""" + SELECT * + FROM `thl-repdata`.`repdata_survey` survey + {filter_str} + """, + params=params, + ) + surveys = [RepDataSurveyHashed.from_db(x) for x in res] + surveys = {s.survey_id: s for s in surveys} + if surveys: + res = self.sql_helper.execute_sql_query( + query=f""" + SELECT * + FROM `thl-repdata`.`repdata_surveystream` + WHERE survey_id IN %s + """, + params=[list(surveys.keys())], + ) + for x in res: + x["qualification_hashes"] = json.loads(x["qualification_hashes"]) + x["hashed_quotas"] = json.loads(x["hashed_quotas"]) + x["used_question_ids"] = json.loads(x["used_question_ids"]) + for x in res: + survey = surveys[x["survey_id"]] + survey.hashed_streams.append(RepDataStreamHashed.from_db(x, survey)) + return list(surveys.values()) + + def create(self, survey: RepDataSurvey | RepDataSurveyHashed) -> bool: + now = datetime.now(tz=timezone.utc) + d = survey.to_mysql() + conn: pymysql.Connection = self.sql_helper.make_connection() + conn.autocommit(True) + c = conn.cursor() + create_fields = self.SURVEY_FIELDS + ["created", "last_updated"] + + fields_str = ", ".join([f"`{x}`" for x in create_fields]) + values_str = ", ".join([f"%({x})s" for x in create_fields]) + survey_data = {k: v for k, v in d.items() if k in create_fields} + survey_data.update({"created": now, "last_updated": now}) + c.execute( + query=f""" + INSERT INTO `thl-repdata`.`repdata_survey` + ({fields_str}) VALUES ({values_str}) + """, + args=survey_data, + ) + + fields_str = ", ".join([f"`{x}`" for x in self.STREAM_FIELDS]) + values_str = ", ".join([f"%({x})s" for x in self.STREAM_FIELDS]) + stream_data = [ + {k: v for k, v in stream.items() if k in self.STREAM_FIELDS} + for stream in d["streams"] + ] + for sd in stream_data: + sd.update({"survey_id": survey.survey_id}) + c.executemany( + query=f""" + INSERT INTO `thl-repdata`.`repdata_surveystream` + ({fields_str}) + VALUES ({values_str}) + """, + args=stream_data, + ) + return True + + def update(self, surveys: List[RepDataSurveyHashed]) -> bool: + now = datetime.now(tz=timezone.utc) + update_fields = self.SURVEY_FIELDS + ["last_updated"] + + data = [survey.to_mysql() for survey in surveys] + survey_data = [[d[k] for k in self.SURVEY_FIELDS] + [now] for d in data] + self.sql_helper.bulk_update( + table_name="repdata_survey", + field_names=update_fields, + values_to_insert=survey_data, + ) + + stream_data = [] + for d in data: + for stream in d["streams"]: + stream["survey_id"] = d["survey_id"] + stream_data.append([stream[k] for k in self.STREAM_FIELDS]) + + self.sql_helper.bulk_update( + table_name="repdata_surveystream", + field_names=self.STREAM_FIELDS, + values_to_insert=stream_data, + ) + return True diff --git a/generalresearch/managers/repdata/user_pid.py b/generalresearch/managers/repdata/user_pid.py new file mode 100644 index 0000000..9d53897 --- /dev/null +++ b/generalresearch/managers/repdata/user_pid.py @@ -0,0 +1,7 @@ +from generalresearch.managers.marketplace.user_pid import UserPidManager +from generalresearch.models import Source + + +class RepdataUserPidManager(UserPidManager): + TABLE_NAME = "repdata_userpid" + SOURCE = Source.REPDATA diff --git a/generalresearch/managers/sago/__init__.py b/generalresearch/managers/sago/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/sago/profiling.py b/generalresearch/managers/sago/profiling.py new file mode 100644 index 0000000..e1ed97f --- /dev/null +++ b/generalresearch/managers/sago/profiling.py @@ -0,0 +1,62 @@ +import json +from typing import List, Collection, Optional, Tuple + +from generalresearch.models.sago.question import SagoQuestion +from generalresearch.sql_helper import SqlHelper + + +def get_profiling_library( + sql_helper: SqlHelper, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + question_ids: Optional[Collection[str]] = None, + max_options: Optional[int] = None, + is_live: Optional[bool] = None, + pks: Optional[Collection[Tuple[str, str, str]]] = None, +) -> List[SagoQuestion]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param question_ids: filters on question_id field, accepts multiple values + :param max_options: filters on max_options field + :param is_live: filters on is_live field + :param pks: The pk is (question_id, country_iso, language_iso). pks accepts a collection of + len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] + :return: + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if question_ids: + params["question_ids"] = question_ids + filters.append("question_id IN %(question_ids)s") + if max_options is not None: + params["max_options"] = max_options + filters.append("COALESCE(num_options, 0) <= %(max_options)s") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + if pks: + params["pks"] = pks + filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( + f""" + SELECT * + FROM `thl-sago`.`sago_question` q + {filter_str} + """, + params, + ) + for x in res: + x["options"] = json.loads(x["options"]) if x["options"] else None + qs = [SagoQuestion.from_db(x) for x in res] + return qs diff --git a/generalresearch/managers/sago/survey.py b/generalresearch/managers/sago/survey.py new file mode 100644 index 0000000..535d8bb --- /dev/null +++ b/generalresearch/managers/sago/survey.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +import logging +from datetime import timezone, datetime +from typing import List, Collection, Optional, Set + +import pymysql +from pymysql import IntegrityError + +from generalresearch.managers.criteria import CriteriaManager +from generalresearch.managers.survey import SurveyManager +from generalresearch.models.sago.survey import SagoSurvey, SagoCondition + +logger = logging.getLogger() + + +class SagoCriteriaManager(CriteriaManager): + CONDITION_MODEL = SagoCondition + TABLE_NAME = "sago_criterion" + + +class SagoSurveyManager(SurveyManager): + SURVEY_FIELDS = [ + "survey_id", + "is_live", + "status", + "country_iso", + "language_iso", + "cpi", + "buyer_id", + "account_id", + "study_type_id", + "industry_id", + "allowed_devices", + "collects_pii", + "bid_loi", + "bid_ir", + "live_link", + "survey_exclusions", + "ip_exclusions", + "remaining_count", + "qualifications", + "quotas", + "used_question_ids", + "modified_api", + ] + + def get_survey_library( + self, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + survey_ids: Optional[Collection[str]] = None, + is_live: Optional[bool] = None, + updated_since: Optional[datetime] = None, + exclude_fields: Optional[Set[str]] = None, + ) -> List[SagoSurvey]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param is_live: filters on is_live field + :param updated_since: filters on "> last_updated" + :param exclude_fields: Optionally exclude fields from query. This + only supports nullable fields, as the SagoSurvey model validation + will fail otherwise. + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if survey_ids is not None: + params["survey_ids"] = survey_ids + filters.append("survey_id IN %(survey_ids)s") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + if updated_since is not None: + params["updated"] = updated_since + filters.append("updated > %(updated)s") + assert filters, "Must set at least 1 filter" + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + fields = set(self.SURVEY_FIELDS) | {"created", "updated"} + if exclude_fields: + fields -= exclude_fields + fields_str = ", ".join([f"`{v}`" for v in fields]) + res = self.sql_helper.execute_sql_query( + query=f""" + SELECT {fields_str} + FROM `thl-sago`.`sago_survey` survey + {filter_str} + """, + params=params, + ) + surveys = [SagoSurvey.from_db(x) for x in res] + return surveys + + def create(self, survey: SagoSurvey) -> bool: + now = datetime.now(tz=timezone.utc) + d = survey.to_mysql() + conn: pymysql.Connection = self.sql_helper.make_connection() + conn.autocommit(True) + c = conn.cursor() + create_fields = self.SURVEY_FIELDS + ["created", "updated"] + + fields_str = ", ".join([f"`{x}`" for x in create_fields]) + values_str = ", ".join([f"%({x})s" for x in create_fields]) + survey_data = {k: v for k, v in d.items() if k in create_fields} + survey_data.update({"created": now, "updated": now}) + c.execute( + query=f""" + INSERT INTO `thl-sago`.`sago_survey` + ({fields_str}) VALUES ({values_str}) + """, + args=survey_data, + ) + return True + + def update(self, surveys: List[SagoSurvey]) -> bool: + now = datetime.now(tz=timezone.utc) + update_fields = self.SURVEY_FIELDS + ["updated"] + + data = [survey.to_mysql() for survey in surveys] + survey_data = [[d[k] for k in self.SURVEY_FIELDS] + [now] for d in data] + self.sql_helper.bulk_update("sago_survey", update_fields, survey_data) + return True + + def update_field(self, survey: SagoSurvey, field: str) -> bool: + now = datetime.now(tz=timezone.utc) + conn: pymysql.Connection = self.sql_helper.make_connection() + value = survey.to_mysql()[field] + c = conn.cursor() + c.execute( + f""" + UPDATE `thl-sago`.`sago_survey` + SET `{field}` = %(value)s, + updated = %(now)s + WHERE survey_id = %(survey_id)s + LIMIT 2 + """, + {"now": now, "value": value, "survey_id": survey.survey_id}, + ) + conn.commit() + if c.rowcount == 0: + raise ValueError( + f"SagoSurveyManager.update_field: " + f"survey {survey.survey_id} not found in db!" + ) + elif c.rowcount == 2: + raise ValueError("this should never happen") + return True + + def create_or_update(self, surveys: List[SagoSurvey]): + surveys = {s.survey_id: s for s in surveys} + sns = set(surveys.keys()) + existing_sns = { + x["survey_id"] + for x in self.sql_helper.execute_sql_query( + query=""" + SELECT survey_id + FROM `thl-sago`.`sago_survey` + WHERE survey_id IN %s + """, + params=[sns], + ) + } + create_sns = sns - existing_sns + for sn in create_sns: + survey = surveys[sn] + try: + self.create(survey) + except IntegrityError as e: + logger.info(e) + if e.args[0] == 1062: + existing_sns.add(sn) + else: + raise e + + self.update([surveys[sn] for sn in existing_sns]) + + return None diff --git a/generalresearch/managers/sago/user_pid.py b/generalresearch/managers/sago/user_pid.py new file mode 100644 index 0000000..311abb7 --- /dev/null +++ b/generalresearch/managers/sago/user_pid.py @@ -0,0 +1,7 @@ +from generalresearch.managers.marketplace.user_pid import UserPidManager +from generalresearch.models import Source + + +class SagoUserPidManager(UserPidManager): + TABLE_NAME = "sago_userpid" + SOURCE = Source.SAGO diff --git a/generalresearch/managers/spectrum/__init__.py b/generalresearch/managers/spectrum/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/spectrum/profiling.py b/generalresearch/managers/spectrum/profiling.py new file mode 100644 index 0000000..21575c6 --- /dev/null +++ b/generalresearch/managers/spectrum/profiling.py @@ -0,0 +1,62 @@ +import json +from typing import List, Collection, Optional, Tuple + +from generalresearch.models.spectrum.question import SpectrumQuestion +from generalresearch.sql_helper import SqlHelper + + +def get_profiling_library( + sql_helper: SqlHelper, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + question_ids: Optional[Collection[str]] = None, + max_options: Optional[int] = None, + is_live: Optional[bool] = None, + pks: Optional[Collection[Tuple[str, str, str]]] = None, +) -> List[SpectrumQuestion]: + """ + Accepts lots of optional filters. + + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param question_ids: filters on question_id field, accepts multiple values + :param max_options: filters on max_options field + :param is_live: filters on is_live field + :param pks: The pk is (question_id, country_iso, language_iso). pks accepts a collection of + len(3) tuples. e.g. [('123', 'us', 'eng'), ('123', 'us', 'spa')] + :return: + """ + filters = ["is_valid"] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if question_ids: + params["question_ids"] = question_ids + filters.append("question_id IN %(question_ids)s") + if max_options is not None: + params["max_options"] = max_options + filters.append("COALESCE(num_options, 0) <= %(max_options)s") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + if pks: + params["pks"] = pks + filters.append("(question_id, country_iso, language_iso) IN %(pks)s") + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + res = sql_helper.execute_sql_query( + f""" + SELECT * + FROM `{sql_helper.db_name}`.`spectrum_question` q + {filter_str} + """, + params, + ) + for x in res: + x["options"] = json.loads(x["options"]) if x["options"] else None + qs = [SpectrumQuestion.from_db(x) for x in res] + return qs diff --git a/generalresearch/managers/spectrum/survey.py b/generalresearch/managers/spectrum/survey.py new file mode 100644 index 0000000..0dcc232 --- /dev/null +++ b/generalresearch/managers/spectrum/survey.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import logging +from datetime import timezone, datetime +from typing import List, Collection, Optional + +import pymysql +from pymysql import IntegrityError + +from generalresearch.managers.criteria import CriteriaManager +from generalresearch.managers.survey import SurveyManager +from generalresearch.models.spectrum.survey import ( + SpectrumSurvey, + SpectrumCondition, +) + +logger = logging.getLogger() + + +class SpectrumCriteriaManager(CriteriaManager): + CONDITION_MODEL = SpectrumCondition + TABLE_NAME = "spectrum_criterion" + + +class SpectrumSurveyManager(SurveyManager): + SURVEY_FIELDS = [ + "survey_id", + "survey_name", + "status", + "country_iso", + "language_iso", + "cpi", + "field_end_date", + "category_code", + "calculation_type", + "requires_pii", + "buyer_id", + "survey_exclusions", + "exclusion_period", + "bid_loi", + "bid_ir", + "last_block_loi", + "last_block_ir", + "overall_ir", + "overall_loi", + "project_last_complete_date", + "include_psids", + "exclude_psids", + "qualifications", + "quotas", + "used_question_ids", + "is_live", + "modified_api", + "created_api", + ] + + def get_survey_library( + self, + country_iso: Optional[str] = None, + language_iso: Optional[str] = None, + survey_ids: Optional[Collection[str]] = None, + is_live: Optional[bool] = None, + updated_since: Optional[datetime] = None, + fields=None, + ) -> List[SpectrumSurvey]: + """ + Accepts lots of optional filters. + :param country_iso: filters on country_iso field + :param language_iso: filters on language_iso field + :param is_live: filters on is_live field + :param updated_since: filters on "> updated" + """ + filters = [] + params = {} + if country_iso: + params["country_iso"] = country_iso + filters.append("`country_iso` = %(country_iso)s") + if language_iso: + params["language_iso"] = language_iso + filters.append("`language_iso` = %(language_iso)s") + if survey_ids is not None: + params["survey_ids"] = survey_ids + filters.append("survey_id IN %(survey_ids)s") + if is_live is not None: + if is_live: + filters.append("is_live") + else: + filters.append("NOT is_live") + if updated_since is not None: + params["updated_since"] = updated_since + filters.append("updated > %(updated_since)s") + assert filters, "Must set at least 1 filter" + fields_str = "*" + if fields: + fields_str = ",".join(fields) + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + + res = self.sql_helper.execute_sql_query( + query=f""" + SELECT {fields_str} + FROM `{self.sql_helper.db_name}`.`spectrum_survey` survey + {filter_str} + """, + params=params, + ) + + surveys = [SpectrumSurvey.from_db(x) for x in res] + return surveys + + def create(self, survey: SpectrumSurvey) -> bool: + now = datetime.now(tz=timezone.utc) + d = survey.to_mysql() + conn: pymysql.Connection = self.sql_helper.make_connection() + conn.autocommit(True) + c = conn.cursor() + create_fields = self.SURVEY_FIELDS + ["updated"] + + fields_str = ", ".join([f"`{x}`" for x in create_fields]) + values_str = ", ".join([f"%({x})s" for x in create_fields]) + survey_data = {k: v for k, v in d.items() if k in create_fields} + survey_data.update({"updated": now}) + + c.execute( + query=f""" + INSERT INTO `{self.sql_helper.db_name}`.`spectrum_survey` + ({fields_str}) + VALUES ({values_str}) + """, + args=survey_data, + ) + + return True + + def update(self, surveys: List[SpectrumSurvey]) -> bool: + now = datetime.now(tz=timezone.utc) + + # Due to stupidity with bid/actual loi/ir values (last block nonsense), + # we can't do a bulk update b/c the fields may be different in + # different rows. Just do one at a time, there shouldn't be that many. + for survey in surveys: + self.update_one(survey, now=now) + + return True + + def update_one(self, survey: SpectrumSurvey, now=None) -> bool: + if now is None: + now = datetime.now(tz=timezone.utc) + + d = survey.to_mysql() + # We have to have special logic for bid/actual loi/ir here. The api + # is stupid and only returns one set of them. If we just do the db + # update it'll overwrite the other with NULL. So, exclude them if + # they are null. + + for k in [ + "bid_loi", + "bid_ir", + "overall_loi", + "overall_ir", + "last_block_loi", + "last_block_ir", + ]: + if d[k] is None: + d.pop(k) + d["updated"] = now + set_str = ", ".join( + [ + f"`{k}` = %({k})s" + for k, v in d.items() + if k not in {"survey_id", "created"} + ] + ) + + conn: pymysql.Connection = self.sql_helper.make_connection() + conn.autocommit(True) + c = conn.cursor() + c.execute( + query=f""" + UPDATE `{self.sql_helper.db_name}`.spectrum_survey + SET {set_str} + WHERE `survey_id`=%(survey_id)s + LIMIT 1 + """, + args=d, + ) + + return c.rowcount == 1 + + def create_or_update(self, surveys: List[SpectrumSurvey]) -> None: + surveys = {s.survey_id: s for s in surveys} + sns = set(surveys.keys()) + existing_sns = { + x["survey_id"] + for x in self.sql_helper.execute_sql_query( + query=f""" + SELECT ss.survey_id + FROM `{self.sql_helper.db_name}`.`spectrum_survey` AS ss + WHERE ss.survey_id IN %s; + """, + params=[sns], + ) + } + create_sns = sns - existing_sns + for sn in create_sns: + survey = surveys[sn] + try: + self.create(survey) + except IntegrityError as e: + logger.info(e) + if e.args[0] == 1062: + existing_sns.add(sn) + else: + raise e + + self.update([surveys[sn] for sn in existing_sns]) + + return None diff --git a/generalresearch/managers/spectrum/user_pid.py b/generalresearch/managers/spectrum/user_pid.py new file mode 100644 index 0000000..495e73c --- /dev/null +++ b/generalresearch/managers/spectrum/user_pid.py @@ -0,0 +1,7 @@ +from generalresearch.managers.marketplace.user_pid import UserPidManager +from generalresearch.models import Source + + +class SpectrumUserPidManager(UserPidManager): + TABLE_NAME = "spectrum_userpid" + SOURCE = Source.SPECTRUM diff --git a/generalresearch/managers/survey.py b/generalresearch/managers/survey.py new file mode 100644 index 0000000..3e3f4ee --- /dev/null +++ b/generalresearch/managers/survey.py @@ -0,0 +1,27 @@ +from abc import ABC +from typing import List + +from generalresearch.managers.base import SqlManager +from generalresearch.models.thl.survey import MarketplaceTask + + +class SurveyManager(SqlManager, ABC): + + def create(self, survey: MarketplaceTask) -> bool: + """ + Create a single survey + """ + ... + + def update(self, surveys: List[MarketplaceTask]) -> bool: + """ + Update a list of surveys. Depending on the implementation, this may + operate one by one or as a bulk update. + """ + ... + + def update_field(self, survey: MarketplaceTask, field: str) -> bool: + """ + Update only `field` from `survey`. The survey must already exist. We expect + that you've already checked that the field's value is different. + """ diff --git a/generalresearch/managers/thl/__init__.py b/generalresearch/managers/thl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/thl/buyer.py b/generalresearch/managers/thl/buyer.py new file mode 100644 index 0000000..dd0b4f2 --- /dev/null +++ b/generalresearch/managers/thl/buyer.py @@ -0,0 +1,113 @@ +from datetime import datetime, timezone +from typing import Collection, Dict, Optional + +from generalresearch.managers.base import PostgresManager, Permission +from generalresearch.models import Source +from generalresearch.models.thl.survey.buyer import Buyer +from generalresearch.pg_helper import PostgresConfig + + +class BuyerManager(PostgresManager): + + def __init__( + self, + pg_config: PostgresConfig, + permissions: Collection[Permission] = None, + ): + super().__init__(pg_config=pg_config, permissions=permissions) + # self.buyer_pk: Dict[Buyer, int] = dict() + self.source_code_buyer: Dict[str, Buyer] = dict() + self.source_code_pk: Dict[str, int] = dict() + self.populate_caches() + + def populate_caches(self): + query = """ + SELECT id, code, source, label, created + FROM marketplace_buyer;""" + res = self.pg_config.execute_sql_query(query) + buyers = [Buyer.model_validate(d) for d in res] + self.source_code_buyer = {b.source_code: b for b in buyers} + self.source_code_pk = {b.source_code: b.id for b in buyers} + + def update_caches(self, buyers: Collection[Buyer]): + self.source_code_buyer.update({b.source_code: b for b in buyers}) + self.source_code_pk.update({b.source_code: b.id for b in buyers}) + + def get(self, source: Source, code: str) -> Buyer: + return self.source_code_buyer[f"{source.value}:{code}"] + + def get_if_exists(self, source: Source, code: str) -> Optional[Buyer]: + try: + return self.get(source=source, code=code) + except KeyError: + return None + + def bulk_get_or_create(self, source: Source, codes: Collection[str]): + now = datetime.now(tz=timezone.utc) + buyers = [] + params_seq = [] + for code in codes: + source_code = f"{source.value}:{code}" + if source_code in self.source_code_buyer: + buyers.append(self.source_code_buyer[source_code]) + else: + params_seq.append({"source": source, "code": code, "created": now}) + + # Insert those not in the cache. If the cache is stale, it doesn't + # really matter b/c we won't insert a dupe, and we'll fetch it + # back right after + query = """ + INSERT INTO marketplace_buyer ( + source, code, created + ) VALUES ( + %(source)s, %(code)s, %(created)s + ) ON CONFLICT (source, code) DO NOTHING;""" + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.executemany(query=query, params_seq=params_seq) + conn.commit() + + lookup = [x["code"] for x in params_seq] + query = """ + SELECT id, source, code, label, created + FROM marketplace_buyer + WHERE source = %(source)s AND + code = ANY(%(lookup)s); + """ + res = self.pg_config.execute_sql_query( + query, params={"lookup": lookup, "source": source.value} + ) + new_buyers = [Buyer.model_validate(d) for d in res] + self.update_caches(new_buyers) + buyers.extend(new_buyers) + # Not required, just for ease of testing/deterministic + buyers = sorted(buyers, key=lambda x: (x.source, x.code)) + assert len(buyers) == len(codes), "something went wrong" + return buyers + + def update(self, buyer: Buyer): + # label is the only thing that can be updated + query = """ + UPDATE marketplace_buyer + SET label = %(label)s + WHERE source = %(source)s + AND code = %(code)s + RETURNING id; + """ + params = { + "source": buyer.source.value, + "code": buyer.code, + "label": buyer.label, + } + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, params=params) + assert c.rowcount == 1 + pk = c.fetchone()["id"] + if buyer.id is not None: + assert buyer.id == pk + else: + buyer.id = pk + conn.commit() + + return None diff --git a/generalresearch/managers/thl/cashout_method.py b/generalresearch/managers/thl/cashout_method.py new file mode 100644 index 0000000..98e8e66 --- /dev/null +++ b/generalresearch/managers/thl/cashout_method.py @@ -0,0 +1,295 @@ +from copy import copy +from datetime import timezone, datetime +from typing import List, Optional, Collection, Dict +from uuid import uuid4, UUID + +from generalresearch.managers.base import PostgresManager +from generalresearch.models.thl.user import User +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.thl.wallet.cashout_method import ( + CashoutMethod, + CashMailCashoutMethodData, + PaypalCashoutMethodData, +) + + +class CashoutMethodManager(PostgresManager): + + def create(self, cm: CashoutMethod) -> None: + now = datetime.now(tz=timezone.utc) + query = """ + INSERT INTO accounting_cashoutmethod ( + id, last_updated, is_live, provider, + ext_id, name, data, user_id + ) VALUES ( + %(id)s, %(last_updated)s, %(is_live)s, %(provider)s, + %(ext_id)s, %(name)s, %(data)s, %(user_id)s + ); + """ + values = { + "id": cm.id, + "last_updated": now, + "is_live": True, + "provider": cm.type.value, + "ext_id": cm.ext_id, + "name": cm.name, + "data": cm.model_dump_json(exclude={"user"}), + "user_id": cm.user.user_id if cm.user else None, + } + + self.pg_config.execute_write(query, values) + + return None + + def delete_cashout_method(self, cm_id: str): + db_res = self.pg_config.execute_sql_query( + f""" + SELECT id::uuid, user_id + FROM accounting_cashoutmethod + WHERE id = %s AND is_live + LIMIT 1;""", + [cm_id], + ) + res = next(iter(db_res), None) + assert res, f"cashout method id {cm_id} not found" + # Don't let anyone delete a non-user-scoped cashout method + assert ( + res["user_id"] is not None + ), f"error trying to delete non user-scoped cashout method" + self.pg_config.execute_write( + f""" + UPDATE accounting_cashoutmethod SET is_live = FALSE + WHERE id = %s;""", + [cm_id], + ) + + def create_cash_in_mail_cashout_method( + self, data: CashMailCashoutMethodData, user: User + ) -> str: + """ + Each user can create 1 or more "cash in mail" cashout method. This + stores their address and possible shipping requests ? Each address + must be unique. + + :return: the uuid of the created cashout method + """ + # todo: validate shipping address? + + cm = CashoutMethod( + name="Cash in Mail", + description="USPS delivery of cash", + id=uuid4().hex, + currency="USD", + image_url="https://www.shutterstock.com/shutterstock/photos/2175413929/display_1500/stock-vector-opened" + "-envelope-with-money-dollar-bills-salary-earning-and-savings-concept-d-web-vector-2175413929.jpg", + min_value=500, # $5.00 + max_value=25000, # $250.00 + data=data, + type=PayoutType.CASH_IN_MAIL, + user=user, + ext_id=data.delivery_address.md5sum(), + ) + + # Make sure this user doesn't already have an identical cashout + # method (same address) + res = self.filter( + user=user, + is_live=True, + payout_types=[PayoutType.CASH_IN_MAIL], + ext_id=data.delivery_address.md5sum(), + ) + if res: + # Already exists with the same address + assert len(res) == 1 + return res[0].id + + self.create(cm) + + return cm.id + + def create_paypal_cashout_method( + self, data: PaypalCashoutMethodData, user: User + ) -> str: + """ + If it already exists, and the emails are the same, do nothing. If the + email is different, raises an error + + :param data: + :param user: + :return: the uuid of the created cashout method + """ + cm = CashoutMethod( + name="PayPal", + description="Cashout via PayPal", + id=uuid4().hex, + currency="USD", + image_url="https://cdn.mmfwcl.com/images/brands/p439786-1200w-326ppi.png", + min_value=100, # $1.00 + max_value=25_000, # $250.00 + data=data, + type=PayoutType.PAYPAL, + user=user, + ext_id=data.email, + ) + # Make sure this user doesn't already have one + res = self.filter(user=user, payout_types=[PayoutType.PAYPAL], is_live=True) + if res: + assert len(res) == 1 + if res[0].data.email == data.email: + # Already exists with the same email, just return it + return res[0].id + else: + raise ValueError( + "User already has a cashout method of this type. " + "Delete the existing one and try again." + ) + else: + self.create(cm) + return cm.id + + @staticmethod + def make_filter_str( + uuid: Optional[str] = None, + user: Optional[User] = None, + ext_id: Optional[str] = None, + payout_types: Optional[Collection[PayoutType]] = None, + is_live: Optional[bool] = True, + ): + filters = [] + params = dict() + if uuid is not None: + params["uuid"] = uuid + filters.append("id = %(uuid)s") + if user is not None: + params["user_id"] = user.user_id + filters.append("user_id = %(user_id)s") + if ext_id is not None: + params["ext_id"] = ext_id + filters.append("ext_id = %(ext_id)s") + if payout_types is not None: + assert isinstance(payout_types, (list, set, tuple)) + params["payout_types"] = [x.value for x in payout_types] + filters.append("provider = ANY(%(payout_types)s)") + if is_live is not None: + params["is_live"] = is_live + filters.append("is_live = %(is_live)s") + assert filters, "must pass at least one filter" + + filter_str = "WHERE " + " AND ".join(filters) if filters else "" + return filter_str, params + + def filter_count( + self, + uuid: Optional[str] = None, + user: Optional[User] = None, + ext_id: Optional[str] = None, + payout_types: Optional[Collection[PayoutType]] = None, + is_live: Optional[bool] = True, + ) -> int: + filter_str, params = self.make_filter_str( + uuid=uuid, + user=user, + ext_id=ext_id, + payout_types=payout_types, + is_live=is_live, + ) + res = self.pg_config.execute_sql_query( + query=f""" + SELECT COUNT(1) as cnt + FROM accounting_cashoutmethod + {filter_str} + """, + params=params, + ) + return res[0]["cnt"] + + def filter( + self, + uuid: Optional[str] = None, + user: Optional[User] = None, + ext_id: Optional[str] = None, + payout_types: Optional[Collection[PayoutType]] = None, + is_live: Optional[bool] = True, + ) -> List[CashoutMethod]: + filter_str, params = self.make_filter_str( + uuid=uuid, + user=user, + ext_id=ext_id, + payout_types=payout_types, + is_live=is_live, + ) + res = self.pg_config.execute_sql_query( + query=f""" + SELECT id::uuid, provider, ext_id, data::jsonb as _data_, user_id + FROM accounting_cashoutmethod + {filter_str} + """, + params=params, + ) + return [self.format_from_db(x, user=user) for x in res] + + def get_cashout_methods(self, user: User) -> List[CashoutMethod]: + """ + The provider column is PayoutType. Some are only user-scoped, + and some are global. + + :param user: The user whose cashout methods we are requesting. + """ + user.prefetch_product(pg_config=self.pg_config) + product = user.product + + supported_payout_types = copy(product.user_wallet_config.supported_payout_types) + if product.user_wallet_config.amt: + supported_payout_types.add(PayoutType.AMT) + + user_scoped_payout_types = [PayoutType.PAYPAL, PayoutType.CASH_IN_MAIL] + params = { + "user_scoped_payout_types": [x.value for x in user_scoped_payout_types], + "supported_payout_types": [x.value for x in supported_payout_types], + "user_id": user.user_id, + } + query = f""" + SELECT id::uuid, provider, ext_id, data::jsonb as _data_, user_id + FROM accounting_cashoutmethod + WHERE is_live + AND ( + (provider = ANY(%(user_scoped_payout_types)s) AND user_id = %(user_id)s) + OR (provider != ANY(%(user_scoped_payout_types)s) AND user_id IS NULL) + ) + AND provider = ANY(%(supported_payout_types)s) + LIMIT 1000;""" + + res = self.pg_config.execute_sql_query(query, params=params) + if len(res) >= 1000: + raise ValueError(f"Unexpectedly large number of cashout_methods: {user=}") + + cms = [self.format_from_db(x, user=user) for x in res] + + # Only allow AMT if the BP is marked as AMT (already should have been + # filtered in query) + cms = [ + x + for x in cms + if (x.type == PayoutType.AMT and product.user_wallet_config.amt) + or (x.type != PayoutType.AMT) + ] + return cms + + @staticmethod + def format_from_db(x: Dict, user: Optional[User] = None) -> CashoutMethod: + x["id"] = UUID(x["id"]).hex + # The data column here is inconsistent. Pulling keys from the mysql 'data' col + # and putting them into the base level. Renamed so that we don't overwrite + # a col called "data" within the "_data_" field. + for k in list(x["_data_"].keys()): + if k in CashoutMethod.model_fields: + x[k] = x["_data_"].pop(k) + x["type"] = PayoutType(x["provider"].upper()) + if "data" not in x: + x["data"] = dict() + x["data"].update(x.pop("_data_")) + x["data"]["type"] = x["type"] + if user and x["type"] in {PayoutType.PAYPAL, PayoutType.CASH_IN_MAIL}: + x["user"] = user + + return CashoutMethod.model_validate(x) diff --git a/generalresearch/managers/thl/category.py b/generalresearch/managers/thl/category.py new file mode 100644 index 0000000..15c309b --- /dev/null +++ b/generalresearch/managers/thl/category.py @@ -0,0 +1,56 @@ +from typing import Collection, Dict + +from generalresearch.managers.base import Permission, PostgresManager +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.thl.category import Category +from generalresearch.pg_helper import PostgresConfig + + +class CategoryManager(PostgresManager): + categories = dict() + category_label_map = dict() + + def __init__( + self, + pg_config: PostgresConfig, + permissions: Collection[Permission] = None, + ): + super().__init__(pg_config=pg_config, permissions=permissions) + self.categories: Dict[UUIDStr, Category] = dict() + self.category_label_map: Dict[str, Category] = dict() + self.populate_caches() + + def populate_caches(self): + query = """ + SELECT + c.id, c.uuid, c.adwords_vertical_id, c.label, c.path, c.parent_id, + p.uuid AS parent_uuid + FROM marketplace_category AS c + LEFT JOIN marketplace_category AS p + ON p.id = c.parent_id;""" + res = self.pg_config.execute_sql_query(query) + self.categories = {d["uuid"]: Category.model_validate(d) for d in res} + self.category_label_map = {c.label: c for c in self.categories.values()} + + def get_by_label(self, label: str) -> Category: + return self.category_label_map[label] + + def get_top_level(self, category: Category) -> Category: + return self.category_label_map[category.root_label] + + def get_category_root(self, category: Category) -> Category: + # These are the categories we'd display. Almost all are just the top-level + # of all paths, but we have a couple we pull out separately + # Alcoholic Beverages, Tobacco Use, Mature Content, Social Research, Demographic, Politics + custom_root = { + "4fd8381d5a1c4409ab007ca254ced084", + "90f92a5d192848ad9a230587c219b82c", + "21536f160f784189be6194ca894f3a65", + "7aa8bf4e71a84dc3b2035f93f9f9c77e", + "c82cf98c578a43218334544ab376b00e", + "87b6d819f3ca4815bf1f135b1e829cc6", + } + if category.uuid in custom_root: + return category + else: + return self.get_top_level(category) diff --git a/generalresearch/managers/thl/contest_manager.py b/generalresearch/managers/thl/contest_manager.py new file mode 100644 index 0000000..4dc02e3 --- /dev/null +++ b/generalresearch/managers/thl/contest_manager.py @@ -0,0 +1,1080 @@ +from datetime import timezone, datetime +from typing import List, Optional, Literal, cast, Collection, Tuple, Dict +from uuid import UUID + +import redis +from pydantic import PositiveInt, NonNegativeInt +from redis import Redis + +from generalresearch.managers.base import PostgresManager +from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, +) +from generalresearch.managers.thl.user_manager.user_manager import ( + UserManager, +) +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.thl.contest import ( + ContestWinner, + ContestPrize, +) +from generalresearch.models.thl.contest.contest import ( + Contest, + ContestUserView, +) +from generalresearch.models.thl.contest.definitions import ( + ContestStatus, + ContestType, +) +from generalresearch.models.thl.contest.exceptions import ContestError +from generalresearch.models.thl.contest.io import ( + ContestCreate, + contest_create_to_contest, + model_cls, + user_model_cls, +) +from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContestUserView, + LeaderboardContest, +) +from generalresearch.models.thl.contest.milestone import ( + MilestoneUserView, + MilestoneEntry, + ContestEntryTrigger, + MilestoneContest, +) +from generalresearch.models.thl.contest.raffle import ( + ContestEntry, + ContestEntryType, + RaffleUserView, + RaffleContest, +) +from generalresearch.models.thl.user import User + +CONTEST_SELECT = """ + c.id, + c.uuid::uuid, + c.product_id::uuid, + c.name, + c.description, + c.country_isos, + c.contest_type, + c.status, + c.starts_at::timestamptz, + c.terms_and_conditions, + c.end_condition::jsonb, + c.prizes::jsonb, + c.ended_at::timestamptz, + c.end_reason, + c.entry_type, + c.entry_rule::jsonb, + c.current_participants, + c.current_amount, + c.milestone_config::jsonb, + c.win_count, + c.leaderboard_key, + c.created_at::timestamptz, + c.updated_at::timestamptz""" + +USER_SELECT = """ + u.id as user_id, + u.uuid::uuid as user_uuid, + u.product_id::uuid, + u.product_user_id""" + +USER_WINNINGS_JOIN = """ +LEFT JOIN LATERAL ( +SELECT + jsonb_agg( + jsonb_build_object( + 'uuid', cw.uuid::uuid, + 'prize', cw.prize::jsonb, + 'created_at', cw.created_at::timestamptz + ) + ) AS user_winnings, + MAX(cw.created_at) AS last_won +FROM contest_contestwinner cw +WHERE cw.contest_id = c.id + AND cw.user_id = %(user_id)s +) cw_json ON TRUE""" + +USER_ENTRIES_JOIN = """ +LEFT JOIN LATERAL ( + SELECT + COALESCE(SUM(ce.amount), 0) AS user_amount, + COALESCE( + SUM( + CASE WHEN ce.created_at > NOW() - INTERVAL '24 hours' + THEN ce.amount ELSE 0 END + ), 0 + ) AS user_amount_today, + MAX(ce.created_at)::timestamptz AS entry_last_created + FROM contest_contestentry ce + WHERE ce.contest_id = c.id + AND ce.user_id = %(user_id)s +) ce_agg ON TRUE +""" + + +class ContestBaseManager(PostgresManager): + + def create(self, product_id: UUIDStr, contest_create: ContestCreate) -> Contest: + contest = contest_create_to_contest( + product_id=product_id, contest_create=contest_create + ) + data = contest.model_dump_mysql() + fields = set(data.keys()) + + fields_str = ", ".join(fields) + values_str = ", ".join([f"%({x})s" for x in fields]) + query = f""" + INSERT INTO contest_contest ({fields_str}) + VALUES ({values_str}) + RETURNING id; + """ + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=query, + params=data, + ) + pk = c.fetchone()["id"] + conn.commit() + + contest.id = pk + return contest + + def get(self, contest_uuid: UUIDStr) -> Contest: + contest_uuid = UUID(contest_uuid).hex + res = self.pg_config.execute_sql_query( + query=f""" + SELECT {CONTEST_SELECT} + FROM contest_contest c + WHERE c.uuid = %(contest_uuid)s + LIMIT 2; + """, + params={"contest_uuid": contest_uuid}, + ) + # uuid column has a unique constraint. there can't possibly be >1 + if len(res) == 0: + raise ValueError("Contest not found") + + d = res[0] + return model_cls[d["contest_type"]].model_validate_mysql(d) + + def get_if_exists(self, contest_uuid: UUIDStr) -> Optional[Contest]: + try: + return self.get(contest_uuid=contest_uuid) + + except ValueError as e: + if e.args[0] == "Contest not found": + return None + raise e + + @staticmethod + def make_filter_str( + product_id: Optional[str] = None, + status: Optional[ContestStatus] = None, + contest_type: Optional[ContestType] = None, + starts_at_before: Optional[datetime | bool] = None, + name: Optional[str] = None, + name_contains: Optional[str] = None, + uuids: Optional[Collection[str]] = None, + has_participants: Optional[bool] = None, + ) -> Tuple[str, Dict]: + filters = [] + params = dict() + if product_id: + params["product_id"] = product_id + filters.append("product_id = %(product_id)s") + if status: + params["status"] = status.value + filters.append("status = %(status)s") + if contest_type: + params["contest_type"] = contest_type.value + filters.append("contest_type = %(contest_type)s") + if starts_at_before is True: + params["starts_at"] = datetime.now(tz=timezone.utc) + filters.append("starts_at < %(starts_at)s") + elif starts_at_before: + assert starts_at_before.tzinfo == timezone.utc + params["starts_at"] = starts_at_before + filters.append("starts_at < %(starts_at)s") + if name is not None: + params["name"] = name + filters.append("name = %(name)s") + if name_contains is not None: + params["name_contains"] = f"%{name_contains}%" + filters.append("name ILIKE %(name_contains)s") + if uuids is not None: + if len(uuids) == 0: + # If we pass an empty list, the sql query will have a syntax error. Make it + # instead a legal filter, that will return nothing. + uuids = ["0" * 32] + params["uuids"] = uuids + filters.append("uuid = ANY(%(uuids)s)") + if has_participants: + filters.append("current_participants > 0") + + filter_str = "WHERE " + " AND ".join(filters) if filters else "" + return filter_str, params + + def get_many( + self, + product_id: Optional[str] = None, + status: Optional[ContestStatus] = None, + contest_type: Optional[ContestType] = None, + starts_at_before: Optional[datetime | bool] = None, + name: Optional[str] = None, + name_contains: Optional[str] = None, + uuids: Optional[Collection[str]] = None, + has_participants: Optional[bool] = None, + page: Optional[int] = None, + size: Optional[int] = None, + include_winners: bool = True, + ) -> List[Contest]: + + filter_str, params = self.make_filter_str( + product_id=product_id, + status=status, + contest_type=contest_type, + starts_at_before=starts_at_before, + name=name, + name_contains=name_contains, + uuids=uuids, + has_participants=has_participants, + ) + + paginated_filter_str = "" + if page is not None: + assert page != 0, "page starts at 1" + size = size if size is not None else 100 + params["offset"] = (page - 1) * size + params["limit"] = size + paginated_filter_str = " LIMIT %(limit)s OFFSET %(offset)s" + + # set "order by" as a param? Would like "ending soonest", but that is not easy to query + order_by_str = "ORDER BY created_at DESC" + + if include_winners: + query = f""" + SELECT {CONTEST_SELECT}, + COALESCE(cw_json.all_winners, '[]'::jsonb) AS all_winners + FROM contest_contest c + LEFT JOIN ( + SELECT + cw.contest_id, + jsonb_agg( + jsonb_build_object( + 'uuid', cw.uuid, + 'prize', cw.prize, + 'created_at', cw.created_at, + 'user_id', cw.user_id, + 'user_uuid', u.uuid::uuid, + 'product_id', u.product_id::uuid, + 'product_user_id', u.product_user_id + ) + ) AS all_winners + FROM contest_contestwinner cw + JOIN thl_user u ON u.id = cw.user_id + GROUP BY cw.contest_id + ) AS cw_json ON cw_json.contest_id = c.id + {filter_str} + {order_by_str} {paginated_filter_str} ; + """ + + else: + query = f""" + SELECT {CONTEST_SELECT} + FROM contest_contest c + {filter_str} + {order_by_str} {paginated_filter_str} ; + """ + + # print(query) + sql_res = self.pg_config.execute_sql_query(query=query, params=params) + res = [] + for d in sql_res: + if include_winners: + for x in d["all_winners"]: + x["uuid"] = UUID(x["uuid"]).hex + x["created_at"] = datetime.fromisoformat(x["created_at"]) + x["user"] = self.parse_user_from_row(x) + c: Contest = model_cls[d["contest_type"]].model_validate_mysql(d) + res.append(c) + return res + + def get_many_by_user_eligible_raffle( + self, user: User, country_iso: str + ) -> List[RaffleUserView]: + # Seems like this is a known pycharm bug. Doing it this way to be explicit. + # https://youtrack.jetbrains.com/issue/PY-42473/Type-inference-broken-for-Literal-with-Enum + cs = self.get_many_by_user_eligible( + user=user, country_iso=country_iso, contest_type=ContestType.RAFFLE + ) + return cast(List[RaffleUserView], cs) + + def get_many_by_user_eligible_milestone( + self, + user: User, + country_iso: str, + entry_trigger: Optional[ContestEntryTrigger] = None, + ) -> List[MilestoneUserView]: + cs = self.get_many_by_user_eligible( + user=user, + country_iso=country_iso, + contest_type=ContestType.MILESTONE, + entry_trigger=entry_trigger, + ) + return cast(List[MilestoneUserView], cs) + + def get_many_by_user_eligible( + self, + user: User, + country_iso: str, + contest_type: Optional[ContestType] = None, + entry_trigger: Optional[ContestEntryTrigger] = None, + ) -> List[ContestUserView]: + # Get by product_id, and status OPEN. Then we have to filter in python. + # (could also add country filter into mysql) + assert user.user_id, "invalid user" + assert user.product_id, "invalid user" + + if entry_trigger: + assert contest_type == ContestType.MILESTONE + + params = {"user_id": user.user_id, "product_id": user.product_id} + filters = [] + if contest_type: + params["contest_type"] = contest_type.value + filters.append("contest_type = %(contest_type)s") + if entry_trigger: + params["entry_trigger"] = entry_trigger.value + filters.append( + "milestone_config::jsonb->>'entry_trigger' = %(entry_trigger)s" + ) + filter_str = " AND " + " AND ".join(filters) if filters else "" + sql_res = self.pg_config.execute_sql_query( + query=f""" + SELECT + {CONTEST_SELECT}, + ce_agg.user_amount, + ce_agg.user_amount_today + FROM contest_contest c + {USER_ENTRIES_JOIN} + WHERE product_id = %(product_id)s AND status = 'active' + {filter_str}; + """, + params=params, + ) + + res = [] + for d in sql_res: + d["product_user_id"] = user.product_user_id + c: ContestUserView = user_model_cls[d["contest_type"]].model_validate_mysql( + d + ) + passes, _ = c.is_user_eligible(country_iso=country_iso) + if passes: + res.append(c) + + return res + + def get_many_by_user_entered( + self, + user: User, + limit: Optional[PositiveInt] = 100, + order_by: Literal["recent_enter", "ending_soon"] = "recent_enter", + ) -> List[ContestUserView]: + """ + This sets the user_contest_info field as well, which calculates the + user's entry count and win percentages. + We need: user_amount, and user_winnings + """ + assert user.user_id, "invalid user" + params = {"user_id": user.user_id} + + if order_by == "recent_enter": + order_by_str = "ORDER BY entry_last_created DESC" + else: + # don't really have a good way of doing this yet ... lol. Sort + # by oldest contest instead + order_by_str = "ORDER BY c.created_at ASC" + + query = f""" + SELECT + {CONTEST_SELECT}, + ce_agg.user_amount, + ce_agg.user_amount_today, + ce_agg.entry_last_created, + COALESCE(cw_json.user_winnings, '[]'::jsonb) AS user_winnings, + {USER_SELECT} + FROM contest_contest c + JOIN thl_user u + ON u.id = %(user_id)s + JOIN contest_contestentry ce0 + ON ce0.contest_id = c.id + AND ce0.user_id = %(user_id)s + {USER_ENTRIES_JOIN} + {USER_WINNINGS_JOIN} + {order_by_str} + LIMIT {limit} + """ + sql_res = self.pg_config.execute_sql_query( + query=query, + params=params, + ) + + res = [] + for d in sql_res: + for x in d["user_winnings"]: + x["uuid"] = UUID(x["uuid"]).hex + x["created_at"] = datetime.fromisoformat(x["created_at"]) + x["user"] = user + c: ContestUserView = user_model_cls[d["contest_type"]].model_validate_mysql( + d + ) + res.append(c) + + return res + + def get_many_by_user_won( + self, + user: User, + limit: Optional[PositiveInt] = 100, + ) -> List[ContestUserView]: + """ + This sets the user_contest_info field as well, which calculates the + user's entry count and win percentages. + """ + assert user.user_id, "invalid user" + params = {"user_id": user.user_id} + query = f""" + SELECT + {CONTEST_SELECT}, + ce_agg.user_amount, + ce_agg.user_amount_today, + COALESCE(cw_json.user_winnings, '[]'::jsonb) AS user_winnings, + cw_json.last_won AS contest_last_won, + {USER_SELECT} + FROM contest_contest c + JOIN thl_user u + ON u.id = %(user_id)s + {USER_ENTRIES_JOIN} + {USER_WINNINGS_JOIN} + WHERE EXISTS ( + SELECT 1 + FROM contest_contestwinner w + WHERE w.contest_id = c.id + AND w.user_id = %(user_id)s + ) + ORDER BY contest_last_won DESC + LIMIT {limit} + """ + sql_res = self.pg_config.execute_sql_query( + query=query, + params=params, + ) + res = [] + for d in sql_res: + for x in d["user_winnings"]: + x["uuid"] = UUID(x["uuid"]).hex + x["created_at"] = datetime.fromisoformat(x["created_at"]) + x["user"] = user + c: ContestUserView = user_model_cls[d["contest_type"]].model_validate_mysql( + d + ) + res.append(c) + + return res + + @staticmethod + def parse_user_from_row(d: Dict): + return User( + uuid=UUID(d["user_uuid"]).hex, + user_id=d["user_id"], + product_user_id=d["product_user_id"], + product_id=UUID(d["product_id"]).hex, + ) + + def get_winnings_by_user(self, user: User) -> List[ContestWinner]: + assert user.user_id, "invalid user" + sql_res = self.pg_config.execute_sql_query( + query=f""" + SELECT + cw.id, + cw.uuid::uuid, + cw.contest_id, + cw.prize::jsonb, + cw.awarded_cash_amount, + cw.created_at::timestamptz, + {USER_SELECT} + FROM contest_contestwinner cw + JOIN thl_user u + ON u.id = cw.user_id + WHERE user_id = %(user_id)s + """, + params={"user_id": user.user_id}, + ) + + res = [] + for x in sql_res: + x["uuid"] = UUID(x["uuid"]).hex + x["prize"] = ContestPrize.model_validate(x["prize"]) + x["user"] = user + res.append(ContestWinner.model_validate(x)) + + return res + + def get_entries_by_contest_id(self, contest_id: PositiveInt) -> List[ContestEntry]: + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT + ce.id, + ce.uuid::uuid, + ce.contest_id, + ce.amount, + ce.user_id, + ce.created_at::timestamptz, + ce.updated_at::timestamptz, + c.entry_type, + {USER_SELECT} + FROM contest_contestentry ce + JOIN contest_contest c + ON c.id = ce.contest_id + JOIN thl_user u + ON u.id = ce.user_id + WHERE ce.contest_id = %(contest_id)s + """, + params={"contest_id": contest_id}, + ) + for x in res: + x["user"] = self.parse_user_from_row(x) + return [ContestEntry.model_validate(x) for x in res] + + def end_contest_with_winners( + self, contest: Contest, ledger_manager: ThlLedgerManager + ) -> None: + assert contest.status == ContestStatus.COMPLETED, "status must be completed" + data = { + "status": contest.status.value, + "ended_at": contest.ended_at, + "end_reason": contest.end_reason, + "contest_uuid": contest.uuid, + } + winners = contest.all_winners + + assert contest.id + rows = [w.model_dump_mysql(contest_id=contest.id) for w in winners] + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.executemany( + query=""" + INSERT INTO contest_contestwinner + (uuid, created_at, user_id, + contest_id, prize, awarded_cash_amount) + VALUES (%(uuid)s, %(created_at)s, %(user_id)s, + %(contest_id)s, %(prize)s, %(awarded_cash_amount)s) + """, + params_seq=rows, + ) + c.execute( + query=""" + UPDATE contest_contest + SET status = %(status)s, + ended_at = %(ended_at)s, + end_reason = %(end_reason)s + WHERE uuid = %(contest_uuid)s + AND status = 'active' + """, + params=data, + ) + assert c.rowcount == 1, "Contest changed during write" + conn.commit() + ledger_manager.create_tx_contest_close(contest=contest) + return None + + def cancel_contest(self, contest: Contest) -> int: + assert contest.status == ContestStatus.CANCELLED, "status must be cancelled" + + return self.pg_config.execute_write( + query=""" + UPDATE contest_contest + SET status = %(status)s + WHERE uuid = %(contest_uuid)s + """, + params={ + "contest_uuid": contest.uuid, + "status": contest.status, + }, + ) + + +class RaffleContestManager(ContestBaseManager): + + def get_raffle_user_view(self, contest_uuid: UUIDStr, user: User) -> RaffleUserView: + + assert user.user_id and user.product_user_id, "invalid user" + query = f""" + SELECT + {CONTEST_SELECT}, + ce_agg.user_amount, + ce_agg.user_amount_today, + COALESCE(cw_json.user_winnings, '[]'::jsonb) AS user_winnings, + {USER_SELECT} + FROM contest_contest c + JOIN thl_user u ON u.id = %(user_id)s + {USER_ENTRIES_JOIN} + {USER_WINNINGS_JOIN} + WHERE c.uuid = %(contest_uuid)s; + """ + sql_res = self.pg_config.execute_sql_query( + query=query, + params={"user_id": user.user_id, "contest_uuid": contest_uuid}, + ) + assert len(sql_res) == 1 + d = sql_res[0] + for x in d["user_winnings"]: + x["uuid"] = UUID(x["uuid"]).hex + x["created_at"] = datetime.fromisoformat(x["created_at"]) + x["user"] = user + return RaffleUserView.model_validate_mysql(d) + + def enter_contest( + self, + contest_uuid: UUIDStr, + entry: ContestEntry, + country_iso: str, + ledger_manager: ThlLedgerManager, + ) -> ContestEntry: + """ + - Validates user is eligible to enter this contest + We need to look up the contest, b/c we need the contest-user-view, + with counts n stuff scoped to the requesting user + - If it is a cash contest: + - validates user has balance in their wallet + - does ledger txs, does enter_contest_db() + - else: + - enter_contest_db() + Note: for milestone contests, the API should prevent a user from + trying to enter it + """ + contest = self.get_raffle_user_view(contest_uuid=contest_uuid, user=entry.user) + assert contest.contest_type == ContestType.RAFFLE, "can only enter a raffle" + assert isinstance(contest, RaffleUserView) + assert contest.entry_type == entry.entry_type, "incompatible entry type" + + res, msg = contest.is_entry_eligible(entry=entry) + if not res: + raise ContestError(msg) + + res, msg = contest.is_user_eligible(country_iso=country_iso) + if not res: + raise ContestError(msg) + + if contest.entry_type == ContestEntryType.CASH: + tx = ledger_manager.create_tx_user_enter_contest( + contest_uuid=contest.uuid, contest_entry=entry + ) + + entry = self.enter_contest_db_work_raffle(contest=contest, entry=entry) + decision, msg = contest.should_end() + if decision: + contest.end_contest() + self.end_contest_with_winners(contest, ledger_manager) + + return entry + + def enter_contest_db_work_raffle( + self, contest: RaffleContest, entry: ContestEntry + ) -> ContestEntry: + assert contest.id, "Contest must be saved." + + # todo: retry if this fails + # 1) get the contest with all its entries, + contest.entries = self.get_entries_by_contest_id(contest_id=contest.id) + assert contest.current_amount == contest.get_current_amount() + assert contest.current_participants == contest.get_current_participants() + old_current_amount = contest.current_amount + + # 2) calculate new values of current_participants and current_amount + contest.entries.append(entry) + contest.current_amount = contest.get_current_amount() + contest.current_participants = contest.get_current_participants() + + data = entry.model_dump_mysql(contest_id=contest.id) + + # 3) IN 1 DB TX: update these 2 field on the contest, and create the entry + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=f""" + INSERT INTO contest_contestentry + (uuid, amount, user_id, + created_at, updated_at, contest_id) + VALUES (%(uuid)s, %(amount)s, %(user_id)s, + %(created_at)s, %(updated_at)s, %(contest_id)s) + """, + params=data, + ) + c.execute( + query=f""" + UPDATE contest_contest + SET current_amount = %(current_amount)s, + current_participants = %(current_participants)s + WHERE id = %(contest_id)s + -- Double click / Lock protection. No rows will be + -- changed if someone tries to enter the contest + -- while we're in the middle of this transaction + AND current_amount = %(old_current_amount)s + """, + params={ + "old_current_amount": old_current_amount, + "current_amount": contest.current_amount, + "current_participants": contest.current_participants, + "contest_id": contest.id, + }, + ) + assert ( + c.rowcount == 1 + ), "enter_contest_db_work_raffle: Mismatch amounts in contest entry" + conn.commit() + + return entry + + +class MilestoneContestManager(ContestBaseManager): + def get_milestone_user_view( + self, contest_uuid: UUIDStr, user: User + ) -> MilestoneUserView: + + assert user.user_id and user.product_user_id, "invalid user" + + # Note: do NOT just join both tables, or you'll end up with "JOIN multiplication". + # Have to join the contestwinner in a subquery. + # Note: In a milestone contest, there will only be 0 or 1 contest_contestentry rows + # per (user, contest), so no aggregation is done (for user_amount, etc). + sql_res = self.pg_config.execute_sql_query( + query=f""" + SELECT + {CONTEST_SELECT}, + COALESCE(ce.amount, 0) AS user_amount, + COALESCE(cw_json.user_winnings, '[]'::jsonb) AS user_winnings, + {USER_SELECT} + FROM contest_contest c + JOIN thl_user u + ON u.id = %(user_id)s + LEFT JOIN contest_contestentry ce + ON ce.contest_id = c.id + AND ce.user_id = %(user_id)s + {USER_WINNINGS_JOIN} + WHERE c.uuid = %(contest_uuid)s + LIMIT 2; + """, + params={"user_id": user.user_id, "contest_uuid": contest_uuid}, + ) + assert len(sql_res) == 1 + + d = sql_res[0] + for x in d["user_winnings"]: + x["uuid"] = UUID(x["uuid"]).hex + x["created_at"] = datetime.fromisoformat(x["created_at"]) + x["user"] = user + + return MilestoneUserView.model_validate_mysql(d) + + def enter_milestone_contest( + self, + contest_uuid: UUIDStr, + user: User, + country_iso: str, + ledger_manager: ThlLedgerManager, + incr: PositiveInt = 1, + ) -> None: + """ + This is "enter_contest" but for a milestone contest. There is a single + contest entry record per (contest, user). We'll validate the user is + eligible, then create or update it, then do contest maintenance. + """ + contest = self.get_milestone_user_view(contest_uuid=contest_uuid, user=user) + assert ( + contest.contest_type == ContestType.MILESTONE + ), "can only enter a milestone" + assert isinstance(contest, MilestoneUserView) + + res, msg = contest.is_user_eligible(country_iso=country_iso) + if not res: + raise ContestError(msg) + + self.enter_contest_db_work_milestone(contest=contest, user=user, incr=incr) + if contest.should_award(): + self.award_milestone_contest(contest, user, ledger_manager=ledger_manager) + + decision, reason = contest.should_end() + if decision: + contest.update( + status=ContestStatus.COMPLETED, + ended_at=datetime.now(tz=timezone.utc), + end_reason=reason, + ) + self.end_milestone_contest(contest) + + return None + + def enter_contest_db_work_milestone( + self, contest: MilestoneUserView, user: User, incr: PositiveInt + ) -> MilestoneEntry: + # Single entry per entry, sum to user's previous if exists + entry = MilestoneEntry(user=user, amount=incr) + data = entry.model_dump_mysql(contest_id=contest.id) + if contest.user_amount == 0: + self.pg_config.execute_write( + query=""" + INSERT INTO contest_contestentry + (uuid, amount, user_id, + created_at, updated_at, contest_id) + VALUES (%(uuid)s, %(amount)s, %(user_id)s, + %(created_at)s, %(updated_at)s, %(contest_id)s) + """, + params=data, + ) + contest.user_amount = entry.amount + else: + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=""" + UPDATE contest_contestentry + SET amount = amount + %(amount)s, + updated_at = %(updated_at)s + WHERE user_id = %(user_id)s + AND contest_id = %(contest_id)s + AND amount = %(current_amount)s + """, + params=data | {"current_amount": contest.user_amount}, + ) + assert ( + c.rowcount == 1 + ), "enter_contest_db_work_milestone: Mismatch amounts in contest entry" + conn.commit() + contest.user_amount += entry.amount + return entry + + def end_milestone_contest(self, contest: MilestoneContest) -> None: + """ + A milestone contest has (possibly) paid out user's award already (once + each user has reached the milestone). So when the contest itself is + over, nothing really happens, money-wise. + """ + assert contest.status == ContestStatus.COMPLETED, "status must be completed" + assert isinstance(contest, MilestoneContest), "must pass MilestoneContest" + data = { + "status": contest.status.value, + "ended_at": contest.ended_at, + "end_reason": contest.end_reason, + "contest_uuid": contest.uuid, + } + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=""" + UPDATE contest_contest + SET status = %(status)s, + ended_at = %(ended_at)s, + end_reason = %(end_reason)s + WHERE uuid = %(contest_uuid)s + AND status = 'active' + """, + params=data, + ) + assert c.rowcount == 1, "Contest changed during write" + conn.commit() + return None + + def award_milestone_contest( + self, + contest: MilestoneUserView, + user: User, + ledger_manager: ThlLedgerManager, + ) -> None: + """A user reached the milestone. The contest stays open (unless it + has reached the max winners). + """ + assert contest.should_award() + assert not contest.user_winnings, "user already was awarded" + winners = [ContestWinner(prize=prize, user=user) for prize in contest.prizes] + rows = [w.model_dump_mysql(contest_id=contest.id) for w in winners] + # The win_count is 1 !!! A user can only "win" a milestone once, no matter how + # many prizes are awarded. + win_count = 1 + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.executemany( + query=""" + INSERT INTO contest_contestwinner + (uuid, created_at, user_id, + contest_id, prize, awarded_cash_amount) + VALUES (%(uuid)s, %(created_at)s, %(user_id)s, + %(contest_id)s, %(prize)s, %(awarded_cash_amount)s) + """, + params_seq=rows, + ) + c.execute( + query=""" + UPDATE contest_contest + SET win_count = win_count + %(win_count)s + WHERE uuid = %(contest_uuid)s + """, + params={ + "contest_uuid": contest.uuid, + "win_count": win_count, + }, + ) + conn.commit() + contest.win_count += win_count + ledger_manager.create_tx_milestone_winner(contest=contest, winners=winners) + return None + + +class LeaderboardContestManager(ContestBaseManager): + def get_leaderboard_user_view( + self, + contest_uuid: UUIDStr, + user: User, + redis_client: redis.Redis, + user_manager: UserManager, + ) -> LeaderboardContestUserView: + """ + A leaderboard contest has NO user_entries. The redis leaderboard + manager handles tracking everything. + """ + assert user.user_id and user.product_user_id, "invalid user" + + sql_res = self.pg_config.execute_sql_query( + query=f""" + SELECT + {CONTEST_SELECT}, + COALESCE(cw_json.user_winnings, '[]'::jsonb) AS user_winnings, + {USER_SELECT} + FROM contest_contest c + JOIN thl_user u + ON u.id = %(user_id)s + {USER_WINNINGS_JOIN} + WHERE c.uuid = %(contest_uuid)s + LIMIT 2; + """, + params={"user_id": user.user_id, "contest_uuid": contest_uuid}, + ) + assert len(sql_res) == 1 + + d = sql_res[0] + for x in d["user_winnings"]: + x["uuid"] = UUID(x["uuid"]).hex + x["created_at"] = datetime.fromisoformat(x["created_at"]) + x["user"] = user + c = LeaderboardContestUserView.model_validate_mysql(d) + c._redis_client = redis_client + c._user_manager = user_manager + + return c + + def end_contest_if_over( + self, contest: Contest, ledger_manager: ThlLedgerManager + ) -> None: + decision, reason = contest.should_end() + if decision: + contest.end_contest() + return self.end_contest_with_winners(contest, ledger_manager) + + +class ContestManager( + RaffleContestManager, MilestoneContestManager, LeaderboardContestManager +): + + def end_contest(self, contest: Contest, ledger_manager: ThlLedgerManager) -> None: + if isinstance(contest, (MilestoneContest)): + return self.end_milestone_contest(contest) + + elif isinstance(contest, (LeaderboardContest, RaffleContest)): + return self.end_contest_with_winners(contest, ledger_manager=ledger_manager) + + def check_for_contest_closing( + self, + ledger_manager: ThlLedgerManager, + redis_client: Redis, + user_manager: UserManager, + ) -> Dict[str, NonNegativeInt]: + # This is an administrative function that we'll run on a schedule, + # that will check for any open contests, for any BP, that should be + # closed, and then do it! + page = 1 + contests_checked = 0 + contests_closed = 0 + while True: + contests = self.get_many( + status=ContestStatus.ACTIVE, + include_winners=False, + starts_at_before=True, + page=page, + size=20, + ) + if not contests: + break + print(f"Got {len(contests)} contests") + contests_checked += len(contests) + this_contests_closed = self.check_for_contest_closing_chunk( + contests, + ledger_manager=ledger_manager, + redis_client=redis_client, + user_manager=user_manager, + ) + contests_closed += this_contests_closed + print(f"Closed {this_contests_closed} contests") + page += 1 + + return {"closed": contests_closed, "checked": contests_checked} + + def check_for_contest_closing_chunk( + self, + contests: Collection[Contest], + ledger_manager: ThlLedgerManager, + redis_client: Redis, + user_manager: UserManager, + ) -> NonNegativeInt: + contests_closed = 0 + for contest in contests: + should_end, reason = contest.should_end() + if should_end: + if hasattr(contest, "redis_client"): + contest.redis_client = redis_client + contest.user_manager = user_manager + contests_closed += 1 + contest.end_contest() + self.end_contest(contest, ledger_manager=ledger_manager) + return contests_closed + + def hit_milestone_triggers( + self, + event: ContestEntryTrigger, + user: User, + country_iso: str, + ledger_manager: ThlLedgerManager, + ) -> PositiveInt: + """For any open milestone contest that has a trigger on this event, + if the user is eligible, hit it! + """ + cs = self.get_many_by_user_eligible_milestone( + user=user, country_iso=country_iso, entry_trigger=event + ) + for c in cs: + self.enter_milestone_contest( + contest_uuid=c.uuid, + country_iso=country_iso, + user=user, + ledger_manager=ledger_manager, + ) + return len(cs) diff --git a/generalresearch/managers/thl/delete_request.py b/generalresearch/managers/thl/delete_request.py new file mode 100644 index 0000000..f9d15c5 --- /dev/null +++ b/generalresearch/managers/thl/delete_request.py @@ -0,0 +1,178 @@ +# from datetime import datetime, timezone +# from typing import Optional +# +# from generalresearch.managers.gr.authentication import GRUserManager +# from generalresearch.managers.thl.user_manager.user_manager import UserManager +# from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr +# from pydantic import BaseModel, Field, PositiveInt, model_validator +# from pydantic.json_schema import SkipJsonSchema +# +# from api.decorators import THL_WEB_RR, GR_DB +# +# GR_UM = GRUserManager(sql_helper=GR_DB) +# UM = UserManager(sql_helper_rr=THL_WEB_RR) +# + +# @pytest.mark.skip(reason="moving to pyutils 2.5.1") +# class TestUserDeleteRequestManager: +# +# def test_delete_request(self, gr_user, user, product, user_manager, gr_um): +# from api.models.product_user import DeleteRequest +# from api.managers.product_user import UserDeletionRequestManager +# +# # A valid Respondent and GR Admin account need to exist in the test +# # database for any of this to work +# user = user_manager.create_dummy( +# product_id=product.id, +# product_user_id=f"test-{uuid4().hex[:6]}", +# ) +# +# instance = DeleteRequest( +# product_id=user.product_id, +# product_user_id=user.product_user_id, +# created_by_user_id=gr_user.id, +# ) +# +# start: int = UserDeletionRequestManager().get_count_by_product_id( +# product_id=user.product_id +# ) +# +# UserDeletionRequestManager.save(deletion_request=instance) +# +# finish: int = UserDeletionRequestManager().get_count_by_product_id( +# product_id=user.product_id +# ) +# +# assert finish == start + 1 + + +# @pytest.mark.skip(reason="Moving to py-utils in 2.5.1") +# class TestProductUserDeleteRequest: +# +# def test_no_user_provided(self, product, business, team, gr_user): +# from api.models.product_user import DeleteRequest +# +# # product_id and product_user_id is required +# with pytest.raises(expected_exception=ValueError) as cm: +# DeleteRequest(created_by_user_id=gr_user.id) +# +# assert "2 validation errors" in str(cm.value) +# +# def test_no_user_exists(self, gr_user, product): +# from api.models.product_user import DeleteRequest +# +# with pytest.raises(expected_exception=ValueError) as cm: +# DeleteRequest( +# product_id=product.id, +# product_user_id=f"test-user-{uuid4().hex[:12]}", +# created_by_user_id=gr_user.id, +# ) +# +# assert "Could not find Worker" in str(cm.value) +# +# def test_no_create_by_user(self, user, product): +# from api.models.product_user import DeleteRequest +# +# with pytest.raises(expected_exception=ValueError) as cm: +# DeleteRequest( +# product_id=user.product_id, +# product_user_id=user.product_user_id, +# created_by_user_id=randint(a=999_999, b=999_999_999), +# ) +# assert "GRUser not found" in str(cm.value) + + +# +# class DeleteRequest(BaseModel): +# id: SkipJsonSchema[Optional[PositiveInt]] = Field(default=None, exclude=True) +# uuid: UUIDStr = Field(examples=[uuid4().hex], default_factory=lambda: uuid4().hex) +# +# product_id: UUIDStr = Field(examples=["00e96773d4ae47f8812488a976a080c8"]) +# product_user_id: str = Field( +# min_length=3, max_length=128, examples=["bpuid-68d989"] +# ) +# +# created: AwareDatetimeISO = Field( +# default=datetime.now(tz=timezone.utc), +# description="When the DeleteRequest was created, this is the UTC time " +# "that a Worker / Respondent's Profiling Questions were " +# "deleted.", +# ) +# created_by_user_id: SkipJsonSchema[PositiveInt] = Field(exclude=True) +# +# @model_validator(mode="after") +# def check_valid_worker(self) -> "DeleteRequest": +# """ Raise an error if the User that the GRUser is attempting to delete +# does not actually exist in the system. We can check the production +# thl-web user table here for real time users +# """ +# user = UM.get_user_if_exists( +# product_id=self.product_id, product_user_id=self.product_user_id +# ) +# +# if not user: +# raise ValueError("Could not find Worker") +# +# return self +# +# @model_validator(mode="after") +# def check_valid_owner(self) -> "DeleteRequest": +# """ Ensure we can track which GRUser made a deletion request so we can +# track the chain of command for who took what action. +# +# """ +# gr_user = GR_UM.get_by_id(gr_user_id=self.created_by_user_id) +# +# if not gr_user: +# raise ValueError("Could not find General Research account") +# +# return self + + +# @staticmethod +# def save(deletion_request: DeleteRequest) -> bool: +# with GR_DB.make_connection() as conn: +# with conn.cursor(row_factory=dict_row) as c: +# c: Cursor +# +# c.execute( +# query=f""" +# INSERT INTO product_user_deleterequest +# (uuid, product_id, product_user_id, created, +# created_by_user_id) +# VALUES (%s, %s, %s, %s, %s) +# """, +# params=[ +# deletion_request.uuid, +# deletion_request.product_id, +# deletion_request.product_user_id, +# deletion_request.created, +# deletion_request.created_by_user_id, +# ], +# ) +# +# conn.commit() +# +# return True +# +# +# @staticmethod +# def get_count_by_product_id(product_id: UUIDStr) -> NonNegativeInt: +# with GR_DB.make_connection() as conn: +# with conn.cursor(row_factory=dict_row) as c: +# c: Cursor +# +# c.execute( +# query=f""" +# SELECT COUNT(1) as cnt +# FROM product_user_deleterequest AS dr +# WHERE dr.product_id = %s +# """, +# params=[ +# product_id, +# ], +# ) +# res = c.fetchall() +# +# assert len(res) == 1, "invalid query" +# return int(res[0]["cnt"]) diff --git a/generalresearch/managers/thl/ipinfo.py b/generalresearch/managers/thl/ipinfo.py new file mode 100644 index 0000000..563b179 --- /dev/null +++ b/generalresearch/managers/thl/ipinfo.py @@ -0,0 +1,819 @@ +import ipaddress +from decimal import Decimal +from random import randint +from typing import List, Optional, Dict, Collection + +import faker +import pymysql +from more_itertools import chunked +from psycopg import Cursor +from pydantic import PositiveInt + +from generalresearch.managers.base import ( + PostgresManager, + PostgresManagerWithRedis, +) +from generalresearch.models.custom_types import ( + IPvAnyAddressStr, + CountryISOLike, +) +from generalresearch.models.thl.ipinfo import ( + IPGeoname, + GeoIPInformation, + IPInformation, + normalize_ip, +) +from generalresearch.models.thl.maxmind.definitions import UserType +from generalresearch.pg_helper import PostgresConfig + +fake = faker.Faker() + + +class IPGeonameManager(PostgresManager): + + def create_dummy( + self, + geoname_id: Optional[PositiveInt] = None, + continent_code: Optional[str] = None, + continent_name: Optional[str] = None, + country_iso: Optional[str] = None, + country_name: Optional[str] = None, + subdivision_1_iso: Optional[str] = None, + subdivision_1_name: Optional[str] = None, + subdivision_2_iso: Optional[str] = None, + subdivision_2_name: Optional[str] = None, + city_name: Optional[str] = None, + metro_code: Optional[int] = None, + time_zone: Optional[str] = None, + is_in_european_union: Optional[bool] = None, + ) -> IPGeoname: + return self.create( + geoname_id=geoname_id or randint(1, 999_999_999), + continent_code=continent_code or "na", + continent_name=continent_name or "North America", + country_iso=country_iso or "us", + country_name=country_name or "United States", + subdivision_1_iso=subdivision_1_iso or "fl", + subdivision_1_name=subdivision_1_name or "Florida", + subdivision_2_iso=subdivision_2_iso, + subdivision_2_name=subdivision_2_name, + city_name=city_name, + metro_code=metro_code, + time_zone=time_zone, + is_in_european_union=is_in_european_union, + ) + + def create_basic( + self, + geoname_id: PositiveInt, + is_in_european_union: bool, + country_iso: CountryISOLike, + country_name: str, + continent_code: str, + continent_name: str, + ) -> IPGeoname: + instance = IPGeoname.model_validate( + { + "geoname_id": geoname_id, + "country_iso": country_iso, + "is_in_european_union": is_in_european_union, + "country_name": country_name, + "continent_code": continent_code, + "continent_name": continent_name, + } + ) + self.pg_config.execute_write( + query=f""" + INSERT INTO thl_geoname ( + geoname_id, country_iso, is_in_european_union, country_name, + continent_code, continent_name, updated + ) + VALUES ( + %(geoname_id)s, %(country_iso)s, %(is_in_european_union)s, %(country_name)s, + %(continent_code)s, %(continent_name)s, %(updated)s + ) + ON CONFLICT (geoname_id) DO NOTHING; + """, + params=instance.model_dump(mode="json"), + ) + return instance + + def create_or_update(self, ipgeo: IPGeoname): + keys = list(ipgeo.model_fields.keys()) + data = ipgeo.model_dump_mysql() + + keys_str = ", ".join(keys) + values_str = ", ".join([f"%({k})s" for k in keys]) + update_cols = set(keys) - {"geoname_id"} + update_str = ", ".join([f"{k} = EXCLUDED.{k}" for k in update_cols]) + + query = f""" + INSERT INTO thl_geoname ({keys_str}) + VALUES ({values_str}) + ON CONFLICT (geoname_id) + DO UPDATE SET {update_str} + """ + self.pg_config.execute_write(query=query, params=data) + + def create( + self, + geoname_id: PositiveInt, + continent_code: str, + continent_name: str, + country_iso: Optional[str], + country_name: Optional[str] = None, + subdivision_1_iso: Optional[str] = None, + subdivision_1_name: Optional[str] = None, + subdivision_2_iso: Optional[str] = None, + subdivision_2_name: Optional[str] = None, + city_name: Optional[str] = None, + metro_code: Optional[int] = None, + time_zone: Optional[str] = None, + is_in_european_union: Optional[bool] = None, + ) -> IPGeoname: + + instance = IPGeoname.model_validate( + { + "geoname_id": geoname_id, + "continent_code": continent_code, + "continent_name": continent_name, + "country_iso": country_iso, + "country_name": country_name, + "subdivision_1_iso": subdivision_1_iso, + "subdivision_1_name": subdivision_1_name, + "subdivision_2_iso": subdivision_2_iso, + "subdivision_2_name": subdivision_2_name, + "city_name": city_name, + "metro_code": metro_code, + "time_zone": time_zone, + "is_in_european_union": is_in_european_union, + } + ) + + self.pg_config.execute_write( + query=f""" + INSERT INTO thl_geoname + ( geoname_id, continent_code, continent_name, + country_iso, country_name, + subdivision_1_iso, subdivision_1_name, + subdivision_2_iso, subdivision_2_name, + city_name, metro_code, time_zone, is_in_european_union, + updated + ) + VALUES ( + %(geoname_id)s, %(continent_code)s, %(continent_name)s, + %(country_iso)s, %(country_name)s, + %(subdivision_1_iso)s, %(subdivision_1_name)s, + %(subdivision_2_iso)s, %(subdivision_2_name)s, + %(city_name)s, %(metro_code)s, %(time_zone)s, %(is_in_european_union)s, + %(updated)s + ) + ON CONFLICT (geoname_id) DO NOTHING; + """, + params=instance.model_dump(mode="json"), + ) + + return instance + + def get_by_id(self, geoname_id: PositiveInt) -> "IPGeoname": + return self.fetch_geoname_ids(filter_ids=[geoname_id])[0] + + def fetch_geoname_ids( + self, + filter_ids: List[PositiveInt], + ) -> List[IPGeoname]: + + if len(filter_ids) == 0: + return [] + + with self.pg_config.make_connection() as sql_connection: + sql_connection: pymysql.Connection + with sql_connection.cursor() as c: + res = [] + for chunk in chunked(filter_ids, 500): + res.extend( + self.fetch_geoname_ids_( + c=c, + filter_ids=chunk, + ) + ) + return res + + def fetch_geoname_ids_( + self, + c: Cursor, + filter_ids: List[PositiveInt], + ) -> List[IPGeoname]: + + assert len(filter_ids) <= 500, "chunk me" + + c.execute( + query=f""" + SELECT g.geoname_id, + g.continent_code, g.continent_name, + g.country_iso, g.country_name, + g.subdivision_1_iso, g.subdivision_1_name, + g.subdivision_2_iso, g.subdivision_2_name, + g.city_name, g.metro_code, + g.time_zone, g.is_in_european_union, + g.updated + FROM thl_geoname AS g + WHERE g.geoname_id = ANY(%s); + """, + params=[filter_ids], + ) + return [IPGeoname.from_mysql(i) for i in c.fetchall()] + + +class IPInformationManager(PostgresManager): + + def create_dummy( + self, + ip: Optional[IPvAnyAddressStr] = None, + geoname_id: Optional[PositiveInt] = None, + country_iso: Optional[str] = None, + registered_country_iso: Optional[str] = None, + is_anonymous: Optional[bool] = None, + is_anonymous_vpn: Optional[bool] = None, + is_hosting_provider: Optional[bool] = None, + is_public_proxy: Optional[bool] = None, + is_tor_exit_node: Optional[bool] = None, + is_residential_proxy: Optional[bool] = None, + autonomous_system_number: Optional[PositiveInt] = None, + autonomous_system_organization: Optional[str] = None, + domain: Optional[str] = None, + isp: Optional[str] = None, + mobile_country_code: Optional[str] = None, + mobile_network_code: Optional[str] = None, + network: Optional[str] = None, + organization: Optional[str] = None, + static_ip_score: Optional[float] = None, + user_type: Optional[UserType] = None, + postal_code: Optional[str] = None, + latitude: Optional[Decimal] = None, + longitude: Optional[Decimal] = None, + accuracy_radius: Optional[int] = None, + ) -> "IPInformation": + return self.create( + ip=ip or fake.ipv4_public(), + geoname_id=geoname_id, + country_iso=country_iso or fake.country_code(), + registered_country_iso=registered_country_iso, + is_anonymous=is_anonymous, + is_anonymous_vpn=is_anonymous_vpn, + is_hosting_provider=is_hosting_provider, + is_public_proxy=is_public_proxy, + is_tor_exit_node=is_tor_exit_node, + is_residential_proxy=is_residential_proxy, + autonomous_system_number=autonomous_system_number, + autonomous_system_organization=autonomous_system_organization, + domain=domain, + isp=isp, + mobile_country_code=mobile_country_code, + mobile_network_code=mobile_network_code, + network=network, + organization=organization, + static_ip_score=static_ip_score, + user_type=user_type, + postal_code=postal_code, + latitude=latitude, + longitude=longitude, + accuracy_radius=accuracy_radius, + ) + + def create_basic( + self, + ip: IPvAnyAddressStr, + geoname_id: PositiveInt, + country_iso: str, + registered_country_iso: str, + ) -> IPInformation: + instance = IPInformation.model_validate( + { + "ip": ip, + "geoname_id": geoname_id, + "country_iso": country_iso, + "registered_country_iso": registered_country_iso, + } + ) + instance.normalize_ip() + self.pg_config.execute_write( + query=f""" + INSERT INTO thl_ipinformation + (ip, country_iso, registered_country_iso, geoname_id, updated) + VALUES (%(ip)s, %(country_iso)s, %(registered_country_iso)s, %(geoname_id)s, %(updated)s) + ON CONFLICT (ip) DO NOTHING; + """, + params=instance.model_dump(mode="json"), + ) + return instance + + def create( + self, + ip: IPvAnyAddressStr, + geoname_id: Optional[PositiveInt] = None, + country_iso: Optional[str] = None, + registered_country_iso: Optional[str] = None, + is_anonymous: Optional[bool] = None, + is_anonymous_vpn: Optional[bool] = None, + is_hosting_provider: Optional[bool] = None, + is_public_proxy: Optional[bool] = None, + is_tor_exit_node: Optional[bool] = None, + is_residential_proxy: Optional[bool] = None, + autonomous_system_number: Optional[PositiveInt] = None, + autonomous_system_organization: Optional[str] = None, + domain: Optional[str] = None, + isp: Optional[str] = None, + mobile_country_code: Optional[str] = None, + mobile_network_code: Optional[str] = None, + network: Optional[str] = None, + organization: Optional[str] = None, + static_ip_score: Optional[float] = None, + user_type: Optional[UserType] = None, + postal_code: Optional[str] = None, + latitude: Optional[Decimal] = None, + longitude: Optional[Decimal] = None, + accuracy_radius: Optional[int] = None, + ) -> "IPInformation": + + instance = IPInformation.model_validate( + { + "ip": ip, + "geoname_id": geoname_id, + "country_iso": country_iso, + "registered_country_iso": registered_country_iso, + "is_anonymous": is_anonymous, + "is_anonymous_vpn": is_anonymous_vpn, + "is_hosting_provider": is_hosting_provider, + "is_public_proxy": is_public_proxy, + "is_tor_exit_node": is_tor_exit_node, + "is_residential_proxy": is_residential_proxy, + "autonomous_system_number": autonomous_system_number, + "autonomous_system_organization": autonomous_system_organization, + "domain": domain, + "isp": isp, + "mobile_country_code": mobile_country_code, + "mobile_network_code": mobile_network_code, + "network": network, + "organization": organization, + "static_ip_score": static_ip_score, + "user_type": user_type, + "postal_code": postal_code, + "latitude": latitude, + "longitude": longitude, + "accuracy_radius": accuracy_radius, + } + ) + instance.normalize_ip() + + self.pg_config.execute_write( + query=f""" + INSERT INTO thl_ipinformation + ( ip, geoname_id, + country_iso, registered_country_iso, + is_anonymous, is_anonymous_vpn, + is_hosting_provider, is_public_proxy, + is_tor_exit_node, is_residential_proxy, + autonomous_system_number, autonomous_system_organization, + domain, isp, + mobile_country_code, mobile_network_code, + network, organization, static_ip_score, + user_type, postal_code, latitude, longitude, + accuracy_radius, + updated + ) + VALUES ( + %(ip)s, %(geoname_id)s, + %(country_iso)s, %(registered_country_iso)s, + %(is_anonymous)s, %(is_anonymous_vpn)s, + %(is_hosting_provider)s, %(is_public_proxy)s, + %(is_tor_exit_node)s, %(is_residential_proxy)s, + %(autonomous_system_number)s, %(autonomous_system_organization)s, + %(domain)s, %(isp)s, + %(mobile_country_code)s, %(mobile_network_code)s, + %(network)s, %(organization)s, %(static_ip_score)s, + %(user_type)s, %(postal_code)s, %(latitude)s, %(longitude)s, + %(accuracy_radius)s, + %(updated)s + ) + ON CONFLICT (ip) DO NOTHING; + """, + params=instance.model_dump(mode="json"), + ) + + return instance + + def create_or_update(self, ipinfo: IPInformation): + ipinfo.normalize_ip() + keys = [key for key, field in ipinfo.model_fields.items() if not field.exclude] + data = ipinfo.model_dump_mysql() + + keys_str = ", ".join(keys) + values_str = ", ".join([f"%({k})s" for k in keys]) + update_cols = set(keys) - {"ip"} + update_str = ", ".join([f"{k} = EXCLUDED.{k}" for k in update_cols]) + + query = f""" + INSERT INTO thl_ipinformation ({keys_str}) + VALUES ({values_str}) + ON CONFLICT (ip) DO UPDATE + SET {update_str} + """ + self.pg_config.execute_write(query, params=data) + + def get_ip_info(self, ip: IPvAnyAddressStr) -> Optional["IPInformation"]: + res = self.fetch_ip_information(filter_ips=[ip]) + if len(res) != 1: + return None + + return res[0] + + def fetch_ip_information( + self, + filter_ips: List[IPvAnyAddressStr], + ) -> List["IPInformation"]: + + if len(filter_ips) == 0: + return [] + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + res = [] + for chunk in chunked(filter_ips, 500): + res.extend( + self.fetch_ip_information_( + c=c, + filter_ips=chunk, + ) + ) + return res + + def fetch_ip_information_( + self, + c: Cursor, + filter_ips: List[IPvAnyAddressStr], + ) -> List["IPInformation"]: + """ + IPs are converted to normalized form (/64 network exploded) for DB lookup, + and are then matched back to the original queried form for return. + e.g. '2600:1700:ece0:9410:055d:faf3:c15d:06e4' is passed in, + it gets converted to '2600:1700:ece0:9410:0000:0000:0000:0000' for db lookup, + the result gets + """ + + assert len(filter_ips) <= 500, "chunk me" + normalized_ip_lookup = {ip: normalize_ip(ip)[0] for ip in filter_ips} + normalized_ips = set(normalized_ip_lookup.values()) + + c.execute( + query=f""" + SELECT i.ip, i.geoname_id, + i.country_iso, i.registered_country_iso, + i.is_anonymous, i.is_anonymous_vpn, i.is_hosting_provider, + i.is_public_proxy, i.is_tor_exit_node, i.is_residential_proxy, + i.autonomous_system_number, i.autonomous_system_organization, + i.domain, i.isp, + i.mobile_country_code, i.mobile_network_code, + i.network, i.organization, + i.static_ip_score, i.user_type, i.postal_code, + i.latitude, i.longitude, i.accuracy_radius, + i.updated + + FROM thl_ipinformation AS i + WHERE i.ip = ANY(%s) + """, + params=[list(normalized_ips)], + ) + + return [IPInformation.from_mysql(i) for i in c.fetchall()] + + @staticmethod + def monitor_ipinformation(pg_config: PostgresConfig) -> None: + """Continually check our IPInformation table to ensure location information + is being saved properly""" + + # Check the percentage of IPs that don't have a country_iso in the past 12 hours + # Split query is 1000x faster + query = """ + SELECT COUNT(*) AS numerator + FROM thl_ipinformation + WHERE updated >= NOW() - INTERVAL '12 hours' + AND country_iso IS NULL; + """ + numerator = list(pg_config.execute_sql_query(query=query))[0]["numerator"] + + query = f""" + SELECT COUNT(1) AS denominator + FROM thl_ipinformation + WHERE updated >= NOW() - INTERVAL '12 hours' + """ + denominator = list(pg_config.execute_sql_query(query=query))[0]["denominator"] + if denominator == 0: + pass + percent_empty = numerator / (denominator or 1) + # TODO: Post to telegraf / grafana + + return + + +class GeoIpInfoManager(PostgresManagerWithRedis): + + def get(self, ip_address: IPvAnyAddressStr) -> Optional[GeoIPInformation]: + res = self.get_cache(ip_address) + if res: + return res + res = self.get_mysql_if_exists(ip_address) + if res: + self.set_cache(res) + return res + + def get_multi( + self, ip_addresses: Collection[IPvAnyAddressStr] + ) -> Dict[IPvAnyAddressStr, Optional[GeoIPInformation]]: + if not ip_addresses: + return {} + # To deploy this, we still have (for the next 28 days) users who's + # ipv6 history was looked up and saved using the full /128. We need + # to pull those if the /64 doesn't exist. + # See notes in get_cache_multi & get_mysql_multi + res = self.get_cache_multi(ip_addresses=ip_addresses) + missing_ips = {k for k, v in res.items() if v is None and k in ip_addresses} + res_mysql = self.get_mysql_multi(ips=missing_ips) + self.set_cache_multi({k: v for k, v in res_mysql.items() if v}) + res.update(res_mysql) + return res + + def set_cache_multi( + self, ipinfo_map: Dict[IPvAnyAddressStr, GeoIPInformation] + ) -> None: + """Set multiple GeoIPInformation objects in Redis in one call.""" + if not ipinfo_map: + return + + pipe = self.redis_client.pipeline(transaction=False) + expire_seconds = 3 * 24 * 3600 + for ip, ipinfo in ipinfo_map.items(): + pipe.set( + self.get_cache_key(ip), + ipinfo.model_dump_json(), + ex=expire_seconds, + ) + pipe.execute() + + @staticmethod + def compress_ip(ip: str) -> str: + """ + To support looking up an ip in the db before we switched + to using the exploded form. (remove me 28 days after 2025-11-15) + """ + addr = ipaddress.ip_address(ip) + if addr.version == 4: + return str(addr) + return addr.compressed + + def get_cache_multi( + self, ip_addresses: Collection[IPvAnyAddressStr] + ) -> Dict[IPvAnyAddressStr, Optional[GeoIPInformation]]: + """Get multiple GeoIPInformation objects from Redis in one call. + + Returns a dict mapping IP address -> GeoIPInformation (or None if not in cache). + """ + if not ip_addresses: + return {} + # We must do it like this b/c we could have multiple /128 ips that normalize + # to the same normalized /64 ip, and we don't want to "loose" them. + ip_norm_lookup = {ip: normalize_ip(ip) for ip in ip_addresses} + normalized_ips = {v[0] for v in ip_norm_lookup.values()} + # also lookup exact matches (can remove this 28 days from 2025-11-15) + normalized_ips.update(ip_addresses) + # also lookup compressed form ... (remove me also) + normalized_ips.update({self.compress_ip(ip) for ip in ip_addresses}) + + keys = [self.get_cache_key(ip) for ip in normalized_ips] + res = self.redis_client.mget(keys) + res = [GeoIPInformation.model_validate_json(raw) for raw in res if raw] + gs = {x.ip: x for x in res} + + res2 = dict() + for ip, (normalized_ip, lookup_prefix) in ip_norm_lookup.items(): + if normalized_ip not in gs: + # try the non-normalized (remove me also 28 days from 2025-11-15) + if ip in gs: + res2[ip] = gs[ip].model_copy() + continue + res2[ip] = None + continue + g = gs[normalized_ip] + g.ip = ip + g.lookup_prefix = lookup_prefix + res2[g.ip] = g.model_copy() + return res2 + + def get_cache_key(self, ip_address: IPvAnyAddressStr) -> str: + return self.cache_prefix + f"thl:GeoIpInfoManager:{ip_address}" + + def clear_cache(self, ip_address: IPvAnyAddressStr) -> None: + # typically for testing + self.redis_client.delete(self.get_cache_key(ip_address=ip_address)) + + def set_cache(self, ipinfo: GeoIPInformation): + ipinfo = ipinfo.model_copy() + ipinfo.normalize_ip() + data = ipinfo.model_dump_json() + return self.redis_client.set( + self.get_cache_key(ip_address=ipinfo.ip), data, ex=3 * 24 * 3600 + ) + + def get_cache(self, ip_address: IPvAnyAddressStr) -> Optional[GeoIPInformation]: + normalized_ip, lookup_prefix = normalize_ip(ip_address) + res: str = self.get_cache_raw(normalized_ip) + if not res: + return None + g = GeoIPInformation.model_validate_json(res) + g.ip = ip_address + g.lookup_prefix = lookup_prefix + return g + + def get_cache_raw(self, ip_address: IPvAnyAddressStr) -> str: + return self.redis_client.get(self.get_cache_key(ip_address=ip_address)) + + def get_mysql_if_exists(self, ip_address: IPvAnyAddressStr): + try: + return self.get_mysql(ip_address=ip_address) + except AssertionError: + return None + + def get_mysql_raw(self, ip_address: IPvAnyAddressStr): + query = """ + SELECT + geo.geoname_id, + geo.continent_name, + LOWER(geo.continent_code) AS continent_code, + geo.country_name, + LOWER(geo.country_iso) AS geo_country_iso, + geo.subdivision_1_iso, + geo.subdivision_1_name, + geo.subdivision_2_iso, + geo.subdivision_2_name, + geo.city_name, + geo.metro_code, + geo.time_zone, + geo.is_in_european_union, + LOWER(ipinfo.country_iso) AS country_iso, + ipinfo.registered_country_iso, + ipinfo.is_anonymous, + ipinfo.is_anonymous_vpn, + ipinfo.is_hosting_provider, + ipinfo.is_public_proxy, + ipinfo.is_tor_exit_node, + ipinfo.is_residential_proxy, + ipinfo.autonomous_system_number, + ipinfo.autonomous_system_organization, + ipinfo.domain, + ipinfo.isp, + ipinfo.mobile_country_code, + ipinfo.mobile_network_code, + ipinfo.network, + ipinfo.organization, + ipinfo.static_ip_score, + ipinfo.user_type, + ipinfo.postal_code, + CAST(ipinfo.latitude AS float) AS latitude, + CAST(ipinfo.longitude AS float) AS longitude, + ipinfo.accuracy_radius, + ipinfo.ip, + ipinfo.updated + FROM thl_ipinformation AS ipinfo + LEFT JOIN thl_geoname AS geo + ON ipinfo.geoname_id = geo.geoname_id + WHERE ipinfo.ip = %s + """ + res = self.pg_config.execute_sql_query(query=query, params=[ip_address]) + assert len(res) == 1 + d = res[0] + if d.get("geo_country_iso") and (d["geo_country_iso"] != d["country_iso"]): + raise ValueError( + f'mismatch between ipinfo country {d["country_iso"]} and geoname country {d["geo_country_iso"]}' + ) + return d + + def get_mysql(self, ip_address: IPvAnyAddressStr): + normalized_ip, lookup_prefix = normalize_ip(ip_address) + d = self.get_mysql_raw(normalized_ip) + g = GeoIPInformation.from_mysql(d) + g.ip = ip_address + g.lookup_prefix = lookup_prefix + return g + + def recreate_cache(self, ip_address: IPvAnyAddressStr) -> GeoIPInformation: + res = self.get_mysql(ip_address) + self.set_cache(res) + return res + + def get_mysql_multi( + self, + ips: Collection[IPvAnyAddressStr], + ) -> Dict[IPvAnyAddressStr, Optional[GeoIPInformation]]: + + if len(ips) == 0: + return {} + + with self.pg_config.make_connection() as sql_connection: + sql_connection: pymysql.Connection + with sql_connection.cursor() as c: + res = {} + for chunk in chunked(ips, 500): + inner = self.get_mysql_multi_chunk( + c=c, + ips=chunk, + ) + res.update(inner) + return res + + def get_mysql_multi_chunk( + self, + c: Cursor, + ips: List[IPvAnyAddressStr], + ) -> Dict[IPvAnyAddressStr, Optional[GeoIPInformation]]: + + assert len(ips) <= 500, "chunk me" + + # We must do it like this b/c we could have multiple /128 ips that normalize + # to the same normalized /64 ip, and we don't want to "loose" them. + ip_norm_lookup = {ip: normalize_ip(ip) for ip in ips} + normalized_ips = {v[0] for v in ip_norm_lookup.values()} + # also lookup exact matches (can remove this 28 days from 2025-11-15) + normalized_ips.update(ips) + # also lookup compressed form ... (remove me also) + normalized_ips.update({self.compress_ip(ip) for ip in ips}) + + c.execute( + query=f""" + SELECT + geo.geoname_id, + geo.continent_name, + LOWER(geo.continent_code) AS continent_code, + geo.country_name, + LOWER(geo.country_iso) AS geo_country_iso, + geo.subdivision_1_iso, + geo.subdivision_1_name, + geo.subdivision_2_iso, + geo.subdivision_2_name, + geo.city_name, + geo.metro_code, + geo.time_zone, + geo.is_in_european_union, + LOWER(ipinfo.country_iso) AS country_iso, + ipinfo.registered_country_iso, + ipinfo.is_anonymous, + ipinfo.is_anonymous_vpn, + ipinfo.is_hosting_provider, + ipinfo.is_public_proxy, + ipinfo.is_tor_exit_node, + ipinfo.is_residential_proxy, + ipinfo.autonomous_system_number, + ipinfo.autonomous_system_organization, + ipinfo.domain, + ipinfo.isp, + ipinfo.mobile_country_code, + ipinfo.mobile_network_code, + ipinfo.network, + ipinfo.organization, + ipinfo.static_ip_score, + ipinfo.user_type, + ipinfo.postal_code, + CAST(ipinfo.latitude AS float) AS latitude, + CAST(ipinfo.longitude AS float) AS longitude, + ipinfo.accuracy_radius, + ipinfo.ip, + ipinfo.updated + FROM thl_ipinformation AS ipinfo + LEFT JOIN thl_geoname AS geo + ON ipinfo.geoname_id = geo.geoname_id + WHERE ipinfo.ip = ANY(%s) + """, + params=[list(normalized_ips)], + ) + + res = c.fetchall() + for d in res: + if d.get("geo_country_iso") and (d["geo_country_iso"] != d["country_iso"]): + raise ValueError( + f'mismatch between ipinfo country {d["country_iso"]} and geoname country {d["geo_country_iso"]}' + ) + gs = [GeoIPInformation.from_mysql(i) for i in res] + gs = {g.ip: g for g in gs} + res2 = dict() + for ip, (normalized_ip, lookup_prefix) in ip_norm_lookup.items(): + if normalized_ip not in gs: + # also can remove 28 days after 2025-11-15 + if ip in gs: + res2[ip] = gs[ip].model_copy() + continue + res2[ip] = None + continue + g = gs[normalized_ip] + g.ip = ip + g.lookup_prefix = lookup_prefix + res2[g.ip] = g.model_copy() + return res2 diff --git a/generalresearch/managers/thl/ledger_manager/__init__.py b/generalresearch/managers/thl/ledger_manager/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/thl/ledger_manager/conditions.py b/generalresearch/managers/thl/ledger_manager/conditions.py new file mode 100644 index 0000000..3c03300 --- /dev/null +++ b/generalresearch/managers/thl/ledger_manager/conditions.py @@ -0,0 +1,217 @@ +import logging +from datetime import datetime, timezone, timedelta +from typing import Callable, Optional, Tuple, TYPE_CHECKING + +from generalresearch.config import JAMES_BILLINGS_TX_CUTOFF, JAMES_BILLINGS_BPID +from generalresearch.currency import USDCent +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.session import Wall, Session +from generalresearch.models.thl.user import User + +logging.basicConfig() +logger = logging.getLogger("LedgerManager") +logger.setLevel(logging.INFO) + +if TYPE_CHECKING: + from generalresearch.managers.thl.ledger_manager.ledger import ( + LedgerManager, + ) + from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, + ) + + +def generate_condition_mp_payment(wall: "Wall") -> Callable: + """This returns a function that checks if the payment for this wall event + exists already. This function gets run after we acquire a lock. It + should return True if we want to continue (create a tx). + """ + wall_uuid = wall.uuid + + def _condition(lm: "LedgerManager") -> bool: + tag = f"{lm.currency.value}:mp_payment:{wall_uuid}" + txs = lm.get_tx_ids_by_tag(tag=tag) + return len(txs) == 0 + + return _condition + + +def generate_condition_bp_payment(session: "Session") -> Callable: + """This returns a function that checks if the payment for this Session + exists already. This function gets run after we acquire a lock. It + should return True if we want to continue (create a tx). + """ + session_uuid = session.uuid + + def _condition(lm: "LedgerManager") -> bool: + tag = f"{lm.currency.value}:bp_payment:{session_uuid}" + txs_ids = lm.get_tx_ids_by_tag(tag=tag) + return len(txs_ids) == 0 + + return _condition + + +def generate_condition_tag_exists(tag: str) -> Callable: + """This returns a function that checks if a tx with this tag already + exists. It should return True if we want to continue (create a tx). + """ + + def _condition(lm: "LedgerManager") -> bool: + txs_ids = lm.get_tx_ids_by_tag(tag=tag) + return len(txs_ids) == 0 + + return _condition + + +def generate_condition_bp_payout( + product: "Product", + amount: USDCent, + payoutevent_uuid: UUIDStr, + skip_one_per_day_check: bool = False, + skip_wallet_balance_check: bool = False, +) -> Callable: + created = datetime.now(tz=timezone.utc) + + def _condition( + lm: "ThlLedgerManager", + ) -> Tuple[bool, str]: + bp_wallet_account = lm.get_account_or_create_bp_wallet(product=product) + tag = f"{lm.currency.value}:bp_payout:{payoutevent_uuid}" + txs_ids = lm.get_tx_ids_by_tag(tag=tag) + + if len(txs_ids) != 0: + logger.info(f"{tag} failed condition check: already paid out payoutevent") + return False, "duplicate tag" + + if not skip_one_per_day_check: + txs = lm.get_tx_filtered_by_account( + account_uuid=bp_wallet_account.uuid, + time_start=created - timedelta(days=1), + time_end=created, + ) + txs = [tx for tx in txs if tx.metadata.get("tx_type") == "bp_payout"] + if len(txs) != 0: + logger.info(f"{tag} failed condition check >1 tx per day") + return False, ">1 tx per day" + + if not skip_wallet_balance_check: + balance: int = lm.get_account_balance(account=bp_wallet_account) + if balance < amount: + logger.info( + f"{tag} failed condition check balance: {balance} < requested amount: {amount}" + ) + return False, "insufficient balance" + + return True, "" + + return _condition + + +def generate_condition_user_payout_request( + user: User, payoutevent_uuid: UUIDStr, min_balance: Optional[int] = None +) -> Callable: + """This returns a function that checks if `user` has at least + `min_balance` in their wallet and that a payout request hasn't already + been issued with this payoutevent_uuid. + + min_balance is an Optional[int] and not a USDCent because I believe + that it could be negative. - Max 2024-07-18 + """ + + if min_balance is not None: + assert isinstance(min_balance, int) + + def _condition(lm: "ThlLedgerManager") -> bool: + tag = f"{lm.currency.value}:user_payout:{payoutevent_uuid}:request" + txs_ids = lm.get_tx_ids_by_tag(tag) + + if len(txs_ids) != 0: + logger.info(f"{tag} failed condition check duplicate transaction") + return False + + if min_balance is not None: + user_wallet_account = lm.get_account_or_create_user_wallet(user) + if user.product_id == JAMES_BILLINGS_BPID: + balance = lm.get_account_balance_timerange( + user_wallet_account, time_start=JAMES_BILLINGS_TX_CUTOFF + ) + else: + balance = lm.get_account_balance(user_wallet_account) + if balance < min_balance: + logger.info( + f"{tag} failed condition check balance: {balance} < requested amount: {min_balance}" + ) + return False + return True + + return _condition + + +def generate_condition_enter_contest( + user: User, tag: str, min_balance: USDCent +) -> Callable: + """This returns a function that checks if `user` has at least + `min_balance` in their wallet and that a tx doesn't already exist + with this tag + """ + assert isinstance(min_balance, USDCent), "balance must be USDCent" + + def _condition(lm: "ThlLedgerManager") -> Tuple[bool, str]: + txs_ids = lm.get_tx_ids_by_tag(tag) + if len(txs_ids) != 0: + logger.info(f"{tag} failed condition check duplicate transaction") + return False, "duplicate transaction" + + user_wallet_account = lm.get_account_or_create_user_wallet(user) + balance = lm.get_account_balance(user_wallet_account) + if balance < min_balance: + logger.info( + f"{tag} failed condition check balance: {balance} < requested amount: {min_balance}" + ) + return False, "insufficient balance" + return True, "" + + return _condition + + +def generate_condition_user_payout_action( + payoutevent_uuid: UUIDStr, action: str +) -> Callable: + """The balance has already been taken from the user's wallet, so there + is no balance check. We only just check that the ledger transaction + doesn't already exist. + + If the action is complete, we check if it hasn't already been cancelled. + If canceled, we check it hasn't been completed. + + :param action: should be in {'complete', 'cancel'} + """ + + def _condition(lm: "ThlLedgerManager") -> bool: + tag = f"{lm.currency.value}:user_payout:{payoutevent_uuid}:{action}" + txs_ids = lm.get_tx_ids_by_tag(tag) + if len(txs_ids) != 0: + logger.info(f"{tag} failed condition check duplicate transaction") + return False + + if action == "complete": + tag = f"{lm.currency.value}:user_payout:{payoutevent_uuid}:cancel" + txs = lm.get_tx_ids_by_tag(tag) + if len(txs) != 0: + logger.warning( + f"{tag} failed condition: trying to complete payout that was already cancelled" + ) + return False + + if action == "cancel": + tag = f"{lm.currency.value}:user_payout:{payoutevent_uuid}:complete" + txs = lm.get_tx_ids_by_tag(tag) + if len(txs) != 0: + logger.warning( + f"{tag} failed condition: trying to cancel payout that was already completed" + ) + return False + return True + + return _condition diff --git a/generalresearch/managers/thl/ledger_manager/exceptions.py b/generalresearch/managers/thl/ledger_manager/exceptions.py new file mode 100644 index 0000000..b79c153 --- /dev/null +++ b/generalresearch/managers/thl/ledger_manager/exceptions.py @@ -0,0 +1,49 @@ +class LedgerAccountDoesntExistError(Exception): + pass + + +class LedgerTransactionDoesntExistError(Exception): + pass + + +class LedgerTransactionCreateError(Exception): + """ + Ledger transaction creation failed + """ + + pass + + +class LedgerTransactionCreateLockError(LedgerTransactionCreateError): + """ + Ledger transaction creation failed because we could not acquire a lock + """ + + pass + + +class LedgerTransactionReleaseLockError(LedgerTransactionCreateError): + """ + There was an error releasing the redis lock. I'm not exactly sure why this + happens sometimes, but it does. Seems to be almost always during + back-populate as in sentry I see this very rarely. + """ + + pass + + +class LedgerTransactionFlagAlreadyExistsError(LedgerTransactionCreateError): + """ + Ledger transaction creation failed because the redis flag for this + tx was already set + """ + + pass + + +class LedgerTransactionConditionFailedError(LedgerTransactionCreateError): + """ + We tried to create a transaction but the condition check failed. + """ + + pass diff --git a/generalresearch/managers/thl/ledger_manager/ledger.py b/generalresearch/managers/thl/ledger_manager/ledger.py new file mode 100644 index 0000000..419f613 --- /dev/null +++ b/generalresearch/managers/thl/ledger_manager/ledger.py @@ -0,0 +1,1139 @@ +import logging +from collections import defaultdict +from datetime import timedelta, datetime, timezone +from typing import Optional, List, Dict, Callable, Collection, Tuple, Set +from uuid import UUID + +import redis +from more_itertools import chunked, flatten +from pydantic import AwareDatetime, PositiveInt +from redis.exceptions import LockError, LockNotOwnedError + +from generalresearch.currency import LedgerCurrency +from generalresearch.managers import parse_order_by +from generalresearch.managers.base import ( + Permission, + PostgresManager, + RedisManager, +) +from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + LedgerTransactionDoesntExistError, + LedgerTransactionFlagAlreadyExistsError, + LedgerTransactionCreateLockError, + LedgerTransactionConditionFailedError, + LedgerTransactionReleaseLockError, + LedgerTransactionCreateError, +) +from generalresearch.models.custom_types import UUIDStr, check_valid_uuid +from generalresearch.models.thl.ledger import ( + LedgerAccount, + LedgerTransaction, + LedgerEntry, + UserLedgerTransactionTypesSummary, + UserLedgerTransactionType, +) +from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig + +logging.basicConfig() +logger = logging.getLogger("LedgerManager") +logger.setLevel(logging.INFO) + +# We can re-use this in any query that is retrieving full TXs +FULL_TX_JOINS = """ +LEFT JOIN LATERAL ( + SELECT + string_agg(tm.key || '=' || tm.value, '&') AS key_value_pairs + FROM ledger_transactionmetadata tm + WHERE tm.transaction_id = lt.id +) meta ON TRUE +LEFT JOIN LATERAL ( + SELECT + jsonb_agg( + jsonb_build_object( + 'direction', le.direction, + 'amount', le.amount, + 'account_id', le.account_id, + 'entry_id', le.id + ) + ) AS entries_json + FROM ledger_entry le + WHERE le.transaction_id = lt.id +) entries ON TRUE +""" + + +class LedgerManagerBasePostgres(PostgresManager, RedisManager): + def __init__( + self, + pg_config: PostgresConfig, + redis_config: RedisConfig, + permissions: Collection[Permission] = None, + cache_prefix: Optional[str] = None, + currency: Optional[LedgerCurrency] = LedgerCurrency.USD, + testing: bool = False, + ): + if permissions is not None and ( + Permission.CREATE in permissions and redis_config is None + ): + raise ValueError("must pass redis_url when requesting CREATE permission") + cache_prefix = cache_prefix or "ledger-manager" + super().__init__( + pg_config=pg_config, + permissions=permissions, + redis_config=redis_config, + cache_prefix=cache_prefix, + ) + self.currency = currency + self.testing = testing + if self.testing: + self.currency = LedgerCurrency.TEST + + def make_filter_str( + self, + time_start: Optional[datetime] = None, + time_end: Optional[datetime] = None, + account_uuid: Optional[str] = None, + metadata_key: Optional[str] = None, + metadata_value: Optional[str] = None, + ): + filters = [] + params = {} + if time_start or time_end: + time_end = time_end or datetime.now(tz=timezone.utc) + time_start = time_start or datetime(2017, 1, 1, tzinfo=timezone.utc) + assert time_start.tzinfo.utcoffset(time_start) == timedelta() + assert time_end.tzinfo.utcoffset(time_end) == timedelta() + filters.append("lt.created BETWEEN %(time_start)s AND %(time_end)s") + params["time_start"] = time_start.replace(tzinfo=None) + params["time_end"] = time_end.replace(tzinfo=None) + if account_uuid: + filters.append("le.account_id = %(account_uuid)s") + params["account_uuid"] = account_uuid + if metadata_key is not None: + filters.append("key = %(metadata_key)s") + params["metadata_key"] = metadata_key + if metadata_value is not None: + assert ( + metadata_key is not None + ), "cannot filter by metadata_value without metadata_key" + filters.append("value = %(metadata_value)s") + params["metadata_value"] = metadata_value + + filter_str = "WHERE " + " AND ".join(filters) if filters else "" + return filter_str, params + + +class LedgerTransactionManager(LedgerManagerBasePostgres): + + def create_tx( + self, + entries: List[LedgerEntry], + metadata: Optional[Dict[str, str]] = None, + ext_description: Optional[str] = None, + tag: Optional[str] = None, + created: Optional[AwareDatetime] = None, + ) -> LedgerTransaction: + """ + :returns a LedgerTransaction ID. This is because we can't fully populate + the response object with valid children (eg: Entries can't get + their ID because of c.executemany only returns back a single + lastrowid) + """ + + assert ( + Permission.CREATE in self.permissions + ), "LedgerTransactionManager has insufficient Permissions" + + if metadata is None: + metadata = dict() + if created is None: + created = datetime.now(tz=timezone.utc) + + t = LedgerTransaction( + created=created, + ext_description=ext_description, + tag=tag, + metadata=metadata, + entries=entries, + ) + d = t.model_dump_mysql(include={"created", "ext_description", "tag"}) + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + # (1) Insert the Ledger Tx into the DB + c.execute( + """ + INSERT INTO ledger_transaction + (created, ext_description, tag) + VALUES (%(created)s, %(ext_description)s, %(tag)s) + RETURNING id; + """, + d, + ) + t.id = c.fetchone()["id"] + + # (2) Associate any metadata with the recently created Ledger Tx in the DB + metadata_values = [ + {"key": k, "value": v, "transaction_id": t.id} + for k, v in metadata.items() + ] + c.executemany( + """ + INSERT INTO ledger_transactionmetadata + (key, value, transaction_id) + VALUES (%(key)s, %(value)s, %(transaction_id)s) + """, + metadata_values, + ) + + # (3) Create the Ledger Tx Entries in the DB + for entry in entries: + entry.transaction_id = t.id + entry_values = [entry.model_dump(mode="json") for entry in entries] + c.executemany( + """ + INSERT INTO ledger_entry + (direction, amount, account_id, transaction_id) + VALUES (%(direction)s, %(amount)s, %(account_uuid)s, + %(transaction_id)s) + """, + entry_values, + ) + conn.commit() + return t + + def create_tx_protected( + self, + lock_key: str, + condition: Callable, + create_tx_func: Callable, + flag_key: Optional[str] = None, + skip_flag_check=False, + ) -> LedgerTransaction: + """ + The complexity here is that even we protect a transaction logic with + a lock, if two workers try to create the same transaction at a time, + the lock just prevents them from doing it simultaneously; they will + instead just do it sequentially. To prevent this, we can pass in a + conditional that evaluates AFTER the lock is acquired, and we break + if the condition is not met. + + 1) Try to acquire a lock. If the lock is currently held, quit. All the + following within held lock. + 2) Check if the flag exists. If so, quit. + 3) Check if the condition is True. If not, quit. (we have to do this + b/c the flag may have expired out of redis). + 4) Create transaction + 5) Release lock + + :param lock_key: A str that should protect the conditions we are + checking. Could be unique for this specific transaction or account. + e.g. For a user cashing out their wallet balance, we would lock on + the user.uuid, for paying for a task, we'd use the wall/session + uuid. + :param flag_key: A str that should be unique for this specific + transaction only (used for de-dupe purposes). e.g. User is cashing + out their wallet. We lock using the user's uuid, so they can't + cashout 2 different txs at the same time (and result in a negative + wallet balance). The flag is set based on the tx's ID just as a + quick de-dupe check to make sure we don't run the same tx twice. + :param condition: A function that gets run once we acquire the lock. It + should return True if we want to continue with creating a new tx. + The LedgerManager will get passed in as the first and only argument. + The condition should do things like 1) check if the tx is already + in the db, and/or 2) check if the account has sufficient funds to + cover the tx, for e.g. + :param create_tx_func: The function to run that creates the transaction. + :param skip_flag_check: If create_tx_func or condition call fails, the + flag would get set even though the transaction did not get created. + If we want to manually re-try it, we need to skip the flag check. + + :return: The transaction that was created (or raise + a LedgerTransactionCreateError()) + """ + rc = self.redis_client + lock_name = f"{self.cache_prefix}:transaction_lock:{lock_key}" + if flag_key is None: + flag_name = f"{self.cache_prefix}:transaction_flag:{lock_key}" + else: + flag_name = f"{self.cache_prefix}:transaction_flag:{flag_key}" + + # The lock is NOT blocking (by default). So if we can't acquire the + # lock immediately, it means someone else has it already, and is + # probably executing this transaction, so quit. + # The timeout is how long before the redis lock key expires, which + # would only happen if we didn't exit the `with` block normally + # (exiting the `with` block normally clears the lock key). + # We could also have `blocking_timeout`, which is how long to wait + # until we can acquire a lock, but this is used only if `blocking` + # is True. + # + # There is nothing here limiting how long we can spend working within + # the `with` block. + try: + tx_created = False + with rc.lock( + name=lock_name, + timeout=10, + blocking=False, + blocking_timeout=None, + ): + if skip_flag_check is False and rc.get(flag_name): + raise LedgerTransactionFlagAlreadyExistsError() + + # Maybe the flag set should be moved after the create_tx_func()? + # If we do that however, if the condition is failing and + # taking a long time, this would allow the tx to get retried + # over and over every 4 seconds, which is not good. + rc.set(name=flag_name, value=1, ex=3600 * 24) + # Condition returns either bool or Tuple[bool, str] + condition_res = condition(self) + if isinstance(condition_res, tuple): + condition_res, condition_msg = condition_res + else: + condition_msg = "" + if condition_res is False: + rc.delete(flag_name) + raise LedgerTransactionConditionFailedError(condition_msg) + + tx = create_tx_func() + tx_created = True + + except LockNotOwnedError: + # This happens if there is an error trying to release the lock. + # The tx most likely was created. + raise LedgerTransactionReleaseLockError() + + except LockError as e: + # There was an error acquiring the lock. The `with` block + # did not run. + logger.log(level=logging.ERROR, msg=str(e)) + rc.delete(flag_name) + raise LedgerTransactionCreateLockError() + + except redis.exceptions.RedisError as e: + if not tx_created: + # Redis failed before tx was created. Could be either on lock acquire, on + # flag check (get or set), on anything in the condition checks. + rc.delete(flag_name) + raise LedgerTransactionCreateError(f"Redis error: {e}") + else: + # Most likely redis fail on lock release. The tx was already created! + raise LedgerTransactionReleaseLockError(f"Redis error: {e}") + + return tx + + def get_tx_ids_by_tag(self, tag: str) -> set[PositiveInt]: + """`tag` is not a unique field, so it may return more than 1 + transaction. It should NOT return a substantial number of + transactions. Use filtering by metadata for that purpose. + + returns: a list of id, not the full transactions + """ + + assert ":" in tag, "Please confirm the tag is valid" + assert len(tag) > 6, "Please confirm the tag is valid" + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT lt.id + FROM ledger_transaction AS lt + WHERE tag = %s + LIMIT 101 + """, + params=[tag], + ) + if len(res) > 100: + raise ValueError(f"Too many txs with this tag: {tag}") + return {x["id"] for x in res} + + def get_tx_by_tag(self, tag: str) -> List[LedgerTransaction]: + tx_ids = self.get_tx_ids_by_tag(tag=tag) + return self.get_tx_by_ids(transaction_ids=tx_ids) + + def get_tx_ids_by_tags(self, tags: List[str]) -> set[PositiveInt]: + res = self.pg_config.execute_sql_query( + query=f""" + SELECT lt.id, lt.tag, lt.created, lt.ext_description + FROM ledger_transaction AS lt + WHERE tag = ANY(%s) + """, + params=[list(tags)], + ) + + return {x["id"] for x in res} + + def get_txs_by_tags(self, tags: List[str]) -> List[LedgerTransaction]: + tx_ids = self.get_tx_ids_by_tags(tags=tags) + return self.get_tx_by_ids(transaction_ids=tx_ids) + + def get_tx_by_id(self, transaction_id: PositiveInt) -> LedgerTransaction: + assert isinstance(transaction_id, int), "transaction_id must be an PositiveInt" + + res = self.get_tx_by_ids(transaction_ids=[transaction_id]) + + if len(res) != 1: + raise LedgerTransactionDoesntExistError + + return res[0] + + def get_tx_by_ids( + self, + transaction_ids: Collection[PositiveInt], + ) -> List[LedgerTransaction]: + + args = {"transaction_ids": list(transaction_ids)} + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT + lt.id AS transaction_id, + lt.created, + lt.ext_description, + lt.tag, + meta.key_value_pairs, + entries.entries_json + FROM ledger_transaction lt + {FULL_TX_JOINS} + WHERE lt.id = ANY(%(transaction_ids)s); + """, + params=args, + ) + return self.process_get_tx_mysql_rows_json(res) + + @staticmethod + def process_get_tx_mysql_rows_json( + rows: Collection[Dict], + ) -> List[LedgerTransaction]: + """Columns: transaction_id, created, ext_description, tag, + key_value_pairs, entries_json + - key_value_pairs: &-delimited key=value pairs + - entries_json: list of objects, containing keys: direction, + amount, entry_id, account_id + + """ + txs = [] + for row in rows: + if row["key_value_pairs"]: + metadata = { + key: value + for key, value in ( + pair.split("=") for pair in row["key_value_pairs"].split("&") + ) + } + else: + metadata = dict() + + entries = [ + LedgerEntry( + id=e["entry_id"], + amount=e["amount"], + direction=e["direction"], + account_uuid=UUID(e["account_id"]).hex, + transaction_id=row["transaction_id"], + ) + # Don't assume a Tx has Entries. We have cleanup methods that + # try to delete Tx if they failed (eg: during bp_payout) + # and we can't guarantee 2 entries per Tx + for e in row.get("entries_json", []) + ] + txs.append( + LedgerTransaction( + id=row["transaction_id"], + entries=entries, + metadata=metadata, + created=row["created"].replace(tzinfo=timezone.utc), + ext_description=row["ext_description"], + tag=row["tag"], + ) + ) + return txs + + def get_tx_filtered_by_account_summary( + self, + account_uuid: UUIDStr, + time_start: Optional[datetime] = None, + time_end: Optional[datetime] = None, + ) -> UserLedgerTransactionTypesSummary: + filter_str, params = self.make_filter_str( + time_start=time_start, + time_end=time_end, + ) + params["account_uuid"] = account_uuid + + # We do direction * -1 b/c the values here are w.r.t the user. + # noinspection SqlShouldBeInGroupBy + res = self.pg_config.execute_sql_query( + query=f""" + SELECT + tmd.value AS tx_type, + COUNT(1) AS entry_count, + MIN(le.amount * le.direction * -1) AS min_amount, + MAX(le.amount * le.direction * -1) AS max_amount, + SUM(le.amount * le.direction * -1) AS total_amount + FROM ledger_transaction lt + JOIN ledger_entry le + ON le.transaction_id = lt.id + AND le.account_id = %(account_uuid)s + JOIN ledger_transactionmetadata tmd + ON tmd.transaction_id = lt.id + AND tmd.key = 'tx_type' + {filter_str} + GROUP BY tmd.value + ORDER BY tmd.value; + """, + params=params, + ) + d = {x["tx_type"]: x for x in res} + return UserLedgerTransactionTypesSummary.model_validate(d) + + def get_tx_filtered_by_account_count( + self, + account_uuid: UUIDStr, + time_start: Optional[datetime] = None, + time_end: Optional[datetime] = None, + ): + filter_str, params = self.make_filter_str( + time_start=time_start, + time_end=time_end, + ) + params["account_uuid"] = account_uuid + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT COUNT(DISTINCT lt.id) as cnt + FROM ledger_transaction lt + JOIN ledger_entry le + ON le.transaction_id = lt.id + AND le.account_id = %(account_uuid)s + {filter_str} + """, + params=params, + ) + return res[0]["cnt"] if res else 0 + + def get_tx_filtered_by_account( + self, + account_uuid: UUIDStr, + time_start: Optional[datetime] = None, + time_end: Optional[datetime] = None, + order_by: Optional[str] = "created,tag", + ) -> List[LedgerTransaction]: + txs, _ = self.get_tx_filtered_by_account_paginated( + account_uuid=account_uuid, + time_start=time_start, + time_end=time_end, + order_by=order_by, + ) + return txs + + def get_balance_before_page( + self, + account_uuid: str, + oldest_created: datetime, + exclude_txs_before: Optional[datetime] = None, + ): + """ + In a paginated list of txs, if I want to calculate + a running balance, I need the balance in that account + starting at the oldest tx in the page + This is identical to get_account_balance_timerange + """ + params = { + "account_uuid": account_uuid, + "oldest_created": oldest_created, + } + exclude_str = "" + if exclude_txs_before: + exclude_str = "AND lt.created > %(exclude_txs_before)s" + params["exclude_txs_before"] = exclude_txs_before + query = f""" + SELECT + COALESCE(SUM(le.amount * le.direction * la.normal_balance), 0) AS balance_before_page + FROM ledger_transaction lt + JOIN ledger_entry le + ON le.transaction_id = lt.id + JOIN ledger_account la + ON la.uuid = le.account_id + WHERE le.account_id = %(account_uuid)s + AND lt.created < %(oldest_created)s + {exclude_str};""" + res = self.pg_config.execute_sql_query(query, params=params) + return res[0]["balance_before_page"] + + def include_running_balance( + self, + txs: List[UserLedgerTransactionType], + account_uuid: str, + exclude_txs_before: Optional[AwareDatetime] = None, + ): + """ + exclude_txs_before is NOT for filtering. It is a "hack" to exclude + transactions from before a certain date for balance consideration. + """ + if len(txs) == 0: + return txs + oldest_created = min([x.created for x in txs]) + balance_before_page = self.get_balance_before_page( + account_uuid=account_uuid, + oldest_created=oldest_created, + exclude_txs_before=exclude_txs_before, + ) + page_with_idx = list(enumerate(txs)) + page_with_idx.sort(key=lambda x: x[1].created) + balance = balance_before_page + for _, tx in page_with_idx: + balance += tx.amount + tx.balance_after = balance + # restore original order + page_with_idx.sort(key=lambda x: x[0]) + txs = [tx for _, tx in page_with_idx] + return txs + + def get_tx_filtered_by_account_paginated( + self, + account_uuid: UUIDStr, + time_start: Optional[datetime] = None, + time_end: Optional[datetime] = None, + page: Optional[int] = None, + size: Optional[int] = None, + order_by: Optional[str] = "created,tag", + ) -> Tuple[List[LedgerTransaction], int]: + """ + If time_start and/or time_end are passed, the txs are filtered to + include only that range. + - time_start is optional, default = beginning of time + - time_end is optional, default = now + + If page/size are passed, return only that page of the filtered (by + account_uuid and optionally time) items. + + Returns (list of items, total (after filtering)). + :param account_uuid: Will return txs that have a ledger entry that touches this account + :param time_start: Filter to include this range. Default: beginning of time + :param time_end: Filter to include this range. Default: now + :param page: page starts at 1 + :param size: size of page, default (if page is not None) = 100. (1<=page<=100) + :param order_by: Required for pagination. Uses django-rest-framework ordering syntax, + e.g. '-created,tag' for (created desc, tag asc) + """ + + assert isinstance(account_uuid, str), "account_uuid must be a str" + check_valid_uuid(account_uuid) + + filter_str, params = self.make_filter_str( + time_start=time_start, time_end=time_end, account_uuid=account_uuid + ) + + if page is not None: + assert type(page) is int + assert page >= 1, "page starts at 1" + size = size if size is not None else 100 + assert type(size) is int + assert 1 <= size <= 100 + params["offset"] = (page - 1) * size + params["limit"] = size + paginated_filter_str = " LIMIT %(limit)s OFFSET %(offset)s" + total = self.get_tx_filtered_by_account_count( + account_uuid=account_uuid, + time_start=time_start, + time_end=time_end, + ) + else: + paginated_filter_str = "" + # Don't need to do a count if we aren't paginating + total = None + + order_by_str = parse_order_by(order_by) + + res = self.pg_config.execute_sql_query( + query=f""" + WITH tx_ids AS ( + SELECT lt.id + FROM ledger_transaction lt + JOIN ledger_entry le + ON le.transaction_id = lt.id + {filter_str} + GROUP BY lt.id, lt.created + {order_by_str} + {paginated_filter_str} + ) + SELECT + lt.id AS transaction_id, + lt.created, + lt.ext_description, + lt.tag, + meta.key_value_pairs, + entries.entries_json + + FROM tx_ids t + JOIN ledger_transaction lt ON lt.id = t.id + + {FULL_TX_JOINS} + + {order_by_str}; + """, + params=params, + ) + if total is None: + total = len(res) + + return ( + self.process_get_tx_mysql_rows_json(res), + total, + ) + + def get_tx_filtered_by_metadata( + self, + metadata_key: str, + metadata_value: str, + time_start: Optional[datetime] = None, + time_end: Optional[datetime] = None, + ) -> List[LedgerTransaction]: + # Renamed from "get_tx_filtered" which is not a good name + + filter_str, params = self.make_filter_str( + time_start=time_start, + time_end=time_end, + metadata_key=metadata_key, + metadata_value=metadata_value, + ) + + res = self.pg_config.execute_sql_query( + query=f""" + WITH tx_ids AS ( + SELECT DISTINCT lt.id + FROM ledger_transaction lt + JOIN ledger_transactionmetadata ltm + ON ltm.transaction_id = lt.id + {filter_str} + ) + SELECT + lt.id AS transaction_id, + lt.created, + lt.ext_description, + lt.tag, + meta.key_value_pairs, + entries.entries_json + + FROM tx_ids t + JOIN ledger_transaction lt ON lt.id = t.id + {FULL_TX_JOINS} + """, + params=params, + ) + + return LedgerTransactionManager.process_get_tx_mysql_rows_json(res) + + +class LedgerMetadataManager(LedgerManagerBasePostgres): + """ + WARNING: TxtMetadata doesn't have an official Pydantic model + definition. So this is going to operate primarily on Dicts + """ + + def get_tx_metadata_by_txs( + self, transactions: List[LedgerTransaction] + ) -> Dict[PositiveInt, Dict]: + """ + Each transaction can have 1 metadata dictionary. However, each + metadata dictionary can have multiple key/value pairs that + corresponds to each metadata row in the database. + + """ + + tx_ids = set([tx.id for tx in transactions]) + res = self.pg_config.execute_sql_query( + query=""" + SELECT + tx_meta.id, tx_meta.key, + tx_meta.value, tx_meta.transaction_id + FROM ledger_transactionmetadata AS tx_meta + WHERE tx_meta.transaction_id = ANY(%s) + """, + params=[list(tx_ids)], + ) + + metadata = defaultdict(dict) + for x in res: + metadata[x["transaction_id"]][x["key"]] = x["value"] + + return metadata + + def get_tx_metadata_ids_by_tx( + self, transaction: LedgerTransaction + ) -> Set[PositiveInt]: + return self.get_tx_metadata_ids_by_txs(transactions=[transaction]) + + def get_tx_metadata_ids_by_txs( + self, transactions: List[LedgerTransaction] + ) -> Set[PositiveInt]: + """ + This explicitly returns the tx_metadata database ids. Potentially, + useful for counting total key/value pairs, and/or deleting records + from the database. + """ + + tx_ids = set([tx.id for tx in transactions]) + res = self.pg_config.execute_sql_query( + query=""" + SELECT tx_meta.id + FROM ledger_transactionmetadata AS tx_meta + WHERE tx_meta.transaction_id = ANY(%s) + """, + params=[list(tx_ids)], + ) + + return set([i["id"] for i in res]) + + +class LedgerEntryManager(LedgerManagerBasePostgres): + + def get_tx_entries_by_tx(self, transaction: LedgerTransaction) -> List[LedgerEntry]: + return self.get_tx_entries_by_txs(transactions=[transaction]) + + def get_tx_entries_by_txs( + self, transactions: List[LedgerTransaction] + ) -> List[LedgerEntry]: + tx_ids = set([tx.id for tx in transactions]) + tx_entries = self.pg_config.execute_sql_query( + query=""" + SELECT + entry.id, entry.direction, entry.amount, + entry.account_id as account_uuid, + entry.transaction_id + FROM ledger_entry AS entry + WHERE entry.transaction_id = ANY(%s) + """, + params=[list(tx_ids)], + ) + + return [LedgerEntry.model_validate(res) for res in tx_entries] + + +class LedgerAccountManager(LedgerManagerBasePostgres): + """This Manager class is primarily involved with any operations on the + ledger_account table within the ledger system. + + We have Ledger Accounts for many different purposes, + + """ + + def create_account(self, account: LedgerAccount) -> LedgerAccount: + assert ( + Permission.CREATE in self.permissions + ), "LedgerManager does not have sufficient permissions" + + d = account.model_dump(mode="json") + + # These we're excluded, so manually reassign them to the mysql args + d["reference_type"] = account.reference_type + d["qualified_name"] = account.qualified_name + + self.pg_config.execute_write( + query=""" + INSERT INTO ledger_account + (uuid, display_name, qualified_name, account_type, + normal_balance, reference_type, reference_uuid, + currency) + VALUES (%(uuid)s, %(display_name)s, %(qualified_name)s, + %(account_type)s, %(normal_balance)s, %(reference_type)s, + %(reference_uuid)s, %(currency)s) + """, + params=d, + ) + + return account + + def get_account( + self, qualified_name: str, raise_on_error=True + ) -> Optional[LedgerAccount]: + res = self.get_account_many( + qualified_names=[qualified_name], raise_on_error=raise_on_error + ) + return res[0] if len(res) == 1 else None + + def get_account_many_( + self, qualified_names: List[str], raise_on_error=True + ) -> List[Dict]: + assert len(qualified_names) <= 500, "chunk me" + + # qualified_name has a unique index so there can only be 0 or 1 match. + res = self.pg_config.execute_sql_query( + query=f""" + SELECT + uuid, display_name, qualified_name, account_type, + normal_balance, reference_type, + reference_uuid, currency + FROM ledger_account + WHERE qualified_name = ANY(%s); + """, + params=[qualified_names], + ) + + if raise_on_error and (not res or len(res) != len(qualified_names)): + raise LedgerAccountDoesntExistError + + return list(res) + + def get_account_many( + self, qualified_names: List[str], raise_on_error=True + ) -> List[LedgerAccount]: + res = flatten( + [ + self.get_account_many_(chunk, raise_on_error) + for chunk in chunked(qualified_names, 500) + ] + ) + return [LedgerAccount.model_validate(i) for i in res] + + def get_account_or_create(self, account: LedgerAccount) -> LedgerAccount: + res: Optional[LedgerAccount] = self.get_account( + qualified_name=account.qualified_name, raise_on_error=False + ) + return res or self.create_account(account=account) + + def get_accounts(self, qualified_names: List[str]) -> List[LedgerAccount]: + return self.get_account_many(qualified_names, raise_on_error=True) + + def get_accounts_if_exists(self, qualified_names: List[str]) -> List[LedgerAccount]: + """Rather than returning None, this may return an empty list, or + a list that has less LedgerAccount instances than the number of + qualified_names that was passed in. + """ + return self.get_account_many(qualified_names, raise_on_error=False) + + def get_account_if_exists(self, qualified_name: str) -> Optional[LedgerAccount]: + return self.get_account(qualified_name, raise_on_error=False) + + def get_account_balance(self, account: LedgerAccount) -> int: + """In a debit normal account, the balance is the sum of debits minus + the sum of credits. + + In a credit normal account, the balance is the sum of credits minus + the sum of debits. + + This returns an int and not a USDCent because an Account's balance + could be negative. + """ + + # TODO: Move to RR with long timeout (2min+), it causes problems + res = self.pg_config.execute_sql_query( + query=f""" + SELECT SUM(amount * direction) AS total + FROM ledger_entry + WHERE account_id = %s + """, + params=[account.uuid], + ) + if res: + return int((res[0]["total"] or 0) * account.normal_balance) + else: + return 0 + + def get_account_balance_timerange( + self, + account: LedgerAccount, + time_start: Optional[AwareDatetime] = None, + time_end: Optional[AwareDatetime] = None, + ) -> int: + """ + This returns an int and not a USDCent because an Account's balance + could be negative. + """ + + # I want the balance for this account optionally filtered by + # transactions within a time range + filter_str, params = self.make_filter_str( + account_uuid=account.uuid, time_start=time_start, time_end=time_end + ) + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT SUM(amount * direction * normal_balance) AS total + FROM ledger_entry AS le + JOIN ledger_transaction AS lt + ON le.transaction_id = lt.id + JOIN ledger_account AS la + ON le.account_id = la.uuid + {filter_str} + """, + params=params, + ) + if not res: + return 0 + return int(res[0]["total"]) if res[0]["total"] else 0 + + def get_account_filtered_balance( + self, + account: LedgerAccount, + metadata_key: str, + metadata_value: str, + time_start: Optional[datetime] = None, + time_end: Optional[datetime] = None, + ) -> int: + """I want the balance for this account filtered by transactions with + a certain tag. + + NOTE: This query will be wrong if the metadata join was changed! + b/c if a transaction had multiple matching metadata rows, then the + ledger_entry row will get returned multiple times and the + account_balance will be SUMmed wrong! + + This returns an int and not a USDCent because an Account's balance + could be negative. + """ + filter_str, params = self.make_filter_str( + account_uuid=account.uuid, + time_start=time_start, + time_end=time_end, + metadata_key=metadata_key, + metadata_value=metadata_value, + ) + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT SUM(amount * direction * normal_balance) AS total + FROM ledger_entry AS le + JOIN ledger_transaction AS lt + ON le.transaction_id = lt.id + JOIN ledger_transactionmetadata AS tm + ON lt.id = tm.transaction_id + JOIN ledger_account AS la + ON le.account_id = la.uuid + {filter_str} + """, + params=params, + ) + if not res: + return 0 + + return int(res[0]["total"]) if res[0]["total"] else 0 + + +class LedgerManager( + LedgerTransactionManager, + LedgerEntryManager, + LedgerAccountManager, + LedgerMetadataManager, +): + """This is the parent class manager for operating within the ledger + app. Many of the methods that are in here are unused, and written + when it was unclear what queries would be needed. + + As of discussion on 2025-05-01, more functionality should be put into + the TransactionManger, AccountManger, and even the creation of a + EntryManager or TransactionMetadataManger with various "verbose" + flags to determine how different relationships are returned. For + example, the TransactionManger doesn't need to always include details + about each Entry. + + Given that there is a "THL Ledger" the goal of these classes should be very + simple and related to the ledger itself, not any specific application + + """ + + def check_ledger_balanced(self) -> bool: + """This is for testing only, as it'll take forever to run this if + the ledger_manager is huge + """ + res = self.pg_config.execute_sql_query( + f""" + SELECT + SUM(CASE WHEN normal_balance = -1 THEN total ELSE 0 END) AS credit_total, + SUM(CASE WHEN normal_balance = 1 THEN total ELSE 0 END) AS debit_total + FROM ( + SELECT + SUM(amount * direction * normal_balance) AS total, + tl.normal_balance + FROM ledger_entry + JOIN ledger_account tl + ON ledger_entry.account_id = tl.uuid + GROUP BY account_id, normal_balance + ) x + """ + )[0] + return res["credit_total"] == res["debit_total"] + + def get_account_debit_credit_by_metadata( + self, + account: LedgerAccount, + metadata_key: str, + time_start: Optional[datetime] = None, + time_end: Optional[datetime] = None, + ) -> Dict[str, Dict[str, int]]: + """Show me the sum of debit and credit scoped to this account, grouped + by all values of metadata_key + """ + filter_str, params = self.make_filter_str( + account_uuid=account.uuid, + metadata_key=metadata_key, + time_end=time_end, + time_start=time_start, + ) + + # noinspection SqlShouldBeInGroupBy + res = self.pg_config.execute_sql_query( + query=f""" + SELECT + SUM(CASE WHEN direction = 1 THEN amount ELSE 0 END) AS debit, + SUM(CASE WHEN direction = -1 THEN amount ELSE 0 END) AS credit, + tm.value + FROM ledger_entry AS le + JOIN ledger_transaction AS lt + ON le.transaction_id = lt.id + JOIN ledger_transactionmetadata AS tm + ON lt.id = tm.transaction_id + JOIN ledger_account AS la + ON le.account_id = la.uuid + {filter_str} + GROUP BY tm.value + """, + params=params, + ) + if not res: + return {} + return { + x["value"]: {"credit": int(x["credit"]), "debit": int(x["debit"])} + for x in res + } + + def get_balances_timerange( + self, + time_start: Optional[AwareDatetime] = None, + time_end: Optional[AwareDatetime] = None, + ) -> Dict: + + filter_str, params = self.make_filter_str( + time_end=time_end, + time_start=time_start, + ) + + # noinspection SqlShouldBeInGroupBy + res = self.pg_config.execute_sql_query( + query=f""" + SELECT + la.*, + SUM(CASE WHEN direction = 1 THEN amount ELSE 0 END) AS debit, + SUM(CASE WHEN direction = -1 THEN amount ELSE 0 END) AS credit + FROM ledger_entry AS le + JOIN ledger_transaction AS lt + ON le.transaction_id = lt.id + JOIN ledger_account AS la + ON le.account_id = la.uuid + {filter_str} + GROUP BY la.uuid + """, + params=params, + ) + d = { + LedgerAccount.model_validate(x): { + "debit": x["debit"], + "credit": x["credit"], + } + for x in res + } + for k, v in d.items(): + v["total"] = (v["debit"] - v["credit"]) * k.normal_balance.value + return d diff --git a/generalresearch/managers/thl/ledger_manager/thl_ledger.py b/generalresearch/managers/thl/ledger_manager/thl_ledger.py new file mode 100644 index 0000000..977b68f --- /dev/null +++ b/generalresearch/managers/thl/ledger_manager/thl_ledger.py @@ -0,0 +1,1968 @@ +import logging +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from typing import Optional, Callable, Collection, List, TYPE_CHECKING +from uuid import UUID + +import numpy as np +import pandas as pd +from pydantic import AwareDatetime, PositiveInt + +from generalresearch.config import ( + JAMES_BILLINGS_BPID, + JAMES_BILLINGS_TX_CUTOFF, +) +from generalresearch.currency import USDCent +from generalresearch.managers.base import Permission +from generalresearch.managers.thl.ledger_manager.conditions import ( + generate_condition_mp_payment, + generate_condition_bp_payment, + generate_condition_bp_payout, + generate_condition_user_payout_request, + generate_condition_user_payout_action, + generate_condition_tag_exists, + generate_condition_enter_contest, +) +from generalresearch.managers.thl.ledger_manager.ledger import ( + LedgerManager, +) +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.thl.contest.contest import Contest +from generalresearch.models.thl.contest.definitions import ( + ContestPrizeKind, + ContestType, +) +from generalresearch.models.thl.contest.milestone import MilestoneContest +from generalresearch.models.thl.contest.raffle import ( + ContestEntry, + ContestEntryType, + RaffleContest, +) +from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + LedgerTransaction, + LedgerEntry, + AccountType, + TransactionType, + TransactionMetadataColumns as tmc, + UserLedgerTransactions, +) +from generalresearch.models.thl.payout import UserPayoutEvent +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.session import Status, Session, Wall +from generalresearch.models.thl.user import User +from generalresearch.models.thl.wallet import PayoutType + +if TYPE_CHECKING: + from generalresearch.models.thl.contest.contest import ContestWinner + +logging.basicConfig() +logger = logging.getLogger("LedgerManager") +logger.setLevel(logging.INFO) + + +class ThlLedgerManager(LedgerManager): + + def get_account_or_create_user_wallet(self, user: User) -> LedgerAccount: + """ + TODO: In the future we could create a user wallet account with a + currency other than USD (or test). This would be determined + by some BP config + """ + + assert user.user_id, "User must be saved" + + account = LedgerAccount( + display_name=f"User Wallet {user.uuid}", + qualified_name=f"{self.currency.value}:user_wallet:{user.uuid}", + normal_balance=Direction.CREDIT, + account_type=AccountType.USER_WALLET, + reference_type="user", + reference_uuid=user.uuid, + currency=self.currency, + ) + + return self.get_account_or_create(account=account) + + def get_account_or_create_bp_wallet_by_uuid( + self, product_uuid: UUIDStr + ) -> LedgerAccount: + assert UUID(product_uuid).hex == product_uuid, "Must provide a product_uuid" + account = LedgerAccount( + display_name=f"BP Wallet {product_uuid}", + qualified_name=f"{self.currency.value}:bp_wallet:{product_uuid}", + normal_balance=Direction.CREDIT, + account_type=AccountType.BP_WALLET, + reference_type="bp", + reference_uuid=product_uuid, + currency=self.currency, + ) + + return self.get_account_or_create(account=account) + + def get_account_or_create_bp_wallet(self, product: Product) -> LedgerAccount: + assert isinstance(product, Product), "Must provide a Product" + return self.get_account_or_create_bp_wallet_by_uuid(product_uuid=product.uuid) + + def get_account_or_create_bp_commission_by_uuid( + self, product_uuid: UUIDStr + ) -> LedgerAccount: + assert UUID(product_uuid).hex == product_uuid, "Must provide a product_uuid" + account = LedgerAccount( + display_name=f"Revenue from commission {product_uuid}", + qualified_name=f"{self.currency.value}:revenue:bp_commission:{product_uuid}", + normal_balance=Direction.CREDIT, + account_type=AccountType.REVENUE, + reference_type="bp", + reference_uuid=product_uuid, + currency=self.currency, + ) + return self.get_account_or_create(account=account) + + def get_account_or_create_bp_commission(self, product: Product) -> LedgerAccount: + assert isinstance(product, Product), "Must provide a Product" + return self.get_account_or_create_bp_commission_by_uuid( + product_uuid=product.uuid + ) + + def get_account_or_create_bp_expense( + self, product: Product, expense_name: str + ) -> LedgerAccount: + """ + Used exclusively for BP with managed user wallets. This account + tracks expenses associated with a BP, for e.g. 20% fee paid to + Amazon / Tango to issue gift cards / paypal. + + :param product: Product + :param expense_name: should be one of {'amt', 'tango', 'paypal'}. Could + grow as more payout methods are supported. + """ + return self.get_account_or_create_bp_expense_by_uuid( + product_uuid=product.uuid, expense_name=expense_name + ) + + def get_account_or_create_bp_expense_by_uuid( + self, product_uuid: UUIDStr, expense_name: str + ) -> LedgerAccount: + + account = LedgerAccount( + display_name=f"Expense {expense_name} {product_uuid}", + qualified_name=f"{self.currency.value}:expense:{expense_name}:{product_uuid}", + normal_balance=Direction.DEBIT, + account_type=AccountType.EXPENSE, + reference_type="bp", + reference_uuid=product_uuid, + currency=self.currency, + ) + + return self.get_account_or_create(account=account) + + def get_account_or_create_contest_wallet_by_uuid( + self, contest_uuid: UUIDStr + ) -> LedgerAccount: + assert UUID(contest_uuid).hex == contest_uuid, "Must provide a contest_uuid" + account = LedgerAccount( + display_name=f"Contest Wallet {contest_uuid}", + qualified_name=f"{self.currency.value}:{AccountType.CONTEST_WALLET.value}:{contest_uuid}", + normal_balance=Direction.CREDIT, + account_type=AccountType.CONTEST_WALLET, + reference_type="contest", + reference_uuid=contest_uuid, + currency=self.currency, + ) + + return self.get_account_or_create(account=account) + + def get_account_or_create_contest_wallet( + self, contest: RaffleContest + ) -> "LedgerAccount": + assert isinstance(contest, RaffleContest), "Must provide a RaffleContest" + return self.get_account_or_create_contest_wallet_by_uuid( + contest_uuid=contest.uuid + ) + + def get_or_create_bp_pending_payout_account( + self, product: "Product" + ) -> "LedgerAccount": + """ + Used exclusively for BP with managed user wallets. This account + holds funds that a BP's users have requested as payouts but are + still pending. Once the payout request is approved, the funds + move from here into an expense / cash account. + """ + + assert Permission.CREATE in self.permissions + + account = LedgerAccount( + display_name=f"BP Wallet Pending {product.id}", + qualified_name=f"{self.currency.value}:bp_wallet:pending:{product.id}", + normal_balance=Direction.CREDIT, + account_type=AccountType.BP_WALLET, + reference_type="bp", + reference_uuid=product.id, + currency=self.currency, + ) + + return self.get_account_or_create(account=account) + + def get_account_task_complete_revenue(self) -> "LedgerAccount": + return self.get_account( + qualified_name=f"{self.currency.value}:revenue:task_complete" + ) + + def get_account_cash(self) -> "LedgerAccount": + return self.get_account(qualified_name=f"{self.currency.value}:cash") + + def get_accounts_bp_wallet_for_products( + self, product_uuids: Collection[UUIDStr] + ) -> Collection[LedgerAccount]: + accounts = self.get_account_many( + qualified_names=[ + f"{self.currency.value}:bp_wallet:{p_uuid}" for p_uuid in product_uuids + ] + ) + assert len(accounts) == len(product_uuids) + + return accounts + + def get_tx_bp_payouts( + self, + account_uuids: Collection[UUIDStr], + time_start: Optional[datetime] = None, + time_end: Optional[datetime] = None, + ): + if time_start is None: + time_start = datetime(year=2017, month=1, day=1, tzinfo=timezone.utc) + + if time_end is None: + time_end = datetime.now(tz=timezone.utc) + + assert all( + isinstance(item, str) for item in account_uuids + ), "Must pass account_uuid as str" + + params = { + "time_start": time_start, + "time_end": time_end, + "tag_like": f"{self.currency.value}:bp_payout:%", + "account_uuids": list(account_uuids), + } + query = """ + SELECT lt.id, lt.tag, lt.created + FROM ledger_transaction AS lt + JOIN ledger_entry le ON lt.id = le.transaction_id + WHERE lt.created BETWEEN %(time_start)s AND %(time_end)s + AND tag LIKE %(tag_like)s + AND account_id = ANY(%(account_uuids)s); + """ + return self.pg_config.execute_sql_query(query=query, params=params) + + def create_tx_task_complete( + self, + wall: Wall, + user: User, + created: Optional[datetime] = None, + force=False, + ) -> PositiveInt: + """ + Create a transaction when we complete a task from a marketplace, + showing the marketplace paying us for the task complete. + + :param wall: the wall event that was completed + :param user: user who completed this wall event + :param created: should only be used for back-fill / testing. + Otherwise, == datetime.now() + :param force: If True, we skip the flag check to allow for retry of + a failed previous call. The locking and condition check still runs. + """ + f = lambda: self.create_tx_task_complete_(wall=wall, user=user, created=created) + + condition = generate_condition_mp_payment(wall=wall) + lock_key = f"{self.currency.value}:thl_wall:{wall.uuid}" + + return self.create_tx_protected( + lock_key=lock_key, + condition=condition, + create_tx_func=f, + skip_flag_check=force, + ) + + def create_tx_task_complete_( + self, wall: Wall, user: User, created: Optional[datetime] = None + ) -> LedgerTransaction: + + revenue_account = self.get_account_task_complete_revenue() + cash_account = self.get_account_cash() + metadata = { + tmc.USER: user.uuid, + tmc.WALL: wall.uuid, + tmc.SOURCE: wall.source, + tmc.TX_TYPE: TransactionType.MP_PAYMENT, + } + # This tag should uniquely identify this transaction (which should only happen once!) + tag = f"{self.currency.value}:mp_payment:{wall.uuid}" + amount = round(wall.cpi * 100) + entries = [ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=revenue_account.uuid, + amount=amount, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=cash_account.uuid, + amount=amount, + ), + ] + ext_description = f"Task Complete {wall.source.name} {wall.survey_id}" + t = self.create_tx( + entries=entries, + metadata=metadata, + tag=tag, + ext_description=ext_description, + created=created, + ) + + return t + + def create_tx_bp_payment( + self, session: Session, created: Optional[datetime] = None, force=False + ) -> LedgerTransaction: + """ + Create a transaction when we decide to report a session as complete + and make a payment to the BP and optionally to the user's wallet. + + :param session: the session event that was completed + :param created: should only be used for back-fill / testing. + Otherwise, == datetime.now() + :param force: If True, we skip the flag check to allow for retry of a + failed previous call. The locking and condition check still runs. + """ + assert session.status == Status.COMPLETE + assert session.payout > 0, "call session.determine_payments() first" + + f = lambda: self.create_tx_bp_payment_(session=session, created=created) + + condition = generate_condition_bp_payment(session) + lock_key = f"{self.currency.value}:thl_session:{session.uuid}" + + return self.create_tx_protected( + lock_key=lock_key, + condition=condition, + create_tx_func=f, + skip_flag_check=force, + ) + + def create_tx_bp_payment_( + self, session: Session, created: Optional[datetime] = None + ) -> LedgerTransaction: + user = session.user + assert user.product, "user.prefetch_product()" + assert session.payout > 0, "call session.determine_payments() first" + assert session.wall_events, "set session.wall_events first" + + metadata = { + tmc.USER: user.uuid, + tmc.SESSION: session.uuid, + tmc.TX_TYPE: TransactionType.BP_PAYMENT, + } + # This tag should uniquely identify this transaction (which should only happen once!) + tag = f"{self.currency.value}:bp_payment:{session.uuid}" + revenue_account = self.get_account_task_complete_revenue() + bp_wallet_account = self.get_account_or_create_bp_wallet(user.product) + bp_commission_account = self.get_account_or_create_bp_commission(user.product) + + # Don't use session.determine_payments() here, b/c during back-pop this may be changed + thl_net = Decimal( + sum(wall.cpi for wall in session.wall_events if wall.is_visible_complete()) + ) + thl_net = round(thl_net * 100) + bp_pay = round(session.payout * 100) + user_pay = ( + round(session.user_payout * 100) if session.user_payout is not None else 0 + ) + if bp_pay > thl_net: + # There are back-population issues (e.g. 5afcf8063ccb4662902ac727c2471202) + # were we paid the BP $0.39 for a $0.385 cpi complete. This is + # wrong because the round algorithm we use is HALF_EVEN, and so + # 0.385 should round to 0.38. + # https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even + logger.warning( + f"bp_pay {bp_pay} > thl_net {thl_net}. Capping bp_pay to thl_net." + ) + bp_pay = thl_net + if user_pay > bp_pay: + user_pay = bp_pay + + commission_amount = round(thl_net - bp_pay) + + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=revenue_account.uuid, + amount=thl_net, + ) + ] + + if commission_amount: + entries.append( + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=bp_commission_account.uuid, + amount=commission_amount, + ) + ) + + if user.product.user_wallet_enabled: + bp_pay -= user_pay + user_account = self.get_account_or_create_user_wallet(user) + + if bp_pay: + entries.append( + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=bp_wallet_account.uuid, + amount=bp_pay, + ) + ) + + if user_pay: + entries.append( + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=user_account.uuid, + amount=user_pay, + ) + ) + ext_description = f"BP & User Payment {session.uuid}" + + else: + entries.append( + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=bp_wallet_account.uuid, + amount=bp_pay, + ) + ) + ext_description = f"BP Payment {session.uuid}" + + t = self.create_tx( + entries=entries, + metadata=metadata, + tag=tag, + ext_description=ext_description, + created=created, + ) + + return t + + def create_tx_task_adjustment( + self, wall: Wall, user: User, created: Optional[datetime] = None + ) -> Optional[LedgerTransaction]: + """ + How is this different then create_tx_bp_adjustment + + """ + + if created is None: + created = wall.adjusted_timestamp + + revenue_account = self.get_account_task_complete_revenue() + cash_account = self.get_account_cash() + metadata = { + tmc.USER: user.uuid, + tmc.WALL: wall.uuid, + tmc.SOURCE: wall.source, + tmc.TX_TYPE: TransactionType.MP_ADJUSTMENT, + } + # This tag may not uniquely identify this tx, b/c it could get adjusted multiple times. + tag = f"{self.currency.value}:mp_adjustment:{wall.uuid}" + new_amount = round(wall.get_cpi_after_adjustment() * 100) + current_amount = self.get_account_filtered_balance( + account=revenue_account, + metadata_key="thl_wall", + metadata_value=wall.uuid, + ) + change_amount = new_amount - current_amount + + if change_amount > 0: + # Fail -> Complete: new_amt = 1, current_amt = 0, change = 1 + logger.info( + f"create_transaction_task_adjustment. current_amt: {current_amount}, new:amt: {new_amount}" + ) + entries = [ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=revenue_account.uuid, + amount=change_amount, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=cash_account.uuid, + amount=change_amount, + ), + ] + + elif change_amount < 0: + # Complete -> Fail: new_amt = 0, current_amt = 1, change = -1 + logger.info( + f"create_transaction_task_adjustment. current_amt: {current_amount}, new:amt: {new_amount}" + ) + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=revenue_account.uuid, + amount=abs(change_amount), + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=cash_account.uuid, + amount=abs(change_amount), + ), + ] + + else: + logger.info(f"create_transaction_task_adjustment. No transactions needed.") + return None + + amt_str = f"${abs(change_amount) / 100:,.2f}" + amt_str = amt_str if change_amount > 0 else "-" + amt_str + ext_description = ( + f"Task Adjustment {amt_str} {wall.source.name} {wall.survey_id}" + ) + + t = self.create_tx( + entries=entries, + metadata=metadata, + tag=tag, + ext_description=ext_description, + created=created, + ) + + return t + + def create_tx_bp_adjustment( + self, session: Session, created: Optional[datetime] = None + ) -> Optional[LedgerTransaction]: + """ + How is this different then create_tx_task_adjustment + """ + + if created is None: + created = session.adjusted_timestamp + user = session.user + assert user.product, "user.prefetch_product()" + metadata = { + tmc.USER: user.uuid, + tmc.SESSION: session.uuid, + tmc.TX_TYPE: TransactionType.BP_ADJUSTMENT, + } + # This tag may not uniquely identify this tx, b/c it could get adjusted multiple times. + tag = f"{self.currency.value}:bp_adjustment:{session.uuid}" + revenue_account = self.get_account_task_complete_revenue() + bp_wallet_account = self.get_account_or_create_bp_wallet(product=user.product) + bp_commission_account = self.get_account_or_create_bp_commission( + product=user.product + ) + + new_payout = round(session.get_payout_after_adjustment() * 100) + thl_net = session.get_thl_net() + new_commission = round(user.product.determine_bp_commission(thl_net) * 100) + + current_commission = self.get_account_filtered_balance( + account=bp_commission_account, + metadata_key="thl_session", + metadata_value=session.uuid, + ) + change_commission = new_commission - current_commission + logger.info( + [ + "commissions: ", + new_commission, + current_commission, + change_commission, + ] + ) + + user_amt_str = "" + if user.product.user_wallet_enabled: + # If the user wallet is enabled, the user_payout "comes out" of + # the payout + payout_after_adj: Optional[Decimal] = ( + session.get_user_payout_after_adjustment() + ) + if payout_after_adj is None: + logger.info("session.get_user_payout_after_adjustment() return None") + return None + + new_user_payout = round(payout_after_adj * 100) + new_bp_payout = new_payout - new_user_payout + current_bp_payout = self.get_account_filtered_balance( + account=bp_wallet_account, + metadata_key="thl_session", + metadata_value=session.uuid, + ) + user_account = self.get_account_or_create_user_wallet(user) + current_user_payout = self.get_account_filtered_balance( + account=user_account, + metadata_key="thl_session", + metadata_value=session.uuid, + ) + change_bp_payout = new_bp_payout - current_bp_payout + change_user_payout = new_user_payout - current_user_payout + logger.info( + f"changes: {change_bp_payout}, {change_user_payout}, {change_commission}" + ) + user_amt_str = f"${abs(change_user_payout) / 100:,.2f}" + user_amt_str = ( + user_amt_str if change_user_payout > 0 else "-" + user_amt_str + ) + if change_bp_payout != 0: + entries = [ + LedgerEntry.from_amount( + account_uuid=revenue_account.uuid, + amount=( + change_bp_payout + change_commission + change_user_payout + ) + * -1, + ), + LedgerEntry.from_amount( + account_uuid=bp_wallet_account.uuid, + amount=change_bp_payout, + ), + ] + + if change_commission: + entries.append( + LedgerEntry.from_amount( + account_uuid=bp_commission_account.uuid, + amount=change_commission, + ) + ) + + if change_user_payout: + entries.append( + LedgerEntry.from_amount( + account_uuid=user_account.uuid, + amount=change_user_payout, + ) + ) + + else: + logger.info( + f"create_transaction_bp_adjustment. No transactions needed." + ) + return None + else: + new_bp_payout = new_payout + current_bp_payout = self.get_account_filtered_balance( + account=bp_wallet_account, + metadata_key="thl_session", + metadata_value=session.uuid, + ) + change_bp_payout = new_bp_payout - current_bp_payout + logger.info(f"changes: {change_bp_payout}, {change_commission}") + if change_bp_payout > 0: + # Fail -> Complete + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=revenue_account.uuid, + amount=change_bp_payout + change_commission, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=bp_wallet_account.uuid, + amount=change_bp_payout, + ), + ] + + # This is a very rare occurrence, but the change_commission + # could be negative if the BP's commission pct changed, and + # now the commission amount is lower even though a complete + # happened. This would only happen if the session had a + # complete already. + # + # e.x. $5 complete, 10% commission -> $4.50 payout, $0.50 + # comm. Now F->C a $1 event in the session, and the + # commission changed to 5%: total $6 complete, 5% + # commission -> $5.70 payout, $.30 comm. + # So the payout increased but the commission decreased. + + if change_commission: + entries.append( + LedgerEntry.from_amount( + account_uuid=bp_commission_account.uuid, + amount=change_commission, + ) + ) + + elif change_bp_payout < 0: + # Complete -> Fail + entries = [ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=revenue_account.uuid, + amount=abs(change_bp_payout + change_commission), + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=bp_wallet_account.uuid, + amount=abs(change_bp_payout), + ), + ] + if change_commission: + entries.append( + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=bp_commission_account.uuid, + amount=abs(change_commission), + ) + ) + + else: + logger.info( + f"create_transaction_bp_adjustment. No transactions needed." + ) + return None + + logger.info(entries) + amt_str = f"${abs(change_bp_payout) / 100:,.2f}" + amt_str = amt_str if change_bp_payout > 0 else "-" + amt_str + ext_description = f"Session BP Payment Adj. {amt_str} {session.uuid}" + if user_amt_str: + ext_description += f" User Payment Adj. {user_amt_str}" + + t = self.create_tx( + entries=entries, + metadata=metadata, + tag=tag, + ext_description=ext_description, + created=created, + ) + + return t + + def create_tx_bp_payout( + self, + product: Product, + amount: USDCent, + payoutevent_uuid: UUIDStr, + created: AwareDatetime, + skip_wallet_balance_check=False, + skip_one_per_day_check=False, + skip_flag_check=False, + ) -> LedgerTransaction: + """This is when we pay "OUT" a BP their wallet balance. (Not a + payment for a task complete) + + - We're by default allowing 1 tx per BP per day. Set + allow_multiple_per_day to allow 1 tx per BP per minute. + - Checks to make sure the BP has at least amount in their wallet. + Set skip_wallet_balance_check to skip this check. + + :param product: The BP to pay + :param amount: The amount to pay out of the BP's wallet + :param payoutevent_uuid: Associates the ledger tx with a payout + event. This is also used to de-duplicate (only 1 tx per + payoutevent). + :param created: When this was paid. Can not be in the future. + :param skip_wallet_balance_check: Skips the condition checking the + BP has >= amount in their wallet. + :param skip_one_per_day_check: Skips the condition check of only + allowing 1 tx per BP per day. + :param skip_flag_check: If True, we skip the redis flag check to allow + for retry of a failed previous call. The Locking and condition + checks still run. + """ + + assert isinstance(amount, int) + assert isinstance(amount, USDCent) + + if skip_one_per_day_check or skip_wallet_balance_check: + skip_flag_check = True + + assert ( + datetime.now(tz=timezone.utc) > created + ), "created cannot be in the future" + f = lambda: self.create_tx_bp_payout_( + product=product, + amount=amount, + payoutevent_uuid=payoutevent_uuid, + created=created, + ) + + condition: Callable = generate_condition_bp_payout( + product=product, + amount=amount, + payoutevent_uuid=payoutevent_uuid, + skip_one_per_day_check=skip_one_per_day_check, + skip_wallet_balance_check=skip_wallet_balance_check, + ) + + lock_key = f"{self.currency.value}:bp_payout:{product.id}" + flag_key = f"{self.currency.value}:bp_payout:{payoutevent_uuid}" + return self.create_tx_protected( + lock_key=lock_key, + condition=condition, + create_tx_func=f, + skip_flag_check=skip_flag_check, + flag_key=flag_key, + ) + + def create_tx_bp_payout_( + self, + product: Product, + amount: USDCent, + payoutevent_uuid: UUIDStr, + created: datetime, + ) -> LedgerTransaction: + + metadata = { + tmc.TX_TYPE: TransactionType.BP_PAYOUT, + tmc.EVENT: payoutevent_uuid, + } + # This tag might will uniquely identify this tx + tag = f"{self.currency.value}:bp_payout:{payoutevent_uuid}" + cash_account = self.get_account_cash() + bp_wallet_account = self.get_account_or_create_bp_wallet(product) + + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=bp_wallet_account.uuid, + amount=amount, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=cash_account.uuid, + amount=amount, + ), + ] + + ext_description = f"BP Payout" + t = self.create_tx( + entries=entries, + metadata=metadata, + tag=tag, + ext_description=ext_description, + created=created, + ) + + return t + + def create_tx_plug_bp_wallet( + self, + product: Product, + amount: USDCent, + created: AwareDatetime, + direction: Direction = Direction.DEBIT, + description: Optional[str] = None, + skip_flag_check=False, + ) -> LedgerTransaction: + """https://en.wikipedia.org/wiki/Plug_(accounting) + + The typical use case here to create a transaction to make up for + discrepancies in what our ledger shows versus what was actually paid + out to a BP. This may be due to receiving reconciliations from a + marketplace (which are in our ledger), but never actually being paid + for them. As such, we did not pay our suppliers for them. The plug + is temporary and can be reversed once marketplace payments are + reconciled. + + :param product: The account to create the transaction for is the + product's bp_wallet account. By default, the transaction is + balanced with the Cash account. + + :param amount: The amount for the transaction in USDCents. + + :param created: When this was paid. Can not be in the future. + + :param direction: A Direction.DEBIT will decrease the BP's wall + balance amount. A Direction.CREDIT will increase the BP's + balance amount. By default, we will always want to decrease + a BP Wallet amount. + + :param description + + :param skip_flag_check: If True, we skip the flag check to allow + for retry of a failed previous call. + """ + assert ( + datetime.now(tz=timezone.utc) > created + ), "created cannot be in the future" + assert isinstance(amount, int) + assert isinstance(amount, USDCent) + + f = lambda: self.create_tx_plug_bp_wallet_( + product=product, + amount=amount, + created=created, + direction=direction, + description=description, + ) + + # This tag won't necessarily uniquely identify this tx, as we could + # make multiple per year + tag = f"{self.currency.value}:plug:{product.id}:{created.strftime('%Y-%m-%d')}" + condition = lambda x: len(self.get_tx_ids_by_tag(tag)) == 0 + lock_key = f"{self.currency.value}:plug:{product.id}" + flag_key = tag + + return self.create_tx_protected( + lock_key=lock_key, + condition=condition, + create_tx_func=f, + skip_flag_check=skip_flag_check, + flag_key=flag_key, + ) + + def create_tx_plug_bp_wallet_( + self, + product: Product, + amount: USDCent, + created: AwareDatetime, + direction: Direction, + description: Optional[str] = None, + ) -> LedgerTransaction: + + assert isinstance(amount, int) + assert isinstance(amount, USDCent) + + tag = f"{self.currency.value}:plug:{product.id}:{created.strftime('%Y-%m-%d')}" + metadata = {tmc.TX_TYPE: TransactionType.PLUG} + cash_account = self.get_account_cash() + bp_wallet_account = self.get_account_or_create_bp_wallet(product) + + match direction: + case Direction.DEBIT: + # Decrease the BP Wall balance (take away Supplier money) + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=bp_wallet_account.uuid, + amount=amount, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=cash_account.uuid, + amount=amount, + ), + ] + case Direction.CREDIT: + # Increase the BP Wall balance (giving the Supplier money) + entries = [ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=bp_wallet_account.uuid, + amount=amount, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=cash_account.uuid, + amount=amount, + ), + ] + case _: + raise ValueError("Invalid Direction") + + if description is None: + description = f"BP Plug" + + t = self.create_tx( + entries=entries, + metadata=metadata, + tag=tag, + ext_description=description, + created=created, + ) + + return t + + def create_tx_user_payout_request( + self, + user: User, + payout_event: UserPayoutEvent, + created: Optional[datetime] = None, + skip_flag_check: Optional[bool] = False, + skip_wallet_balance_check: Optional[bool] = False, + ) -> LedgerTransaction: + """ + The funds move from the user's wallet into the BP's "pending" + wallet. Then, once the cashout request is completed, the + funds will be taken from the BP's pending wallet and the + commission will be recorded. + + Note: We are assuming the user that is requesting the payout is + requesting from their USD wallet. No other currencies are + supported now. + """ + assert ( + user.product.user_wallet_enabled + ), "Can only call this on an wallet enabled BPs" + amount = USDCent(payout_event.amount) + + amt_str = f"${int(amount) / 100:,.2f}" + descriptions = { + PayoutType.AMT_HIT: f"User Payout AMT Assignment Request {amt_str}", + PayoutType.AMT_BONUS: f"User Payout AMT Bonus Request {amt_str}", + PayoutType.PAYPAL: f"User Payout Paypal Request {amt_str}", + PayoutType.CASH_IN_MAIL: f"User Payout Cash Request {amt_str}", + PayoutType.TANGO: f"User Payout Tango Request {amt_str}", + } + description = descriptions[payout_event.payout_type] + + if payout_event.payout_type in { + PayoutType.AMT_HIT, + PayoutType.AMT_BONUS, + }: + """ + This is for AMT accounts only (currently JB). This is the + payment of a either 1) 1c or 5c (typically) assignment or 2) a + bonus for task complete to the user. The 20% commission will + be taken from the BP's wallet once the tx is completed. + """ + assert ( + user.product.user_wallet_amt + ), "Can only call this on an AMT-enabled BPs" + + f = lambda: self.create_tx_user_payout_request_( + user=user, + payout_event=payout_event, + description=description, + created=created, + ) + + min_balance: Optional[int] = int(amount) + if payout_event.payout_type == PayoutType.AMT_HIT: + # We allow the user's balance to reach up to -$1.00. + min_balance = -100 + amount + + if skip_wallet_balance_check: + min_balance = None + + condition: Callable = generate_condition_user_payout_request( + user=user, + payoutevent_uuid=payout_event.uuid, + min_balance=min_balance, + ) + + lock_key = f"{self.currency.value}:user_payout:{user.uuid}" + flag_key = f"{self.currency.value}:user_payout:{payout_event.uuid}:request" + + return self.create_tx_protected( + lock_key=lock_key, + flag_key=flag_key, + condition=condition, + create_tx_func=f, + skip_flag_check=skip_flag_check, + ) + + def create_tx_user_payout_complete( + self, + user: User, + payout_event: UserPayoutEvent, + created: Optional[datetime] = None, + fee_amount: Optional[Decimal] = None, + skip_flag_check: Optional[bool] = False, + ) -> LedgerTransaction: + """ + Once the cashout request is approved and completed, the funds + are taken from the BP's pending wallet, the commission will be + recorded, and the cash debited. + """ + assert ( + user.product.user_wallet_enabled + ), "Can only call this on an wallet enabled BPs" + + # Before we even do anything, we should check that a ledger tx exists for the request + request_tag = f"{self.currency.value}:user_payout:{payout_event.uuid}:request" + txs = self.get_tx_ids_by_tag(request_tag) + if len(txs) != 1: + raise ValueError( + f"Trying to complete user payout {payout_event.uuid} with no request tx found." + ) + + amount_usd = Decimal(payout_event.amount) / 100 + amt_str = f"${amount_usd:,.2f}" + descriptions = { + PayoutType.AMT_HIT: f"User Payout AMT Assignment Complete {amt_str}", + PayoutType.AMT_BONUS: f"User Payout AMT Bonus Complete {amt_str}", + PayoutType.PAYPAL: f"User Payout Paypal Complete {amt_str}", + PayoutType.CASH_IN_MAIL: f"User Payout Cash Complete {amt_str}", + PayoutType.TANGO: f"User Payout Tango Complete {amt_str}", + } + description = descriptions[payout_event.payout_type] + bp_wallet_account = self.get_account_or_create_bp_wallet(user.product) + + if payout_event.payout_type in { + PayoutType.AMT_HIT, + PayoutType.AMT_BONUS, + }: + assert ( + user.product.user_wallet_amt + ), "Can only call this on an AMT-enabled BP" + bp_expense_account = self.get_account_or_create_bp_expense( + product=user.product, expense_name="amt" + ) + + if fee_amount is None: + fee_amount = (amount_usd * Decimal("0.2")).quantize( + Decimal("0.01") + ) or Decimal("0.01") + + elif payout_event.payout_type == PayoutType.PAYPAL: + bp_expense_account = self.get_account_or_create_bp_expense( + product=user.product, expense_name="paypal" + ) + assert fee_amount is not None, "must set fee_amount" + + elif payout_event.payout_type == PayoutType.CASH_IN_MAIL: + bp_expense_account = self.get_account_or_create_bp_expense( + product=user.product, expense_name=PayoutType.CASH_IN_MAIL + ) + assert fee_amount is not None, "must set fee_amount" + + elif payout_event.payout_type == PayoutType.TANGO: + bp_expense_account = self.get_account_or_create_bp_expense( + product=user.product, expense_name="tango" + ) + if fee_amount is None: + fee_amount = (amount_usd * Decimal("0.035")).quantize(Decimal("0.01")) + else: + raise NotImplementedError() + + f = lambda: self.create_tx_user_payout_complete_( + user=user, + payout_event=payout_event, + fee_expense_account=bp_expense_account, + fee_payer_account=bp_wallet_account, + fee_amount=fee_amount, + description=description, + created=created, + ) + + condition = generate_condition_user_payout_action( + payout_event.uuid, action="complete" + ) + lock_key = f"{self.currency.value}:user_payout:{user.uuid}" + flag_key = f"{self.currency.value}:user_payout:{payout_event.uuid}:complete" + + return self.create_tx_protected( + lock_key=lock_key, + flag_key=flag_key, + condition=condition, + create_tx_func=f, + skip_flag_check=skip_flag_check, + ) + + def create_tx_user_payout_cancelled( + self, + user: User, + payout_event: UserPayoutEvent, + created: Optional[datetime] = None, + skip_flag_check: Optional[bool] = False, + ) -> LedgerTransaction: + assert ( + user.product.user_wallet_enabled + ), "Can only call this on an wallet enabled BPs" + + # Before we even do anything, we should check that a ledger tx exists for the request + request_tag = f"{self.currency.value}:user_payout:{payout_event.uuid}:request" + txs = self.get_tx_ids_by_tag(request_tag) + if len(txs) != 1: + raise ValueError( + f"Trying to cancel user payout {payout_event.uuid} with no request tx found." + ) + + description = f"User Payout Cancelled" + f = lambda: self.create_tx_user_payout_cancelled_( + user=user, + payout_event=payout_event, + description=description, + created=created, + ) + + condition = generate_condition_user_payout_action( + payoutevent_uuid=payout_event.uuid, action="cancel" + ) + lock_key = f"{self.currency.value}:user_payout:{user.uuid}" + flag_key = f"{self.currency.value}:user_payout:{payout_event.uuid}:cancel" + + return self.create_tx_protected( + lock_key=lock_key, + flag_key=flag_key, + condition=condition, + create_tx_func=f, + skip_flag_check=skip_flag_check, + ) + + def create_tx_user_payout_request_( + self, + user: User, + payout_event: UserPayoutEvent, + description: str, + created: Optional[datetime] = None, + ) -> LedgerTransaction: + # This is the same for all user payout requests, regardless of the + # payout_type (paypal, amt, tango) + metadata = { + tmc.USER: user.uuid, + tmc.TX_TYPE: TransactionType.USER_PAYOUT_REQUEST, + tmc.EVENT2: payout_event.uuid, + tmc.PAYOUT_TYPE: payout_event.payout_type.value, + } + # This tag uniquely identifies this tx + tag = f"{self.currency.value}:user_payout:{payout_event.uuid}:request" + bp_pending_account = self.get_or_create_bp_pending_payout_account( + product=user.product + ) + # The USD assumption is "enforced" here, in that this call gets user's USD wallet. + user_wallet_account = self.get_account_or_create_user_wallet(user=user) + amount_cents = payout_event.amount + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=user_wallet_account.uuid, + amount=amount_cents, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=bp_pending_account.uuid, + amount=amount_cents, + ), + ] + + t = self.create_tx( + entries=entries, + metadata=metadata, + tag=tag, + ext_description=description, + created=created, + ) + + return t + + def create_tx_user_payout_complete_( + self, + user: User, + payout_event: UserPayoutEvent, + fee_expense_account: LedgerAccount, + fee_payer_account: LedgerAccount, + fee_amount: Decimal, + description: str, + created: Optional[datetime] = None, + ) -> LedgerTransaction: + """ + Creates the LedgerTransaction for a completed user payout request. + + :param user: The user who is requesting the payout. The `amount` comes + from this user's wallet + :param payout_event: The payout event associated with this tx + :param fee_expense_account: Which account records the expense + associated with the transaction fee. + :param fee_payer_account: Which account actually pays the transaction + fee. Typically, this is the BP's wallet. + :param fee_amount: The amount of the transaction fee. + :param created: Whe the payout was completed + + :return: the ledger transaction + """ + + # TODO: The fee_payer_account must be the bp_wallet_account, or else + # we must change the user_payout_request logic to hold the fee + # amount from the user's wallet as well. + bp_wallet_account = self.get_account_or_create_bp_wallet(user.product) + assert fee_payer_account == bp_wallet_account, "unsupported fee_payer_account" + + metadata = { + tmc.USER: user.uuid, + tmc.TX_TYPE: TransactionType.USER_PAYOUT_COMPLETE, + tmc.EVENT2: payout_event.uuid, + tmc.PAYOUT_TYPE: payout_event.payout_type, + } + # This tag uniquely identifies this tx + tag = f"{self.currency.value}:user_payout:{payout_event.uuid}:complete" + cash_account = self.get_account_cash() + bp_pending_account = self.get_or_create_bp_pending_payout_account( + product=user.product + ) + + amount_cents = payout_event.amount + fee_cents = round(fee_amount * 100) + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=bp_pending_account.uuid, + amount=amount_cents, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=cash_account.uuid, + amount=amount_cents, + ), + ] + + if fee_cents: + entries.extend( + [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=fee_payer_account.uuid, + amount=fee_cents, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=fee_expense_account.uuid, + amount=fee_cents, + ), + ] + ) + + t = self.create_tx( + entries=entries, + metadata=metadata, + tag=tag, + ext_description=description, + created=created, + ) + + return t + + def create_tx_user_payout_cancelled_( + self, + user: User, + payout_event: UserPayoutEvent, + description: str, + created: Optional[datetime] = None, + ) -> LedgerTransaction: + assert user.product + + metadata = { + tmc.USER: user.uuid, + tmc.TX_TYPE: TransactionType.USER_PAYOUT_CANCEL, + tmc.EVENT2: payout_event.uuid, + tmc.PAYOUT_TYPE: payout_event.payout_type, + } + # This tag uniquely identifies this tx + tag = f"{self.currency.value}:user_payout:{payout_event.uuid}:cancel" + bp_pending_account = self.get_or_create_bp_pending_payout_account(user.product) + user_wallet_account = self.get_account_or_create_user_wallet(user) + amount_cents: int = payout_event.amount + + entries = [ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=user_wallet_account.uuid, + amount=amount_cents, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=bp_pending_account.uuid, + amount=amount_cents, + ), + ] + + t = self.create_tx( + entries=entries, + metadata=metadata, + tag=tag, + ext_description=description, + created=created, + ) + + return t + + def create_tx_user_bonus( + self, + user: User, + amount: Decimal, + ref_uuid: UUIDStr, + description: str, + source_account: Optional[LedgerAccount] = None, + created: Optional[datetime] = None, + skip_flag_check: Optional[bool] = False, + ) -> LedgerTransaction: + """ + Pay a user into their wallet balance. There is no fee here. There + is only a fee when the user requests a payout. The bonus could + be as a bribe, winnings for a contest, leaderboard, etc. + + :param source_account: Is this paid from the bp's wallet? or from us? + """ + assert ( + user.product.user_wallet_enabled + ), "Can only call this on an wallet enabled BPs" + assert user.product, "user.prefetch_product()" + + # This tag should uniquely id this tx. + f = lambda: self.create_tx_user_bonus_( + user=user, + amount=amount, + ref_uuid=ref_uuid, + description=description, + source_account=source_account, + created=created, + ) + + tag = f"{self.currency.value}:user_bonus:{ref_uuid}" + condition = generate_condition_tag_exists(tag) + + return self.create_tx_protected( + lock_key=tag, + condition=condition, + create_tx_func=f, + skip_flag_check=skip_flag_check, + ) + + def create_tx_user_bonus_( + self, + user: User, + amount: Decimal, + ref_uuid: UUIDStr, + description: str, + source_account: Optional[LedgerAccount] = None, + created: Optional[datetime] = None, + ) -> LedgerTransaction: + + metadata = { + tmc.USER: user.uuid, + tmc.TX_TYPE: TransactionType.USER_BONUS, + tmc.BONUS: ref_uuid, + } + tag = f"{self.currency.value}:user_bonus:{ref_uuid}" + user_account = self.get_account_or_create_user_wallet(user) + + # TODO: the source_account could be a separate account than the + # BP's main wallet account .. + bp_account = self.get_account_or_create_bp_wallet(product=user.product) + if source_account: + assert source_account == bp_account, "not supported" + + amount_cents = round(amount * 100) + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=bp_account.uuid, + amount=amount_cents, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=user_account.uuid, + amount=amount_cents, + ), + ] + + return self.create_tx( + entries=entries, + metadata=metadata, + tag=tag, + ext_description=description, + created=created, + ) + + def create_tx_user_enter_contest( + self, + contest_uuid: UUIDStr, + contest_entry: ContestEntry, + skip_flag_check: Optional[bool] = False, + ) -> LedgerTransaction: + """ + User is requesting to enter a Raffle Contest. We'll DEBIT + funds from their wallet and CREDIT the contest wallet. + """ + assert ( + contest_entry.entry_type == ContestEntryType.CASH + ), "Can only call this for CASH Contests" + user = contest_entry.user + assert ( + user.product.user_wallet_enabled + ), "Can only call this on an wallet enabled BPs" + assert user.product, "user.prefetch_product()" + amount = contest_entry.amount + entry_uuid = contest_entry.uuid + created = contest_entry.created_at + + f = lambda: self.create_tx_user_enter_contest_( + user=user, + amount=amount, + contest_uuid=contest_uuid, + tag=tag, + created=created, + ) + + # This tag should uniquely id this tx. + tag = f"{self.currency.value}:enter_contest:{entry_uuid}" + # Checks that the user has at least this balance and that a tx with tag doesn't exist + condition = generate_condition_enter_contest( + user=user, tag=tag, min_balance=amount + ) + # Lock the whole thing along with any tx that the user can do to spend money + lock_key = f"{self.currency.value}:user_payout:{user.uuid}" + return self.create_tx_protected( + lock_key=lock_key, + flag_key=tag, + condition=condition, + create_tx_func=f, + skip_flag_check=skip_flag_check, + ) + + def create_tx_user_enter_contest_( + self, + user: User, + amount: USDCent, + contest_uuid: UUIDStr, + tag: str, + created: Optional[datetime] = None, + ) -> LedgerTransaction: + description = f"Enter contest {amount.to_usd_str()} {contest_uuid}" + metadata = { + tmc.USER: user.uuid, + tmc.TX_TYPE: TransactionType.USER_ENTER_CONTEST, + tmc.CONTEST: contest_uuid, + } + user_account = self.get_account_or_create_user_wallet(user) + contest_account = self.get_account_or_create_contest_wallet_by_uuid( + contest_uuid=contest_uuid + ) + + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=user_account.uuid, + amount=int(amount), + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=contest_account.uuid, + amount=int(amount), + ), + ] + + return self.create_tx( + entries=entries, + metadata=metadata, + tag=tag, + ext_description=description, + created=created, + ) + + def create_tx_contest_close( + self, + contest: Contest, + skip_flag_check: Optional[bool] = False, + ) -> LedgerTransaction: + """ + Contest is over. For each winner, we make a transaction. + If the prize is physical, the money goes into a prize-expense + account (for that BP), and if the prize is monetary, the money + goes into the winner's wallet. + + Any remaining money goes back into the BP's wallet ? todo + """ + if contest.contest_type in {ContestType.RAFFLE, ContestType.MILESTONE}: + assert ( + contest.entry_type == ContestEntryType.CASH + ), "Can only call this for CASH Contests" + + contest_account = self.get_account_or_create_contest_wallet_by_uuid( + contest_uuid=contest.uuid + ) + bp_wallet = self.get_account_or_create_bp_wallet_by_uuid(contest.product_id) + bp_prize_expense_account = self.get_account_or_create_bp_expense_by_uuid( + contest.product_id, expense_name="Prize" + ) + + contest_account_balance = self.get_account_balance(contest_account) + print(f"{contest_account_balance=}") + + entries = [] + for w in contest.all_winners: + if w.prize.kind == ContestPrizeKind.CASH: + user_wallet = self.get_account_or_create_user_wallet(w.user) + entries.extend( + [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=contest_account.uuid, + amount=int(w.prize.estimated_cash_value), + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=user_wallet.uuid, + amount=int(w.prize.estimated_cash_value), + ), + ] + ) + elif w.prize.kind == ContestPrizeKind.PHYSICAL: + entries.extend( + [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=contest_account.uuid, + amount=int(w.prize.estimated_cash_value), + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=bp_prize_expense_account.uuid, + amount=int(w.prize.estimated_cash_value), + ), + ] + ) + else: + # The prize is a promotion. It has no cash value now! The money goes + # back into the BP's wallet. The BP will eventually (supposedly) + # have to pay the expense of the promotion (e.g. 50% bonus on completes) + # once the user actually "redeems" the promotion. + pass + prize_value = sum( + [ + w.prize.estimated_cash_value + for w in contest.all_winners + if w.prize.kind in {ContestPrizeKind.CASH, ContestPrizeKind.PHYSICAL} + ] + ) + if prize_value > contest_account_balance: + logger.warning("Paying out more than the balance!") + extra_expense = prize_value - contest_account_balance + # Debit this balance from the BP's wallet + entries.extend( + [ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=contest_account.uuid, + amount=int(extra_expense), + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=bp_wallet.uuid, + amount=int(extra_expense), + ), + ] + ) + elif prize_value < contest_account_balance: + extra_income = contest_account_balance - prize_value + # BP's wallet gets the overage + entries.extend( + [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=contest_account.uuid, + amount=int(extra_income), + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=bp_wallet.uuid, + amount=int(extra_income), + ), + ] + ) + + f = lambda: self.create_tx_contest_close_( + entries=entries, + contest_uuid=contest.uuid, + tag=tag, + created=contest.ended_at, + ) + + # This tag should uniquely id this tx. + tag = f"{self.currency.value}:contest_close:{contest.uuid}" + # Checks that a tx with tag doesn't exist + condition = generate_condition_tag_exists(tag=tag) + # Lock the whole thing by this contest id + lock_key = tag + tx = self.create_tx_protected( + lock_key=lock_key, + flag_key=tag, + condition=condition, + create_tx_func=f, + skip_flag_check=skip_flag_check, + ) + assert self.get_account_balance(contest_account) == 0 + return tx + + def create_tx_contest_close_( + self, + entries: List[LedgerEntry], + contest_uuid: UUIDStr, + tag: str, + created: Optional[datetime] = None, + ) -> LedgerTransaction: + description = f"Close contest {contest_uuid}" + metadata = { + tmc.TX_TYPE: TransactionType.CLOSE_CONTEST, + tmc.CONTEST: contest_uuid, + } + return self.create_tx( + entries=entries, + metadata=metadata, + tag=tag, + ext_description=description, + created=created, + ) + + def create_tx_milestone_winner( + self, + contest: MilestoneContest, + winners: List["ContestWinner"], + skip_flag_check: Optional[bool] = False, + ) -> LedgerTransaction: + """ + A user has reached a milestone. Pay out any cash or physical prizes, + coming from the BP's wallet. + """ + assert isinstance(contest, MilestoneContest), "invalid contest type" + assert all(w.user.user_id for w in winners), "user must be set" + assert len({w.user.user_id for w in winners}) == 1, "Cannot mix users" + user = winners[0].user + created_at = winners[0].created_at + + bp_wallet = self.get_account_or_create_bp_wallet_by_uuid(contest.product_id) + bp_prize_expense_account = self.get_account_or_create_bp_expense_by_uuid( + contest.product_id, expense_name="Prize" + ) + + entries = [] + for w in winners: + if w.prize.kind == ContestPrizeKind.CASH: + user_wallet = self.get_account_or_create_user_wallet(w.user) + entries.extend( + [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=bp_wallet.uuid, + amount=int(w.prize.estimated_cash_value), + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=user_wallet.uuid, + amount=int(w.prize.estimated_cash_value), + ), + ] + ) + elif w.prize.kind == ContestPrizeKind.PHYSICAL: + entries.extend( + [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=bp_wallet.uuid, + amount=int(w.prize.estimated_cash_value), + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=bp_prize_expense_account.uuid, + amount=int(w.prize.estimated_cash_value), + ), + ] + ) + else: + # The prize is a promotion. It has no cash value now! + # The BP will pay any expenses associated with it. + pass + + f = lambda: self.create_tx_milestone_winner_( + entries=entries, + contest_uuid=contest.uuid, + user_uuid=user.uuid, + tag=tag, + created=created_at, + ) + + # This tag should uniquely id this tx. + tag = f"{self.currency.value}:contest_milestone:{contest.uuid}:{user.user_id}" + # Checks that a tx with tag doesn't exist + condition = generate_condition_tag_exists(tag=tag) + lock_key = tag + tx = self.create_tx_protected( + lock_key=lock_key, + flag_key=tag, + condition=condition, + create_tx_func=f, + skip_flag_check=skip_flag_check, + ) + return tx + + def create_tx_milestone_winner_( + self, + entries: List[LedgerEntry], + contest_uuid: UUIDStr, + user_uuid: UUIDStr, + tag: str, + created: Optional[datetime] = None, + ) -> LedgerTransaction: + description = f"Milestone award {contest_uuid}" + metadata = { + tmc.TX_TYPE: TransactionType.USER_MILESTONE, + tmc.CONTEST: contest_uuid, + tmc.USER: user_uuid, + } + return self.create_tx( + entries=entries, + metadata=metadata, + tag=tag, + ext_description=description, + created=created, + ) + + def get_user_wallet_balance( + self, user: User, since_days_ago: Optional[int] = None + ) -> int: + """ + Calculates all payments to user's wallet minus all payouts from + user's wallet. The user's wallet is a credit normal account, so + the balance is the sum of credits minus the sum of debits, which + should typically be positive (if the user has money available). + + :param user: User + :param since_days_ago: if None, we get over all time + :returns wallet balance in integer cents + """ + user.prefetch_product(self.pg_config) + assert ( + user.product.user_wallet_config.enabled + ), "Can't get wallet balance on non-managed account." + + now = datetime.now(tz=timezone.utc) + wallet = self.get_account_or_create_user_wallet(user) + if user.product_id == JAMES_BILLINGS_BPID: + assert since_days_ago is None + return self.get_account_balance_timerange( + wallet, time_start=JAMES_BILLINGS_TX_CUTOFF, time_end=now + ) + if since_days_ago: + start_dt = now - timedelta(days=since_days_ago) + return self.get_account_balance_timerange( + wallet, time_start=start_dt, time_end=now + ) + return self.get_account_balance(wallet) + + def get_user_redeemable_wallet_balance( + self, user: User, user_wallet_balance: int + ) -> PositiveInt: + """ + Returns the amount (from the user's wallet) that is currently + redeemable. This amount will be less than or equal to the + user_wallet_balance and non-negative. + + In the future, we want to model the risk of recon by day and by survey + buyer's historical recon behavior, but for now: + + Looking at historical data, we can expect that for the worst 5% of users + to get ~30 % of their completes reconciled. + After 3 days, about 25% of all "future" recons have happened, + 7 days: 50%, 14 days: 75%, till end of next month: 100%. + """ + now = datetime.now(tz=timezone.utc) + # The redeemable balance can NOT ever be more than the actual user_wallet_balance + + # Sum up the redeemable amount for each complete + user_id = user.user_id + wall = pd.DataFrame( + self.pg_config.execute_sql_query( + query=""" + SELECT finished, COALESCE(adjusted_user_payout, user_payout) as user_payout + FROM thl_session + WHERE user_id = %s AND status='c' + """, + params=[user_id], + ), + columns=["finished", "user_payout"], + ) + if wall.empty: + reserve = 0 + else: + wall["user_payout"] = wall["user_payout"].astype(float) + wall["user_payout_int"] = wall["user_payout"] * 100 + wall["days_since_complete"] = (now - wall["finished"]).dt.days + wall["pct_rdm"] = wall["days_since_complete"].apply(self.get_redeemable_pct) + wall.loc[wall["pct_rdm"] > 0.95, "pct_rdm"] = 1 + wall["redeemable"] = wall["pct_rdm"] * wall["user_payout_int"] + # Calculate money needed to save in reserve to cover the difference between + # money earned from completes and $ redeemable, subtract that from the + # wall balance. + reserve = round(wall["user_payout_int"].sum() - wall["redeemable"].sum()) + redeemable_balance = user_wallet_balance - reserve + redeemable_balance = 0 if redeemable_balance < 0 else redeemable_balance + + if redeemable_balance > 0: + # it is possible the user_wallet_balance is negative, in which case the redeemable + # balance is 0. Don't fail assertion if that happens. + assert redeemable_balance <= user_wallet_balance + return redeemable_balance + + def get_redeemable_pct( + self, days_since_complete: float, user_trust: float = 0.0 + ) -> float: + """ + Returns the percentage of a payment for a complete that should be redeemable given the + number of days since the complete occurred. + """ + days_since_complete = round(max([min([days_since_complete, 60]), 0])) + # Redeemable pct by days since complete. Logistic growth model: + # https://people.richland.edu/james/lecture/m116/logs/models.html + # Starts at 40% and goes up to 95% by day 38 (we then round up to + # 100% from there) with a 4-day hold at 40% (for an untrusted user). + initial_value = 0.40 + (0.20 * user_trust) + max_value = 1 + day_delay = 4 - (2 * user_trust) + rate = 0.1 + (0.1 * user_trust) + b = (max_value - initial_value) / initial_value + days = max([(days_since_complete - day_delay), 0]) + y = 1 / (1 + (b * np.exp(-1 * rate * days))) + pct_rdm = np.clip(y, a_min=0, a_max=0.95) / 0.95 + # We can plot it with ... + # x = [timedelta(days=d) for d in range(60)] + # plt.plot([d.days for d in x], [self.get_redeemable_amount(d) for d in x]) + return pct_rdm + + def get_user_txs( + self, + user: User, + time_start: Optional[datetime] = None, + time_end: Optional[datetime] = None, + page: int = 1, + size: int = 50, + order_by: Optional[str] = "created,tag", + ) -> UserLedgerTransactions: + user.prefetch_product(self.pg_config) + user_account = self.get_account_or_create_user_wallet(user) + exclude_txs_before = None + + if user.product_id == JAMES_BILLINGS_BPID: + time_start = ( + max([JAMES_BILLINGS_TX_CUTOFF, time_start]) + if time_start is not None + else JAMES_BILLINGS_TX_CUTOFF + ) + exclude_txs_before = JAMES_BILLINGS_TX_CUTOFF + + txs, total = self.get_tx_filtered_by_account_paginated( + user_account.uuid, + time_start=time_start, + time_end=time_end, + page=page, + size=size, + order_by=order_by, + ) + summary = self.get_tx_filtered_by_account_summary( + user_account.uuid, time_start=time_start, time_end=time_end + ) + # the 'total' should equal the sum of the UserLedgerTransactionTypeSummary.entry_count for each field + utx = UserLedgerTransactions.from_txs( + user_account=user_account, + txs=txs, + product_id=user.product_id, + payout_format=user.product.payout_config.payout_format, + summary=summary, + page=page, + size=size, + total=total, + ) + # Now calculate the rolling balance. Modifies utx.transactions in place + self.include_running_balance( + txs=utx.transactions, + account_uuid=user_account.uuid, + exclude_txs_before=exclude_txs_before, + ) + return utx diff --git a/generalresearch/managers/thl/maxmind/__init__.py b/generalresearch/managers/thl/maxmind/__init__.py new file mode 100644 index 0000000..59e0af8 --- /dev/null +++ b/generalresearch/managers/thl/maxmind/__init__.py @@ -0,0 +1,162 @@ +from typing import Collection, Optional + +import geoip2.models + +from generalresearch.managers.base import ( + Permission, + PostgresManagerWithRedis, +) +from generalresearch.managers.thl.ipinfo import ( + IPInformationManager, + IPGeonameManager, + GeoIpInfoManager, +) +from generalresearch.managers.thl.maxmind.basic import MaxmindBasicManager +from generalresearch.managers.thl.maxmind.insights import ( + get_insights_ip_information, + should_call_insights, +) +from generalresearch.models.custom_types import IPvAnyAddressStr +from generalresearch.models.thl.ipinfo import ( + IPInformation, + IPGeoname, + GeoIPInformation, + normalize_ip, +) +from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig + + +class MaxmindManager(PostgresManagerWithRedis): + def __init__( + self, + maxmind_account_id: str, + maxmind_license_key: str, + pg_config: PostgresConfig, + redis_config: RedisConfig, + permissions: Collection[Permission] = None, + ): + self.ipinfo_manager = IPInformationManager(pg_config=pg_config) + self.ipgeo_manager = IPGeonameManager(pg_config=pg_config) + self.geoipinfo_manager = GeoIpInfoManager( + pg_config=pg_config, redis_config=redis_config + ) + + self.basic_maxmind_manager = MaxmindBasicManager( + data_dir="/tmp/", + maxmind_account_id=maxmind_account_id, + maxmind_license_key=maxmind_license_key, + ) + + self.maxmind_account_id = maxmind_account_id + self.maxmind_license_key = maxmind_license_key + + super().__init__( + pg_config=pg_config, + redis_config=redis_config, + permissions=permissions, + ) + + def store_basic_ip_information(self, res: geoip2.models.Country) -> None: + geoname_id = res.country.geoname_id + assert geoname_id, "Must have a Geoname ID to store" + + res_geo = self.ipgeo_manager.fetch_geoname_ids(filter_ids=[geoname_id]) + if len(res_geo) == 0: + self.ipgeo_manager.create_basic( + geoname_id=geoname_id, + is_in_european_union=res.country.is_in_european_union, + country_iso=res.country.iso_code, + country_name=res.country.name, + continent_name=res.continent.name, + continent_code=res.continent.code, + ) + + self.ipinfo_manager.create_basic( + ip=res.traits.ip_address, + country_iso=res.country.iso_code, + registered_country_iso=res.registered_country.iso_code, + geoname_id=geoname_id, + ) + + def store_insights_ip_information(self, res: geoip2.models.Insights) -> None: + ipinfo = IPInformation.from_insights(res) + geoname_id = ipinfo.geoname_id + res_geo = self.ipgeo_manager.fetch_geoname_ids([geoname_id]) + if len(res_geo) == 0: + ipgeo = IPGeoname.from_insights(res) + self.ipgeo_manager.create_or_update(ipgeo=ipgeo) + self.ipinfo_manager.create_or_update(ipinfo=ipinfo) + + return None + + def get_or_create_ip_information( + self, + ip_address: IPvAnyAddressStr, + force_insights: bool = False, + ) -> Optional[GeoIPInformation]: + """ + This is the 'top-level' IP handling call. + + - Check to see if we already 'know about' this IP. If so, return + it. Otherwise: + - Lookup basic or detailed info. Cache the result. maxmind lookup + happens synchronously. If `pool`, the db operation happens async + and we don't necessarily return the insights info. + """ + res = self.geoipinfo_manager.get(ip_address) + if res and ( + (force_insights is True and res.basic is False) or (force_insights is False) + ): + return res + return self.run_ip_information(ip_address, force_insights=force_insights) + + def run_ip_information( + self, + ip_address: IPvAnyAddressStr, + force_insights: bool = False, + ) -> Optional[GeoIPInformation]: + """ + Assumes this IP is "unknown" to us (not in the ipinformation table). + Quick lookup IP using geoip2.Database. If its "good", lookup detailed + info. Run db update. + """ + # Quick lookup IP using geoip2.database + basic_res = self.basic_maxmind_manager.get_basic_ip_information(ip_address) + if basic_res is None: + # IP is not 'valid'. We do nothing because if we see it again, it'll just hit the + # geoip2.database (and redis and mysql_rr) which is ok... so no biggie. + return None + + if force_insights or should_call_insights(res=basic_res): + # IP is valid and country is good. Look up insights. + return self.get_and_store_insights(ip_address) + + else: + # IP is valid, but from a spammy country. + self.store_basic_ip_information(res=basic_res) + return self.geoipinfo_manager.get(ip_address) + + def get_and_store_insights( + self, + ip_address: IPvAnyAddressStr, + ) -> GeoIPInformation: + + rc = self.redis_client + normalized_ip, lookup_prefix = normalize_ip(ip_address) + # Protect the actual calling of this with a lock + with rc.lock(f"insights-lock:{normalized_ip}", timeout=2, blocking_timeout=1): + # Check again we don't have it (or it is only the basic that is cached) + res = self.geoipinfo_manager.get_cache(ip_address=ip_address) + if res is not None and res.basic is False: + return res + + res_mm = get_insights_ip_information( + ip_address=normalized_ip, + maxmind_account_id=self.maxmind_account_id, + maxmind_license_key=self.maxmind_license_key, + ) + self.store_insights_ip_information(res_mm) + res = self.geoipinfo_manager.recreate_cache(ip_address) + + return res diff --git a/generalresearch/managers/thl/maxmind/basic.py b/generalresearch/managers/thl/maxmind/basic.py new file mode 100644 index 0000000..d065c13 --- /dev/null +++ b/generalresearch/managers/thl/maxmind/basic.py @@ -0,0 +1,134 @@ +import logging +import os +import subprocess +from datetime import timedelta +from pathlib import Path +from threading import RLock +from typing import Optional, Union +from uuid import uuid4 + +import geoip2.database +import geoip2.models +import requests +from cachetools import cached, TTLCache +from geoip2.errors import AddressNotFoundError + +from generalresearch.managers.base import Manager +from generalresearch.models.custom_types import ( + IPvAnyAddressStr, + CountryISOLike, +) + + +logger = logging.getLogger() + + +class MaxmindBasicManager(Manager): + + def __init__( + self, + data_dir: Union[str, Path], + maxmind_account_id: str, + maxmind_license_key: str, + ): + + self.data_dir = data_dir + self.maxmind_account_id = maxmind_account_id + self.maxmind_license_key = maxmind_license_key + + self.run_update_geoip_db() + super().__init__() + + @cached( + cache=TTLCache(maxsize=1, ttl=timedelta(hours=1).total_seconds()), + lock=RLock(), + ) + def get_geoip_db(self): + db_path = os.path.join(self.data_dir, "GeoIP2-Country.mmdb") + return geoip2.database.Reader(fileish=db_path) + + def get_basic_ip_information( + self, ip_address: IPvAnyAddressStr + ) -> Optional[geoip2.models.Country]: + try: + return self.get_geoip_db().country(ip_address) + except (ValueError, AddressNotFoundError): + return None + + def get_country_iso_from_ip_geoip2db( + self, ip: IPvAnyAddressStr + ) -> Optional[CountryISOLike]: + res = self.get_basic_ip_information(ip_address=ip) + if res: + return res.country.iso_code.lower() + + def run_update_geoip_db(self) -> None: + # runs update_geoip_db with slack panic if fails + db_path = os.path.join(self.data_dir, "GeoIP2-Country.mmdb") + if os.path.exists(db_path): + logger.info("GeoIP2-Country.mmdb already exists!") + else: + logger.info("Updating GeoIP2-Country.mmdb") + try: + self.update_geoip_db() + except Exception as e: + # TODO: Alert + pass + + def update_geoip_db(self) -> None: + """ + Download, checksum, extract from archive, confirm it works, then replace file on disk. + # note: allowed 2,000 downloads per day, so I'm not bothering to implement + # last modified or whatever checks. + # https://support.maxmind.com/geoip-faq/databases-and-database-updates/is-there-a-limit-to-how-often-i-can + -download-a-database-from-my-maxmind-account/ + + """ + db_url = ( + f"https://download.maxmind.com/app/geoip_download?edition_id=GeoIP2-Country&" + f"license_key={self.maxmind_license_key}&suffix=tar.gz" + ) + sha256_url = ( + f"https://download.maxmind.com/app/geoip_download?edition_id=GeoIP2-Country&" + f"license_key={self.maxmind_license_key}&suffix=tar.gz.sha256" + ) + u = uuid4().hex + cwd = f"/tmp/{u}/" + os.makedirs(name=cwd, exist_ok=True) + + res = requests.get(db_url) + # db_file_name looks like "GeoIP2-Country_20210806.tar.gz" + db_file_name = res.headers.get("Content-Disposition").split("filename=")[1] + tmp_db_file = cwd + db_file_name + with open(tmp_db_file, "wb") as f: + f.write(res.content) + res = requests.get(sha256_url) + tmp_sha256_file = cwd + "db.sha256" + with open(tmp_sha256_file, "wb") as f: + f.write(res.content) + subprocess.check_call(args=["sha256sum", "-c", tmp_sha256_file], cwd=cwd) + # Extract + db_name = db_file_name.replace(".tar.gz", "") + subprocess.check_call( + args=[ + "tar", + "-xf", + tmp_db_file, + "--strip-components", + "1", + f"{db_name}/GeoIP2-Country.mmdb", + ], + cwd=cwd, + ) + + # Confirm it works + g = geoip2.database.Reader(fileish=cwd + "GeoIP2-Country.mmdb") + g.country("111.111.111.111").country.iso_code.lower() + + # update file on disk + prod_db = os.path.join(self.data_dir, "GeoIP2-Country.mmdb") + subprocess.check_call(["mv", cwd + "GeoIP2-Country.mmdb", prod_db]) + + # clean up + assert cwd.startswith("/tmp/") + subprocess.check_call(["rm", "-r", cwd]) diff --git a/generalresearch/managers/thl/maxmind/insights.py b/generalresearch/managers/thl/maxmind/insights.py new file mode 100644 index 0000000..b83bded --- /dev/null +++ b/generalresearch/managers/thl/maxmind/insights.py @@ -0,0 +1,52 @@ +import logging +from typing import Optional + +import geoip2.database +import geoip2.models +import geoip2.webservice +import slack +from geoip2.errors import ( + AddressNotFoundError, + AuthenticationError, + InvalidRequestError, + OutOfQueriesError, +) + +logger = logging.getLogger() + + +def get_insights_ip_information( + ip_address: str, + maxmind_account_id: str, + maxmind_license_key: str, +) -> Optional[geoip2.models.Insights]: + + # (2) We want more information, proceed further ($0.002) + client = geoip2.webservice.Client( + account_id=maxmind_account_id, license_key=maxmind_license_key, timeout=1 + ) + logger.info(f"get_insights_ip_information: {ip_address}") + try: + res = client.insights(ip_address) + + except (AuthenticationError, OutOfQueriesError) as e: + # TODO: Alert + return None + + except (AddressNotFoundError, InvalidRequestError): + return None + else: + return res + + +def should_call_insights(res: geoip2.models.Country) -> bool: + """ + Call insights immediately if the IP is either: + - in the continent of North America, Europe, or Oceania + - in the country of Japan, Singapore, Israel, Hong Kong, Taiwan, South Korea + """ + if res.continent.code.upper() in {"NA", "EU", "OC"}: + return True + if res.country.iso_code.upper() in {"JP", "SG", "IL", "HK", "TW", "KR"}: + return True + return False diff --git a/generalresearch/managers/thl/payout.py b/generalresearch/managers/thl/payout.py new file mode 100644 index 0000000..e99cc25 --- /dev/null +++ b/generalresearch/managers/thl/payout.py @@ -0,0 +1,1256 @@ +from collections import defaultdict +from datetime import timezone, datetime, timedelta +from random import randint, choice as rand_choice +from time import sleep +from typing import Collection, Optional, Dict, List, Union +from uuid import UUID, uuid4 + +import numpy as np +import pandas as pd +from psycopg import sql +from pydantic import AwareDatetime, PositiveInt, NonNegativeInt + +from generalresearch.currency import USDCent +from generalresearch.decorators import LOG +from generalresearch.managers.base import ( + PostgresManagerWithRedis, +) +from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, +) +from generalresearch.managers.thl.product import ProductManager +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr +from generalresearch.models.gr.business import Business +from generalresearch.models.thl.definitions import PayoutStatus +from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + OrderBy, +) +from generalresearch.models.thl.payout import ( + PayoutEvent, + UserPayoutEvent, + BrokerageProductPayoutEvent, + BusinessPayoutEvent, +) +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.thl.wallet.cashout_method import ( + CashoutRequestInfo, + CashMailOrderData, +) + + +class PayoutEventManager(PostgresManagerWithRedis): + """This is the default base Payout Event Manger. It acts as a base for + mixing up two different concepts: + - User Payout Events (money to Users / respondents) + - Brokerage Product Payout Events (money to Suppliers) + + """ + + def set_account_lookup_table(self, thl_lm: ThlLedgerManager) -> None: + """This needs to run from grl-flow or from somewhere that has thl-redis + access + """ + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT uuid, reference_uuid + FROM ledger_account + WHERE qualified_name LIKE '{thl_lm.currency.value}:bp_wallet:%' + """ + ) + account_to_product = {i["uuid"]: i["reference_uuid"] for i in res} + product_to_account = {i["reference_uuid"]: i["uuid"] for i in res} + + rc = self.redis_client + rc.hset(name="pem:account_to_product", mapping=account_to_product) + rc.hset(name="pem:product_to_account", mapping=product_to_account) + + return None + + def get_by_uuid(self, pe_uuid: UUIDStr) -> PayoutEvent: + res = self.pg_config.execute_sql_query( + query=""" + SELECT ep.uuid, + debit_account_uuid, + cashout_method_uuid, + ep.created, ep.amount, ep.status, + ep.ext_ref_id, ep.payout_type, + ep.request_data::jsonb, + ep.order_data::jsonb + FROM event_payout AS ep + WHERE ep.uuid = %s + """, + params=[pe_uuid], + ) + assert len(res) == 1, f"{pe_uuid} expected 1 result, got {len(res)}" + return PayoutEvent.model_validate(res[0]) + + def update( + self, + payout_event: Union[UserPayoutEvent, BrokerageProductPayoutEvent], + status: PayoutStatus, + ext_ref_id: Optional[str] = None, + order_data: Optional[Dict] = None, + ) -> None: + # These 3 things are the only modifiable attributes + ext_ref_id = ext_ref_id if ext_ref_id is not None else payout_event.ext_ref_id + order_data = order_data if order_data is not None else payout_event.order_data + payout_event.update(status=status, ext_ref_id=ext_ref_id, order_data=order_data) + + d = payout_event.model_dump_mysql() + query = sql.SQL( + """ + UPDATE event_payout SET + status = %(status)s, + ext_ref_id = %(ext_ref_id)s, + order_data = %(order_data)s + WHERE uuid = %(uuid)s; + """ + ) + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=query, params=d) + assert ( + c.rowcount == 1 + ), "Nothing was updated! Are you sure this payout_event exists?" + conn.commit() + + return None + + +class UserPayoutEventManager(PayoutEventManager): + + def get_by_uuid(self, pe_uuid: UUIDStr) -> UserPayoutEvent: + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT ep.uuid, + ep.debit_account_uuid, + ep.cashout_method_uuid, + ep.created, ep.amount, ep.status, ep.ext_ref_id, ep.payout_type, + ep.request_data::jsonb, + ep.order_data::jsonb, + -- User Payout specific + ac.name as description, + la.reference_type as account_reference_type, + la.reference_uuid as account_reference_uuid + FROM event_payout AS ep + LEFT JOIN accounting_cashoutmethod AS ac + ON ep.cashout_method_uuid = ac.id + LEFT JOIN ledger_account AS la + ON la.uuid = ep.debit_account_uuid + WHERE ep.uuid = %s + """, + params=[pe_uuid], + ) + + assert len(res) == 1, f"{pe_uuid} expected 1 result, got {len(res)}" + + d = res[0] + pe = UserPayoutEvent.model_validate(d) + if pe.order_data and pe.payout_type == PayoutType.CASH_IN_MAIL: + pe.order_data = CashMailOrderData.model_validate(pe.order_data) + + return pe + + def get_payout_detail(self, pe_uuid: UUIDStr) -> CashoutRequestInfo: + # This gets the payout event, and then extracts information for + # the purposes of returning to the user. + pe = self.get_by_uuid(pe_uuid=pe_uuid) + + transaction_info = dict() + order: Dict = pe.order_data + if pe.payout_type == PayoutType.TANGO and pe.status == PayoutStatus.COMPLETE: + reward = order["reward"] + if "credentialList" in reward: + reward["credential_list"] = reward.pop("credentialList") + if "redemptionInstructions" in reward: + reward["redemption_instructions"] = reward.pop("redemptionInstructions") + transaction_info = order["reward"] + elif pe.payout_type == PayoutType.PAYPAL and pe.status == PayoutStatus.COMPLETE: + info = {"transaction_id": order["transaction_id"]} + transaction_info = info + elif ( + pe.payout_type == PayoutType.CASH_IN_MAIL + and pe.status == PayoutStatus.COMPLETE + ): + transaction_info = pe.order_data.model_dump(mode="json") + + return CashoutRequestInfo( + id=pe_uuid, + status=pe.status, + description=pe.description, + transaction_info=transaction_info, + message="", + ) + + def filter_by( + self, + reference_uuid: Optional[str] = None, + debit_account_uuids: Optional[Collection[UUIDStr]] = None, + amount: Optional[int] = None, + created: Optional[datetime] = None, + created_after: Optional[datetime] = None, + product_ids: Collection[str] = None, + bp_user_ids: Optional[Collection[str]] = None, + cashout_method_uuids: Collection[UUIDStr] = None, + cashout_types: Optional[Collection[PayoutType]] = None, + statuses: Optional[Collection[PayoutStatus]] = None, + ) -> List[UserPayoutEvent]: + """Try to retrieve payout events by the product_id/user_uuid, amount, + and optionally timestamp. + + WARNING: This is only on the "payout events" table and nothing to + do with the Ledger itself. Therefore, the product_ids query + doesn't return Brokerage Product Payouts (the ACH or Wire events + to Suppliers) as part of the query. + + *** IT IS ONLY FOR USER PAYOUTS *** + + Note: what used to be in thl-grpcs "ListCashoutRequests" calling + "list_cashout_requests" was merged into this. + """ + args = [] + filters = [] + if reference_uuid: + # This could be a product_id or a user_uuid + filters.append("la.reference_uuid = %s") + args.append(reference_uuid) + if debit_account_uuids: + # Or we could use the bp_wallet or user_wallet's account uuid + # instead of looking up by the product/user + filters.append("ep.debit_account_uuid = ANY(%s)") + args.append(debit_account_uuids) + if amount: + filters.append("ep.amount = %s") + args.append(amount) + if created: + filters.append("ep.created = %s") + args.append(created.replace(tzinfo=None)) + if created_after: + filters.append("ep.created >= %s") + args.append(created_after.replace(tzinfo=None)) + if product_ids: + filters.append("product_id = ANY(%s)") + args.append(product_ids) + if bp_user_ids: + filters.append("product_user_id = ANY(%s)") + args.append(bp_user_ids) + if cashout_method_uuids: + filters.append("cashout_method_uuid = ANY(%s)") + args.append(cashout_method_uuids) + if cashout_types: + filters.append("payout_type = ANY(%s)") + args.append([x.value for x in cashout_types]) + if statuses: + filters.append("status = ANY(%s)") + args.append([x.value for x in statuses]) + + assert len(filters) > 0, "must pass at least 1 filter" + filter_str = "WHERE " + " AND ".join(filters) + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT + ep.uuid, ep.debit_account_uuid, + ep.created, ep.amount, ep.status, + ep.ext_ref_id, ep.payout_type, ep.cashout_method_uuid, + ep.order_data::jsonb, + ep.request_data::jsonb, + ac.name as description, + la.reference_type as account_reference_type, + la.reference_uuid as account_reference_uuid + FROM event_payout AS ep + LEFT JOIN accounting_cashoutmethod AS ac + ON ep.cashout_method_uuid = ac.id + LEFT JOIN ledger_account AS la + ON la.uuid = ep.debit_account_uuid + LEFT JOIN thl_user u + ON la.reference_uuid = u.uuid + {filter_str} + """, + params=args, + ) + + pes = [] + for d in res: + pes.append(UserPayoutEvent.model_validate(d)) + return pes + + def create( + self, + debit_account_uuid: UUIDStr, + cashout_method_uuid: UUIDStr, + payout_type: PayoutType, + amount: PositiveInt, + # --- Optional: Default / Default Factory --- + uuid: Optional[UUIDStr] = None, + status: Optional[PayoutStatus] = None, + created: Optional[AwareDatetimeISO] = None, + request_data: Optional[Dict] = None, + # --- Optional: None --- + account_reference_type: Optional[str] = None, + account_reference_uuid: Optional[UUIDStr] = None, + description: Optional[str] = None, + ext_ref_id: Optional[str] = None, + order_data: Optional[Dict | CashMailOrderData] = None, + ) -> UserPayoutEvent: + + payout_event = UserPayoutEvent( + uuid=uuid or uuid4().hex, + debit_account_uuid=debit_account_uuid, + account_reference_type=account_reference_type, + account_reference_uuid=account_reference_uuid, + cashout_method_uuid=cashout_method_uuid, + description=description, + created=created or datetime.now(tz=timezone.utc), + amount=amount, + status=status or PayoutStatus.PENDING, + ext_ref_id=ext_ref_id, + payout_type=payout_type, + request_data=request_data or {}, + order_data=order_data, + ) + d = payout_event.model_dump_mysql() + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=f""" + INSERT INTO event_payout ( + uuid, debit_account_uuid, created, cashout_method_uuid, amount, + status, ext_ref_id, payout_type, order_data, request_data + ) VALUES ( + %(uuid)s, %(debit_account_uuid)s, %(created)s, + %(cashout_method_uuid)s, %(amount)s, %(status)s, + %(ext_ref_id)s, %(payout_type)s, %(order_data)s, + %(request_data)s + ); + """, + params=d, + ) + assert c.rowcount == 1, f"expected 1 row inserted, got {c.rowcount}" + conn.commit() + + return payout_event + + def create_dummy( + self, + uuid: Optional[UUIDStr] = None, + debit_account_uuid: Optional[UUIDStr] = None, + account_reference_type: Optional[str] = None, + account_reference_uuid: Optional[UUIDStr] = None, + cashout_method_uuid: Optional[UUIDStr] = None, + description: Optional[str] = None, + created: Optional[AwareDatetimeISO] = None, + amount: Optional[PositiveInt] = None, + status: Optional[PayoutStatus] = None, + ext_ref_id: Optional[str] = None, + payout_type: Optional[PayoutType] = None, + request_data: Optional[Dict] = None, + order_data: Optional[Dict | CashMailOrderData] = None, + ) -> UserPayoutEvent: + debit_account_uuid = debit_account_uuid or uuid4().hex + cashout_method_uuid = cashout_method_uuid or uuid4().hex + # account_reference_type = account_reference_type or f"acct-ref-{uuid4().hex}" + # account_reference_uuid = account_reference_uuid or uuid4().hex + # cashout_method_uuid = cashout_method_uuid or uuid4().hex + amount = amount or randint(a=99, b=9_999) + status = status or rand_choice(list(PayoutStatus)) + + description = description or f"desc-{uuid4().hex[:12]}" + # ext_ref_id = ext_ref_id or f"ext-ref-{uuid4().hex[:8]}" + payout_type = payout_type or rand_choice(list(PayoutType)) + request_data = request_data or {} + # order_data = order_data or None + + return self.create( + uuid=uuid, + debit_account_uuid=debit_account_uuid, + account_reference_type=account_reference_type, + account_reference_uuid=account_reference_uuid, + cashout_method_uuid=cashout_method_uuid, + description=description, + created=created, + amount=amount, + status=status, + ext_ref_id=ext_ref_id, + payout_type=payout_type, + request_data=request_data, + order_data=order_data, + ) + + +class BrokerageProductPayoutEventManager(PayoutEventManager): + # This is what makes a PayoutEvent a Brokerage Product Payout + CASHOUT_METHOD_UUID = "602113e330cf43ae85c07d94b5100291" + + def get_by_uuid( + self, + pe_uuid: UUIDStr, + # --- Support resources --- + account_product_mapping: Optional[Dict[UUIDStr, UUIDStr]] = None, + ) -> BrokerageProductPayoutEvent: + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT ep.uuid, + ep.debit_account_uuid, + ep.cashout_method_uuid, + ep.created, ep.amount, ep.status, ep.ext_ref_id, ep.payout_type, + ep.request_data::jsonb, + ep.order_data::jsonb + FROM event_payout AS ep + WHERE ep.uuid = %s + """, + params=[pe_uuid], + ) + assert len(res) == 1, f"{pe_uuid} expected 1 result, got {len(res)}" + + d = res[0] + + # This isn't really need for creation... but we're doing it so that + # it can return back a full BrokerageProductPayoutEvent instance + if account_product_mapping is None: + rc = self.redis_client + account_product_mapping: Dict = rc.hgetall(name="pem:account_to_product") + assert isinstance(account_product_mapping, dict) + d["product_id"] = account_product_mapping[d["debit_account_uuid"]] + + return BrokerageProductPayoutEvent.model_validate(d) + + @staticmethod + def check_for_ledger_tx( + thl_ledger_manager: ThlLedgerManager, + product_id: UUIDStr, + amount: USDCent, + payout_event: BrokerageProductPayoutEvent, + ) -> bool: + """ + Checks if a ledger tx for this payout event exists properly in the DB. + It looks up by the tag (which is uniquely specified by the payout event uuid), + and then confirms that the associated transaction if a bp_payout, for the + specified Product, for the same amount. + + Returns True if the tx exists and looks ok, False if no txs with that tag + are found, and raises a ValueError if something is inconsistent. + """ + tag = f"{thl_ledger_manager.currency.value}:bp_payout:{payout_event.uuid}" + txs = thl_ledger_manager.get_tx_by_tag(tag) + + if not txs: + return False + + if len(txs) != 1: + raise ValueError(f"Two transactions found for tag: {tag}!") + + tx = txs[0] + if ( + (len(tx.entries) != 2) + or (tx.entries[0].amount != amount) + or (tx.metadata["tx_type"] != "bp_payout") + or (tx.metadata["event_payout"] != payout_event.uuid) + ): + raise ValueError( + f"Found existing tx with tag: {tag}, but different than expected!" + ) + bp_wallet_account = thl_ledger_manager.get_account_or_create_bp_wallet_by_uuid( + product_uuid=product_id + ) + entry = [x for x in tx.entries if x.direction == Direction.DEBIT][0] + if entry.account_uuid != bp_wallet_account.uuid: + raise ValueError( + f"Found existing tx with tag: {tag}, but for a different account!" + ) + + return True + + def create( + self, + uuid: Optional[UUIDStr] = None, + debit_account_uuid: Optional[UUIDStr] = None, + created: AwareDatetimeISO = None, + amount: PositiveInt = None, + status: Optional[PayoutStatus] = None, + ext_ref_id: Optional[str] = None, + payout_type: PayoutType = None, + request_data: Dict = None, + order_data: Optional[Dict | CashMailOrderData] = None, + # --- Support resources --- + account_product_mapping: Optional[Dict[UUIDStr, UUIDStr]] = None, + ) -> BrokerageProductPayoutEvent: + if request_data is None: + request_data = dict() + + # This isn't really need for creation... but we're doing it so that + # it can return back a full BrokerageProductPayoutEvent instance + if account_product_mapping is None: + rc = self.redis_client + account_product_mapping: Dict = rc.hgetall(name="pem:account_to_product") + assert isinstance(account_product_mapping, dict) + product_id = account_product_mapping[debit_account_uuid] + + bp_payout_event = BrokerageProductPayoutEvent( + uuid=uuid or uuid4().hex, + debit_account_uuid=debit_account_uuid, + cashout_method_uuid=self.CASHOUT_METHOD_UUID, + created=created or datetime.now(tz=timezone.utc), + amount=amount, + status=status, + ext_ref_id=ext_ref_id, + payout_type=payout_type, + request_data=request_data, + order_data=order_data, + product_id=product_id, + ) + d = bp_payout_event.model_dump_mysql() + + self.pg_config.execute_write( + query=f""" + INSERT INTO event_payout ( + uuid, debit_account_uuid, created, cashout_method_uuid, amount, + status, ext_ref_id, payout_type, order_data, request_data + ) VALUES ( + %(uuid)s, %(debit_account_uuid)s, %(created)s, + %(cashout_method_uuid)s, %(amount)s, %(status)s, + %(ext_ref_id)s, %(payout_type)s, %(order_data)s, + %(request_data)s + ); + """, + params=d, + ) + + return bp_payout_event + + def filter_by( + self, + reference_uuid: Optional[str] = None, + ext_ref_id: Optional[str] = None, + debit_account_uuids: Optional[Collection[UUIDStr]] = None, + amount: Optional[int] = None, + created: Optional[datetime] = None, + created_after: Optional[datetime] = None, + product_ids: Collection[str] = None, + bp_user_ids: Optional[Collection[str]] = None, + cashout_types: Optional[Collection[PayoutType]] = None, + statuses: Optional[Collection[PayoutStatus]] = None, + ) -> List[BrokerageProductPayoutEvent]: + """Try to retrieve payout events by the product_id/user_uuid, amount, + and optionally timestamp. + + WARNING: This is only on the "payout events" table and nothing to + do with the Ledger itself. Therefore, the product_ids query + doesn't return Brokerage Product Payouts (the ACH or Wire events + to Suppliers) as part of the query. + + *** IT IS ONLY FOR USER PAYOUTS *** + + Note: what used to be in thl-grpcs "ListCashoutRequests" calling + "list_cashout_requests" was merged into this. + """ + args = [] + filters = [] + if reference_uuid: + # This could be a product_id or a user_uuid + filters.append("la.reference_uuid = %s") + args.append(reference_uuid) + if ext_ref_id: + # This is transaction id for tracking ACH/Wires with a banking + # institution + filters.append("ep.ext_ref_id = %s") + args.append(ext_ref_id) + if debit_account_uuids: + # Or we could use the bp_wallet or user_wallet's account uuid + # instead of looking up by the product/user + filters.append("ep.debit_account_uuid = ANY(%s)") + args.append(debit_account_uuids) + if amount: + filters.append("ep.amount = %s") + args.append(amount) + if created: + filters.append("ep.created = %s") + args.append(created.replace(tzinfo=None)) + if created_after: + filters.append("ep.created >= %s") + args.append(created_after.replace(tzinfo=None)) + if product_ids: + filters.append("product_id = ANY(%s)") + args.append(product_ids) + if bp_user_ids: + filters.append("product_user_id = ANY(%s)") + args.append(bp_user_ids) + if cashout_types: + filters.append("payout_type = ANY(%s)") + args.append([x.value for x in cashout_types]) + if statuses: + filters.append("status = ANY(%s)") + args.append([x.value for x in statuses]) + + assert len(filters) > 0, "must pass at least 1 filter" + filter_str = " AND ".join(filters) + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT ep.uuid, + ep.debit_account_uuid, + ep.cashout_method_uuid, + ep.created, + ep.amount, ep.status, ep.ext_ref_id, ep.payout_type, + ep.request_data::jsonb, ep.order_data::jsonb, + ac.name as description, + la.reference_type as account_reference_type, + la.reference_uuid as account_reference_uuid + FROM event_payout AS ep + LEFT JOIN accounting_cashoutmethod AS ac + ON ep.cashout_method_uuid = ac.id + LEFT JOIN ledger_account AS la + ON la.uuid = ep.debit_account_uuid + LEFT JOIN thl_user u + ON la.reference_uuid = u.uuid + WHERE cashout_method_uuid = '{self.CASHOUT_METHOD_UUID}' + AND {filter_str} + """, + params=args, + ) + + rc = self.redis_client + account_product_mapping = rc.hgetall(name="pem:account_to_product") + + pes = [] + for d in res: + for k in [ + "uuid", + "debit_account_uuid", + "account_reference_uuid", + "cashout_method_uuid", + ]: + if d[k] is not None: + d[k] = UUID(d[k]).hex + + d["product_id"] = account_product_mapping[d["debit_account_uuid"]] + pes.append(BrokerageProductPayoutEvent.model_validate(d)) + + return pes + + def get_bp_payout_events_for_accounts( + self, accounts: Collection[LedgerAccount] + ) -> List[BrokerageProductPayoutEvent]: + return self.filter_by( + debit_account_uuids=[i.uuid for i in accounts], + cashout_types=[PayoutType.ACH], + ) + + def get_bp_bp_payout_events_for_products( + self, + thl_ledger_manager: ThlLedgerManager, + product_uuids: Collection[UUIDStr], + order_by: Optional[OrderBy] = OrderBy.ASC, + ) -> List["BrokerageProductPayoutEvent"]: + """This is a terrible name, but it returns the + BPPayoutEvent model type rather than a list of PayoutEvents. + + We do this for the Supplier centric APIs where they don't know, + or care about the underlying ledger account structure. + """ + assert len(product_uuids) > 0, "Must provide product_uuids" + accounts = thl_ledger_manager.get_accounts_bp_wallet_for_products( + product_uuids=product_uuids + ) + + assert len(accounts) == len(product_uuids), "Unequal Product & Account lists" + + rc = self.redis_client + account_product_mapping = rc.hgetall(name="pem:account_to_product") + + payout_events: List[BrokerageProductPayoutEvent] = ( + self.get_bp_payout_events_for_accounts( + accounts=accounts, + ) + ) + + return BrokerageProductPayoutEvent.from_payout_events( + payout_events=payout_events, + account_product_mapping=account_product_mapping, + order_by=order_by, + ) + + def retry_create_bp_payout_event_tx( + self, + thl_ledger_manager: ThlLedgerManager, + product: Product, + payout_event_uuid: UUIDStr, + skip_wallet_balance_check: bool = False, + skip_one_per_day_check: bool = False, + ) -> BrokerageProductPayoutEvent: + """If a create_bp_payout_event call fails, this can be called with + the associated payoutevent. + """ + bp_pe: BrokerageProductPayoutEvent = self.get_by_uuid(payout_event_uuid) + assert bp_pe.status == PayoutStatus.FAILED, "Only use this on failed payouts" + created = bp_pe.created + + assert not self.check_for_ledger_tx( + thl_ledger_manager=thl_ledger_manager, + payout_event=bp_pe, + product_id=bp_pe.product_id, + amount=bp_pe.amount_usd, + ), "Transaction exists! You should mark the payout event status as complete" + + return self._create_tx_bp_payout_from_payout_event( + thl_ledger_manager=thl_ledger_manager, + bp_pe=bp_pe, + product=product, + amount=bp_pe.amount_usd, + created=created, + skip_one_per_day_check=skip_one_per_day_check, + skip_wallet_balance_check=skip_wallet_balance_check, + ) + + def create_bp_payout_event( + self, + thl_ledger_manager: ThlLedgerManager, + product: Product, + amount: USDCent, + payout_type: PayoutType = PayoutType.ACH, + ext_ref_id: Optional[str] = None, + created: Optional[AwareDatetime] = None, + skip_wallet_balance_check: bool = False, + skip_one_per_day_check: bool = False, + ) -> BrokerageProductPayoutEvent: + """This should be called when a BP is paid out money from their + wallet. Typically, this is an ACH payment. This function creates + the PayoutEvent and the Ledger entries. + + :param thl_ledger_manager: + :param product: The BP being paid. Assuming we're paying them out + of the balance of their USD wallet account. + :param amount: We're assuming everything is in USD, and we're + paying out a USD currency account. We could theoretically also + pay, for e.g. a Bitcoin account with a bitcoin transfer, but + this is not supported for now. + :param payout_type: PayoutType. default ACH + :param cashout_method_uuid: The entry in the + accounting_cashoutmethod table that records payment method + details. By default, the generic ACH cashout method (that has + no actual banking details). + + :param ext_ref_id: This is a unique ID for the Supplier Payment. + Typically it'll be from JP Morgan Chase, but may also just be + random if we can retrieve anything + + :param created: + + :param skip_wallet_balance_check: By default, this will fail unless + the BP's wallet actually has the amount requested. + + :param skip_one_per_day_check: Safety mechanism, checks if there + has already been a payout to this wallet in the past 24 hours. + + :return: + """ + + assert isinstance(amount, USDCent), "Must provide a USDCent" + + if created: + # Try to do a quick dupe check first before we create the payout event + pes = self.filter_by( + reference_uuid=product.id, amount=amount, created=created + ) + if len(pes) > 0: + raise ValueError(f"Payout event already exists!: {pes}") + + if created is None: + created = datetime.now(tz=timezone.utc) + + # TODO: Explain why we're doing this. Why is it important to have + # Payout Events when the ledger has everything that should be + # needed. + bp_wallet = thl_ledger_manager.get_account_or_create_bp_wallet(product=product) + + bp_pe: BrokerageProductPayoutEvent = self.create( + debit_account_uuid=bp_wallet.uuid, + payout_type=payout_type, + amount=amount, + ext_ref_id=ext_ref_id, + created=created, + status=PayoutStatus.PENDING, + ) + return self._create_tx_bp_payout_from_payout_event( + thl_ledger_manager=thl_ledger_manager, + bp_pe=bp_pe, + product=product, + amount=amount, + created=created, + skip_one_per_day_check=skip_one_per_day_check, + skip_wallet_balance_check=skip_wallet_balance_check, + ) + + def _create_tx_bp_payout_from_payout_event( + self, + thl_ledger_manager: ThlLedgerManager, + bp_pe: BrokerageProductPayoutEvent, + product: Product, + amount: USDCent, + created: Optional[AwareDatetime] = None, + skip_wallet_balance_check: bool = False, + skip_one_per_day_check: bool = False, + ) -> BrokerageProductPayoutEvent: + """ + This should not be called directly. + Creates the ledger transaction for a BP Payout, given a PayoutEvent. + Handles exceptions: Check if the ledger tx actually exists or not, and set the + payout event status accordingly. + """ + try: + thl_ledger_manager.create_tx_bp_payout( + product=product, + amount=amount, + payoutevent_uuid=bp_pe.uuid, + created=created, + skip_wallet_balance_check=skip_wallet_balance_check, + skip_one_per_day_check=skip_one_per_day_check, + ) + + except Exception as e: + e.pe_uuid = bp_pe.uuid + if self.check_for_ledger_tx( + thl_ledger_manager=thl_ledger_manager, + product_id=product.uuid, + amount=amount, + payout_event=bp_pe, + ): + LOG.warning(f"Got exception {e} but ledger tx exists! Continuing ... ") + self.update(payout_event=bp_pe, status=PayoutStatus.COMPLETE) + return bp_pe + else: + LOG.warning(f"Got exception {e}. No ledger tx was created.") + self.update(payout_event=bp_pe, status=PayoutStatus.FAILED) + raise e + + self.update(payout_event=bp_pe, status=PayoutStatus.COMPLETE) + return bp_pe + + def get_bp_payout_events_for_product( + self, + thl_ledger_manager: ThlLedgerManager, + product: Product, + ) -> List[BrokerageProductPayoutEvent]: + account = thl_ledger_manager.get_account_or_create_bp_wallet(product=product) + return self.get_bp_payout_events_for_accounts(accounts=[account]) + + def get_bp_payout_events_for_account( + self, account: LedgerAccount + ) -> List[BrokerageProductPayoutEvent]: + return self.get_bp_payout_events_for_accounts(accounts=[account]) + + def get_bp_payout_events_for_products( + self, + thl_ledger_manager: ThlLedgerManager, + product_uuids: Collection[UUIDStr], + ) -> List[BrokerageProductPayoutEvent]: + accounts = thl_ledger_manager.get_accounts_bp_wallet_for_products( + product_uuids=product_uuids + ) + return self.get_bp_payout_events_for_accounts(accounts=accounts) + + +class BusinessPayoutEventManager(BrokerageProductPayoutEventManager): + + def update_ext_reference_ids( + self, + new_value: str, + current_value: Optional[str] = None, + ) -> None: + """ + There are scenarios where an ACH/Wire payout event was saved with + a generic or anonymized reference identifier. We may want to be + able to go back and update all of those transaction IDs. + + """ + + if current_value is None: + raise ValueError("Dangerous to do ambiguous updates") + + # SELECT first to check that records exist + res = self.filter_by(ext_ref_id=current_value) + if len(res) == 0: + raise Warning("No event_payouts found to UPDATE") + + # As of 2025, no single Business has more than 10,000 Products, + # leave the limit in as an additional safeguard. + query = """ + UPDATE event_payout + SET ext_ref_id = %s + WHERE ext_ref_id = %s + """ + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=query, params=[new_value, current_value]) + assert c.rowcount < 10000 + conn.commit() + + return None + + def delete_failed_business_payout(self, ext_ref_id: str, thl_lm: ThlLedgerManager): + """ + Sometimes ACH/Wire payouts fail due to multiple reasons (timeouts, + Business Product having insufficient funds, etc). This is a utility + method that finds all event_payouts, and deletes them with all the + associated: + (1) Transactions + (2) Transaction Metadata + (3) Transaction Entries + + and then proceeds to delete them all in reverse order (so there is + no orphan / FK constraint issues). + """ + + # (1) Find all by payout_event + event_payouts = self.filter_by(ext_ref_id=ext_ref_id) + if len(event_payouts) == 0: + raise Warning("No event_payouts found to DELETE") + + # sum([i["amount"] for i in event_payouts])/100 + event_payout_uuids = [i.uuid for i in event_payouts] + + # (2) Find all ledger_transactions + tags = [f"{thl_lm.currency.value}:bp_payout:{x}" for x in event_payout_uuids] + transactions = thl_lm.get_txs_by_tags(tags=tags) + transaction_ids = [tx.id for tx in transactions] + print("XXX1", transaction_ids) + # assert len(tags) == len(transactions) + + # (3) Find all ledger_transactionmetadata: assert two rows per tx + tx_metadata_ids = thl_lm.get_tx_metadata_ids_by_txs(transactions=transactions) + # assert len(tx_metadata) == len(transaction_ids)*2 + + # (4) Find all ledger_entry: assert two rows per tx + tx_entries = thl_lm.get_tx_entries_by_txs(transactions=transactions) + tx_entry_ids = [tx_entry.id for tx_entry in tx_entries] + # assert len(tx_entry) == len(transaction_ids)*2 + + # (5) Delete records + + # DELETE: tx_entry + self.pg_config.execute_write( + query=""" + DELETE + FROM ledger_entry + WHERE transaction_id = ANY(%s) + AND id = ANY(%s) + """, + params=[transaction_ids, tx_entry_ids], + ) + + # DELETE: tx_metadata + self.pg_config.execute_write( + query=""" + DELETE + FROM ledger_transactionmetadata + WHERE transaction_id = ANY(%s) + AND id = ANY(%s) + """, + params=[transaction_ids, list(tx_metadata_ids)], + ) + + # DELETE: transactions + self.pg_config.execute_write( + query=""" + DELETE + FROM ledger_transaction + WHERE id = ANY(%s) + """, + params=[transaction_ids], + ) + + # DELETE: event_payouts + self.pg_config.execute_write( + query=""" + DELETE + FROM event_payout + WHERE ext_ref_id = %s + AND uuid = ANY(%s) + """, + params=[ext_ref_id, event_payout_uuids], + ) + + return None + + def get_business_payout_events_for_products( + self, + thl_ledger_manager: ThlLedgerManager, + product_uuids: Collection[UUIDStr], + order_by: Optional[OrderBy] = OrderBy.ASC, + ) -> List["BusinessPayoutEvent"]: + res = self.get_bp_bp_payout_events_for_products( + thl_ledger_manager=thl_ledger_manager, + product_uuids=product_uuids, + order_by=order_by, + ) + + return self.from_bp_payout_events(bp_payout_events=res) + + @staticmethod + def from_bp_payout_events( + bp_payout_events: Collection["BrokerageProductPayoutEvent"], + ) -> List["BusinessPayoutEvent"]: + if len(bp_payout_events) == 0: + return [] + + grouped = defaultdict(list) + for bp_pe in bp_payout_events: + grouped[bp_pe.ext_ref_id].append(bp_pe) + + res = [] + for ex_ref_id, members in grouped.items(): + res.append(BusinessPayoutEvent.model_validate({"bp_payouts": members})) + + return res + + @staticmethod + def recoup_proportional( + df: pd.DataFrame, + target_amount: Union[USDCent, NonNegativeInt], + ) -> pd.DataFrame: + """ + Recoup a target amount from rows proportionally based on a numeric column. + + Does not filter the dataframe. Length in == Length out + + Parameters: + - df: pandas DataFrame + - target_amount: total amount to recoup + + Returns: + - A new DataFrame with columns: + - original amounts + - weights + - proposed and actual deductions + - remaining balances + """ + w_df = df.copy(deep=True) + target_amount = USDCent(target_amount) + total_available = int(w_df["available_balance"].sum()) + + if total_available == 0: + raise ValueError("Total available amount is empty, cannot recoup") + + if int(target_amount) > total_available: + raise ValueError( + f"Target amount ({target_amount}) exceeds total available " + f"({total_available})." + ) + + # Calculate weight and proportional deduction + w_df["weight"] = w_df["available_balance"] / total_available + w_df["raw_deduction"] = w_df["weight"] * target_amount + w_df["deduction"] = np.floor(w_df["raw_deduction"]).astype(int) + w_df["remainder"] = w_df["raw_deduction"] - w_df["deduction"] + # While this is updated initially, we'll also update it on every + # loop to make sure we only pull from + w_df["remaining_balance"] = w_df["available_balance"] - w_df["deduction"] + + shortfall: int = int(target_amount) - w_df["deduction"].sum() + + while shortfall > 0: + # Distribute remaining cents to rows with the largest remainder + extra_idxs = ( + w_df[w_df["remaining_balance"] >= 1] + .sort_values(by="weight", ascending=False) + .index[:shortfall] + ) + w_df.loc[extra_idxs, "deduction"] += 1 + + shortfall: int = int(target_amount) - w_df["deduction"].sum() + w_df["remaining_balance"] = w_df["available_balance"] - w_df["deduction"] + + assert w_df[ + w_df["deduction"] > w_df["available_balance"] + ].empty, "Trying to deduct more from an Product than what is available" + + return w_df + + @staticmethod + def distribute_amount( + df: pd.DataFrame, + amount: USDCent, + weight_col="weight", + balance_col="remaining_balance", + ) -> pd.Series: + """ + Distributes an integer amount across dataframe rows proportionally, + ensuring the total equals exactly the desired amount (in cents). + + Parameters: + ----------- + df : pd.DataFrame + The dataframe with product information + amount : USDCent + The total amount to distribute (in cents) + weight_col : str + Column name containing the weights + balance_col : str + Column name containing the balance constraint + + Returns: + -------- + pd.Series + A series with integer allocations that sum to exactly the amount + """ + res_df = df.copy(deep=True) + + # Calculate ideal fractional allocation + ideal_allocation = res_df[weight_col] * int(amount) + + # Ensure we don't exceed available balance + ideal_allocation = np.minimum(ideal_allocation, res_df[balance_col]) + + # Start with floor values + allocation = np.floor(ideal_allocation).astype(int) + + # Calculate remainders + remainders = ideal_allocation - allocation + + # Distribute the remaining cents to rows with largest remainders + shortage = int(amount) - allocation.sum() + + if shortage > 0: + + assert shortage < len(remainders), ( + "The shortage cent amount must be less than or equal to the " + "length of the remainders if we intend of taking a penny " + "from each" + ) + + remainders.sort_values(ascending=False, inplace=True) + from itertools import islice + + # Add 1 cent to the top 'shortage' rows + for idx, value in islice(remainders.items(), shortage): + # Only add if it doesn't exceed the balance + if allocation.loc[idx] < df[balance_col].loc[idx]: + allocation.loc[idx] += 1 + + return allocation + + def create_from_ach_or_wire( + self, + business: Business, + amount: USDCent, + pm: ProductManager, + thl_lm: ThlLedgerManager, + created: Optional[datetime] = None, + transaction_id: Optional[str] = None, + ) -> Optional[BusinessPayoutEvent]: + """This records a single banking transfer to a supplier. Takes a + specific Business that was paid out and how much. It then determines + how to distribute the amount to each Brokerage Product in the + Business. + + :param business + :param amount + :param pm + :param thl_lm: this must have rw permissions to add transactions to + the ledger + :param created + :param transaction_id + + :return: + """ + assert business.balance is not None, ( + "Must provide a full version of a Business in order to calculate" + "the required Brokerage Product amounts." + ) + + assert amount > 100_00, "Must issue Supplier Payouts at least $100 minimum." + LOG.warning("Paying out ") + + if created: + LOG.warning("Payouts in the past, require the parquet files to be rebuilt.") + assert created < datetime.now(tz=timezone.utc) + + else: + created = datetime.now(tz=timezone.utc) + + # Gather the total amount available balance from each and put into + # a simple DF. We're using the available balance because we need it + # to always be positive.. and we never want to get into a negative + # situation again, so it's best to be extra conservative. + res = { + pb.product_id: pb.available_balance + for pb in business.balance.product_balances + } + df = pd.DataFrame.from_dict(res, orient="index").reset_index() + df.columns = ["product_id", "available_balance"] + + res = BusinessPayoutEventManager.recoup_proportional( + df=df, target_amount=business.balance.recoup + ) + + # Can't pay any Products that don't have a remaining balance + res = res[res["remaining_balance"] > 0] + + assert ( + res.deduction.sum() == business.balance.recoup + ), "recoup_proportional failure" + + res["issue_amount"] = BusinessPayoutEventManager.distribute_amount( + df=res, amount=amount + ) + + assert res.issue_amount.sum() == amount, "issue_amount failure" + + # Can't pay any Products that don't have an issue amount + res = res[res["issue_amount"] > 0] + + recouped_amounts: List[Dict[str, int]] = res[ + ["product_id", "remaining_balance", "issue_amount"] + ].to_dict(orient="records") + + # Get all of the products at once so we're not doing it for every interation + products = pm.get_by_uuids( + product_uuids=[i["product_id"] for i in recouped_amounts] + ) + + bp_payouts: List[BrokerageProductPayoutEvent] = [] + for idx, item in enumerate(recouped_amounts): + product = next((p for p in products if p.uuid == item["product_id"]), None) + assert product is not None + + try: + bp_pe: BrokerageProductPayoutEvent = self.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(item["issue_amount"]), + created=created + timedelta(milliseconds=idx + 1), + ext_ref_id=transaction_id, + ) + + assert bp_pe.status == PayoutStatus.COMPLETE + bp_payouts.append(bp_pe) + + except (Exception,) as e: + # Cleanup bp_payouts + print("Exception", e) + return None + + if bp_pe.status == PayoutStatus.FAILED: + sleep(1) + + try: + bp_pe = self.retry_create_bp_payout_event_tx( + thl_ledger_manager=thl_lm, + product=product, + payout_event_uuid=bp_pe.uuid, + ) + assert bp_pe.status == PayoutStatus.COMPLETE + bp_payouts.append(bp_pe) + + except (Exception,) as e: + # Cleanup bp_payouts + return None + + return BusinessPayoutEvent.model_validate({"bp_payouts": bp_payouts}) diff --git a/generalresearch/managers/thl/product.py b/generalresearch/managers/thl/product.py new file mode 100644 index 0000000..c7d38c2 --- /dev/null +++ b/generalresearch/managers/thl/product.py @@ -0,0 +1,570 @@ +import json +import logging +import operator +from datetime import timezone, datetime +from decimal import Decimal +from threading import Lock +from typing import Collection, Optional, List, TYPE_CHECKING, Union +from uuid import uuid4, UUID + +from cachetools import TTLCache, cachedmethod, keys +from more_itertools import chunked +from psycopg import Cursor +from pydantic import ValidationError +from sentry_sdk import capture_exception + +from generalresearch.decorators import LOG +from generalresearch.managers.base import ( + Permission, + PostgresManager, +) +from generalresearch.models.custom_types import UUIDStr, is_valid_uuid +from generalresearch.pg_helper import PostgresConfig + +logger = logging.getLogger() + +if TYPE_CHECKING: + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.product import ( + UserCreateConfig, + PayoutConfig, + SessionConfig, + UserWalletConfig, + SourcesConfig, + UserHealthConfig, + ProfilingConfig, + SupplyConfigs, + ) + + +class ProductManager(PostgresManager): + def __init__( + self, + pg_config: PostgresConfig, + permissions: Collection[Permission] = None, + ): + super().__init__(pg_config=pg_config, permissions=permissions) + self.uuid_cache = TTLCache(maxsize=1024, ttl=5 * 60) + self.uuid_lock = Lock() + + def cache_clear(self, product_uuid: UUIDStr) -> None: + # Calling get_by_uuid with or without kwargs hits different internal keys in the cache! + with self.uuid_lock: + self.uuid_cache.pop(keys.hashkey(product_uuid), None) + self.uuid_cache.pop(keys.hashkey(product_uuid=product_uuid), None) + + @cachedmethod( + operator.attrgetter("uuid_cache"), lock=operator.attrgetter("uuid_lock") + ) + def get_by_uuid( + self, + product_uuid: UUIDStr, + ) -> "Product": + assert is_valid_uuid(product_uuid), "invalid uuid" + res = self.fetch_uuids( + product_uuids=[product_uuid], + ) + # do this so we uniformly raise AssertionErrors + assert len(res) == 1, "product not found" + return res[0] + + def get_by_uuids( + self, + product_uuids: List[UUIDStr], + ) -> List["Product"]: + + res = self.fetch_uuids( + product_uuids=product_uuids, + ) + assert len(product_uuids) == len(res), "incomplete product response" + return res + + @cachedmethod( + operator.attrgetter("uuid_cache"), lock=operator.attrgetter("uuid_lock") + ) + def get_by_uuid_if_exists( + self, + product_uuid: UUIDStr, + ) -> Optional["Product"]: + # many=False, raise_on_error=False + try: + return self.fetch_uuids( + product_uuids=[product_uuid], + )[0] + except (AssertionError,): + return None + except (IndexError,): + return None + + def get_by_uuids_if_exists( + self, + product_uuids: List[UUIDStr], + ) -> List["Product"]: + # Same as .get_by_uuids but doesn't raise Exception if len(product_uuids) != len(res) + return self.fetch_uuids( + product_uuids=product_uuids, + ) + + def get_all(self, rand_limit: Optional[int]) -> List["Product"]: + product_uuids = self.get_all_uuids(rand_limit=rand_limit) + return self.fetch_uuids(product_uuids=product_uuids) + + def get_all_uuids(self, rand_limit: Optional[int]) -> List[UUIDStr]: + + if rand_limit: + res = self.pg_config.execute_sql_query( + query=f""" + SELECT p.id::uuid + FROM userprofile_brokerageproduct AS p + ORDER BY RANDOM() + LIMIT %s + """, + params=[rand_limit], + ) + + else: + res = self.pg_config.execute_sql_query( + query=f""" + SELECT p.id::uuid + FROM userprofile_brokerageproduct AS p + """ + ) + return [i["id"] for i in res] + + def fetch_uuids( + self, + product_uuids: Optional[List[UUIDStr]] = None, + business_uuids: Optional[List[UUIDStr]] = None, + team_uuids: Optional[List[UUIDStr]] = None, + ) -> List["Product"]: + LOG.debug(f"PM.fetch_uuids({product_uuids=}, {business_uuids=}, {team_uuids=})") + + assert ( + sum( + bool(x) # This will also be False is the array is empty + for x in [product_uuids, business_uuids, team_uuids] + ) + == 1 + ), "Can only provide one set of identifiers" + + filter_column = None + filter_uuids = None + if bool(product_uuids): + assert all(is_valid_uuid(v) for v in product_uuids), "invalid uuid passed" + filter_column = "id" + filter_uuids = product_uuids + elif bool(business_uuids): + assert all(is_valid_uuid(v) for v in business_uuids), "invalid uuid passed" + filter_column = "business_id" + filter_uuids = business_uuids + elif bool(team_uuids): + assert all(is_valid_uuid(v) for v in team_uuids), "invalid uuid passed" + filter_column = "team_id" + filter_uuids = team_uuids + + assert filter_column is not None + + if filter_uuids is None or len(filter_uuids) == 0: + return [] + + with self.pg_config.make_connection() as sql_connection: + with sql_connection.cursor() as c: + res = [] + for chunk in chunked(filter_uuids, 500): + res.extend( + self.fetch_uuids_( + c=c, filter_uuids=chunk, filter_column=filter_column + ) + ) + return res + + def fetch_uuids_( + self, c: Cursor, filter_uuids: List[UUIDStr], filter_column: str + ) -> List["Product"]: + from generalresearch.models.thl.product import Product + + assert len(filter_uuids) <= 500, "chunk me" + assert filter_column in {"id", "business_id", "team_id"} + + # Step 1: Retrieve the basic columns from the "Product table" + query = f""" + SELECT + bp.id, + bp.id_int, + bp.name, + bp.enabled, + bp.created::timestamptz, + bp.team_id::uuid, + bp.business_id::uuid, + bp.commission AS commission_pct, + bp.grs_domain as harmonizer_domain, + bp.redirect_url, + bp.session_config::jsonb, + bp.payout_config::jsonb, + bp.user_create_config::jsonb, + bp.offerwall_config::jsonb, + bp.profiling_config::jsonb, + bp.user_health_config::jsonb, + bp.yield_man_config::jsonb, + t.tags + FROM userprofile_brokerageproduct AS bp + LEFT JOIN ( + SELECT product_id, STRING_AGG(tag, ',') as tags + FROM userprofile_brokerageproducttag + GROUP BY product_id + ) t ON t.product_id = bp.id_int + WHERE {filter_column} = ANY(%s) + """ + + c.execute(query, [list(filter_uuids)]) + + res = c.fetchall() + + if len(res) == 0: + return [] + for x in res: + x["id"] = UUID(x["id"]).hex + x["team_id"] = UUID(x["team_id"]).hex if x["team_id"] else None + x["business_id"] = UUID(x["business_id"]).hex if x["business_id"] else None + x["tags"] = set(x["tags"].split(",")) if x["tags"] else set() + + res1 = {i["id"]: i for i in res} + + # Step 2: Retrieve additional metadata from the "Product Config table" + c.execute( + query=""" + SELECT bpc.product_id::uuid as product_id, bpc.key, bpc.value::jsonb + FROM userprofile_brokerageproductconfig AS bpc + WHERE product_id = ANY(%s) + AND key IN ('sources_config', 'user_wallet') + """, + # Pulling from keys b/c no reason to try to retrieve any config + # k,v rows for products that we know aren't in the other table. + params=[list(res1.keys())], + ) + kv_res = c.fetchall() + for item in kv_res: + item["value"] = item["value"][item["key"]] + if item["key"] == "user_wallet": + item["key"] = "user_wallet_config" + + # Step 2.1: go through them all, and add the key,vals to the correct + # Product in the dictionary + for item in kv_res: + k: str = item["key"] + product_id: str = UUID(item["product_id"]).hex + res1[product_id][k] = item["value"] + r = [] + for k, v in res1.items(): + try: + r.append(Product.model_validate(v)) + except ValidationError as e: + logger.info(f"failed to parse product: {k}") + raise e + return r + + def create_dummy( + self, + product_id: Optional[UUIDStr] = None, + team_id: Optional[UUIDStr] = None, + business_id: Optional[UUIDStr] = None, + name: Optional[str] = None, + redirect_url: Optional[str] = None, + harmonizer_domain: Optional[str] = None, + commission_pct: Decimal = Decimal("0.05000"), + sources_config: Optional[Union["SourcesConfig", "SupplyConfigs"]] = None, + payout_config: Optional["PayoutConfig"] = None, + session_config: Optional["SessionConfig"] = None, + profiling_config: Optional["ProfilingConfig"] = None, + user_wallet_config: Optional["UserWalletConfig"] = None, + user_create_config: Optional["UserCreateConfig"] = None, + user_health_config: Optional["UserHealthConfig"] = None, + ) -> "Product": + """To be used in tests, where we don't care about certain fields""" + product_id = product_id if product_id else uuid4().hex + team_id = team_id if team_id else uuid4().hex + name = name if name else f"name-{product_id[:12]}" + redirect_url = redirect_url if redirect_url else "https://www.example.com/" + + return self.create( + product_id=product_id, + team_id=team_id, + business_id=business_id, + name=name, + redirect_url=redirect_url, + harmonizer_domain=harmonizer_domain, + commission_pct=commission_pct, + sources_config=sources_config, + payout_config=payout_config, + session_config=session_config, + profiling_config=profiling_config, + user_wallet_config=user_wallet_config, + user_create_config=user_create_config, + user_health_config=user_health_config, + ) + + def create( + self, + product_id: UUIDStr, + team_id: UUIDStr, + name: str, + redirect_url: str, + business_id: Optional[UUIDStr] = None, + harmonizer_domain: Optional[str] = None, + commission_pct: Decimal = Decimal("0.05"), + sources_config: Optional[Union["SourcesConfig", "SupplyConfigs"]] = None, + payout_config: Optional["PayoutConfig"] = None, + session_config: Optional["SessionConfig"] = None, + profiling_config: Optional["ProfilingConfig"] = None, + user_wallet_config: Optional["UserWalletConfig"] = None, + user_create_config: Optional["UserCreateConfig"] = None, + user_health_config: Optional["UserHealthConfig"] = None, + ) -> "Product": + """Create a Product with all the basic defaults and return the instance""" + from generalresearch.models.thl.product import ( + UserCreateConfig, + PayoutConfig, + SessionConfig, + UserWalletConfig, + SourcesConfig, + UserHealthConfig, + ProfilingConfig, + Product, + ) + + now = datetime.now(tz=timezone.utc) + + # TODO: Add product_id, and possibly name uniqueness validation to the + # pydantic model definition itself. The create manager doesn't need + # to do this IMO.. but it also means it'll need to be fast and simple + # in the model validation steps. + + product_data = { + "id": product_id, + "name": name, + "created": now, + "team_id": team_id, + "business_id": business_id, + "commission_pct": commission_pct, + "redirect_url": redirect_url, + "sources_config": sources_config or SourcesConfig(), + "payout_config": payout_config or PayoutConfig(), + "session_config": session_config or SessionConfig(), + "profiling_config": profiling_config or ProfilingConfig(), + "user_wallet_config": user_wallet_config or UserWalletConfig(), + "user_create_config": user_create_config or UserCreateConfig(), + "user_health_config": user_health_config or UserHealthConfig(), + } + # If not defined, we want the default to be used. So we can't pass + # it in or else the validators fail. + if harmonizer_domain: + product_data["harmonizer_domain"] = harmonizer_domain + + instance = Product.model_validate(product_data) + + # Notes: I intentionally removed the name update stuff in here. IMO + # we should have an update method on the manager to handle any of the + # possible update operations and be explicit about it. + + # Notes: I intentionally removed the ledger key lock now that we're + # not using it for any of the accounting work. It's not worth trying + # to carry forward in any form. + + # Goes in BPC: sources_config, user_wallet + insert_data = instance.model_dump_mysql( + include={ + "id", + "name", + "created", + "enabled", + "team_id", + "business_id", + "commission_pct", + "harmonizer_domain", + "redirect_url", + # JSON configs + "payout_config", + "session_config", + "user_create_config", + # We haven't done anything with these, but for mysql + # they need to be passed + "offerwall_config", + "profiling_config", + "user_health_config", + "yield_man_config", + } + ) + # These things don't have the same name in the db + insert_data["commission"] = str(instance.commission_pct) + insert_data["grs_domain"] = insert_data.pop("harmonizer_domain") + insert_data["payments_enabled"] = instance.payments_enabled + + try: + insert_data["id_int"] = list( + self.pg_config.execute_sql_query( + f""" + SELECT COALESCE(MAX(id_int), 0) + 1 as id_int + FROM userprofile_brokerageproduct + """ + ) + )[0]["id_int"] + instance.id_int = insert_data["id_int"] + + query = """ + INSERT INTO userprofile_brokerageproduct ( + id, name, created, enabled, payments_enabled, + team_id, business_id, + commission, grs_domain, redirect_url, + session_config, payout_config, + user_create_config, offerwall_config, + profiling_config, user_health_config, + yield_man_config, id_int + ) + VALUES ( + %(id)s, %(name)s, %(created)s, %(enabled)s, %(payments_enabled)s, + %(team_id)s, %(business_id)s, + %(commission)s, %(grs_domain)s, %(redirect_url)s, + %(session_config)s, %(payout_config)s, + %(user_create_config)s, %(offerwall_config)s, + %(profiling_config)s, %(user_health_config)s, + %(yield_man_config)s, %(id_int)s + ); + """ + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, params=insert_data) + conn.commit() + + # I'm not going to be specific here because we will expand this soon + # to store in a single table / new datastore + # + # from pymysql import IntegrityError + # except IntegrityError as e: + except (Exception,) as e: + + try: + return self.get_by_uuid(product_uuid=instance.id) + except (Exception,) as e2: + pass + finally: + self.cache_clear(instance.id) + + # If we couldn't find the Product, then go ahead and raise. + capture_exception(e) + raise e + + bpconfig = instance.model_dump( + include={"sources_config", "user_wallet"}, mode="json" + ) + + bpc = {k: json.dumps({k: v}) for k, v in bpconfig.items()} + values = [[k, v, instance.id] for k, v in bpc.items()] + + query = """ + INSERT INTO userprofile_brokerageproductconfig + (key,value,product_id) + VALUES (%s, %s, %s); + """ + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.executemany(query, values) + conn.commit() + + # We should clear the cache here, b/c we might have tried to get it before, + # using get_by_uuid_if_exists, which set the cache to None + self.cache_clear(product_uuid=product_id) + + return instance + + def update(self, new_product: "Product") -> None: + product_uuid = new_product.id + old_product = self.get_by_uuid(product_uuid=product_uuid) + old_dump = old_product.model_dump(mode="json") + new_dump = new_product.model_dump(mode="json") + assert set(old_dump.keys()) == set(new_dump.keys()) + + keys_to_update = set() + for k in set(old_dump.keys()): + if old_dump[k] != new_dump[k]: + keys_to_update.add(k) + + not_allowed = {"id", "created", "team_id", "business_id"} + if keys_to_update & not_allowed: + raise ValueError(f"Not allowed to change: {keys_to_update & not_allowed}") + + if not keys_to_update: + return None + + in_bp_keys = { + "name", + "enabled", + "team_id", + "redirect_url", + "session_config", + "payout_config", + "user_create_config", + "offerwall_config", + "profiling_config", + "user_health_config", + "yield_man_config", + # naming ---- ... + "commission", + "harmonizer_domain", + "grs_domain", + } + in_bpc_keys = {"sources_config", "user_wallet", "user_wallet_config"} + if keys_to_update & in_bp_keys: + data = new_product.model_dump_mysql() + # These things don't have the same name in the db + data["commission"] = str(new_product.commission_pct) + data["grs_domain"] = data.pop("harmonizer_domain") + data = {k: v for k, v in data.items() if k in in_bp_keys} + data["id"] = product_uuid + update_str = ", ".join(f"{k}=%({k})s" for k in data.keys()) + self.pg_config.execute_write( + f""" + UPDATE userprofile_brokerageproduct + SET {update_str} + WHERE id = %(id)s + """, + data, + ) + + if keys_to_update & in_bpc_keys: + bpconfig = new_product.model_dump( + include={"sources_config", "user_wallet"}, mode="json" + ) + + bpc = {k: json.dumps({k: v}) for k, v in bpconfig.items()} + data = [] + if "sources_config" in keys_to_update: + data.append( + { + "id": product_uuid, + "key": "sources_config", + "value": bpc["sources_config"], + }, + ) + if "user_wallet_config" in keys_to_update: + data.append( + { + "id": product_uuid, + "key": "user_wallet", + "value": bpc["user_wallet"], + } + ) + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + for d in data: + c.execute( + """ + UPDATE userprofile_brokerageproductconfig + SET value = %(value)s + WHERE product_id = %(id)s AND key = %(key)s + """, + d, + ) + conn.commit() + + self.cache_clear(product_uuid) diff --git a/generalresearch/managers/thl/profiling/__init__.py b/generalresearch/managers/thl/profiling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/managers/thl/profiling/question.py b/generalresearch/managers/thl/profiling/question.py new file mode 100644 index 0000000..1ad27ac --- /dev/null +++ b/generalresearch/managers/thl/profiling/question.py @@ -0,0 +1,157 @@ +import random +import threading +from typing import Collection, List, Tuple + +from cachetools import cached, TTLCache +from pydantic import ValidationError + +from generalresearch.decorators import LOG +from generalresearch.managers.base import PostgresManager +from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + UPKImportance, +) + + +class QuestionManager(PostgresManager): + + def get_multi_upk(self, question_ids: Collection[str]) -> List[UpkQuestion]: + query = """ + SELECT data, property_code, explanation_template, explanation_fragment_template + FROM marketplace_question + WHERE id = ANY(%(question_ids)s); + """ + res = self.pg_config.execute_sql_query( + query, {"question_ids": list(question_ids)} + ) + for x in res: + x["data"]["ext_question_id"] = x["property_code"] + x["data"]["explanation_template"] = x["explanation_template"] + x["data"]["explanation_fragment_template"] = x[ + "explanation_fragment_template" + ] + x["data"].pop("categories", None) + return [UpkQuestion.model_validate(x["data"]) for x in res] + + @cached( + cache=TTLCache(maxsize=256, ttl=3600 + random.randint(-900, 900)), + lock=threading.Lock(), + info=True, + ) + def get_questions_ranked( + self, country_iso: str, language_iso: str + ) -> List[UpkQuestion]: + query = """ + SELECT data, property_code, explanation_template, explanation_fragment_template + FROM marketplace_question + WHERE country_iso = %(country_iso)s + AND language_iso = %(language_iso)s + AND property_code NOT LIKE 'gr:%%' + AND property_code NOT LIKE 'g:%%' + AND is_live + """ + res = self.pg_config.execute_sql_query( + query=query, + params={"country_iso": country_iso, "language_iso": language_iso}, + ) + qs: List[UpkQuestion] = [] + for x in res: + x["data"]["ext_question_id"] = x["property_code"] + x["data"]["explanation_template"] = x["explanation_template"] + x["data"]["explanation_fragment_template"] = x[ + "explanation_fragment_template" + ] + x["data"].pop("categories", None) + q = UpkQuestion.model_validate(x["data"]) + if not q.importance: + q.importance = UPKImportance( + task_count=x["data"].get("task_count", 0), + task_score=x["data"].get("task_score", 0), + ) + qs.append(q) + + res = sorted(qs, key=lambda x: x.importance.task_score, reverse=True) + return res + + @cached( + cache=TTLCache(maxsize=256, ttl=3600 + random.randint(-900, 900)), + lock=threading.Lock(), + info=True, + ) + def lookup_by_property( + self, property_code: str, country_iso: str, language_iso: str + ) -> UpkQuestion: + query = f""" + SELECT data, property_code, explanation_template, explanation_fragment_template + FROM marketplace_question + WHERE property_code = %(property_code)s + AND country_iso = %(country_iso)s + AND language_iso = %(language_iso)s + LIMIT 2; + """ + params = { + "property_code": property_code, + "country_iso": country_iso, + "language_iso": language_iso, + } + res = self.pg_config.execute_sql_query(query=query, params=params) + assert len(res) == 1, f"expected 1, got {len(res)} results" + x = res[0] + x["data"]["ext_question_id"] = x["property_code"] + x["data"]["explanation_template"] = x["explanation_template"] + x["data"]["explanation_fragment_template"] = x["explanation_fragment_template"] + x["data"].pop("categories", None) + return UpkQuestion.model_validate(x["data"]) + + def filter_by_property( + self, lookup: Collection[Tuple[str, str, str]] + ) -> List[UpkQuestion]: + """ + lookup is [(property_code, country_iso, language_iso)] + """ + where_str = " OR ".join( + "(property_code = %s AND country_iso = %s AND language_iso = %s)" + for _ in lookup + ) + query = f""" + SELECT data, property_code, explanation_template, explanation_fragment_template + FROM marketplace_question + WHERE {where_str} + """ + flat_params = [item for tup in lookup for item in tup] + res = self.pg_config.execute_sql_query(query, params=flat_params) + for x in res: + x["data"]["ext_question_id"] = x["property_code"] + x["data"]["explanation_template"] = x["explanation_template"] + x["data"]["explanation_fragment_template"] = x[ + "explanation_fragment_template" + ] + x["data"].pop("categories", None) + res2 = [] + for x in res: + try: + res2.append(UpkQuestion.model_validate(x["data"])) + except ValidationError as e: + LOG.warning(e) + return res2 + + def update_question_explanation(self, q: UpkQuestion): + # Assuming the question already exists in the db, and we're updating + # the fields explanation_template and explanation_fragment_template + assert q.id, "q.id must be set" + query = """ + UPDATE marketplace_question + SET explanation_template = %(explanation_template)s, + explanation_fragment_template = %(explanation_fragment_template)s + WHERE id = %(id)s;""" + params = { + "id": q.id, + "explanation_template": q.explanation_template, + "explanation_fragment_template": q.explanation_fragment_template, + } + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, params) + assert c.rowcount == 1 + conn.commit() + return None diff --git a/generalresearch/managers/thl/profiling/schema.py b/generalresearch/managers/thl/profiling/schema.py new file mode 100644 index 0000000..581270b --- /dev/null +++ b/generalresearch/managers/thl/profiling/schema.py @@ -0,0 +1,75 @@ +from threading import RLock +from typing import List +from uuid import UUID + +from cachetools import cached, TTLCache + +from generalresearch.managers.base import PostgresManager +from generalresearch.models.thl.profiling.upk_property import ( + UpkProperty, +) + + +class UpkSchemaManager(PostgresManager): + + @cached(cache=TTLCache(maxsize=1, ttl=18 * 60), lock=RLock()) + def get_props_info(self) -> List[UpkProperty]: + query = """ + SELECT + p.id AS property_id, + p.label AS property_label, + p.cardinality, + p.prop_type, + pc.country_iso, + pc.gold_standard, + allowed_items.allowed_items, + cats.categories + FROM marketplace_property p + JOIN marketplace_propertycountry pc + ON p.id = pc.property_id + -- allowed_items: all items for this property + country + LEFT JOIN LATERAL ( + SELECT jsonb_agg( + jsonb_build_object( + 'id', mi.id, + 'label', mi.label, + 'description', mi.description + ) ORDER BY mi.label + ) AS allowed_items + FROM marketplace_propertyitemrange pir + JOIN marketplace_item mi + ON pir.item_id = mi.id + WHERE pir.property_id = p.id + AND pir.country_iso = pc.country_iso + ) allowed_items ON TRUE + + -- categories: all categories for this property + LEFT JOIN LATERAL ( + SELECT + jsonb_agg( + jsonb_build_object( + 'uuid', cat.uuid, + 'label', cat.label, + 'path', cat.path, + 'adwords_vertical_id', cat.adwords_vertical_id + ) + ) AS categories + FROM marketplace_propertycategoryassociation pcat + JOIN marketplace_category cat ON pcat.category_id = cat.id + WHERE pcat.property_id = p.id + ) AS cats ON TRUE; + """ + res = self.pg_config.execute_sql_query(query) + for x in res: + for c in x["categories"]: + c["uuid"] = UUID(c["uuid"]).hex + if x["allowed_items"]: + for c in x["allowed_items"]: + c["id"] = UUID(c["id"]).hex + return [UpkProperty.model_validate(x) for x in res] + + def get_props_info_for_country(self, country_iso: str) -> List[UpkProperty]: + assert country_iso.lower() == country_iso + res = self.get_props_info() + res = [x for x in res if x.country_iso == country_iso].copy() + return res diff --git a/generalresearch/managers/thl/profiling/uqa.py b/generalresearch/managers/thl/profiling/uqa.py new file mode 100644 index 0000000..6800d32 --- /dev/null +++ b/generalresearch/managers/thl/profiling/uqa.py @@ -0,0 +1,211 @@ +import logging +from datetime import datetime, timedelta, timezone +from typing import Collection, List, Optional + +from generalresearch.managers.base import PostgresManagerWithRedis +from generalresearch.models.thl.profiling.user_question_answer import ( + DUMMY_UQA, + UserQuestionAnswer, +) +from generalresearch.models.thl.user import User + +logger = logging.getLogger() + + +class UQAManager(PostgresManagerWithRedis): + + CACHE_PREFIX = "thl-grpc:uqa-cache-v2" + + def redis_key(self, user: User) -> str: + return f"{self.CACHE_PREFIX}:{user.user_id}" + + def redis_lock_key(self, user: User) -> str: + return self.redis_key(user) + ":lock" + + def update_cache( + self, + user: User, + uqas: List[UserQuestionAnswer], + ): + """ + Adds new answers to the redis cache for this user. If the cache + doesn't exist, does a db query to populate the cache. + """ + REDIS = self.redis_client + redis_key = self.redis_key(user) + redis_lock = self.redis_lock_key(user) + json_answers = [uqa.model_dump_json() for uqa in uqas] + with REDIS.lock(redis_lock, timeout=2): + # Append the new answers into the cache. + res = REDIS.rpushx(redis_key, *json_answers) + if res == 0: + # If the cache doesn't exist, we need to make it from scratch + uqas = self.get_from_db(user) + all_json_answers = {uqa.model_dump_json() for uqa in uqas} + # And then make sure thew new answers are in it (b/c we query from the RR) + all_json_answers.update(set(json_answers)) + REDIS.rpush(redis_key, *all_json_answers) + REDIS.expire(redis_key, 3600) + self.clear_user_demographic_cache(user) + return res + + def clear_cache(self, user: User) -> None: + self.redis_client.delete(self.redis_key(user)) + + def recreate_cache(self, user: User) -> Collection[UserQuestionAnswer]: + REDIS = self.redis_client + redis_key = self.redis_key(user) + redis_lock = self.redis_lock_key(user) + with REDIS.lock(redis_lock, timeout=2): + # once we've acquired the lock, we need to check again if someone else just made it + if REDIS.exists(redis_key): + values = REDIS.lrange(redis_key, 0, -1) + uqas = [UserQuestionAnswer.model_validate_json(x) for x in values] + else: + uqas = self.get_from_db(user) + if not uqas: + # We can't set this to an empty list in redis (no + # difference between None and []). + # + # I don't know what is the best thing to do here, so I'm + # gonna push in a "dummy" UQA, just to indicate this has + # been set and we shouldn't query the db anymore unless + # something changes... + REDIS.rpush(redis_key, DUMMY_UQA.model_dump_json()) + else: + REDIS.rpush(redis_key, *[uqa.model_dump_json() for uqa in uqas]) + REDIS.expire(redis_key, 3600 * 3 * 24) + + uqas = self._dedupe_and_clean_uqas(uqas) + self.clear_user_demographic_cache(user) + return uqas + + def _dedupe_and_clean_uqas( + self, + uqas: List[UserQuestionAnswer], + ) -> List[UserQuestionAnswer]: + # Remove anything older than 30 days + uqas = [uqa for uqa in uqas if not uqa.is_stale()] + + # Dedupe, latest answer per question + new_uqas = set() + seen_question_ids = set() + uqas = sorted(uqas, key=lambda x: x.timestamp, reverse=True) + for uqa in uqas: + if uqa.question_id not in seen_question_ids: + seen_question_ids.add(uqa.question_id) + new_uqas.add(uqa) + + return sorted(new_uqas, key=lambda x: x.timestamp, reverse=True) + + def get(self, user: User) -> List[UserQuestionAnswer]: + uqas = self.get_from_cache(user=user) + + if uqas is None: + uqas = self.recreate_cache(user) + return self._dedupe_and_clean_uqas(uqas) + + def get_from_cache(self, user: User) -> Optional[List[UserQuestionAnswer]]: + redis_key = self.redis_key(user) + + # Do the exists check and the list retrieval in a single transaction + with self.redis_client.pipeline() as pipe: + exists = pipe.exists(redis_key) + if exists: + pipe.lrange(redis_key, 0, -1) + result = pipe.execute() + exists = result[0] + values = result[1] + if not exists: + logger.info(f"{redis_key} doesn't exist") + return None + uqas = [UserQuestionAnswer.model_validate_json(x) for x in values] + logger.info(f"{redis_key} exists") + return uqas + + def get_from_db(self, user: User) -> List[UserQuestionAnswer]: + logger.info(f"get_uqa_from_db: {user.user_id}") + # Only store the latest row per question_id. We don't need it multiple times. + since = datetime.now(tz=timezone.utc) - timedelta(days=30) + + # We CAN use the RR, b/c either + # 1) the cache expired and the user hasn't sent an answer recently + # or 2) The user just sent an answer, so we'll make sure it gets put into the results + # after this query runs. + query = f""" + WITH ranked AS ( + SELECT + uqa.*, + ROW_NUMBER() OVER ( + PARTITION BY question_id + ORDER BY created DESC + ) AS rn + FROM marketplace_userquestionanswer uqa + WHERE uqa.user_id = %(user_id)s + AND uqa.created > %(since)s + ) + SELECT + r.question_id::uuid, + r.created::timestamptz AS timestamp, + r.calc_answer::jsonb AS calc_answers, + r.answer::jsonb, + r.user_id, + mq.property_code, + mq.country_iso, + mq.language_iso + FROM ranked r + JOIN marketplace_question mq + ON r.question_id = mq.id + WHERE rn = 1 + ORDER BY r.created; + """ + res = self.pg_config.execute_sql_query( + query=query, + params={"user_id": user.user_id, "since": since}, + ) + uqas = [UserQuestionAnswer.model_validate(x) for x in res] + return uqas + + def clear_user_demographic_cache( + self, + user: User, + ) -> None: + # this will get regenerated by thl-grpc when an offerwall call is made + redis_key = f"thl-grpc:user-demographics:{user.user_id}" + self.redis_client.delete(redis_key) + + return None + + def create( + self, + user: User, + uqas: List[UserQuestionAnswer], + session_id: Optional[str] = None, + ): + for uqa in uqas: + if uqa.user_id is None: + uqa.user_id = user.user_id + else: + assert uqa.user_id == user.user_id + self.create_in_db(uqas=uqas, session_id=session_id) + self.update_cache(user=user, uqas=uqas) + return None + + def create_in_db( + self, uqas: List[UserQuestionAnswer], session_id: Optional[str] = None + ): + values = [uqa.model_dump_mysql(session_id=session_id) for uqa in uqas] + query = """ + INSERT INTO marketplace_userquestionanswer + (created, session_id, answer, question_id, user_id, calc_answer) + VALUES ( + %(created)s, %(session_id)s, %(answer)s, + %(question_id)s, %(user_id)s, %(calc_answer)s + ); + """ + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.executemany(query=query, params_seq=values) + conn.commit() + return None diff --git a/generalresearch/managers/thl/profiling/user_upk.py b/generalresearch/managers/thl/profiling/user_upk.py new file mode 100644 index 0000000..4449afa --- /dev/null +++ b/generalresearch/managers/thl/profiling/user_upk.py @@ -0,0 +1,343 @@ +import json +from collections import defaultdict +from datetime import timedelta, datetime, timezone +from typing import Dict, Union, Set, List, Collection, Optional, Tuple +from uuid import UUID + +from psycopg import Cursor + +from generalresearch.managers.base import ( + Permission, + PostgresManagerWithRedis, +) +from generalresearch.managers.thl.profiling.schema import UpkSchemaManager +from generalresearch.models.thl.profiling.upk_property import ( + Cardinality, + PropertyType, + UpkProperty, +) +from generalresearch.models.thl.profiling.upk_question_answer import ( + UpkQuestionAnswer, +) +from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig + + +class UserUpkManager(PostgresManagerWithRedis): + def __init__( + self, + pg_config: PostgresConfig, + redis_config: RedisConfig, + permissions: Collection[Permission] = None, + cache_prefix: Optional[str] = None, + ): + super().__init__( + pg_config=pg_config, + redis_config=redis_config, + permissions=permissions, + cache_prefix=cache_prefix, + ) + self.upk_schema_manager = UpkSchemaManager(pg_config=pg_config) + + def clear_upk_cache(self, user_id: int) -> None: + self.redis_client.delete(f"thl-grpc:user-upk:{user_id}") + return None + + def get_user_upk(self, user_id: int) -> List[UpkQuestionAnswer]: + res = self.redis_client.get(f"thl-grpc:user-upk:{user_id}") + if res: + return [UpkQuestionAnswer.model_validate(x) for x in json.loads(res)] + res = self.get_user_upk_mysql(user_id) + value = json.dumps([x.model_dump(mode="json") for x in res]) + self.redis_client.set(f"thl-grpc:user-upk:{user_id}", value, ex=60 * 60 * 24) + return res + + def get_user_upk_mysql(self, user_id: int) -> List[UpkQuestionAnswer]: + since = datetime.now(tz=timezone.utc) - timedelta(days=89) + + query = """ + SELECT + x.property_id, + mp.label AS property_label, + mi.id AS item_id, + mi.label AS item_label, + mp.prop_type, + mp.cardinality, + x.created, + x.value_num, + x.value_text, + x.country_iso + FROM ( + SELECT + property_id, + value::uuid AS item_id, + NULL::numeric AS value_num, + NULL::text AS value_text, + created, + country_iso + FROM marketplace_userprofileknowledgeitem AS upki + WHERE user_id = %(user_id)s AND created > %(since)s + + UNION ALL + + SELECT + property_id, + NULL::uuid AS item_id, + value::numeric AS value_num, + NULL::text AS value_text, + created, + country_iso + FROM marketplace_userprofileknowledgenumerical + WHERE user_id = %(user_id)s AND created > %(since)s + + UNION ALL + + SELECT + property_id, + NULL::uuid AS item_id, + NULL::numeric AS value_num, + value::text AS value_text, + created, + country_iso + FROM marketplace_userprofileknowledgetext + WHERE user_id = %(user_id)s AND created > %(since)s + ) x + JOIN marketplace_property mp + ON x.property_id = mp.id + LEFT JOIN marketplace_item mi + ON x.item_id = mi.id::uuid; + """ + params = {"user_id": user_id, "since": since} + res = self.pg_config.execute_sql_query(query, params=params) + for x in res: + x["user_id"] = user_id + x["property_id"] = UUID(x["property_id"]).hex + if x["item_id"]: + x["item_id"] = UUID(x["item_id"]).hex + return [UpkQuestionAnswer.model_validate(x) for x in res] + + def get_user_upk_simple( + self, user_id, country_iso="us" + ) -> Dict[str, Union[Set[str], str, float]]: + res = self.get_user_upk(user_id=user_id) + res = [x for x in res if x.country_iso == country_iso] + d: Dict[str, Union[Set[str], str, float]] = defaultdict(set) + for x in res: + if x.cardinality == Cardinality.ZERO_OR_ONE: + d[x.property_label] = x.value + else: + d[x.property_label].add(x.value) + return dict(d) + + def get_age_gender( + self, user_id, country_iso="us" + ) -> Tuple[Optional[int], Optional[str]]: + # Returns an integer year for age, and {'male', 'female', 'other_gender'} + d = self.get_user_upk_simple(user_id, country_iso) + age = d.get("age_in_years") + if age is not None: + age = int(age) + gender = d.get("gender") + return age, gender + + def get_upk_schema(self, country_iso: str) -> List[UpkProperty]: + return self.upk_schema_manager.get_props_info_for_country( + country_iso=country_iso + ) + + def populate_user_upk_from_dict(self, upk_ans_dict): + + country_isos = {x["country_iso"] for x in upk_ans_dict} + assert len(country_isos) == 1 + country_iso = list(country_isos)[0] + for x in upk_ans_dict: + x["pred"] = x["pred"].replace("gr:", "") + x["obj"] = x["obj"].replace("gr:", "") + prop_labels = {x["pred"] for x in upk_ans_dict} + + props = self.get_upk_schema(country_iso=country_iso) + props = [ + x + for x in props + if x.property_label in prop_labels or x.property_id in prop_labels + ] + label_to_prop = {x.property_label: x for x in props} + id_to_prop = {x.property_id: x for x in props} + + for x in upk_ans_dict: + prop = label_to_prop.get(x["pred"]) or id_to_prop[x["pred"]] + x["property_id"] = prop.property_id + x["property_label"] = prop.property_label + x["prop_type"] = prop.prop_type + x["cardinality"] = prop.cardinality + x["created"] = x["timestamp"] + if prop.prop_type == PropertyType.UPK_ITEM: + if x["obj"] in prop.allowed_items_by_id: + x["item_label"] = prop.allowed_items_by_id[x["obj"]].label + x["item_id"] = x["obj"] + else: + x["item_label"] = x["obj"] + x["item_id"] = prop.allowed_items_by_label[x["obj"]].id + elif prop.prop_type == PropertyType.UPK_TEXT: + x["value_text"] = x["obj"] + elif prop.prop_type == PropertyType.UPK_NUMERICAL: + x["value_num"] = x["obj"] + + upk_ans = [UpkQuestionAnswer.model_validate(x) for x in upk_ans_dict] + return upk_ans + + def upsert_user_profile_knowledge(self, c: Cursor, row: UpkQuestionAnswer): + prop_type_table = { + PropertyType.UPK_ITEM: "marketplace_userprofileknowledgeitem", + PropertyType.UPK_NUMERICAL: "marketplace_userprofileknowledgenumerical", + PropertyType.UPK_TEXT: "marketplace_userprofileknowledgetext", + } + prop_type_value = { + PropertyType.UPK_ITEM: "item_id", + PropertyType.UPK_NUMERICAL: "value_num", + PropertyType.UPK_TEXT: "value_text", + } + table = prop_type_table[row.prop_type] + value = prop_type_value[row.prop_type] + args = row.model_dump_mysql() + + c.execute( + f""" + SELECT id FROM {table} + WHERE user_id = %(user_id)s AND property_id = %(property_id)s AND country_iso = %(country_iso)s + LIMIT 1""", + args, + ) + existing = c.fetchone() + + if existing: + c.execute( + f""" + UPDATE {table} + SET value = %({value})s, + created = %(created)s, + question_id = %(question_id)s, + session_id = %(session_id)s + WHERE user_id = %(user_id)s AND + property_id = %(property_id)s AND + country_iso = %(country_iso)s + """, + args, + ) + else: + c.execute( + f""" + INSERT INTO {table} + (property_id, value, created, country_iso, question_id, user_id, session_id) + VALUES (%(property_id)s, %({value})s, %(created)s, %(country_iso)s, + %(question_id)s, %(user_id)s, %(session_id)s) + """, + args, + ) + + def upsert_user_profile_knowledge_multi_item( + self, c: Cursor, row: UpkQuestionAnswer + ): + args = row.model_dump_mysql() + + c.execute( + """ + SELECT id FROM marketplace_userprofileknowledgeitem + WHERE user_id = %(user_id)s AND + property_id = %(property_id)s AND + country_iso = %(country_iso)s AND + value = %(item_id)s + LIMIT 1""", + args, + ) + existing = c.fetchone() + + if existing: + c.execute( + """ + UPDATE marketplace_userprofileknowledgeitem + SET created = %(created)s, + question_id = %(question_id)s, + session_id = %(session_id)s + WHERE user_id = %(user_id)s AND + property_id = %(property_id)s AND + country_iso = %(country_iso)s AND + value = %(item_id)s + """, + args, + ) + else: + c.execute( + """ + INSERT INTO marketplace_userprofileknowledgeitem + (property_id, value, created, country_iso, question_id, user_id, session_id) + VALUES (%(property_id)s, %(item_id)s, %(created)s, %(country_iso)s, + %(question_id)s, %(user_id)s, %(session_id)s) + """, + args, + ) + + def delete_user_profile_knowledge_multi_item( + self, c: Cursor, row: UpkQuestionAnswer + ) -> None: + args = row.model_dump_mysql() + c.execute( + """ + DELETE FROM marketplace_userprofileknowledgeitem + WHERE user_id = %(user_id)s AND + property_id = %(property_id)s AND + country_iso = %(country_iso)s AND + value = %(item_id)s + """, + args, + ) + + return None + + def set_user_upk(self, upk_ans: List[UpkQuestionAnswer]): + user_id = {x.user_id for x in upk_ans} + assert len(user_id) == 1, "only run for 1 user at a time" + user_id = list(user_id)[0] + + curr_upk = self.get_user_upk(user_id=user_id) + curr_upk_simple = self.get_user_upk_simple(user_id=user_id) + + new_upk_simple = defaultdict(set) + delete_items = set() + upk_multi = list() + delete_upk_multi = list() + for x in upk_ans: + # For zero or more (multiple values) We want all values to equal these. + # Might involve deleting values if they exist and are not in upk_ans + if ( + x.cardinality == Cardinality.ZERO_OR_MORE + and x.prop_type != PropertyType.UPK_ITEM + ): + raise ValueError("unsupported") + if ( + x.cardinality == Cardinality.ZERO_OR_MORE + and x.prop_type == PropertyType.UPK_ITEM + ): + new_upk_simple[x.property_label].add(x.item_label) + upk_multi.append(x) + for k, v in new_upk_simple.items(): + prop_delete_labels = curr_upk_simple.get(k, set()) - v + for x in prop_delete_labels: + delete_items.add((k, x)) + if delete_items: + for x in curr_upk: + if (x.property_label, x.item_label) in delete_items: + delete_upk_multi.append(x) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + for x in upk_multi: + self.upsert_user_profile_knowledge_multi_item(c, row=x) + for x in delete_upk_multi: + self.delete_user_profile_knowledge_multi_item(c, row=x) + for x in upk_ans: + if x.cardinality == Cardinality.ZERO_OR_ONE: + # If the cardinality is 0 or 1, we're inserting or updating the answer + self.upsert_user_profile_knowledge(c, x) + conn.commit() + self.clear_upk_cache(user_id=user_id) diff --git a/generalresearch/managers/thl/session.py b/generalresearch/managers/thl/session.py new file mode 100644 index 0000000..d328f7d --- /dev/null +++ b/generalresearch/managers/thl/session.py @@ -0,0 +1,669 @@ +from datetime import datetime, timedelta, timezone +from decimal import Decimal +from typing import Optional, Dict, Tuple, List, Any, Collection +from uuid import uuid4, UUID + +from faker import Faker +from psycopg import sql +from pydantic import NonNegativeInt + +from generalresearch.managers import parse_order_by +from generalresearch.managers.base import ( + Permission, + PostgresManager, +) +from generalresearch.managers.thl.product import ProductManager +from generalresearch.models import DeviceType +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.legacy.bucket import Bucket +from generalresearch.models.thl.definitions import ( + Status, + StatusCode1, + SessionStatusCode2, +) +from generalresearch.models.thl.session import ( + Session, + Wall, +) +from generalresearch.models.thl.task_status import ( + TaskStatusResponse, + TasksStatusResponse, +) +from generalresearch.models.thl.user import User + +fake = Faker() + + +class SessionManager(PostgresManager): + # I'm assuming the SessionManager will ALWAYS be passed a SqlHelper via + # thl_web_rw_db b/c the UPDATE operations... the SELECT operations + # will also be done with thl_web_rw_db bc of potential ReadReplica + # latency issues. + + def create( + self, + started: datetime, + user: User, + country_iso: Optional[str] = None, + device_type: Optional[DeviceType] = None, + ip: Optional[str] = None, + bucket: Optional[Bucket] = None, + url_metadata: Optional[Dict] = None, + uuid_id: Optional[str] = None, + ) -> Session: + """Creates a Session. Prefer to use this rather than instantiating the + model directly, because we're explicitly defining here which keys + should be set and which won't get set until later. + """ + if uuid_id is None: + uuid_id = uuid4().hex + + session = Session( + uuid=uuid_id, + started=started, + user=user, + country_iso=country_iso, + device_type=device_type, + ip=ip, + clicked_bucket=bucket, + url_metadata=url_metadata, + ) + + d = session.model_dump_mysql() + query = sql.SQL( + """ + INSERT INTO thl_session ( + uuid, user_id, started, loi_min, loi_max, + user_payout_min, user_payout_max, country_iso, + device_type, ip, url_metadata + ) VALUES ( + %(uuid)s, %(user_id)s, %(started)s, %(loi_min)s, %(loi_max)s, + %(user_payout_min)s, %(user_payout_max)s, %(country_iso)s, + %(device_type)s, %(ip)s, %(url_metadata_json)s + ) RETURNING id; + """ + ) + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=query, params=d) + session.id = c.fetchone()["id"] + conn.commit() + return session + + def create_dummy( + self, + # -- Create Dummy "optional" -- # + started: Optional[datetime] = None, + user: Optional[User] = None, + # -- Optional -- # + country_iso: Optional[str] = None, + device_type: Optional[DeviceType] = None, + ip: Optional[str] = None, + bucket: Optional[Bucket] = None, + url_metadata: Optional[Dict] = None, + uuid_id: Optional[str] = None, + ) -> Session: + """To be used in tests, where we don't care about certain fields""" + started = started or fake.date_time_between( + start_date=datetime(year=1900, month=1, day=1), + end_date=datetime(year=2000, month=1, day=1), + tzinfo=timezone.utc, + ) + user = user or User( + user_id=fake.random_int(min=1, max=2_147_483_648), uuid=uuid4().hex + ) + + return self.create( + started=started, + user=user, + country_iso=country_iso, + device_type=device_type, + ip=ip, + bucket=bucket, + url_metadata=url_metadata, + uuid_id=uuid_id, + ) + + def get_from_uuid(self, session_uuid: UUIDStr) -> Session: + query = f""" + SELECT + s.id AS session_id, + s.uuid AS session_uuid, + s.user_id, s.started, s.finished, s.loi_min, s.loi_max, + s.user_payout_min, s.user_payout_max, s.country_iso, s.device_type, + s.ip, s.status, s.status_code_1, s.status_code_2, s.payout, + s.user_payout, s.adjusted_status, s.adjusted_payout, + s.adjusted_user_payout, s.adjusted_timestamp, s.url_metadata::jsonb, + u.product_id, u.product_user_id, u.uuid AS user_uuid + FROM thl_session AS s + LEFT JOIN thl_user AS u + ON s.user_id = u.id + WHERE s.uuid = %(session_uuid)s + LIMIT 2 + """ + res = self.pg_config.execute_sql_query( + query=query, params={"session_uuid": session_uuid} + ) + assert len(res) == 1 + return self.session_from_mysql(res[0]) + + def get_from_id(self, session_id: int) -> Session: + query = f""" + SELECT + s.id AS session_id, + s.uuid AS session_uuid, + s.user_id, s.started, s.finished, s.loi_min, s.loi_max, + s.user_payout_min, s.user_payout_max, s.country_iso, s.device_type, + s.ip, s.status, s.status_code_1, s.status_code_2, s.payout, + s.user_payout, s.adjusted_status, s.adjusted_payout, + s.adjusted_user_payout, s.adjusted_timestamp, s.url_metadata::jsonb, + u.product_id, u.product_user_id, u.uuid AS user_uuid + FROM thl_session AS s + LEFT JOIN thl_user AS u + ON s.user_id = u.id + WHERE s.id = %(session_id)s + LIMIT 2 + """ + res = self.pg_config.execute_sql_query( + query=query, params={"session_id": session_id} + ) + assert len(res) == 1 + return self.session_from_mysql(res[0]) + + def session_from_mysql(self, d: Dict) -> Session: + d["id"] = d.pop("session_id") + d["uuid"] = UUID(d.pop("session_uuid")).hex + d["user"] = User( + product_id=UUID(d.pop("product_id")).hex, + product_user_id=d.pop("product_user_id"), + uuid=UUID(d.pop("user_uuid")).hex, + user_id=d.pop("user_id"), + ) + + d["loi_min"] = ( + timedelta(seconds=d["loi_min"]) if d["loi_min"] is not None else None + ) + d["loi_max"] = ( + timedelta(seconds=d["loi_max"]) if d["loi_max"] is not None else None + ) + bucket_keys = [ + "loi_min", + "loi_max", + "user_payout_min", + "user_payout_max", + ] + if all(d.get(k) is None for k in bucket_keys): + d["clicked_bucket"] = None + else: + d["clicked_bucket"] = Bucket( + loi_min=d.get("loi_min"), + loi_max=d.get("loi_max"), + user_payout_min=d.get("user_payout_min"), + user_payout_max=d.get("user_payout_max"), + ) + for k in bucket_keys: + d.pop(k, None) + if d["url_metadata"] is not None: + d["url_metadata"] = {k: str(v) for k, v in d["url_metadata"].items()} + return Session.model_validate(d) + + def finish_with_status( + self, + session: Session, + finished: Optional[datetime] = None, + status: Optional[Status] = None, + status_code_1: Optional[StatusCode1] = None, + status_code_2: Optional[SessionStatusCode2] = None, + payout: Optional[Decimal] = None, + user_payout: Optional[Decimal] = None, + ) -> Session: + # We have to update all the fields at once, or else we'll get + # validation errors. There doesn't seem to be a clean way of doing this. + # model_copy with update doesn't trigger the validators, so we + # re-run model_validate after + finished = finished if finished else datetime.now(tz=timezone.utc) + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "status_code_2": status_code_2, + "finished": finished, + "payout": payout, + "user_payout": user_payout, + } + ) + d = session.model_dump_mysql() + self.pg_config.execute_write( + query=f""" + UPDATE thl_session + SET status = %(status)s, status_code_1 = %(status_code_1)s, + status_code_2 = %(status_code_2)s, finished = %(finished)s, + payout = %(payout)s, user_payout = %(user_payout)s + WHERE id = %(id)s; + """, + params=d, + ) + return session + + def adjust_status(self, session: Session) -> None: + assert session.user.product, "prefetch product" + modified = session.adjust_status() + if not modified: + return None + + d = { + "adjusted_status": ( + session.adjusted_status.value if session.adjusted_status else None + ), + "adjusted_timestamp": session.adjusted_timestamp, + # These are Decimals which is why we str() them + "adjusted_payout": ( + str(session.adjusted_payout) + if session.adjusted_payout is not None + else None + ), + "adjusted_user_payout": ( + str(session.adjusted_user_payout) + if session.adjusted_user_payout is not None + else None + ), + "uuid": session.uuid, + } + + self.pg_config.execute_write( + query=""" + UPDATE thl_session + SET adjusted_status = %(adjusted_status)s, + adjusted_timestamp = %(adjusted_timestamp)s, + adjusted_payout = %(adjusted_payout)s, + adjusted_user_payout = %(adjusted_user_payout)s + WHERE uuid = %(uuid)s; + """, + params=d, + ) + + return None + + def filter_paginated( + self, + user_id: Optional[int] = None, + session_uuids: Optional[List[UUIDStr]] = None, + product_uuids: Optional[List[UUIDStr]] = None, + started_after: Optional[datetime] = None, + started_before: Optional[datetime] = None, + status: Optional[Status] = None, + adjusted_after: Optional[datetime] = None, + adjusted_before: Optional[datetime] = None, + page: int = 1, + size: int = 100, + order_by: Optional[str] = "-started", + ) -> Tuple[List[Session], int]: + """ + Sessions are filtered using user, product_uuids, started_after, & + started_before (if set). + - started_after is optional, default = beginning of time + - started_before is optional, default = now + + If page/size are passed, return only that page of the filtered (by + account_uuid and optionally time) items. Returns (list of items, total + (after filtering)). + + :param user_id: Return sessions from this User. Cannot pass both user_id and product_uuids + :param product_uuids: Return sessions from these products. Cannot pass both user_id and product_uuids + :param started_after: Filter to include this range. Default: beginning of time + :param started_before: Filter to include this range. Default: now + :param status: Filter for sessions with this status. + :param adjusted_after: Filter for sessions adjusted after this timestamp. + :param adjusted_before: Filter for sessions adjusted before this timestamp. If either adjusted_after + or adjusted_before is not None, then only adjusted sessions will be returned. + :param page: page starts at 1 + :param size: size of page, default (if page is not None) = 100. (1<=page<=100) + :param order_by: Required for pagination. Uses django-rest-framework ordering syntax, + e.g. '-created,tag' for (created desc, tag asc) + """ + filter_str, params = self.make_filter_str( + user_id=user_id, + session_uuids=session_uuids, + product_uuids=product_uuids, + started_after=started_after, + started_before=started_before, + status=status, + adjusted_after=adjusted_after, + adjusted_before=adjusted_before, + ) + + if page is not None: + assert type(page) is int + assert page >= 1, "page starts at 1" + size = size if size is not None else 100 + assert type(size) is int + assert 1 <= size <= 100 + params["offset"] = (page - 1) * size + params["limit"] = size + paginated_filter_str = "LIMIT %(limit)s OFFSET %(offset)s" + total = self.filter_count( + user_id=user_id, + session_uuids=session_uuids, + product_uuids=product_uuids, + started_after=started_after, + started_before=started_before, + status=status, + adjusted_before=adjusted_before, + adjusted_after=adjusted_after, + ) + else: + paginated_filter_str = "" + # Don't need to do a count if we aren't paginating + total = None + + order_by_str = parse_order_by(order_by) + query = f""" + SELECT + s.id AS session_id, s.uuid AS session_uuid, + s.user_id, s.started, s.finished, s.loi_min, s.loi_max, + s.user_payout_min, s.user_payout_max, s.country_iso, s.device_type, + s.ip, s.status, s.status_code_1, s.status_code_2, s.payout, + s.user_payout, s.adjusted_status, s.adjusted_payout, + s.adjusted_user_payout, s.adjusted_timestamp, s.url_metadata::jsonb, + u.product_id, u.product_user_id, u.uuid AS user_uuid, + COALESCE(walls.walls_json, '[]'::jsonb) AS walls_json + FROM thl_session s + + JOIN thl_user u + ON s.user_id = u.id + + LEFT JOIN LATERAL ( + SELECT jsonb_agg( + jsonb_build_object( + 'uuid', w.uuid, + 'started', w.started::timestamptz, + 'finished', w.finished::timestamptz, + 'source', w.source, + 'survey_id', w.survey_id, + 'req_survey_id', w.req_survey_id, + 'cpi', w.cpi, + 'req_cpi', w.req_cpi, + 'buyer_id', w.buyer_id, + 'status', w.status, + 'status_code_1', w.status_code_1, + 'status_code_2', w.status_code_2, + 'ext_status_code_1', w.ext_status_code_1, + 'ext_status_code_2', w.ext_status_code_2, + 'ext_status_code_3', w.ext_status_code_3, + 'adjusted_timestamp', w.adjusted_timestamp::timestamptz, + 'adjusted_status', w.adjusted_status, + 'adjusted_cpi', w.adjusted_cpi, + 'report_notes', w.report_notes, + 'report_value', w.report_value + ) + ) AS walls_json + FROM thl_wall w + WHERE w.session_id = s.id + ) walls ON TRUE + + {filter_str} + {order_by_str} + {paginated_filter_str} + """ + res = self.pg_config.execute_sql_query( + query=query, + params=params, + ) + if total is None: + total = len(res) + + return ( + self.session_from_mysql_rows_json(res), + total, + ) + + def session_from_mysql_rows_json( + self, + rows: Collection[Dict], + ) -> List[Session]: + """Columns: thl_session.*, thl_user.*, walls_json + - walls_json: list of objects, containing keys: thl_wall.* + """ + sessions = [] + for row in rows: + walls = [ + Wall( + uuid=UUID(w["uuid"]).hex, + started=datetime.fromisoformat(w["started"]), + finished=( + datetime.fromisoformat(w["finished"]) if w["finished"] else None + ), + source=w["source"], + survey_id=w["survey_id"], + buyer_id=w["buyer_id"], + status=w["status"], + status_code_1=w["status_code_1"], + status_code_2=w["status_code_2"], + ext_status_code_1=w["ext_status_code_1"], + ext_status_code_2=w["ext_status_code_2"], + ext_status_code_3=w["ext_status_code_3"], + adjusted_cpi=( + Decimal(w["adjusted_cpi"]).quantize(Decimal("0.01")) + if w["adjusted_cpi"] is not None + else None + ), + adjusted_status=w["adjusted_status"], + adjusted_timestamp=( + datetime.fromisoformat(w["adjusted_timestamp"]) + if w["adjusted_timestamp"] + else None + ), + report_notes=w["report_notes"], + report_value=w["report_value"], + req_survey_id=w["req_survey_id"], + req_cpi=Decimal(w["req_cpi"]).quantize(Decimal("0.01")), + cpi=Decimal(w["cpi"]).quantize(Decimal("0.01")), + session_id=row["session_id"], + user_id=row["user_id"], + ) + for w in row["walls_json"] + ] + walls = sorted(walls, key=lambda x: x.started) + row.pop("walls_json") + s = self.session_from_mysql(row) + s.wall_events = walls + sessions.append(s) + return sessions + + @staticmethod + def make_filter_str( + user_id: Optional[int] = None, + session_uuids: Optional[List[UUIDStr]] = None, + product_uuids: Optional[List[UUIDStr]] = None, + started_after: Optional[datetime] = None, + started_before: Optional[datetime] = None, + status: Optional[Status] = None, + adjusted_after: Optional[datetime] = None, + adjusted_before: Optional[datetime] = None, + extra_filters: Optional[str] = None, + ) -> Tuple[str, Dict[str, Any]]: + filters = [] + params = {} + + if started_before or started_after: + started_after = started_after or datetime(2017, 1, 1, tzinfo=timezone.utc) + started_before = started_before or datetime.now(tz=timezone.utc) + assert ( + started_after.tzinfo == timezone.utc + ), "started_after must be tz-aware as UTC" + assert ( + started_before.tzinfo == timezone.utc + ), "started_before must be tz-aware as UTC" + assert ( + started_after < started_before + ), "started_after must be before started_before" + filters.append("started BETWEEN %(started_after)s AND %(started_before)s") + params["started_after"] = started_after + params["started_before"] = started_before + + if adjusted_before or adjusted_after: + adjusted_after = adjusted_after or datetime(2017, 1, 1, tzinfo=timezone.utc) + adjusted_before = adjusted_before or datetime.now(tz=timezone.utc) + assert ( + adjusted_after.tzinfo == timezone.utc + ), "adjusted_after must be tz-aware as UTC" + assert ( + adjusted_before.tzinfo == timezone.utc + ), "adjusted_before must be tz-aware as UTC" + assert ( + adjusted_after < adjusted_before + ), "adjusted_after must be before adjusted_before" + filters.append( + "adjusted_timestamp BETWEEN %(adjusted_after)s AND %(adjusted_before)s" + ) + params["adjusted_after"] = adjusted_after + params["adjusted_before"] = adjusted_before + + if user_id: + assert product_uuids is None + filters.append("user_id = %(user_id)s") + params["user_id"] = user_id + + if product_uuids: + assert user_id is None + filters.append("product_id = ANY(%(product_uuids)s)") + params["product_uuids"] = product_uuids + + if session_uuids: + filters.append("s.uuid = ANY(%(session_uuids)s)") + params["session_uuids"] = session_uuids + + if status: + # We need to include the cases where status is NULL as ABANDON. We'll handle the distinction + # between TIMEOUT (no status, older than 90 min) and UNKNOWN (no status, newer than 90 min) later. + params["status"] = status.value + filters.append(f"COALESCE(status, 'a') = %(status)s") + + if extra_filters: + filters.append(extra_filters) + + filter_str = "WHERE " + " AND ".join(filters) if filters else "" + return filter_str, params + + def filter( + self, + started_since: Optional[datetime] = None, + started_between: Optional[Tuple[datetime, datetime]] = None, + user: Optional[User] = None, + product_uuids: Optional[List[UUIDStr]] = None, + team_uuids: Optional[List[UUIDStr]] = None, + business_uuids: Optional[List[UUIDStr]] = None, + order_by: str = "-started", + limit: Optional[int] = None, + ) -> List[Session]: + # to be deprecated ... + + if team_uuids: + raise NotImplementedError("Cannot filter by Teams (yet)") + + if business_uuids: + raise NotImplementedError("Cannot filter by Businesses (yet)") + + if started_since and started_between: + raise ValueError() + started_after = None + started_before = None + if started_since: + started_after = started_since + if started_between: + started_after, started_before = started_between + + return self.filter_paginated( + user_id=user.user_id if user is not None else None, + product_uuids=product_uuids, + started_after=started_after, + started_before=started_before, + size=limit or 100, + order_by=order_by, + )[0] + + def filter_count( + self, + user_id: Optional[int] = None, + session_uuids: Optional[List[UUIDStr]] = None, + product_uuids: Optional[List[UUIDStr]] = None, + started_after: Optional[datetime] = None, + started_before: Optional[datetime] = None, + status: Optional[Status] = None, + adjusted_after: Optional[datetime] = None, + adjusted_before: Optional[datetime] = None, + extra_filters: Optional[str] = None, + ) -> NonNegativeInt: + filter_str, params = self.make_filter_str( + user_id=user_id, + session_uuids=session_uuids, + product_uuids=product_uuids, + started_after=started_after, + started_before=started_before, + status=status, + adjusted_after=adjusted_after, + adjusted_before=adjusted_before, + extra_filters=extra_filters, + ) + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT COUNT(1) AS cnt + FROM thl_session AS s + JOIN thl_user AS u + ON s.user_id = u.id + {filter_str} + """, + params=params, + ) + return res[0]["cnt"] if res else 0 + + def get_task_status_response( + self, session_uuid: UUIDStr + ) -> Optional[TaskStatusResponse]: + res, total = self.filter_paginated(session_uuids=[session_uuid]) + if total == 0: + return None + session = res[0] + PM = ProductManager(pg_config=self.pg_config, permissions=[Permission.READ]) + product = PM.get_by_uuid(product_uuid=session.user.product_id) + return TaskStatusResponse.from_session(session=session, product=product) + + def get_tasks_status_response( + self, + product_uuid: UUIDStr, + user_id: Optional[int] = None, + started_after: Optional[datetime] = None, + started_before: Optional[datetime] = None, + status: Optional[Status] = None, + adjusted_after: Optional[datetime] = None, + adjusted_before: Optional[datetime] = None, + page: int = 1, + size: int = 100, + order_by: Optional[str] = "-started", + ) -> Optional[TasksStatusResponse]: + PM = ProductManager(pg_config=self.pg_config, permissions=[Permission.READ]) + product = PM.get_by_uuid(product_uuid=product_uuid) + + # This is for filtering. If we're not filtering by user, then add the product_id filter + product_uuids = [product_uuid] if user_id is None else None + res, total = self.filter_paginated( + user_id=user_id, + product_uuids=product_uuids, + started_after=started_after, + started_before=started_before, + status=status, + adjusted_after=adjusted_after, + adjusted_before=adjusted_before, + page=page, + size=size, + order_by=order_by, + ) + tsrs = [ + TaskStatusResponse.from_session(session=session, product=product) + for session in res + ] + + return TasksStatusResponse.model_validate( + {"tasks_status": tsrs, "page": page, "size": size, "total": total} + ) diff --git a/generalresearch/managers/thl/survey.py b/generalresearch/managers/thl/survey.py new file mode 100644 index 0000000..871ab83 --- /dev/null +++ b/generalresearch/managers/thl/survey.py @@ -0,0 +1,791 @@ +from collections import defaultdict +from datetime import datetime, timezone +from typing import Collection, List, Tuple, Optional + +import pandas as pd +from more_itertools import chunked +from psycopg import sql + +from generalresearch.managers.base import PostgresManager, Permission +from generalresearch.managers.thl.buyer import BuyerManager +from generalresearch.managers.thl.category import CategoryManager +from generalresearch.models import Source +from generalresearch.models.custom_types import SurveyKey +from generalresearch.models.thl.survey.model import ( + Survey, + SurveyStat, +) +from generalresearch.pg_helper import PostgresConfig + + +class SurveyManager(PostgresManager): + + def __init__( + self, + pg_config: PostgresConfig, + permissions: Collection[Permission] = None, + ): + super().__init__(pg_config=pg_config, permissions=permissions) + self.buyer_manager = BuyerManager(pg_config=pg_config, permissions=permissions) + self.category_manager = CategoryManager(pg_config=pg_config) + + def create_or_update(self, surveys: List[Survey]): + """ + The only field that is checked for a possible update is `is_live`! + """ + assert len({s.source for s in surveys}) == 1, "Only do one source at a time" + source = surveys[0].source + survey_ids = [s.survey_id for s in surveys] + assert len(survey_ids) == len(set(survey_ids)), "duplicate survey_ids" + + # Handle the buyers + buyer_codes = {s.buyer_code for s in surveys} + self.buyer_manager.bulk_get_or_create(source=source, codes=buyer_codes) + for s in surveys: + s.buyer_id = self.buyer_manager.source_code_pk[s.buyer_natural_key] + + existing_surveys = self.filter_by_natural_key( + source=source, survey_ids=survey_ids + ) + existing_nks = {s.natural_key for s in existing_surveys} + + to_create = [ + survey for survey in surveys if survey.natural_key not in existing_nks + ] + if to_create: + self.create_bulk(surveys=to_create) + to_create_survey_ids = [s.survey_id for s in to_create] + created_surveys = self.filter_by_natural_key( + source=source, survey_ids=to_create_survey_ids + ) + existing_surveys.extend(created_surveys) + + # Sometimes surveys get turned back on. Check that here + potentially_update = [s for s in surveys if s.natural_key in existing_nks] + existing_d = {s.survey_id: s for s in existing_surveys} + to_update = [] + for s in potentially_update: + if existing_d[s.survey_id].is_live != s.is_live: + s.id = existing_d[s.survey_id].id + to_update.append(s) + if to_update: + self.update_is_live(to_update) + return { + "survey_created_count": len(to_create), + "survey_updated_count": len(to_update), + } + + def create_bulk(self, surveys: List[Survey]): + for chunk in chunked(surveys, 500): + self.create_bulk_chunk(chunk) + return None + + def create_bulk_chunk(self, surveys: List[Survey]): + assert len(surveys) <= 500, "chunk me" + + query = """ + INSERT INTO marketplace_survey ( + source, survey_id, created_at, updated_at, + is_live, is_recontact, buyer_id, eligibility_criteria + ) VALUES ( + %(source)s, %(survey_id)s, %(created_at)s, %(updated_at)s, + %(is_live)s, %(is_recontact)s, %(buyer_id)s, %(eligibility_criteria)s + ) ON CONFLICT (source, survey_id) DO NOTHING;""" + params = [s.model_dump_sql() for s in surveys] + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.executemany(query=query, params_seq=params) + conn.commit() + return None + + def update_is_live(self, surveys: List[Survey]): + ids_ON = [s.id for s in surveys if s.is_live] + ids_OFF = [s.id for s in surveys if not s.is_live] + query_ON = """ + UPDATE marketplace_survey + SET is_live = TRUE, updated_at = NOW() + WHERE id = ANY(%(ids)s); + """ + query_OFF = """ + UPDATE marketplace_survey + SET is_live = FALSE, updated_at = NOW() + WHERE id = ANY(%(ids)s); + """ + if ids_ON: + self.pg_config.execute_write( + query_ON, + params={"ids": ids_ON}, + ) + if ids_OFF: + self.pg_config.execute_write( + query_OFF, + params={"ids": ids_OFF}, + ) + + def filter_by_keys( + self, + survey_keys: Collection[SurveyKey], + include_categories: bool = False, + ): + assert len(survey_keys) <= 1000 + if len(survey_keys) == 0: + return [] + + params = dict() + survey_source_ids = defaultdict(set) + + for sk in survey_keys: + source, survey_id = sk.split(":") + survey_source_ids[Source(source).value].add(survey_id) + + sk_filters = [] + for source, survey_ids in survey_source_ids.items(): + sk_filters.append( + f"(s.source = '{source}' AND s.survey_id = ANY(%(survey_ids_{source})s))" + ) + params[f"survey_ids_{source}"] = list(survey_ids) + + filter_str = f"WHERE ({' OR '.join(sk_filters)})" + + if include_categories: + CATEGORY_JOIN = """ + LEFT JOIN LATERAL ( + SELECT + jsonb_agg( + jsonb_build_object( + 'category', + jsonb_build_object( + 'id', c.id, + 'uuid', replace(c.uuid::text, '-', ''), + 'label', c.label, + 'path', c.path, + 'adwords_vertical_id', c.adwords_vertical_id, + 'parent_id', c.parent_id + ), + 'strength', sc.strength + ) + ORDER BY c.id + ) AS categories + FROM marketplace_surveycategory sc + JOIN marketplace_category c + ON c.id = sc.category_id + WHERE sc.survey_id = s.id + ) cat ON TRUE + """ + + query = f""" + SELECT + s.*, + b.code as buyer_code, + COALESCE(cat.categories, '[]'::jsonb) AS categories + FROM marketplace_survey s + LEFT JOIN marketplace_buyer b on s.buyer_id = b.id + {CATEGORY_JOIN} + {filter_str}; + """ + else: + query = f""" + SELECT s.*, b.code as buyer_code + FROM marketplace_survey s + LEFT JOIN marketplace_buyer b on s.buyer_id = b.id + {filter_str}; + """ + + res = self.pg_config.execute_sql_query( + query, + params=params, + ) + return [Survey.model_validate(x) for x in res] + + def filter_by_natural_key(self, source: Source, survey_ids: Collection[str]): + res = [] + for chunk in chunked(survey_ids, 1000): + res.extend(self.filter_by_natural_key_chunk(source, chunk)) + return res + + def filter_by_natural_key_chunk(self, source: Source, survey_ids: Collection[str]): + query = """ + SELECT id, source, survey_id, created_at, updated_at, + is_live, is_recontact, buyer_id, eligibility_criteria + FROM marketplace_survey + WHERE source = %(source)s AND + survey_id = ANY(%(survey_ids)s); + """ + res = self.pg_config.execute_sql_query( + query, + params={"survey_ids": list(survey_ids), "source": source.value}, + ) + return [Survey.model_validate(x) for x in res] + + def filter_by_source_live(self, source: Source): + """ + Return all live surveys for this source + """ + query = """ + SELECT id, source, survey_id, created_at, updated_at, + is_live, is_recontact, buyer_id, eligibility_criteria + FROM marketplace_survey + WHERE source = %(source)s AND is_live; + """ + res = self.pg_config.execute_sql_query(query, params={"source": source.value}) + return [Survey.model_validate(x) for x in res] + + def filter_by_live(self, fields: Optional[List[str]] = None): + """ + Return all live surveys + """ + fields_default = """id, source, survey_id, created_at, updated_at, + is_live, is_recontact, buyer_id, eligibility_criteria""" + fields = ", ".join(fields) if fields else fields_default + query = f""" + SELECT {fields} + FROM marketplace_survey + WHERE is_live; + """ + res = self.pg_config.execute_sql_query(query) + return [Survey.model_validate(x) for x in res] + + def turn_off_by_natural_key(self, source: Source, survey_ids: Collection[str]): + params = {"survey_ids": list(survey_ids), "source": source.value} + query = """ + UPDATE marketplace_survey + SET is_live = FALSE, updated_at = NOW() + WHERE source = %(source)s AND + survey_id = ANY(%(survey_ids)s) + RETURNING id; + """ + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, params=params) + survey_pks = [x["id"] for x in c.fetchall()] + conn.commit() + + query = """ + UPDATE marketplace_surveystat + SET survey_is_live = FALSE, updated_at = NOW() + WHERE survey_is_live AND + survey_id = ANY(%(survey_pks)s); + """ + self.pg_config.execute_write( + query, + params={"survey_pks": survey_pks}, + ) + return None + + def update_surveys_categories(self, surveys: List[Survey] = None) -> None: + for chunk in chunked(surveys, 500): + self.update_surveys_categories_chunk(chunk) + return None + + def update_surveys_categories_chunk(self, surveys: List[Survey] = None) -> None: + assert len(surveys) <= 500, "chunk me" + temp_table_sql = sql.SQL( + """ + CREATE TEMP TABLE tmp_survey_categories ( + survey_id bigint, + category_id int, + strength float8 + ) ON COMMIT DROP; + """ + ) + # noinspection SqlResolve + insert_values_sql = sql.SQL( + "INSERT INTO tmp_survey_categories VALUES (%s, %s, %s)" + ) + # noinspection SqlResolve + delete_sql = sql.SQL( + """ + DELETE FROM marketplace_surveycategory sc + WHERE NOT EXISTS ( + SELECT 1 + FROM tmp_survey_categories t + WHERE t.survey_id = sc.survey_id + AND t.category_id = sc.category_id + ) + AND sc.survey_id IN ( + SELECT DISTINCT survey_id FROM tmp_survey_categories + );""" + ) + # noinspection SqlResolve + upsert_sql = sql.SQL( + """ + INSERT INTO marketplace_surveycategory (survey_id, category_id, strength) + SELECT survey_id, category_id, strength + FROM tmp_survey_categories + ON CONFLICT (survey_id, category_id) + DO UPDATE SET + strength = EXCLUDED.strength;""" + ) + + rows = [ + (survey.id, c.category.id, c.strength) + for survey in surveys + for c in survey.categories + ] + with self.pg_config.make_connection() as conn: + # noinspection PyArgumentList + with conn.transaction(): + with conn.cursor() as c: + c.execute(temp_table_sql) + c.executemany(insert_values_sql, rows) + c.execute(delete_sql) + c.execute(upsert_sql) + conn.commit() + + def get_survey_categories(self): + query = """ + SELECT + s.source, s.survey_id, + jsonb_agg( + jsonb_build_object( + 'category_id', sc.category_id, + 'strength', sc.strength + ) + ) as categories + FROM marketplace_survey s + JOIN marketplace_surveycategory sc ON s.id = sc.survey_id + WHERE is_live + GROUP BY s.source, s.survey_id + """ + return self.pg_config.execute_sql_query(query) + + +class SurveyStatManager(PostgresManager): + KEYS = [ + "survey_id", + "quota_id", + "country_iso", + "version", + "cpi", + "complete_too_fast_cutoff", + "prescreen_conv_alpha", + "prescreen_conv_beta", + "conv_alpha", + "conv_beta", + "dropoff_alpha", + "dropoff_beta", + "completion_time_mu", + "completion_time_sigma", + "mobile_eligible_alpha", + "mobile_eligible_beta", + "desktop_eligible_alpha", + "desktop_eligible_beta", + "tablet_eligible_alpha", + "tablet_eligible_beta", + "long_fail_rate", + "user_report_coeff", + "recon_likelihood", + "score_x0", + "score_x1", + "score", + "updated_at", + "survey_is_live", + "survey_survey_id", + "survey_source", + ] + + SURVEY_STATS_COL_MAP = { + "PRESCREEN_CONVERSION.alpha": "prescreen_conv_alpha", + "PRESCREEN_CONVERSION.beta": "prescreen_conv_beta", + "CONVERSION.alpha": "conv_alpha", + "CONVERSION.beta": "conv_beta", + "COMPLETION_TIME.mu": "completion_time_mu", + "COMPLETION_TIME.sigma": "completion_time_sigma", + "LONG_FAIL.value": "long_fail_rate", + "USER_REPORT_COEFF.value": "user_report_coeff", + "RECON_LIKELIHOOD.value": "recon_likelihood", + "DROPOFF_RATE.alpha": "dropoff_alpha", + "DROPOFF_RATE.beta": "dropoff_beta", + "IS_MOBILE_ELIGIBLE.alpha": "mobile_eligible_alpha", + "IS_MOBILE_ELIGIBLE.beta": "mobile_eligible_beta", + "IS_DESKTOP_ELIGIBLE.alpha": "desktop_eligible_alpha", + "IS_DESKTOP_ELIGIBLE.beta": "desktop_eligible_beta", + "IS_TABLET_ELIGIBLE.alpha": "tablet_eligible_alpha", + "IS_TABLET_ELIGIBLE.beta": "tablet_eligible_beta", + "cpi": "cpi", + } + + def __init__( + self, + pg_config: PostgresConfig, + permissions: Collection[Permission] = None, + ): + super().__init__(pg_config=pg_config, permissions=permissions) + self.survey_manager = SurveyManager( + pg_config=pg_config, permissions=permissions + ) + # self.ensure_surveystat_key_type() + + # + # def ensure_surveystat_key_type(self): + # SQL = """ + # DO $$ + # BEGIN + # IF NOT EXISTS ( + # SELECT 1 + # FROM pg_type t + # JOIN pg_namespace n ON n.oid = t.typnamespace + # WHERE t.typname = 'surveystat_key' + # AND n.nspname = 'public' + # ) THEN + # CREATE TYPE public.surveystat_key AS ( + # survey_id bigint, + # quota_id varchar(32), + # country_iso varchar(2), + # version integer + # ); + # END IF; + # END + # $$;""" + # with self.pg_config.make_connection() as conn: + # with conn.cursor() as c: + # c.execute(SQL) + # conn.commit() + # return None + + # def register_surveystat_key(self, conn): + # info = CompositeInfo.fetch(conn, "surveystat_key") + # info.register(conn) + + def update_or_create(self, survey_stats: List[SurveyStat]): + """ + This manager is NOT responsible for creating surveys or buyers. + It will check to make sure they exist + """ + if len(survey_stats) == 0: + return [] + assert all(s.survey_survey_id is not None for s in survey_stats) + assert all(s.survey_source is not None for s in survey_stats) + assert ( + len({s.survey_source for s in survey_stats}) == 1 + ), "Only do one source at a time" + source = survey_stats[0].survey_source + nks = [s.natural_key for s in survey_stats] + assert len(nks) == len(set(nks)), "duplicate natural_keys" + + # Look up survey pks + survey_ids = [s.survey_survey_id for s in survey_stats] + surveys = self.survey_manager.filter_by_natural_key( + source=source, survey_ids=survey_ids + ) + nk_to_pk = {s.natural_key: s.id for s in surveys} + for ss in survey_stats: + try: + ss.survey_id = nk_to_pk[ss.survey_natural_key] + except KeyError as e: + raise ValueError( + f"Survey {e.args[0]} does not exist. Must create surveys first" + ) + # print(f"----aa-----: {datetime.now().isoformat()}") + self.upsert_sql(survey_stats=survey_stats) + # print(f"----ab-----: {datetime.now().isoformat()}") + return None + # keys = [s.unique_key for s in survey_stats] + # print(keys[:4]) + # survey_stats = self.filter_by_unique_keys(keys) + # print(len(survey_stats)) + # print(f"----ac-----: {datetime.now().isoformat()}") + # # For testing/deterministic + # survey_stats = sorted(survey_stats, key=lambda s: s.natural_key) + # return survey_stats + + def upsert_sql(self, survey_stats: List[SurveyStat]): + for chunk in chunked(survey_stats, 1000): + self.upsert_sql_chunk(survey_stats=chunk) + return None + + # def insert_sql(self, survey_stats: List[SurveyStat]): + # for chunk in chunked(survey_stats, 1000): + # self.insert_sql_chunk(survey_stats=chunk) + # return None + # + # def insert_sql_chunk(self, survey_stats: List[SurveyStat]): + # assert len(survey_stats) <= 1000, "chunk me" + # keys = self.keys + # keys_str = ", ".join(keys) + # values_str = ", ".join([f"%({k})s" for k in keys]) + # unique_cols = ["survey_id", "quota_id", "country_iso", "version"] + # unique_cols_str = ", ".join(unique_cols) + # + # query = f""" + # INSERT INTO marketplace_surveystat ({keys_str}) + # VALUES ({values_str}) + # ON CONFLICT ({unique_cols_str}) + # DO NOTHING ;""" + # params = [ss.model_dump_sql() for ss in survey_stats] + # with self.pg_config.make_connection() as conn: + # with conn.cursor() as c: + # c.executemany(query=query, params_seq=params) + # conn.commit() + # return None + + def upsert_sql_chunk(self, survey_stats: List[SurveyStat]): + assert len(survey_stats) <= 1000, "chunk me" + keys = self.KEYS + keys_str = ", ".join(keys) + values_str = ", ".join([f"%({k})s" for k in keys]) + unique_cols = ["survey_id", "quota_id", "country_iso", "version"] + unique_cols_str = ", ".join(unique_cols) + update_cols = set(keys) - set(unique_cols) - {"updated_at", "is_live"} + update_str = ", ".join( + [f"{k} = EXCLUDED.{k}" for k in update_cols] + ["updated_at = NOW()"] + ) + + query = f""" + INSERT INTO marketplace_surveystat ({keys_str}) + VALUES ({values_str}) + ON CONFLICT ({unique_cols_str}) + DO UPDATE SET {update_str};""" + now = datetime.now(tz=timezone.utc) + params = [ss.model_dump_sql() | {"updated_at": now} for ss in survey_stats] + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.executemany(query=query, params_seq=params) + conn.commit() + return None + + def filter_by_unique_keys(self, keys: Collection[Tuple]): + res = [] + for chunk in chunked(keys, 5000): + res.extend(self.filter_by_unique_keys_chunk(chunk)) + return res + + def filter_by_unique_keys_chunk(self, keys: Collection[Tuple]): + values_sql = ", ".join(["(%s, %s, %s, %s)"] * len(keys)) + query = f""" + SELECT + ss.* + FROM marketplace_surveystat ss + JOIN ( + VALUES {values_sql} + ) AS v(survey_id, quota_id, country_iso, version) + ON (ss.survey_id, ss.quota_id, ss.country_iso, ss.version) + = (v.survey_id, v.quota_id, v.country_iso, v.version); + """ + params = [item for row in keys for item in row] + with self.pg_config.make_connection() as conn: + # self.register_surveystat_key(conn) + with conn.cursor() as c: + c.execute(query, params=params) + res = c.fetchall() + # print('\n'.join([x['QUERY PLAN'] for x in res])) + return [SurveyStat.model_validate(x) for x in res] + + def update_surveystats_for_source( + self, + source: Source, + surveys: List[Survey], + survey_stats: List[SurveyStat], + ): + """ + What ym-survey-stats actually calls. + 1. All surveys for this source not in this list of surveys + get turned off + 2. Get or create all surveys and buyers + 3. Update survey stats + """ + # Assert the surveys and surveystats we passed are all + # for this Source + survey_source = {s.source for s in surveys} + assert len(survey_source) == 1 and survey_source == {source} + # And that the surveys in the surveystats match the passed in Surveys + surveys_nks = {s.natural_key for s in surveys} + ss_surveys_nks = {ss.survey_natural_key for ss in survey_stats} + assert surveys_nks == ss_surveys_nks + + # Turn off not live surveys + live_surveys = self.survey_manager.filter_by_source_live(source=source) + live_ids = {s.survey_id for s in live_surveys} + new_ids = {s.survey_id for s in surveys} + turn_off_surveys = live_ids - new_ids + self.survey_manager.turn_off_by_natural_key( + source=source, survey_ids=turn_off_surveys + ) + + # Create or Update (is_live) Surveys + res = self.survey_manager.create_or_update(surveys) + + # Update ss + self.update_or_create(survey_stats=survey_stats) + + return res + + def filter_by_updated_since(self, since): + return self.filter(updated_after=since, is_live=None) + + def filter_by_live(self): + return self.filter(is_live=True) + + def make_filter_str( + self, + is_live: Optional[bool] = True, + updated_after: Optional[datetime] = None, + min_score: Optional[float] = None, + survey_keys: Optional[Collection[SurveyKey]] = None, + sources: Optional[Collection[Source]] = None, + country_iso: Optional[str] = None, + ): + filters = [] + params = dict() + if updated_after is not None: + params["updated_after"] = updated_after + filters.append("ss.updated_at >= %(updated_after)s") + if min_score: + params["min_score"] = min_score + filters.append("score >= %(min_score)s") + if is_live is not None: + if is_live: + filters.append("ss.survey_is_live") + else: + filters.append("NOT ss.survey_is_live") + if sources: + assert survey_keys is None + params["sources"] = [s.value for s in sources] + filters.append("survey_source = ANY(%(sources)s)") + if country_iso: + params["country_iso"] = country_iso + filters.append("country_iso = %(country_iso)s") + if survey_keys is not None: + # Instead of doing a big IN with a big set of tuples, since we know + # we only have N possible sources, we just split by that and do + # a set of: + # ( (survey_source = 'x' and survey_survey_id IN ('1', '2') ) OR + # (survey_source = 'y' and survey_survey_id IN ('3', '4') ) ... ) + sk_filters = [] + survey_source_ids = defaultdict(set) + for sk in survey_keys: + source, survey_id = sk.split(":") + survey_source_ids[Source(source).value].add(survey_id) + for source, survey_ids in survey_source_ids.items(): + sk_filters.append( + f"(survey_source = '{source}' AND survey_survey_id = ANY(%(survey_ids_{source})s))" + ) + params[f"survey_ids_{source}"] = list(survey_ids) + # potential bug here ! --v Make sure this is wrapped in parentheses! + filters.append(f"({' OR '.join(sk_filters)})") + + filter_str = "WHERE " + " AND ".join(filters) if filters else "" + return filter_str, params + + def filter_count( + self, + is_live: Optional[bool] = True, + updated_after: Optional[datetime] = None, + min_score: Optional[float] = None, + survey_keys: Optional[Collection[SurveyKey]] = None, + sources: Optional[Collection[Source]] = None, + country_iso: Optional[str] = None, + ) -> int: + filter_str, params = self.make_filter_str( + is_live=is_live, + updated_after=updated_after, + min_score=min_score, + survey_keys=survey_keys, + sources=sources, + country_iso=country_iso, + ) + query = f""" + SELECT COUNT(1) as cnt + FROM marketplace_surveystat ss + {filter_str}; + """ + return self.pg_config.execute_sql_query(query, params=params)[0]["cnt"] + + def filter( + self, + is_live: Optional[bool] = True, + updated_after: Optional[datetime] = None, + min_score: Optional[float] = None, + survey_keys: Optional[Collection[SurveyKey]] = None, + sources: Optional[Collection[Source]] = None, + country_iso: Optional[str] = None, + page: Optional[int] = None, + size: Optional[int] = None, + order_by: Optional[str] = None, + debug: Optional[bool] = False, + ): + filter_str, params = self.make_filter_str( + is_live=is_live, + updated_after=updated_after, + min_score=min_score, + survey_keys=survey_keys, + sources=sources, + country_iso=country_iso, + ) + + paginated_filter_str = "" + if page is not None: + assert page != 0, "page starts at 1" + size = size if size is not None else 100 + params["offset"] = (page - 1) * size + params["limit"] = size + paginated_filter_str = " LIMIT %(limit)s OFFSET %(offset)s" + + order_by_str = "" + if order_by: + assert order_by in {"score DESC", "score", "updated_at DESC", "updated_at"} + order_by_str = f"ORDER BY {order_by}" + + query = f""" + SELECT + quota_id, country_iso, cpi, + complete_too_fast_cutoff, + prescreen_conv_alpha, prescreen_conv_beta, + conv_alpha, conv_beta, + dropoff_alpha, dropoff_beta, + completion_time_mu, completion_time_sigma, + mobile_eligible_alpha, mobile_eligible_beta, + desktop_eligible_alpha, desktop_eligible_beta, + tablet_eligible_alpha, tablet_eligible_beta, + long_fail_rate, user_report_coeff, recon_likelihood, + score_x0, score_x1, updated_at, version, score, + survey_is_live, survey_source, survey_survey_id + FROM marketplace_surveystat ss + {filter_str} + {order_by_str} + {paginated_filter_str} ; + """ + if debug: + print(query) + print(params) + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute("SET work_mem = '256MB';") + c.execute("SET statement_timeout = '10s';") + c.execute(query, params=params) + res = c.fetchall() + return [SurveyStat.model_validate(x) for x in res] + + def filter_to_merge_table( + self, + is_live: Optional[bool] = True, + updated_after: Optional[datetime] = None, + min_score: Optional[float] = 0.0001, + ): + survey_stats = self.filter( + is_live=is_live, updated_after=updated_after, min_score=min_score + ) + if not survey_stats: + return None + extra_cols = { + "survey_id", + "quota_id", + "country_iso", + "version", + "updated_at", + "score_x0", + "score_x1", + "survey_is_live", + "survey_source", + "survey_survey_id", + } + data = [] + for ss in survey_stats: + d = {k: getattr(ss, v) for k, v in self.SURVEY_STATS_COL_MAP.items()} + d.update({k: getattr(ss, k) for k in extra_cols}) + d["sid"] = ss.survey_natural_key + data.append(d) + df = pd.DataFrame(data) + df = df.set_index("sid") + df["cpi"] = df["cpi"].astype(float) + return df diff --git a/generalresearch/managers/thl/survey_penalty.py b/generalresearch/managers/thl/survey_penalty.py new file mode 100644 index 0000000..4e2c104 --- /dev/null +++ b/generalresearch/managers/thl/survey_penalty.py @@ -0,0 +1,112 @@ +import json +import threading +from collections import defaultdict +from datetime import timedelta +from typing import Optional, List, Tuple, Dict + +from generalresearch.decorators import LOG +from generalresearch.managers.base import RedisManager +from generalresearch.models.custom_types import ( + UUIDStr, +) +from generalresearch.models.thl.survey.penalty import ( + BPSurveyPenalty, + TeamSurveyPenalty, + PenaltyListAdapter, + Penalty, +) +from generalresearch.redis_helper import RedisConfig +from cachetools import cachedmethod, TTLCache + + +class SurveyPenaltyManager(RedisManager): + """ + Penalties are stored in redis with keys index by the product_id or team_id. + So getting the penalties for a BP will return all surveys that have + penalties for that BP. + The redis object is a hash, where the key is "survey-penalty-{bp/team}", + and each has fields per source. The field value is a list of + json-dumped SurveyPenalty objects. + Since we calculate the penalties batched by marketplace, when we set + this field it *replaces* all previous penalties for that + BP/team - marketplace. + """ + + def __init__( + self, + redis_config: RedisConfig, + cache_prefix: Optional[str] = None, + **kwargs, + ): + super().__init__(redis_config=redis_config, cache_prefix=cache_prefix, **kwargs) + self.redis_prefix = ( + f"{self.cache_prefix}:survey-penalty" + if self.cache_prefix + else "survey-penalty" + ) + self.cache = TTLCache(maxsize=128, ttl=60) + self.cache_lock = threading.Lock() + + def get_redis_key(self, penalty: Penalty) -> str: + if penalty.kind == "team": + return f"{self.redis_prefix}:{penalty.team_id}" + elif penalty.kind == "bp": + return f"{self.redis_prefix}:{penalty.product_id}" + else: + raise AssertionError("unreachable") + + def get_redis_key_for_id(self, uuid_id: UUIDStr): + return f"{self.redis_prefix}:{uuid_id}" + + def set_penalties(self, penalties: List[Penalty]): + """ """ + if len(penalties) > 1000: + LOG.warning("SurveyPenaltyManager.set_penalties batch me!") + assert len(penalties) < 10_000, "something is surely wrong" + self.cache.clear() + d = defaultdict(lambda: defaultdict(list)) + for p in penalties: + d[self.get_redis_key(p)][p.source.value].append(p.model_dump(mode="json")) + + pipe = self.redis_client.pipeline(transaction=False) + for key, mapping in d.items(): + mapping = { + k: json.dumps(v, separators=(",", ":")) for k, v in mapping.items() + } + pipe.hmset(key, mapping=mapping) + pipe.expire(key, timedelta(days=1)) + pipe.hexpire(key, timedelta(days=1), *mapping.keys()) + pipe.execute() + + def _load_penalties( + self, product_id: UUIDStr, team_id: UUIDStr + ) -> Tuple[List[BPSurveyPenalty], List[TeamSurveyPenalty]]: + pipe = self.redis_client.pipeline(transaction=False) + bp_res, team_res = ( + pipe.hgetall(self.get_redis_key_for_id(product_id)) + .hgetall(self.get_redis_key_for_id(team_id)) + .execute() + ) + bp_penalties = [] + for v in bp_res.values(): + bp_penalties.extend(PenaltyListAdapter.validate_python(json.loads(v))) + team_penalties = [] + for v in team_res.values(): + team_penalties.extend(PenaltyListAdapter.validate_python(json.loads(v))) + return bp_penalties, team_penalties + + @cachedmethod(lambda self: self.cache, lock=lambda self: self.cache_lock) + def get_penalties_for( + self, product_id: UUIDStr, team_id: UUIDStr + ) -> Dict[str, float]: + """ + Returns a dict with keys survey sids ({source}:{survey_id}) and values penalties. + e.g. {'s:1234': 0.8} + """ + bp_penalties, team_penalties = self._load_penalties( + product_id=product_id, team_id=team_id + ) + penalties: dict[str, float] = {} + for p in (*bp_penalties, *team_penalties): + penalties[p.sid] = max(p.penalty, penalties.get(p.sid, 0.0)) + return penalties diff --git a/generalresearch/managers/thl/task_adjustment.py b/generalresearch/managers/thl/task_adjustment.py new file mode 100644 index 0000000..15802e3 --- /dev/null +++ b/generalresearch/managers/thl/task_adjustment.py @@ -0,0 +1,187 @@ +import logging +from datetime import datetime, timezone +from decimal import Decimal +from functools import cached_property +from typing import Optional + +from generalresearch.managers import parse_order_by +from generalresearch.managers.base import ( + PostgresManager, +) +from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, +) +from generalresearch.managers.thl.session import SessionManager +from generalresearch.managers.thl.wall import WallManager +from generalresearch.models.thl.definitions import ( + WallAdjustedStatus, + Status, +) +from generalresearch.models.thl.session import ( + _check_adjusted_status_wall_consistent, +) +from generalresearch.models.thl.task_adjustment import TaskAdjustmentEvent + + +class TaskAdjustmentManager(PostgresManager): + + @cached_property + def wall_manager(self): + return WallManager(pg_config=self.pg_config) + + @cached_property + def session_manager(self): + return SessionManager(pg_config=self.pg_config) + + def filter_by_wall_uuid( + self, + wall_uuid, + page: int = 1, + size: int = 100, + order_by: Optional[str] = "-created", + ): + params = {"wall_uuid": wall_uuid} + order_by_str = parse_order_by(order_by) + paginated_filter_str = "LIMIT %(limit)s OFFSET %(offset)s" + params["offset"] = (page - 1) * size + params["limit"] = size + res = self.pg_config.execute_sql_query( + f""" + SELECT + uuid, + adjusted_status, + ext_status_code, + amount, + alerted, + created, + user_id, + wall_uuid, + started, + source, + survey_id + FROM thl_taskadjustment + WHERE wall_uuid = %(wall_uuid)s + {order_by_str} + {paginated_filter_str};""", + params=params, + ) + return [TaskAdjustmentEvent.model_validate(x) for x in res] + + def create_task_adjustment_event(self, event: TaskAdjustmentEvent): + # Only insert a new record into thl_taskadjustment if the status for this wall_uuid + # is different from the last one. Don't need the same thing twice + res = self.filter_by_wall_uuid( + wall_uuid=event.wall_uuid, page=1, size=1, order_by="-created" + ) + + if res and event.adjusted_status == res[0].adjusted_status: + # We already have this and it's the same change. Still call the wall_manager.adjust_status + # and ledger code b/c 1) it also won't do the same thing twice, and 2) we could be out of sync + # so check anyway. + return res[0] + + self.pg_config.execute_write( + """ + INSERT INTO thl_taskadjustment + (uuid, adjusted_status, ext_status_code, amount, alerted, + created, user_id, wall_uuid, started, source, survey_id) + VALUES (%(uuid)s, %(adjusted_status)s, %(ext_status_code)s, %(amount)s, %(alerted)s, + %(created)s, %(user_id)s, %(wall_uuid)s, %(started)s, %(source)s, %(survey_id)s) + """, + params=event.model_dump(mode="json"), + ) + return event + + def handle_single_recon( + self, + ledger_manager: ThlLedgerManager, + wall_uuid: str, + adjusted_status: WallAdjustedStatus, + alert_time: Optional[datetime] = None, + ext_status_code: Optional[str] = None, + adjusted_cpi: Optional[Decimal] = None, + ): + """ + We just got an adjustment notification from a marketplace. + + See note on TaskAdjustmentEvent.adjusted_status. + These fields (specifically adjusted_status and adjusted_cpi) are CHANGES/DELTAS + as just communicated by the marketplace, not what the Wall's final adjusted_* will be. + """ + alert_time = alert_time or datetime.now(tz=timezone.utc) + assert alert_time.tzinfo == timezone.utc + + wall = self.wall_manager.get_from_uuid(wall_uuid) + session = self.session_manager.get_from_id(wall.session_id) + user = session.user + user.prefetch_product(self.pg_config) + + if adjusted_status == WallAdjustedStatus.ADJUSTED_TO_FAIL: + amount_usd = wall.cpi * -1 + adjusted_cpi = 0 + elif adjusted_status == WallAdjustedStatus.ADJUSTED_TO_COMPLETE: + amount_usd = wall.cpi + adjusted_cpi = wall.cpi + elif adjusted_status == WallAdjustedStatus.CPI_ADJUSTMENT: + amount_usd = adjusted_cpi + elif adjusted_status == WallAdjustedStatus.CONFIRMED_COMPLETE: + amount_usd = None + else: + raise ValueError + + # If the wall event is a complete -> fail -> complete, we are going to + # receive an adjusted_status.adjust_to_complete, but internally, + # this is going to set the adjusted_status to None (b/c it was already a complete) + if ( + wall.status == Status.COMPLETE + and adjusted_status == WallAdjustedStatus.ADJUSTED_TO_COMPLETE + ): + new_adjusted_status = None + new_adjusted_cpi = None + elif ( + wall.status != Status.COMPLETE + and adjusted_status == WallAdjustedStatus.ADJUSTED_TO_FAIL + ): + new_adjusted_status = None + new_adjusted_cpi = None + else: + new_adjusted_status = adjusted_status + new_adjusted_cpi = adjusted_cpi + + # Validate that this event's transition is allowed + try: + _check_adjusted_status_wall_consistent( + status=wall.status, + cpi=wall.cpi, + adjusted_status=wall.adjusted_status, + adjusted_cpi=wall.adjusted_cpi, + new_adjusted_status=new_adjusted_status, + new_adjusted_cpi=new_adjusted_cpi, + ) + except AssertionError as e: + logging.warning(e) + return None + + event = TaskAdjustmentEvent( + adjusted_status=adjusted_status, + alerted=alert_time, + amount=amount_usd, + wall_uuid=wall_uuid, + started=wall.started, + source=wall.source, + survey_id=wall.survey_id, + user_id=user.user_id, + ext_status_code=ext_status_code, + ) + + self.create_task_adjustment_event(event=event) + self.wall_manager.adjust_status( + wall, + adjusted_status=new_adjusted_status, + adjusted_cpi=new_adjusted_cpi, + adjusted_timestamp=alert_time, + ) + ledger_manager.create_tx_task_adjustment(wall, user=user, created=alert_time) + session.wall_events = self.wall_manager.get_wall_events(session.id) + self.session_manager.adjust_status(session) + ledger_manager.create_tx_bp_adjustment(session, created=alert_time) diff --git a/generalresearch/managers/thl/user_compensate.py b/generalresearch/managers/thl/user_compensate.py new file mode 100644 index 0000000..c5018d9 --- /dev/null +++ b/generalresearch/managers/thl/user_compensate.py @@ -0,0 +1,89 @@ +from datetime import datetime, timezone +from decimal import Decimal +from typing import Optional +from uuid import uuid4 + +from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, +) +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.thl.user import User + + +def user_compensate( + ledger_manager: ThlLedgerManager, + user: User, + amount_int: int, + ext_ref=None, + description=None, + skip_flag_check: Optional[bool] = False, +) -> UUIDStr: + """ + Compensate a user. aka "bribe". The money is paid out of the BP's wallet balance. + Amount is in USD cents. + """ + pg_config = ledger_manager.pg_config + redis_client = ledger_manager.redis_client + + now = datetime.now(tz=timezone.utc) + assert type(amount_int) is int + user.prefetch_product(pg_config=pg_config) + assert ( + user.product.user_wallet_enabled + ), "Trying to compensate user without managed wallet" + + # Simple dedupe mechanism. Don't allow more than 1 per user_id every 1 min. + if not skip_flag_check: + flag_just_set = bool( + redis_client.set( + f"thl-grpc:user_compensate:{user.user_id}", 1, nx=True, ex=60 + ) + ) + assert flag_just_set, "User already compensated within the past minute!" + + # If there is an external reference ID, don't allow it to be used twice + if ext_ref: + res = pg_config.execute_sql_query( + query=f""" + SELECT 1 + FROM event_bribe + WHERE ext_ref_id = %s + """, + params=[ext_ref], + ) + assert not res, f"UserCompensate: ext_ref {ext_ref} already used!" + + # Create a Bribe instance, that stores info about this event + bribe_uuid = uuid4().hex + + # Create a new bribe instance + account = ledger_manager.get_account_or_create_user_wallet(user) + pg_config.execute_write( + query=f""" + INSERT INTO event_bribe + (uuid, credit_account_uuid, created, amount, ext_ref_id, description, data) + VALUES (%s, %s, %s, %s, %s, %s, %s) + """, + params=[ + bribe_uuid, + account.uuid, + now, + amount_int, + ext_ref, + description, + None, + ], + ) + # For now, all Ledger Accounts are USD + amount_usd = Decimal(amount_int) / 100 + if description is None: + description = f"Bonus ${amount_usd:,.2f}" + ledger_manager.create_tx_user_bonus( + user, + amount=amount_usd, + ref_uuid=bribe_uuid, + description=description, + skip_flag_check=skip_flag_check, + ) + + return bribe_uuid diff --git a/generalresearch/managers/thl/user_manager/__init__.py b/generalresearch/managers/thl/user_manager/__init__.py new file mode 100644 index 0000000..b8b1c35 --- /dev/null +++ b/generalresearch/managers/thl/user_manager/__init__.py @@ -0,0 +1,108 @@ +import csv +import logging +import os +import threading +import time +from pathlib import Path +from threading import RLock +from typing import Dict + +from cachetools import cached, TTLCache + +from generalresearch.managers.thl.user_manager import mysql_user_manager +from generalresearch.models.thl.product import Product + +logger = logging.getLogger() + + +class UserDoesntExistError(Exception): + pass + + +class UserCreateNotAllowedError(Exception): + pass + + +def download_bp_trust(): + raise DeprecationWarning("No more S3") + + +@cached(TTLCache(maxsize=1, ttl=5 * 60), lock=RLock()) +def get_bp_trust_df(): + from importlib.resources import files + + fp = str( + files("generalresearch.resources").joinpath("brokerage_trust_calculated.csv") + ) + # cols = ['product_id', 'team_id', 'team_name', 'business_id', 'business_name', + # 'product_name', 'bp_trust', 'team_trust', 'entrance_limit_expire_sec', + # 'entrance_limit_value'] + + if not os.path.exists(fp): + Path(fp).touch() + threading.Thread(target=download_bp_trust).start() + # raise exception so its not cached + raise FileNotFoundError() + if time.time() - os.path.getmtime(fp) > 3600: + Path(fp).touch() + threading.Thread(target=download_bp_trust).start() + bptrust = parse_bp_trust_df(fp) + + return bptrust + + +convert_int = lambda x: int(float(x)) + + +def parse_bp_trust_df(fp) -> Dict: + dtype = { + "bp_trust": float, + "team_trust": float, + "entrance_limit_expire_sec": convert_int, + "entrance_limit_value": convert_int, + "median_daily_completes_7d": convert_int, + } + bptrust = dict() + + with open(fp, newline="") as csvfile: + reader = csv.reader(csvfile) + header = next(reader) + for row in reader: + d = dict(zip(header, row)) + for k, v in dtype.items(): + if k in d: + d[k] = v(d[k]) + bptrust[d["product_id"]] = d + + return bptrust + + +def get_bp_user_create_limit_hourly(product: Product) -> int: + """The BP's hour user creation limit is calculated as: 4 times the median + daily completes over the past 7 days, with a default range of 60 to 1000 + per hour. For e.g. if a BP has 0 median daily completes, the user + creation limit is 60/hr. We can also override the default 60-1000 range + using the product.user_create_config. + """ + global_default = 120 + default = product.user_create_config.clip_hourly_create_limit(global_default) + + try: + bptrust = get_bp_trust_df() + except (FileNotFoundError, StopIteration): + return default + + if product.id not in bptrust: + return default + + if "median_daily_completes_7d" not in bptrust[product.id]: + logger.exception("missing median_daily_completes_7d column") + return default + + user_create_limit_daily = bptrust[product.id]["median_daily_completes_7d"] * 8 + user_create_limit_hourly = user_create_limit_daily / 24 + user_create_limit_hourly = max(min(user_create_limit_hourly, 5000), global_default) + user_create_limit_hourly = product.user_create_config.clip_hourly_create_limit( + user_create_limit_hourly + ) + return user_create_limit_hourly diff --git a/generalresearch/managers/thl/user_manager/memcached_user_manager.py b/generalresearch/managers/thl/user_manager/memcached_user_manager.py new file mode 100644 index 0000000..d2c68ed --- /dev/null +++ b/generalresearch/managers/thl/user_manager/memcached_user_manager.py @@ -0,0 +1,49 @@ +# from typing import List, Optional +# +# import pylibmc +# +# from generalresearch.models.thl.user import User +# +# +# class MemcachedUserManager: +# def __init__(self, servers: List[str], cache_prefix: Optional[str] = None): +# self.servers = servers +# self.cache_prefix = cache_prefix if cache_prefix else "user-lookup" +# +# def create_client(self): +# # Clients are NOT thread safe. Make a new one each time +# +# # There's a receive_timeout and send_timeout also, but the documentation is incomprehensible, +# # and they don't seem to do anything??? (I tested setting them at 1ms and I can't +# # get it to fail) +# # https://sendapatch.se/projects/pylibmc/behaviors.html +# mc_client = pylibmc.Client(servers=self.servers, binary=True, +# behaviors={'connect_timeout': 100}) +# return mc_client +# +# def get_user(self, *, product_id: str = None, product_user_id: str = None, user_id: int = None, +# user_uuid: UUIDStr = None) -> User: +# # assume we did input validation in user_manager.get_user() function +# mc_client = self.create_client() +# if user_uuid: +# d = mc_client.get(f"{self.cache_prefix}:uuid:{user_uuid}") +# elif user_id: +# d = mc_client.get(f"{self.cache_prefix}:user_id:{user_id}") +# else: +# d = mc_client.get(f"{self.cache_prefix}:ubp:{product_id}:{product_user_id}") +# if d: +# return User.model_validate_json(d) +# +# def set_user(self, user: User): +# d = user.to_json() +# mc_client = self.create_client() +# mc_client.set(f"{self.cache_prefix}:uuid:{user.uuid}", d, time=60 * 60 * 24) +# mc_client.set(f"{self.cache_prefix}:user_id:{user.user_id}", d, time=60 * 60 * 24) +# mc_client.set(f"{self.cache_prefix}:ubp:{user.product_id}:{user.product_user_id}", d, time=60 * 60 * 24) +# +# def clear_user(self, user: User): +# # this should only be used by tests +# mc_client = self.create_client() +# mc_client.delete(f"{self.cache_prefix}:uuid:{user.uuid}") +# mc_client.delete(f"{self.cache_prefix}:user_id:{user.user_id}") +# mc_client.delete(f"{self.cache_prefix}:ubp:{user.product_id}:{user.product_user_id}") diff --git a/generalresearch/managers/thl/user_manager/mysql_user_manager.py b/generalresearch/managers/thl/user_manager/mysql_user_manager.py new file mode 100644 index 0000000..ab2c6c3 --- /dev/null +++ b/generalresearch/managers/thl/user_manager/mysql_user_manager.py @@ -0,0 +1,287 @@ +import logging +from datetime import datetime, timezone +from functools import lru_cache +from typing import Optional, Collection, List +from uuid import uuid4 + +import psycopg +from psycopg import sql + +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.thl.user import User +from generalresearch.pg_helper import PostgresConfig + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +class MysqlUserManager: + def __init__(self, pg_config: PostgresConfig, is_read_replica: bool): + self.pg_config = pg_config + self.is_read_replica = is_read_replica + + def _set_last_seen(self, user: User) -> None: + # Don't call this directly. Use UserManager.set_last_seen() + assert not self.is_read_replica + now = datetime.now(tz=timezone.utc) + self.pg_config.execute_write( + """ + UPDATE thl_user + SET last_seen = %s + WHERE id = %s + """, + params=[now, user.user_id], + ) + + def get_user_from_mysql( + self, + *, + product_id: Optional[str] = None, + product_user_id: Optional[str] = None, + user_id: Optional[int] = None, + user_uuid: Optional[UUIDStr] = None, + can_use_read_replica=True, + ) -> Optional[User]: + + logger.info( + f"get_user_from_mysql: {product_id}, {product_user_id}, {user_id}, {user_uuid}" + ) + assert ( + (product_id and product_user_id) or user_id or user_uuid + ), "Must pass either (product_id, product_user_id), or user_id, or uuid" + if product_id or product_user_id: + assert ( + product_id and product_user_id + ), "Must pass both product_id and product_user_id" + assert ( + sum(map(bool, [product_id or product_id, user_id, user_uuid])) == 1 + ), "Must pass only 1 of (product_id, product_user_id), or user_id, or uuid" + + # Using RR: Assume we check redis first for newly created users + if can_use_read_replica is False: + assert self.is_read_replica is False + + if product_id: + res = self.pg_config.execute_sql_query( + query=f""" + SELECT id AS user_id, product_id, product_user_id, + uuid, blocked, created, last_seen + FROM thl_user + WHERE product_id = %s + AND product_user_id = %s + LIMIT 1 + """, + params=[product_id, product_user_id], + ) + + elif user_id: + res = self.pg_config.execute_sql_query( + query=f""" + SELECT id AS user_id, product_id, product_user_id, + uuid, blocked, created, last_seen + FROM thl_user + WHERE id = %s + LIMIT 1 + """, + params=[user_id], + ) + + else: + res = self.pg_config.execute_sql_query( + query=f""" + SELECT id AS user_id, product_id, product_user_id, + uuid, blocked, created, last_seen + FROM thl_user + WHERE uuid = %s + LIMIT 1 + """, + params=[user_uuid], + ) + + if res: + res = res[0] + # todo: add other cols into User (`last_ip`, `last_geoname_id`, `last_country_iso`) + return User.from_db(res) + + def create_user( + self, + product_user_id: str, + product_id: str, + created: Optional[datetime] = None, + ) -> User: + """Creates a thl_user record for a new user.""" + assert self.is_read_replica is False + # assert that the product exists + if not self.product_id_exists(product_id=product_id): + raise ValueError(f"userprofile_brokerageproduct not found: {product_id}") + + now = created or datetime.now(tz=timezone.utc) + user_uuid = uuid4().hex + params = { + "user_uuid": user_uuid, + "product_id": product_id, + "product_user_id": product_user_id, + "created": now, + "last_seen": now, + } + + # in postgres, you do not include the auto-increment id column + query = sql.SQL( + """ + INSERT INTO thl_user + (uuid, product_id, product_user_id, created, + last_seen, blocked, last_country_iso, last_geoname_id, last_ip) + VALUES (%(user_uuid)s, %(product_id)s, %(product_user_id)s, %(created)s, + %(last_seen)s, FALSE, NULL, NULL, NULL) + RETURNING id; + """ + ) + + try: + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=query, params=params) + user_id = c.fetchone()["id"] + except psycopg.IntegrityError as e: + # Two machines/processes are trying to create this same (product_id, product_user_id) + # at the same time. There's a unique index, so mysql will not let two be created. + # The 2nd should get an IntegrityError, meaning this already exists, and we can just query it. + logger.info( + f"mysql_user_manager.create_user_new integrity error: {product_id} {product_user_id}" + ) + user_mysql = self.get_user_from_mysql( + product_id=product_id, + product_user_id=product_user_id, + can_use_read_replica=False, + ) + if user_mysql: + return user_mysql + else: + # We specifically queried the NON read-replica, and we got an IntegrityError, so + # something else must be wrong... + raise e + else: + user = User( + user_id=user_id, + product_id=product_id, + product_user_id=product_user_id, + uuid=user_uuid, + last_seen=now, + created=now, + ) + + return user + + @lru_cache(maxsize=5000) + def product_id_exists(self, product_id: str): + # 'id' is the primary key, there can only be 0 or 1 + query = """ + SELECT id + FROM userprofile_brokerageproduct + WHERE id = %s; + """ + res = self.pg_config.execute_sql_query(query, [product_id]) + return len(res) > 0 + + def _block_user( + self, + user: User, + ) -> None: + # Don't call this directly. Use UserManager.block_user() + assert not self.is_read_replica + # id is primary key, there can only be 1 row + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=""" + UPDATE thl_user SET blocked = %s + WHERE id = %s + """, + params=[True, user.user_id], + ) + assert c.rowcount == 1, "User does not exist" + conn.commit() + + def is_whitelisted(self, user: User): + res = self.pg_config.execute_sql_query( + f""" + SELECT value + FROM userprofile_userstat + WHERE user_id = %s + AND key = 'USER_HEALTH.access_control'""", + [user.user_id], + ) + if res and res[0]["value"] is not None: + return bool(int(res[0]["value"])) + return False + + def fetch_by_bpuids( + self, + *, + product_id: str, + product_user_ids: Collection[str], + ) -> List[User]: + assert product_id, "must pass product_id" + assert len(product_user_ids) > 0, "must pass 1 or more product_user_ids" + assert len(product_user_ids) <= 500, "limit 500 product_user_ids" + assert isinstance( + product_user_ids, (list, set) + ), "must pass a collection of product_user_ids" + res = self.pg_config.execute_sql_query( + query=""" + SELECT id AS user_id, product_id, product_user_id, + uuid, blocked, created, last_seen + FROM thl_user + WHERE product_id = %(product_id)s + AND product_user_id = ANY(%(product_user_ids)s) + LIMIT 500 + """, + params={ + "product_id": product_id, + "product_user_ids": product_user_ids, + }, + ) + return [User.from_db(x) for x in res] + + def fetch( + self, + *, + user_ids: Collection[int] = None, + user_uuids: Collection[str] = None, + ) -> List[User]: + assert (user_ids or user_uuids) and not ( + user_ids and user_uuids + ), "Must pass ONE of user_ids, user_uuids" + if user_ids: + assert len(user_ids) <= 500, "limit 500 user_ids" + assert isinstance( + user_ids, (list, set) + ), "must pass a collection of user_ids" + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT id AS user_id, product_id, product_user_id, + uuid, blocked, created, last_seen + FROM thl_user + WHERE id = ANY(%(user_ids)s) + LIMIT 500 + """, + params={"user_ids": user_ids}, + ) + else: + assert len(user_uuids) <= 500, "limit 500 user_uuids" + assert isinstance( + user_uuids, (list, set) + ), "must pass a collection of user_uuids" + res = self.pg_config.execute_sql_query( + query=f""" + SELECT id AS user_id, product_id, product_user_id, + uuid, blocked, created, last_seen + FROM thl_user + WHERE uuid = ANY(%(user_uuids)s) + LIMIT 500 + """, + params={"user_uuids": user_uuids}, + ) + return [User.from_db(x) for x in res] diff --git a/generalresearch/managers/thl/user_manager/rate_limit.py b/generalresearch/managers/thl/user_manager/rate_limit.py new file mode 100644 index 0000000..bd0f9ac --- /dev/null +++ b/generalresearch/managers/thl/user_manager/rate_limit.py @@ -0,0 +1,76 @@ +import logging + +from limits import storage, strategies, RateLimitItemPerHour, RateLimitItem +from limits.limits import TIME_TYPES, safe_string +from pydantic import RedisDsn + +from generalresearch.managers.thl.user_manager import ( + UserCreateNotAllowedError, + get_bp_user_create_limit_hourly, +) +from generalresearch.models.thl.product import Product + +logger = logging.getLogger() + + +class RateLimitItemPerHourConstantKey(RateLimitItem): + """ + Per hour rate limit, where the key is specified manually + """ + + GRANULARITY = TIME_TYPES["hour"] + + def key_for(self, *identifiers: str) -> str: + """ + By default, the key includes the rate limit values. e.g. + `LIMITER/thl-grpc/allow_user_create/f1eb616ae68e488ab5b1f6839cb06f6a/61/1/hour` + This changes so that the key does not include the `61/1/hour` part, so that if the + actual limit changes (the limit of 61 hits per hour in this example), the + cache item key doesn't change + """ + remainder = "/".join([safe_string(k) for k in identifiers]) + return f"{self.namespace}/{remainder}" + + +class UserManagerLimiter: + def __init__(self, redis: RedisDsn): + self.redis = redis + + # memcache supported: connect_timeout=1, timeout=1), what about redis? + self.storage = storage.RedisStorage(uri=str(redis)) + + self.window = strategies.FixedWindowRateLimiter(storage=self.storage) + # self.window = strategies.MovingWindowRateLimiter(storage=self.storage) + + def raise_allow_user_create(self, product: Product) -> None: + """ + Checks if this product_id is allowed to create a new user now. + Sends only 1 sentry event per product_id per hour (if the product_id has exceeded the limit) + :raises UserCreateNotAllowedError + """ + allowed = self.user_create_allowed(product=product) + if not allowed: + err_msg = f"product_id {product.id} exceeded user creation limit" + + # -- Don't spam Sentry.io + sentry_rl = RateLimitItemPerHour(1) + if self.window.hit( + sentry_rl, + "thl-grpc", + "allow_user_create", + "sentry", + product.id, + cost=1, + ): + logger.exception(err_msg) + + raise UserCreateNotAllowedError(err_msg) + + def user_create_allowed(self, product: Product) -> bool: + """ + :returns if this product_id is not allowed to create a user + """ + rl_value = get_bp_user_create_limit_hourly(product) + rl = RateLimitItemPerHourConstantKey(rl_value) + + return self.window.hit(rl, "thl-grpc", "allow_user_create", product.id, cost=1) diff --git a/generalresearch/managers/thl/user_manager/redis_user_manager.py b/generalresearch/managers/thl/user_manager/redis_user_manager.py new file mode 100644 index 0000000..8e69740 --- /dev/null +++ b/generalresearch/managers/thl/user_manager/redis_user_manager.py @@ -0,0 +1,88 @@ +from typing import Optional + +import redis +from pydantic import RedisDsn + +from generalresearch.models.thl.user import User + + +class RedisUserManager: + def __init__( + self, + redis_dsn: RedisDsn, + cache_prefix: Optional[str] = None, + redis_timeout: Optional[float] = None, + ): + self.redis = redis_dsn + self.redis_timeout = redis_timeout if redis_timeout else 0.10 + self.client = self.create_client() + self.cache_prefix = cache_prefix if cache_prefix else "user-lookup" + + def create_client(self) -> redis.Redis: + # Clients are thread safe. We can just create one upon init + redis_config_dict = { + "url": str(self.redis), + "decode_responses": True, + "socket_timeout": self.redis_timeout, + "socket_connect_timeout": self.redis_timeout, + } + + return redis.Redis.from_url(**redis_config_dict) + + def get_user( + self, + *, + product_id: Optional[str] = None, + product_user_id: Optional[str] = None, + user_id: Optional[int] = None, + user_uuid: Optional[str] = None, + ) -> User: + # assume we did input validation in user_manager.get_user() function + if user_uuid: + d = self.client.get(f"{self.cache_prefix}:uuid:{user_uuid}") + + elif user_id: + d = self.client.get(f"{self.cache_prefix}:user_id:{user_id}") + + else: + d = self.client.get( + f"{self.cache_prefix}:ubp:{product_id}:{product_user_id}" + ) + + if d: + return User.model_validate_json(d) + + def set_user(self, user: User) -> None: + d = user.to_json() + with self.client.pipeline(transaction=False) as p: + p.set( + name=f"{self.cache_prefix}:uuid:{user.uuid}", + value=d, + ex=60 * 60 * 24, + ) + p.set( + name=f"{self.cache_prefix}:user_id:{user.user_id}", + value=d, + ex=60 * 60 * 24, + ) + p.set( + name=f"{self.cache_prefix}:ubp:{user.product_id}:{user.product_user_id}", + value=d, + ex=60 * 60 * 24, + ) + + p.execute() + + return None + + def clear_user(self, user: User) -> None: + # this should only be used by tests + with self.client.pipeline(transaction=False) as p: + p.delete(f"{self.cache_prefix}:uuid:{user.uuid}") + p.delete(f"{self.cache_prefix}:user_id:{user.user_id}") + p.delete( + f"{self.cache_prefix}:ubp:{user.product_id}:{user.product_user_id}" + ) + p.execute() + + return None diff --git a/generalresearch/managers/thl/user_manager/user_manager.py b/generalresearch/managers/thl/user_manager/user_manager.py new file mode 100644 index 0000000..dd32f8b --- /dev/null +++ b/generalresearch/managers/thl/user_manager/user_manager.py @@ -0,0 +1,378 @@ +import logging +from datetime import datetime +from functools import lru_cache +from typing import Collection, Optional, List +from uuid import uuid4 + +from pydantic import RedisDsn + +from generalresearch.managers.base import Permission +from generalresearch.managers.thl.product import ProductManager +from generalresearch.managers.thl.user_manager import UserDoesntExistError +from generalresearch.managers.thl.user_manager.mysql_user_manager import ( + MysqlUserManager, +) +from generalresearch.managers.thl.user_manager.rate_limit import ( + UserManagerLimiter, +) +from generalresearch.managers.thl.user_manager.redis_user_manager import ( + RedisUserManager, +) +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.user import User +from generalresearch.pg_helper import PostgresConfig +from generalresearch.utils.copying_cache import deepcopy_return + +logging.basicConfig() +logger = logging.getLogger() +auditlog = logging.getLogger("auditlog") + + +class UserManager: + def __init__( + self, + redis: Optional[RedisDsn] = None, + pg_config: Optional[PostgresConfig] = None, + pg_config_rr: Optional[PostgresConfig] = None, + sql_permissions: Collection[Permission] = None, + cache_prefix: Optional[str] = None, + redis_timeout: Optional[float] = None, + ): + + if sql_permissions is None: + sql_permissions = [] + + if pg_config is not None: + assert ( + pg_config_rr is not None + ), "you should pass RR credentials also for fast lookups" + + assert Permission.DELETE not in sql_permissions, "delete not allowed" + if Permission.UPDATE in sql_permissions or Permission.CREATE in sql_permissions: + assert pg_config is not None, "must pass pg_config" + + self.sql_permissions = set(sql_permissions) if sql_permissions else set() + self.mysql_user_manager = None + if pg_config: + self.mysql_user_manager = MysqlUserManager(pg_config, is_read_replica=False) + + self.mysql_user_manager_rr = None + if pg_config_rr: + self.mysql_user_manager_rr = MysqlUserManager( + pg_config_rr, is_read_replica=True + ) + + self.user_manager_limiter = None + self.redis_user_manager = None + if redis: + # Assuming we have full write access to redis if clients exist + self.user_manager_limiter = UserManagerLimiter(redis=redis) + self.redis_user_manager = RedisUserManager( + redis_dsn=redis, + cache_prefix=cache_prefix, + redis_timeout=redis_timeout, + ) + + self.product_manager = ProductManager( + pg_config=pg_config, permissions=[Permission.READ] + ) + + def set_last_seen(self, user: User) -> None: + assert Permission.UPDATE in self.sql_permissions, "permission error" + return self.mysql_user_manager._set_last_seen(user) + + def audit_log( + self, + user: User, + level: int, + event_type: str, + event_msg: Optional[str] = None, + event_value: Optional[float] = None, + ) -> None: + from generalresearch.managers.thl.userhealth import AuditLogManager + from generalresearch.models.thl.userhealth import AuditLogLevel + + alm = AuditLogManager(pg_config=self.mysql_user_manager.pg_config) + alm.create( + user_id=user.user_id, + level=AuditLogLevel(level), + event_type=event_type, + event_msg=event_msg, + event_value=event_value, + ) + + return None + + def cache_clear(self): + # Generally this is used in testing. This clears the .get_user's lru_cache. + # There is no way of clearing only a specific key from the cache. + # It does not clear any redis caches; that has to be done separately. + self.get_user.__wrapped__.cache_clear() + + @deepcopy_return + @lru_cache(maxsize=10000) + def get_user( + self, + *, + product_id: Optional[str] = None, + product_user_id: Optional[str] = None, + user_id: Optional[int] = None, + user_uuid: Optional[UUIDStr] = None, + ) -> User: + """ + Retrieve User from (product_id & product_user_id) or (user_id), or (uuid). + Looks up in lru_cache, then (redis, memcached), then mysql. + Raises UserDoesntExistError if user is not found. + (the * makes all arguments keyword-only arguments) + """ + assert ( + (product_id and product_user_id) or user_id or user_uuid + ), "Must pass either (product_id, product_user_id), or user_id, or uuid" + if product_id or product_user_id: + assert ( + product_id and product_user_id + ), "Must pass both product_id and product_user_id" + assert ( + sum(map(bool, [product_id or product_id, user_id, user_uuid])) == 1 + ), "Must pass only 1 of (product_id, product_user_id), or user_id, or uuid" + user = self.get_user_inmemory_cache( + product_id=product_id, + product_user_id=product_user_id, + user_id=user_id, + user_uuid=user_uuid, + ) + + if user: + return user + + # We can use the read-replica here b/c when we create a user we'll + # put it in the in-memory cache + mysql_user_manager = self.mysql_user_manager_rr or self.mysql_user_manager + user = mysql_user_manager.get_user_from_mysql( + product_id=product_id, + product_user_id=product_user_id, + user_id=user_id, + user_uuid=user_uuid, + can_use_read_replica=True, + ) + + # Note: Do not return None for a user that doesn't exist. If the user + # doesn't exist in mysql, this function will return None until the + # cache expires. Throw exception instead. + if user is None: + raise UserDoesntExistError( + f"user doesn't exist: {product_id}, {product_user_id}, {user_id}, {user_uuid}" + ) + + # Set the redis/memcached caches to we don't end up hitting mysql again! + self.set_user_inmemory_cache(user) + return user + + def get_user_if_exists( + self, + product_id: Optional[str] = None, + product_user_id: Optional[str] = None, + ) -> Optional[User]: + """ + Look up User from (product_id & product_user_id). Returns + None if user does not exist. + """ + try: + return self.get_user(product_id=product_id, product_user_id=product_user_id) + except UserDoesntExistError: + return None + + def get_user_inmemory_cache( + self, + *, + product_id: Optional[str] = None, + product_user_id: Optional[str] = None, + user_id: Optional[int] = None, + user_uuid: Optional[UUIDStr] = None, + ) -> Optional[User]: + + input_str = f"{product_id}, {product_user_id}, {user_id}, {user_uuid}" + if self.redis_user_manager: + import redis + + logger.info(f"get_user from redis: {input_str}") + try: + user = self.redis_user_manager.get_user( + product_id=product_id, + product_user_id=product_user_id, + user_id=user_id, + user_uuid=user_uuid, + ) + except ( + redis.exceptions.TimeoutError, + redis.exceptions.ConnectionError, + ) as e: + logger.info(f"get_user from redis failed: {input_str}, {e}") + else: + return user + + return None + + def set_user_inmemory_cache(self, user: User) -> None: + if self.redis_user_manager: + import redis + + try: + self.redis_user_manager.set_user(user) + except ( + redis.exceptions.TimeoutError, + redis.exceptions.ConnectionError, + ) as e: + logger.info(f"redis.set_user failed: {user}, {e}") + + return None + + def clear_user_inmemory_cache(self, user: User) -> None: + if self.redis_user_manager: + # this should only be used by tests + self.redis_user_manager.clear_user(user) + + return None + + def get_or_create_user(self, product_user_id: str, product_id: str) -> User: + """ + Given a bp_user_id and a product_id, get or create a User + """ + assert Permission.CREATE in self.sql_permissions + assert self.mysql_user_manager is not None + assert ( + self.redis_user_manager is not None + ), "need at least redis to synchronize user creation" + + assert ( + self.user_manager_limiter is not None + ), "Need user_manager_limiter to get_or_create_user" + # Attempt to create common_struct solely for validation purposes + if not User.is_valid_ubp( + product_id=product_id, product_user_id=product_user_id + ): + # Hopefully FSB checks this before it gets here and returns a helpful error message + raise ValueError("invalid product_id/product_user_id") + + u = self.get_user_if_exists( + product_id=product_id, product_user_id=product_user_id + ) + if u is not None: + return u + + return self.create_user(product_id=product_id, product_user_id=product_user_id) + + def create_user( + self, + product_user_id: str, + product_id: Optional[UUIDStr] = None, + product: Optional[Product] = None, + created: Optional[datetime] = None, + ) -> User: + + assert ( + self.user_manager_limiter is not None + ), "Need user_manager_limiter to create_user" + assert product_id or product, "Needs a product_id or a Product instance" + + if product is None: + product = self.product_manager.get_by_uuid(product_uuid=product_id) + + # This will raise a UserCreateNotAllowedError Exception if the + # product_id is over the limit + + # TODO: DB source for enable/disable user creation rate limit + # if product.id not in {}: + # self.user_manager_limiter.raise_allow_user_create(product=product) + + user = self.mysql_user_manager.create_user( + product_user_id=product_user_id, + product_id=product.id, + created=created, + ) + + self.set_user_inmemory_cache(user=user) + + return user + + def create_dummy( + self, + # --- Create dummy "optional" --- # + product_user_id: Optional[str] = None, + # --- Optional --- # + product_id: Optional[UUIDStr] = None, + product: Optional[Product] = None, + created: Optional[datetime] = None, + ) -> User: + + product_user_id = product_user_id or uuid4().hex + + return self.create_user( + product_user_id=product_user_id, + product_id=product_id, + product=product, + created=created, + ) + + def product_id_exists(self, product_id: str) -> bool: + mysql_user_manager = self.mysql_user_manager_rr or self.mysql_user_manager + return mysql_user_manager.product_id_exists(product_id) + + def block_user(self, user: User) -> bool: + """ + Block this user "permanently". + Writes to `300large`.`thl_user`.blocked. + :param user: User + :return: if the user has been blocked (i.e. False if they are already blocked) + """ + if user.blocked: + logger.info(f"User {user} is already blocked") + return False + if self.is_whitelisted(user): + logger.info(f"User {user} is whitelisted") + return False + + self.mysql_user_manager._block_user(user=user) + user.blocked = True + + # If we change something about a user, we should update the in-memory caches + self.set_user_inmemory_cache(user) + # There is no way to clear a single key from the lru_cache... + # https://bugs.python.org/issue28178 + self.cache_clear() + return True + + def is_whitelisted(self, user: User) -> bool: + """ + We have a user whitelist/blocklist system, which protects a user against a hard block + + Currently, this sets a key in the userprofile_userstat table. + TODO: this should be a property of the user? + """ + return self.mysql_user_manager.is_whitelisted(user=user) + + def fetch_by_bpuids( + self, + *, + product_id: str, + product_user_ids: Collection[str], + ) -> List[User]: + assert product_id, "must pass product_id" + assert len(product_user_ids) > 0, "must pass 1 or more product_user_ids" + return self.mysql_user_manager_rr.fetch_by_bpuids( + product_id=product_id, product_user_ids=product_user_ids + ) + + def fetch( + self, + *, + user_ids: Collection[int] = None, + user_uuids: Collection[str] = None, + ) -> List[User]: + assert (user_ids or user_uuids) and not ( + user_ids and user_uuids + ), "Must pass ONE of user_ids, user_uuids" + return self.mysql_user_manager_rr.fetch( + user_ids=user_ids, user_uuids=user_uuids + ) diff --git a/generalresearch/managers/thl/user_manager/user_metadata_manager.py b/generalresearch/managers/thl/user_manager/user_metadata_manager.py new file mode 100644 index 0000000..23c9d3c --- /dev/null +++ b/generalresearch/managers/thl/user_manager/user_metadata_manager.py @@ -0,0 +1,141 @@ +from typing import List, Optional, Collection + +from generalresearch.managers.base import PostgresManager +from generalresearch.models.thl.user_profile import UserMetadata + + +class UserMetadataManager(PostgresManager): + def filter( + self, + user_ids: Optional[Collection[int]] = None, + email_addresses: Optional[Collection[str]] = None, + email_sha256s: Optional[Collection[str]] = None, + email_sha1s: Optional[Collection[str]] = None, + email_md5s: Optional[Collection[str]] = None, + ) -> List[UserMetadata]: + for arg in [ + user_ids, + email_addresses, + email_sha256s, + email_sha1s, + email_md5s, + ]: + assert arg is None or isinstance( + arg, (set, list) + ), "must pass a collection of objects" + filters = [] + params = {} + if user_ids: + params["user_id"] = list(set(user_ids)) + filters.append("user_id = ANY(%(user_id)s)") + if email_addresses: + params["email_address"] = list(set(email_addresses)) + filters.append("email_address = ANY(%(email_address)s)") + if email_sha256s: + params["email_sha256"] = list(set(email_sha256s)) + filters.append("email_sha256 = ANY(%(email_sha256)s)") + if email_sha1s: + params["email_sha1"] = list(set(email_sha1s)) + filters.append("email_sha1 = ANY(%(email_sha1)s)") + if email_md5s: + params["email_md5"] = list(set(email_md5s)) + filters.append("email_md5 = ANY(%(email_md5)s)") + + filter_str = "WHERE " + " AND ".join(filters) if filters else "" + res = self.pg_config.execute_sql_query( + f""" + SELECT user_id, email_address, email_sha256, email_sha1, email_md5 + FROM thl_usermetadata + {filter_str} + """, + params, + ) + return [UserMetadata.from_db(**x) for x in res] + + def get_if_exists( + self, + user_id: Optional[int] = None, + email_address: Optional[str] = None, + email_sha256: Optional[str] = None, + email_sha1: Optional[str] = None, + email_md5: Optional[str] = None, + ) -> Optional[UserMetadata]: + filters = { + "user_ids": user_id, + "email_addresses": email_address, + "email_sha256s": email_sha256, + "email_sha1s": email_sha1, + "email_md5s": email_md5, + } + filters = {k: [v] for k, v in filters.items() if v is not None} + assert len(filters) == 1, "Exactly ONE filter argument must be provided." + res = self.filter(**filters) + if len(res) == 0: + return None + if len(res) > 1: + raise ValueError("More than 1 result returned!") + return UserMetadata.model_validate(res[0]) + + def get( + self, + user_id: Optional[int] = None, + email_address: Optional[str] = None, + email_sha256: Optional[str] = None, + email_sha1: Optional[str] = None, + email_md5: Optional[str] = None, + ) -> UserMetadata: + res = self.get_if_exists( + user_id=user_id, + email_address=email_address, + email_sha256=email_sha256, + email_sha1=email_sha1, + email_md5=email_md5, + ) + if res is None: + if user_id is not None: + # We don't raise a "not found" here, b/c it just means that + # nothing has been set for this user. + return UserMetadata(user_id=user_id) + else: + # We are filtering, not looking up a user's info, so in this + # case, nothing is found + raise ValueError("not found") + return res + + def update(self, user_metadata: UserMetadata) -> int: + """ + The row in the thl_usermetadata might not exist. We'll implicitly create it + if it doesn't yet exist. The caller does not need to know this detail. + """ + res = self.get_if_exists(user_id=user_metadata.user_id) + + # We're assuming the user itself exists. There's a foreign key so the + # db call will fail if it doesn't, so we don't need to check + # it beforehand. + if not res: + return self._create(user_metadata=user_metadata) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + """ + UPDATE thl_usermetadata + SET email_address = %(email_address)s, email_sha256 = %(email_sha256)s, + email_sha1 = %(email_sha1)s, email_md5 = %(email_md5)s + WHERE user_id = %(user_id)s; + """, + params=user_metadata.to_db(), + ) + rowcount = c.rowcount + conn.commit() + return rowcount + + def _create(self, user_metadata: UserMetadata) -> int: + return self.pg_config.execute_write( + """ + INSERT INTO thl_usermetadata + (user_id, email_address, email_sha256, email_sha1, email_md5) + VALUES (%(user_id)s, %(email_address)s, %(email_sha256)s, %(email_sha1)s, %(email_md5)s); + """, + params=user_metadata.to_db(), + ) diff --git a/generalresearch/managers/thl/user_streak.py b/generalresearch/managers/thl/user_streak.py new file mode 100644 index 0000000..fd3c1ba --- /dev/null +++ b/generalresearch/managers/thl/user_streak.py @@ -0,0 +1,150 @@ +from datetime import date, datetime +from typing import Optional, List, Tuple + +import pandas as pd + +from generalresearch.managers.base import PostgresManager +from generalresearch.managers.leaderboard import country_timezone +from generalresearch.models.thl.user_streak import ( + UserStreak, + StreakPeriod, + StreakFulfillment, + StreakState, + PERIOD_TO_PD_FREQ, +) + + +class UserStreakManager(PostgresManager): + + def get_user_country(self, user_id: int) -> None: + # For the purposes of streaks, the country they are in is + # the first country they were active in + res = self.pg_config.execute_sql_query( + """ + SELECT country_iso + FROM thl_session + WHERE user_id = %(user_id)s + ORDER BY started + LIMIT 1; + """, + {"user_id": user_id}, + ) + if res: + return res[0]["country_iso"] + + return None + + def get_user_active_days_query(self, user_id: int, country_iso: str): + tz = country_timezone()[country_iso] + query = """ + SELECT + (started AT TIME ZONE %(tz)s)::date AS d, + MAX((status = 'c')::int) AS is_complete + FROM thl_session + WHERE user_id = %(user_id)s + AND status IS NOT NULL + AND (status_code_1 IS NULL OR status_code_1 NOT IN (16, 18, 19)) + GROUP BY d + ORDER BY d;""" + params = {"user_id": user_id, "tz": str(tz)} + return self.pg_config.execute_sql_query(query, params) + + def get_user_streaks( + self, user_id: int, country_iso: Optional[str] = None + ) -> List[UserStreak]: + country_iso = country_iso or self.get_user_country(user_id=user_id) + if country_iso is None: + return [] + res = self.get_user_active_days_query(user_id=user_id, country_iso=country_iso) + + active_days = [x["d"] for x in res] + complete_days = [x["d"] for x in res if x["is_complete"]] + + streaks: List[UserStreak] = [] + + for period in StreakPeriod: + for fulfillment, days in [ + (StreakFulfillment.ACTIVE, active_days), + (StreakFulfillment.COMPLETE, complete_days), + ]: + current, longest, state, last_period = compute_streaks_from_days( + days=days, + country_iso=country_iso, + period=StreakPeriod(period), + ) + + streaks.append( + UserStreak( + fulfillment=fulfillment, + period=StreakPeriod(period), + current_streak=current, + longest_streak=longest, + state=state, + last_fulfilled_period_start=last_period, + country_iso=country_iso, + user_id=user_id, + ) + ) + # Don't return any that are empty (no current or longest streak) + streaks = [s for s in streaks if s.current_streak or s.longest_streak] + + return streaks + + +def compute_streaks_from_days( + days: List[date], + country_iso: str, + period: StreakPeriod, + today: Optional[date] = None, +) -> Tuple[int, int, StreakState, Optional[date]]: + """ + :returns: (current_streak, longest_streak, streak_state, last_period_start) + """ + + if not days: + return 0, 0, StreakState.BROKEN, None + + tz = country_timezone()[country_iso] + today = today or datetime.now(tz=tz).date() + + freq = PERIOD_TO_PD_FREQ[period] + + # Convert raw days -> pandas Periods + periods = sorted({pd.Timestamp(d).to_period(freq) for d in days}) + today_period = pd.Timestamp(today).to_period(freq) + + # period -> period start dates + # period_starts = [p.start_time.date() for p in periods] + # today_start = today_period.start_time.date() + + # ---- longest streak ---- + longest = 1 + running = 1 + + for i in range(1, len(periods)): + if periods[i] == periods[i - 1] + 1: + running += 1 + else: + longest = max(longest, running) + running = 1 + + longest = max(longest, running) + + # ---- current streak ---- + last_run = 1 + for i in range(len(periods) - 1, 0, -1): + if periods[i] == periods[i - 1] + 1: + last_run += 1 + else: + break + + last_period = periods[-1] + last_period_start = last_period.start_time.date() + + if last_period == today_period: + return last_run, longest, StreakState.ACTIVE, last_period_start + + if last_period + 1 == today_period: + return last_run, longest, StreakState.AT_RISK, last_period_start + + return 0, longest, StreakState.BROKEN, last_period_start diff --git a/generalresearch/managers/thl/userhealth.py b/generalresearch/managers/thl/userhealth.py new file mode 100644 index 0000000..6f66ce3 --- /dev/null +++ b/generalresearch/managers/thl/userhealth.py @@ -0,0 +1,579 @@ +import ipaddress +from datetime import timezone, datetime, timedelta +from itertools import zip_longest +from random import choice as rchoice, random +from typing import List, Collection, Optional, Dict, Tuple + +import faker +from pydantic import PositiveInt, NonNegativeInt + +from generalresearch.decorators import LOG +from generalresearch.managers.base import ( + Permission, + PostgresManager, + PostgresManagerWithRedis, +) +from generalresearch.managers.thl.ipinfo import GeoIpInfoManager +from generalresearch.models.custom_types import IPvAnyAddressStr +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.user import User +from generalresearch.models.thl.user_iphistory import ( + IPRecord, + UserIPRecord, + UserIPHistory, +) +from generalresearch.models.thl.userhealth import AuditLog, AuditLogLevel +from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig + +fake = faker.Faker() + + +class UserIpHistoryManager(PostgresManagerWithRedis): + def __init__( + self, + pg_config: PostgresConfig, + redis_config: RedisConfig, + permissions: Collection[Permission] = None, + cache_prefix: Optional[str] = None, + ): + super().__init__( + pg_config=pg_config, + redis_config=redis_config, + permissions=permissions, + cache_prefix=cache_prefix, + ) + self.geoipinfo_manager = GeoIpInfoManager( + pg_config=pg_config, + redis_config=redis_config, + cache_prefix=cache_prefix, + ) + + def get_redis_key(self, user_id: int) -> str: + return f"py-utils:user-ip-history:{user_id}" + + def get_user_ip_records_sql(self, user_id: int) -> List[UserIPRecord]: + # The IP metadata is ONLY for the 'ip', NOT for any forwarded ips. + # This might get called immediately after a write, so use the non-rr + res = self.pg_config.execute_sql_query( + query=f""" + SELECT iph.ip, iph.created, iph.user_id, + geo.subdivision_1_iso, + ipinfo.country_iso, + ipinfo.is_anonymous + FROM userhealth_iphistory iph + LEFT JOIN thl_ipinformation AS ipinfo + ON iph.ip = ipinfo.ip + LEFT JOIN thl_geoname AS geo + ON ipinfo.geoname_id = geo.geoname_id + WHERE iph.user_id = %s + AND created > NOW() - INTERVAL '28 days' + ORDER BY iph.created DESC + LIMIT 100 + """, + params=[user_id], + ) + + res = [UserIPRecord.model_validate(x) for x in res] + return res + + def get_user_ip_history_cache(self, user_id: int) -> Optional[UserIPHistory]: + res = self.redis_client.get(self.get_redis_key(user_id)) + if res: + return UserIPHistory.model_validate_json(res) + return None + + def delete_user_ip_history_cache(self, user_id: int) -> None: + self.redis_client.delete(self.get_redis_key(user_id)) + return None + + def set_user_ip_history_cache(self, user_id: int, iph: UserIPHistory) -> None: + value = iph.model_dump_json() + self.redis_client.set(self.get_redis_key(user_id), value, ex=3 * 24 * 3600) + return None + + def recreate_user_ip_history_cache(self, user_id: int) -> None: + self.delete_user_ip_history_cache(user_id=user_id) + records = self.get_user_ip_records_sql(user_id=user_id) + # todo: we may get dns records from somewhere else here ... + iph = UserIPHistory(user_id=user_id, ips=records) + self.set_user_ip_history_cache(user_id=user_id, iph=iph) + return None + + def get_user_ip_history(self, user_id: int) -> UserIPHistory: + assert isinstance(user_id, int) + iph = self.get_user_ip_history_cache(user_id=user_id) + if iph: + LOG.debug(f"get_user_ip_history got in cache: {iph.model_dump_json()}") + + else: + LOG.debug("get_user_ip_history cache not found, using mysql") + records = self.get_user_ip_records_sql(user_id=user_id) + # todo: we may get dns records from somewhere else here ... + iph = UserIPHistory(user_id=user_id, ips=records) + self.set_user_ip_history_cache(user_id=user_id, iph=iph) + + iph.enrich_ips(pg_config=self.pg_config, redis_config=self.redis_config) + return iph + + def get_user_latest_ip( + self, user: User, exclude_anon: bool = False + ) -> Optional[str]: + record = self.get_user_latest_ip_record(user=user, exclude_anon=exclude_anon) + if record: + return record.ip + return None + + def get_user_latest_ip_record( + self, user: User, exclude_anon: bool = False + ) -> Optional[UserIPRecord]: + iphistory = self.get_user_ip_history(user_id=user.user_id) + + if iphistory.ips: + if exclude_anon: + return next( + filter( + lambda x: not x.information.is_anonymous, + iphistory.ips[::-1], + ), + None, + ) + else: + return iphistory.ips[-1] + + return None + + def get_user_latest_country( + self, user: User, exclude_anon: bool = False + ) -> Optional[str]: + """Get the country the user is in, based off their latest ip.""" + ipr = self.get_user_latest_ip_record(user, exclude_anon=exclude_anon) + # The ipr.information should exist, but it is possible the user has + # no IP history at all, so the record is None + return ipr.country_iso if ipr is not None else None + + def is_user_anonymous(self, user: User) -> Optional[bool]: + # Get the user's latest ip. is it marked as anonymous? + # Note: it is possible we only did a "basic" lookup of this IP so + # we don't know if they are anonymous. Default to False + # Return None if the user has no IP history at all + ipr = self.get_user_latest_ip_record(user) + if ipr: + return ipr.is_anonymous if ipr.is_anonymous is not None else False + return None + + +class IPRecordManager(PostgresManagerWithRedis): + + def __init__( + self, + pg_config: PostgresConfig, + redis_config: RedisConfig, + permissions: Collection[Permission] = None, + cache_prefix: Optional[str] = None, + ): + super().__init__( + pg_config=pg_config, + redis_config=redis_config, + permissions=permissions, + cache_prefix=cache_prefix, + ) + self.user_ip_history_manager = UserIpHistoryManager( + pg_config=self.pg_config, + redis_config=self.redis_config, + cache_prefix=self.cache_prefix, + permissions=self.permissions, + ) + + def create_dummy( + self, + user_id: PositiveInt, + ip: Optional[IPvAnyAddressStr] = None, + forwarded_ip1: Optional[IPvAnyAddressStr] = None, + forwarded_ip2: Optional[IPvAnyAddressStr] = None, + forwarded_ip3: Optional[IPvAnyAddressStr] = None, + forwarded_ip4: Optional[IPvAnyAddressStr] = None, + forwarded_ip5: Optional[IPvAnyAddressStr] = None, + forwarded_ip6: Optional[IPvAnyAddressStr] = None, + ) -> IPRecord: + return self.create( + user_id=user_id, + ip=ip or fake.ipv4_public(), + forwarded_ip1=(forwarded_ip1 or fake.ipv4_public()), + forwarded_ip2=(forwarded_ip2 or fake.ipv6() if random() < 0.5 else None), + forwarded_ip3=( + forwarded_ip3 or fake.ipv4_public() if random() < 0.25 else None + ), + forwarded_ip4=forwarded_ip4, + forwarded_ip5=forwarded_ip5, + forwarded_ip6=forwarded_ip6, + ) + + def create_unpack( + self, + user_id: PositiveInt, + ip: IPvAnyAddressStr, + forwarded_ips: List[str], + ) -> IPRecord: + if len(forwarded_ips) > 6: + raise ValueError("A maximum of 6 forwarded IPs is allowed.") + + padded = list(forwarded_ips) + [None] * (6 - len(forwarded_ips)) + + return self.create(user_id, ip, *padded) + + def create( + self, + user_id: PositiveInt, + ip: IPvAnyAddressStr, + forwarded_ip1: IPvAnyAddressStr, + forwarded_ip2: IPvAnyAddressStr, + forwarded_ip3: IPvAnyAddressStr, + forwarded_ip4: IPvAnyAddressStr, + forwarded_ip5: IPvAnyAddressStr, + forwarded_ip6: IPvAnyAddressStr, + ) -> IPRecord: + + data = { + "user_id": user_id, + "ip": ipaddress.ip_address(ip).exploded, + "created": datetime.now(tz=timezone.utc), + } + + fips_cols = [ + "forwarded_ip1", + "forwarded_ip2", + "forwarded_ip3", + "forwarded_ip4", + "forwarded_ip5", + "forwarded_ip6", + ] + for col, ip in zip_longest( + fips_cols, + [ + forwarded_ip1, + forwarded_ip2, + forwarded_ip3, + forwarded_ip4, + forwarded_ip5, + forwarded_ip6, + ], + fillvalue=None, + ): + data[col] = ipaddress.ip_address(ip).exploded if ip else ip + + self.pg_config.execute_write( + query=f""" + INSERT INTO userhealth_iphistory ( + user_id, ip, created, + forwarded_ip1, forwarded_ip2, forwarded_ip3, + forwarded_ip4, forwarded_ip5, forwarded_ip6 + ) + VALUES ( + %(user_id)s, %(ip)s, %(created)s, + %(forwarded_ip1)s, %(forwarded_ip2)s, %(forwarded_ip3)s, + %(forwarded_ip4)s, %(forwarded_ip5)s, %(forwarded_ip6)s + ); + """, + params=data, + ) + self.recreate_user_ip_history_cache(user_id=user_id) + + return IPRecord.from_mysql(data) + + def get_user_latest_ip_record(self, user: User) -> Optional[IPRecord]: + res = self.filter_ip_records(user_ids=[user.user_id], limit=1) + if res: + return res[0] + return None + + def filter_ip_records( + self, + filter_ips: Optional[List[IPvAnyAddressStr]] = None, + user_ids: Optional[List[PositiveInt]] = None, + created_from: Optional[datetime] = None, + limit: Optional[int] = None, + ) -> List[IPRecord]: + + assert any([filter_ips, user_ids, created_from]), "Must provide filter criteria" + + if filter_ips is not None and not filter_ips: + raise AssertionError("Must provide valid filter_ips filter lists") + + if user_ids is not None and not user_ids: + raise AssertionError("Must provide valid user_id filter lists") + + filters = [] + params = {} + if filter_ips: + params["filter_ips"] = filter_ips + filters.append("ip = ANY(%(filter_ips)s)") + + if user_ids: + params["user_ids"] = user_ids + filters.append("user_id = ANY(%(user_ids)s)") + + if created_from: + params["created_from"] = created_from + filters.append("created >= %(created_from)s") + + filter_str = " AND ".join(filters) + filter_str = "WHERE " + filter_str if filter_str else "" + if limit is not None: + assert type(limit) is int + assert 0 <= limit <= 1000 + limit_str = f"LIMIT {limit}" if limit is not None else "" + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT i.ip, i.user_id, + i.forwarded_ip1, i.forwarded_ip2, + i.forwarded_ip3, i.forwarded_ip4, + i.forwarded_ip5, i.forwarded_ip6, + i.created + FROM userhealth_iphistory AS i + {filter_str} + ORDER BY created DESC + {limit_str} + """, + params=params, + ) + + return [IPRecord.from_mysql(i) for i in res] + + def recreate_user_ip_history_cache(self, user_id: int): + return self.user_ip_history_manager.recreate_user_ip_history_cache( + user_id=user_id + ) + + +class AuditLogManager(PostgresManager): + + def create_dummy( + self, + user_id: PositiveInt, + level: Optional[AuditLogLevel] = None, + event_type: Optional[str] = None, + event_msg: Optional[str] = None, + event_value: Optional[float] = None, + ) -> AuditLog: + + event_types = { + "offerwall-enter.blocked", + "offerwall-enter.rate-limited", + "offerwall-enter.url-modified", + } + + return self.create( + user_id=user_id, + level=level or rchoice(list(AuditLogLevel)), + event_type=event_type or rchoice(list(event_types)), + event_msg=event_msg, + event_value=event_value, + ) + + def create( + self, + user_id: PositiveInt, + level: AuditLogLevel, + event_type: str, + event_msg: Optional[str] = None, + event_value: Optional[float] = None, + ) -> AuditLog: + """AuditLogs may exist with the same event_type, and with different levels""" + + al = AuditLog.model_validate( + { + "user_id": user_id, + "created": datetime.now(tz=timezone.utc), + "level": level, + "event_type": event_type, + "event_msg": event_msg, + "event_value": event_value, + } + ) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute( + query=f""" + INSERT INTO userhealth_auditlog + (user_id, created, level, event_type, event_msg, event_value) + VALUES ( %(user_id)s , %(created)s, %(level)s, %(event_type)s, + %(event_msg)s, %(event_value)s) + RETURNING id; + """, + params=al.model_dump_mysql(), + ) + pk = c.fetchone()["id"] + conn.commit() + + al.id = pk + return al + + def get_by_id(self, auditlog_id: PositiveInt) -> AuditLog: + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT al.* + FROM userhealth_auditlog AS al + WHERE al.id = %s + LIMIT 2; + """, + params=(auditlog_id,), + ) + + if len(res) == 0: + raise Exception(f"No AuditLog with id of '{auditlog_id}'") + + if len(res) > 1: + raise Exception(f"Too many AuditLog found with id of '{auditlog_id}'") + + return AuditLog.from_mysql(res[0]) + + def filter_by_product(self, product: Product) -> List[AuditLog]: + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT al.* + FROM userhealth_auditlog AS al + INNER JOIN thl_user AS u + ON u.id = al.user_id + WHERE u.product_id = %s + ORDER BY al.created DESC + LIMIT 2500; + """, + params=(product.uuid,), + ) + + return [AuditLog.from_mysql(i) for i in res] + + def filter_by_user_id(self, user_id: PositiveInt) -> List[AuditLog]: + res = self.pg_config.execute_sql_query( + query=f""" + SELECT * + FROM userhealth_auditlog AS al + WHERE al.user_id = %s + ORDER BY al.created DESC + LIMIT 2500; + """, + params=(user_id,), + ) + + return [AuditLog.from_mysql(i) for i in res] + + def filter( + self, + user_ids: Collection[int], + level: Optional[int] = None, + level_ge: Optional[int] = None, + event_type: Optional[str] = None, + event_type_like: Optional[str] = None, + event_msg: Optional[str] = None, + created_after: Optional[datetime] = None, + ) -> List[AuditLog]: + + filter_str, args = self.make_filter_str( + user_ids=user_ids, + level=level, + level_ge=level_ge, + event_type=event_type, + event_type_like=event_type_like, + event_msg=event_msg, + created_after=created_after, + ) + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT user_id, created, level, event_type, + event_msg, event_value + FROM userhealth_auditlog + WHERE {filter_str} + """, + params=args, + ) + + return [AuditLog.from_mysql(i) for i in res] + + def filter_count( + self, + user_ids: Collection[int], + level: Optional[int] = None, + level_ge: Optional[int] = None, + event_type: Optional[str] = None, + event_type_like: Optional[str] = None, + event_msg: Optional[str] = None, + created_after: Optional[datetime] = None, + ) -> NonNegativeInt: + + filter_str, args = self.make_filter_str( + user_ids=user_ids, + level=level, + level_ge=level_ge, + event_type=event_type, + event_type_like=event_type_like, + event_msg=event_msg, + created_after=created_after, + ) + + res = self.pg_config.execute_sql_query( + query=f""" + SELECT COUNT(1) as c + FROM userhealth_auditlog + WHERE {filter_str} + """, + params=args, + ) + + assert len(res) == 1 + + return int(res[0]["c"]) + + @staticmethod + def make_filter_str( + user_ids: Collection[int], + level: Optional[int] = None, + level_ge: Optional[int] = None, + event_type: Optional[str] = None, + event_type_like: Optional[str] = None, + event_msg: Optional[str] = None, + created_after: Optional[datetime] = None, + ) -> Tuple[str, Dict]: + assert user_ids, "must pass at least 1 user_id" + assert all( + [isinstance(uid, int) for uid in user_ids] + ), "must pass user_id as int" + + if created_after is None: + created_after = datetime.now(tz=timezone.utc) - timedelta(days=7) + + filters = [ + "user_id = ANY(%(user_ids)s)", + "created >= %(created_after)s", + ] + args = {"user_ids": user_ids, "created_after": created_after} + + if level: + assert level_ge is None + filters.append("level = %(level)s") + args["level"] = level + if level_ge: + assert level is None + filters.append("level >= %(level_ge)s") + args["level_ge"] = level_ge + + if event_type: + assert event_type_like is None + filters.append("event_type = %(event_type)s") + args["event_type"] = event_type + if event_type_like: + assert event_type is None + filters.append("event_type LIKE %(event_type_like)s") + args["event_type_like"] = event_type_like + + if event_msg: + filters.append("event_msg = %(event_msg)s") + args["event_msg"] = event_msg + + filter_str = " AND ".join(filters) + return filter_str, args diff --git a/generalresearch/managers/thl/wall.py b/generalresearch/managers/thl/wall.py new file mode 100644 index 0000000..cd1bdbf --- /dev/null +++ b/generalresearch/managers/thl/wall.py @@ -0,0 +1,675 @@ +import logging +from collections import defaultdict +from datetime import datetime, timezone, timedelta +from decimal import Decimal, ROUND_DOWN +from functools import cached_property +from random import choice as rchoice +from typing import Optional, Collection, List +from uuid import uuid4 + +from faker import Faker +from psycopg import sql +from psycopg.rows import dict_row +from pydantic import AwareDatetime, PositiveInt, PostgresDsn, RedisDsn + +from generalresearch.managers import parse_order_by +from generalresearch.managers.base import ( + Permission, + PostgresManager, + PostgresManagerWithRedis, +) +from generalresearch.models import Source +from generalresearch.models.custom_types import UUIDStr, SurveyKey +from generalresearch.models.thl.definitions import ( + Status, + StatusCode1, + WallStatusCode2, + ReportValue, + WallAdjustedStatus, +) +from generalresearch.models.thl.ledger import OrderBy +from generalresearch.models.thl.session import ( + check_adjusted_status_wall_consistent, + Wall, + WallAttempt, +) +from generalresearch.models.thl.survey.model import TaskActivity +from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig + +logger = logging.getLogger("WallManager") +fake = Faker() + + +class WallManager(PostgresManager): + def __init__( + self, + pg_config: PostgresConfig, + permissions: Optional[Collection[Permission]] = None, + ): + assert pg_config.row_factory == dict_row + super().__init__(pg_config=pg_config, permissions=permissions) + + def create( + self, + session_id: int, + user_id: int, + started: datetime, + source: Source, + req_survey_id: str, + req_cpi: Decimal, + buyer_id: Optional[str] = None, + uuid_id: Optional[str] = None, + ) -> Wall: + """ + Creates a Wall event. Prefer to use this rather than instantiating + the model directly, because we're explicitly defining here which keys + should be set and which won't get set until later. + """ + if uuid_id is None: + uuid_id = uuid4().hex + + wall = Wall( + session_id=session_id, + user_id=user_id, + uuid=uuid_id, + started=started, + source=source, + buyer_id=buyer_id, + req_survey_id=req_survey_id, + req_cpi=req_cpi, + ) + d = wall.model_dump_mysql() + query = """ + INSERT INTO thl_wall ( + uuid, started, source, buyer_id, req_survey_id, + req_cpi, survey_id, cpi, session_id + ) VALUES ( + %(uuid)s, %(started)s, %(source)s, + %(buyer_id)s, %(req_survey_id)s, %(req_cpi)s, + %(survey_id)s, %(cpi)s, %(session_id)s + ); + """ + self.pg_config.execute_write(query=query, params=d) + return wall + + def create_dummy( + self, + session_id: Optional[int] = None, + user_id: Optional[int] = None, + started: Optional[datetime] = None, + source: Optional[Source] = None, + req_survey_id: Optional[str] = None, + req_cpi: Optional[Decimal] = None, + buyer_id: Optional[str] = None, + uuid_id: Optional[str] = None, + ): + """To be used in tests, where we don't care about certain fields""" + + user_id = user_id or fake.random_int(min=1, max=2_147_483_648) + started = started or fake.date_time_between( + start_date=datetime(year=1900, month=1, day=1, tzinfo=timezone.utc), + end_date=datetime.now(tz=timezone.utc), + tzinfo=timezone.utc, + ) + + if session_id is None: + from generalresearch.managers.thl.session import SessionManager + + session = SessionManager(pg_config=self.pg_config).create_dummy( + started=started + ) + session_id = session.id + + source = source or rchoice(list(Source)) + req_survey_id = req_survey_id or uuid4().hex + req_cpi = req_cpi or Decimal(fake.random_int(min=1, max=150) / 100).quantize( + Decimal(".01"), rounding=ROUND_DOWN + ) + + return self.create( + session_id=session_id, + user_id=user_id, + started=started, + source=source, + req_survey_id=req_survey_id, + req_cpi=req_cpi, + buyer_id=buyer_id, + uuid_id=uuid_id, + ) + + def get_from_uuid(self, wall_uuid: UUIDStr) -> Wall: + query = """ + SELECT + tw.uuid, tw.source, tw.buyer_id, tw.survey_id, + tw.req_survey_id, tw.cpi, tw.req_cpi, tw.started, + tw.finished, tw.status, tw.status_code_1, + tw.status_code_2, tw.ext_status_code_1, + tw.ext_status_code_2, tw.ext_status_code_3, + tw.report_value, tw.report_notes, tw.adjusted_status, + tw.adjusted_cpi, tw.adjusted_timestamp, tw.session_id, + ts.user_id + FROM thl_wall AS tw + JOIN thl_session AS ts + ON tw.session_id = ts.id + WHERE tw.uuid = %(wall_uuid)s + LIMIT 2; + """ + res = self.pg_config.execute_sql_query(query, params={"wall_uuid": wall_uuid}) + assert len(res) == 1, f"Expected 1 result, got {len(res)}" + return Wall.model_validate(res[0]) + + def get_from_uuid_if_exists(self, wall_uuid: UUIDStr) -> Optional[Wall]: + try: + return self.get_from_uuid(wall_uuid=wall_uuid) + except AssertionError: + return None + + def finish( + self, + wall: Wall, + status: Status, + status_code_1: StatusCode1, + finished: datetime, + ext_status_code_1: Optional[str] = None, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, + status_code_2: Optional[WallStatusCode2] = None, + survey_id: Optional[str] = None, + cpi: Optional[Decimal] = None, + ) -> None: + """This wall event is finished. This would be called if/when we get a + callback for this wall event. Some other code is responsible for + translating external status codes to grl statuses + """ + wall.finish( + status=status, + status_code_1=status_code_1, + status_code_2=status_code_2, + ext_status_code_1=ext_status_code_1, + ext_status_code_2=ext_status_code_2, + ext_status_code_3=ext_status_code_3, + finished=finished, + survey_id=survey_id, + cpi=cpi, + ) + d = { + "status": status, + "status_code_1": status_code_1.value, + "status_code_2": status_code_2.value if status_code_2 else None, + "finished": finished, + "ext_status_code_1": ext_status_code_1, + "ext_status_code_2": ext_status_code_2, + "ext_status_code_3": ext_status_code_3, + "uuid": wall.uuid, + } + extra = [] + if survey_id is not None: + extra.append("survey_id = %(survey_id)s") + d["survey_id"] = survey_id + if cpi is not None: + extra.append("cpi = %(cpi)s") + d["cpi"] = str(cpi) + extra_str = "," + ", ".join(extra) if extra else "" + + query = f""" + UPDATE thl_wall + SET status=%(status)s, status_code_1=%(status_code_1)s, + status_code_2=%(status_code_2)s, finished=%(finished)s, + ext_status_code_1=%(ext_status_code_1)s, + ext_status_code_2=%(ext_status_code_2)s, + ext_status_code_3=%(ext_status_code_3)s + {extra_str} + WHERE uuid = %(uuid)s; + """ + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query, params=d) + assert c.rowcount == 1 + conn.commit() + + return None + + def get_wall_events( + self, + session_id: Optional[PositiveInt] = None, + session_ids: Optional[List[PositiveInt]] = None, + order_by: OrderBy = OrderBy.ASC, + ) -> List[Wall]: + + if session_id is not None and session_ids is not None: + raise ValueError("Cannot provide both session_id and session_ids") + + if session_id is None and session_ids is None: + raise ValueError("Must provide either session_id or session_ids") + + ids = session_ids if session_ids is not None else [session_id] + + if len(ids) > 500: + raise ValueError("Cannot look up more than 500 Sessions at once.") + + query = f""" + SELECT + tw.uuid, tw.source, tw.buyer_id, tw.survey_id, + tw.req_survey_id, tw.cpi, tw.req_cpi, tw.started, + tw.finished, tw.status, tw.status_code_1, + tw.status_code_2, tw.ext_status_code_1, + tw.ext_status_code_2, tw.ext_status_code_3, + tw.report_value, tw.report_notes, tw.adjusted_status, + tw.adjusted_cpi, tw.adjusted_timestamp, tw.session_id, + ts.user_id + FROM thl_wall AS tw + JOIN thl_session AS ts + ON tw.session_id = ts.id + WHERE tw.session_id = ANY(%s) + ORDER BY tw.started {order_by.value} + """ + res = self.pg_config.execute_sql_query(query=query, params=[ids]) + return [Wall.model_validate(d) for d in res] + + def adjust_status( + self, + wall: Wall, + adjusted_timestamp: AwareDatetime, + adjusted_status: Optional[WallAdjustedStatus] = None, + adjusted_cpi: Optional[Decimal] = None, + ) -> None: + assert wall.status, "Wall must have an existing Status" + + # Be generous here, and if adjusted_status is adj to fail and + # adjusted_cpi is None, set it to 0 + if ( + adjusted_status == WallAdjustedStatus.ADJUSTED_TO_FAIL + and adjusted_cpi is None + ): + adjusted_cpi = 0 + elif ( + adjusted_status == WallAdjustedStatus.ADJUSTED_TO_COMPLETE + and adjusted_cpi is None + ): + adjusted_cpi = wall.cpi + + allowed, msg = check_adjusted_status_wall_consistent( + status=wall.status, + cpi=wall.cpi, + adjusted_status=wall.adjusted_status, + adjusted_cpi=wall.adjusted_cpi, + new_adjusted_status=adjusted_status, + new_adjusted_cpi=adjusted_cpi, + ) + + if not allowed: + raise ValueError(msg) + + wall.update( + adjusted_status=adjusted_status, + adjusted_cpi=adjusted_cpi, + adjusted_timestamp=adjusted_timestamp, + ) + d = { + "adjusted_status": ( + wall.adjusted_status.value if wall.adjusted_status else None + ), + "adjusted_timestamp": adjusted_timestamp, + "adjusted_cpi": ( + str(wall.adjusted_cpi) if wall.adjusted_cpi is not None else None + ), + "uuid": wall.uuid, + } + + query = sql.SQL( + """ + UPDATE thl_wall + SET adjusted_status = %(adjusted_status)s, + adjusted_timestamp = %(adjusted_timestamp)s, + adjusted_cpi = %(adjusted_cpi)s + WHERE uuid = %(uuid)s; + """ + ) + + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=query, params=d) + assert c.rowcount == 1 + conn.commit() + + return None + + def report( + self, + wall: Wall, + report_value: ReportValue, + report_notes: Optional[str] = None, + report_timestamp: Optional[AwareDatetime] = None, + ) -> None: + wall.report( + report_value=report_value, + report_notes=report_notes, + report_timestamp=report_timestamp, + ) + params = { + "uuid": wall.uuid, + "report_value": report_value.value, + "status": wall.status.value, + "finished": wall.finished, + "report_notes": report_notes, + } + query = sql.SQL( + """ + UPDATE thl_wall + SET report_value = %(report_value)s, + report_notes = %(report_notes)s, + status = %(status)s, + finished = %(finished)s + WHERE uuid = %(uuid)s; + """ + ) + with self.pg_config.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=query, params=params) + assert c.rowcount == 1 + conn.commit() + return None + + def filter_count_attempted_live(self, user_id: int) -> int: + """ + Get the number of surveys this user has attempted that + are still currently live. This can be shown as port of + a "progress bar" for eligible, live, surveys they've + already attempted. + """ + query = f""" + SELECT + COUNT(1) as cnt + FROM thl_wall w + JOIN thl_session s ON w.session_id = s.id + JOIN marketplace_survey ms ON + ms.source = w.source AND + ms.survey_id = w.req_survey_id AND + ms.is_live + WHERE user_id = %(user_id)s + AND w.source != 'g' + """ + params = {"user_id": user_id} + res = self.pg_config.execute_sql_query( + query=query, + params=params, + ) + return res[0]["cnt"] + + def filter_wall_attempts_paginated( + self, + user_id: int, + started_after: Optional[datetime] = None, + started_before: Optional[datetime] = None, + page: int = 1, + size: int = 100, + order_by: Optional[str] = "-started", + ) -> List[WallAttempt]: + """ + Returns WallAttempt + """ + filters = [] + params = {} + filters.append("user_id = %(user_id)s") + params["user_id"] = user_id + default_started = datetime.now(tz=timezone.utc) - timedelta(days=90) + started_after = started_after or default_started + started_before = started_before or datetime.now(tz=timezone.utc) + assert ( + started_before.tzinfo == timezone.utc + ), "started_before must be tz-aware as UTC" + assert ( + started_after < started_before + ), "started_after must be before started_before" + # Don't use BETWEEN b/c we want exclusive started_after here + filters.append( + "(w.started > %(started_after)s AND w.started <= %(started_before)s)" + ) + params["started_after"] = started_after + params["started_before"] = started_before + + filter_str = "WHERE " + " AND ".join(filters) if filters else "" + + assert page >= 1, "page starts at 1" + assert 1 <= size <= 500 + params["offset"] = (page - 1) * size + params["limit"] = size + paginated_filter_str = "LIMIT %(limit)s OFFSET %(offset)s" + + order_by_str = parse_order_by(order_by) + query = f""" + SELECT + w.req_survey_id, + w.started::timestamptz, + w.source, + w.uuid::uuid, + s.user_id + FROM thl_wall w + JOIN thl_session s on w.session_id = s.id + + {filter_str} + {order_by_str} + {paginated_filter_str} + """ + res = self.pg_config.execute_sql_query( + query=query, + params=params, + ) + return [WallAttempt.model_validate(x) for x in res] + + def filter_wall_attempts( + self, + user_id: int, + started_after: Optional[datetime] = None, + started_before: Optional[datetime] = None, + order_by: Optional[str] = "-started", + ) -> List[WallAttempt]: + started_before = started_before or datetime.now(tz=timezone.utc) + res = [] + page = 1 + while True: + chunk = self.filter_wall_attempts_paginated( + user_id=user_id, + started_after=started_after, + started_before=started_before, + order_by=order_by, + page=page, + size=250, + ) + res.extend(chunk) + if not chunk: + break + page += 1 + + return res + + def get_survey_activities( + self, survey_keys: Collection[SurveyKey], product_id: Optional[str] = None + ) -> List[TaskActivity]: + query_base = """ + row_stats AS ( + SELECT + source, survey_id, + count(*) FILTER (WHERE effective_status IS NULL) AS in_progress_count, + max(started) AS last_entrance, + max(finished) FILTER (WHERE effective_status = 'c') AS last_complete + FROM classified + GROUP BY source, survey_id + ), + status_agg AS ( + SELECT + source, survey_id, + jsonb_object_agg(effective_status, cnt) AS status_counts + FROM ( + SELECT source, survey_id, effective_status, count(*) AS cnt + FROM classified + WHERE effective_status IS NOT NULL + GROUP BY source, survey_id, effective_status + ) s + GROUP BY source, survey_id + ), + status_code_1_agg AS ( + SELECT + source, survey_id, + jsonb_object_agg(status_code_1, cnt) AS status_code_1_counts + FROM ( + SELECT source, survey_id, status_code_1, count(*) AS cnt + FROM classified + WHERE status_code_1 IS NOT NULL + GROUP BY source, survey_id, status_code_1 + ) sc + GROUP BY source, survey_id + ) + SELECT + rs.source, + rs.survey_id, + rs.in_progress_count, + rs.last_entrance, + rs.last_complete, + COALESCE(sa.status_counts, '{}'::jsonb) as status_counts, + COALESCE(sc1.status_code_1_counts, '{}'::jsonb) as status_code_1_counts + FROM row_stats rs + LEFT JOIN status_agg sa + ON sa.source = rs.source + AND sa.survey_id = rs.survey_id + LEFT JOIN status_code_1_agg sc1 + ON sc1.source = rs.source + AND sc1.survey_id = rs.survey_id + ORDER BY rs.source, rs.survey_id; + """ + + params = dict() + filters = [] + + # Instead of doing a big IN with a big set of tuples, since we know + # we only have N possible sources, we just split by that and do e.g.: + # ( (source = 'x' and survey_id IN ('1', '2') ) OR + # (source = 'y' and survey_id IN ('3', '4') ) ... ) + sk_filters = [] + survey_source_ids = defaultdict(set) + for sk in survey_keys: + source, survey_id = sk.split(":") + survey_source_ids[Source(source).value].add(survey_id) + for source, survey_ids in survey_source_ids.items(): + sk_filters.append( + f"(source = '{source}' AND survey_id = ANY(%(survey_ids_{source})s))" + ) + params[f"survey_ids_{source}"] = list(survey_ids) + # Make sure this is wrapped in parentheses! + filters.append(f"({' OR '.join(sk_filters)})") + + product_query_join = "" + if product_id is not None: + product_query_join = """ + JOIN thl_session ON w.session_id = thl_session.id + JOIN thl_user ON thl_user.id = thl_session.user_id""" + filters.append("product_id = %(product_id)s") + params["product_id"] = product_id + + filter_str = "WHERE " + " AND ".join(filters) if filters else "" + query_filter = f""" + WITH classified AS ( + SELECT + CASE WHEN w.status IS NULL AND now() - w.started >= interval '90 minutes' + THEN 't' ELSE w.status + END AS effective_status, + w.status_code_1, + w.started, + w.finished, + w.source, + w.survey_id + FROM thl_wall w {product_query_join} + {filter_str} + ), + """ + query = query_filter + query_base + res = self.pg_config.execute_sql_query(query, params) + return [TaskActivity.model_validate(x) for x in res] + + +class WallCacheManager(PostgresManagerWithRedis): + + @cached_property + def wall_manager(self): + return WallManager(pg_config=self.pg_config) + + def get_cache_key_(self, user_id: int) -> str: + assert type(user_id) is int, "user_id must be int" + return f"{self.cache_prefix}:generate_attempts:{user_id}" + + def get_flag_key_(self, user_id: int) -> str: + assert type(user_id) is int, "user_id must be int" + return f"{self.cache_prefix}:generate_attempts:flag:{user_id}" + + def is_flag_set(self, user_id: int) -> bool: + # This flag gets set if a new wall event is created. Whenever we + # update the cache we'll delete the flag. + assert type(user_id) is int, "user_id must be int" + return bool(self.redis_client.get(self.get_flag_key_(user_id=user_id))) + + def set_flag(self, user_id: int) -> None: + # Upon a wall entrance, set this, so we know we have to refresh the cache + assert type(user_id) is int, "user_id must be int" + self.redis_client.set(self.get_flag_key_(user_id=user_id), 1, ex=60 * 60 * 24) + + def clear_flag(self, user_id: int) -> None: + assert type(user_id) is int, "user_id must be int" + self.redis_client.delete(self.get_flag_key_(user_id=user_id)) + + def get_attempts_redis_(self, user_id: int) -> List[WallAttempt]: + redis_key = self.get_cache_key_(user_id=user_id) + # Returns a list even if there is nothing set + res = self.redis_client.lrange(redis_key, 0, 5000) + attempts = [WallAttempt.model_validate_json(x) for x in res] + return attempts + + def update_attempts_redis_(self, attempts: List[WallAttempt], user_id: int) -> None: + if not attempts: + return None + redis_key = self.get_cache_key_(user_id=user_id) + # Make sure attempts is ordered, so the most recent is last + # "LPUSH mylist a b c will result into a list containing c as first element, + # b as second element and a as third element" + attempts = sorted(attempts, key=lambda x: x.started) + json_res = [attempt.model_dump_json() for attempt in attempts] + res = self.redis_client.lpush(redis_key, *json_res) + self.redis_client.expire(redis_key, time=60 * 60 * 24) + # So this doesn't grow forever, keep only the most recent 5k + self.redis_client.ltrim(redis_key, 0, 4999) + return None + + def get_attempts(self, user_id: int) -> List[WallAttempt]: + """ + This is used in the GetOpportunityIDs call to get a list of surveys + (& surveygroups) which should be excluded for this user. We don't + need to know the status or if they finished the survey, just they + entered it, so we don't need to fetch 90 min backfills. The + WallAttempts are stored in a Redis List, ordered most-recent + in index 0. + """ + assert type(user_id) is int, "user_id must be int" + + wall_modified = self.is_flag_set(user_id=user_id) + if not wall_modified: + return self.get_attempts_redis_(user_id=user_id) + + # Attempt to get the most recent wall attempt + redis_key = self.get_cache_key_(user_id=user_id) + res: Optional[str] = self.redis_client.lindex(redis_key, 0) # type: ignore[assignment] + if res is None: + # Nothing in the cache, query for all from db + attempts = self.wall_manager.filter_wall_attempts(user_id=user_id) + self.update_attempts_redis_(attempts=attempts, user_id=user_id) + self.clear_flag(user_id=user_id) + return attempts + + # See if there is anything after the latest cached wall event we have + w = WallAttempt.model_validate_json(res) + new_attempts = self.wall_manager.filter_wall_attempts( + user_id=user_id, started_after=w.started + ) + self.update_attempts_redis_(attempts=new_attempts, user_id=user_id) + self.clear_flag(user_id=user_id) + return self.get_attempts_redis_(user_id=user_id) diff --git a/generalresearch/managers/thl/wallet/__init__.py b/generalresearch/managers/thl/wallet/__init__.py new file mode 100644 index 0000000..3b34756 --- /dev/null +++ b/generalresearch/managers/thl/wallet/__init__.py @@ -0,0 +1,147 @@ +from decimal import Decimal +from typing import Optional, Dict + +from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, +) +from generalresearch.managers.thl.payout import ( + UserPayoutEventManager, + PayoutEventManager, +) +from generalresearch.managers.thl.user_manager.user_manager import ( + UserManager, +) +from generalresearch.managers.thl.userhealth import UserIpHistoryManager +from generalresearch.managers.thl.wallet.approve import ( + approve_paypal_order, + approve_amt_cashout, +) +from generalresearch.models.thl.definitions import PayoutStatus +from generalresearch.models.thl.payout import UserPayoutEvent +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.thl.wallet.cashout_method import ( + CashMailOrderData, +) + + +def manage_pending_cashout( + payout_id: str, + new_status: PayoutStatus, + user_payout_event_manager: UserPayoutEventManager, + user_ip_history_manager: UserIpHistoryManager, + user_manager: UserManager, + ledger_manager: ThlLedgerManager, + order_data: Optional[Dict | CashMailOrderData] = None, +) -> UserPayoutEvent: + """ + Called by a UI actions performed by Todd. This rejects/approves/cancels + a payout event. We're calling this "cashout" because that is the + terminology used in generalresearch, even though the cashouts are stored + in the payoutevent table + + :param payout_id: the payoutevent pk hex + :param new_status: + :param user_payout_event_manager + :param user_ip_history_manager + :param user_manager + :param ledger_manager + :param order_data: For Cash_in_mail, pass this in. + + :returns: PayoutEvent object + """ + pe: UserPayoutEvent = user_payout_event_manager.get_by_uuid(payout_id) + pe.check_status_change_allowed(status=new_status) + assert pe.account_reference_type == "user" + user = user_manager.get_user(user_uuid=pe.account_reference_uuid) + user.prefetch_product(user_manager.mysql_user_manager.pg_config) + + assert ( + user.product.user_wallet_enabled + ), "manage_pending_cashout called on user without managed wallet" + assert not user.blocked, "manage_pending_cashout: Blocked user" + assert not user_ip_history_manager.is_user_anonymous( + user + ), "manage_pending_cashout: Anonymous user" + + # Just assign it with direct casting/type annotation + payout_event_manager: PayoutEventManager = user_payout_event_manager + + if new_status == PayoutStatus.APPROVED: + if pe.payout_type == PayoutType.TANGO: + from generalresearch.managers.thl.wallet.tango import ( + complete_tango_order, + ) + + complete_tango_order( + user=user, + payout_event=pe, + payout_event_manager=payout_event_manager, + ledger_manager=ledger_manager, + ) + + elif pe.payout_type == PayoutType.PAYPAL: + approve_paypal_order( + payout_event=pe, payout_event_manager=payout_event_manager + ) + + elif pe.payout_type in {PayoutType.AMT_BONUS, PayoutType.AMT_HIT}: + approve_amt_cashout( + user=user, + payout_event=pe, + payout_event_manager=payout_event_manager, + ledger_manager=ledger_manager, + ) + + elif pe.payout_type == PayoutType.CASH_IN_MAIL: + assert order_data, "must pass order_data" + payout_event_manager.update( + pe, status=PayoutStatus.APPROVED, order_data=order_data + ) + ledger_manager.create_tx_user_payout_complete( + user, + payout_event=pe, + fee_amount=Decimal(order_data.shipping_cost) / 100, + ) + + else: + raise ValueError(f"unsupported payout_type: {pe.payout_type}") + + return pe + + elif new_status == PayoutStatus.COMPLETE: + # Used only for AMT/dummy cashouts that are actually paid out not + # by us. They are informing us that the cashout was successfully + # sent to the user + if pe.payout_type in {PayoutType.AMT_BONUS, PayoutType.AMT_HIT}: + # We already do this under approve_amt_cashout() + pass + + elif pe.payout_type == PayoutType.PAYPAL: + # This is an issue here in that we actually don't know what the + # fee is until it is sent and we read it back from paypal's csv + # result. We have to just run this with a custom script, which + # uses manual_complete_paypal_order() + raise ValueError("user custom paypal script for this") + + payout_event_manager.update(pe, status=new_status) + + return pe + + elif new_status == PayoutStatus.REJECTED: + # They lose the money in their wallet at this point, no ledger txs occur. + payout_event_manager.update(pe, status=new_status) + return pe + + elif new_status == PayoutStatus.CANCELLED: + # create another ledger item putting the money back into their wallet. + payout_event_manager.update(pe, status=new_status) + ledger_manager.create_tx_user_payout_cancelled(user, payout_event=pe) + return pe + + elif new_status == PayoutStatus.FAILED: + # We just update the status (like in PayoutStatus.REJECTED). No ledger xs + payout_event_manager.update(pe, status=new_status) + return pe + + else: + raise ValueError(f"unsupported status: {new_status}") diff --git a/generalresearch/managers/thl/wallet/approve.py b/generalresearch/managers/thl/wallet/approve.py new file mode 100644 index 0000000..4a7ae5e --- /dev/null +++ b/generalresearch/managers/thl/wallet/approve.py @@ -0,0 +1,52 @@ +from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, +) +from generalresearch.managers.thl.payout import PayoutEventManager +from generalresearch.models.thl.definitions import PayoutStatus +from generalresearch.models.thl.payout import UserPayoutEvent +from generalresearch.models.thl.user import User + + +def approve_paypal_order( + payout_event: UserPayoutEvent, payout_event_manager: PayoutEventManager +): + """ + The order has been approved, but it hasn't actually been sent. + """ + assert payout_event.status in { + PayoutStatus.PENDING, + PayoutStatus.FAILED, + }, "attempting to manage payout that is not pending (or you can retry a failed order)" + + payout_event_manager.update(payout_event, status=PayoutStatus.APPROVED) + + interface = payout_event.request_data.get("interface") + if interface == "api": + # todo: Use the Payouts API to sent this payout, and then update the DB + pass + + else: + # Flow monitoring for payouts where the type is paypal, the status is + # approved, and the interface is web, and then it'll send the payout and + # update the status to complete (and create a ledger item for the fee) + pass + + return payout_event + + +def approve_amt_cashout( + user: User, + payout_event: UserPayoutEvent, + ledger_manager: ThlLedgerManager, + payout_event_manager: PayoutEventManager, +) -> None: + """ + This is going to be paid out by the requester (the jb-lambdas) as an AMT bonus. + """ + assert payout_event.status in { + PayoutStatus.PENDING, + PayoutStatus.FAILED, + }, "attempting to manage payout that is not pending (or you can retry a failed order)" + + payout_event_manager.update(payout_event, status=PayoutStatus.APPROVED) + ledger_manager.create_tx_user_payout_complete(user, payout_event=payout_event) diff --git a/generalresearch/managers/thl/wallet/tango.py b/generalresearch/managers/thl/wallet/tango.py new file mode 100644 index 0000000..5587435 --- /dev/null +++ b/generalresearch/managers/thl/wallet/tango.py @@ -0,0 +1,162 @@ +from typing import Any, Dict + +import slack + +from generalresearch.config import ( + is_debug, +) +from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, +) +from generalresearch.managers.thl.payout import PayoutEventManager +from generalresearch.models.thl.definitions import PayoutStatus +from generalresearch.models.thl.payout import UserPayoutEvent +from generalresearch.models.thl.user import User + +# from raas.api_helper import APIHelper +# from raas.exceptions.raas_client_exception import RaasClientException +# from raas.raas_client import RaasClient + +# RaasClient.config.environment = 1 +# api_client = RaasClient( +# platform_name=TANGO_PLATFORM_NAME, platform_key=TANGO_PLATFORM_KEY +# ) +# it really annoyingly logs the entire http response. turn it off +# api_client.catalog.logger.setLevel(logging.INFO) +# api_client.exchange_rates.logger.setLevel(logging.INFO) +# api_client.orders.logger.setLevel(logging.INFO) +# api_client.status.logger.setLevel(logging.INFO) +# api_client.accounts.logger.setLevel(logging.INFO) +# api_client.customers.logger.setLevel(logging.INFO) +# api_client.fund.logger.setLevel(logging.INFO) + + +def complete_tango_order( + user: User, + payout_event: UserPayoutEvent, + payout_event_manager: PayoutEventManager, + ledger_manager: ThlLedgerManager, +): + """ + We approved the Tango card redemption. Actually request the card. + + (Note: we're skipping the PENDING -> APPROVED -> COMPLETE order for tango. + When a tango request gets APPROVED, we COMPLETE it (or FAIL!) in the + same step) + """ + assert payout_event.status in { + PayoutStatus.PENDING, + PayoutStatus.FAILED, + }, "attempting to manage payout that is not pending (or you can retry a failed order)" + request = payout_event.request_data + ref_id = request["externalRefID"] + # amount_usd = Decimal(payout_event.request_data["amount_usd"]) + + # Note: tango uses the ref_id to uniquify orders, so locking is not actually needed as long + # as the ref_id is the same. + try: + order = create_tango_order( + request_data=payout_event.request_data, ref_id=ref_id + ) + + except Exception as e: + # todo: its possible the order went through, but something else was wrong + # we should try to retrieve the order by its ref_id and confirm it really + # failed... + payout_event_manager.update(payout_event, status=PayoutStatus.FAILED) + return payout_event + + # update TangoPayoutEvent with the order data + payout_event_manager.update( + payout_event, + status=order["status"], + ext_ref_id=order["referenceOrderID"], + order_data=order, + ) + + ledger_manager.create_tx_user_payout_complete(user, payout_event=payout_event) + + return payout_event + + +def get_tango_order(ref_id: str): + """ + Retrieve a tango order by its external ref ID. + We should have set it to the TangoPayoutEvent instance uuid associated + with this Tango order (lowercase no dashes). + :return: the json order data or None if doesn't exist + """ + raise NotImplementedError("convert to requests") + # orders = api_client.orders.get_orders({"external_ref_id": ref_id}).orders + # if orders: + # return json.loads(APIHelper.json_serialize(orders[0])) + + +def create_tango_order(request_data: Dict, ref_id: str) -> Dict[str, Any]: + """ + Create a tango gift card order. + Throws exception if anything is not right. + # https://integration-www.tangocard.com/raas_api_console/v2/ + # https://www.apimatic.io/apidocs/tangocard/v/2_3_4#/python + + :param utid: Card identifier + :param amount: requested card value in USD + :param ref_id: TangoPayoutEvent.uuid + :return: + """ + # make sure we don't create more than one tango order for a single PayoutEvent + assert get_tango_order(ref_id) is None + amount = request_data["amount"] + request_data.pop("amount_usd", None) + request_data.pop("description", None) + + if is_debug(): + return { + "status": "COMPLETE", + "referenceOrderID": "test", + "reward": { + "credentials": { + "Security Code": "XXXX-XXXX", + "Redemption URL": "https://codes.rewardcodes.com/r2/1/XXXX", + }, + "credentialList": [ + { + "type": "text", + "label": "Security Code", + "value": "XXXX-XXXX", + }, + { + "type": "url", + "label": "Redemption URL", + "value": "https://codes.rewardcodes.com/r2/1/XXXX", + }, + ], + "redemptionInstructions": "do your thang fam", + }, + } + + raise NotImplementedError("convert to requests") + # try: + # order = api_client.orders.create_order(request_data) + # order = json.loads(APIHelper.json_serialize(order)) + # except RaasClientException as e: + # e = json.loads(APIHelper.json_serialize(e)) + # try: + # msgs = [x["message"] for x in e["errors"]] + # print(" | ".join(msgs)) + # except Exception: + # pass + # capture_exception() + # raise e + # except Exception as e: + # capture_exception() + # raise e + + # amount_f: float = float(amount) + # assert order["status"] == "COMPLETE" + # assert abs(order["amountCharged"]["total"] - amount_f) < 0.0200001 + # assert order["amountCharged"]["currencyCode"] == "USD" + # if order["denomination"]["currencyCode"] == "USD": + # assert order["denomination"]["value"] == amount_f + # + # return order diff --git a/generalresearch/mariadb.py b/generalresearch/mariadb.py new file mode 100644 index 0000000..d95e46c --- /dev/null +++ b/generalresearch/mariadb.py @@ -0,0 +1,42 @@ +import mariadb.constants +from mariadb.constants import EXT_FIELD_TYPE + +# This should be an enum, or a dictionary ... of course it's not +# field_flags = {k: getattr(mariadb.constants.FIELD_FLAG, k) for k in dir(mariadb.constants.FIELD_FLAG) +# if not k.startswith('__')} +ext_field_flags = { + k: getattr(EXT_FIELD_TYPE, k) for k in dir(EXT_FIELD_TYPE) if not k.startswith("__") +} +ext_field_flags_rev = {v: k for k, v in ext_field_flags.items()} + + +# def decode_field_flags(field_flag: int): +# # https://mariadb-corporation.github.io/mariadb-connector-python/cursor.html +# # This was written by chatgpt basically... idk how binary works. +# decoded_flags = [] +# for flag, value in field_flags.items(): +# if field_flag & value: +# decoded_flags.append(flag) +# +# return decoded_flags + + +def example(): + # actually we don't need the field flags. I didn't see, but there is an + # extended field type returned also. Which explicitly tags uuid fields. + conn = mariadb.connect( + host="127.0.0.1", user="root", password="", database="300large-morning" + ) + c = conn.cursor() + c.execute("SELECT user_id, pid as greg FROM morning_userpid limit 1") + for m in zip(c.metadata["field"], c.metadata["ext_type_or_format"]): + # here we can just check if the field's ext_field_flag == 'UUID' (2) + print(m[0], ext_field_flags_rev[m[1]]) + + +def get_column_types(): + # How does django do this? + res = """ + SELECT column_name, data_type + FROM information_schema.columns + WHERE table_name = 'morning_userpid' AND table_schema = DATABASE()""" diff --git a/generalresearch/models/__init__.py b/generalresearch/models/__init__.py new file mode 100644 index 0000000..0a013e2 --- /dev/null +++ b/generalresearch/models/__init__.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from enum import Enum + +from generalresearch.utils.enum import ReprEnumMeta + + +class Source(str, Enum, metaclass=ReprEnumMeta): + # The external marketplace, or the source of the survey / work. + # Max length of the value is 2. + GRS = "g" + CINT = "c" + DALIA = "a" # deprecated + DYNATA = "d" + ETX = "et" + FULL_CIRCLE = "f" + INNOVATE = "i" + LUCID = "l" + MORNING_CONSULT = "m" + OPEN_LABS = "n" + POLLFISH = "o" + PRECISION = "e" + PRODEGE_USER = "r" # deprecated + PRODEGE = "pr" # using 'r' for vendor_wall + PULLEY = "p" # deprecated + REPDATA = "rd" # using 'q' for vendor_wall + SAGO = "h" + SPECTRUM = "s" + TESTING = "t" # Used internally for testing + TESTING2 = "u" # Used internally for testing + WXET = "w" + + +class DebitKey(int, Enum, metaclass=ReprEnumMeta): + # The debit key for marketplaces + CINT = 8 + DALIA = 9 + DYNATA = 6 + # ETX = None + FULL_CIRCLE = 15 + INNOVATE = 7 + LUCID = 0 + MORNING_CONSULT = 12 + # OPEN_LABS = None + POLLFISH = 13 + PRECISION = 14 + PRODEGE = 11 + SAGO = 10 + SPECTRUM = 5 + # WXET = None + + +class DeviceType(int, Enum, metaclass=ReprEnumMeta): + UNKNOWN = 0 + MOBILE = 1 + DESKTOP = 2 + TABLET = 3 + + +class LogicalOperator(str, Enum, metaclass=ReprEnumMeta): + OR = "OR" + AND = "AND" + # There is currently no use case for NOT. See MarketplaceCondition.explain_not + NOT = "NOT" + + +class TaskStatus(str, Enum, metaclass=ReprEnumMeta): + # A survey is live if it is open and, given all conditions are met, is + # possible to send in traffic. All other statuses are just variants of + # NOT Live (not accepting traffic) + LIVE = "LIVE" + + # This is a generic NOT Live status. A marketplace may use other more + # specific statuses but in practice they don't matter because all we care + # about is if the task is LIVE. + NOT_LIVE = "NOT_LIVE" + + # We need a status to mark if a survey we thought was live does not come + # back from the API, we'll mark it as NOT_FOUND. + NOT_FOUND = "NOT_FOUND" + + +class TaskCalculationType(str, Enum): + COMPLETES = "COMPLETES" + STARTS = "STARTS" + + @classmethod + def from_api(cls, v: str) -> "TaskCalculationType": + return { + "complete": cls.COMPLETES, + "completes": cls.COMPLETES, + "survey start": cls.STARTS, + "survey starts": cls.STARTS, + "start": cls.STARTS, + "prescreens": cls.STARTS, + "prescreen": cls.STARTS, + }[v.lower()] + + @classmethod + def prodege_from_api(cls, v: int) -> "TaskCalculationType": + return {1: cls.COMPLETES, 2: cls.STARTS}[v] + + @classmethod + def innovate_from_api(cls, v: int) -> "TaskCalculationType": + return {0: cls.COMPLETES, 1: cls.STARTS}[v] + + +class URLQueryKey(str, Enum, metaclass=ReprEnumMeta): + PRODUCT_ID = "39057c8b" + PRODUCT_USER_ID = "c184efc0" + SESSION_ID = "0bb50182" + + +MAX_INT32 = 2**31 diff --git a/generalresearch/models/admin/__init__.py b/generalresearch/models/admin/__init__.py new file mode 100644 index 0000000..6e57d56 --- /dev/null +++ b/generalresearch/models/admin/__init__.py @@ -0,0 +1,59 @@ +from datetime import datetime, timezone +from typing import Optional + +import pandas as pd +from dateutil import relativedelta + + +def get_date_list(start_datetime: datetime, end_datetime: Optional[datetime] = None): + start_datetime = start_datetime.replace(tzinfo=timezone.utc) + end_datetime = end_datetime if end_datetime else datetime.now(tz=timezone.utc) + return ( + pd.date_range(start_datetime, end_datetime, freq="1D") + .strftime("%Y-%m-%d") + .tolist() + ) + + +def year_start(periods_ago: int = 6) -> datetime: + """ + Returns the starting date of the last N Full + years. Goal is to provide a simple way to + know when to do filters from + """ + n: datetime = datetime.now(tz=timezone.utc) + d: datetime = n - relativedelta.relativedelta(years=periods_ago) + return d.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0) + + +def month_start(periods_ago: int = 6) -> datetime: + """ + Returns the starting date of the last N Full + months. Goal is to provide a simple way to + know when to do filters from + """ + n: datetime = datetime.now(tz=timezone.utc) + d: datetime = n - relativedelta.relativedelta(months=periods_ago) + return d.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + +def day_start(periods_ago: int = 6) -> datetime: + """ + Returns the starting date of the last N Full + days. Goal is to provide a simple way to + know when to do filters from + """ + n: datetime = datetime.now(tz=timezone.utc) + d: datetime = n - relativedelta.relativedelta(days=periods_ago) + return d.replace(hour=0, minute=0, second=0, microsecond=0) + + +def hour_start(periods_ago: int = 6) -> datetime: + """ + Returns the starting date of the last N Full + hours. Goal is to provide a simple way to + know when to do filters from + """ + n: datetime = datetime.now(tz=timezone.utc) + d: datetime = n - relativedelta.relativedelta(hours=periods_ago) + return d.replace(minute=0, second=0, microsecond=0) diff --git a/generalresearch/models/admin/request.py b/generalresearch/models/admin/request.py new file mode 100644 index 0000000..30e9fcd --- /dev/null +++ b/generalresearch/models/admin/request.py @@ -0,0 +1,157 @@ +from datetime import datetime, timezone, timedelta +from enum import Enum +from typing import Literal, List, Tuple + +import pandas as pd +from pydantic import BaseModel, Field, model_validator, computed_field + +from generalresearch.models.custom_types import AwareDatetimeISO + + +class ReportType(Enum): + POP_SESSION = "pop_session" + POP_EVENT = "pop_event" + POP_LEDGER = "pop_ledger" + + +class ReportRequest(BaseModel): + report_type: ReportType = Field(default=ReportType.POP_SESSION) + + index0: str = Field( + default="started", + ) + index1: str = Field(default="product_id") + + start: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) - timedelta(days=14) + ) + end: AwareDatetimeISO = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) + + interval: Literal["5min", "15min", "1h", "6h", "12h", "1d"] = "1h" + include_open_bucket: bool = Field(default=True) + + @computed_field( + title="Start floor", + description="The datetime that this report starts from", + examples=[datetime(year=2025, month=5, day=1)], + return_type=datetime, + ) + @property + def start_floor(self) -> AwareDatetimeISO: + """Always floor start time to the interval.""" + return self.ts_start.floor(self.interval).to_pydatetime() + + # --- Validation --- + + @model_validator(mode="after") + def check_start_end(self): + + assert self.start < self.end - timedelta( + hours=1 + ), "Reports less than 1 hour are not supported" + + assert ( + self.end - self.start + ).days < 365 * 10, "Report.starts must not be longer than 10 years" + + return self + + @model_validator(mode="after") + def check_start_end_tz(self): + assert self.start.tzinfo == self.end.tzinfo == timezone.utc + return self + + @model_validator(mode="after") + def check_index0_for_schema(self): + if self.report_type == ReportType.POP_SESSION: + assert self.index0 in [ + "started" + ], f"session exports can't split by {self.index0}" + return self + + @model_validator(mode="after") + def check_index1_for_schema(self): + if self.report_type == ReportType.POP_SESSION: + assert self.index1 in [ + "product_id", + "user_id", + "country_iso", + "device_type", + "status", + "status_code_1", + "status_code_2", + ], f"session exports can't split by {self.index1}" + + if self.report_type == ReportType.POP_EVENT: + assert self.index1 in [ + "product_id", + "user_id", + "country_iso", + "device_type", + "source", + "buyer_id", + "survey_id", + "status", + "status_code_1", + "status_code_2", + ], f"wall exports can't split by {self.index1}" + + return self + + # --- Properties --- + @property + def pd_interval(self) -> pd.Interval: + return pd.Interval( + left=pd.Timestamp(self.start_floor), + right=pd.Timestamp(self.end), + closed="both", + ) + + @property + def interval_timedelta(self) -> pd.Timedelta: + return pd.Timedelta(self.interval) + + @property + def start_floor_naive(self) -> datetime: + return self.start_floor.replace(tzinfo=None) + + @property + def end_naive(self) -> datetime: + return datetime.now(tz=None) + + @property + def ts_start(self) -> pd.Timestamp: + return pd.Timestamp(self.start) + + @property + def ts_start_floor(self) -> pd.Timestamp: + return pd.Timestamp(self.start_floor) + + @property + def ts_end(self) -> pd.Timestamp: + return pd.Timestamp(self.end) + + @property + def finish(self) -> pd.Timestamp: + return self.end + + @property + def ts_finish(self) -> pd.Timestamp: + return pd.Timestamp(self.end) + + # --- Methods --- + def buckets(self) -> pd.DatetimeIndex: + """ + Returns all bucket start times. + """ + return pd.date_range( + start=self.ts_start_floor, + end=self.ts_end, + freq=self.interval, + tz=timezone.utc, + ) + + def bucket_ranges(self) -> List[Tuple[pd.Timestamp, pd.Timestamp]]: + """Returns list of (start, end) tuples for each bucket.""" + starts = self.buckets() + return [(s, s + self.interval_timedelta) for s in starts] diff --git a/generalresearch/models/cint/__init__.py b/generalresearch/models/cint/__init__.py new file mode 100644 index 0000000..ea3e500 --- /dev/null +++ b/generalresearch/models/cint/__init__.py @@ -0,0 +1,7 @@ +from pydantic import Field + +from typing_extensions import Annotated + +CintQuestionIdType = Annotated[ + str, Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") +] diff --git a/generalresearch/models/cint/question.py b/generalresearch/models/cint/question.py new file mode 100644 index 0000000..f840a46 --- /dev/null +++ b/generalresearch/models/cint/question.py @@ -0,0 +1,244 @@ +import json +from datetime import datetime, timezone +from enum import Enum +from typing import Optional, List, Literal, Dict, Any +from uuid import UUID + +from pydantic import Field, BaseModel, model_validator, field_validator +from typing_extensions import Self + +from generalresearch.models import Source, string_utils +from generalresearch.models.cint import CintQuestionIdType +from generalresearch.models.custom_types import AwareDatetimeISO +from generalresearch.models.thl.profiling.marketplace import ( + MarketplaceQuestion, + MarketplaceUserQuestionAnswer, +) + + +class CintQuestionType(str, Enum): + SINGLE_SELECT = "s" + MULTI_SELECT = "m" + # Dummy means they're calculated + DUMMY = "d" + TEXT_ENTRY = "t" + NUMERIC_ENTRY = "n" + + @classmethod + def from_api(cls, a: int): + API_TYPE_MAP = { + "Single Punch": CintQuestionType.SINGLE_SELECT, + "Multi Punch": CintQuestionType.MULTI_SELECT, + "Dummy": CintQuestionType.DUMMY, + # What's the difference between dummy and calculated dummy? I thought dummy + # was calculated? who knows + "Calculated Dummy": CintQuestionType.DUMMY, + "Open Ended": CintQuestionType.TEXT_ENTRY, + "Numeric - Open-end": CintQuestionType.NUMERIC_ENTRY, + # This seems to be invalid as there are no options??? + "Grid": None, + } + return API_TYPE_MAP[a] if a in API_TYPE_MAP else None + + +class CintUserQuestionAnswer(MarketplaceUserQuestionAnswer): + question_id: CintQuestionIdType = Field() + question_type: Optional[CintQuestionType] = Field(default=None) + # Did this answer come from us asking, or was it passed back from the marketplace + from_thl: bool = Field(default=True) + + +class CintQuestionOption(BaseModel): + id: str = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+|-3105|true|false$", + frozen=True, + description="This is called precode in their API", + ) + text: str = Field( + min_length=1, + max_length=1024, + frozen=True, + description="The response text shown to respondents", + ) + order: int = Field() + + +class CintQuestion(MarketplaceQuestion): + question_id: CintQuestionIdType = Field( + description="The unique identifier for the qualification", + frozen=True, + examples=["741"], + ) + question_name: str = Field(examples=["STANDARD_GAMING_TYPE"]) + question_text: str = Field( + max_length=1024, + min_length=1, + description="The text shown to respondents", + frozen=False, + examples=["What kind(s) of video/computer games do you play?"], + ) + question_type: CintQuestionType = Field( + description="The type of question asked", + frozen=True, + examples=[CintQuestionType.MULTI_SELECT], + ) + options: Optional[List[CintQuestionOption]] = Field( + default=None, min_length=1, frozen=True + ) + option_mask: str = Field(examples=["000000000000000000"]) + classification_code: Optional[str] = Field(examples=["ELE"], default=None) + # This comes from the API! not us + created_at: AwareDatetimeISO = Field(description="Called create_date in API") + + source: Literal[Source.CINT] = Source.CINT + + @property + def internal_id(self) -> str: + return self.question_id + + @field_validator("question_name", "question_text", mode="after") + def remove_nbsp(cls, s: Optional[str]) -> Optional[str]: + return string_utils.remove_nbsp(s) + + @model_validator(mode="after") + def check_type_options_agreement(self) -> Self: + if self.question_type in { + CintQuestionType.TEXT_ENTRY, + CintQuestionType.NUMERIC_ENTRY, + }: + assert self.options is None, "TEXT_ENTRY/NUMERICAL shouldn't have options" + elif self.question_type == CintQuestionType.DUMMY: + # These are calculated. Sometimes they have options? idk + pass + else: + assert self.options is not None, "missing options" + return self + + @field_validator("options") + @classmethod + def order_options(cls, options): + if options: + options.sort(key=lambda x: x.order) + return options + + @field_validator("options") + @classmethod + def validate_options(cls, options): + if options: + ids = {x.id for x in options} + assert len(ids) == len(options), "options.id must be unique" + orders = {x.order for x in options} + assert len(orders) == len(options), "options.order must be unique" + return options + + @classmethod + def from_api(cls, d: dict, country_iso: str, language_iso: str) -> Self: + options = None + created_at = datetime.strptime( + d["create_date"], "%Y-%m-%dT%H:%M:%S%z" + ).astimezone(timezone.utc) + if d.get("question_options"): + options = [ + CintQuestionOption( + id=r["precode"], text=r["text"], order=r.get("order", order) + ) + for order, r in enumerate(d["question_options"]) + ] + # Sometimes the order from the api is incorrect + orders = {opt.order for opt in options} + if len(orders) != len(options): + for idx, opt in enumerate(options): + opt.order = idx + + return cls( + question_id=str(d["id"]), + question_name=d["name"], + question_text=d["question_text"], + question_type=CintQuestionType.from_api(d["question_type"]), + country_iso=country_iso, + language_iso=language_iso, + options=options, + option_mask=d["option_mask"], + created_at=created_at, + classification_code=d["classification_code"], + ) + + @classmethod + def from_db(cls, d: dict) -> Self: + options = None + if d["options"]: + options = [ + CintQuestionOption(id=r["id"], text=r["text"], order=r["order"]) + for r in d["options"] + ] + if d.get("created_at"): + d["created_at"] = d["created_at"].replace(tzinfo=timezone.utc) + return cls( + question_id=d["question_id"], + question_name=d["question_name"], + question_text=d["question_text"], + question_type=d["question_type"], + country_iso=d["country_iso"], + language_iso=d["language_iso"], + options=options, + option_mask=d["option_mask"], + created_at=d["created_at"], + classification_code=d["classification_code"], + category_id=UUID(d["category_id"]).hex if d["category_id"] else None, + ) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json") + d["options"] = json.dumps(d["options"]) + if self.created_at: + d["created_at"] = self.created_at.replace(tzinfo=None) + return d + + def to_upk_question(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionType, + UpkQuestionSelectorMC, + UpkQuestionSelectorTE, + UpkQuestion, + ) + + upk_type_selector_map = { + CintQuestionType.SINGLE_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.SINGLE_ANSWER, + ), + # CintQuestionType.DUMMY: ( + # UpkQuestionType.MULTIPLE_CHOICE, + # UpkQuestionSelectorMC.SINGLE_ANSWER, + # ), + CintQuestionType.MULTI_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.MULTIPLE_ANSWER, + ), + CintQuestionType.TEXT_ENTRY: ( + UpkQuestionType.TEXT_ENTRY, + UpkQuestionSelectorTE.SINGLE_LINE, + ), + CintQuestionType.NUMERIC_ENTRY: ( + UpkQuestionType.TEXT_ENTRY, + UpkQuestionSelectorTE.SINGLE_LINE, + ), + } + upk_type, upk_selector = upk_type_selector_map[self.question_type] + d = { + "ext_question_id": self.external_id, + "country_iso": self.country_iso, + "language_iso": self.language_iso, + "type": upk_type, + "selector": upk_selector, + "text": self.question_text, + } + if self.options: + d["choices"] = [ + UpkQuestionChoice(id=c.id, text=c.text, order=c.order) + for c in self.options + ] + return UpkQuestion(**d) diff --git a/generalresearch/models/cint/survey.py b/generalresearch/models/cint/survey.py new file mode 100644 index 0000000..30b140c --- /dev/null +++ b/generalresearch/models/cint/survey.py @@ -0,0 +1,532 @@ +from __future__ import annotations + +import json +import logging +from datetime import datetime, timezone +from decimal import Decimal +from typing import Optional, Dict, Set, Tuple, List, Literal, Any, Type + +from more_itertools import flatten +from pydantic import ( + NonNegativeInt, + Field, + ConfigDict, + BaseModel, + computed_field, + model_validator, +) +from typing_extensions import Self, Annotated + +from generalresearch.locales import Localelator +from generalresearch.models import Source, TaskCalculationType +from generalresearch.models.cint import CintQuestionIdType +from generalresearch.models.custom_types import ( + AwareDatetimeISO, + CoercedStr, + AlphaNumStr, +) +from generalresearch.models.thl.demographics import Gender +from generalresearch.models.thl.survey import MarketplaceTask +from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, +) + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +locale_helper = Localelator() + + +class CintCondition(MarketplaceCondition): + model_config = ConfigDict(populate_by_name=True, frozen=False, extra="ignore") + + source: Source = Field(default=Source.CINT) + question_id: CintQuestionIdType = Field() + + @classmethod + def from_api(cls, d: Dict[str, Any]) -> Self: + d["question_id"] = str(d["question_id"]) + d["values"] = list(map(str.lower, d["precodes"])) + d["value_type"] = ConditionValueType.LIST + if d["logical_operator"] == "NOT": + # In cint, a not means a negated OR + d["negate"] = True + d["logical_operator"] = "OR" + return cls.model_validate(d) + + +class CintQuota(BaseModel): + model_config = ConfigDict(populate_by_name=True, frozen=True) + quota_id: CoercedStr = Field(validation_alias="survey_quota_id") + quota_type: Literal["total", "client"] = Field(validation_alias="survey_quota_type") + conversion: Optional[float] = Field(ge=0, le=1, default=None) + number_of_respondents: NonNegativeInt = Field( + description="Number of completes available" + ) + condition_hashes: Optional[List[str]] = Field(min_length=1, default=None) + + def __hash__(self): + return hash(tuple((tuple(self.condition_hashes), self.quota_id))) + + @model_validator(mode="after") + def validate_condition_len(self) -> Self: + if self.quota_type == "total": + assert ( + self.condition_hashes is None + ), "total quota should not have conditions" + elif self.quota_type == "client": + assert len(self.condition_hashes) > 0, "quota must have conditions" + return self + + @property + def is_open(self) -> bool: + return self.number_of_respondents >= 2 + + @classmethod + def from_api(cls, d: Dict) -> Self: + d["survey_quota_type"] = d["survey_quota_type"].lower() + return cls.model_validate(d) + + def passes(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Passes means we 1) meet all conditions (aka "match") AND 2) the quota is open. + return self.is_open and self.matches(criteria_evaluation) + + def matches(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Matches means we meet all conditions. + # We can "match" a quota that is closed. In that case, we would not be eligible for the survey. + return all(criteria_evaluation.get(c) for c in self.condition_hashes) + + def matches_optional( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Optional[bool]: + # We need to know if any conditions are unknown to avoid matching a full quota. If any fail, + # then we know we fail regardless of any being unknown. + evals = [criteria_evaluation.get(c) for c in self.condition_hashes] + if False in evals: + return False + if None in evals: + return None + return True + + def matches_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Passes back "matches" (T/F/none) and a list of unknown criterion hashes + hash_evals = { + cell: criteria_evaluation.get(cell) for cell in self.condition_hashes + } + if False in hash_evals.values(): + return False, set() + if None in hash_evals.values(): + return None, {cell for cell, ev in hash_evals.items() if ev is None} + return True, set() + + +class CintSurvey(MarketplaceTask): + model_config = ConfigDict(populate_by_name=True) + + survey_id: CoercedStr = Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") + survey_name: str = Field(max_length=128) + buyer_name: str = Field( + description="Name of the buyer running the survey", + validation_alias="account_name", + ) + buyer_id: CoercedStr = Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") + + is_live_raw: bool = Field(alias="is_live") + + bid_loi: Optional[int] = Field( + ge=60, le=90 * 60, validation_alias="bid_length_of_interview" + ) + bid_ir: Optional[float] = Field(ge=0, le=1, validation_alias="bid_incidence") + collects_pii: Optional[bool] = Field() + survey_group_ids: Set[CoercedStr] = Field() + + calculation_type: TaskCalculationType = Field( + description="Indicates whether quotas are calculated based on completes or prescreens", + default=TaskCalculationType.COMPLETES, + validation_alias="survey_quota_calc_type", + ) + is_only_supplier_in_group: bool = Field( + description="true indicates that an allocation is reserved for a single supplier" + ) + + cpi: Decimal = Field( + gt=0, + le=100, + decimal_places=2, + max_digits=5, + validation_alias="revenue_per_interview", + description="This is AFTER commission", + ) + gross_cpi: ( + Annotated[ + Decimal, + Field( + gt=0, + le=100, + decimal_places=2, + max_digits=5, + description="This is BEFORE commission", + ), + ] + | None + ) = None + + industry: str = Field(max_length=64) + study_type: str = Field(max_length=64) + + total_client_entrants: NonNegativeInt = Field( + description="Number of total client survey entrants across all suppliers." + ) + total_remaining: NonNegativeInt = Field( + description="Number of completes still available to the supplier" + ) + completion_percentage: float = Field() + conversion: Optional[float] = Field( + ge=0, + le=1, + description="Percentage of respondents who complete the survey after qualifying", + ) + mobile_conversion: Optional[float] = Field( + ge=0, + le=1, + description="Percentage of respondents on a mobile device who complete the survey after qualifying.", + ) + length_of_interview: Optional[NonNegativeInt] = Field( + description="Median time for a respondent to complete the survey, excluding prescreener, in minutes. This " + "value will be zero until 6 completes are achieved." + ) + overall_completes: NonNegativeInt = Field( + description="Number of completes already achieved across all suppliers on the survey." + ) + revenue_per_click: Optional[float] = Field( + description="The Revenue Per Click value of the survey. RPC = (RPI * completes) / system entrants", + default=None, + ) + termination_length_of_interview: Optional[NonNegativeInt] = Field( + description="Median time for a respondent to be termed, in minutes. This value is calculated after six survey " + "entrants and rounded to the nearest whole number. Until six survey entrants are achieved the " + "value will be zero." + ) + + respondent_pids: Set[str] = Field(default_factory=set) + + qualifications: List[str] = Field(default_factory=list) + quotas: List[CintQuota] = Field(default_factory=list) + + source: Literal[Source.CINT] = Field(default=Source.CINT) + + used_question_ids: Set[AlphaNumStr] = Field(default_factory=set) + + # This is a "special" key to store all conditions that are used (as "condition_hashes") throughout + # this survey. In the reduced representation of this task (nearly always, for db i/o, in global_vars) + # this field will be null. + conditions: Optional[Dict[str, CintCondition]] = Field(default=None) + + # These do not come from the API. We set it when we update/create in the db. + created_at: Optional[AwareDatetimeISO] = Field(default=None) + last_updated: Optional[AwareDatetimeISO] = Field(default=None) + + @property + def internal_id(self) -> str: + return self.survey_id + + @property + def is_open(self) -> bool: + # The survey is open if the status is OPEN and there is at least 1 open quota (or there are no quotas!) + return self.is_live and ( + any(q.is_open for q in self.quotas) or len(self.quotas) == 0 + ) + + @property + def is_live(self) -> bool: + return self.is_live_raw + + def model_dump(self, **kwargs: Any) -> dict: + data = super().model_dump(**kwargs) + data["is_live"] = data.pop("is_live_raw", None) + return data + + @computed_field + @property + def all_hashes(self) -> Set[str]: + s = set(self.qualifications) + for q in self.quotas: + s.update(set(q.condition_hashes)) if q.condition_hashes else None + return s + + @model_validator(mode="before") + @classmethod + def set_cpi(cls, data: Any): + if data.get("gross_cpi") and not data.get("cpi"): + data["cpi"] = (data["gross_cpi"] * Decimal("0.70")).quantize( + Decimal("0.01") + ) + if data.get("gross_cpi"): + data["gross_cpi"] = data["gross_cpi"].quantize(Decimal("0.01")) + return data + + @model_validator(mode="before") + @classmethod + def set_locale(cls, data: Any): + data["country_isos"] = [data["country_iso"]] + data["language_isos"] = [data["language_iso"]] + return data + + @model_validator(mode="before") + @classmethod + def set_used_questions(cls, data: Any): + if data.get("used_question_ids") is not None: + return data + if not data.get("conditions"): + data["used_question_ids"] = set() + return data + data["used_question_ids"] = { + c.question_id for c in data["conditions"].values() if c.question_id + } + return data + + @property + def condition_model(self) -> Type[MarketplaceCondition]: + return CintCondition + + @property + def age_question(self) -> str: + return "42" + + @property + def marketplace_genders(self) -> Dict[Gender, Optional[MarketplaceCondition]]: + return { + Gender.MALE: CintCondition( + question_id="43", + values=["1"], + value_type=ConditionValueType.LIST, + ), + Gender.FEMALE: CintCondition( + question_id="43", + values=["2"], + value_type=ConditionValueType.LIST, + ), + Gender.OTHER: None, + } + + @classmethod + def from_api(cls, d: Dict) -> Optional[Self]: + try: + return cls._from_api(d) + except Exception as e: + logger.warning(f"Unable to parse survey: {d}. {e}") + return None + + @classmethod + def _from_api(cls, d: Dict) -> Self: + if "cpi" in d: + d["gross_cpi"] = Decimal(d.pop("cpi")) + if "revenue_per_interview" in d: + assert d["revenue_per_interview"]["currency_code"] == "USD" + d["revenue_per_interview"] = Decimal( + d["revenue_per_interview"]["value"] + ).quantize(Decimal("0.01")) + + language_iso, country_iso = d["country_language"].split("_") + d["country_iso"] = locale_helper.get_country_iso(country_iso.lower()) + d["language_iso"] = locale_helper.get_language_iso(language_iso.lower()) + + d["bid_length_of_interview"] = round(d["bid_length_of_interview"] * 60) + d["length_of_interview"] = round(d["length_of_interview"] * 60) + d["termination_length_of_interview"] = round( + d["termination_length_of_interview"] * 60 + ) + d["bid_incidence"] /= 100 + d["survey_quota_calc_type"] = TaskCalculationType.from_api( + d["survey_quota_calc_type"] + ) + + # Cint/Cint Doesn't believe in using nullable values. Nullify them manually + + # termination_length_of_interview: Median time for a respondent to be termed, + # in minutes. This value is calculated after six survey entrants and rounded + # to the nearest whole number. Until six survey entrants are achieved the + # value will be zero. + if ( + d["termination_length_of_interview"] == 0 + and d["total_client_entrants"] <= 6 + ): + d["termination_length_of_interview"] = None + + # length_of_interview int Median time for a respondent to complete the + # survey, excluding the Cint Exchange (formerly Marketplace) prescreener, + # in minutes. This value will be zero until a complete is achieved. + # Documenation is wrong. it is 6 completes, but still some are not right + if d["length_of_interview"] == 0 and d["overall_completes"] < 6: + d["length_of_interview"] = None + + # conversion: either 1 or 6 completes? not clear + if d["overall_completes"] == 0: + d["conversion"] = None + d["mobile_conversion"] = None + d["revenue_per_click"] = None + + d["conditions"] = dict() + d.setdefault("survey_qualifications", list()) + qualifications = [CintCondition.from_api(q) for q in d["survey_qualifications"]] + for q in qualifications: + d["conditions"][q.criterion_hash] = q + d["qualifications"] = [x.criterion_hash for x in qualifications] + + quotas = [] + for quota in d["survey_quotas"]: + if quota["survey_quota_type"] == "Total": + quotas.append(CintQuota.from_api(quota)) + else: + criteria = [CintCondition.from_api(q) for q in quota["questions"]] + quota["condition_hashes"] = [x.criterion_hash for x in criteria] + quotas.append(CintQuota.from_api(quota)) + for q in criteria: + d["conditions"][q.criterion_hash] = q + d["quotas"] = quotas + + now = datetime.now(tz=timezone.utc) + d["created_at"] = now + d["last_updated"] = now + + return cls.model_validate(d) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump( + mode="json", + exclude={ + "all_hashes", + "country_isos", + "language_isos", + "source", + "conditions", + }, + ) + d["qualifications"] = json.dumps(d["qualifications"]) + d["quotas"] = json.dumps(d["quotas"]) + d["used_question_ids"] = json.dumps(sorted(d["used_question_ids"])) + d["survey_group_ids"] = json.dumps(sorted(d["survey_group_ids"])) + d["respondent_pids"] = json.dumps(sorted(d["respondent_pids"])) + d["last_updated"] = self.last_updated + d["created_at"] = self.created_at + return d + + @classmethod + def from_mysql(cls, d: Dict[str, Any]) -> Self: + d["created_at"] = d["created_at"].replace(tzinfo=timezone.utc) + d["last_updated"] = d["last_updated"].replace(tzinfo=timezone.utc) + d["qualifications"] = json.loads(d["qualifications"]) + d["used_question_ids"] = json.loads(d["used_question_ids"]) + d["quotas"] = json.loads(d["quotas"]) + d["survey_group_ids"] = json.loads(d["survey_group_ids"]) + d["respondent_pids"] = json.loads(d["respondent_pids"]) + return cls.model_validate(d) + + def passes_qualifications( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> bool: + # We have to match all quals + return all(criteria_evaluation.get(q) for q in self.qualifications) + + def passes_qualifications_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Passes back "passes" (T/F/none) and a list of unknown criterion hashes + hash_evals = {q: criteria_evaluation.get(q) for q in self.qualifications} + evals = set(hash_evals.values()) + # We have to match all. So if any are False, we know we don't pass + if False in evals: + return False, set() + # If any are None, we don't know + if None in evals: + return None, {cell for cell, ev in hash_evals.items() if ev is None} + return True, set() + + def passes_quotas(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Many surveys have 0 quotas. Quotas are exclusionary. + # They can NOT match a quota where currently_open=0 + any_pass = True + for q in self.quotas: + if q.quota_type == "total": + matches = q.is_open + else: + matches = q.matches_optional(criteria_evaluation) + if matches in {True, None} and not q.is_open: + # We also cannot be unknown for this quota, b/c we might fall into it, which would be a fail. + return False + return any_pass + + def passes_quotas_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Many surveys have 0 quotas. Quotas are exclusionary. + # They can NOT match a quota where currently_open=0 + total_quota = [q for q in self.quotas if q.quota_type == "total"][0] + if not total_quota.is_open: + return False, set() + quotas = [q for q in self.quotas if q.quota_type != "total"] + if len(quotas) == 0: + return True, set() + quota_eval = { + quota: quota.matches_soft(criteria_evaluation) for quota in quotas + } + evals = set(g[0] for g in quota_eval.values()) + if any(m[0] is True and not q.is_open for q, m in quota_eval.items()): + # matched a full quota + return False, set() + if any(m[0] is None and not q.is_open for q, m in quota_eval.items()): + # Unknown match for full quota + if True in evals: + # we match 1 other, so the missing are only this type + return None, set( + flatten( + [ + m[1] + for q, m in quota_eval.items() + if m[0] is None and not q.is_open + ] + ) + ) + else: + # we don't match any quotas, so everything is unknown + return None, set( + flatten([m[1] for q, m in quota_eval.items() if m[0] is None]) + ) + if True in evals: + return True, set() + if None in evals: + return None, set( + flatten([m[1] for q, m in quota_eval.items() if m[0] is None]) + ) + return False, set() + + def determine_eligibility( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> bool: + return ( + self.is_open + and self.passes_qualifications(criteria_evaluation) + and self.passes_quotas(criteria_evaluation) + ) + + def determine_eligibility_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # We check is_open when putting the survey in global_vars. Don't need to check again. + # if self.is_open is False: + # return False, set() + pass_quals, h_quals = self.passes_qualifications_soft(criteria_evaluation) + if pass_quals is False: + # short-circuit fail + return False, set() + pass_quotas, h_quotas = self.passes_quotas_soft(criteria_evaluation) + if pass_quals and pass_quotas: + return True, set() + elif pass_quotas is False: + return False, set() + else: + return None, h_quals | h_quotas diff --git a/generalresearch/models/cint/task_collection.py b/generalresearch/models/cint/task_collection.py new file mode 100644 index 0000000..91db35f --- /dev/null +++ b/generalresearch/models/cint/task_collection.py @@ -0,0 +1,71 @@ +from typing import List, Set + +import pandas as pd +from pandera import Column, DataFrameSchema, Check, Index + +from generalresearch.locales import Localelator +from generalresearch.models.cint.survey import CintSurvey +from generalresearch.models.thl.survey.task_collection import ( + TaskCollection, + create_empty_df_from_schema, +) + +COUNTRY_ISOS: Set[str] = Localelator().get_all_countries() +LANGUAGE_ISOS: Set[str] = Localelator().get_all_languages() + +CintTaskCollectionSchema = DataFrameSchema( + columns={ + "survey_name": Column(str, Check.str_length(min_value=1, max_value=128)), + "is_live": Column(bool), + "cpi": Column(float, Check.between(min_value=0, max_value=100)), + "buyer_id": Column(str), + "buyer_name": Column(str), + "study_type": Column(str), + "country_iso": Column(str, Check.isin(COUNTRY_ISOS)), # 2 letter, lowercase + "language_iso": Column(str, Check.isin(LANGUAGE_ISOS)), # 3 letter, lowercase + "total_client_entrants": Column(int), + "overall_completes": Column(int), + "length_of_interview": Column( + "Int32", Check.between(0, 90 * 60), nullable=True + ), + "conversion": Column(float, Check.between(0, 1), nullable=True), + "bid_loi": Column("Int32", Check.between(0, 90 * 60), nullable=True), + "bid_ir": Column(float, Check.between(0, 1), nullable=True), + "created_at": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "last_updated": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "used_question_ids": Column(List[str]), + "all_hashes": Column(List[str]), # set >> list for column support + }, + checks=[], + index=Index( + str, + name="survey_id", + checks=Check.str_length(min_value=1, max_value=16), + unique=True, + ), + strict=True, + coerce=True, + drop_invalid_rows=False, +) + + +class CintTaskCollection(TaskCollection): + items: List[CintSurvey] + _schema = CintTaskCollectionSchema + + def to_row(self, s: CintSurvey): + d = s.model_dump( + mode="json", + include=set(CintTaskCollectionSchema.columns.keys()) | {"survey_id"}, + ) + d["cpi"] = float(s.cpi) + return d + + def to_df(self) -> pd.DataFrame: + rows = [] + for s in self.items: + rows.append(self.to_row(s)) + if rows: + return pd.DataFrame.from_records(rows, index="survey_id") + else: + return create_empty_df_from_schema(self._schema) diff --git a/generalresearch/models/custom_types.py b/generalresearch/models/custom_types.py new file mode 100644 index 0000000..54208a1 --- /dev/null +++ b/generalresearch/models/custom_types.py @@ -0,0 +1,282 @@ +import json +from datetime import datetime, timezone, timedelta +from typing import Any, Optional, Set, Literal +from uuid import UUID + +from pydantic import ( + AwareDatetime, + StringConstraints, + TypeAdapter, + HttpUrl, + IPvAnyAddress, + Field, + AnyUrl, +) +from pydantic.functional_serializers import PlainSerializer +from pydantic.functional_validators import AfterValidator, BeforeValidator +from pydantic.networks import UrlConstraints +from pydantic_core import Url +from typing_extensions import Annotated + +from generalresearch.models import DeviceType, Source + + +# if TYPE_CHECKING: +# from generalresearch.models import DeviceType + + +def convert_datetime_to_iso_8601_with_z_suffix(dt: datetime) -> str: + # By default, datetimes are serialized with the %f optional. We don't + # want that because then the deserialization fails if the datetime + # didn't have microseconds. + return dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + + +def convert_str_dt(v: Any) -> Optional[AwareDatetime]: + # By default, pydantic is unable to handle tz-aware isoformat str. Attempt + # to parse a str that was dumped using the iso8601 format with Z suffix. + if v is not None and type(v) is str: + assert v.endswith("Z") and "T" in v, "invalid format" + return datetime.strptime(v, "%Y-%m-%dT%H:%M:%S.%fZ").replace( + tzinfo=timezone.utc + ) + return v + + +def assert_utc(v: AwareDatetime) -> AwareDatetime: + if isinstance(v, datetime): + # We need utcoffset b/c FastAPI parses datetimes using FixedTimezone + assert v.tzinfo == timezone.utc or v.tzinfo.utcoffset(v) == timedelta( + 0 + ), "Timezone is not UTC" + v = v.astimezone(timezone.utc) + return v + + +InclExcl = Literal["exclude", "include"] + +# Our custom AwareDatetime that correctly serializes and deserializes +# to an ISO8601 str with timezone +AwareDatetimeISO = Annotated[ + AwareDatetime, + BeforeValidator(convert_str_dt), + AfterValidator(assert_utc), + PlainSerializer( + lambda x: x.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + when_used="json-unless-none", + ), +] + +# ISO 3166-1 alpha-2 (two-letter codes, lowercase) +# "Like" b/c it matches the format, but we're not explicitly checking +# it is one of our supported values. See models.thl.locales for that. +CountryISOLike = Annotated[ + str, StringConstraints(max_length=2, min_length=2, pattern=r"^[a-z]{2}$") +] +# 3-char ISO 639-2/B, lowercase +LanguageISOLike = Annotated[ + str, StringConstraints(max_length=3, min_length=3, pattern=r"^[a-z]{3}$") +] + + +def check_valid_uuid(v: str) -> str: + try: + assert UUID(v).hex == v + except Exception: + raise ValueError("Invalid UUID") + return v + + +def is_valid_uuid(v: str) -> bool: + try: + assert UUID(v).hex == v + except Exception: + return False + return True + + +# Our custom field that stores a UUID4 as the .hex string representation +UUIDStr = Annotated[ + str, + StringConstraints(min_length=32, max_length=32), + AfterValidator(check_valid_uuid), +] +# Accepts the non-hex representation and coerces +UUIDStrCoerce = Annotated[ + str, + StringConstraints(min_length=32, max_length=32), + BeforeValidator(lambda value: TypeAdapter(UUID).validate_python(value).hex), + AfterValidator(check_valid_uuid), +] + +# Same thing as UUIDStr with HttpUrl field. It is confusing that this +# is not a str https://github.com/pydantic/pydantic/discussions/6395 +HttpUrlStr = Annotated[ + str, + BeforeValidator(lambda value: str(TypeAdapter(HttpUrl).validate_python(value))), +] + +HttpsUrl = Annotated[Url, UrlConstraints(max_length=2083, allowed_schemes=["https"])] +HttpsUrlStr = Annotated[ + str, + BeforeValidator(lambda value: str(TypeAdapter(HttpsUrl).validate_python(value))), +] + +# Same thing as UUIDStr with IPvAnyAddress field. It is confusing that this is not a str +IPvAnyAddressStr = Annotated[ + str, + BeforeValidator( + lambda value: str(TypeAdapter(IPvAnyAddress).validate_python(value).exploded) + ), +] + + +def coerce_int_to_str(data: Any) -> Any: + """Transform input int to str, return other types as is""" + if isinstance(data, int): + return str(data) + return data + + +# This is a string field, but accepts integers that can be coerced into strings. +CoercedStr = Annotated[str, BeforeValidator(coerce_int_to_str)] + +# Serializers that can transform a collection of str into a comma separated +# str bidirectionally +to_comma_sep_str = PlainSerializer(lambda x: ",".join(sorted(list(x))), return_type=str) +enum_to_comma_sep_str = PlainSerializer( + lambda x: ",".join(sorted([str(y.value) for y in x])), return_type=str +) +from_comma_sep_str = BeforeValidator( + lambda x: set(x.split(",") if x != "" else []) if isinstance(x, str) else x +) + +# This is a set of DeviceType, that serializes and de-serializes into a +# (sorted) comma-separated str +DeviceTypes = Annotated[Set[DeviceType], enum_to_comma_sep_str, from_comma_sep_str] + +# This is a set of alphanumeric strings, that serializes and de-serializes +# into a (sorted) comma-separated str +AlphaNumStr = Annotated[str, StringConstraints(max_length=32, min_length=1)] + +# a string like an IP address, but we don't need to validate that it is +# actually an IP address. +IPLikeStr = Annotated[str, StringConstraints(max_length=39, min_length=2)] + + +def assert_dask_auth(v: Url) -> Url: + # Even if we're using tls and a SSL cert, Dask doesn't have the concept + # of user authentication + assert [v.username, v.password] == [ + None, + None, + ], "User & Password are not supported" + return v + + +def assert_sentry_auth(v: Url) -> Url: + assert v.username, "Sentry URL requires a user key" + assert len(v.username) > 10, "Sentry user key seems bad" + assert v.password is None, "Sentry password is not supported" + assert int(v.path[1:]), "Sentry project id needs to be a number (I think)" + assert v.port == 443, "https required" + assert v.fragment is None + return v + + +SentryDsn = Annotated[ + Url, + UrlConstraints( + allowed_schemes=["https"], + default_host="ingest.us.sentry.io", + default_port=443, + ), + AfterValidator(assert_sentry_auth), +] + +MySQLOrMariaDsn = Annotated[ + AnyUrl, + UrlConstraints(allowed_schemes=["mysql", "mariadb"]), +] + +DaskDsn = Annotated[ + Url, + UrlConstraints( + allowed_schemes=["tcp", "tls"], + default_host="127.0.0.1", + default_port=8786, + ), + AfterValidator(assert_dask_auth), +] + +InfluxDsn = Annotated[ + Url, + UrlConstraints( + allowed_schemes=["influxdb"], + default_host="127.0.0.1", + default_port=8086, + ), +] + +AlphaNumStrSet = Annotated[Set[AlphaNumStr], to_comma_sep_str, from_comma_sep_str] +IPLikeStrSet = Annotated[Set[IPLikeStr], to_comma_sep_str, from_comma_sep_str] +UUIDStrSet = Annotated[Set[UUIDStr], to_comma_sep_str, from_comma_sep_str] + +list_models_to_json_str = PlainSerializer( + lambda x: json.dumps([y.model_dump(mode="json") for y in x]), + return_type=str, + when_used="json", +) +json_str_to_model = BeforeValidator( + lambda x: json.loads(x) if isinstance(x, str) else x +) +json_str_to_set = BeforeValidator( + lambda x: set(json.loads(x)) if isinstance(x, str) else x +) + +EnumNameSerializer = PlainSerializer( + lambda e: e.name, return_type="str", when_used="unless-none" +) + +# These are used to make it explicit which attributes are pk/fk values. +BigAutoInteger = Annotated[int, Field(strict=True, gt=0, lt=9223372036854775807)] + + +def validate_survey_key(v: str) -> str: + """ + Variously called a Survey.natural_key or in Web3.0 language a CURIE + """ + # Must contain exactly one colon + if v.count(":") != 1: + raise ValueError("survey_key must be ':'") + + source, survey_id = v.split(":", 1) + + try: + Source(source) + except ValueError: + raise ValueError(f"invalid source '{source}'") + + if not (1 <= len(survey_id) <= 32): + raise ValueError("survey_id must be 1–32 characters") + + return v + + +SurveyKey = Annotated[ + str, + StringConstraints( + min_length=3, # 1-char source: "c:x" + max_length=35, # 2-char source: "tt:" + 32 + ), + AfterValidator(validate_survey_key), +] + +PropertyCode = Annotated[ + str, + StringConstraints( + min_length=3, # 1-char source: "c:x" + max_length=64, # DB max field length + pattern=r"^[a-z]{1,2}\:.*", + ), +] diff --git a/generalresearch/models/device.py b/generalresearch/models/device.py new file mode 100644 index 0000000..cc15eee --- /dev/null +++ b/generalresearch/models/device.py @@ -0,0 +1,15 @@ +from user_agents import parse as parse_ua + +from generalresearch.models import DeviceType + + +def parse_device_from_useragent(user_agent: str) -> DeviceType: + ua = parse_ua(user_agent) + if ua.is_mobile: + return DeviceType.MOBILE + elif ua.is_tablet: + return DeviceType.TABLET + elif ua.is_pc: + return DeviceType.DESKTOP + else: + return DeviceType.UNKNOWN diff --git a/generalresearch/models/dynata/__init__.py b/generalresearch/models/dynata/__init__.py new file mode 100644 index 0000000..c6d3a67 --- /dev/null +++ b/generalresearch/models/dynata/__init__.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class DynataStatus(str, Enum): + OPEN = "OPEN" + PAUSED = "PAUSED" + CLOSED = "CLOSED" diff --git a/generalresearch/models/dynata/question.py b/generalresearch/models/dynata/question.py new file mode 100644 index 0000000..76670ce --- /dev/null +++ b/generalresearch/models/dynata/question.py @@ -0,0 +1,269 @@ +# https://developers.dynata.com/docs/rex-respondent-gateway/dc5b33f20a1c9-get-attribute-info +import json +import logging +import re +from datetime import timedelta +from enum import Enum +from functools import cached_property +from typing import List, Optional, Literal, Any, Dict, Set + +from pydantic import BaseModel, Field, model_validator, field_validator, PositiveInt + +from generalresearch.models import Source, MAX_INT32 +from generalresearch.models.custom_types import AwareDatetimeISO +from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +TAG_RE = re.compile(r"<[^>]+>") + + +def clean_text(s: str): + # Some have a bunch of stupid html tags like 'What type of phone do you use?' + # thank you stackoverflow + return TAG_RE.sub("", s).replace("\n", "").replace(" ", "") + + +class DynataQuestionOption(BaseModel): + id: str = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + frozen=True, + description="The unique identifier for a response to a qualification", + ) + text: str = Field( + min_length=1, + max_length=1024, + frozen=True, + description="The response text shown to respondents", + ) + + # Order does not come back explicitly in the API, and the options are not ordered at all. We will + # order the responses when converting to UpkQuestion + + @field_validator("text", mode="after") + def clean_text(cls, s: str): + return clean_text(s) + + +class DynataQuestionType(str, Enum): + """ + From the API: {'geo', 'multi_select', 'multi_select_searchable', 'none', + 'single_select', 'single_select_grid', 'single_select_searchable', 'zip'} + These are of course not defined anywhere... + """ + + # single_select, single_select_grid, single_select_searchable, geo + SINGLE_SELECT = "s" + # multi_select, multi_select_searchable + MULTI_SELECT = "m" + # zip + TEXT_ENTRY = "t" + # Some questions are "restricted"/hidden, and we don't know anything but their ID + RESTRICTED = "r" + + # none: Some of these are just invalid, some are calculated GEO questions. + + @staticmethod + def from_api(display_mode: str): + question_type_map = { + "multi_select": DynataQuestionType.MULTI_SELECT, + "multi_select_searchable": DynataQuestionType.MULTI_SELECT, + "single_select": DynataQuestionType.SINGLE_SELECT, + "single_select_grid": DynataQuestionType.SINGLE_SELECT, + "single_select_searchable": DynataQuestionType.SINGLE_SELECT, + "geo": DynataQuestionType.SINGLE_SELECT, + "zip": DynataQuestionType.TEXT_ENTRY, + "text_entry": DynataQuestionType.TEXT_ENTRY, + } + # We don't want to fail if it is None b/c some are calculated GEO questions + return question_type_map.get(display_mode, DynataQuestionType.SINGLE_SELECT) + + +class DynataUserQuestionAnswer(BaseModel): + # This is optional b/c this model can be used for eligibility checks for "anonymous" users, which are represented + # by a list of question answers not associated with an actual user. No default b/c we must explicitly set + # the field to None. + user_id: Optional[PositiveInt] = Field(lt=MAX_INT32) + question_id: str = Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") + # This is optional b/c we do not need it when writing these to the db. When these are fetched from the db + # for use in yield-management, we read this field from the question table. + question_type: Optional[DynataQuestionType] = Field(default=None) + # This may be a pipe-separated string if the question_type is multi. regex means any chars except capital letters + option_id: str = Field(pattern=r"^[^A-Z]*$") + created: AwareDatetimeISO = Field() + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: str = Field( + max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True + ) + # 3-char ISO 639-2/B, lowercase + language_iso: str = Field( + max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True + ) + + @cached_property + def options_ids(self) -> Set[str]: + return set(self.option_id.split("|")) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", exclude={"question_type"}) + d["created"] = self.created.replace(tzinfo=None) + return d + + +class DynataQuestionDependency(BaseModel, frozen=True): + # This is not explained or documented. Going to just store it for now + question_id: str = Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") + # Some are an empty list. Unclear if this means "any option" or it is broken. + option_ids: List[str] = Field() + + +class DynataQuestion(MarketplaceQuestion): + # This is called "qualification_code" in the API + question_id: str = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + description="The unique identifier for the qualification", + frozen=True, + ) + # In the API: desc + question_name: str = Field( + max_length=255, min_length=1, description="A short name for the question" + ) + description: str = Field(max_length=255, min_length=1) + question_text: str = Field( + max_length=1024, min_length=1, description="The text shown to respondents" + ) + question_type: DynataQuestionType = Field(frozen=True) + options: Optional[List[DynataQuestionOption]] = Field(default=None, min_length=1) + # This does not mean that it doesn't expire, it means undefined. + expiration_duration: Optional[timedelta] = Field(default=None) + parent_dependencies: List[DynataQuestionDependency] = Field(default_factory=list) + + source: Literal[Source.DYNATA] = Source.DYNATA + + @property + def internal_id(self) -> str: + return self.question_id + + @field_validator("question_text", mode="after") + def clean_text(cls, s: str): + return clean_text(s) + + @model_validator(mode="after") + def check_type_options_agreement(self): + # If type == "text_entry", options is None. Otherwise, must be set. + if self.question_type in { + DynataQuestionType.TEXT_ENTRY, + DynataQuestionType.RESTRICTED, + }: + assert self.options is None, "TEXT_ENTRY shouldn't have options" + else: + assert self.options is not None, "missing options" + return self + + @classmethod + def create_restricted_question(cls, question_id): + # In a restricted question, we don't know the name/description/etc, but I don't + # want these fields nullable + return cls( + question_id=question_id, + question_type=DynataQuestionType.RESTRICTED, + question_name="unknown", + question_text="unknown", + description="unknown", + is_live=True, + # We don't know what locale these questions are for + country_iso="us", + language_iso="eng", + ) + + @classmethod + def from_db(cls, d: dict): + options = None + if d["options"]: + options = [ + DynataQuestionOption(id=r["id"], text=r["text"]) for r in d["options"] + ] + parent_dependencies = [ + DynataQuestionDependency( + question_id=pd["question_id"], option_ids=pd["option_ids"] + ) + for pd in d["parent_dependencies"] + ] + expiration_duration = ( + timedelta(seconds=d["expiration_duration_sec"]) + if d["expiration_duration_sec"] + else None + ) + return cls( + question_id=d["question_id"], + question_name=d["question_name"], + question_text=d["question_text"], + question_type=d["question_type"], + country_iso=d["country_iso"], + language_iso=d["language_iso"], + options=options, + parent_dependencies=parent_dependencies, + description=d["description"], + is_live=d["is_live"], + category_id=d.get("category_id"), + expiration_duration=expiration_duration, + ) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", by_alias=True) + d["options"] = json.dumps(d["options"]) + d["parent_dependencies"] = json.dumps(d["parent_dependencies"]) + d["expiration_duration_sec"] = ( + self.expiration_duration.total_seconds() + if self.expiration_duration + else None + ) + return d + + def to_upk_question(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionType, + UpkQuestionSelectorMC, + UpkQuestionSelectorTE, + UpkQuestion, + order_exclusive_options, + ) + + upk_type_selector_map = { + DynataQuestionType.SINGLE_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.SINGLE_ANSWER, + ), + DynataQuestionType.MULTI_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.MULTIPLE_ANSWER, + ), + DynataQuestionType.TEXT_ENTRY: ( + UpkQuestionType.TEXT_ENTRY, + UpkQuestionSelectorTE.SINGLE_LINE, + ), + } + upk_type, upk_selector = upk_type_selector_map[self.question_type] + d = { + "ext_question_id": self.external_id, + "country_iso": self.country_iso, + "language_iso": self.language_iso, + "type": upk_type, + "selector": upk_selector, + "text": self.question_text, + } + if self.options: + d["choices"] = [ + UpkQuestionChoice(id=c.id, text=c.text, order=n) + for n, c in enumerate(self.options) + ] + q = UpkQuestion(**d) + order_exclusive_options(q) + return q diff --git a/generalresearch/models/dynata/survey.py b/generalresearch/models/dynata/survey.py new file mode 100644 index 0000000..de4b177 --- /dev/null +++ b/generalresearch/models/dynata/survey.py @@ -0,0 +1,656 @@ +from __future__ import annotations + +import json +import logging +from datetime import timezone +from decimal import Decimal +from functools import cached_property +from typing import Optional, Dict, Any, List, Literal, Set, Tuple, Type + +from more_itertools import flatten +from pydantic import ( + Field, + ConfigDict, + BaseModel, + model_validator, + field_validator, + RootModel, + computed_field, +) +from typing_extensions import Self + +from generalresearch.locales import Localelator +from generalresearch.models import TaskCalculationType, Source +from generalresearch.models.custom_types import ( + CoercedStr, + AwareDatetimeISO, + AlphaNumStrSet, + DeviceTypes, + AlphaNumStr, +) +from generalresearch.models.dynata import DynataStatus +from generalresearch.models.thl.demographics import ( + Gender, +) +from generalresearch.models.thl.survey import MarketplaceTask +from generalresearch.models.thl.survey.condition import ( + ConditionValueType, + MarketplaceCondition, +) + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +locale_helper = Localelator() + + +class DynataRequirements(BaseModel): + # Requires inviting (recontacting) specific respondents to a follow up survey. + requires_recontact: bool = Field(default=False) + # Requires respondents to provide personally identifiable information (PII) within client survey. + requires_pii_collection: bool = Field(default=False) + # Requires respondents to utilize their webcam to participate. + requires_webcam: bool = Field(default=False) + # Requires use of facial recognition technology with respondents, such as eye tracking. + requires_eye_tracking: bool = Field(default=False) + # Requires partner to allow Dynata to drop a cookie on respondent. + requires_cookie_drops: bool = Field(default=False) + # Requires partner-uploaded respondent PII to expand third-party matched data. + requires_sample_plus: bool = Field(default=False) + # Requires respondents to download a software application. + requires_app_vpn: bool = Field(default=False) + # Requires additional incentives to be manually awarded to respondent by partner outside of the typical online + # survey flow. + requires_manual_rewards: bool = Field(default=False) + + def __repr__(self) -> str: + # Fancy repr that only shows values if they are True + repr_args = list(self.__repr_args__()) + repr_args = [(k, v) for k, v in repr_args if v] + join_str = ", " + repr_str = join_str.join( + repr(v) if a is None else f"{a}={v!r}" for a, v in repr_args + ) + return f"{self.__repr_name__()}({repr_str})" + + +class DynataCondition(MarketplaceCondition): + question_id: Optional[CoercedStr] = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + validation_alias="attribute_id", + ) + + # This comes in the API and is used to match "cells" to quotas they're associated with. Once + # we parse the API response, we don't need this tag id anymore. + tag: Optional[str] = Field(default=None, max_length=36) + + @classmethod + def from_api(cls, cell: Dict[str, Any]) -> "DynataCondition": + """ + We perform some preprocessing before calling this to pull in the data from COLLECTION cells. + """ + if cell["kind"] == "COLLECTION": + raise ValueError( + "this should be converted into a LIST type by the API helper first" + ) + + assert cell["kind"] in { + "VALUE", + "LIST", + "RANGE", + "INEFFABLE", + "ANSWERED", + "RECONTACT", + "INVITE_COLLECTIONS", + "STATIC_INVITE_COLLECTIONS", + }, f"unknown cell kind {cell['kind']}" + d = {k: cell[k] for k in ["tag", "attribute_id", "negate"]} + + if cell["kind"] == "VALUE": + d["values"] = [cell["value"]] + d["value_type"] = ConditionValueType.LIST + d["logical_operator"] = "OR" + return cls.model_validate(d) + + if cell["kind"] == "LIST": + d["values"] = list(map(str.lower, cell["list"])) + d["value_type"] = ConditionValueType.LIST + d["logical_operator"] = cell.get("operator", "OR") + return cls.model_validate(d) + + if cell["kind"] == "RANGE": + d["values"] = [ + "{0}-{1}".format( + cell["range"]["from"] or "inf", cell["range"]["to"] or "inf" + ) + ] + d["value_type"] = ConditionValueType.RANGE + return cls.model_validate(d) + + if cell["kind"] == "INEFFABLE": + d["value_type"] = ConditionValueType.INEFFABLE + d["values"] = [] + return cls.model_validate(d) + + if cell["kind"] == "ANSWERED": + d["value_type"] = ConditionValueType.ANSWERED + d["values"] = [] + return cls.model_validate(d) + + if cell["kind"] in {"INVITE_COLLECTIONS", "STATIC_INVITE_COLLECTIONS"}: + d["values"] = list(map(str.lower, cell["invite_collections"])) + d["value_type"] = ConditionValueType.RECONTACT + d["logical_operator"] = cell["operator"] + d["attribute_id"] = None + return cls.model_validate(d) + + +class DynataQuota(BaseModel): + # This is called a Quota Object in Dynata + model_config = ConfigDict(populate_by_name=True, frozen=True) + + # We don't ever need this + # quota_id: CoercedStr = Field(min_length=1, max_length=64, validation_alias="id") + count: int = Field(description="Limit of completes available") + # Each condition_hash is called in Dynata a "Quota Cell" + # Some quotas have no conditions. I'm not sure how eligibility is supposed to work for this. + condition_hashes: List[str] = Field(min_length=0, default_factory=list) + status: DynataStatus = Field() + + def __hash__(self): + return hash(tuple((tuple(self.condition_hashes), self.count, self.status))) + + @property + def is_open(self) -> bool: + # todo: should we make this configurable somehow? Until we have like a bag-holding score back, + # this has be hardcoded... + min_open_spots = 3 + return self.status == DynataStatus.OPEN and (self.count >= min_open_spots) + + def passes(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # We have to match all conditions (aka cells) within the quota (aka quota object). + return self.is_open and all( + criteria_evaluation.get(c) for c in self.condition_hashes + ) + + def passes_verbose(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + print(f"quota.is_open: {self.is_open}") + print( + ", ".join( + [f"{c}: {criteria_evaluation.get(c)}" for c in self.condition_hashes] + ) + ) + return self.is_open and all( + criteria_evaluation.get(c) for c in self.condition_hashes + ) + + def passes_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Passes back "passes" (T/F/none) and a list of unknown criterion hashes + if self.is_open is False: + return False, set() + cell_evals = { + cell: criteria_evaluation.get(cell) for cell in self.condition_hashes + } + evals = set(cell_evals.values()) + # We have to match all. So if any are False, we know we don't pass + if False in evals: + return False, set() + # if any are None, we don't know + elif None in evals: + return None, {cell for cell, ev in cell_evals.items() if ev is None} + else: + return True, set() + + +class DynataQuotaGroup(RootModel): + root: List[DynataQuota] = Field() + + def __iter__(self): + return iter(self.root) + + def __hash__(self): + return hash(tuple(self.root)) + + def passes(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Qualify for ANY quota object within a quota group + return any(quota.passes(criteria_evaluation) for quota in self.root) + + def passes_verbose(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Qualify for ANY quota object within a quota group + for quota in self.root: + print("---") + print(quota.passes_verbose(criteria_evaluation)) + print("---") + return any(quota.passes(criteria_evaluation) for quota in self.root) + + @property + def is_open(self) -> bool: + return any(cell.is_open for cell in self.root) + + def passes_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Qualify for ANY quota object within a quota group + obj_evals = {obj: obj.passes_soft(criteria_evaluation) for obj in self.root} + evals = set(v[0] for v in obj_evals.values()) + # If we match 1 obj, then the others don't matter + if any(evals): + return True, set() + # If we have none passing, and at least 1 unknown, then it is conditional + elif None in evals: + conditional_hashes = set( + flatten([v[1] for v in obj_evals.values() if v[0] is None]) + ) + return None, conditional_hashes + else: + return False, set() + + +class DynataFilterObject(RootModel): + root: List[str] = Field() # list of criterion hashes + + def __iter__(self): + return iter(self.root) + + def __hash__(self): + return hash(tuple(self.root)) + + def passes(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # We have to match all cells within an object. + return all(criteria_evaluation.get(cell) for cell in self.root) + + def passes_verbose(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + for cell in self.root: + print(f"{cell}: {criteria_evaluation.get(cell)}") + # We have to match all cells within an object. + return all(criteria_evaluation.get(cell) for cell in self.root) + + def passes_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Passes back "passes" (T/F/none) and a list of unknown criterion hashes + cell_evals = {cell: criteria_evaluation.get(cell) for cell in self.root} + evals = set(cell_evals.values()) + # We have to match all. So if any are False, we know we don't pass + if False in evals: + return False, set() + # if any are None, we don't know + elif None in evals: + return None, {cell for cell, ev in cell_evals.items() if ev is None} + else: + return True, set() + + +class DynataFilterGroup(RootModel): + root: List[DynataFilterObject] = Field() + + def __iter__(self): + return iter(self.root) + + def __hash__(self): + return hash(tuple(self.root)) + + def passes(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # A filter group is matched if we match at least 1 filter objs in the group. + return any(obj.passes(criteria_evaluation) for obj in self.root) + + def passes_verbose(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # A filter group is matched if we match at least 1 filter objs in the group. + for obj in self.root: + print("---") + print(obj.passes_verbose(criteria_evaluation)) + print("---") + return any(obj.passes(criteria_evaluation) for obj in self.root) + + def passes_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Passes back "passes" (T/F/none) and a list of unknown criterion hashes + obj_evals = {obj: obj.passes_soft(criteria_evaluation) for obj in self.root} + evals = set(v[0] for v in obj_evals.values()) + # If we match 1 obj, then the others don't matter + if any(evals): + return True, set() + # If we have none passing, and at least 1 unknown, then it is conditional + elif None in evals: + conditional_hashes = set( + flatten([v[1] for v in obj_evals.values() if v[0] is None]) + ) + return None, conditional_hashes + else: + return False, set() + + +class DynataSurvey(MarketplaceTask): + model_config = ConfigDict(populate_by_name=True) + + survey_id: CoercedStr = Field( + min_length=1, max_length=16, pattern=r"^[0-9]+$", validation_alias="id" + ) + status: DynataStatus = Field() + + client_id: CoercedStr = Field( + description="Identifier of client requesting the study", max_length=32 + ) + order_number: str = Field(description="Unique project identifier", max_length=32) + project_id: CoercedStr = Field( + max_length=32, + min_length=1, + description="opportunities in the same project have mutual participation exclusions", + ) + group_id: CoercedStr = Field( + description="Identifier of opportunity group", + max_length=32, + min_length=1, + ) + + # There are 91 min surveys. We'll filter them out later + bid_loi: Optional[int] = Field( + default=None, + le=120 * 60, + description="Docs says 'Estimated length of interview', but this is " + "really the bid LOI'", + validation_alias="length_of_interview", + ) + bid_ir: Optional[float] = Field(validation_alias="incidence_rate", ge=0, le=1) + cpi: Decimal = Field(gt=0, le=100, validation_alias="cost_per_interview") + days_in_field: int = Field(description="Expected duration of opportunity in days") + # This isn't checked for eligibility determination + expected_count: int = Field( + validation_alias="completes", + description="Total fielding completes requested", + ) + + calculation_type: TaskCalculationType = Field( + description="Indicates whether the targets are counted per Complete or Survey Start", + validation_alias="evaluation", + ) + category_ids: AlphaNumStrSet = Field(default_factory=set) + + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: str = Field( + max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True + ) + # 3-char ISO 639-2/B, lowercase + language_iso: str = Field( + max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True + ) + + allowed_devices: DeviceTypes = Field(min_length=1, validation_alias="devices") + live_link: str = Field(description="entry link") + created: AwareDatetimeISO = Field(description="Creation date of opportunity") + + project_exclusions: AlphaNumStrSet = Field(default_factory=set) + category_exclusions: AlphaNumStrSet = Field(default_factory=set) + requirements: DynataRequirements = Field() + + filters: List[DynataFilterGroup] = Field(default_factory=list) + quotas: List[DynataQuotaGroup] = Field(default_factory=list) + + source: Literal[Source.DYNATA] = Field(default=Source.DYNATA) + + used_question_ids: Set[AlphaNumStr] = Field(default_factory=set) + + # This is a "special" key to store all conditions that are used (as "condition_hashes") throughout + # this survey. In the reduced representation of this task (nearly always, for db i/o, in global_vars) + # this field will be null. + conditions: Optional[Dict[str, DynataCondition]] = Field(default=None) + + # These do not come from the API. We set them ourselves + last_updated: Optional[AwareDatetimeISO] = Field(default=None) + + @property + def internal_id(self) -> str: + return self.survey_id + + @computed_field + def is_live(self) -> bool: + return self.status == DynataStatus.OPEN + + @property + def is_open(self) -> bool: + # The survey is open if the status is OPEN and there is at least 1 open quota (or there are no quotas!) + return self.is_live and ( + any(q.is_open for q in self.quotas) or len(self.quotas) == 0 + ) + + @computed_field + @cached_property + def all_hashes(self) -> Set[str]: + s = set() + for fg in self.filters: + for f in fg.root: + s.update(f.root) + for qg in self.quotas: + for q in qg.root: + s.update(set(q.condition_hashes)) + return s + + @field_validator("category_ids", mode="before") + def split_category_ids(cls, v: object) -> object: + if isinstance(v, str): + return v.split("|") + return v + + @model_validator(mode="before") + @classmethod + def set_locale(cls, data: Any): + data["country_isos"] = [data["country_iso"]] + data["language_isos"] = [data["language_iso"]] + return data + + @model_validator(mode="before") + @classmethod + def set_used_questions(cls, data: Any): + if data.get("used_question_ids") is not None: + return data + if not data.get("conditions"): + data["used_question_ids"] = set() + return data + data["used_question_ids"] = { + c.question_id for c in data["conditions"].values() if c.question_id + } + return data + + @model_validator(mode="after") + def set_buyer_id(self): + # In dynata, this is called "client_id", in the generic MarketplaceTask, we're using "buyer_id" + self.buyer_id = self.client_id + return self + + @property + def filters_verbose(self) -> List[List[str]]: + assert self.conditions is not None, "conditions must be set" + res = [] + for filter_group in self.filters: + sub_res = [] + res.append(sub_res) + for filter in filter_group.root: + sub_res.extend([self.conditions[c].minified for c in filter.root]) + return res + + @property + def quotas_verbose(self) -> List[List[Dict[str, Any]]]: + assert self.conditions is not None, "conditions must be set" + res = [] + for quota_group in self.quotas: + sub_res = [] + res.append(sub_res) + for quota in quota_group.root: + q = quota.model_dump(mode="json") + q["conditions"] = [ + self.conditions[c].minified for c in quota.condition_hashes + ] + sub_res.append(q) + return res + + @property + def condition_model(self) -> Type[MarketplaceCondition]: + return DynataCondition + + @property + def age_question(self) -> str: + return "80" + + @property + def marketplace_genders(self): + return { + Gender.MALE: DynataCondition( + question_id="1", + values=["1"], + value_type=ConditionValueType.LIST, + ), + Gender.FEMALE: DynataCondition( + question_id="1", + values=["2"], + value_type=ConditionValueType.LIST, + ), + Gender.OTHER: None, + } + + def is_unchanged(self, other) -> bool: + # Avoiding overloading __eq__ because it looks kind of complicated? I want to be explicit that + # this is not testing object equivalence, just that the objects don't require any db updates. + # We also exclude conditions b/c this is just the condition_hash definitions + return self.model_dump( + exclude={"created", "last_updated", "conditions"} + ) == other.model_dump(exclude={"created", "last_updated", "conditions"}) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump( + mode="json", + exclude={ + "all_hashes", + "country_isos", + "language_isos", + "source", + "conditions", + }, + ) + d["filters"] = json.dumps(d["filters"]) + d["quotas"] = json.dumps(d["quotas"]) + d["used_question_ids"] = json.dumps(sorted(d["used_question_ids"])) + d["requirements"] = json.dumps(d["requirements"]) + d["created"] = self.created + d["last_updated"] = self.last_updated + return d + + @classmethod + def from_db(cls, d: Dict[str, Any]) -> Self: + d["created"] = d["created"].replace(tzinfo=timezone.utc) + d["last_updated"] = d["last_updated"].replace(tzinfo=timezone.utc) + d["filters"] = json.loads(d["filters"]) + d["quotas"] = json.loads(d["quotas"]) + d["used_question_ids"] = json.loads(d["used_question_ids"]) + d["requirements"] = json.loads(d["requirements"]) + return cls.model_validate(d) + + def passes_filters(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # We have to match all filter groups + return all(group.passes(criteria_evaluation) for group in self.filters) + + def passes_filters_verbose( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> bool: + # We have to match all filter groups + for group in self.filters: + print("+++") + group.passes_verbose(criteria_evaluation) + print("+++") + return all(group.passes(criteria_evaluation) for group in self.filters) + + def passes_filters_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # We have to match all filter groups + group_eval = { + group: group.passes_soft(criteria_evaluation) for group in self.filters + } + evals = set(g[0] for g in group_eval.values()) + if False in evals: + return False, set() + elif None in evals: + conditional_hashes = set( + flatten([v[1] for v in group_eval.values() if v[0] is None]) + ) + return None, conditional_hashes + else: + return True, set() + + def passes_quotas(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # We have to match all quota groups + return all( + quota_group.passes(criteria_evaluation) for quota_group in self.quotas + ) + + def passes_quotas_verbose( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> bool: + # We have to match all quota groups + for quota_group in self.quotas: + print("+++") + quota_group.passes_verbose(criteria_evaluation) + print("+++") + return all( + quota_group.passes(criteria_evaluation) for quota_group in self.quotas + ) + + def passes_quotas_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # We have to match all quota groups + group_eval = { + quota: quota.passes_soft(criteria_evaluation) for quota in self.quotas + } + evals = set(g[0] for g in group_eval.values()) + if False in evals: + return False, set() + elif None in evals: + conditional_hashes = set( + flatten([v[1] for v in group_eval.values() if v[0] is None]) + ) + return None, conditional_hashes + else: + return True, set() + + def determine_eligibility( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> bool: + return ( + self.is_open + and self.passes_filters(criteria_evaluation) + and self.passes_quotas(criteria_evaluation) + ) + + def determine_eligibility_verbose( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> bool: + print(f"is_open: {self.is_open}") + print("passes_filters") + print(self.passes_filters_verbose(criteria_evaluation)) + print("passes_quotas") + print(self.passes_quotas_verbose(criteria_evaluation)) + return ( + self.is_open + and self.passes_filters(criteria_evaluation) + and self.passes_quotas(criteria_evaluation) + ) + + def determine_eligibility_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + if self.is_open is False: + return False, set() + pass_filters, h_filters = self.passes_filters_soft(criteria_evaluation) + pass_quotas, h_quotas = self.passes_quotas_soft(criteria_evaluation) + if pass_filters and pass_quotas: + return True, set() + elif pass_filters is False or pass_quotas is False: + return False, set() + else: + return None, h_filters | h_quotas diff --git a/generalresearch/models/dynata/task_collection.py b/generalresearch/models/dynata/task_collection.py new file mode 100644 index 0000000..bd9b81a --- /dev/null +++ b/generalresearch/models/dynata/task_collection.py @@ -0,0 +1,86 @@ +from typing import List, Dict, Any + +import pandas as pd +from pandera import Column, DataFrameSchema, Check, Index + +from generalresearch.locales import Localelator +from generalresearch.models import TaskCalculationType +from generalresearch.models.dynata import DynataStatus +from generalresearch.models.dynata.survey import DynataSurvey +from generalresearch.models.thl.survey.task_collection import ( + TaskCollection, + create_empty_df_from_schema, +) + +COUNTRY_ISOS = Localelator().get_all_countries() +LANGUAGE_ISOS = Localelator().get_all_languages() + +DynataTaskCollectionSchema = DataFrameSchema( + columns={ + "status": Column(str, Check.isin(DynataStatus)), + "buyer_id": Column(str), + "order_number": Column(str), + "project_id": Column(str), + "group_id": Column(str), + "bid_loi": Column("Int32", Check.between(0, 90 * 60), nullable=True), + "bid_ir": Column(float, Check.between(0, 1), nullable=True), + "cpi": Column(float, Check.between(min_value=0, max_value=100)), + "days_in_field": Column(int), + "expected_count": Column(int), + "calculation_type": Column(str, Check.isin(TaskCalculationType)), + "category_ids": Column(str), + "country_iso": Column(str, Check.isin(COUNTRY_ISOS)), # 2 letter, lowercase + "language_iso": Column(str, Check.isin(LANGUAGE_ISOS)), # 3 letter, lowercase + "allowed_devices": Column(str), + "requirements": Column(str), # json dumped str + "created": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "last_updated": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "used_question_ids": Column(List[str]), + "all_hashes": Column(List[str]), # set >> list for column support + }, + checks=[], + index=Index( + str, + name="survey_id", + checks=Check.str_length(min_value=1, max_value=16), + unique=True, + ), + strict=True, + coerce=True, + drop_invalid_rows=False, +) + + +class DynataTaskCollection(TaskCollection): + items: List[DynataSurvey] + _schema = DynataTaskCollectionSchema + + def to_row(self, s: DynataSurvey) -> Dict[str, Any]: + d = s.model_dump( + mode="json", + exclude={ + "country_isos", + "language_isos", + "filters", + "quotas", + "source", + "conditions", + "is_live", + "project_exclusions", + "category_exclusions", + "live_link", + "client_id", + }, + ) + d["cpi"] = float(s.cpi) + d["requirements"] = s.requirements.model_dump_json() + return d + + def to_df(self) -> pd.DataFrame: + rows = [] + for s in self.items: + rows.append(self.to_row(s)) + if rows: + return pd.DataFrame.from_records(rows, index="survey_id") + else: + return create_empty_df_from_schema(self._schema) diff --git a/generalresearch/models/events.py b/generalresearch/models/events.py new file mode 100644 index 0000000..4504ab1 --- /dev/null +++ b/generalresearch/models/events.py @@ -0,0 +1,299 @@ +from datetime import datetime, timezone, timedelta +from enum import StrEnum +from typing import Union, Literal, Optional, Dict +from uuid import uuid4 + +from pydantic import ( + BaseModel, + Field, + PositiveFloat, + NonNegativeInt, + model_validator, + TypeAdapter, + ConfigDict, +) +from typing_extensions import Annotated + +from generalresearch.models import Source +from generalresearch.models.custom_types import ( + CountryISOLike, + UUIDStr, + AwareDatetimeISO, +) +from generalresearch.models.thl.definitions import ( + Status, + StatusCode1, + WallStatusCode2, + SessionStatusCode2, +) + + +class MessageKind(StrEnum): + # An event, with type EventType (enum below) + EVENT = "event" + # A message containing bulk/aggregated stats + STATS = "stats" + # Heartbeats + PING = "ping" + PONG = "pong" + # Must be the first message sent from client + SUBSCRIBE = "subscribe" + + +class EventType(StrEnum): + # Task Lifecycle + # (enter/finish could also be called start/end) + TASK_ENTER = "task.enter" + TASK_FINISH = "task.finish" + # Session Lifecycle + SESSION_ENTER = "session.enter" + SESSION_FINISH = "session.finish" + # Wallet / payments + WALLET_CREDIT = "wallet.credit" + WALLET_DEBIT = "wallet.debit" + + # User + USER_CREATED = "user.created" # A user we've never seen before + USER_ACTIVE = "user.active" + # USER_AUDIT = "user.audit" # Something happened with this user + + +class TaskEnterPayload(BaseModel): + event_type: Literal[EventType.TASK_ENTER] = EventType.TASK_ENTER + + source: Source = Field() + survey_id: str = Field(min_length=1, max_length=32, examples=["127492892"]) + quota_id: Optional[str] = Field( + default=None, + max_length=32, + description="The marketplace's internal quota id", + ) + country_iso: CountryISOLike = Field() + + +class TaskFinishPayload(TaskEnterPayload): + event_type: Literal[EventType.TASK_FINISH] = EventType.TASK_FINISH + + duration_sec: PositiveFloat = Field() + status: Status + status_code_1: Optional[StatusCode1] = None + status_code_2: Optional[WallStatusCode2] = None + cpi: Optional[NonNegativeInt] = Field(le=4000, default=None) + + +class SessionEnterPayload(BaseModel): + event_type: Literal[EventType.SESSION_ENTER] = EventType.SESSION_ENTER + country_iso: CountryISOLike = Field() + + +class SessionFinishPayload(SessionEnterPayload): + event_type: Literal[EventType.SESSION_FINISH] = EventType.SESSION_FINISH + + duration_sec: PositiveFloat = Field() + status: Status + status_code_1: Optional[StatusCode1] = None + status_code_2: Optional[SessionStatusCode2] = None + user_payout: Optional[NonNegativeInt] = Field(default=None, le=4000, ge=0) + + +EventPayload = Annotated[ + Union[ + TaskEnterPayload, + TaskFinishPayload, + SessionEnterPayload, + SessionFinishPayload, + ], + Field(discriminator="event_type"), +] + + +class EventEnvelope(BaseModel): + event_uuid: UUIDStr = Field(default_factory=lambda: uuid4().hex) + event_type: EventType = Field() + timestamp: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + version: int = 1 + + product_user_id: Optional[str] = Field( + min_length=3, + max_length=128, + examples=["app-user-9329ebd"], + description="A unique identifier for each user. This is hidden unless" + "the event is for the requesting user.", + default=None, + ) + product_id: UUIDStr = Field(examples=["4fe381fb7186416cb443a38fa66c6557"]) + + payload: EventPayload + + @model_validator(mode="after") + def event_type_matches_payload(self): + if self.event_type != self.payload.event_type: + raise ValueError("event_type must match payload.event_type") + return self + + +class AggregateBySource(BaseModel): + total: NonNegativeInt = Field(default=0) + by_source: Dict[Source, NonNegativeInt] = Field(default_factory=dict) + + @model_validator(mode="after") + def remove_zero(self): + self.by_source = {k: v for k, v in self.by_source.items() if v} + return self + + +class MaxGaugeBySource(BaseModel): + value: Optional[NonNegativeInt] = Field(default=None) + by_source: Dict[Source, NonNegativeInt] = Field(default_factory=dict) + + @model_validator(mode="after") + def remove_zero(self): + self.by_source = {k: v for k, v in self.by_source.items() if v} + return self + + +class TaskStatsSnapshot(BaseModel): + # Counts: Task related + live_task_count: AggregateBySource = Field(default_factory=AggregateBySource) + + task_created_count_last_1h: AggregateBySource = Field( + default_factory=AggregateBySource + ) + task_created_count_last_24h: AggregateBySource = Field( + default_factory=AggregateBySource + ) + + live_tasks_max_payout: MaxGaugeBySource = Field( + description="In integer USDCents", default_factory=MaxGaugeBySource + ) + + +class StatsSnapshot(TaskStatsSnapshot): + model_config = ConfigDict(ser_json_timedelta="float") + + # If this is set, then everything is scoped to this country. + country_iso: Optional[CountryISOLike] = Field(default=None) + + timestamp: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + + # Counts: User related + active_users_last_1h: NonNegativeInt = Field( + description="""Count of users (in this product_id) that were active in the past 60 minutes. + Behaviors that trigger active: + - Request an offerwall + - Enter an offerwall bucket + - Request profiling questions + - Submit profiling answers + - Update user profile + """, + default=0, + ) + active_users_last_24h: NonNegativeInt = Field( + description="Count of users (in this product_id) that were active in the past 24 hours.", + default=0, + ) + # decrements upon either 90 min since enter, or upon finish. + in_progress_users: NonNegativeInt = Field( + description="Count of users that are currently doing work at this moment" + ) + signups_last_24h: NonNegativeInt = Field(description="Count of users created") + # Requires db for lookback. Skip for now + # total_users: NonNegativeInt = Field(description="Total count of users over all time") + + # Counts: Activity related + session_enters_last_1h: NonNegativeInt = Field() + session_enters_last_24h: NonNegativeInt = Field() + session_fails_last_1h: NonNegativeInt = Field() + session_fails_last_24h: NonNegativeInt = Field() + session_completes_last_1h: NonNegativeInt = Field() + session_completes_last_24h: NonNegativeInt = Field() + sum_payouts_last_1h: NonNegativeInt = Field(ge=0, description="In integer USDCents") + sum_payouts_last_24h: NonNegativeInt = Field( + ge=0, description="In integer USDCents" + ) + + # Rolling averages + session_avg_payout_last_24h: Optional[NonNegativeInt] = Field( + description="Average (actual) payout of all tasks completed in the past 24 hrs" + ) + session_avg_user_payout_last_24h: Optional[NonNegativeInt] = Field( + description="Average (actual) user payout of all tasks completed in the past 24 hrs" + ) + + session_fail_avg_loi_last_24h: Optional[timedelta] = Field( + description="Average LOI of all tasks terminated in the past 24 hrs (excludes abandons)" + ) + session_complete_avg_loi_last_24h: Optional[timedelta] = Field( + description="Average LOI of all tasks completed in the past 24 hrs" + ) + + # # todo: + # avg_user_earned_last_24h: Optional[NonNegativeFloat] = Field( + # ge=0, + # default=None, + # description="The average amount active users earned in total in the past 24 hrs", + # ) + + +# ---------------- +# Top-level messages +# ---------------- + + +class EventMessage(BaseModel): + kind: Literal[MessageKind.EVENT] = Field(default=MessageKind.EVENT) + timestamp: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + data: EventEnvelope + + +class StatsMessage(BaseModel): + kind: Literal[MessageKind.STATS] = Field(default=MessageKind.STATS) + timestamp: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + # The data/StatsSnapshot can optionally be scoped to a country + country_iso: Optional[CountryISOLike] = Field(default=None) + data: StatsSnapshot + + +class PingMessage(BaseModel): + kind: Literal[MessageKind.PING] = Field(default=MessageKind.PING) + timestamp: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + + +class PongMessage(BaseModel): + kind: Literal[MessageKind.PONG] = Field(default=MessageKind.PONG) + timestamp: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + + +class SubscribeMessage(BaseModel): + kind: Literal[MessageKind.SUBSCRIBE] = Field(default=MessageKind.SUBSCRIBE) + product_id: UUIDStr = Field(examples=["4fe381fb7186416cb443a38fa66c6557"]) + + +ServerToClientMessage = Union[EventMessage, StatsMessage, PingMessage] +ServerToClientMessageField = Annotated[ + ServerToClientMessage, + Field(discriminator="kind"), +] +ServerToClientMessageAdapter = TypeAdapter(ServerToClientMessageField) + +ClientToServerMessage = Union[ + SubscribeMessage, + PongMessage, +] +ClientToServerMessageField = Annotated[ + ClientToServerMessage, + Field(discriminator="kind"), +] +ClientToServerMessageAdapter = TypeAdapter(ClientToServerMessageField) diff --git a/generalresearch/models/gr/__init__.py b/generalresearch/models/gr/__init__.py new file mode 100644 index 0000000..713bba6 --- /dev/null +++ b/generalresearch/models/gr/__init__.py @@ -0,0 +1,13 @@ +from generalresearch.models.gr.authentication import GRUser, GRToken +from generalresearch.models.gr.business import Business +from generalresearch.models.gr.team import Team +from generalresearch.models.thl.payout import BrokerageProductPayoutEvent +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.finance import BusinessBalances + +_ = Business, Product, BrokerageProductPayoutEvent, BusinessBalances + +GRUser.model_rebuild() +GRToken.model_rebuild() +Business.model_rebuild() +Team.model_rebuild() diff --git a/generalresearch/models/gr/authentication.py b/generalresearch/models/gr/authentication.py new file mode 100644 index 0000000..ff1d065 --- /dev/null +++ b/generalresearch/models/gr/authentication.py @@ -0,0 +1,375 @@ +from __future__ import annotations + +import binascii +import json +import os +from datetime import datetime, timezone +from typing import Optional, List, TYPE_CHECKING, Dict, Union + +from pydantic import AnyHttpUrl +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PositiveInt, + field_validator, + NonNegativeInt, +) +from typing_extensions import Self + +from generalresearch.decorators import LOG +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr +from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig + +if TYPE_CHECKING: + from generalresearch.models.gr.business import Business + from generalresearch.models.gr.team import Team + from generalresearch.models.thl.product import Product + + +class Claims(BaseModel): + iss: Optional[str] = Field( + default=None, + description="Issuer: https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1", + ) + + sub: Optional[str] = Field( + default=None, + description="Subject: https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.2", + ) + + aud: Optional[str] = Field( + default=None, + description="Audience: https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.3", + ) + + exp: Optional[NonNegativeInt] = Field( + default=None, + description="Expiration time: https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.4", + ) + + iat: Optional[NonNegativeInt] = Field( + default=None, + description="Issued at: https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.6", + ) + + auth_time: Optional[NonNegativeInt] = Field( + default=None, + description="When authentication occured: https://openid.net/specs/openid-connect-core-1_0.html#IDToken", + ) + + acr: Optional[str] = Field( + default=None, + description="Authentication Context Class Reference: https://openid.net/specs/openid-connect-core-1_0.html#IDToken", + ) + + amr: Optional[List[str]] = Field( + default=None, + description="Authentication Methods References: https://openid.net/specs/openid-connect-core-1_0.html#IDToken", + ) + + c_hash: Optional[str] = Field( + default=None, + description="Code hash value: http://openid.net/specs/openid-connect-core-1_0.html", + ) + + nonce: Optional[str] = Field( + default=None, + description="Value used to associate a Client session with an ID Token: http://openid.net/specs/openid-connect-core-1_0.html", + ) + + at_hash: Optional[str] = Field( + default=None, + description="Access Token hash value: http://openid.net/specs/openid-connect-core-1_0.html", + ) + + sid: Optional[str] = Field( + default=None, + description="Session ID: https://openid.net/specs/openid-connect-frontchannel-1_0.html#ClaimsContents", + ) + + # --- Properties --- + + @property + def subject(self): + return self.sub + + +class GRUser(BaseModel): + model_config = ConfigDict( + # extra="forbid", + # from_attributes=True, + arbitrary_types_allowed=True + ) + + id: Optional[PositiveInt] = Field(default=None) + sub: Optional[str] = Field(max_length=200) + is_superuser: bool = Field(default=False) + + date_joined: AwareDatetimeISO = Field( + description="When the GR User account signed up." + ) + + # prefetch attributes + businesses: Optional[List["Business"]] = Field(default=None) + teams: Optional[List["Team"]] = Field(default=None) + products: Optional[List["Product"]] = Field(default=None) + token: Optional["GRToken"] = Field(default=None) + claims: Optional["Claims"] = Field(default=None) + + def prefetch_claims( + self, token: str, key: Dict, audience: str, issuer: AnyHttpUrl + ) -> None: + from jose import jwt + + payload = jwt.decode( + token=token, + key=key, + algorithms=["RS256"], + audience=audience, + issuer=issuer, + ) + self.claims = Claims.model_validate(payload) + + def prefetch_businesses( + self, pg_config: PostgresConfig, redis_config: RedisConfig + ) -> None: + from generalresearch.managers.gr.business import BusinessManager + + bm = BusinessManager(pg_config=pg_config, redis_config=redis_config) + + if self.is_superuser: + self.businesses = bm.get_all() + else: + self.businesses = bm.get_by_user_id(user_id=self.id) + + def prefetch_teams( + self, pg_config: PostgresConfig, redis_config: RedisConfig + ) -> None: + from generalresearch.managers.gr.team import TeamManager + + tm = TeamManager(pg_config=pg_config, redis_config=redis_config) + + if self.is_superuser: + self.teams = tm.get_all() + else: + self.teams = tm.get_by_user(gr_user=self) + + def prefetch_products( + self, + pg_config: PostgresConfig, + thl_pg_config: PostgresConfig, + redis_config: RedisConfig, + ) -> None: + + self.prefetch_businesses(pg_config=pg_config, redis_config=redis_config) + self.prefetch_teams(pg_config=pg_config, redis_config=redis_config) + business_uuids = self.business_uuids + team_uuids = self.team_uuids + + if len(business_uuids + team_uuids) == 0: + self.products = [] + return None + + from generalresearch.managers.thl.product import ProductManager + + pm = ProductManager(pg_config=thl_pg_config) + + business_products = ( + pm.fetch_uuids(business_uuids=business_uuids) if business_uuids else [] + ) + team_products = pm.fetch_uuids(team_uuids=team_uuids) if team_uuids else [] + products = {p.id: p for p in business_products + team_products} + + self.products = sorted(products.values(), key=lambda x: getattr(x, "created")) + + def prefetch_token(self, pg_config: PostgresConfig): + from generalresearch.managers.gr.authentication import ( + GRTokenManager, + ) + + tm = GRTokenManager(pg_config=pg_config) + self.token = tm.get_by_user_id(user_id=self.id) + + def __eq__(self, other: "GRUser") -> bool: + return self.id == other.id + + # --- Validations --- + @field_validator("date_joined") + @classmethod + def date_joined_utc(cls, v: datetime) -> datetime: + return v.replace(tzinfo=timezone.utc) + + # --- Properties --- + @property + def cache_key(self) -> str: + return f"gr_user:{self.id}" + + @property + def business_uuids(self) -> Optional[List[UUIDStr]]: + if self.businesses is None: + LOG.warning("prefetch not run") + return None + + return [b.uuid for b in self.businesses] + + @property + def business_ids(self) -> Optional[List[PositiveInt]]: + if self.businesses is None: + LOG.warning("prefetch not run") + return None + + return [b.id for b in self.businesses] + + @property + def team_uuids(self) -> Optional[List[UUIDStr]]: + if self.teams is None: + LOG.warning("prefetch not run") + return None + + return [t.uuid for t in self.teams] + + @property + def team_ids(self) -> Optional[List[PositiveInt]]: + if self.teams is None: + LOG.warning("prefetch not run") + return None + + return [t.id for t in self.teams] + + @property + def product_uuids(self) -> Optional[List[UUIDStr]]: + if self.products is None: + LOG.warning("prefetch not run") + return None + + return [p.uuid for p in self.products] + + # --- Methods --- + + def set_cache( + self, + pg_config: PostgresConfig, + thl_web_rr: PostgresConfig, + redis_config: RedisConfig, + ) -> None: + ex_secs = 60 * 60 * 24 * 3 # 3 days + + self.prefetch_teams(pg_config=pg_config, redis_config=redis_config) + self.prefetch_businesses(pg_config=pg_config, redis_config=redis_config) + self.prefetch_products( + pg_config=pg_config, + thl_pg_config=thl_web_rr, + redis_config=redis_config, + ) + self.prefetch_token(pg_config=pg_config) + + rc = redis_config.create_redis_client() + + rc.set(name=self.cache_key, value=self.to_redis(), ex=ex_secs) + rc.set( + name=f"{self.cache_key}:team_uuids", + value=json.dumps(self.team_uuids), + ex=ex_secs, + ) + rc.set( + name=f"{self.cache_key}:business_uuids", + value=json.dumps(self.business_uuids), + ex=ex_secs, + ) + rc.set( + name=f"{self.cache_key}:product_uuids", + value=json.dumps(self.product_uuids), + ex=ex_secs, + ) + + return None + + # --- ORM --- + + @classmethod + def from_postgresql(cls, d: dict) -> Self: + d["date_joined"] = d["date_joined"].replace(tzinfo=timezone.utc) + return GRUser.model_validate(d) + + @classmethod + def from_redis(cls, d: Union[str, Dict]) -> Self: + if isinstance(d, str): + d = json.loads(d) + assert isinstance(d, dict) + + d["date_joined"] = datetime.fromisoformat(d["date_joined"]) + + if d.get("token"): + d["token"] = GRToken.from_redis(d["token"]) + + return GRUser.model_validate(d) + + def to_redis(self) -> str: + d = self.model_dump(mode="json", exclude={"businesses", "teams", "products"}) + d["business_uuids"] = self.business_uuids + d["team_uuids"] = self.team_uuids + d["product_uuids"] = self.product_uuids + + return json.dumps(d) + + +class GRToken(BaseModel): + key: str = Field( + min_length=32, + max_length=2_000, + # rest_framework.authtoken.models.py:37 generate_key() + examples=[binascii.hexlify(os.urandom(20)).decode()], + ) + + created: AwareDatetimeISO = Field() + user_id: PositiveInt = Field() + + # --- prefetch field --- + user: Optional["GRUser"] = Field(default=None) + + @property + def sso(self) -> bool: + return GRToken.is_sso(api_key=self.key) + + @staticmethod + def is_sso(api_key: str) -> bool: + return len(api_key) > 255 + + def prefetch_user( + self, pg_config: PostgresConfig, redis_config: RedisConfig + ) -> None: + from generalresearch.managers.gr.authentication import ( + GRUserManager, + ) + + gr_um = GRUserManager(pg_config=pg_config, redis_config=redis_config) + + self.user = gr_um.get_by_id(gr_user_id=self.user_id) + + def __eq__(self, other: "GRToken") -> bool: + return self.key == other.key + + @field_validator("created", mode="before") + @classmethod + def created_utc(cls, v: datetime) -> datetime: + return v.replace(tzinfo=timezone.utc) + + # --- Properties --- + + @property + def auth_header(self, key_name="Authorization") -> Dict: + return {key_name: self.key} + + # --- ORM --- + + @classmethod + def from_redis(cls, d: Union[str, Dict]) -> Self: + if isinstance(d, str): + d = json.loads(d) + assert isinstance(d, dict) + + d["created"] = datetime.fromisoformat(d["created"]) + + return GRToken.model_validate(d) diff --git a/generalresearch/models/gr/business.py b/generalresearch/models/gr/business.py new file mode 100644 index 0000000..b07c584 --- /dev/null +++ b/generalresearch/models/gr/business.py @@ -0,0 +1,743 @@ +from __future__ import annotations + +import json +import logging +import os +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Optional, List, TYPE_CHECKING, Union +from uuid import uuid4 + +import pandas as pd +from dask.distributed import Client +from psycopg.cursor import Cursor +from psycopg.rows import dict_row +from pydantic import BaseModel, ConfigDict, Field, PositiveInt +from pydantic.json_schema import SkipJsonSchema +from pydantic_extra_types.phone_numbers import PhoneNumber +from typing_extensions import Self + +from generalresearch.currency import USDCent +from generalresearch.decorators import LOG +from generalresearch.incite.mergers.pop_ledger import PopLedgerMerge +from generalresearch.incite.schemas.mergers.pop_ledger import ( + numerical_col_names, +) +from generalresearch.utils.enum import ReprEnumMeta +from generalresearch.models.admin.request import ReportRequest, ReportType +from generalresearch.models.custom_types import ( + UUIDStr, + UUIDStrCoerce, + AwareDatetime, +) +from generalresearch.models.thl.finance import POPFinancial +from generalresearch.models.thl.ledger import LedgerAccount, OrderBy +from generalresearch.models.thl.payout import BusinessPayoutEvent +from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig +from generalresearch.utils.aggregation import group_by_year + +if TYPE_CHECKING: + from generalresearch.models.thl.finance import BusinessBalances + + from generalresearch.models.thl.product import Product + from generalresearch.models.gr.team import Team + from generalresearch.incite.base import GRLDatasets + from generalresearch.incite.mergers.foundations.enriched_session import ( + EnrichedSessionMerge, + ) + from generalresearch.incite.mergers.foundations.enriched_wall import ( + EnrichedWallMerge, + ) + from generalresearch.managers.thl.ledger_manager.ledger import ( + LedgerManager, + ) + from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, + ) + from generalresearch.managers.thl.payout import ( + BusinessPayoutEventManager, + ) + from generalresearch.managers.thl.product import ProductManager + + +class TransferMethod(Enum, metaclass=ReprEnumMeta): + ACH = 0 + WIRE = 1 + + +class BusinessType(str, Enum, metaclass=ReprEnumMeta): + INDIVIDUAL = "i" + COMPANY = "c" + + +class BusinessBankAccount(BaseModel): + model_config = ConfigDict( + use_enum_values=True, + json_encoders={TransferMethod: lambda tm: tm.value}, + ) + + id: SkipJsonSchema[Optional[PositiveInt]] = Field(default=None) + uuid: UUIDStrCoerce = Field(examples=[uuid4().hex]) + + business_id: PositiveInt = Field() + + # 'business' is a Class with values that are fetched from the DB. + # Initialization is deferred until it is actually needed + # (see .prefetch_business()) + business: SkipJsonSchema[Optional["Business"]] = Field(default=None) + + transfer_method: TransferMethod = Field( + description=TransferMethod.as_openapi(), + examples=[TransferMethod.ACH.value], + ) + + # ACH requirements + account_number: Optional[str] = Field( + default=None, + max_length=16, + description="ACH requirements", + examples=[f"{'*' * 9}1234"], + ) + + routing_number: Optional[str] = Field( + default=None, + max_length=9, + description="ACH requirements", + examples=[f"{'*' * 5}1234"], + ) + + # Wire requirements + iban: Optional[str] = Field( + default=None, + max_length=50, + description="Wire requirements", + examples=[None], + ) + swift: Optional[str] = Field( + default=None, + max_length=50, + description="Wire requirements", + examples=[None], + ) + + def prefetch_business( + self, pg_config: PostgresConfig, redis_config: RedisConfig + ) -> None: + from generalresearch.managers.gr.business import BusinessManager + + if self.business is None: + bm = BusinessManager(pg_config=pg_config, redis_config=redis_config) + self.business = bm.get_by_id(business_id=self.business_id) + + +class BusinessAddress(BaseModel): + model_config = ConfigDict(extra="ignore") + + id: SkipJsonSchema[Optional[PositiveInt]] = Field(default=None) + uuid: UUIDStrCoerce = Field(examples=[uuid4().hex]) + + line_1: Optional[str] = Field( + default=None, max_length=255, examples=["540 Mariposa"] + ) + + line_2: Optional[str] = Field(default=None, max_length=255, examples=[None]) + + city: Optional[str] = Field( + default=None, max_length=255, examples=["Mountain View"] + ) + + state: Optional[str] = Field( + default=None, + max_length=255, + description="This can only be more than len=2 if it's a state or" + "providence out of the United States", + examples=["CA"], + ) + + postal_code: Optional[str] = Field(default=None, max_length=12, examples=["94041"]) + + phone_number: Optional[PhoneNumber] = Field(default=None) + + country: Optional[str] = Field(default=None, max_length=2, examples=["US"]) + + business_id: PositiveInt = Field() + + +class BusinessContact(BaseModel): + model_config = ConfigDict(extra="ignore") + + name: Optional[str] = Field(default=None) + email: Optional[str] = Field(default=None) + + phone_number: Optional[str] = Field( + default=None, + min_length=10, + max_length=31, + examples=["+1 (888) 888-8888"], + ) + + +class Business(BaseModel): + """This is the Base model to represent a Business,""" + + model_config = ConfigDict(extra="ignore") + + id: SkipJsonSchema[Optional[PositiveInt]] = Field(default=None) + uuid: UUIDStrCoerce = Field(examples=[uuid4().hex]) + + name: str = Field( + min_length=3, + max_length=255, + examples=["General Research Laboratories, LLC"], + ) + + kind: str = Field( + max_length=1, + description=BusinessType.as_openapi(), + examples=[BusinessType.COMPANY.value], + ) + + tax_number: Optional[str] = Field(default=None, max_length=20) + contact: Optional["BusinessContact"] = Field(default=None) + + # Initialization is deferred until it is actually needed + # (see .prefetch_***()) + addresses: Optional[List["BusinessAddress"]] = Field(default=None) + teams: Optional[List["Team"]] = Field(default=None) + products: Optional[List["Product"]] = Field(default=None) + bank_accounts: Optional[List["BusinessBankAccount"]] = Field(default=None) + + # Initialization is deferred until unless it's called + # (see .prebuild_***()) + balance: Optional["BusinessBalances"] = Field(default=None, name="Business Balance") + + payouts_total_str: Optional[str] = Field(default=None) + payouts_total: Optional[USDCent] = Field(default=None) + payouts: Optional[List[BusinessPayoutEvent]] = Field( + default=None, + name="Business Payouts", + description="These are the ACH or Wire payments that were sent to the" + "Business as a single amount, summed for all the Business" + "child Products", + ) + + pop_financial: Optional[List[POPFinancial]] = Field(default=None) + bp_accounts: Optional[List[LedgerAccount]] = Field(default=None) + + def __str__(self) -> str: + return ( + f"Name: {self.name} ({self.uuid})\n" + f"Products: {len(self.products) if self.products else 'Not Loaded'}\n" + f"Ledger Accounts: {len(self.bp_accounts) if self.bp_accounts else 'Not Loaded'}\n" + f"Addresses: {len(self.addresses) if self.addresses else 'Not Loaded'}\n" + f"Teams: {len(self.teams) if self.teams else 'Not Loaded'}\n" + f"Bank Accounts: {len(self.bank_accounts) if self.bank_accounts else 'Not Loaded'}\n" + f"–––\n" + f"Payouts: {len(self.payouts) if self.payouts else 'Not Loaded'}\n" + f"Available Balance: {self.balance.available_balance if self.balance else 'Not Loaded'}\n" + ) + + def __repr__(self): + return f"" + + # --- Prefetch --- + + def prefetch_addresses(self, pg_config: PostgresConfig) -> None: + with pg_config.make_connection() as conn: + with conn.cursor(row_factory=dict_row) as c: + c.execute( + query=f""" + SELECT * + FROM common_businessaddress AS ba + WHERE ba.business_id = %s + LIMIT 1 + """, + params=[self.id], + ) + res = c.fetchall() + + if len(res) == 0: + self.addresses = [] + + self.addresses = [BusinessAddress.model_validate(i) for i in res] + + def prefetch_teams(self, pg_config: PostgresConfig) -> None: + from generalresearch.models.gr.team import Team + + with pg_config.make_connection() as conn: + with conn.cursor(row_factory=dict_row) as c: + c: Cursor + + c.execute( + query=f""" + SELECT t.* + FROM common_team AS t + INNER JOIN common_team_businesses AS tb + ON tb.team_id = t.id + WHERE tb.business_id = %s + """, + params=(self.id,), + ) + + res = c.fetchall() + + if len(res) == 0: + self.teams = [] + + self.teams = [Team.model_validate(i) for i in res] + + def prefetch_products(self, thl_pg_config: PostgresConfig) -> None: + """ + :return: All the Products for this Business + """ + from generalresearch.managers.thl.product import ProductManager + + pm = ProductManager(pg_config=thl_pg_config) + self.products = pm.fetch_uuids(business_uuids=[self.uuid]) + + def prefetch_bank_accounts(self, pg_config: PostgresConfig) -> None: + from generalresearch.managers.gr.business import ( + BusinessBankAccountManager, + ) + + bam = BusinessBankAccountManager(pg_config=pg_config) + self.bank_accounts = bam.get_by_business_id(business_id=self.id) + + def prefetch_bp_accounts(self, lm: LedgerManager, thl_pg_config: PostgresConfig): + # We need to prefetch the Products everytime because there is no way + # of knowing if a new Product has been added since the last time it + # ran. + self.prefetch_products(thl_pg_config=thl_pg_config) + + accounts = lm.get_accounts_if_exists( + qualified_names=[ + f"{lm.currency.value}:bp_wallet:{bpid}" for bpid in self.product_uuids + ] + ) + + assert len(accounts) == len(self.product_uuids) + + self.bp_accounts = accounts + + # --- Prebuild --- + + def prebuild_balance( + self, + thl_pg_config: PostgresConfig, + lm: "LedgerManager", + ds: "GRLDatasets", + client: Client, + pop_ledger: Optional["PopLedgerMerge"] = None, + at_timestamp: Optional[AwareDatetime] = None, + ) -> None: + """ + This returns the Business's Balances that are calculated across + all time. They are inclusive of every transaction that has ever + occurred in relation to any of the Products for this Business + + GRL does not use a Net30 or other time or Monthly styling billing + practice. All financials are calculated in real time and immediately + available based off the real-time calculated Smart Retainer balance. + + Smart Retainer: + GRL's fully automated smart retainer system incorporates the real-time + recon risk exposure on the BPID account. The retainer amount is prone + to change every few hours based off real time traffic characteristics. + The intention is to provide protection against an account immediately + stopping traffic and having up to 2 months worth of reconciliations + continue to roll in. Using the Smart Retainer amount will allow the + most amount of an accounts balance to be deposited into the owner's + account at any frequency without being tied to monthly invoicing. The + goal is to be as aggressive as possible and not hold funds longer than + absolutely required, Smart Retainer accounts are supported for any + volume levels. + """ + LOG.debug(f"Business.prebuild_balance({self.uuid=})") + + self.prefetch_products(thl_pg_config=thl_pg_config) + + accounts: List[LedgerAccount] = lm.get_accounts_if_exists( + qualified_names=[ + f"{lm.currency.value}:bp_wallet:{bpid}" for bpid in self.product_uuids + ] + ) + + if len(accounts) != len(self.products): + raise ValueError("Inconsistent BP Wallet Accounts for Business: ") + + if pop_ledger is None: + from generalresearch.incite.defaults import pop_ledger as plm + + pop_ledger = plm(ds=ds) + + if at_timestamp is None: + at_timestamp = datetime.now(tz=timezone.utc) + assert at_timestamp.tzinfo == timezone.utc + + ddf = pop_ledger.ddf( + force_rr_latest=False, + include_partial=True, + columns=numerical_col_names + ["account_id"], + filters=[ + ("account_id", "in", [a.uuid for a in accounts]), + ("time_idx", "<=", at_timestamp), + ], + ) + + if ddf is None: + raise AssertionError("Cannot build Business Balance") + + # This is so stupid. Something goes wrong when trying to groupby directly + # on the ddf (says there is a datetime), so drop the small speed + # improvement and simply build the full df, and then group by on + # a pandas df instead of a dask dataframe + # https://g-r-l.slack.com/archives/G8ULA6CV8/p1755898636685149?thread_ts=1755868251.296459&cid=G8ULA6CV8 + # ddf = ddf.groupby("account_id").sum() + df: pd.DataFrame = client.compute(collections=ddf, sync=True) + + if df.empty: + # A Business can have multiple Products. However, none of those + # Products need to have had any ledger transactional events and + # that is still valid. Don't attempt to build a balance, leave it + # as None rather than all zeros + LOG.warning(f"Business({self.uuid=}).prebuild_balance empty dataframe") + return None + + LOG.debug(f"Business.prebuild_balance.groupby() {df.head()}") + df = df.groupby("account_id").sum() + + from generalresearch.models.thl.finance import BusinessBalances + + self.balance = BusinessBalances.from_pandas( + input_data=df, accounts=accounts, thl_pg_config=thl_pg_config + ) + + return None + + def prebuild_payouts( + self, + thl_pg_config: PostgresConfig, + thl_lm: "ThlLedgerManager", + bpem: BusinessPayoutEventManager, + ) -> None: + LOG.debug(f"Business.prebuild_payouts({self.uuid=})") + + self.prefetch_products(thl_pg_config=thl_pg_config) + + self.payouts = bpem.get_business_payout_events_for_products( + thl_ledger_manager=thl_lm, + product_uuids=self.product_uuids, + order_by=OrderBy.DESC, + ) + + self.prebuild_payouts_total() + + def prebuild_payouts_total(self): + assert self.payouts is not None + self.payouts_total = USDCent(sum([po.amount for po in self.payouts])) + self.payouts_total_str = self.payouts_total.to_usd_str() + + return None + + def prebuild_pop_financial( + self, + thl_pg_config: PostgresConfig, + lm: "LedgerManager", + ds: "GRLDatasets", + client: Client, + pop_ledger: Optional["PopLedgerMerge"] = None, + ) -> None: + """This is very similar to the Product POP Financial endpoint; however, + it returns more than one item for a single time interval. This is + because more than a single account will have likely had any + financial activity within that time window. + """ + if self.bp_accounts is None: + self.prefetch_bp_accounts(lm=lm, thl_pg_config=thl_pg_config) + + from generalresearch.models.admin.request import ( + ReportRequest, + ReportType, + ) + + rr = ReportRequest(report_type=ReportType.POP_LEDGER, interval="5min") + + if pop_ledger is None: + from generalresearch.incite.defaults import pop_ledger as plm + + pop_ledger = plm(ds=ds) + + ddf = pop_ledger.ddf( + force_rr_latest=False, + include_partial=True, + columns=numerical_col_names + ["time_idx", "account_id"], + filters=[ + ("account_id", "in", [a.uuid for a in self.bp_accounts]), + ("time_idx", ">=", pop_ledger.start), + ], + ) + if ddf is None: + self.pop_financial = [] + return None + + df = client.compute(collections=ddf, sync=True) + + if df.empty: + self.pop_financial = [] + return None + + df = df.groupby( + [pd.Grouper(key="time_idx", freq=rr.interval), "account_id"] + ).sum() + + self.pop_financial = POPFinancial.list_from_pandas( + input_data=df, accounts=self.bp_accounts + ) + + def prebuild_enriched_session_parquet( + self, + thl_pg_config: PostgresConfig, + ds: "GRLDatasets", + client: Client, + mnt_gr_api: Path, + enriched_session: Optional["EnrichedSessionMerge"] = None, + ) -> None: + self.prefetch_products(thl_pg_config=thl_pg_config) + + if enriched_session is None: + from generalresearch.incite.defaults import ( + enriched_session as es, + ) + + enriched_session = es(ds=ds) + + rr = ReportRequest.model_validate( + { + "start": enriched_session.start, + "interval": "5min", + "type": ReportType.POP_SESSION, + } + ) + df = enriched_session.to_admin_response( + product_ids=self.product_uuids, rr=rr, client=client + ) + + path = Path( + os.path.join(mnt_gr_api, rr.report_type.value, f"{self.file_key}.parquet") + ) + + df.to_parquet( + path=path, + engine="pyarrow", + compression="brotli", + ) + + try: + test = pd.read_parquet(path, engine="pyarrow") + except Exception as e: + raise IOError(f"Parquet verification failed: {e}") + + return None + + def prebuild_enriched_wall_parquet( + self, + thl_pg_config: PostgresConfig, + ds: "GRLDatasets", + client: Client, + mnt_gr_api: Path, + enriched_wall: Optional["EnrichedWallMerge"] = None, + ) -> None: + self.prefetch_products(thl_pg_config=thl_pg_config) + + if enriched_wall is None: + from generalresearch.incite.defaults import ( + enriched_wall as ew, + ) + + enriched_wall = ew(ds=ds) + + rr = ReportRequest.model_validate( + { + "start": enriched_wall.start, + "interval": "5min", + "report_type": ReportType.POP_EVENT, + } + ) + df = enriched_wall.to_admin_response( + product_ids=self.product_uuids, rr=rr, client=client + ) + + path = Path( + os.path.join(mnt_gr_api, rr.report_type.value, f"{self.file_key}.parquet") + ) + + df.to_parquet( + path=path, + engine="pyarrow", + compression="brotli", + ) + + try: + test = pd.read_parquet(path, engine="pyarrow") + except Exception as e: + raise IOError(f"Parquet verification failed: {e}") + + return None + + @classmethod + def required_fields(cls) -> List[str]: + return [ + field_name + for field_name, field_info in cls.model_fields.items() + if field_info.is_required() + ] + + # --- Properties --- + + @property + def product_uuids(self) -> Optional[List[UUIDStr]]: + if self.products is None: + LOG.warning("prefetch not run") + return None + + return [p.uuid for p in self.products] + + @property + def cache_key(self) -> str: + return f"business:{self.uuid}" + + @property + def file_key(self) -> str: + return f"business-{self.uuid}" + + # --- Methods --- + + def set_cache( + self, + pg_config: PostgresConfig, + thl_web_rr: PostgresConfig, + redis_config: RedisConfig, + client: "Client", + ds: "GRLDatasets", + lm: "LedgerManager", + thl_lm: "ThlLedgerManager", + bpem: "BusinessPayoutEventManager", + mnt_gr_api: Union[Path, str], + pop_ledger: Optional["PopLedgerMerge"] = None, + enriched_session: Optional["EnrichedSessionMerge"] = None, + enriched_wall: Optional["EnrichedWallMerge"] = None, + ) -> None: + LOG.debug(f"Business.set_cache({self.uuid=})") + + ex_secs = 60 * 60 * 24 * 3 # 3 days + + self.prefetch_addresses(pg_config=pg_config) + self.prefetch_teams(pg_config=pg_config) + self.prefetch_products(thl_pg_config=thl_web_rr) + self.prefetch_bank_accounts(pg_config=pg_config) + self.prefetch_bp_accounts(lm=lm, thl_pg_config=thl_web_rr) + + self.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=ds, + client=client, + pop_ledger=pop_ledger, + ) + self.prebuild_payouts(thl_pg_config=thl_web_rr, thl_lm=thl_lm, bpem=bpem) + self.prebuild_pop_financial( + thl_pg_config=thl_web_rr, + lm=lm, + ds=ds, + client=client, + pop_ledger=pop_ledger, + ) + + rc = redis_config.create_redis_client() + mapping = self.model_dump(mode="json") + + # For POP Financial data, we want to also break that out by year + res = { + f"pop_financial:{key}": value + for key, value in group_by_year( + records=mapping["pop_financial"], datetime_field="time" + ).items() + } + mapping = mapping | res + + for key in mapping: + mapping[key] = json.dumps(mapping[key]) + rc.hset(name=self.cache_key, mapping=mapping) + + # -- Saves Parquet files + if enriched_session is None: + from generalresearch.incite.defaults import ( + enriched_session as es, + ) + + enriched_session = es(ds=ds) + + self.prebuild_enriched_session_parquet( + thl_pg_config=thl_web_rr, + client=client, + ds=ds, + mnt_gr_api=mnt_gr_api, + enriched_session=enriched_session, + ) + + if enriched_wall is None: + from generalresearch.incite.defaults import enriched_wall as ew + + enriched_wall = ew(ds=ds) + + self.prebuild_enriched_wall_parquet( + thl_pg_config=thl_web_rr, + client=client, + ds=ds, + mnt_gr_api=mnt_gr_api, + enriched_wall=enriched_wall, + ) + + return None + + # --- ORM --- + + @classmethod + def from_redis( + cls, + uuid: UUIDStr, + fields: List[str], + gr_redis_config: RedisConfig, + ) -> Optional[Self]: + keys: List = Business.required_fields() + fields + if "pop_financial" in keys: + # We should explicitly pass the pop_financial years we want. By default, + # at least get this year. + year = datetime.now(tz=timezone.utc).year + keys = list(set(keys) | {f"pop_financial:{year}"}) + rc = gr_redis_config.create_redis_client() + + try: + res: List = rc.hmget(name=f"business:{uuid}", keys=keys) + d = { + val: json.loads(res[idx]) if res[idx] is not None else None + for idx, val in enumerate(keys) + } + + # Extract all pop_financial records + pop_financial = [ + record + for key, value in d.items() + if key.startswith("pop_financial:") and value is not None + for record in value + ] + + result = {k: v for k, v in d.items() if not k.startswith("pop_financial:")} + result["pop_financial"] = pop_financial + + return Business.model_validate(result) + except Exception as e: + logging.exception(e) + return None diff --git a/generalresearch/models/gr/team.py b/generalresearch/models/gr/team.py new file mode 100644 index 0000000..38aff56 --- /dev/null +++ b/generalresearch/models/gr/team.py @@ -0,0 +1,346 @@ +import json +import os +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Optional, Union, List, TYPE_CHECKING +from uuid import uuid4 + +import pandas as pd +from dask.distributed import Client +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PositiveInt, + field_validator, +) +from pydantic.json_schema import SkipJsonSchema +from typing_extensions import Self + +from generalresearch.decorators import LOG +from generalresearch.incite.mergers.foundations.enriched_session import ( + EnrichedSessionMerge, +) +from generalresearch.incite.mergers.foundations.enriched_wall import ( + EnrichedWallMerge, +) +from generalresearch.utils.enum import ReprEnumMeta +from generalresearch.models.admin.request import ReportRequest, ReportType +from generalresearch.models.custom_types import ( + AwareDatetimeISO, + UUIDStr, + UUIDStrCoerce, +) +from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig + +if TYPE_CHECKING: + from generalresearch.incite.base import GRLDatasets + from generalresearch.models.gr.business import Business + from generalresearch.models.gr.authentication import GRUser + from generalresearch.models.thl.product import Product + + +class MembershipPrivilege(Enum, metaclass=ReprEnumMeta): + ADMIN = 0 + MAINTAIN = 1 + READ = 2 + + +class Membership(BaseModel): + """A Membership is the relationship between a GR User and a Team. + + GRUsers do not have direct connections to Businesses or Products, + they're all connected through a Team and a GRUser's relationship to + a Team can have various levels of permissions and rights. + """ + + model_config = ConfigDict(use_enum_values=True) + + id: SkipJsonSchema[Optional[PositiveInt]] = Field( + default=None, + ) + uuid: UUIDStrCoerce = Field(examples=[uuid4().hex]) + + privilege: MembershipPrivilege = Field( + default=MembershipPrivilege.MAINTAIN, + examples=[MembershipPrivilege.READ.value], + description=MembershipPrivilege.as_openapi(), + ) + + owner: bool = Field(default=False, examples=[True]) + + created: AwareDatetimeISO = Field( + description="This is when the User was added to the Team, it's when" + "the Membership was created and not when the GR User " + "account was created." + ) + + user_id: SkipJsonSchema[PositiveInt] = Field(default=None) + + team_id: SkipJsonSchema[PositiveInt] = Field() + + # prefetch attributes + team: SkipJsonSchema[Optional["Team"]] = Field(default=None) + + # --- Validators --- + + @field_validator("created", mode="before") + @classmethod + def created_utc(cls, v: Union[datetime, str]) -> Union[datetime, str]: + if isinstance(v, datetime): + return v.replace(tzinfo=timezone.utc) + return v + + # --- prefetch methods --- + + def prefetch_team( + self, pg_config: PostgresConfig, redis_config: RedisConfig + ) -> None: + from generalresearch.managers.gr.team import TeamManager + + tm = TeamManager(pg_config=pg_config, redis_config=redis_config) + self.team = tm.get_by_id(team_id=self.team_id) + + +class Team(BaseModel): + id: SkipJsonSchema[Optional[PositiveInt]] = Field(default=None) + uuid: UUIDStrCoerce = Field(examples=[uuid4().hex]) + name: str = Field(max_length=255, examples=["Team ABC"]) + + # prefetch attributes + memberships: SkipJsonSchema[Optional[List["Membership"]]] = Field(default=None) + gr_users: SkipJsonSchema[Optional[List["GRUser"]]] = Field(default=None) + businesses: SkipJsonSchema[Optional[List["Business"]]] = Field(default=None) + products: SkipJsonSchema[Optional[List["Product"]]] = Field(default=None) + + # --- Prefetch Methods --- + + def prefetch_memberships(self, pg_config: PostgresConfig) -> None: + from generalresearch.managers.gr.team import MembershipManager + + mm = MembershipManager(pg_config=pg_config) + self.memberships = mm.get_by_team_id(team_id=self.id) + + def prefetch_gr_users( + self, pg_config: PostgresConfig, redis_config: RedisConfig + ) -> None: + from generalresearch.managers.gr.authentication import ( + GRUserManager, + ) + + gr_um = GRUserManager(pg_config=pg_config, redis_config=redis_config) + + self.gr_users = gr_um.get_by_team(team_id=self.id) + + def prefetch_businesses( + self, pg_config: PostgresConfig, redis_config: RedisConfig + ) -> None: + from generalresearch.managers.gr.business import BusinessManager + + bm = BusinessManager(pg_config=pg_config, redis_config=redis_config) + self.businesses = bm.get_by_team(team_id=self.id) + + def prefetch_products(self, thl_pg_config: PostgresConfig) -> None: + from generalresearch.managers.thl.product import ProductManager + + pm = ProductManager(pg_config=thl_pg_config) + self.products = pm.fetch_uuids(team_uuids=[self.uuid]) + + # --- Prebuild Methods --- + + def prebuild_enriched_session_parquet( + self, + thl_pg_config: PostgresConfig, + ds: "GRLDatasets", + client: Client, + mnt_gr_api: Path, + enriched_session: Optional["EnrichedSessionMerge"] = None, + ) -> None: + self.prefetch_products(thl_pg_config=thl_pg_config) + + if enriched_session is None: + from generalresearch.incite.defaults import ( + enriched_session as es, + ) + + enriched_session = es(ds=ds) + + rr = ReportRequest.model_validate( + { + "start": enriched_session.start, + "interval": "5min", + "type": ReportType.POP_SESSION, + } + ) + df = enriched_session.to_admin_response( + product_ids=self.product_uuids, rr=rr, client=client + ) + + path = Path( + os.path.join(mnt_gr_api, rr.report_type.value, f"{self.file_key}.parquet") + ) + + df.to_parquet( + path=path, + engine="pyarrow", + compression="brotli", + ) + + try: + test = pd.read_parquet(path, engine="pyarrow") + except Exception as e: + raise IOError(f"Parquet verification failed: {e}") + + return None + + def prebuild_enriched_wall_parquet( + self, + thl_pg_config: PostgresConfig, + ds: "GRLDatasets", + client: Client, + mnt_gr_api: Path, + enriched_wall: Optional["EnrichedWallMerge"] = None, + ) -> None: + self.prefetch_products(thl_pg_config=thl_pg_config) + + if enriched_wall is None: + from generalresearch.incite.defaults import ( + enriched_wall as ew, + ) + + enriched_wall = ew(ds=ds) + + rr = ReportRequest.model_validate( + { + "start": enriched_wall.start, + "interval": "5min", + "report_type": ReportType.POP_EVENT, + } + ) + df = enriched_wall.to_admin_response( + product_ids=self.product_uuids, rr=rr, client=client + ) + + path = Path( + os.path.join(mnt_gr_api, rr.report_type.value, f"{self.file_key}.parquet") + ) + + df.to_parquet( + path=path, + engine="pyarrow", + compression="brotli", + ) + + try: + test = pd.read_parquet(path, engine="pyarrow") + except Exception as e: + raise IOError(f"Parquet verification failed: {e}") + + return None + + @classmethod + def required_fields(cls) -> List[str]: + return [ + field_name + for field_name, field_info in cls.model_fields.items() + if field_info.is_required() + ] + + # --- Properties --- + @property + def cache_key(self) -> str: + return f"team:{self.uuid}" + + @property + def file_key(self) -> str: + return f"team-{self.uuid}" + + @property + def product_ids(self) -> Optional[List[UUIDStr]]: + if self.products is None: + LOG.warning("prefetch not run") + return None + + return [p.uuid for p in self.products] + + @property + def product_uuids(self) -> Optional[List[UUIDStr]]: + return self.product_ids + + # --- Methods --- + + def set_cache( + self, + pg_config: PostgresConfig, + thl_web_rr: PostgresConfig, + redis_config: RedisConfig, + client: "Client", + ds: "GRLDatasets", + mnt_gr_api: Union[Path, str], + enriched_session: Optional["EnrichedSessionMerge"] = None, + enriched_wall: Optional["EnrichedWallMerge"] = None, + ) -> None: + ex_secs = 60 * 60 * 24 * 3 # 3 days + + self.prefetch_products(thl_pg_config=thl_web_rr) + self.prefetch_gr_users(pg_config=pg_config, redis_config=redis_config) + self.prefetch_businesses(pg_config=pg_config, redis_config=redis_config) + self.prefetch_memberships(pg_config=pg_config) + + rc = redis_config.create_redis_client() + mapping = self.model_dump(mode="json") + for key in mapping: + mapping[key] = json.dumps(mapping[key]) + rc.hset(name=self.cache_key, mapping=mapping) + + # -- Saves Parquet files + if enriched_session is None: + from generalresearch.incite.defaults import ( + enriched_session as es, + ) + + enriched_session = es(ds=ds) + + self.prebuild_enriched_session_parquet( + thl_pg_config=thl_web_rr, + client=client, + ds=ds, + mnt_gr_api=mnt_gr_api, + enriched_session=enriched_session, + ) + + if enriched_wall is None: + from generalresearch.incite.defaults import enriched_wall as ew + + enriched_wall = ew(ds=ds) + + self.prebuild_enriched_wall_parquet( + thl_pg_config=thl_web_rr, + client=client, + ds=ds, + mnt_gr_api=mnt_gr_api, + enriched_wall=enriched_wall, + ) + + return None + + # --- ORM --- + + @classmethod + def from_redis( + cls, + uuid: UUIDStr, + fields: List[str], + gr_redis_config: RedisConfig, + ) -> Optional[Self]: + keys: List = Team.required_fields() + fields + rc = gr_redis_config.create_redis_client() + + try: + res: List = rc.hmget(name=f"team:{uuid}", keys=keys) + d = {val: json.loads(res[idx]) for idx, val in enumerate(keys)} + return Team.model_validate(d) + except (Exception,) as e: + return None diff --git a/generalresearch/models/innovate/__init__.py b/generalresearch/models/innovate/__init__.py new file mode 100644 index 0000000..054c69d --- /dev/null +++ b/generalresearch/models/innovate/__init__.py @@ -0,0 +1,38 @@ +from enum import Enum + +from pydantic import StringConstraints +from typing_extensions import Annotated + +# Note, this is called the KEY in the Question model +InnovateQuestionID = Annotated[ + str, StringConstraints(min_length=1, max_length=64, pattern=r"^[^A-Z]+$") +] + + +class InnovateStatus(str, Enum): + LIVE = "LIVE" + NOT_LIVE = "NOT_LIVE" + + +class InnovateQuotaStatus(str, Enum): + OPEN = "OPEN" + CLOSED = "CLOSED" + + +class InnovateDuplicateCheckLevel(str, Enum): + # How we should check for de-dupes / survey exclusions. + # https://innovatemr.stoplight.io/docs/supplier-api/ZG9jOjEzNzYxMTg2-statuses-term-reasons-and-categories + # #duplicatedtoken + + JOB = "JOB" # user cannot participate if they have participated in a survey with the same job id + EXCLUDED_SURVEYS = "EX_SURVEYS" # cannot participate if they've done any survey in the "excluded_surveys" + SURVEY = "SURVEY" # only dedupe check is on the survey itself + NA = "NA" # idk how this is different from SURVEY + + @classmethod + def from_api(cls, s: str): + return { + "Job Level": cls.JOB, + "Multi Surveys": cls.EXCLUDED_SURVEYS, + "Survey Level": cls.SURVEY, + }.get(s, cls.NA) diff --git a/generalresearch/models/innovate/question.py b/generalresearch/models/innovate/question.py new file mode 100644 index 0000000..0fd2547 --- /dev/null +++ b/generalresearch/models/innovate/question.py @@ -0,0 +1,244 @@ +# https://innovatemr.stoplight.io/docs/supplier-api/d21fa72c538db-lookup-question-library +from __future__ import annotations + +import json +import logging +from enum import Enum +from typing import List, Optional, Literal, Any, Dict + +from pydantic import BaseModel, Field, model_validator, field_validator + +from generalresearch.models import Source +from generalresearch.models.innovate import InnovateQuestionID +from generalresearch.models.thl.profiling.marketplace import ( + MarketplaceQuestion, + MarketplaceUserQuestionAnswer, +) + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +class InnovateUserQuestionAnswer(MarketplaceUserQuestionAnswer): + # Note, this is referred to as the KEY in the Question model + question_id: InnovateQuestionID = Field() + question_type: Optional[InnovateQuestionType] = Field(default=None) + # Did this answer come from us asking, or was it passed back from the marketplace + from_thl: bool = Field(default=True) + + +class InnovateQuestionOption(BaseModel): + id: str = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + frozen=True, + description="The unique identifier for a response to a qualification", + ) + text: str = Field( + min_length=1, + max_length=1024, + frozen=True, + description="The response text shown to respondents", + ) + order: int = Field() + + +class InnovateQuestionType(str, Enum): + # API response: {'Multipunch', 'Numeric Open Ended', 'Single Punch'} + # "Numeric Open Ended" must be wrong... It can't be numeric, as UK's postcode question is marked + # as this, but it wants alphanumeric answers. So this is just text_entry. + + SINGLE_SELECT = "s" + MULTI_SELECT = "m" + TEXT_ENTRY = "t" + + @staticmethod + def get_api_map(): + return { + "Single Punch": InnovateQuestionType.SINGLE_SELECT, + "Multipunch": InnovateQuestionType.MULTI_SELECT, + "Numeric Open Ended": InnovateQuestionType.TEXT_ENTRY, + } + + @classmethod + def from_api(cls, a: int): + API_TYPE_MAP = cls.get_api_map() + return API_TYPE_MAP[a] if a in API_TYPE_MAP else None + + +class InnovateQuestion(MarketplaceQuestion): + # Each question has an ID (numerical) and a Name (which they call "Key") which are both unique. The + # key is what is used throughout, so this what will be used as the primary key. + question_key: str = Field( + min_length=1, + max_length=64, + pattern=r"^[^A-Z]+$", + description="Primary identifier that is used throughout Innovate", + frozen=True, + ) + question_id: str = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + description="Numerical identifier for the qualification", + frozen=True, + ) + + question_text: str = Field( + max_length=1024, + min_length=1, + description="The text shown to respondents", + frozen=False, + ) + question_type: InnovateQuestionType = Field( + description="The type of question asked", frozen=True + ) + # This comes from the API field "Category". There are some useful categories in here, but a bunch have + # categories that are not (e.g. NFX - Adhoc, Testing_Cat). We'll store it as a comma-separated string + # here to use it to aid our own real categorization. + tags: Optional[str] = Field(default=None, frozen=True) + options: Optional[List[InnovateQuestionOption]] = Field( + default=None, min_length=1, frozen=True + ) + + source: Literal[Source.INNOVATE] = Source.INNOVATE + + @property + def internal_id(self) -> str: + return self.question_key + + @model_validator(mode="after") + def check_type_options_agreement(self): + # If type == "text_entry", options is None. Otherwise, must be set. + if self.question_type == InnovateQuestionType.TEXT_ENTRY: + assert self.options is None, "TEXT_ENTRY shouldn't have options" + else: + assert self.options is not None, "missing options" + return self + + @field_validator("options") + @classmethod + def order_options(cls, options): + if options: + options.sort(key=lambda x: x.order) + return options + + @field_validator("question_key", mode="before") + @classmethod + def question_key_lower(cls, v: str) -> str: + if v.lower() != v: + logger.warning(f"question key {v} should be lowercase!") + v = v.lower() + return v + + @classmethod + def from_api( + cls, d: dict, country_iso: str, language_iso: str + ) -> Optional["InnovateQuestion"]: + """ + :param d: Raw response from API + :param country_iso: + :param language_iso: + :return: + """ + try: + return cls._from_api(d, country_iso, language_iso) + except Exception as e: + logger.warning(f"Unable to parse question: {d}. {e}") + return None + + @classmethod + def _from_api( + cls, d: dict, country_iso: str, language_iso: str + ) -> "InnovateQuestion": + # Question AGE returns options even though its marked as a text entry (but only in some locales) + d["QuestionKey"] = d["QuestionKey"].lower() + if d["QuestionKey"] == "age": + d["QuestionOptions"] = [] + + options = None + if d.get("QuestionOptions"): + options = [ + InnovateQuestionOption( + id=str(r["id"]), text=r["OptionText"], order=r["Order"] + ) + for r in d["QuestionOptions"] + ] + tags = ",".join(map(str.strip, d["Category"])) + return cls( + question_id=str(d["QuestionId"]), + question_key=d["QuestionKey"], + question_text=d["QuestionText"], + question_type=InnovateQuestionType.from_api(d["QuestionType"]), + tags=tags, + options=options, + country_iso=country_iso, + language_iso=language_iso, + ) + + @classmethod + def from_db(cls, d: dict) -> "InnovateQuestion": + options = None + if d["options"]: + options = [ + InnovateQuestionOption(id=r["id"], text=r["text"], order=r["order"]) + for r in d["options"] + ] + return cls( + question_id=d["question_id"], + question_key=d["question_key"], + question_text=d["question_text"], + question_type=d["question_type"], + country_iso=d["country_iso"], + language_iso=d["language_iso"], + options=options, + is_live=d["is_live"], + category_id=d.get("category_id"), + tags=d["tags"], + ) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", by_alias=True) + d["options"] = json.dumps(d["options"]) + return d + + def to_upk_question(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionType, + UpkQuestionSelectorMC, + UpkQuestionSelectorTE, + UpkQuestion, + ) + + upk_type_selector_map = { + InnovateQuestionType.SINGLE_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.SINGLE_ANSWER, + ), + InnovateQuestionType.MULTI_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.MULTIPLE_ANSWER, + ), + InnovateQuestionType.TEXT_ENTRY: ( + UpkQuestionType.TEXT_ENTRY, + UpkQuestionSelectorTE.SINGLE_LINE, + ), + } + upk_type, upk_selector = upk_type_selector_map[self.question_type] + d = { + "ext_question_id": self.external_id, + "country_iso": self.country_iso, + "language_iso": self.language_iso, + "type": upk_type, + "selector": upk_selector, + "text": self.question_text, + } + if self.options: + d["choices"] = [ + UpkQuestionChoice(id=c.id, text=c.text, order=c.order) + for c in self.options + ] + return UpkQuestion(**d) diff --git a/generalresearch/models/innovate/survey.py b/generalresearch/models/innovate/survey.py new file mode 100644 index 0000000..e230899 --- /dev/null +++ b/generalresearch/models/innovate/survey.py @@ -0,0 +1,491 @@ +from __future__ import annotations + +import json +import logging +from datetime import timezone, date +from decimal import Decimal +from functools import cached_property +from typing import ( + Optional, + Dict, + Any, + List, + Literal, + Set, + Tuple, + Annotated, + Type, +) + +from more_itertools import flatten +from pydantic import ( + Field, + ConfigDict, + BaseModel, + model_validator, + computed_field, +) +from typing_extensions import Self + +from generalresearch.locales import Localelator +from generalresearch.models import ( + Source, + LogicalOperator, + TaskCalculationType, +) +from generalresearch.models.custom_types import ( + CoercedStr, + AwareDatetimeISO, + AlphaNumStrSet, + DeviceTypes, +) +from generalresearch.models.innovate import ( + InnovateStatus, + InnovateQuotaStatus, + InnovateDuplicateCheckLevel, +) +from generalresearch.models.innovate.question import InnovateQuestionID +from generalresearch.models.thl.demographics import Gender +from generalresearch.models.thl.survey import MarketplaceTask +from generalresearch.models.thl.survey.condition import ( + ConditionValueType, + MarketplaceCondition, +) + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +locale_helper = Localelator() + + +class InnovateCondition(MarketplaceCondition): + model_config = ConfigDict(populate_by_name=True, frozen=False, extra="ignore") + # store everything lowercase ! + question_id: Optional[CoercedStr] = Field( + min_length=1, max_length=64, pattern=r"^[^A-Z]+$" + ) + # There isn't really a hard limit, but their API is inconsistent and + # sometimes returns all the options comma-separated instead of as a list. + # Try to catch that. + values: List[Annotated[str, Field(max_length=128)]] = Field() + + @classmethod + def from_api(cls, d: Dict[str, Any]) -> "InnovateCondition": + d["logical_operator"] = LogicalOperator.OR + d["value_type"] = ConditionValueType.LIST + d["negate"] = False + d["values"] = list(set(x.strip().lower() for x in d["values"])) + return cls.model_validate(d) + + +class InnovateQuota(BaseModel): + model_config = ConfigDict(populate_by_name=True, frozen=True) + + desired_count: int = Field() + remaining_count: int = Field() + complete_count: int = Field() + start_count: int = Field() + + status: InnovateQuotaStatus = Field() + task_calculation_type: TaskCalculationType = Field() + hard_stop: bool = Field() + + condition_hashes: List[str] = Field(min_length=0, default_factory=list) + + def __hash__(self): + return hash(tuple((tuple(self.condition_hashes), self.remaining_count))) + + @property + def is_open(self) -> bool: + min_open_spots = 3 + return ( + self.remaining_count >= min_open_spots + and self.status == InnovateQuotaStatus.OPEN + ) + + @classmethod + def from_api(cls, d: Dict): + return cls.model_validate(d) + + def passes(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Passes means we 1) meet all conditions (aka "match") AND 2) the quota is open. + return self.is_open and self.matches(criteria_evaluation) + + def matches(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Matches means we meet all conditions. + # We can "match" a quota that is closed. In that case, we would not be eligible for the survey. + return all(criteria_evaluation.get(c) for c in self.condition_hashes) + + def matches_optional( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Optional[bool]: + # We need to know if any conditions are unknown to avoid matching a full quota. If any fail, + # then we know we fail regardless of any being unknown. + evals = [criteria_evaluation.get(c) for c in self.condition_hashes] + if False in evals: + return False + if None in evals: + return None + return True + + def matches_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Passes back "matches" (T/F/none) and a list of unknown criterion hashes + hash_evals = { + cell: criteria_evaluation.get(cell) for cell in self.condition_hashes + } + evals = set(hash_evals.values()) + if False in evals: + return False, set() + if None in evals: + return None, {cell for cell, ev in hash_evals.items() if ev is None} + return True, set() + + +class InnovateSurvey(MarketplaceTask): + model_config = ConfigDict(populate_by_name=True) + + survey_id: CoercedStr = Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") + # There is no status returned, using one I make up b/c is_live depends on it, + status: InnovateStatus = Field(default=InnovateStatus.LIVE) + # is_live: bool = Field(default=True) # can't overload the is_live property ... + cpi: Decimal = Field(gt=0, le=100, decimal_places=2, max_digits=5) + buyer_id: CoercedStr = Field(max_length=32) + + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: str = Field( + max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True + ) + # 3-char ISO 639-2/B, lowercase + language_iso: str = Field( + max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True + ) + + job_id: str = Field(description="basically a project id") + survey_name: str = Field() + + desired_count: int = Field() + remaining_count: int = Field() + supplier_completes_achieved: int = Field() + global_completes: int = Field() + global_starts: int = Field() + global_median_loi: Optional[int] = Field(le=120 * 60) + global_conversion: Optional[float] = Field(ge=0, le=1) + + bid_loi: Optional[int] = Field(default=None, le=120 * 60) + bid_ir: Optional[float] = Field(default=None, ge=0, le=1) + + allowed_devices: DeviceTypes = Field(min_length=1) + + entry_link: str = Field() + category: str = Field() + requires_pii: bool = Field(default=False) + + excluded_surveys: Optional[AlphaNumStrSet] = Field( + description="list of excluded survey ids", default=None + ) + duplicate_check_level: InnovateDuplicateCheckLevel = Field() + + exclude_pids: Optional[AlphaNumStrSet] = Field(default=None) + include_pids: Optional[AlphaNumStrSet] = Field(default=None) + + # idk what these mean + is_revenue_sharing: bool = Field() + group_type: str = Field() + # undocumented, not sure how we use this + off_hour_traffic: Optional[Dict] = Field(default=None) + + qualifications: List[str] = Field(default_factory=list) + quotas: List[InnovateQuota] = Field(default_factory=list) + + source: Literal[Source.INNOVATE] = Field(default=Source.INNOVATE) + + used_question_ids: Set[InnovateQuestionID] = Field(default_factory=set) + + # This is a "special" key to store all conditions that are used (as "condition_hashes") throughout + # this survey. In the reduced representation of this task (nearly always, for db i/o, in global_vars) + # this field will be null. + conditions: Optional[Dict[str, InnovateCondition]] = Field(default=None) + + # These come from the API + created_api: AwareDatetimeISO = Field( + description="When the survey was created in innovate's system" + ) + modified_api: AwareDatetimeISO = Field( + description="When the survey was last updated in innovate's system" + ) + expected_end_date: date = Field() + + # This does not come from the API. We set it when we update this in the db. + created: Optional[AwareDatetimeISO] = Field(default=None) + updated: Optional[AwareDatetimeISO] = Field(default=None) + + @property + def internal_id(self) -> str: + return self.survey_id + + @computed_field + def is_live(self) -> bool: + return self.status == InnovateStatus.LIVE + + @property + def is_open(self) -> bool: + # The survey is open if the status is OPEN and there is at least 1 open quota (or there are no quotas!) + return self.is_live and ( + any(q.is_open for q in self.quotas) or len(self.quotas) == 0 + ) + + @computed_field + @cached_property + def all_hashes(self) -> Set[str]: + s = set(self.qualifications) + for q in self.quotas: + s.update(set(q.condition_hashes)) + return s + + @model_validator(mode="before") + @classmethod + def set_locale(cls, data: Any): + data["country_isos"] = [data["country_iso"]] + data["language_isos"] = [data["language_iso"]] + return data + + @model_validator(mode="before") + @classmethod + def set_used_questions(cls, data: Any): + if data.get("used_question_ids") is not None: + return data + if not data.get("conditions"): + data["used_question_ids"] = set() + return data + data["used_question_ids"] = { + c.question_id for c in data["conditions"].values() if c.question_id + } + return data + + @classmethod + def from_api(cls, d: Dict) -> Optional["InnovateSurvey"]: + try: + return cls._from_api(d) + except Exception as e: + logger.warning(f"Unable to parse survey: {d}. {e}") + return None + + @classmethod + def _from_api(cls, d: Dict): + d["conditions"] = dict() + + # If we haven't hit the "detail" endpoint, we won't get this + d.setdefault("qualifications", []) + for q in d["qualifications"]: + d["conditions"][q.criterion_hash] = q + d["qualifications"] = [x.criterion_hash for x in d["qualifications"]] + + quotas = [] + d.setdefault("quotas", []) + for quota in d["quotas"]: + conditions = quota["conditions"] + quota["condition_hashes"] = [x.criterion_hash for x in conditions] + quotas.append(InnovateQuota.from_api(quota)) + for q in conditions: + d["conditions"][q.criterion_hash] = q + d["quotas"] = quotas + return cls.model_validate(d) + + @property + def condition_model(self) -> Type[MarketplaceCondition]: + return InnovateCondition + + @property + def age_question(self) -> str: + return "age" + + @property + def marketplace_genders(self): + # There is also a "gender_plus", but it doesn't seem widely used. + return { + Gender.MALE: InnovateCondition( + question_id="gender", + values=["1"], + value_type=ConditionValueType.LIST, + ), + Gender.FEMALE: InnovateCondition( + question_id="gender", + values=["2"], + value_type=ConditionValueType.LIST, + ), + Gender.OTHER: None, + } + + def __repr__(self) -> str: + # Fancy repr that abbreviates exclude_pids and excluded_surveys + repr_args = list(self.__repr_args__()) + for n, (k, v) in enumerate(repr_args): + if k in {"exclude_pids", "include_pids", "excluded_surveys"}: + if v and len(v) > 6: + v = sorted(v) + v = v[:3] + ["…"] + v[-3:] + repr_args[n] = (k, v) + join_str = ", " + repr_str = join_str.join( + repr(v) if a is None else f"{a}={v!r}" for a, v in repr_args + ) + return f"{self.__repr_name__()}({repr_str})" + + def is_unchanged(self, other) -> bool: + # Avoiding overloading __eq__ because it looks kind of complicated? I + # want to be explicit that this is not testing object equivalence, + # just that the objects don't require any db updates. We also exclude + # conditions b/c this is just the condition_hash definitions + return self.model_dump( + exclude={"updated", "conditions", "created"} + ) == other.model_dump(exclude={"updated", "conditions", "created"}) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump( + mode="json", + exclude={ + "all_hashes", + "country_isos", + "language_isos", + "source", + "conditions", + }, + ) + d["qualifications"] = json.dumps(d["qualifications"]) + d["quotas"] = json.dumps(d["quotas"]) + d["used_question_ids"] = json.dumps(sorted(d["used_question_ids"])) + d["off_hour_traffic"] = json.dumps(d["off_hour_traffic"]) + d["modified_api"] = self.modified_api + d["created_api"] = self.created_api + d["updated"] = self.updated + d["created"] = self.created + return d + + @classmethod + def from_db(cls, d: Dict[str, Any]) -> Self: + d["created"] = d["created"].replace(tzinfo=timezone.utc) + d["updated"] = d["updated"].replace(tzinfo=timezone.utc) + d["modified_api"] = d["modified_api"].replace(tzinfo=timezone.utc) + d["created_api"] = d["created_api"].replace(tzinfo=timezone.utc) + d["qualifications"] = json.loads(d["qualifications"]) + d["used_question_ids"] = json.loads(d["used_question_ids"]) + d["quotas"] = json.loads(d["quotas"]) + d["off_hour_traffic"] = json.loads(d["off_hour_traffic"]) + return cls.model_validate(d) + + def participation_allowed( + self, att_survey_ids: Set[str], att_job_ids: Set[str] + ) -> bool: + """ + Checks if this user can participate in this survey based on the 'duplicate_check_level'-dictated requirements + :param att_survey_ids: list of the user's previously attempted survey IDs + :param att_job_ids: list of the user's previously attempted survey ID's Job IDs + """ + assert isinstance(att_survey_ids, set), "must pass a set" + assert isinstance(att_job_ids, set), "must pass a set" + if self.survey_id in att_survey_ids: + return False + if self.duplicate_check_level == InnovateDuplicateCheckLevel.JOB: + if self.job_id in att_job_ids: + return False + if self.duplicate_check_level == InnovateDuplicateCheckLevel.EXCLUDED_SURVEYS: + if self.excluded_surveys.intersection(att_survey_ids): + return False + return True + + def passes_qualifications( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> bool: + # We have to match all quals + return all(criteria_evaluation.get(q) for q in self.qualifications) + + def passes_qualifications_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Passes back "passes" (T/F/none) and a list of unknown criterion hashes + hash_evals = {q: criteria_evaluation.get(q) for q in self.qualifications} + evals = set(hash_evals.values()) + # We have to match all. So if any are False, we know we don't pass + if False in evals: + return False, set() + # If any are None, we don't know + if None in evals: + return None, {cell for cell, ev in hash_evals.items() if ev is None} + return True, set() + + def passes_quotas(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Many surveys have 0 quotas. Quotas are exclusionary. + # They can NOT match a quota where currently_open=0 + any_pass = True + for q in self.quotas: + matches = q.matches_optional(criteria_evaluation) + if matches in {True, None} and not q.is_open: + # We also cannot be unknown for this quota, b/c we might fall into it, which would be a fail. + return False + return any_pass + + def passes_quotas_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Many surveys have 0 quotas. Quotas are exclusionary. + # They can NOT match a quota where currently_open=0 + if len(self.quotas) == 0: + return True, set() + quota_eval = { + quota: quota.matches_soft(criteria_evaluation) for quota in self.quotas + } + evals = set(g[0] for g in quota_eval.values()) + if any(m[0] is True and not q.is_open for q, m in quota_eval.items()): + # matched a full quota + return False, set() + if any(m[0] is None and not q.is_open for q, m in quota_eval.items()): + # Unknown match for full quota + if True in evals: + # we match 1 other, so the missing are only this type + return None, set( + flatten( + [ + m[1] + for q, m in quota_eval.items() + if m[0] is None and not q.is_open + ] + ) + ) + else: + # we don't match any quotas, so everything is unknown + return None, set( + flatten([m[1] for q, m in quota_eval.items() if m[0] is None]) + ) + if True in evals: + return True, set() + if None in evals: + return None, set( + flatten([m[1] for q, m in quota_eval.items() if m[0] is None]) + ) + return False, set() + + def determine_eligibility( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> bool: + return ( + self.is_open + and self.passes_qualifications(criteria_evaluation) + and self.passes_quotas(criteria_evaluation) + ) + + def determine_eligibility_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + if self.is_open is False: + return False, set() + pass_quals, h_quals = self.passes_qualifications_soft(criteria_evaluation) + pass_quotas, h_quotas = self.passes_quotas_soft(criteria_evaluation) + if pass_quals and pass_quotas: + return True, set() + elif pass_quals is False or pass_quotas is False: + return False, set() + else: + return None, h_quals | h_quotas diff --git a/generalresearch/models/innovate/task_collection.py b/generalresearch/models/innovate/task_collection.py new file mode 100644 index 0000000..97e7a7c --- /dev/null +++ b/generalresearch/models/innovate/task_collection.py @@ -0,0 +1,97 @@ +from typing import List, Set + +import pandas as pd +from pandera import Column, DataFrameSchema, Check, Index + +from generalresearch.locales import Localelator +from generalresearch.models.innovate import InnovateStatus +from generalresearch.models.innovate.survey import InnovateSurvey +from generalresearch.models.thl.survey.task_collection import ( + TaskCollection, + create_empty_df_from_schema, +) + +COUNTRY_ISOS: Set[str] = Localelator().get_all_countries() +LANGUAGE_ISOS: Set[str] = Localelator().get_all_languages() + +InnovateTaskCollectionSchema = DataFrameSchema( + columns={ + "survey_name": Column(str, Check.str_length(min_value=1, max_value=256)), + "status": Column(str, Check.isin(InnovateStatus)), + "cpi": Column(float, Check.between(min_value=0, max_value=100)), + "buyer_id": Column(str), + "country_iso": Column(str, Check.isin(COUNTRY_ISOS)), # 2 letter, lowercase + "language_iso": Column(str, Check.isin(LANGUAGE_ISOS)), # 3 letter, lowercase + "job_id": Column(str), + "category": Column(str), + "desired_count": Column(int), + "remaining_count": Column(int), + "supplier_completes_achieved": Column(int), + "global_completes": Column(int), + "global_starts": Column(int), + "global_median_loi": Column("Int32", Check.between(0, 90 * 60), nullable=True), + "global_conversion": Column(float, Check.between(0, 1), nullable=True), + "bid_loi": Column("Int32", Check.between(0, 90 * 60), nullable=True), + "bid_ir": Column(float, Check.between(0, 1), nullable=True), + "allowed_devices": Column(str), + "requires_pii": Column(bool), + # exclude_pids is potentially large. We don't need these usually, we just want to know + # if include_pids is set, if so then this is a recontact + # "exclude_pids": Column(bool), + "include_pids": Column(str, nullable=True), + "created_api": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "modified_api": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "updated": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "used_question_ids": Column(List[str]), + "all_hashes": Column(List[str]), # set >> list for column support + }, + checks=[], + index=Index( + str, + name="survey_id", + checks=Check.str_length(min_value=1, max_value=16), + unique=True, + ), + strict=True, + coerce=True, + drop_invalid_rows=False, +) + + +class InnovateTaskCollection(TaskCollection): + items: List[InnovateSurvey] + _schema = InnovateTaskCollectionSchema + + def to_row(self, s: InnovateSurvey): + d = s.model_dump( + mode="json", + exclude={ + "country_isos", + "language_isos", + "qualifications", + "quotas", + "source", + "conditions", + "is_live", + "excluded_surveys", + "exclude_pids", + "entry_link", + "duplicate_check_level", + "is_revenue_sharing", + "group_type", + "off_hour_traffic", + "expected_end_date", + "created", + }, + ) + d["cpi"] = float(s.cpi) + return d + + def to_df(self) -> pd.DataFrame: + rows = [] + for s in self.items: + rows.append(self.to_row(s)) + if rows: + return pd.DataFrame.from_records(rows, index="survey_id") + else: + return create_empty_df_from_schema(self._schema) diff --git a/generalresearch/models/legacy/__init__.py b/generalresearch/models/legacy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/models/legacy/api_status.py b/generalresearch/models/legacy/api_status.py new file mode 100644 index 0000000..8241eca --- /dev/null +++ b/generalresearch/models/legacy/api_status.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from typing import Literal, Optional + +from pydantic import BaseModel, Field + +""" + Messed up consistency, and we have multiple different formats + for error reporting for no reason. Faithfully recreating them here... +""" + + +class StatusResponse(BaseModel): + status: Literal["success", "error"] = Field( + description="The status of the API response.", examples=["success"] + ) + msg: Optional[str] = Field( + description="An optional message, if the status is error.", + examples=[""], + default=None, + ) + + +class StatusResponseError(BaseModel): + status: Literal["error"] = Field( + description="The status of the API response.", examples=["error"] + ) + msg: str = Field( + description="An optional message, if the status is error.", + examples=["An error has occurred"], + ) + + +class StatusResponseFailure(BaseModel): + status: Literal["failure"] = Field( + description="The status of the API response.", examples=["failure"] + ) + msg: str = Field( + description="An optional message, if the status is failure.", + examples=["An error has occurred"], + ) + + +class StatusSuccess(BaseModel): + success: bool = Field( + default=True, description="Whether the API response is successful." + ) + + +class StatusSuccessFail(StatusSuccess): + success: bool = Field( + default=False, description="Whether the API response is successful." + ) + + +class StatusInfoResponse(BaseModel): + info: StatusSuccess = Field() + msg: str = Field( + description="An optional message, if success is False", + examples=[""], + default="", + ) + + +class StatusInfoResponseFail(BaseModel): + info: StatusSuccessFail = Field() + msg: str = Field( + description="An optional message, if success is False", + examples=["An error has occurred"], + ) diff --git a/generalresearch/models/legacy/bucket.py b/generalresearch/models/legacy/bucket.py new file mode 100644 index 0000000..5afb17b --- /dev/null +++ b/generalresearch/models/legacy/bucket.py @@ -0,0 +1,772 @@ +from __future__ import annotations + +import logging +import math +from datetime import timedelta +from decimal import Decimal +from typing import Optional, Dict, List, Union, Literal, Tuple +from typing_extensions import Self + +from pydantic import ( + BaseModel, + Field, + field_validator, + model_validator, + ConfigDict, + NonNegativeInt, +) + +from generalresearch.models import Source +from generalresearch.models.custom_types import ( + HttpsUrl, + UUIDStr, + PropertyCode, +) +from generalresearch.models.thl.stats import StatisticalSummary + +logger = logging.getLogger() + +Eligibility = Literal["conditional", "unconditional", "ineligible"] + +SourceName = Literal[ + "innovate", + "dynata", + "schlesinger", + "purespectrum", + "morning", + "pollfish", + "precision", + "repdata", + "prodege", +] + + +class CategoryAssociation(BaseModel): + """Used in an offerwall. Stores the association between a category + and a bucket, with a score. + """ + + id: UUIDStr = Field( + description="The category ID", + examples=["c8642a1b86d9460cbe8f7e8ae6e56ee4"], + ) + + label: str = Field( + max_length=255, + description="The category label", + examples=["People & Society"], + ) + + adwords_id: Optional[str] = Field(default=None, max_length=8, examples=["14"]) + + adwords_label: Optional[str] = Field( + default=None, max_length=255, examples=["People & Society"] + ) + + p: float = Field( + ge=0, + le=1, + examples=[1.0], + description="The strength of the association of this bucket" + "with this category. Will sum to 1 within a bucket.", + ) + + +class BucketTask(BaseModel): + """ + This represents one of the "tasks" within a bucket's ordered list of tasks. + """ + + id: str = Field( + min_length=1, + max_length=32, + examples=["6ov9jz3"], + description="The internal task id for this task within the marketplace", + ) + id_code: str = Field( + min_length=3, + max_length=35, + pattern=r"^[a-z]{1,2}\:.*", + examples=["o:6ov9jz3"], + description="The namespaced task id for this task within the marketplace", + ) + source: Source = Field(examples=[Source.POLLFISH]) + loi: int = Field( + gt=1, le=90 * 60, description="expected loi in seconds", examples=[612] + ) + payout: int = Field(gt=1, description="integer cents", examples=[123]) + + @model_validator(mode="after") + def check_id_code(self) -> Self: + assert self.source.value + ":" + self.id == self.id_code, "ids are wrong!!" + return self + + def censor(self): + censor_idx = math.ceil(len(self.id) / 2) + self.id = self.id[:censor_idx] + ("*" * len(self.id[censor_idx:])) + self.id_code = self.source.value + ":" + self.id + + +class BucketBase(BaseModel): + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ser_json_timedelta="float", + arbitrary_types_allowed=True, + ) + + id: UUIDStr = Field( + description="Unique identifier this particular bucket", + examples=["5ba2fe5010cc4d078fc3cc0b0cc264c3"], + ) + uri: HttpsUrl = Field( + examples=[ + "https://task.generalresearch.com/api/v1/52d3f63b2709/797df4136c604a6c8599818296aae6d1/?i" + "=5ba2fe5010cc4d078fc3cc0b0cc264c3&b=test&66482fb=e7baf5e" + ], + description="The URL to send a respondent into. Must not edit this URL in any way", + ) + + x: int = Field( + description="For UI. Provides a dimensionality position for the bucket on the x-axis.", + ge=0, + default=0, + examples=[0, 1, 2], + ) + y: int = Field( + description="For UI. Provides a dimensionality position for the bucket on the y-axis.", + ge=0, + default=0, + examples=[0, 1, 2], + ) + name: str = Field( + description="Currently unused. Will always return empty string", + default="", + ) + description: str = Field( + description="Currently unused. Will always return empty string", + default="", + ) + + def censor(self): + if not hasattr(self, "contents"): + return + contents: List[BucketTask] = self.contents + for content in contents: + content.censor() + + +class Bucket(BaseModel): + """ + This isn't returned in any API response. It is used internally to GRL as + the common form to represent a bucket in all offerwalls. Depending on + which offerwall is requested, we'll convert from this format to the + requested format. + """ + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ser_json_timedelta="float", + arbitrary_types_allowed=True, + ) + + name: Optional[str] = Field(default=None) + description: Optional[str] = Field(default=None) + + # pydantic serializes this to seconds + loi_min: Optional[timedelta] = Field(strict=True, default=None) + loi_max: Optional[timedelta] = Field(strict=True, default=None) + loi_mean: Optional[timedelta] = Field(strict=True, default=None) + loi_q1: Optional[timedelta] = Field(strict=True, default=None) + loi_q2: Optional[timedelta] = Field(strict=True, default=None) + loi_q3: Optional[timedelta] = Field(strict=True, default=None) + # decimal USD. This should not have more than 2 decimal places. + # There is no way to make this "strict" and optional, so we have a separate pre-validator + user_payout_min: Optional[Decimal] = Field(default=None, lt=1000, gt=0) + user_payout_max: Optional[Decimal] = Field(default=None, lt=1000, gt=0) + user_payout_q1: Optional[Decimal] = Field(default=None, lt=1000, gt=0) + user_payout_q2: Optional[Decimal] = Field(default=None, lt=1000, gt=0) + user_payout_q3: Optional[Decimal] = Field(default=None, lt=1000, gt=0) + user_payout_mean: Optional[Decimal] = Field(default=None, lt=1000, gt=0) + + quality_score: Optional[float] = Field(default=None) + + category: List[CategoryAssociation] = Field(default_factory=list) + + contents: Optional[List[BucketTask]] = Field(default=None) + + # This could store things like "is_recontact=False" + metadata: Dict[str, Union[str, float, bool, int]] = Field(default_factory=dict) + + eligibility_criteria: Optional[Tuple[SurveyEligibilityCriterion, ...]] = Field( + description="The reasons the user is eligible for tasks in this bucket", + default=None, + ) + eligibility_explanation: Optional[str] = Field( + default=None, + description="Human-readable text explaining a user's eligibility for tasks in this bucket", + examples=[ + "You are a **47-year-old** **white** **male** with a *college degree*, who's employer's retirement plan is **Fidelity Investments**." + ], + ) + + @field_validator("loi_min", "loi_max", "loi_q1", "loi_q2", "loi_q3") + @classmethod + def check_loi_ranges(cls, v): + if v is not None: + assert v > timedelta(seconds=0), "lois should be greater than 0" + assert v <= timedelta(minutes=90), "lois should be less than 90 minutes" + return v + + @field_validator( + "user_payout_min", + "user_payout_max", + "user_payout_q1", + "user_payout_q2", + "user_payout_q3", + mode="before", + ) + @classmethod + def check_decimal_type(cls, v: Decimal) -> Decimal: + # pydantic is unable to set strict=True, so we'll do that manually here + if v is not None: + assert type(v) == Decimal, f"Must pass a Decimal, not a {type(v)}" + return v + + @field_validator( + "user_payout_min", + "user_payout_max", + "user_payout_q1", + "user_payout_q2", + "user_payout_q3", + mode="after", + ) + @classmethod + def check_payout_decimal_places(cls, v: Decimal) -> Decimal: + if v is not None: + assert ( + v.as_tuple().exponent >= -2 + ), "Must have 2 or fewer decimal places ('XXX.YY')" + # explicitly make sure it is 2 decimal places, after checking that it is already 2 or less. + v = v.quantize(Decimal("0.00")) + return v + + @model_validator(mode="after") + def check_lois(self): + if self.loi_min is not None and self.loi_max is not None: + assert self.loi_min <= self.loi_max, "loi_min should be <= loi_max" + if self.loi_q1 or self.loi_q2 or self.loi_q3: + assert ( + self.loi_q1 and self.loi_q2 and self.loi_q3 + ), "loi_q1, q2, and q3 should all be set or all None" + assert ( + self.loi_min is not None and self.loi_max is not None + ), "If loi_q1, q2, or q3 are set, then loi_min and max should be set" + assert self.loi_q1 >= self.loi_min, "loi_min should be <= loi_q1" + assert self.loi_q2 >= self.loi_q1, "loi_q1 should be <= loi_q2" + assert self.loi_q3 >= self.loi_q2, "loi_q2 should be <= loi_q3" + assert self.loi_max >= self.loi_q3, "loi_q3 should be <= loi_max" + return self + + @model_validator(mode="after") + def check_payouts(self): + if self.user_payout_min is not None and self.user_payout_max is not None: + assert ( + self.user_payout_min <= self.user_payout_max + ), "user_payout_min should be <= user_payout_max" + if self.user_payout_q1 or self.user_payout_q2 or self.user_payout_q3: + assert ( + self.user_payout_q1 and self.user_payout_q2 and self.user_payout_q3 + ), "user_payout_q1, q2, and q3 should all be set or all None" + assert ( + self.user_payout_min is not None and self.user_payout_max is not None + ), "If user_payout_q1, q2, or q3 are set, then user_payout_min and max should be set" + assert ( + self.user_payout_q1 >= self.user_payout_min + ), "user_payout_min should be <= user_payout_q1" + assert ( + self.user_payout_q2 >= self.user_payout_q1 + ), "user_payout_q1 should be <= user_payout_q2" + assert ( + self.user_payout_q3 >= self.user_payout_q2 + ), "user_payout_q2 should be <= user_payout_q3" + assert ( + self.user_payout_max >= self.user_payout_q3 + ), "user_payout_q3 should be <= user_payout_max" + return self + + @field_validator("category") + @classmethod + def check_category(cls, v: List[CategoryAssociation]) -> List[CategoryAssociation]: + assert sum(c.p for c in v) == 1, "sum of category score must be 1" + return v + + @classmethod + def parse_from_offerwall(cls, bucket: Dict): + """ + This isn't really consistent across all offerwalls... Handle three cases: + Could be {'payout': {'min': 123}}, or {'min_payout': 123} or {'payout': 123} + Only min_payout is really required. The others can be optional. + payouts - Should always be integer usd cents. + duration / loi - Should always be seconds. + """ + if "min_payout" in bucket: + return cls.parse_from_offerwall_style1(bucket) + elif "payout" in bucket and type(bucket["payout"]) is dict: + return cls.parse_from_offerwall_style2(bucket) + elif "payout" in bucket and type(bucket["payout"]) is not dict: + return cls.parse_from_offerwall_style3(bucket) + else: + logger.info("unknown bucket format") + return cls() + + @classmethod + def parse_from_offerwall_style1(cls, bucket: Dict): + # {'min_payout': 123} + return cls( + user_payout_min=cls.usd_cents_to_decimal(bucket["min_payout"]), + user_payout_max=cls.usd_cents_to_decimal(bucket.get("max_payout")), + user_payout_q1=cls.usd_cents_to_decimal(bucket.get("q1_payout")), + user_payout_q2=cls.usd_cents_to_decimal(bucket.get("q2_payout")), + user_payout_q3=cls.usd_cents_to_decimal(bucket.get("q3_payout")), + loi_min=( + timedelta(seconds=bucket["min_duration"]) + if bucket.get("min_duration") is not None + else None + ), + loi_max=( + timedelta(seconds=bucket["max_duration"]) + if bucket.get("max_duration") is not None + else None + ), + loi_q1=( + timedelta(seconds=bucket["q1_duration"]) + if bucket.get("q1_duration") is not None + else None + ), + loi_q2=( + timedelta(seconds=bucket["q2_duration"]) + if bucket.get("q2_duration") is not None + else None + ), + loi_q3=( + timedelta(seconds=bucket["q3_duration"]) + if bucket.get("q3_duration") is not None + else None + ), + ) + + @classmethod + def parse_from_offerwall_style2(cls, bucket: Dict): + # {'payout': {'min': 123}} + loi_min_sec = bucket.get("duration", {}).get("min") + loi_max_sec = bucket.get("duration", {}).get("max") + loi_q1_sec = bucket.get("duration", {}).get("q1") + loi_q2_sec = bucket.get("duration", {}).get("q2") + loi_q3_sec = bucket.get("duration", {}).get("q3") + return cls( + user_payout_min=cls.usd_cents_to_decimal(bucket["payout"]["min"]), + user_payout_max=cls.usd_cents_to_decimal(bucket["payout"].get("max")), + user_payout_q1=cls.usd_cents_to_decimal(bucket["payout"].get("q1")), + user_payout_q2=cls.usd_cents_to_decimal(bucket["payout"].get("q2")), + user_payout_q3=cls.usd_cents_to_decimal(bucket["payout"].get("q3")), + loi_min=( + timedelta(seconds=loi_min_sec) if loi_min_sec is not None else None + ), + loi_max=( + timedelta(seconds=loi_max_sec) if loi_max_sec is not None else None + ), + loi_q1=(timedelta(seconds=loi_q1_sec) if loi_q1_sec is not None else None), + loi_q2=(timedelta(seconds=loi_q2_sec) if loi_q2_sec is not None else None), + loi_q3=(timedelta(seconds=loi_q3_sec) if loi_q3_sec is not None else None), + ) + + @classmethod + def parse_from_offerwall_style3(cls, bucket: Dict): + # {'payout': 123, 'duration': 123} + return cls( + user_payout_min=cls.usd_cents_to_decimal(bucket["payout"]), + user_payout_max=None, + loi_min=None, + loi_max=( + timedelta(seconds=bucket["duration"]) + if bucket.get("duration") is not None + else None + ), + ) + + @staticmethod + def usd_cents_to_decimal(v: int): + if v is None: + return None + return Decimal(Decimal(int(v)) / Decimal(100)) + + @staticmethod + def decimal_to_usd_cents(d: Decimal): + if d is None: + return None + return round(d * Decimal(100), 2) + + +class DurationSummary(StatisticalSummary): + """Durations are in integer seconds. + Describes the statistical distribution of expected durations of tasks within this bucket. + """ + + min: int = Field(gt=0, le=60 * 90) + max: int = Field(gt=0, le=60 * 90) + q1: int = Field(gt=0, le=60 * 90) + q2: int = Field(gt=0, le=60 * 90) + q3: int = Field(gt=0, le=60 * 90) + mean: Optional[int] = Field(gt=0, le=60 * 90, default=None) + + model_config = { + "json_schema_extra": { + "examples": [ + { + "min": 112, + "max": 1180, + "q1": 457, + "q2": 650, + "q3": 1103, + "mean": 660, + } + ] + } + } + + @classmethod + def from_bucket(cls, bucket: Bucket): + return cls( + min=bucket.loi_min.total_seconds(), + max=bucket.loi_max.total_seconds(), + q1=bucket.loi_q1.total_seconds(), + q2=bucket.loi_q2.total_seconds(), + q3=bucket.loi_q3.total_seconds(), + mean=( + bucket.loi_mean.total_seconds() if bucket.loi_mean is not None else None + ), + ) + + +class PayoutSummaryDecimal(StatisticalSummary): + """Payouts are in Decimal USD""" + + min: Decimal = Field(gt=0, le=100) + max: Decimal = Field(gt=0, le=100) + q1: Decimal = Field(gt=0, le=100) + q2: Decimal = Field(gt=0, le=100) + q3: Decimal = Field(gt=0, le=100) + mean: Optional[Decimal] = Field(gt=0, le=100, default=None) + + +class PayoutSummary(StatisticalSummary): + """Payouts are in Integer USD Cents""" + + min: int = Field(gt=0, le=10000) + max: int = Field(gt=0, le=10000) + q1: int = Field(gt=0, le=10000) + q2: int = Field(gt=0, le=10000) + q3: int = Field(gt=0, le=10000) + mean: Optional[int] = Field(gt=0, le=10000, default=None) + + model_config = { + "json_schema_extra": { + "examples": [ + { + "min": 14, + "max": 132, + "q1": 45, + "q2": 68, + "q3": 124, + } + ] + } + } + + @classmethod + def from_bucket(cls, bucket: Bucket): + return cls( + min=bucket.decimal_to_usd_cents(bucket.user_payout_min), + max=bucket.decimal_to_usd_cents(bucket.user_payout_max), + q1=bucket.decimal_to_usd_cents(bucket.user_payout_q1), + q2=bucket.decimal_to_usd_cents(bucket.user_payout_q2), + q3=bucket.decimal_to_usd_cents(bucket.user_payout_q3), + mean=( + bucket.decimal_to_usd_cents(bucket.user_payout_mean) + if bucket.user_payout_mean is not None + else None + ), + ) + + +class SurveyEligibilityCriterion(BaseModel): + """ + Explanatory record of which question answers contributed + to a user's eligibility for a survey. + This is INSUFFICIENT for determining eligibility to a task + as it IGNORES logical operators, dependencies between criteria, + and other requirements. It is only intended for the UI. + """ + + model_config = ConfigDict(validate_assignment=True) + + question_id: Optional[UUIDStr] = Field( + examples=["71a367fb71b243dc89f0012e0ec91749"] + ) + property_code: Optional[PropertyCode] = Field(examples=["c:73629"]) + question_text: str = Field( + examples=[ + "What company administers the retirement plan for your current employer?" + ] + ) + # The answer(s) that were considered qualifying + qualifying_answer: Tuple[str, ...] = Field( + description="User answer(s) that satisfied at least one eligibility rule", + examples=["121"], + ) + qualifying_answer_label: Optional[Tuple[str, ...]] = Field( + examples=["Fidelity Investments"] + ) + explanation: Optional[str] = Field( + default=None, + description="Human-readable text explaining how a user's answer to this question affects eligibility", + examples=[ + "The company that administers your employer's retirement plan is **Fidelity Investments**." + ], + ) + explanation_fragment: Optional[str] = Field( + default=None, + exclude=True, + description="For internal use", + examples=["who's retirement plan is administered by **Fidelity Investments**"], + ) + # Rank more "interesting"/rare/salient criterion first. + rank: Optional[NonNegativeInt] = Field( + default=None, + description="Lower values are shown more prominently in the UI", + ) + + +class TopNBucket(BucketBase): + category: List[CategoryAssociation] = Field(default_factory=list) + duration: DurationSummary = Field() + payout: PayoutSummary = Field() + quality_score: float = Field( + ge=0, + le=1, + examples=[0.29223], + description="A proprietary score to determine the overall quality of the tasks that " + "are within the bucket. " + "Higher is better.", + ) + + @classmethod + def from_bucket(cls, bucket: Bucket): + return cls.model_validate( + { + "id": bucket.id, + "uri": bucket.uri, + "duration": DurationSummary.from_bucket(bucket), + "payout": PayoutSummary.from_bucket(bucket), + "quality_score": bucket.quality_score, + "category": bucket.category, + } + ) + + +class SingleEntryBucket(BucketBase): + x: int = Field(exclude=True, default=0) + y: int = Field(exclude=True, default=0) + name: int = Field(exclude=True, default="") + description: int = Field(exclude=True, default="") + + +class TopNPlusBucket(BucketBase): + category: List[CategoryAssociation] = Field(default_factory=list) + contents: List[BucketTask] = Field() + duration: DurationSummary = Field() + payout: PayoutSummary = Field() + quality_score: float = Field() + currency: str = Field( + description="This will always be 'USD'", default="USD", examples=["USD"] + ) + + eligibility_criteria: Tuple[SurveyEligibilityCriterion, ...] = Field( + description="The reasons the user is eligible for tasks in this bucket", + default_factory=tuple, + ) + eligibility_explanation: Optional[str] = Field( + default=None, + description="Human-readable text explaining a user's eligibility for tasks in this bucket", + examples=[ + "You are a **47-year-old** **white** **male** with a *college degree*, who's employer's retirement plan is **Fidelity Investments**." + ], + ) + + @field_validator("eligibility_criteria", mode="after") + @classmethod + def eligibility_ranks(cls, criteria): + criteria = list(criteria) + ranks = [c.rank for c in criteria] + if all(r is None for r in ranks): + for i, c in enumerate(criteria): + c.rank = i + return tuple(criteria) + if any(r is None for r in ranks): + raise ValueError("Set all or no ranks in eligibility_criteria") + if len(ranks) != len(set(ranks)): + raise ValueError("Duplicate ranks") + return tuple(sorted(criteria, key=lambda c: c.rank)) + + @classmethod + def from_bucket(cls, bucket: Bucket): + return cls.model_validate( + { + "id": bucket.id, + "uri": bucket.uri, + "duration": DurationSummary.from_bucket(bucket), + "payout": PayoutSummary.from_bucket(bucket), + "quality_score": bucket.quality_score, + "category": bucket.category, + "contents": bucket.contents, + "eligibility_criteria": bucket.eligibility_criteria, + "eligibility_explanation": bucket.eligibility_explanation, + } + ) + + +class TopNPlusRecontactBucket(BucketBase): + category: List[CategoryAssociation] = Field(default_factory=list) + contents: List[BucketTask] = Field() + duration: DurationSummary = Field() + payout: PayoutSummary = Field() + quality_score: float = Field() + is_recontact: bool = Field() + currency: str = Field( + description="This will always be 'USD'", default="USD", examples=["USD"] + ) + + @classmethod + def from_bucket(cls, bucket: Bucket): + return cls.model_validate( + { + "id": bucket.id, + "uri": bucket.uri, + "duration": DurationSummary.from_bucket(bucket), + "payout": PayoutSummary.from_bucket(bucket), + "quality_score": bucket.quality_score, + "category": bucket.category, + "contents": bucket.contents, + "is_recontact": bucket.metadata.get("is_recontact", False), + } + ) + + +class SoftPairBucket(BucketBase): + uri: Optional[HttpsUrl] = Field( + examples=[None], + description="The URL to send a respondent into. Must not edit this URL in any way. If the eligibility is " + "conditional or ineligible, the uri will be null.", + ) + + category: List[CategoryAssociation] = Field(default_factory=list) + contents: List[BucketTask] = Field() + + eligibility: Eligibility = Field(examples=["conditional"]) + missing_questions: List[str] = Field( + default_factory=list, examples=[["fb20fd4773304500b39c4f6de0012a5a"]] + ) + loi: int = Field(description="this is the max loi of the contents", examples=[612]) + payout: int = Field( + description="this is the min payout of the contents", examples=[123] + ) + + x: int = Field(exclude=True, default=0) + y: int = Field(exclude=True, default=0) + name: int = Field(exclude=True, default="") + description: int = Field(exclude=True, default="") + + +class MarketplaceBucket(BucketBase): + category: List[CategoryAssociation] = Field(default_factory=list) + contents: List[BucketTask] = Field() + duration: DurationSummary = Field() + payout: PayoutSummary = Field() + source: SourceName = Field( + description="this is the source of the contents", examples=["pollfish"] + ) + + +class TimeBucksBucket(BucketBase): + duration: int = Field( + gt=0, le=60 * 90, description="The bucket's q1 duration, in seconds" + ) + min_payout: int = Field( + gt=0, le=100_00, description="The bucket's min payout, in usd cents" + ) + currency: str = Field( + description="This will always be 'USD'", default="USD", examples=["USD"] + ) + + +class OneShotOfferwallBucket(BaseModel): + model_config = ConfigDict( + extra="forbid", validate_assignment=True, ser_json_timedelta="float" + ) + + id: UUIDStr = Field( + description="Unique identifier this particular bucket", + examples=["5ba2fe5010cc4d078fc3cc0b0cc264c3"], + ) + uri: HttpsUrl = Field( + examples=[ + "https://task.generalresearch.com/api/v1/52d3f63b2709/797df4136c604a6c8599818296aae6d1/?i" + "=5ba2fe5010cc4d078fc3cc0b0cc264c3&b=test&66482fb=e7baf5e" + ], + description="The URL to send a respondent into. Must not edit this URL in any way", + ) + duration: int = Field( + gt=0, + le=60 * 90, + description="The bucket's expected duration, in seconds", + ) + min_payout: int = Field( + gt=0, le=100_00, description="The bucket's min payout, in usd cents" + ) + + +class OneShotSoftPairOfferwallBucket(OneShotOfferwallBucket): + eligibility: Eligibility = Field(examples=["conditional"]) + missing_questions: List[str] = Field( + default_factory=list, examples=[["fb20fd4773304500b39c4f6de0012a5a"]] + ) + + +class WXETOfferwallBucket(BaseModel): + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ser_json_timedelta="float", + ) + + id: UUIDStr = Field( + description="Unique identifier this particular bucket", + examples=["5ba2fe5010cc4d078fc3cc0b0cc264c3"], + ) + uri: HttpsUrl = Field( + examples=[ + "https://task.generalresearch.com/api/v1/52d3f63b2709/797df4136c604a6c8599818296aae6d1/?i" + "=5ba2fe5010cc4d078fc3cc0b0cc264c3&b=test&66482fb=e7baf5e" + ], + description="The URL to send a respondent into. Must not edit this URL in any way", + ) + duration: int = Field( + gt=0, + le=60 * 90, + description="The bucket's expected duration, in seconds", + ) + min_payout: int = Field( + gt=0, le=10000, description="The bucket's min payout, in usd cents" + ) diff --git a/generalresearch/models/legacy/definitions.py b/generalresearch/models/legacy/definitions.py new file mode 100644 index 0000000..1755d2a --- /dev/null +++ b/generalresearch/models/legacy/definitions.py @@ -0,0 +1,11 @@ +from enum import Enum + + +class OfferwallReason(str, Enum): + USER_BLOCKED = "USER_BLOCKED" + HIGH_RECON_RATE = "HIGH_RECON_RATE" + UNCOMMON_DEMOGRAPHICS = "UNCOMMON_DEMOGRAPHICS" + UNDER_MINIMUM_AGE = "UNDER_MINIMUM_AGE" + EXHAUSTED_HIGH_VALUE_SUPPLY = "EXHAUSTED_HIGH_VALUE_SUPPLY" + ALL_ELIGIBLE_ATTEMPTED = "ALL_ELIGIBLE_ATTEMPTED" + LOW_CURRENT_SUPPLY = "LOW_CURRENT_SUPPLY" diff --git a/generalresearch/models/legacy/offerwall.py b/generalresearch/models/legacy/offerwall.py new file mode 100644 index 0000000..67213f7 --- /dev/null +++ b/generalresearch/models/legacy/offerwall.py @@ -0,0 +1,349 @@ +from __future__ import annotations + +from typing import List, Dict + +from pydantic import BaseModel, Field, ConfigDict, NonNegativeInt + +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.legacy.bucket import ( + BucketBase, + SoftPairBucket, + TopNBucket, + TimeBucksBucket, + MarketplaceBucket, + TopNPlusBucket, + SingleEntryBucket, + WXETOfferwallBucket, + OneShotOfferwallBucket, + OneShotSoftPairOfferwallBucket, + TopNPlusRecontactBucket, +) +from generalresearch.models.legacy.definitions import OfferwallReason +from generalresearch.models.thl.payout_format import ( + PayoutFormatField, + PayoutFormatType, +) +from generalresearch.models.thl.profiling.upk_question import UpkQuestion + +""" +Not Done: +8531fee24712: jeopardy +""" + + +class OfferWallInfo(BaseModel): + success: bool = Field() + + +class OfferWallResponse(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + info: OfferWallInfo = Field() + offerwall: OfferWall = Field() + + +class OfferWall(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + id: UUIDStr = Field( + description="Unique identifier to reference a generated offerwall", + examples=["7dc1d3aeb4844a6fab17ecd370b8bf1e"], + ) + + availability_count: NonNegativeInt = Field( + description="Total opportunities available for specific bpuid " + "respondent and parameters. This value changes frequently " + "and can be used to determine if a respondent has potential " + "tasks available, regardless of the offerwall type being " + "requested. If the value is 0, no buckets will be generated.", + examples=[42], + ) + + attempted_live_eligible_count: NonNegativeInt = Field( + description=( + "Number of currently live opportunities for which the respondent " + "meets all eligibility criteria but is excluded due to a prior attempt. " + "Only includes surveys that are still live and otherwise eligible; " + "does not include previously attempted surveys that are no longer available." + ), + examples=[7], + default=0, + ) + + buckets: List[BucketBase] = Field(default_factory=list) + + offerwall_reasons: List[OfferwallReason] = Field( + default_factory=list, + description=( + "Explanations describing why so many or few opportunities are available." + ), + examples=[[OfferwallReason.USER_BLOCKED, OfferwallReason.UNDER_MINIMUM_AGE]], + ) + + def censor(self): + for bucket in self.buckets: + bucket.censor() + + +class SingleEntryOfferWall(OfferWall): + """Only returns a single bucket with the top scoring tasks. + + Offerwall code: `5fl8bpv5` + """ + + payout_format: PayoutFormatType = PayoutFormatField + buckets: List[SingleEntryBucket] = Field(default_factory=list, max_length=1) + + +class TopNOfferWall(OfferWall): + """An offerwall with buckets that are clustered by the `split_by` argument + using KMeans clustering. + + Offerwall code: `45b7228a7` + """ + + buckets: List[TopNBucket] = Field(default_factory=list) + payout_format: PayoutFormatType = PayoutFormatField + + +class StarwallOfferWall(OfferWall): + """An offerwall with buckets that are clustered by setting as seeds the + highest scoring surveys for each bin, then the rest are distributed + according to their Euclidean distance using the bucket's features. + + Offerwall code: `b59a2d2b` + """ + + buckets: List[TopNBucket] = Field(default_factory=list) + payout_format: PayoutFormatType = PayoutFormatField + + +class TopNPlusOfferWall(OfferWall): + """Same as the TopNOfferWall, but the buckets include contents. + + Offerwall code: `b145b803` + """ + + buckets: List[TopNPlusBucket] = Field(default_factory=list) + + +class TopNPlusBlockOfferWall(OfferWall): + """Same as the TopNOfferWall, but the buckets include contents and no + buckets are returned if the user is blocked. + + Offerwall code: `d48cce47` + """ + + buckets: List[TopNPlusBucket] = Field(default_factory=list) + + # This incorrectly gets returned only when the user is blocked. It + # shouldn't get returned at all + payout_format: str = Field(exclude=True, default="") + + +class TopNPlusBlockRecontactOfferWall(OfferWall): + """Same as the TopNOfferWall, but the buckets include contents, no buckets + are returned if the user is blocked, and each bucket includes a + `is_recontact` key. + + Offerwall code: `1e5f0af8` + """ + + buckets: List[TopNPlusRecontactBucket] = Field(default_factory=list) + + # This incorrectly gets returned only when the user is blocked. It + # shouldn't get returned at all + payout_format: str = Field(exclude=True, default="") + + +class StarwallPlusOfferWall(OfferWall): + """Same as the StarwallOfferWall, but the buckets include contents. + + Offerwall code: `5481f322` + """ + + buckets: List[TopNPlusBucket] = Field(default_factory=list) + + +class StarwallPlusBlockOfferWall(OfferWall): + """Same as the StarwallOfferWall, but the buckets include contents and no + buckets are returned if the user is blocked. + + Offerwall code: `7fa1b3f4` + """ + + buckets: List[TopNPlusBucket] = Field(default_factory=list) + + # This incorrectly gets returned only when the user is blocked. It + # shouldn't get returned at all + payout_format: str = Field(exclude=True, default="") + + +class StarwallPlusBlockRecontactOfferWall(OfferWall): + """Same as the StarwallOfferWall, but the buckets include contents, no + buckets are returned if the user is blocked, and each bucket includes + a recontact key. + + Offerwall code: `630db2a4` + """ + + buckets: List[TopNPlusRecontactBucket] = Field(default_factory=list) + + # This incorrectly gets returned only when the user is blocked. It + # shouldn't get returned at all + payout_format: str = Field(exclude=True, default="") + + +class SoftPairOfferwall(OfferWall): + """This offerwall contains tasks for which the user has a conditional + eligibility. The questions that a user must answer to determine the + eligibility are included within each bucket. Additionally, the question + definitions are included for convenience. + + Offerwall code: `37d1da64` + """ + + buckets: List[SoftPairBucket] = Field(default_factory=list) + + question_info: Dict[str, "UpkQuestion"] = Field( + default_factory=dict, + examples=[ + # { + # UpkQuestion.model_config["json_schema_extra"]["example"][ + # "question_id" + # ]: UpkQuestion.model_config["json_schema_extra"]["example"] + # } + ], + ) + + # This incorrectly gets returned only when the user is blocked. It + # shouldn't get returned at all + payout_format: str = Field(exclude=True, default="") + + +class MarketplaceOfferwall(OfferWall): + """Returns buckets grouped by marketplace, one per marketplace, with the + tasks ordered by quality. + + Offerwall code: `5fa23085` + """ + + buckets: List[MarketplaceBucket] = Field(default_factory=list) + + +class TimeBucksOfferwall(OfferWall): + """A modification of the TopNOfferwall: + 1) topN split by payout with 10 buckets + 2) remove buckets min_payout > $4 (distribute those surveys to the + other buckets) + 3) duplicate each bucket 3x, with loi and payout jitter. no contents + key, no IQR, just return loi = q1_duration, payout = min_payout + + Offerwall code: `1705e4f8` + """ + + buckets: List[TimeBucksBucket] = Field(default_factory=list) + + +class TimeBucksBlockOfferwall(OfferWall): + """Same as the TimeBucksOfferwall, but no buckets are returned if the + user is blocked. + + Offerwall code: `0af0f7ec` + """ + + buckets: List[TimeBucksBucket] = Field(default_factory=list) + # This incorrectly gets returned only when the user is blocked. It shouldn't get returned at all + payout_format: str = Field(exclude=True, default="") + + +class OneShotOfferwall(OfferWall): + """Each bucket has only 1 single task, and only basic info is returned + about each bucket. + + Offerwall code: `6f27b1ae` + """ + + buckets: List[OneShotOfferwallBucket] = Field(default_factory=list) + + +class OneShotSoftPairOfferwall(SoftPairOfferwall): + """Each bucket has only 1 single task, and only basic info is returned + about each bucket. Supports soft pair + + Offerwall code: `18347426` + """ + + buckets: List[OneShotSoftPairOfferwallBucket] = Field(default_factory=list) + + +class WXETOfferwall(OfferWall): + """Returns buckets from WXET as single tasks + Offerwall code: `55a4e1a9` + """ + + buckets: List[WXETOfferwallBucket] = Field(default_factory=list) + + +class SingleEntryOfferWallResponse(OfferWallResponse): + offerwall: SingleEntryOfferWall = Field() + + +class TopNOfferWallResponse(OfferWallResponse): + offerwall: TopNOfferWall = Field() + + +class TopNPlusOfferWallResponse(OfferWallResponse): + offerwall: TopNPlusOfferWall = Field() + + +class TopNPlusBlockOfferWallResponse(OfferWallResponse): + offerwall: TopNPlusBlockOfferWall = Field() + + +class TopNPlusBlockRecontactOfferWallResponse(OfferWallResponse): + offerwall: TopNPlusBlockRecontactOfferWall = Field() + + +class StarwallOfferWallResponse(OfferWallResponse): + offerwall: StarwallOfferWall = Field() + + +class StarwallPlusOfferWallResponse(OfferWallResponse): + offerwall: StarwallPlusOfferWall = Field() + + +class StarwallPlusBlockOfferWallResponse(OfferWallResponse): + offerwall: StarwallPlusBlockOfferWall = Field() + + +class StarwallPlusBlockRecontactOfferWallResponse(OfferWallResponse): + offerwall: StarwallPlusBlockRecontactOfferWall = Field() + + +class SoftPairOfferwallResponse(OfferWallResponse): + offerwall: SoftPairOfferwall = Field() + + +class MarketplaceOfferwallResponse(OfferWallResponse): + offerwall: MarketplaceOfferwall = Field() + + +class TimeBucksOfferwallResponse(OfferWallResponse): + offerwall: TimeBucksOfferwall = Field() + + +class TimeBucksBlockOfferwallResponse(OfferWallResponse): + offerwall: TimeBucksBlockOfferwall = Field() + + +class OneShotOfferwallResponse(OfferWallResponse): + offerwall: OneShotOfferwall = Field() + + +class OneShotSoftPairOfferwallResponse(OfferWallResponse): + offerwall: OneShotSoftPairOfferwall = Field() + + +class WXETOfferwallResponse(OfferWallResponse): + offerwall: WXETOfferwall = Field() diff --git a/generalresearch/models/legacy/questions.py b/generalresearch/models/legacy/questions.py new file mode 100644 index 0000000..1559f24 --- /dev/null +++ b/generalresearch/models/legacy/questions.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +from typing import Dict, List, Optional, TYPE_CHECKING + +from pydantic import ( + BaseModel, + Field, + NonNegativeInt, + model_validator, + StringConstraints, + ConfigDict, + ValidationError, + BeforeValidator, + field_validator, +) +from sentry_sdk import capture_exception +from typing_extensions import Annotated +from typing_extensions import Self + +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.legacy.api_status import StatusResponse +from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionOut, +) +from generalresearch.models.thl.session import Wall +from generalresearch.models.thl.user import User + +if TYPE_CHECKING: + from generalresearch.managers.thl.user_manager.user_manager import ( + UserManager, + ) + from generalresearch.managers.thl.wall import WallManager + + +class UpkQuestionResponse(StatusResponse): + questions: List[UpkQuestionOut] = Field() + consent_questions: List[Dict] = Field( + description="For internal use", default_factory=list + ) + special_questions: List[Dict] = Field( + description="For internal use", default_factory=list + ) + count: NonNegativeInt = Field(description="The number of questions returned") + + +AnswerStr = Annotated[ + # TODO: What should the max_length be? TE open ended questions could + # mess with this... + str, + StringConstraints(min_length=1, max_length=5_000), +] + + +class UserQuestionAnswerIn(BaseModel): + """Send the answers to one or more questions for a user. A question is + uniquely specified by the question_id key. The answer is: the choice_id + if the question_type is "MC" the actual entered text if the + question_type is "TE" + + TODO: look up the question_type from the question_id to apply MC or + TE specific validation on the answer(s) + """ + + model_config = ConfigDict( + # This is applied to private empty strings as answers. However, it may + # alter TE input from users in unexpected ways for security or other + # forms of validation checks as it seems to modify the values in place. + str_strip_whitespace=True, + extra="forbid", + frozen=True, + ) + + question_id: UUIDStr = Field(examples=["fb20fd4773304500b39c4f6de0012a5a"]) + + answer: List[AnswerStr] = Field( + min_length=1, + max_length=10, + description="The user's answers to this question. Must pass the " + "choice_id if the question is a Multiple Choice, or the " + "actual text if the question is Text Entry", + examples=[["1"]], + ) + + # --- Validation --- + @model_validator(mode="after") + def single_answer_questions(self): + user_agent_qid = "2fbedb2b9f7647b09ff5e52fa119cc5e" + fingerprint_langs = "4030c52371b04e80b64e058d9c5b82e9" + fingerprint_tz = "a91cb1dea814480dba12d9b7b48696dd" + fingerprint_fingerprint = "1d1e2e8380ac474b87fb4e4c569b48df" + + if self.question_id in { + user_agent_qid, + fingerprint_langs, + fingerprint_tz, + fingerprint_fingerprint, + }: + if len(self.answer) != 1: + raise ValueError("Too many answer values provided") + + return self + + @model_validator(mode="after") + def user_agent_check(self) -> Self: + # TODO: where / how do I want to pass in this Werz user_agent stuff? + user_agent_qid = "2fbedb2b9f7647b09ff5e52fa119cc5e" + + if self.question_id == user_agent_qid: + val = self.answer[0] + # assert val == request.user_agent.to_header(): + pass + + return self + + @field_validator("answer", mode="after") + @classmethod + def no_duplicate_answer_values(cls, v: List[AnswerStr]) -> List[AnswerStr]: + if len(v) != len(set(v)): + raise ValueError("Don't provide duplicate answers") + + return v + + @field_validator("answer", mode="after") + @classmethod + def sort_answer_values(cls, v: List[AnswerStr]) -> List[AnswerStr]: + return sorted(v) + + # --- Properties --- + + # --- Methods --- + + +def preflight(li): + # https://github.com/pydantic/pydantic/discussions/7660 + new_li = [] + for x in li: + try: + x = UserQuestionAnswerIn.model_validate(x) + new_li.append(x) + except ValidationError as e: + capture_exception(error=e) + continue + + return new_li + + +class UserQuestionAnswers(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + product_id: UUIDStr = Field(examples=["4fe381fb7186416cb443a38fa66c6557"]) + + product_user_id: str = Field( + min_length=3, + max_length=128, + examples=["app-user-9329ebd"], + description="A unique identifier for each user, which is set by the " + "Supplier. It should not contain any sensitive information" + "like email or names, and should avoid using any" + "incrementing values.", + ) + + # Notice: There may be an issue where we could have told Suppliers that + # POST /profiling-questions/ that they could use a randomly generated + # session_id... I'm not sure, but it's entirely possible this will start + # to cause issues in production. + session_id: Optional[UUIDStr] = Field( + default=None, + description="The Session ID corresponds to the Wall.uuid. If profiling" + "answers are being submitted directly, this can be None.", + ) + + # We don't apply a default_factory here because there is no valid reason + # why a GRS submission would come valid without any answers. + answers: Annotated[List[UserQuestionAnswerIn], BeforeValidator(preflight)] = Field( + min_length=1, + max_length=100, + description="The list of questions and their answers that are being" + "submitted by the user (if via GRS), or by the Supplier " + "(if via FSB).", + ) + + user: Optional[User] = Field(default=None) + wall: Optional[Wall] = Field(default=None) + + # --- Validation --- + + # A user that doesn't yet exist can submit profiling questions, + # since there is no explicit "Create User" call. If session_id + # is passed, then the user should exist. + # @model_validator(mode="after") + # def user_exists(self): + # if self.user is None: + # raise ValueError("Invalid user") + # return self + + @model_validator(mode="after") + def valid_wall_event(self): + # session_id is Optional, so break early if we can't proceed. + if self.session_id is None: + return self + return self + + # I have this commented out for now because there is an argument to be made + # that a blocked user can or should be able to submit profiling data, or + # at least init a MarketplaceUserQuestionAnswer. + # @model_validator(mode="after") + # def grs_allowed_user(self): + # assert not self.user.blocked, "blocked user can't submit profiling " + # return self + + @field_validator("answers", mode="after") + @classmethod + def no_duplicate_questions(cls, v: List[UserQuestionAnswerIn]): + answer_qids = [qa.question_id for qa in v] + if len(answer_qids) != len(set(answer_qids)): + raise ValueError("Don't provide answers to duplicate questions") + + return v + + # --- Prefetch --- + def prefetch_user(self, um: "UserManager") -> None: + from generalresearch.models.thl.user import User + + res: User = um.get_user_if_exists( + product_id=self.product_id, product_user_id=self.product_user_id + ) + + if res is None: + raise ValidationError("Invalid user") + + self.user = res + + def prefetch_wall(self, wm: "WallManager") -> None: + from generalresearch.models.thl.session import Wall + from generalresearch.models import Source + + res: Optional[Wall] = wm.get_from_uuid_if_exists(wall_uuid=self.session_id) + + if res is None: + raise ValueError("Invalid Event for session_id") + + if res.source != Source.GRS: + raise ValueError("Not a valid GRS event") + + if res.user_id != self.product_user_id: + raise ValueError("Not a valid GRS event for this user") + + # I think it's fair to say a UserQuestionAnswers instance can / should + # only be initialized for a Wall event that exists, but hasn't been + # finished yet. Therefor this is safe to do for legit users for now + if res.finished is not None: + raise ValueError("Not a valid GRS event status") + + self.wall = res diff --git a/generalresearch/models/lucid/__init__.py b/generalresearch/models/lucid/__init__.py new file mode 100644 index 0000000..84a210d --- /dev/null +++ b/generalresearch/models/lucid/__init__.py @@ -0,0 +1,7 @@ +from pydantic import Field + +from typing_extensions import Annotated + +LucidQuestionIdType = Annotated[ + str, Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") +] diff --git a/generalresearch/models/lucid/question.py b/generalresearch/models/lucid/question.py new file mode 100644 index 0000000..b3d9b27 --- /dev/null +++ b/generalresearch/models/lucid/question.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import logging +from enum import Enum +from typing import List, Optional, Literal, TYPE_CHECKING + +from pydantic import BaseModel, Field, model_validator, field_validator +from typing_extensions import Self + +from generalresearch.models import Source +from generalresearch.models.lucid import LucidQuestionIdType +from generalresearch.models.thl.profiling.marketplace import ( + MarketplaceQuestion, +) + +if TYPE_CHECKING: + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +class LucidQuestionOption(BaseModel): + id: str = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+|-3105$", + frozen=True, + description="precode", + ) + text: str = Field( + min_length=1, + max_length=1024, + frozen=True, + description="The response text shown to respondents", + ) + # Order does not come back explicitly in the API + order: int = Field() + + +class LucidQuestionType(str, Enum): + SINGLE_SELECT = "s" + MULTI_SELECT = "m" + TEXT_ENTRY = "t" + # This is text entry, but only numbers + NUMERICAL = "n" + # Dummy means they're calculated + DUMMY = "d" + + +class LucidQuestion(MarketplaceQuestion): + question_id: LucidQuestionIdType = Field( + description="The unique identifier for the qualification", frozen=True + ) + question_text: str = Field( + max_length=1024, + min_length=1, + description="The text shown to respondents", + frozen=False, + ) + question_type: LucidQuestionType = Field( + description="The type of question asked", frozen=True + ) + options: Optional[List[LucidQuestionOption]] = Field( + default=None, min_length=1, frozen=True + ) + + source: Literal[Source.LUCID] = Source.LUCID + + @property + def internal_id(self) -> str: + return self.question_id + + @model_validator(mode="after") + def check_type_options_agreement(self) -> Self: + # If type == "text_entry", options is None. Otherwise, must be set. + if self.question_type in { + LucidQuestionType.TEXT_ENTRY, + LucidQuestionType.NUMERICAL, + }: + assert self.options is None, "TEXT_ENTRY/NUMERICAL shouldn't have options" + else: + assert self.options is not None, "missing options" + return self + + @field_validator("options") + @classmethod + def order_options(cls, options): + if options: + options.sort(key=lambda x: x.order) + return options + + @classmethod + def from_db(cls, d: dict) -> Self: + options = None + if d["options"]: + options = [ + LucidQuestionOption(id=r["id"], text=r["text"], order=r["order"]) + for r in d["options"] + ] + return cls( + question_id=d["question_id"], + question_text=d["question_text"], + question_type=d["question_type"], + country_iso=d["country_iso"], + language_iso=d["language_iso"], + options=options, + ) + + def to_upk_question(self) -> "UpkQuestion": + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionType, + UpkQuestionSelectorMC, + UpkQuestionSelectorTE, + UpkQuestion, + ) + + upk_type_selector_map = { + LucidQuestionType.SINGLE_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.SINGLE_ANSWER, + ), + LucidQuestionType.DUMMY: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.SINGLE_ANSWER, + ), + LucidQuestionType.MULTI_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.MULTIPLE_ANSWER, + ), + LucidQuestionType.TEXT_ENTRY: ( + UpkQuestionType.TEXT_ENTRY, + UpkQuestionSelectorTE.SINGLE_LINE, + ), + LucidQuestionType.NUMERICAL: ( + UpkQuestionType.TEXT_ENTRY, + UpkQuestionSelectorTE.SINGLE_LINE, + ), + } + upk_type, upk_selector = upk_type_selector_map[self.question_type] + d = { + "ext_question_id": self.external_id, + "country_iso": self.country_iso, + "language_iso": self.language_iso, + "type": upk_type, + "selector": upk_selector, + "text": self.question_text, + } + if self.options: + d["choices"] = [ + UpkQuestionChoice(id=c.id, text=c.text, order=c.order) + for c in self.options + ] + return UpkQuestion(**d) diff --git a/generalresearch/models/lucid/survey.py b/generalresearch/models/lucid/survey.py new file mode 100644 index 0000000..6d81254 --- /dev/null +++ b/generalresearch/models/lucid/survey.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from typing import Optional, Dict, Set, Tuple, List + +from pydantic import NonNegativeInt, Field, ConfigDict, BaseModel + +from generalresearch.models import Source +from generalresearch.models.custom_types import ( + AwareDatetimeISO, + UUIDStr, + CoercedStr, + BigAutoInteger, +) +from generalresearch.models.thl.locales import CountryISO, LanguageISO +from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, +) + + +class LucidCondition(MarketplaceCondition): + model_config = ConfigDict(populate_by_name=True, frozen=False, extra="ignore") + + id: BigAutoInteger = Field() + source: Source = Field(default=Source.LUCID) + question_id: Optional[CoercedStr] = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + ) + country_iso: CountryISO = Field() + language_iso: LanguageISO = Field() + + @property + def criterion_hash(self) -> None: + # We use the integer ID throughout. Make sure we don't accidentally use this + raise ValueError() + + def __hash__(self): + # this is so it can be put into a set / dictionary key + return hash(self.id) + + @classmethod + def from_mysql(cls, x): + x["value_type"] = ConditionValueType.LIST + x["negate"] = False + x["values"] = x.pop("pre_codes").split("|") + x["question_id"] = str(x["question_id"]) + return cls.model_validate(x) + + +class LucidQualification(BaseModel): + criterion: int = Field() + modified: AwareDatetimeISO = Field(description="modified or created") + + +class LucidQuota(BaseModel): + id: BigAutoInteger = Field() + uuid: UUIDStr = Field() + upper_limit: NonNegativeInt = Field(examples=[20]) + criteria: List[int] = Field(min_length=1, max_length=25) + modified: AwareDatetimeISO = Field(description="modified or created") + # We'll look this up with a special mysql query. If None, it means + # that we don't know. + finish_count: Optional[int] = Field(default=None) + + def __hash__(self): + return hash(self.id) + + @property + def is_open(self) -> bool: + return self.upper_limit > self.finish_count + + def passes(self, criteria_evaluation: Dict[int, Optional[bool]]) -> bool: + # Passes means we 1) meet all conditions (aka "match") AND 2) the quota is open. + return self.is_open and self.matches(criteria_evaluation) + + def matches(self, criteria_evaluation: Dict[int, Optional[bool]]) -> bool: + # Matches means we meet all conditions. + # We can "match" a quota that is closed. In that case, we would not be eligible for the survey. + return all(criteria_evaluation.get(c) for c in self.criteria) + + # def matches_optional( + # self, criteria_evaluation: Dict[int, Optional[bool]] + # ) -> Optional[bool]: + # # We need to know if any conditions are unknown to avoid matching a full quota. If any fail, + # # then we know we fail regardless of any being unknown. + # evals = [criteria_evaluation.get(c) for c in self.criteria] + # if False in evals: + # return False + # if None in evals: + # return None + # return True + + def matches_soft( + self, criteria_evaluation: Dict[int, Optional[bool]] + ) -> Tuple[Optional[bool], Set[int]]: + # Passes back "matches" (T/F/none) and a list of unknown criterion hashes + hash_evals = {cell: criteria_evaluation.get(cell) for cell in self.criteria} + evals = set(hash_evals.values()) + if False in evals: + return False, set() + if None in evals: + return None, {cell for cell, ev in hash_evals.items() if ev is None} + return True, set() diff --git a/generalresearch/models/marketplace/__init__.py b/generalresearch/models/marketplace/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/models/marketplace/summary.py b/generalresearch/models/marketplace/summary.py new file mode 100644 index 0000000..0dd3404 --- /dev/null +++ b/generalresearch/models/marketplace/summary.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from abc import ABC +from typing import List, Optional, Dict, Literal, Collection + +import numpy as np +from pydantic import BaseModel, Field, computed_field, ConfigDict +from typing_extensions import Self + +from generalresearch.models.thl.stats import StatisticalSummary + + +class MarketplaceSummary(BaseModel): + model_config = ConfigDict(validate_assignment=True) + + inventory: MarketplaceInventorySummary = Field( + description="Inventory of the marketplace" + ) + user_activity: Optional[str] = Field( + description="User activity of the marketplace", default=None + ) + + +class MarketplaceInventorySummary(BaseModel): + model_config = ConfigDict(validate_assignment=True) + + live_tasks: List[CountStat] = Field( + default_factory=list, + description="The count of tasks that are currently live", + ) + live_gen_pop_tasks: List[CountStat] = Field( + default_factory=list, + description="The count of gen-pop tasks that are currently live", + ) + tasks_created: List[CountStat] = Field( + default_factory=list, + description="The count of tasks created", + ) + required_finishes: List[CountStat] = Field( + default_factory=list, + description="Number of finishes needed across all live tasks", + ) + + payout: List[StatisticalSummaryStat] = Field( + default_factory=list, + description="The distribution of payouts for all live tasks", + ) + expected_duration: List[StatisticalSummaryStat] = Field( + default_factory=list, + description="The distribution of expected durations for all live tasks", + ) + required_finishes_per_task: List[StatisticalSummaryStat] = Field( + default_factory=list, + description="The distribution of required finishes on all live tasks", + ) + + +FacetKey = Literal["country_iso", "day", "month"] + + +class Stat(BaseModel, ABC): + facet: Dict[FacetKey, str | int | float] = Field( + examples=[{"country_iso": "us"}], description="The grouping criteria" + ) + + +class CountStat(Stat): + count: int = Field(description="The count value for the given metric and facet") + + +class StatisticalSummaryStat(Stat): + value: StatisticalSummaryValue = Field( + description="Statistical Summary for the given metric and facet" + ) + + +class StatisticalSummaryValue(StatisticalSummary): + min: float = Field() + max: float = Field() + mean: float = Field() + q1: float = Field() + q2: float = Field(description="equal to the median") + q3: float = Field() + + @classmethod + def from_values(cls, values: Collection[int | float]) -> Self: + values = sorted(values) + return cls( + min=min(values), + max=max(values), + q1=np.percentile(values, 25), + q2=np.percentile(values, 50), + q3=np.percentile(values, 75), + mean=np.mean(values), + ) + + @computed_field + @property + def lower_whisker(self) -> float: + return self.q1 - (1.5 * self.iqr) + + @computed_field + @property + def upper_whisker(self) -> float: + return self.q3 + (1.5 * self.iqr) + + +d = MarketplaceSummary( + inventory=MarketplaceInventorySummary( + live_tasks=[ + CountStat( + facet={"country_iso": "us"}, + count=10, + ), + CountStat( + facet={"country_iso": "ca"}, + count=2, + ), + CountStat(facet={}, count=15), + ], + tasks_created=[ + CountStat( + facet={"day": "2024-11-02"}, + count=5, + ), + CountStat( + facet={"day": "2024-11-02", "country_iso": "us"}, + count=4, + ), + CountStat( + facet={"day": "2024-11-01", "country_iso": "us"}, + count=4, + ), + ], + payout=[ + StatisticalSummaryStat( + facet={}, + value=StatisticalSummaryValue( + min=14, q1=40, q2=96, q3=123, max=420, mean=100 + ), + ), + StatisticalSummaryStat( + facet={"country_iso": "us"}, + value=StatisticalSummaryValue( + min=16, q1=42, q2=98, q3=123, max=400, mean=100 + ), + ), + ], + ) +) diff --git a/generalresearch/models/morning/__init__.py b/generalresearch/models/morning/__init__.py new file mode 100644 index 0000000..2c61c49 --- /dev/null +++ b/generalresearch/models/morning/__init__.py @@ -0,0 +1,16 @@ +from enum import Enum + +from pydantic import StringConstraints +from typing_extensions import Annotated + +# This is text-based, in lowercase. e.g. 'age', 'household_income' +MorningQuestionID = Annotated[ + str, StringConstraints(min_length=1, max_length=64, pattern=r"^[^A-Z]+$") +] + + +class MorningStatus(str, Enum): + DRAFT = "draft" + ACTIVE = "active" # aka LIVE + PAUSED = "paused" + CLOSED = "closed" diff --git a/generalresearch/models/morning/question.py b/generalresearch/models/morning/question.py new file mode 100644 index 0000000..77ea209 --- /dev/null +++ b/generalresearch/models/morning/question.py @@ -0,0 +1,207 @@ +import json +from enum import Enum +from typing import List, Optional, Dict, Literal, Any +from uuid import UUID + +from pydantic import BaseModel, Field, model_validator, field_validator +from typing_extensions import Self + +from generalresearch.locales import Localelator +from generalresearch.models import Source +from generalresearch.models.morning import MorningQuestionID +from generalresearch.models.thl.profiling.marketplace import ( + MarketplaceUserQuestionAnswer, + MarketplaceQuestion, +) + +# todo: we could validate that the country_iso / language_iso exists ... +locale_helper = Localelator() + + +class MorningQuestionOption(BaseModel, frozen=True): + # API limit is 50, db limit is 32 + id: str = Field( + min_length=1, + max_length=32, + pattern=r"^[\w\s\.\-]+$", + description="The unique identifier for a response to a qualification", + serialization_alias="option_id", + ) + text: str = Field( + min_length=1, + description="The response text shown to respondents", + serialization_alias="option_text", + ) + # Order does not come back explicitly in the API, instead they are already ordered. We're + # adding this for db sort purposes to explicitly order them. We use the API's order. + order: int = Field() + + +class MorningQuestionType(str, Enum): + # The db stores these as a single letter + + # Geographic questions represent geographic areas within a country. + # These behave like multiple_choice questions + geographic = "g" + # The 's' is for "single-select". Morning does not support "multi-select" + # multiple choice, but if they did, we would use 'm' for "multi-select". + multiple_choice = "s" + # Questions whose answers are submitted by respondents. The ID and + # response text are both defined as the exact text that was typed by + # the respondent. Text entry responses are not case-sensitive + text_entry = "t" + + +class MorningUserQuestionAnswer(MarketplaceUserQuestionAnswer): + question_id: MorningQuestionID = Field() + question_type: Optional[MorningQuestionType] = Field(default=None) + # Did this answer come from us asking, or was it passed back from the + # marketplace? Note, morning doesn't "pass back" answers, but we can + # retrieve a user's profile through API, so it is possible to populate + # this from_thl False + from_thl: bool = Field(default=True) + + +class MorningQuestion(MarketplaceQuestion): + # API limit is 100, db limit is 64 + id: str = Field( + min_length=1, + max_length=64, + pattern=r"^[a-z0-9_\s\.]+$", + description="The unique identifier for the qualification", + serialization_alias="question_id", + frozen=True, + ) + # API has no limit, db limit is 64 + name: str = Field( + max_length=64, + min_length=1, + serialization_alias="question_name", + description="The human-readable short label for the qualification", + frozen=True, + ) + text: str = Field( + min_length=1, + description="The text shown to respondents", + serialization_alias="question_text", + frozen=True, + ) + type: MorningQuestionType = Field( + description="The type of question asked", + serialization_alias="question_type", + frozen=True, + ) + # API calls this "responses", but I think that is a confusing name + options: Optional[List[MorningQuestionOption]] = Field( + default=None, min_length=1, frozen=True + ) + + source: Literal[Source.MORNING_CONSULT] = Source.MORNING_CONSULT + + @property + def internal_id(self) -> str: + return self.id + + @model_validator(mode="after") + def check_type_options_agreement(self) -> Self: + # If type == "text_entry", options is None. Otherwise, must be set. + if self.type == MorningQuestionType.text_entry: + assert self.options is None + else: + assert self.options is not None + return self + + @field_validator("options") + @classmethod + def order_options(cls, options): + if options: + options.sort(key=lambda x: x.order) + return options + + @classmethod + def from_api(cls, d: dict, country_iso: str, language_iso: str): + options = None + if d.get("responses"): + options = [ + MorningQuestionOption(id=r["id"], text=r["text"], order=order) + for order, r in enumerate(d["responses"]) + ] + return cls( + id=d["id"], + name=d["name"], + text=d["text"], + type=MorningQuestionType[d["type"]], + country_iso=country_iso, + language_iso=language_iso, + options=options, + ) + + @classmethod + def from_db(cls, d: dict): + options = None + if d["options"]: + options = [ + MorningQuestionOption( + id=r["option_id"], text=r["option_text"], order=r["order"] + ) + for r in d["options"] + ] + return cls( + id=d["question_id"], + name=d["question_name"], + text=d["question_text"], + type=d["question_type"], + country_iso=d["country_iso"], + language_iso=d["language_iso"], + options=options, + is_live=d["is_live"], + category_id=( + UUID(d.get("category_id")).hex if d.get("category_id") else None + ), + ) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", by_alias=True) + d["options"] = json.dumps(d["options"]) + return d + + def to_upk_question(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionType, + UpkQuestionSelectorMC, + UpkQuestionSelectorTE, + UpkQuestionSelectorHIDDEN, + UpkQuestion, + ) + + upk_type_selector_map = { + # multiple select doesn't exist in morning, only single select + MorningQuestionType.multiple_choice: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.SINGLE_ANSWER, + ), + MorningQuestionType.text_entry: ( + UpkQuestionType.TEXT_ENTRY, + UpkQuestionSelectorTE.SINGLE_LINE, + ), + MorningQuestionType.geographic: ( + UpkQuestionType.HIDDEN, + UpkQuestionSelectorHIDDEN.HIDDEN, + ), + } + upk_type, upk_selector = upk_type_selector_map[self.type] + d = { + "ext_question_id": self.external_id, + "country_iso": self.country_iso, + "language_iso": self.language_iso, + "type": upk_type, + "selector": upk_selector, + "text": self.text, + } + if self.type == MorningQuestionType.multiple_choice: + d["choices"] = [ + UpkQuestionChoice(id=c.id, text=c.text, order=c.order) + for c in self.options + ] + return UpkQuestion(**d) diff --git a/generalresearch/models/morning/survey.py b/generalresearch/models/morning/survey.py new file mode 100644 index 0000000..6f61661 --- /dev/null +++ b/generalresearch/models/morning/survey.py @@ -0,0 +1,556 @@ +from __future__ import annotations + +import json +import logging +from datetime import timezone +from decimal import Decimal +from functools import cached_property +from typing import ( + Optional, + Dict, + Any, + List, + Set, + Annotated, + Tuple, + Literal, + Type, +) + +from pydantic import ( + Field, + ConfigDict, + BaseModel, + computed_field, + NonNegativeInt, + model_validator, + PositiveInt, + PrivateAttr, +) +from typing_extensions import Self + +from generalresearch.locales import Localelator +from generalresearch.models import Source +from generalresearch.models.custom_types import ( + AwareDatetimeISO, + UUIDStrCoerce, +) +from generalresearch.models.morning import MorningStatus, MorningQuestionID +from generalresearch.models.morning.question import MorningQuestion +from generalresearch.models.thl.demographics import Gender +from generalresearch.models.thl.locales import ( + CountryISO, + CountryISOs, + LanguageISOs, +) +from generalresearch.models.thl.survey import MarketplaceTask +from generalresearch.models.thl.survey.condition import ( + ConditionValueType, + MarketplaceCondition, +) + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +locale_helper = Localelator() + + +class MorningExclusion(BaseModel): + group_id: UUIDStrCoerce = Field( + description="The unique identifier for the exclusion group" + ) + # The length of time in days to lock out a respondent who has successfully + # completed another bid in the same exclusion group. + # When omitted, respondents who have ever participated in that exclusion + # group will be disallowed from entering the current bid. The value may + # also be set to 0 to signal group exclusion for future bids without + # excluding participants from the current bid. + lockout_period: NonNegativeInt = Field(description="length of time in days") + + +class MorningStatistics(BaseModel): + model_config = ConfigDict(populate_by_name=True, frozen=False, extra="ignore") + + # `length_of_interview` changes meaning after 5 completes. We can use + # `estimated_length_of_interview` and `median_length_of_interview` instead + # to get the bid and obs values. + # The bid loi is the same for all quotas in the bid, so we'll put it on the + # bid bid_loi: int = Field(validation_alias="estimated_length_of_interview", + # le=120 * 60) + # If num_completes == 0 , this gets returned as 0. it should be None + obs_median_loi: Optional[NonNegativeInt] = Field( + validation_alias="median_length_of_interview", default=None, le=120 * 60 + ) + + # API returns 100 until 5 completes! Should be None. + # This is calculated as the total completes divided by the total number of + # finished sessions that passed the prescreener. + qualified_conversion: Optional[float] = Field( + ge=0, le=1, description="conversion rate of qualified respondents" + ) + + # Panelists should only be sent to a bid or quota if this number is greater + # than zero. In-progress sessions are taken into account + num_available: NonNegativeInt = Field( + description="The number of completes that are currently available to fill" + ) + + num_completes: NonNegativeInt = Field( + description="The number of people who have successfully completed the survey" + ) + num_failures: NonNegativeInt = Field( + description="The number of people who have been rejected for an unknown reason" + ) + # This includes respondents who are in the prescreener or the survey, but + # have not yet completed or been rejected. + num_in_progress: NonNegativeInt = Field( + description="The number of people with active sessions" + ) + num_over_quotas: NonNegativeInt = Field( + description="The number of respondents who have been terminated for meeting a quota " + "which is already full" + ) + num_qualified: NonNegativeInt = Field( + description="The number of respondents who qualified for a quota, including over " + "quotas" + ) + num_quality_terminations: NonNegativeInt = Field( + description="The number of respondents who have been terminated for quality reasons" + ) + num_timeouts: NonNegativeInt = Field( + description="The number of respondents who have been timed out" + ) + + # Not using: length_of_interview (meaning changes after 5 completes) + + @model_validator(mode="after") + def check_api_default(self) -> Self: + # API returns stupid default values instead of None + if self.num_completes < 5: + self.obs_median_loi = None + self.qualified_conversion = None + else: + assert self.obs_median_loi is not None + assert self.qualified_conversion is not None + return self + + +class MorningTaskStatistics(MorningStatistics): + # This is the "statistics" for the "Bid" aka the survey/task. It contains + # all the field as for the quota statistics plus extra fields that are not + # relevant to quotas. + + # API returns 100 until 5 completes! Should be None ... + system_conversion: Optional[float] = Field( + description="conversion rate of the system. completes divided by total number of entrants to the system", + ge=0, + le=1, + ) + num_entrants: NonNegativeInt = Field( + description="The number of people who have entered the respondent router and successfully reached the " + "prescreener. This includes respondents who have not yet qualified" + ) + num_screenouts: NonNegativeInt = Field( + description="Number of screenouts, including those screened out in the prescreener and those screened out in " + "the survey" + ) + # this is for the bid only. the quotas dont have bid lois + bid_loi: PositiveInt = Field( + validation_alias="estimated_length_of_interview", le=120 * 60 + ) + + # Not using: incidence_rate & length_of_interview (meaning changes after 5 completes), earnings_per_click (can + # calculate from the other values) + + +class MorningCondition(MarketplaceCondition): + model_config = ConfigDict(populate_by_name=True, frozen=False, extra="ignore") + question_id: Optional[MorningQuestionID] = Field(validation_alias="id") + values: List[Annotated[str, Field(max_length=128)]] = Field( + validation_alias="response_ids" + ) + value_type: ConditionValueType = Field(default=ConditionValueType.LIST) + + +class MorningQuota(MorningStatistics, MarketplaceTask): + model_config = ConfigDict(populate_by_name=True, frozen=False) + + id: UUIDStrCoerce = Field() + cpi: Decimal = Field( + gt=0, + le=100, + decimal_places=2, + max_digits=5, + validation_alias="cost_per_interview", + ) + condition_hashes: List[str] = Field(min_length=1, default_factory=list) + + # since the Quota is the MarketplaceTask, it needs these fields, copied from the Bid + source: Literal[Source.MORNING_CONSULT] = Field(default=Source.MORNING_CONSULT) + used_question_ids: Set[MorningQuestionID] = Field(default_factory=set) + country_iso: CountryISO = Field(frozen=True) + country_isos: CountryISOs = Field() + language_isos: LanguageISOs = Field(frozen=True) + buyer_id: UUIDStrCoerce = Field() + + # Min spots a quota should have open to be OPEN + _min_open_spots: int = PrivateAttr(default=1) + + def __hash__(self): + return hash(self.id) + + @model_validator(mode="before") + @classmethod + def set_locale(cls, data: Any): + data["country_isos"] = [data["country_iso"]] + if isinstance(data["language_isos"], str): + data["language_isos"] = set(data["language_isos"].split(",")) + data["language_iso"] = sorted(data["language_isos"])[0] + return data + + @property + def internal_id(self) -> str: + return self.id + + @computed_field + def is_live(self) -> bool: + return True + + @computed_field + @cached_property + def all_hashes(self) -> Set[str]: + return set(self.condition_hashes) + + @property + def condition_model(self) -> Type[MarketplaceCondition]: + return MorningCondition + + @property + def age_question(self) -> str: + return "age" + + @property + def marketplace_genders( + self, + ) -> Dict[Gender, Optional[MarketplaceCondition]]: + return { + Gender.MALE: MorningCondition( + question_id="gender", + values=["1"], + value_type=ConditionValueType.LIST, + ), + Gender.FEMALE: MorningCondition( + question_id="gender", + values=["2"], + value_type=ConditionValueType.LIST, + ), + Gender.OTHER: None, + } + + @property + def is_open(self) -> bool: + # num_available includes in-progress (they're already deducted) + return self.num_available >= self._min_open_spots + + def passes(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Passes means we 1) meet all conditions (aka "match") AND 2) the quota is open. + return self.is_open and self.matches(criteria_evaluation) + + # TODO: I did some speed tests. This is faster than how this is implemented + # in sago/spectrum/dynata/etc. We should generalize this logic instead of + # copying/pasting it 7 times. (matches, matches_optional and _soft) + def matches(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Matches means we meet all conditions. + # In Morning, all quotas are mutually exclusive. so if it doesn't + # matter if we match a closed quota, b/c that means that we won't + # match any other quota anyway + return self.matches_optional(criteria_evaluation) is True + + def matches_optional( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Optional[bool]: + for c in self.condition_hashes: + eval_value = criteria_evaluation.get(c) + if eval_value is False: + return False + if eval_value is None: + return None + return True + + def matches_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], List[str]]: + # Passes back "matches" (T/F/none) and a list of unknown criterion hashes + unknowns = list() + for c in self.condition_hashes: + eval_value = criteria_evaluation.get(c) + if eval_value is False: + return False, list() + if eval_value is None: + unknowns.append(c) + if unknowns: + return None, unknowns + return True, unknowns + + +class MorningBid(MorningTaskStatistics): + """ + This is the top-level task in Morning Consult; what we would normally call + a survey. A survey can have 1 or more quotas. Each quota has its own CPI + and targeting. We use the quota as the generic task throughout THL because + the quota has a unique ID which we'll use for targeting. + """ + + model_config = ConfigDict(populate_by_name=True) + + id: UUIDStrCoerce = Field() + status: MorningStatus = Field( + default=MorningStatus.ACTIVE, validation_alias="state" + ) + + # A survey has 1 country and one or more languages + country_iso: CountryISO = Field(frozen=True) + language_isos: LanguageISOs = Field(frozen=True) + + buyer_account_id: UUIDStrCoerce = Field() + buyer_id: UUIDStrCoerce = Field() + name: str = Field(min_length=1, max_length=100) + supplier_exclusive: bool = Field(default=False) + survey_type: str = Field(min_length=1, max_length=32) + timeout: PositiveInt = Field(le=24 * 60 * 60) + topic_id: str = Field(min_length=1, max_length=64) + + exclusions: List[MorningExclusion] = Field(default_factory=list) + + quotas: List[MorningQuota] = Field(default_factory=list) + + source: Literal[Source.MORNING_CONSULT] = Field(default=Source.MORNING_CONSULT) + + used_question_ids: Set[MorningQuestionID] = Field(default_factory=set) + + # This is a "special" key to store all conditions that are used (as + # "condition_hashes") throughout this survey. In the reduced representation + # of this task (nearly always, for db i/o, in global_vars) this field will + # be null. + conditions: Optional[Dict[str, MorningCondition]] = Field(default=None) + + # This doesn't get stored in the db directly + experimental_single_use_qualifications: Optional[List[MorningQuestion]] = Field( + default=None + ) + + # These come from the API + expected_end: AwareDatetimeISO = Field(validation_alias="end_date") + created_api: AwareDatetimeISO = Field(validation_alias="published_at") + + # This does not come from the API. We set it when we update this in the db. + created: Optional[AwareDatetimeISO] = Field(default=None) + updated: Optional[AwareDatetimeISO] = Field(default=None) + + # ignoring from API: closed_at + + def __hash__(self): + return hash(self.id) + + @computed_field + def is_live(self) -> bool: + return self.status == MorningStatus.ACTIVE + + @property + def is_open(self) -> bool: + # The survey is open if the status is ACTIVE and there is at least 1 + # open quota. + return self.is_live and any(q.is_open for q in self.quotas) + + @property + def language_iso_any(self): + return sorted(self.language_isos)[0] + + @property + def locale(self): + return self.country_iso, self.language_iso_any + + @computed_field + @cached_property + def all_hashes(self) -> Set[str]: + s = set() + for q in self.quotas: + s.update(set(q.condition_hashes)) + return s + + @model_validator(mode="before") + @classmethod + def set_locale(cls, data: Any): + data["country_isos"] = [data["country_iso"]] + return data + + @model_validator(mode="before") + @classmethod + def setup_quota_fields(cls, data: dict) -> dict: + # These fields get "inherited" by each quota from its bid. + quota_fields = [ + "country_iso", + "language_isos", + "buyer_id", + "bid_loi", + "used_question_ids", + ] + for quota in data["quotas"]: + for field in quota_fields: + if field not in quota: + quota[field] = data[field] + return data + + @model_validator(mode="before") + @classmethod + def set_used_questions(cls, data: Any): + if data.get("used_question_ids") is not None: + return data + if not data.get("conditions"): + data["used_question_ids"] = set() + return data + data["used_question_ids"] = { + c.question_id for c in data["conditions"].values() if c.question_id + } + return data + + @model_validator(mode="before") + @classmethod + def setup_conditions(cls, data: dict) -> dict: + if "conditions" in data: + return data + data["conditions"] = dict() + for quota in data["quotas"]: + if "qualifications" in quota: + quota_conditions = [ + MorningCondition.model_validate(q) for q in quota["qualifications"] + ] + quota["condition_hashes"] = [c.criterion_hash for c in quota_conditions] + data["conditions"].update( + {c.criterion_hash: c for c in quota_conditions} + ) + if "_experimental_single_use_qualifications" in quota: + quota_conditions = [ + MorningCondition.model_validate(q) + for q in quota["_experimental_single_use_qualifications"] + ] + quota["condition_hashes"].extend( + [c.criterion_hash for c in quota_conditions] + ) + data["conditions"].update( + {c.criterion_hash: c for c in quota_conditions} + ) + return data + + @model_validator(mode="before") + @classmethod + def clean_alias(cls, data: Dict) -> Dict: + # Make sure fields are named certain ways, so we don't have to check + # aliases within other validators + if "estimated_length_of_interview" in data: + data["bid_loi"] = data.pop("estimated_length_of_interview") + return data + + @model_validator(mode="after") + def sort_quotas(self) -> Self: + # sort the quotas so that we can do comparisons on bids to see if anything has changed + self.quotas = sorted(self.quotas, key=lambda x: x.id) + return self + + def is_unchanged(self, other) -> bool: + # Avoiding overloading __eq__ because it looks kind of complicated? I + # want to be explicit that this is not testing object equivalence, + # just that the objects don't require any db updates. We also exclude + # conditions b/c this is just the condition_hash definitions + return self.model_dump( + exclude={"updated", "conditions", "created"} + ) == other.model_dump(exclude={"updated", "conditions", "created"}) + + def is_changed(self, other) -> bool: + return not self.is_unchanged(other) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump( + mode="json", + exclude={ + "all_hashes": True, + "country_isos": True, + "source": True, + "conditions": True, + "quotas": { + "__all__": { + "all_hashes", + "used_question_ids", + "is_live", + "country_isos", + "language_isos", + } + }, + }, + ) + d["used_question_ids"] = json.dumps(sorted(d["used_question_ids"])) + d["exclusions"] = json.dumps(d["exclusions"]) + for q in d["quotas"]: + q["condition_hashes"] = json.dumps(q["condition_hashes"]) + d["expected_end"] = self.expected_end + d["created_api"] = self.created_api + d["updated"] = self.updated + d["created"] = self.created + return d + + @classmethod + def from_db(cls, d: Dict[str, Any]): + d["created"] = d["created"].replace(tzinfo=timezone.utc) + d["updated"] = d["updated"].replace(tzinfo=timezone.utc) + d["expected_end"] = d["expected_end"].replace(tzinfo=timezone.utc) + d["created_api"] = d["created_api"].replace(tzinfo=timezone.utc) + d["used_question_ids"] = json.loads(d["used_question_ids"]) + d["exclusions"] = json.loads(d["exclusions"]) + return cls.model_validate(d) + + def passes_quotas( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Optional[str]: + # Quotas are mutually-exclusive. A user can only possibly match 1 quota. + # Returns the passing quota ID or None (if user doesn't pass any quota) + for q in self.quotas: + if q.passes(criteria_evaluation): + return q.id + + def passes_quotas_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Optional[List[str]], Optional[Set[str]]]: + """ + Quotas are mutually-exclusive. A user can only possibly match 1 quota. As such, all unknown + questions on any quota will be the same unknowns on all. + Returns (the eligibility (True/False/None), passing quota ID or None (if eligibility is not True), + unknown_hashes (or None)) + """ + unknown_quotas = [] + unknown_hashes = set() + for q in self.quotas: + if q.is_open: + elig, quota_unknown_hashes = q.matches_soft(criteria_evaluation) + if elig is True: + return True, [q.id], None + if elig is None: + unknown_quotas.append(q.id) + unknown_hashes.update(quota_unknown_hashes) + if unknown_quotas: + return None, unknown_quotas, unknown_hashes + return False, None, None + + def determine_eligibility( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Optional[str]: + if not self.is_open: + return None + return self.passes_quotas(criteria_evaluation) + + def determine_eligibility_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Optional[List[str]], Optional[Set[str]]]: + if not self.is_open: + return False, None, None + return self.passes_quotas_soft(criteria_evaluation) diff --git a/generalresearch/models/morning/task_collection.py b/generalresearch/models/morning/task_collection.py new file mode 100644 index 0000000..174f5e1 --- /dev/null +++ b/generalresearch/models/morning/task_collection.py @@ -0,0 +1,140 @@ +from typing import List, Set + +import pandas as pd +from pandera import Column, DataFrameSchema, Check, Index + +from generalresearch.locales import Localelator +from generalresearch.models.morning import MorningStatus +from generalresearch.models.morning.survey import MorningBid +from generalresearch.models.thl.survey.task_collection import ( + create_empty_df_from_schema, + TaskCollection, +) + +COUNTRY_ISOS: Set[str] = Localelator().get_all_countries() +LANGUAGE_ISOS: Set[str] = Localelator().get_all_languages() + +bid_stats_columns = { + "system_conversion": Column(float, Check.between(0, 1), nullable=True), + "num_entrants": Column(int, Check.ge(0)), + "num_screenouts": Column(int, Check.ge(0)), + "bid_loi": Column("Int32", Check.between(0, 90 * 60), nullable=True), +} + +# Used for "counts", should be a non-negative integer. +CountColumn = Column(int, Check.ge(0)) +stats_columns = { + "num_available": CountColumn, + "num_completes": CountColumn, + "num_failures": CountColumn, + "num_in_progress": CountColumn, + "num_over_quotas": CountColumn, + "num_qualified": CountColumn, + "num_quality_terminations": CountColumn, + "num_timeouts": CountColumn, + "obs_median_loi": Column("Int32", Check.between(0, 90 * 60), nullable=True), + "qualified_conversion": Column(float, Check.between(0, 1), nullable=True), +} + +bid_columns = { + "bid.id": Column(str, Check.str_length(min_value=1, max_value=32)), # uuid-hex + "status": Column(str, Check.isin(MorningStatus)), + "country_iso": Column(str, Check.isin(COUNTRY_ISOS)), # 2 letter, lowercase + "language_isos": Column(str), # comma-separated list of [3 letter, lowercase] + "buyer_account_id": Column(str), # uuid-hex + "buyer_id": Column(str), # uuid-hex + "name": Column(str, Check.str_length(min_value=1, max_value=256)), + "supplier_exclusive": Column(bool), + "survey_type": Column(str, Check.str_length(min_value=1, max_value=32)), + "topic_id": Column(str, Check.str_length(min_value=1, max_value=64)), + "timeout": Column(int, Check.ge(0)), + "created_api": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "expected_end": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "updated": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), +} +quota_columns = { + "cpi": Column(float, Check.between(min_value=0, max_value=100)), + "used_question_ids": Column(List[str]), + "all_hashes": Column(List[str]), # set >> list for column support +} + +columns = ( + bid_columns + | quota_columns + | {"bid." + k: v for k, v in bid_stats_columns.items()} + | {"bid." + k: v for k, v in stats_columns.items()} + | {"quota." + k: v for k, v in stats_columns.items()} +) + +# In Morning, each row is 1 quota! +MorningTaskCollectionSchema = DataFrameSchema( + columns=columns, + checks=[], + # this should be a uuid-hex + index=Index( + str, + name="quota_id", + checks=Check.str_length(min_value=1, max_value=32), + unique=True, + ), + strict=True, + coerce=True, + drop_invalid_rows=False, +) + + +class MorningTaskCollection(TaskCollection): + items: List[MorningBid] + _schema = MorningTaskCollectionSchema + + def to_rows(self, bid: MorningBid): + stats_fields = list(stats_columns.keys()) + bid_stats_fields = list(bid_stats_columns.keys()) + bid_fields = [ + # 'id', # we have to rename this + "status", + "country_iso", + "language_isos", + "buyer_account_id", + "buyer_id", + "name", + "supplier_exclusive", + "survey_type", + "topic_id", + "timeout", + "created_api", + "expected_end", + "updated", + ] + quota_fields = list(quota_columns.keys()) + rows = [] + bid_dict = dict() + for k in bid_fields: + bid_dict[k] = getattr(bid, k) + bid_dict["bid.id"] = bid.id + bid_dict["language_isos"] = ",".join(sorted(bid.language_isos)) + for k in bid_stats_fields: + bid_dict["bid." + k] = getattr(bid, k) + for k in stats_fields: + bid_dict["bid." + k] = getattr(bid, k) + for quota in bid.quotas: + d = bid_dict.copy() + d["quota_id"] = quota.id + for k in quota_fields: + d[k] = getattr(quota, k) + for k in stats_fields: + d["quota." + k] = getattr(quota, k) + d["cpi"] = float(quota.cpi) + d["used_question_ids"] = list(quota.used_question_ids) + d["all_hashes"] = list(quota.all_hashes) + rows.append(d) + return rows + + def to_df(self): + rows = [] + for s in self.items: + rows.extend(self.to_rows(s)) + if rows: + return pd.DataFrame.from_records(rows, index="quota_id") + else: + return create_empty_df_from_schema(self._schema) diff --git a/generalresearch/models/pollfish/__init__.py b/generalresearch/models/pollfish/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/models/pollfish/question.py b/generalresearch/models/pollfish/question.py new file mode 100644 index 0000000..30f8088 --- /dev/null +++ b/generalresearch/models/pollfish/question.py @@ -0,0 +1,140 @@ +# https://wss.pollfish.com/mediation/documentation +import json +import logging +from enum import Enum +from typing import List, Optional, Literal, Any, Dict + +from pydantic import BaseModel, Field, model_validator + +from generalresearch.models import Source +from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +class PollfishQuestionOption(BaseModel): + id: str = Field( + min_length=1, + max_length=64, + pattern=r"^[\w\s\.\-]+$", + frozen=True, + description="The unique identifier for a response to a qualification", + ) + text: str = Field( + min_length=1, + max_length=1024, + frozen=True, + description="The response text shown to respondents", + ) + # Order does not come back explicitly in the API, but the responses seem to be ordered + order: int = Field() + + +class PollfishQuestionType(str, Enum): + """ + From the API: {'single_punch', 'multi_punch', 'open_ended'} + """ + + SINGLE_SELECT = "s" + MULTI_SELECT = "m" + TEXT_ENTRY = "t" + + +class PollfishQuestion(MarketplaceQuestion): + question_id: str = Field( + min_length=1, + max_length=64, + pattern=r"^[a-z0-9_\s\.]+$", + description="The unique identifier for the qualification", + frozen=True, + ) + question_text: str = Field( + max_length=1024, min_length=1, description="The text shown to respondents" + ) + question_type: PollfishQuestionType = Field(frozen=True) + options: Optional[List[PollfishQuestionOption]] = Field(default=None, min_length=1) + # This comes from the API field "category" + tags: Optional[str] = Field(default=None, frozen=True) + source: Literal[Source.POLLFISH] = Source.POLLFISH + + @property + def internal_id(self) -> str: + return self.question_id + + @model_validator(mode="after") + def check_type_options_agreement(self): + # If type == "text_entry", options is None. Otherwise, must be set. + if self.question_type == PollfishQuestionType.TEXT_ENTRY: + assert self.options is None, "TEXT_ENTRY shouldn't have options" + else: + assert self.options is not None, "missing options" + return self + + @classmethod + def from_db(cls, d: dict): + options = None + if d["options"]: + options = [ + PollfishQuestionOption(id=r["id"], text=r["text"], order=r["order"]) + for r in d["options"] + ] + return cls( + question_id=d["question_id"], + question_text=d["question_text"], + question_type=d["question_type"], + country_iso=d["country_iso"], + language_iso=d["language_iso"], + options=options, + tags=d["tags"], + is_live=d["is_live"], + category_id=d.get("category_id"), + ) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", by_alias=True) + d["options"] = json.dumps(d["options"]) + return d + + def to_upk_question(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionType, + UpkQuestionSelectorMC, + UpkQuestionSelectorTE, + UpkQuestion, + order_exclusive_options, + ) + + upk_type_selector_map = { + PollfishQuestionType.SINGLE_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.SINGLE_ANSWER, + ), + PollfishQuestionType.MULTI_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.MULTIPLE_ANSWER, + ), + PollfishQuestionType.TEXT_ENTRY: ( + UpkQuestionType.TEXT_ENTRY, + UpkQuestionSelectorTE.SINGLE_LINE, + ), + } + upk_type, upk_selector = upk_type_selector_map[self.question_type] + d = { + "ext_question_id": self.external_id, + "country_iso": self.country_iso, + "language_iso": self.language_iso, + "type": upk_type, + "selector": upk_selector, + "text": self.question_text, + } + if self.options: + d["choices"] = [ + UpkQuestionChoice(id=c.id, text=c.text, order=n) + for n, c in enumerate(self.options) + ] + q = UpkQuestion(**d) + order_exclusive_options(q) + return q diff --git a/generalresearch/models/precision/__init__.py b/generalresearch/models/precision/__init__.py new file mode 100644 index 0000000..4bb2c6a --- /dev/null +++ b/generalresearch/models/precision/__init__.py @@ -0,0 +1,16 @@ +from enum import Enum + +from pydantic import StringConstraints +from typing_extensions import Annotated + + +class PrecisionStatus(str, Enum): + # I made this up. They use isactive: "Yes" or "no", which I think is stupid + OPEN = "open" + CLOSED = "closed" + + +# Some questions are strings, like 'state', 'gender', and others are numeric +PrecisionQuestionID = Annotated[ + str, StringConstraints(min_length=1, max_length=32, pattern=r"^[^A-Z]+$") +] diff --git a/generalresearch/models/precision/definitions.py b/generalresearch/models/precision/definitions.py new file mode 100644 index 0000000..daa6f64 --- /dev/null +++ b/generalresearch/models/precision/definitions.py @@ -0,0 +1,322 @@ +# These were sent to us in an excel file. Copied here because we don't use excel files +# also available here: https://integrations.precisionsample.com/api.html#API%20Lookup%20Document +# https://integrations.precisionsample.com/PS_GetProjects_API_Lookup_values.xlsx + +from generalresearch.locales import Localelator + +locales = Localelator() + +COUNTRY_CODE_TO_ISO = { + "15": "au", + "38": "ca", + "229": "uk", + "231": "us", + "492": "af", + "493": "ax", + "494": "al", + "495": "dz", + "496": "as", + "497": "ad", + "498": "ao", + "499": "ai", + "500": "aq", + "501": "ag", + "502": "ar", + "503": "am", + "504": "aw", + "505": "at", + "506": "az", + "507": "bs", + "508": "bh", + "509": "bd", + "510": "bb", + "511": "by", + "512": "be", + "513": "bz", + "514": "bj", + "515": "bm", + "516": "bt", + "517": "bo", + "518": "bq", + "519": "ba", + "520": "bw", + "521": "bv", + "522": "br", + "523": "io", + "524": "bn", + "525": "bg", + "526": "bf", + "527": "bi", + "528": "kh", + "529": "cm", + "530": "cv", + "531": "ky", + "532": "cf", + "533": "td", + "534": "cl", + "535": "cn", + "536": "cx", + "537": "cc", + "538": "co", + "539": "km", + "540": "cg", + "541": "cd", + "542": "ck", + "543": "cr", + "544": "ci", + "545": "hr", + "546": "cu", + "547": "cw", + "548": "cy", + "549": "cz", + "550": "dk", + "551": "dj", + "552": "dm", + "553": "do", + "554": "ec", + "555": "eg", + "556": "sv", + "557": "gq", + "558": "er", + "559": "ee", + "560": "et", + "561": "fk", + "562": "fo", + "563": "fj", + "564": "fi", + "565": "fr", + "566": "gf", + "567": "pf", + "568": "tf", + "569": "ga", + "570": "gm", + "571": "ge", + "572": "de", + "573": "gh", + "574": "gi", + "575": "gr", + "576": "gl", + "577": "gd", + "578": "gp", + "579": "gu", + "580": "gt", + "581": "gg", + "582": "gn", + "583": "gw", + "584": "gy", + "585": "ht", + "586": "hm", + "587": "va", + "588": "hn", + "589": "hk", + "590": "hu", + "591": "is", + "592": "in", + "593": "id", + "594": "ir", + "595": "iq", + "596": "ie", + "597": "im", + "598": "il", + "599": "it", + "600": "jm", + "601": "jp", + "602": "je", + "603": "jo", + "604": "kz", + "605": "ke", + "606": "ki", + "607": "kp", + "608": "kr", + "609": "xk", + "610": "kw", + "611": "kg", + "612": "la", + "613": "lv", + "614": "lb", + "615": "ls", + "616": "lr", + "617": "ly", + "618": "li", + "619": "lt", + "620": "lu", + "621": "mo", + "622": "mk", + "623": "mg", + "624": "mw", + "625": "my", + "626": "mv", + "627": "ml", + "628": "mt", + "629": "mh", + "630": "mq", + "631": "mr", + "632": "mu", + "633": "yt", + "634": "mx", + "635": "fm", + "636": "md", + "637": "mc", + "638": "mn", + "639": "me", + "640": "ms", + "641": "ma", + "642": "mz", + "643": "mm", + "644": "na", + "645": "nr", + "646": "np", + "647": "nl", + "648": "an", + "649": "nc", + "650": "nz", + "651": "ni", + "652": "ne", + "653": "ng", + "654": "nu", + "655": "nf", + "656": "mp", + "657": "no", + "658": "om", + "659": "pk", + "660": "pw", + "661": "ps", + "662": "pa", + "663": "pg", + "664": "py", + "665": "pe", + "666": "ph", + "667": "pn", + "668": "pl", + "669": "pt", + "670": "pr", + "671": "qa", + "672": "re", + "673": "ro", + "674": "ru", + "675": "rw", + "676": "bl", + "677": "sh", + "678": "kn", + "679": "lc", + "680": "mf", + "681": "pm", + "682": "vc", + "683": "ws", + "684": "sm", + "685": "st", + "686": "sa", + "687": "sn", + "688": "rs", + "689": "sc", + "690": "sl", + "691": "sg", + "692": "sx", + "693": "sk", + "694": "si", + "695": "sb", + "696": "so", + "697": "za", + "698": "gs", + "699": "ss", + "700": "es", + "701": "lk", + "702": "sd", + "703": "sr", + "704": "sj", + "705": "sz", + "706": "se", + "707": "ch", + "708": "sy", + "709": "tw", + "710": "tj", + "711": "tz", + "712": "th", + "713": "tl", + "714": "tg", + "715": "tk", + "716": "to", + "717": "tt", + "718": "tn", + "719": "tr", + "720": "tm", + "721": "tc", + "722": "tv", + "723": "ug", + "724": "ua", + "725": "ae", + "726": "um", + "727": "uy", + "728": "uz", + "729": "vu", + "730": "ve", + "731": "vn", + "732": "vg", + "733": "vi", + "734": "wf", + "735": "eh", + "736": "ye", + "737": "zm", + "738": "zw", + "1196": "kr", + "1197": "kr", + "3501": "al", + "3502": "ay", + "3503": "eg", + "3504": "gl", + "3505": "mi", + "3506": "md", + "3507": "ni", + "3508": "ri", + "3509": "sc", + "3510": "sa", + "3511": "sh", + "3512": "sm", + "3513": "so", + "3514": "sp", + "3515": "wl", +} +COUNTRY_CODE_TO_ISO["229"] = "gb" # uk is not a valid country +# Filter to use only valid countries +COUNTRY_CODE_TO_ISO = { + k: v for k, v in COUNTRY_CODE_TO_ISO.items() if v in locales.get_all_countries() +} + +COUNTRY_ISO_TO_CODE = {v: k for k, v in COUNTRY_CODE_TO_ISO.items()} + +LANGUAGE_CODE_TO_ISO = { + "114282": "eng", + "114283": "fre", + "114284": "spa", + "114285": "ger", + "114286": "por", + "114287": "rus", + "114288": "ita", + "114289": "dut", + "114290": "jpn", + "114291": "chi", + "114292": "kor", + "131616": "ara", + "131617": "swe", + "136675": "pol", + "137223": "tha", + "137224": "ind", + "137225": "vie", +} +LANGUAGE_ISO_TO_CODE = {v: k for k, v in LANGUAGE_CODE_TO_ISO.items()} + + +def country_code_to_iso(code: int | str) -> str: + return COUNTRY_CODE_TO_ISO[str(code)] + + +def country_iso_to_code(country_iso: str) -> str: + return COUNTRY_ISO_TO_CODE[country_iso] + + +def language_code_to_iso(code: int | str) -> str: + return LANGUAGE_CODE_TO_ISO[str(code)] + + +def language_iso_to_code(language_iso: str) -> str: + return LANGUAGE_ISO_TO_CODE[language_iso] diff --git a/generalresearch/models/precision/question.py b/generalresearch/models/precision/question.py new file mode 100644 index 0000000..5750bb6 --- /dev/null +++ b/generalresearch/models/precision/question.py @@ -0,0 +1,199 @@ +# https://integrations.precisionsample.com/api.html#Get%20Questions +import json +import logging +from enum import Enum +from typing import List, Optional, Literal, Any, Dict + +from pydantic import BaseModel, Field, model_validator, field_validator + +from generalresearch.models import Source, string_utils +from generalresearch.models.precision import PrecisionQuestionID +from generalresearch.models.thl.profiling.marketplace import ( + MarketplaceQuestion, + MarketplaceUserQuestionAnswer, +) + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +class PrecisionQuestionOption(BaseModel): + id: str = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + frozen=True, + description="The unique identifier for a response to a qualification", + ) + text: str = Field( + min_length=1, + max_length=1024, + frozen=True, + description="The response text shown to respondents", + ) + # Order does not come back explicitly in the API, but the responses seem to be ordered + order: int = Field() + + +class PrecisionQuestionType(str, Enum): + """ + From the API: {'Drop Down', 'Multi Select', 'Single Select', 'Single Select Matrix', 'Vertical Question'} + Of course undocumented. And there doesn't seem to be a text entry option? + """ + + SINGLE_SELECT = "s" + MULTI_SELECT = "m" + TEXT_ENTRY = "t" + + @classmethod + def from_api(cls, a: int): + API_TYPE_MAP = { + "Drop Down": PrecisionQuestionType.SINGLE_SELECT, + "Multi Select": PrecisionQuestionType.MULTI_SELECT, + "Single Select": PrecisionQuestionType.SINGLE_SELECT, + "Single Select Matrix": PrecisionQuestionType.SINGLE_SELECT, + "Vertical Question": PrecisionQuestionType.SINGLE_SELECT, + } + return API_TYPE_MAP[a] if a in API_TYPE_MAP else None + + +class PrecisionUserQuestionAnswer(MarketplaceUserQuestionAnswer): + question_id: PrecisionQuestionID = Field() + question_type: Optional[PrecisionQuestionType] = Field(default=None) + # Was this answer synchronized with precision's user profile API? + synced: bool = Field(default=False) + + +class PrecisionQuestion(MarketplaceQuestion): + question_id: PrecisionQuestionID = Field( + description="The unique identifier for the qualification" + ) + question_name: Optional[str] = Field(default=None, max_length=128) + question_text: str = Field( + max_length=1024, min_length=1, description="The text shown to respondents" + ) + question_type: PrecisionQuestionType = Field(frozen=True) + options: Optional[List[PrecisionQuestionOption]] = Field(default=None, min_length=1) + # This comes from the API field ProfileName. idk what the possible values are, looks like: + # 'Personal Profile', 'Work Profile', 'Auto Profile', 'Medical Profile', 'Travel & Entertainment'. + # I don't know what, if anything, this is used for. + profile: Optional[str] = Field(default=None, frozen=True) + source: Literal[Source.PRECISION] = Source.PRECISION + + @property + def internal_id(self) -> str: + return self.question_id + + @field_validator("question_text", mode="after") + def remove_nbsp(cls, s: Optional[str]): + return string_utils.remove_nbsp(s) + + @model_validator(mode="after") + def check_type_options_agreement(self): + # If type == "text_entry", options is None. Otherwise, must be set. + if self.question_type == PrecisionQuestionType.TEXT_ENTRY: + assert self.options is None, "TEXT_ENTRY shouldn't have options" + else: + assert self.options is not None, "missing options" + return self + + @classmethod + def from_api(cls, d: dict) -> Optional["PrecisionQuestion"]: + """ + :param d: Raw response from API + """ + try: + return cls._from_api(d) + except Exception as e: + logger.warning(f"Unable to parse question: {d}. {e}") + return None + + @classmethod + def _from_api(cls, d: dict) -> "PrecisionQuestion": + question_type = PrecisionQuestionType.from_api(d["question_type_name"]) + # sometimes an empty option is returned .... ? + options = [ + PrecisionQuestionOption( + id=str(r["option_id"]), text=r["option_text"], order=n + ) + for n, r in enumerate(d["options"]) + if r + ] + return cls( + question_id=str(d["question_id"]), + profile=d.get("ProfileName"), + question_name=d.get("question_name"), + question_text=d["question_text"], + question_type=question_type, + options=options, + country_iso=d["country_iso"], + language_iso=d["language_iso"], + ) + + @classmethod + def from_db(cls, d: dict): + options = None + if d["options"]: + options = [ + PrecisionQuestionOption(id=r["id"], text=r["text"], order=r["order"]) + for r in d["options"] + ] + return cls( + question_id=d["question_id"], + question_text=d["question_text"], + question_name=d["question_name"], + question_type=d["question_type"], + country_iso=d["country_iso"], + language_iso=d["language_iso"], + options=options, + is_live=d["is_live"], + category_id=d.get("category_id"), + ) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", by_alias=True) + d["options"] = json.dumps(d["options"]) + return d + + def to_upk_question(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionType, + UpkQuestionSelectorMC, + UpkQuestionSelectorTE, + UpkQuestion, + order_exclusive_options, + ) + + upk_type_selector_map = { + PrecisionQuestionType.SINGLE_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.SINGLE_ANSWER, + ), + PrecisionQuestionType.MULTI_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.MULTIPLE_ANSWER, + ), + PrecisionQuestionType.TEXT_ENTRY: ( + UpkQuestionType.TEXT_ENTRY, + UpkQuestionSelectorTE.SINGLE_LINE, + ), + } + upk_type, upk_selector = upk_type_selector_map[self.question_type] + d = { + "ext_question_id": self.external_id, + "country_iso": self.country_iso, + "language_iso": self.language_iso, + "type": upk_type, + "selector": upk_selector, + "text": self.question_text, + } + if self.options: + d["choices"] = [ + UpkQuestionChoice(id=c.id, text=c.text, order=n) + for n, c in enumerate(self.options) + ] + q = UpkQuestion(**d) + order_exclusive_options(q) + return q diff --git a/generalresearch/models/precision/survey.py b/generalresearch/models/precision/survey.py new file mode 100644 index 0000000..847093b --- /dev/null +++ b/generalresearch/models/precision/survey.py @@ -0,0 +1,375 @@ +from __future__ import annotations + +import json +from datetime import timezone +from functools import cached_property +from typing import Optional, List, Literal, Set, Dict, Any, Tuple, Type + +from more_itertools import flatten +from pydantic import ( + ConfigDict, + Field, + PrivateAttr, + BaseModel, + computed_field, + model_validator, +) +from typing_extensions import Annotated + +from generalresearch.models import Source +from generalresearch.models.custom_types import ( + CoercedStr, + UUIDStrCoerce, + AwareDatetimeISO, + AlphaNumStrSet, + DeviceTypes, +) +from generalresearch.models.precision import PrecisionQuestionID, PrecisionStatus +from generalresearch.models.thl.demographics import Gender +from generalresearch.models.thl.survey import MarketplaceTask +from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, +) + + +class PrecisionCondition(MarketplaceCondition): + question_id: Optional[PrecisionQuestionID] = Field() + values: List[Annotated[str, Field(max_length=128)]] = Field() + value_type: ConditionValueType = Field(default=ConditionValueType.LIST) + _CONVERT_LIST_TO_RANGE = ["age"] + + +class PrecisionQuota(BaseModel): + model_config = ConfigDict(populate_by_name=True, frozen=False) + + name: str = Field() + # Not sure if id or guid are used for anything + id: str = Field() + guid: str = Field() + status: PrecisionStatus = Field(default=PrecisionStatus.OPEN) + desired_count: int = Field(ge=0) + # These 3 fields are "global" ! + achieved_count: int = Field(ge=0) + termination_count: int = Field(ge=0) + overquota_count: int = Field(ge=0) + + condition_hashes: List[str] = Field(min_length=1, default_factory=list) + + # Min spots a quota should have open to be OPEN + _min_open_spots: int = PrivateAttr(default=3) + + def __hash__(self): + return hash(self.guid) + + @property + def is_live(self) -> bool: + return self.status == PrecisionStatus.OPEN + + @property + def is_open(self) -> bool: + min_open_spots = 3 + return self.is_live and self.remaining_count >= min_open_spots + + @property + def remaining_count(self) -> int: + return max(self.desired_count - self.achieved_count, 0) + + # TODO: I did some speed tests. This is faster than how this is implemented + # in sago/spectrum/dynata/etc. We should generalize this logic instead of + # copying/pasting it 7 times. (matches, matches_optional and _soft) + def matches(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Matches means we meet all conditions. + # In Morning, all quotas are mutually exclusive. so if it doesn't + # matter if we match a closed quota, b/c that means that we won't + # match any other quota anyway + return self.matches_optional(criteria_evaluation) is True + + def matches_optional( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Optional[bool]: + for c in self.condition_hashes: + eval_value = criteria_evaluation.get(c) + if eval_value is False: + return False + if eval_value is None: + return None + return True + + def matches_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], List[str]]: + # Passes back "matches" (T/F/none) and a list of unknown criterion hashes + unknowns = list() + for c in self.condition_hashes: + eval_value = criteria_evaluation.get(c) + if eval_value is False: + return False, list() + if eval_value is None: + unknowns.append(c) + if unknowns: + return None, unknowns + return True, unknowns + + +class PrecisionSurvey(MarketplaceTask): + model_config = ConfigDict(populate_by_name=True, extra="ignore") + + # They call this the project ID (a project is a survey) + survey_id: CoercedStr = Field( + min_length=1, max_length=16, pattern=r"^[0-9]+$", validation_alias="prj_id" + ) + # Almost always equals the survey_id, but we can use this to retrieve user IDs who should be excluded + group_id: CoercedStr = Field( + min_length=1, max_length=16, pattern=r"^[0-9]+$", validation_alias="grouping_id" + ) + # There is no status returned, using one I make up b/c is_live depends on it, + status: PrecisionStatus = Field(default=PrecisionStatus.OPEN) + name: str = Field(validation_alias="prj_name") + survey_guid: UUIDStrCoerce = Field(validation_alias="prj_guid") + + category_id: Optional[str] = Field(validation_alias="sc_id", default=None) + buyer_id: CoercedStr = Field(max_length=16) + + # This seems to always be 0 ... ? + # response_rate: float = Field(ge=0, le=1, validation_alias="rr", description="Invites divided by Completes") + # How is this calculated? makes no sense + # complete_pct: float = Field(ge=0, le=1, validation_alias="cp") + # Also skipping: ismultiple (allowing multiple entrances). How is that even possible? They are all False anyways. + + bid_loi: int = Field(default=None, ge=59, le=120 * 60, validation_alias="loi") + bid_ir: float = Field(ge=0, le=1, validation_alias="ir") + # Be careful with this, it doesn't make any sense. See survey 452481, has 12 completes with a 100% live_ir, + # but the only quotas have 0 completes and 1052 terms. .... ?? + global_conversion: Optional[float] = Field( + ge=0, + le=1, + default=None, + validation_alias="live_ir", + description="completes divide by sum of completes & terms", + ) + + desired_count: int = Field(ge=0, validation_alias="total_completes") + # If achieved_count is 0, the global_conversion should be None + achieved_count: int = Field(ge=0, validation_alias="cc") + + allowed_devices: DeviceTypes = Field(min_length=1) + + entry_link: str = Field(validation_alias="url") + excluded_surveys: Optional[AlphaNumStrSet] = Field( + description="list of excluded survey ids", + default=None, + validation_alias="exclusion_project_id", + ) + + quotas: List[PrecisionQuota] = Field(default_factory=list) + + source: Literal[Source.PRECISION] = Field(default=Source.PRECISION) + + used_question_ids: Set[PrecisionQuestionID] = Field(default_factory=set) + + # This is a "special" key to store all conditions that are used (as "condition_hashes") throughout + # this survey. In the reduced representation of this task (nearly always, for db i/o, in global_vars) + # this field will be null. + conditions: Optional[Dict[str, PrecisionCondition]] = Field(default=None) + + # This comes from the API + expected_end_date: Optional[AwareDatetimeISO] = Field( + default=None, validation_alias="end_date" + ) + + # This does not come from the API. We set it when we update this in the db. + created: Optional[AwareDatetimeISO] = Field(default=None) + updated: Optional[AwareDatetimeISO] = Field(default=None) + + @property + def internal_id(self) -> str: + return self.survey_id + + @computed_field + def is_live(self) -> bool: + return self.status == PrecisionStatus.OPEN + + @property + def is_open(self) -> bool: + # The survey is open if the status is OPEN and there is at least 1 open quota (or there are no quotas!) + return self.is_live and ( + any(q.is_open for q in self.quotas) or len(self.quotas) == 0 + ) + + @computed_field + @cached_property + def all_hashes(self) -> Set[str]: + s = set() + for q in self.quotas: + s.update(set(q.condition_hashes)) + return s + + @model_validator(mode="before") + @classmethod + def set_used_questions(cls, data: Any): + if data.get("used_question_ids") is not None: + return data + if not data.get("conditions"): + data["used_question_ids"] = set() + return data + data["used_question_ids"] = { + c.question_id for c in data["conditions"].values() if c.question_id + } + return data + + @property + def condition_model(self) -> Type[MarketplaceCondition]: + return PrecisionCondition + + @property + def age_question(self) -> str: + return "age" + + @property + def marketplace_genders(self) -> Dict[Gender, Optional[MarketplaceCondition]]: + return { + Gender.MALE: PrecisionCondition( + question_id="gender", + values=["male"], + value_type=ConditionValueType.LIST, + ), + Gender.FEMALE: PrecisionCondition( + question_id="gender", + values=["female"], + value_type=ConditionValueType.LIST, + ), + Gender.OTHER: None, + } + + def __repr__(self) -> str: + # Fancy repr that abbreviates exclude_pids and excluded_surveys + repr_args = list(self.__repr_args__()) + for n, (k, v) in enumerate(repr_args): + if k in {"excluded_surveys"}: + if v and len(v) > 6: + v = sorted(v) + v = v[:3] + ["…"] + v[-3:] + repr_args[n] = (k, v) + join_str = ", " + repr_str = join_str.join( + repr(v) if a is None else f"{a}={v!r}" for a, v in repr_args + ) + return f"{self.__repr_name__()}({repr_str})" + + def is_unchanged(self, other): + return self.model_dump( + exclude={"updated", "conditions", "created"} + ) == other.model_dump(exclude={"updated", "conditions", "created"}) + + def to_mysql(self): + d = self.model_dump( + mode="json", + exclude={ + "all_hashes", + "country_iso", + "language_iso", + "source", + "conditions", + "country_isos", + "language_isos", + }, + ) + d["quotas"] = json.dumps(d["quotas"]) + d["used_question_ids"] = json.dumps(sorted(d["used_question_ids"])) + d["expected_end_date"] = self.expected_end_date + d["updated"] = self.updated + d["created"] = self.created + return d + + @classmethod + def from_db(cls, d: Dict[str, Any]): + d["created"] = d["created"].replace(tzinfo=timezone.utc) + d["updated"] = d["updated"].replace(tzinfo=timezone.utc) + d["expected_end_date"] = ( + d["expected_end_date"].replace(tzinfo=timezone.utc) + if d["expected_end_date"] + else None + ) + d["quotas"] = json.loads(d["quotas"]) + d["used_question_ids"] = json.loads(d["used_question_ids"]) + return cls.model_validate(d) + + def passes_quotas(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # We have to match 1 or more quota. + # Quotas are exclusionary: they can NOT match a quota where currently_open=0 + any_pass = False + for q in self.quotas: + matches = q.matches(criteria_evaluation) + if matches and not q.is_open: + return False + if matches: + any_pass = True + return any_pass + + def passes_quotas_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Quotas are exclusionary. They can NOT match a quota where currently_open=0 + quota_eval = { + quota: quota.matches_soft(criteria_evaluation) for quota in self.quotas + } + evals = set(g[0] for g in quota_eval.values()) + if any(m[0] is True and not q.is_open for q, m in quota_eval.items()): + # matched a full quota + return False, set() + if any(m[0] is None and not q.is_open for q, m in quota_eval.items()): + # Unknown match for full quota + if True in evals: + # we match 1 other, so the missing are only this type + return None, set( + flatten( + [ + m[1] + for q, m in quota_eval.items() + if m[0] is None and not q.is_open + ] + ) + ) + else: + # we don't match any quotas, so everything is unknown + return None, set( + flatten([m[1] for q, m in quota_eval.items() if m[0] is None]) + ) + if True in evals: + return True, set() + if None in evals: + return None, set( + flatten([m[1] for q, m in quota_eval.items() if m[0] is None]) + ) + return False, set() + + def determine_eligibility( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> bool: + return self.is_open and self.passes_quotas(criteria_evaluation) + + def determine_eligibility_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Optional[Set[str]]]: + if not self.is_open: + return False, None + return self.passes_quotas_soft(criteria_evaluation) + + def participation_allowed( + self, att_survey_ids: Set[str], att_group_ids: Set[str] + ) -> bool: + """ + Checks if this user can participate in this survey + :param att_survey_ids: list of the user's previously attempted survey IDs + :param att_job_ids: list of the user's previously attempted survey ID's Job IDs + """ + assert isinstance(att_survey_ids, set), "must pass a set" + assert isinstance(att_group_ids, set), "must pass a set" + if self.survey_id in att_survey_ids: + return False + if self.group_id in att_group_ids: + return False + if self.excluded_surveys & att_survey_ids: + return False + return True diff --git a/generalresearch/models/precision/task_collection.py b/generalresearch/models/precision/task_collection.py new file mode 100644 index 0000000..e2942b5 --- /dev/null +++ b/generalresearch/models/precision/task_collection.py @@ -0,0 +1,82 @@ +from typing import List + +import pandas as pd +from pandera import Column, DataFrameSchema, Check, Index + +from generalresearch.locales import Localelator +from generalresearch.models.precision import PrecisionStatus +from generalresearch.models.precision.survey import PrecisionSurvey +from generalresearch.models.thl.survey.task_collection import ( + TaskCollection, + create_empty_df_from_schema, +) + +COUNTRY_ISOS = Localelator().get_all_countries() +LANGUAGE_ISOS = Localelator().get_all_languages() + +PrecisionTaskCollectionSchema = DataFrameSchema( + columns={ + "status": Column(str, Check.isin(PrecisionStatus)), + "cpi": Column(float, Check.between(min_value=0, max_value=100)), + "group_id": Column(str), + "name": Column(str), + "survey_guid": Column(str), + "category_id": Column(str), + "buyer_id": Column(str), + "country_iso": Column(str, Check.isin(COUNTRY_ISOS)), + "country_isos": Column(str), # comma-sep string + "language_iso": Column(str, Check.isin(LANGUAGE_ISOS)), + "language_isos": Column(str), # comma-sep string + "bid_loi": Column("Int32", Check.between(0, 90 * 60), nullable=True), + "bid_ir": Column(float, Check.between(0, 1), nullable=True), + "global_conversion": Column(float, Check.between(0, 1), nullable=True), + "desired_count": Column(int), + "achieved_count": Column(int), + "allowed_devices": Column(str), + "expected_end_date": Column(dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=True), + "created": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "updated": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "used_question_ids": Column(List[str]), + "all_hashes": Column(List[str]), # set >> list for column support + }, + checks=[], + index=Index( + str, + name="survey_id", + checks=Check.str_length(min_value=1, max_value=16), + unique=True, + ), + strict=True, + coerce=True, + drop_invalid_rows=False, +) + + +class PrecisionTaskCollection(TaskCollection): + items: List[PrecisionSurvey] + _schema = PrecisionTaskCollectionSchema + + def to_row(self, s: PrecisionSurvey): + d = s.model_dump( + mode="json", + exclude={ + "qualifications", + "quotas", + "source", + "conditions", + "is_live", + "excluded_surveys", + "entry_link", + }, + ) + d["cpi"] = float(s.cpi) + return d + + def to_df(self): + rows = [] + for s in self.items: + rows.append(self.to_row(s)) + if rows: + return pd.DataFrame.from_records(rows, index="survey_id") + else: + return create_empty_df_from_schema(self._schema) diff --git a/generalresearch/models/prodege/__init__.py b/generalresearch/models/prodege/__init__.py new file mode 100644 index 0000000..c7bc4e7 --- /dev/null +++ b/generalresearch/models/prodege/__init__.py @@ -0,0 +1,37 @@ +from enum import Enum +from typing import Literal + +from pydantic import Field +from typing_extensions import Annotated + +ProdegeQuestionIdType = Annotated[ + str, Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") +] + + +class ProdegeStatus(str, Enum): + LIVE = "LIVE" + # We need another status to mark if a survey we thought was live does not come back + # from the API, we'll mark it as NOT_FOUND + NOT_FOUND = "NOT_FOUND" + # We need another status to mark if a survey is ineligible for entrances (b/c it doesn't have a single + # live quota) + INELIGIBLE = "INELIGIBLE" + + +class ProdegePastParticipationType(str, Enum): + # These come from the "participation_types" key in the survey API response + # which is how we filter by users' past_participation. + CLICK = "click" + COMPLETE = "complete" + DQ = "dq" + OQ = "oq" + + +# This is the value of the 'status' url param in the redirect +# https://developer.prodege.com/surveys-feed/redirects +# Note: there is no status for ProdegePastParticipationType.CLICK b/c that would be an abandonent +# Note: there is no ProdegePastParticipationType for quality (status 4) +ProdgeRedirectStatus = Literal["1", "2", "3", "4"] +# I'm not using the ProdegePastParticipationType for the values here b/c there is not a 1-to-1 mapping. +ProdgeRedirectStatusNameMap = {"1": "complete", "2": "oq", "3": "dq", "4": "quality"} diff --git a/generalresearch/models/prodege/definitions.py b/generalresearch/models/prodege/definitions.py new file mode 100644 index 0000000..0fdd2fc --- /dev/null +++ b/generalresearch/models/prodege/definitions.py @@ -0,0 +1,187 @@ +PG_COUNTRY_TO_ISO = { + 256: "bv", + 257: "cc", + 258: "fk", + 259: "gf", + 261: "gs", + 263: "hm", + 265: "pw", + 266: "um", + 268: "nr", + 269: "pm", + 271: "km", + 273: "st", + 275: "pn", + 277: "bq", + 34: "cn", + 35: "jp", + 36: "th", + 38: "my", + 39: "kr", + 40: "hk", + 41: "tw", + 42: "ph", + 43: "vn", + 46: "se", + 47: "it", + 49: "at", + 50: "nl", + 51: "ae", + 52: "il", + 53: "ua", + 54: "cz", + 55: "ru", + 56: "kz", + 57: "pt", + 58: "gr", + 59: "sa", + 60: "dk", + 62: "no", + 63: "mx", + 64: "bm", + 65: "vi", + 66: "pr", + 68: "sg", + 69: "id", + 70: "np", + 72: "pk", + 73: "ch", + 75: "bs", + 77: "ar", + 78: "uy", + 79: "dm", + 80: "bd", + 81: "tk", + 82: "kh", + 83: "mo", + 85: "af", + 86: "nc", + 89: "wf", + 90: "pl", + 91: "ro", + 92: "tr", + 93: "sk", + 95: "fi", + 96: "am", + 97: "si", + 99: "li", + 100: "qa", + 101: "be", + 102: "ng", + 103: "bg", + 104: "is", + 105: "al", + 106: "cy", + 107: "lu", + 108: "hu", + 109: "ee", + 110: "by", + 111: "lv", + 114: "md", + 116: "lt", + 117: "hr", + 118: "ba", + 121: "az", + 123: "sm", + 124: "br", + 125: "sj", + 126: "za", + 127: "ve", + 128: "co", + 129: "eg", + 130: "cl", + 131: "dz", + 132: "pe", + 133: "kw", + 134: "ma", + 135: "ao", + 138: "ec", + 139: "om", + 140: "do", + 141: "lk", + 142: "tn", + 143: "gt", + 145: "rs", + 147: "cr", + 148: "ke", + 150: "pa", + 151: "jo", + 154: "cm", + 155: "sv", + 156: "bh", + 157: "tt", + 158: "bo", + 159: "gh", + 160: "py", + 161: "ug", + 163: "hn", + 164: "gq", + 165: "jm", + 167: "ad", + 168: "fo", + 170: "gl", + 171: "gg", + 172: "va", + 173: "im", + 174: "mt", + 175: "mc", + 176: "me", + 179: "tm", + 180: "cd", + 182: "bz", + 183: "sn", + 184: "mg", + 188: "ml", + 189: "bj", + 190: "td", + 191: "bw", + 194: "cg", + 196: "gm", + 200: "bf", + 201: "sl", + 203: "ne", + 204: "cf", + 206: "tg", + 207: "bi", + 208: "sc", + 210: "gw", + 213: "dj", + 215: "ni", + 217: "ky", + 219: "mh", + 220: "aq", + 221: "bb", + 222: "aw", + 223: "ai", + 225: "gd", + 228: "tc", + 229: "ag", + 230: "tv", + 233: "vu", + 234: "er", + 236: "sh", + 238: "eh", + 239: "cx", + 241: "io", + 242: "gu", + 245: "ck", + 246: "ki", + 247: "nu", + 249: "tf", + 251: "yt", + 252: "nf", + 253: "as", + 254: "bn", + 255: "bt", + 9: "es", + 8: "fr", + 7: "de", + 6: "in", + 5: "ie", + 2: "ca", + 3: "gb", + 67: "nz", + 4: "au", + 1: "us", +} +ISO_TO_PG_COUNTRY = {v: k for k, v in PG_COUNTRY_TO_ISO.items()} diff --git a/generalresearch/models/prodege/question.py b/generalresearch/models/prodege/question.py new file mode 100644 index 0000000..0ae4548 --- /dev/null +++ b/generalresearch/models/prodege/question.py @@ -0,0 +1,243 @@ +# https://developer.prodege.com/surveys-feed/api-reference/lookup-calls/lookup-questions-by-countryid +from __future__ import annotations + +import json +import logging +from datetime import datetime, timezone +from enum import Enum +from functools import cached_property +from typing import List, Optional, Literal, Any, Dict, Set + +from pydantic import BaseModel, Field, model_validator, ConfigDict, PositiveInt + +from generalresearch.locales import Localelator +from generalresearch.models import Source, MAX_INT32 +from generalresearch.models.custom_types import AwareDatetimeISO +from generalresearch.models.prodege import ProdegeQuestionIdType +from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +locale_helper = Localelator() + + +class ProdegeUserQuestionAnswer(BaseModel): + # This is optional b/c this model can be used for eligibility checks for "anonymous" users, which are represented + # by a list of question answers not associated with an actual user. No default b/c we must explicitly set + # the field to None. + user_id: Optional[PositiveInt] = Field(lt=MAX_INT32) + question_id: ProdegeQuestionIdType = Field() + # This is optional b/c we do not need it when writing these to the db. When these are fetched from the db + # for use in yield-management, we read this field from the prodege_question table. + question_type: Optional[ProdegeQuestionType] = Field(default=None) + # This may be a pipe-separated string if the question_type is multi. regex means any chars except capital letters + option_id: str = Field(pattern=r"^[^A-Z]*$") + created: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: str = Field( + max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True + ) + # 3-char ISO 639-2/B, lowercase + language_iso: str = Field( + max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True + ) + + @cached_property + def options_ids(self) -> Set[str]: + return set(self.option_id.split("|")) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", exclude={"question_type"}) + d["created"] = self.created.replace(tzinfo=None) + return d + + +class ProdegeQuestionOption(BaseModel): + model_config = ConfigDict(populate_by_name=True) + id: str = Field( + min_length=1, + max_length=16, + pattern=r"^([0-9]+)|-1|-3105$", + frozen=True, + validation_alias="option_id", + description="The unique identifier for a response to a qualification", + ) + text: str = Field( + min_length=1, + max_length=1024, + validation_alias="option_text", + description="The response text shown to respondents", + ) + # Order does not come back explicitly in the API, but the responses seem to be ordered + order: int = Field() + # Both is_exclusive and is_anchored are returned, but I don't see how they are different. + # We are merging them both into is_exclusive. + is_exclusive: bool = Field(default=False) + + +class ProdegeQuestionType(str, Enum): + """ + {'Derived', 'Multi Punch', 'Numeric - Open End', 'Single Punch', 'Zip Code'} + """ + + SINGLE_SELECT = "s" + MULTI_SELECT = "m" + TEXT_ENTRY = "t" + UNKNOWN = "u" + + @classmethod + def from_api(cls, a: int): + API_TYPE_MAP = { + "Single-Select": ProdegeQuestionType.SINGLE_SELECT, + "Multi-Select": ProdegeQuestionType.MULTI_SELECT, + "Numeric": ProdegeQuestionType.TEXT_ENTRY, + "Text": ProdegeQuestionType.TEXT_ENTRY, + } + return API_TYPE_MAP[a] + + +class ProdegeQuestion(MarketplaceQuestion): + model_config = ConfigDict(extra="ignore", populate_by_name=True) + question_id: ProdegeQuestionIdType = Field( + description="The unique identifier for the qualification", frozen=True + ) + question_name: str = Field(min_length=1, max_length=64, frozen=True) + question_text: str = Field(max_length=1024, min_length=1) + question_type: ProdegeQuestionType = Field(frozen=True) + # This comes from the API category, but is not great (most are "Consumer Lifestyle") + tags: Optional[str] = Field(default=None, frozen=True) + options: Optional[List[ProdegeQuestionOption]] = Field(default=None, min_length=1) + source: Literal[Source.PRODEGE] = Source.PRODEGE + + @property + def internal_id(self) -> str: + return self.question_id + + @model_validator(mode="after") + def check_type_options_agreement(self): + if self.question_type == ProdegeQuestionType.TEXT_ENTRY: + assert self.options is None, "TEXT_ENTRY shouldn't have options" + else: + assert self.options is not None, "missing options" + return self + + @classmethod + def from_api(cls, d: dict, country_iso: str) -> Optional["ProdegeQuestion"]: + """ + :param d: Raw response from API + """ + try: + return cls._from_api(d, country_iso) + except Exception as e: + logger.warning(f"Unable to parse question: {d}. {e}") + return None + + @classmethod + def _from_api(cls, d: dict, country_iso: str) -> "ProdegeQuestion": + # The API has no concept of language at all. Questions for a country + # are returned both in english and other languages. Questions do have + # a field 'country_specific', and if True, that generally means the + # question's language is the country's default lang. So we're mostly + # guessing here ... + d["question_id"] = str(d["question_id"]) + d["language_iso"] = ( + "eng" + if d["country_specific"] is False + else (locale_helper.get_default_lang_from_country(country_iso)) + ) + d["country_iso"] = country_iso + d["question_type"] = ProdegeQuestionType.from_api(d["question_type"]) + d["tags"] = d["category"].lower() + if not d["question_text"]: + d["question_text"] = d["question_name"] + if d["question_type"] == ProdegeQuestionType.TEXT_ENTRY: + d["options"] = None + if d["options"]: + d["options"] = [ + ProdegeQuestionOption( + id=str(r["option_id"]), + text=r["option_text"], + order=n, + is_exclusive=r["is_exclusive"] or r["is_anchored"], + ) + for n, r in enumerate(d["options"]) + if r and r["option_text"] + ] + return cls.model_validate(d) + + @classmethod + def from_db(cls, d: dict): + options = None + if d["options"]: + options = [ + ProdegeQuestionOption( + id=r["id"], + text=r["text"], + order=r["order"], + is_exclusive=r.get("is_exclusive", False), + ) + for r in d["options"] + ] + return cls( + question_id=d["question_id"], + question_text=d["question_text"], + question_name=d["question_name"], + question_type=d["question_type"], + country_iso=d["country_iso"], + language_iso=d["language_iso"], + options=options, + is_live=d["is_live"], + category_id=d.get("category_id"), + tags=d.get("tags"), + ) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", by_alias=True) + d["options"] = json.dumps(d["options"]) + return d + + def to_upk_question(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionType, + UpkQuestionSelectorMC, + UpkQuestionSelectorTE, + UpkQuestion, + order_exclusive_options, + ) + + upk_type_selector_map = { + ProdegeQuestionType.SINGLE_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.SINGLE_ANSWER, + ), + ProdegeQuestionType.MULTI_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.MULTIPLE_ANSWER, + ), + ProdegeQuestionType.TEXT_ENTRY: ( + UpkQuestionType.TEXT_ENTRY, + UpkQuestionSelectorTE.SINGLE_LINE, + ), + } + upk_type, upk_selector = upk_type_selector_map[self.question_type] + d = { + "ext_question_id": self.external_id, + "country_iso": self.country_iso, + "language_iso": self.language_iso, + "type": upk_type, + "selector": upk_selector, + "text": self.question_text, + } + if self.options: + d["choices"] = [ + UpkQuestionChoice(id=c.id, text=c.text, order=n) + for n, c in enumerate(self.options) + ] + q = UpkQuestion(**d) + order_exclusive_options(q) + return q diff --git a/generalresearch/models/prodege/survey.py b/generalresearch/models/prodege/survey.py new file mode 100644 index 0000000..378f0a7 --- /dev/null +++ b/generalresearch/models/prodege/survey.py @@ -0,0 +1,747 @@ +# https://developer.prodege.com/surveys-feed/api-reference/survey-matching/surveys +from __future__ import annotations + +import json +import logging +from collections import defaultdict +from datetime import timezone, datetime +from decimal import Decimal +from functools import cached_property +from typing import List, Optional, Dict, Any, Set, Literal, Tuple, Type + +from pydantic import ( + BaseModel, + Field, + ConfigDict, + computed_field, + model_validator, + field_validator, +) + +from generalresearch.locales import Localelator +from generalresearch.models import LogicalOperator, Source, TaskCalculationType +from generalresearch.models.custom_types import ( + AlphaNumStrSet, + InclExcl, + UUIDStr, + CoercedStr, + AwareDatetimeISO, +) +from generalresearch.models.prodege import ( + ProdegeStatus, + ProdegeQuestionIdType, + ProdgeRedirectStatus, + ProdegePastParticipationType, +) +from generalresearch.models.prodege.definitions import PG_COUNTRY_TO_ISO +from generalresearch.models.thl.demographics import Gender +from generalresearch.models.thl.survey import MarketplaceTask +from generalresearch.models.thl.survey.condition import ( + ConditionValueType, + MarketplaceCondition, +) + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +locale_helper = Localelator() + + +class ProdegeCondition(MarketplaceCondition): + model_config = ConfigDict(populate_by_name=True) + + question_id: ProdegeQuestionIdType = Field() + values: List[str] = Field(validation_alias="precodes") + + @classmethod + def from_api(cls, d: Dict[str, Any]) -> "ProdegeCondition": + assert d["operator"] in { + "OR", + "NOT", + "BETWEEN", + }, f"invalid operator: {d['operator']}" + d["precodes"] = [s.lower() for s in d["precodes"]] + if d["operator"] == "BETWEEN": + # They have a logical operator "between". Make this a range type with 1 range + assert len(d["precodes"]) == 2, "wtf between" + d["precodes"] = sorted(d["precodes"]) + d["precodes"] = [d["precodes"][0] + "-" + d["precodes"][1]] + d["value_type"] = ConditionValueType.RANGE + d["operator"] = LogicalOperator.OR + elif d["operator"] == "NOT": + # unclear if this is not(AND) or not(or). assuming not(or) + d["value_type"] = ConditionValueType.LIST + d["operator"] = LogicalOperator.OR + d["negate"] = True + else: + d["value_type"] = ConditionValueType.LIST + # They said if there are no precodes, it accepts any answer... supposedly. + # (https://g-r-l.slack.com/archives/C04FMFTV48N/p1712878104684299) + if len(d["precodes"]) == 0: + d["value_type"] = ConditionValueType.ANSWERED + d["question_id"] = str(d["question_id"]) + return cls.model_validate(d) + + +class ProdegeQuota(BaseModel): + model_config = ConfigDict(populate_by_name=True, frozen=True) + + # API response is "sample_size" + desired_count: int = Field( + description="The desired total number of respondents", + validation_alias="sample_size", + ) + # API response is "number_of_respondents" + remaining_count: int = Field( + description="The total number of allowed responses that remain from the sample_size", + validation_alias="number_of_respondents", + ) + condition_hashes: List[str] = Field(min_length=0, default_factory=list) + # Each quota can have a different calculation type, instead of on the survey + calculation_type: TaskCalculationType = Field( + description="Indicates whether the targets are counted per Complete or Survey Start", + default=TaskCalculationType.COMPLETES, + ) + quota_id: CoercedStr = Field() + # If the parent_quota_id is None, then this is a parent. There can be multiple parent quotas. + parent_quota_id: Optional[CoercedStr] = Field() + + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: Optional[str] = Field( + max_length=2, min_length=2, pattern=r"^[a-z]{2}$", default=None + ) + + # There is no explicit status. The quota is closed if the remaining_count is 0 + + def __hash__(self): + return hash(self.quota_id) + + @property + def is_parent(self) -> bool: + return self.parent_quota_id is None + + @property + def is_open(self) -> bool: + min_open_spots = 2 + return self.remaining_count >= min_open_spots + + @property + def condition_model(self) -> Type[MarketplaceCondition]: + return ProdegeCondition + + @property + def age_question(self) -> str: + return "1" + + @property + def marketplace_genders(self) -> Dict[Gender, Optional[MarketplaceCondition]]: + return { + Gender.MALE: ProdegeCondition( + question_id="3", + values=["1"], + value_type=ConditionValueType.LIST, + ), + Gender.FEMALE: ProdegeCondition( + question_id="3", + values=["2"], + value_type=ConditionValueType.LIST, + ), + Gender.OTHER: None, + } + + @classmethod + def from_api(cls, d: Dict): + # the API doesn't handle None's correctly? idk + if d["parent_quota_id"] == 0: + d["parent_quota_id"] = None + d["calculation_type"] = TaskCalculationType.prodege_from_api( + d["calculation_type"] + ) + if d.get("country_id"): + d["country_iso"] = PG_COUNTRY_TO_ISO[d["country_id"]] + return cls.model_validate(d) + + def passes( + self, criteria_evaluation: Dict[str, Optional[bool]], country_iso: str + ) -> bool: + # Passes means we 1) meet all conditions (aka "match") AND 2) the quota is open. + return self.is_open and self.matches( + criteria_evaluation, country_iso=country_iso + ) + + def matches( + self, criteria_evaluation: Dict[str, Optional[bool]], country_iso: str + ) -> bool: + # Match means we meet all conditions. + # We can "match" a quota that is closed. In that case, we would fail the parent quota + return self.matches_country(country_iso) and all( + criteria_evaluation.get(c) for c in self.condition_hashes + ) + + def matches_country(self, country_iso: str): + return self.country_iso is None or self.country_iso == country_iso + + def passes_verbose( + self, criteria_evaluation: Dict[str, Optional[bool]], country_iso: str + ) -> bool: + print(f"quota.is_open: {self.is_open}") + print( + ", ".join( + [f"{c}: {criteria_evaluation.get(c)}" for c in self.condition_hashes] + ) + ) + return ( + self.matches_country(country_iso) + and self.is_open + and all(criteria_evaluation.get(c) for c in self.condition_hashes) + ) + + def passes_soft( + self, criteria_evaluation: Dict[str, Optional[bool]], country_iso: str + ) -> Tuple[Optional[bool], Set[str]]: + # Passes back "passes" (T/F/none) and a list of unknown criterion hashes + if self.is_open is False: + return False, set() + if not self.matches_country(country_iso): + return False, set() + hash_evals = { + cell: criteria_evaluation.get(cell) for cell in self.condition_hashes + } + evals = set(hash_evals.values()) + # We have to match all. So if any are False, we know we don't pass + if False in evals: + return False, set() + # if any are None, we don't know + elif None in evals: + return None, {cell for cell, ev in hash_evals.items() if ev is None} + else: + return True, set() + + +class ProdegeMaxClicksSetting(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + # The total number of clicks allowed before survey traffic is paused. + cap: int = Field(validation_alias="max_clicks_cap") + # The current remaining number of clicks before survey traffic is paused. + allowed_clicks: int = Field(validation_alias="max_clicks_allowed_clicks") + # The refill rate id for clicks (1: every 30 min, 2: every 1 hour, 3: every 24 hours, 0: one-time setting). + # (not going to bother structuring this, we can't really use it...) + max_click_rate_id: int = Field(validation_alias="max_clicks_max_click_rate_id") + + +class ProdegeUserPastParticipation(BaseModel): + # Represents the participation of a user in a Prodege task. This is stored in the + # prodege_sessionattempthistory table + model_config = ConfigDict(frozen=True) + + survey_id: str = Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") + started: AwareDatetimeISO = Field() + # This is what is returned in the redirect in the url param "status". + ext_status_code_1: Optional[ProdgeRedirectStatus] = Field(default=None) + + @property + def participation_types(self) -> Set[ProdegePastParticipationType]: + # If the survey is filtering completes, then only a complete counts. But if the survey is filtering + # on clicks, then a person who got a complete ALSO did click. And so, the logic here is that + # participation_types should always include "click". + if self.ext_status_code_1 is None: + return {ProdegePastParticipationType.CLICK} + elif self.ext_status_code_1 == "1": + return { + ProdegePastParticipationType.CLICK, + ProdegePastParticipationType.COMPLETE, + } + elif self.ext_status_code_1 == "2": + return {ProdegePastParticipationType.CLICK, ProdegePastParticipationType.OQ} + elif self.ext_status_code_1 == "3": + return {ProdegePastParticipationType.CLICK, ProdegePastParticipationType.DQ} + elif self.ext_status_code_1 == "4": + # 4 means "Quality Disqualification". unclear which participation type this is. + return {ProdegePastParticipationType.CLICK, ProdegePastParticipationType.DQ} + + def days_ago(self) -> float: + now = datetime.now(timezone.utc) + return (now - self.started).total_seconds() / (3600 * 24) + + +class ProdegePastParticipation(BaseModel): + model_config = ConfigDict(populate_by_name=True) + # They call a survey a project. + survey_ids: AlphaNumStrSet = Field(validation_alias="participation_project_ids") + filter_type: InclExcl = Field() + # API has a mistake. We treat 0 as null + in_past_days: Optional[int] = Field(default=None) + participation_types: List[ProdegePastParticipationType] = Field() + + """ + e.g. Anyone who got a complete in either of these projects in the past 7 days, + is not allowed to participate in this task. + {'participation_project_ids': [152677146, 152803285], + 'filter_type': 'exclude', + 'in_past_days': 7, + 'participation_types': ['complete']} + """ + + @classmethod + def from_api(cls, d: Dict): + # the API doesn't handle None's correctly? idk + if d["in_past_days"] == 0: + d["in_past_days"] = None + d["participation_project_ids"] = list(map(str, d["participation_project_ids"])) + return cls.model_validate(d) + + def user_participated(self, user_participation: ProdegeUserPastParticipation): + # Given this user's participation event (1 single event), is it being filtered by this survey? + return ( + user_participation.survey_id in self.survey_ids + and ( + (self.in_past_days is None) + or (self.in_past_days > user_participation.days_ago()) + ) + and user_participation.participation_types.intersection( + self.participation_types + ) + ) + + def is_eligible( + self, user_participations: List[ProdegeUserPastParticipation] + ) -> bool: + if self.filter_type == "include": + # User is only eligible if they HAVE participated. Return True as soon as they match anything. + for user_participation in user_participations: + if self.user_participated(user_participation): + return True + return False + else: + # User is only eligible if they HAVE NOT participated. We have to check ALL of their past participations, + # but we can return False as soon as one fails. + for user_participation in user_participations: + if self.user_participated(user_participation): + return False + return True + + +class ProdegeSurvey(MarketplaceTask): + model_config = ConfigDict(populate_by_name=True) + + survey_id: CoercedStr = Field( + min_length=1, max_length=16, pattern=r"^[0-9]+$", validation_alias="projectid" + ) + survey_name: str = Field(max_length=256, validation_alias="project_name") + status: ProdegeStatus = Field(default=ProdegeStatus.LIVE) # not returned from API + # API returns more than 2 decimal places, but we are storing it in the db with max 2 ... + cpi: Decimal = Field(gt=0, le=100, decimal_places=2) + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: str = Field( + max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True + ) + # 3-char ISO 639-2/B, lowercase + language_iso: str = Field( + max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True + ) + + desired_count: int = Field( + description="The desired total number of respondents", + validation_alias="sample_size", + gt=0, + ) + # Unclear if this is always completes, or not if the TaskCalculationType is STARTS + # Unclear if the survey.remaining_completes is tracking the same thing as the quota.number_of_respondents ... ? + remaining_count: int = Field( + description="The total number of allowed responses that remain from the sample_size", + validation_alias="remaining_completes", + ge=0, + ) + # Unclear if this is always completes, or not if the TaskCalculationType is STARTS + achieved_completes: int = Field( + description="idk, not in the documentation. Seems to show the actual number of" + "achieved completes globally, not just for us.", + ge=0, + default=0, + ) + + # Only the bid or actual value are returned in API res. We're going to + # have to store it in the db if we see it. In API res, these are called + # "loi" and "actual_ir", but the actual IR is only actually the actual + # IR if the "phases" is "actual" :facepalm: + bid_loi: Optional[int] = Field(default=None, le=120 * 60) + bid_ir: Optional[float] = Field(default=None, ge=0, le=1) + actual_loi: Optional[int] = Field(default=None, le=120 * 60) + actual_ir: Optional[float] = Field(default=None, ge=0, le=1) + # Unclear what the difference is bw IR and conversion + conversion_rate: Optional[float] = Field(default=None, ge=0, le=1) + + entrance_url: str = Field( + description="The link survey respondents should be sent to", + validation_alias="surveyurl", + ) + + # This described time-based click rate limiting. + max_clicks_settings: Optional[ProdegeMaxClicksSetting] = Field(default=None) + # This describes the project/surveygroup exclusions + past_participation: Optional[ProdegePastParticipation] = Field(default=None) + # These describe the panelist exclusions/inclusions + include_psids: Optional[Set[UUIDStr]] = Field(default=None) + exclude_psids: Optional[Set[UUIDStr]] = Field(default=None) + + # There are no "qualifications" per se. Instead, everyone has the match a + # parent quota (and its children) qualifications: List[str] = + # Field(default_factory=list) + + # The eligibility is somewhat complex, with parent and children quotas. + # Going to keep it flat here. + quotas: List[ProdegeQuota] = Field(default_factory=list) + + source: Literal[Source.PRODEGE] = Field(default=Source.PRODEGE) + + used_question_ids: Set[ProdegeQuestionIdType] = Field(default_factory=set) + + # This is a "special" key to store all conditions that are used (as + # "condition_hashes") throughout this survey. In the reduced representation + # of this task (nearly always, for db i/o, in global_vars) this field will + # be null. + conditions: Optional[Dict[str, ProdegeCondition]] = Field(default=None) + + # These do not come from the API. We set them. + created: Optional[AwareDatetimeISO] = Field( + description="when we created this survey in our system", default=None + ) + updated: Optional[AwareDatetimeISO] = Field(default=None) + + @property + def internal_id(self) -> str: + return self.survey_id + + @computed_field + def is_live(self) -> bool: + return self.status == ProdegeStatus.LIVE + + @computed_field + def is_recontact(self) -> bool: + return self.include_psids is not None + + @property + def is_open(self) -> bool: + # The survey is open if the status is OPEN and there is at least 1 open + # quota (or there are no quotas!), and the remaining_count > 0, and + # the max_clicks (if exists) > 0 + return ( + self.is_live + and (any(q.is_open for q in self.quotas) or len(self.quotas) == 0) + and self.remaining_count >= 2 + and ( + self.max_clicks_settings is None + or self.max_clicks_settings.allowed_clicks > 0 + ) + ) + + @model_validator(mode="before") + @classmethod + def set_locale(cls, data: Any): + if not data.get("country_isos"): + country_isos = [ + q["country_iso"] for q in data["quotas"] if q.get("country_iso") + ] + if country_isos: + data["country_isos"] = country_isos + data["language_isos"] = [ + locale_helper.get_default_lang_from_country(c) + for c in data["country_isos"] + ] + else: + data["country_isos"] = [data["country_iso"]] + data["language_isos"] = [data["language_iso"]] + return data + + @model_validator(mode="before") + @classmethod + def set_used_questions(cls, data: Any): + if data.get("used_question_ids") is not None: + return data + if not data.get("conditions"): + data["used_question_ids"] = set() + return data + data["used_question_ids"] = { + c.question_id for c in data["conditions"].values() if c.question_id + } + return data + + @property + def condition_model(self) -> Type[MarketplaceCondition]: + return ProdegeCondition + + @property + def age_question(self) -> str: + return "1" + + @property + def marketplace_genders(self) -> Dict[Gender, Optional[MarketplaceCondition]]: + return { + Gender.MALE: ProdegeCondition( + question_id="3", values=["1"], value_type=ConditionValueType.LIST + ), + Gender.FEMALE: ProdegeCondition( + question_id="3", values=["2"], value_type=ConditionValueType.LIST + ), + Gender.OTHER: None, + } + + @field_validator("cpi", mode="before") + def round_to_two_decimals(cls, v): + return round(float(v), 2) + + @classmethod + def from_api(cls, d: Dict) -> Optional["ProdegeSurvey"]: + try: + return cls._from_api(d) + except Exception as e: + logger.warning(f"Unable to parse survey: {d}. {e}") + return None + + @classmethod + def _from_api(cls, d: Dict): + # handle phases. keys in api response are 'loi' and 'actual_ir' + if d["phases"]["loi_phase"] == "actual": + d["actual_loi"] = d.pop("loi") * 60 + else: + d["bid_loi"] = d.pop("loi") * 60 + if d["phases"]["actual_ir_phase"] == "actual": + d["actual_ir"] = d.pop("actual_ir") / 100 + else: + d["bid_ir"] = d.pop("actual_ir") / 100 + d["conversion_rate"] = ( + d["conversion_rate"] / 100 if d["conversion_rate"] else None + ) + if d.get("country_code"): + d["country_isos"] = [ + locale_helper.get_country_iso(d.pop("country_code").lower()) + ] + d["country_iso"] = sorted(d["country_isos"])[0] + # No languages are returned anywhere for anything + d["language_isos"] = [ + locale_helper.get_default_lang_from_country(d["country_isos"][0]) + ] + d["language_iso"] = locale_helper.get_default_lang_from_country( + d["country_iso"] + ) + + if d.get("past_participation"): + d["past_participation"] = ProdegePastParticipation.from_api( + d["past_participation"] + ) + d["conditions"] = dict() + for quota in d["quotas"]: + quota["condition_hashes"] = [] + for c in quota["targeting_criteria"]: + c["value_type"] = ConditionValueType.LIST + c = ProdegeCondition.from_api(c) + d["conditions"][c.criterion_hash] = c + quota["condition_hashes"].append(c.criterion_hash) + d["quotas"] = [ProdegeQuota.from_api(q) for q in d["quotas"]] + countries = {q.country_iso for q in d["quotas"] if q.country_iso} + if countries: + d["country_iso"] = sorted(countries)[0] + d["country_isos"] = countries + d["language_iso"] = locale_helper.get_default_lang_from_country( + d["country_iso"] + ) + d["language_isos"] = [ + locale_helper.get_default_lang_from_country(c) + for c in d["country_isos"] + ] + return cls.model_validate(d) + + @computed_field + @cached_property + def all_hashes(self) -> Set[str]: + s = set() + for q in self.quotas: + s.update(set(q.condition_hashes)) + return s + + @property + def quotas_verbose(self) -> List[List[Dict[str, Any]]]: + assert self.conditions is not None, "conditions must be set" + res = [] + for quota_group in self.quotas: + sub_res = [] + res.append(sub_res) + for quota in quota_group.root: + q = quota.model_dump(mode="json") + q["conditions"] = [ + self.conditions[c].minified for c in quota.condition_hashes + ] + sub_res.append(q) + return res + + def is_unchanged(self, other): + # Avoiding overloading __eq__ because it looks kind of complicated? I + # want to be explicit that this is not testing object equivalence, + # just that the objects don't require any db updates. We also exclude + # conditions b/c this is just the condition_hash definitions + + # This is also especially bad bc the api returns ONLY bid OR actual + # values, and so if a survey is stored with bid values in the db, the + # api doesn't have them, it'll always be changed. Also, the name of + # the survey changes randomly? idk. ignore that too + o1 = self.model_dump( + exclude={"created", "updated", "conditions", "survey_name"} + ) + o2 = other.model_dump( + exclude={"created", "updated", "conditions", "survey_name"} + ) + if o1 == o2: + # We don't have to check bid/actual, b/c we already know it's not changed + return True + + # Ignore bid fields if either one is NULL + for k in ["bid_loi", "bid_ir"]: + if o1.get(k) is None or o2.get(k) is None: + o1.pop(k, None) + o2.pop(k, None) + + return o1 == o2 + + def to_mysql(self): + d = self.model_dump( + mode="json", + exclude={ + "all_hashes", + "country_isos", + "language_isos", + "source", + "conditions", + "buyer_id", + "is_recontact", + }, + ) + d["quotas"] = json.dumps(d["quotas"]) + for k in [ + "max_clicks_settings", + "past_participation", + "include_psids", + "exclude_psids", + ]: + d[k] = json.dumps(d[k]) if d[k] else None + d["used_question_ids"] = json.dumps(d["used_question_ids"]) + d["created"] = self.created + d["updated"] = self.updated + return d + + @classmethod + def from_db(cls, d: Dict[str, Any]): + d["created"] = d["created"].replace(tzinfo=timezone.utc) + d["updated"] = d["updated"].replace(tzinfo=timezone.utc) + d["quotas"] = json.loads(d["quotas"]) + for k in [ + "max_clicks_settings", + "past_participation", + "include_psids", + "exclude_psids", + "used_question_ids", + ]: + d[k] = json.loads(d[k]) if d[k] else None + # Need to re set countries from quotas here? Or not? + # countries should be a property not a field anyways (todo:) + return cls.model_validate(d) + + # ---- Yield Management ---- + + def passes_quotas( + self, + criteria_evaluation: Dict[str, Optional[bool]], + country_iso: str, + verbose=False, + ) -> bool: + # https://developer.prodege.com/surveys-feed/api-reference/survey-matching/quota-structure + # https://developer.prodege.com/surveys-feed/api-reference/survey-matching/quota-matching-requirements + parent_quotas = [q for q in self.quotas if q.is_parent] + child_quotas = [q for q in self.quotas if not q.is_parent] + quota_map = {q.quota_id: q for q in self.quotas} + parent_children = defaultdict(set) + for q in child_quotas: + parent_children[q.parent_quota_id].add(q.quota_id) + + # To be eligible for a survey, we need to match ANY parent quota. To + # match a parent quota, we need to match at least 1 child quota and + # NOT match any closed children. + passing_parent_quotas = [ + quota + for quota in parent_quotas + if quota.passes(criteria_evaluation, country_iso=country_iso) + ] + if not passing_parent_quotas: + if verbose: + print("No passing parent quotas") + return False + for quota in passing_parent_quotas: + if verbose: + print("parent") + print( + quota.passes_verbose(criteria_evaluation, country_iso=country_iso) + ) + child_quotas = [ + quota_map[quota_id] for quota_id in parent_children[quota.quota_id] + ] + passes = self.passes_child_quotas( + criteria_evaluation, + child_quotas=child_quotas, + country_iso=country_iso, + verbose=verbose, + ) + if passes: + return True + return False + + def passes_child_quotas( + self, + criteria_evaluation: Dict[str, Optional[bool]], + child_quotas: List[ProdegeQuota], + country_iso: str, + verbose=False, + ) -> bool: + if len(child_quotas) == 0: + # If the parent has no children, we pass + return True + + # We have to pass at least 1 child + passes = False + for quota in child_quotas: + if quota.matches(criteria_evaluation, country_iso=country_iso): + if not quota.is_open: + # If we match a closed quota, the parent fails. + if verbose: + print("matched closed quota") + return False + passes = True + # We pass tentatively now, we still have to check the rest to see if we match any closed quotas. + + if verbose: + print( + [ + quota.passes_verbose(criteria_evaluation, country_iso=country_iso) + for quota in child_quotas + ] + ) + + return passes + + def determine_eligibility( + self, criteria_evaluation: Dict[str, Optional[bool]], country_iso: str + ) -> bool: + return self.is_open and self.passes_quotas( + criteria_evaluation, country_iso=country_iso + ) + + def print_eligibility( + self, criteria_evaluation: Dict[str, Optional[bool]], country_iso: str + ) -> None: + print(f"is_open: {self.is_open}") + print("passes_quotas") + print( + self.passes_quotas( + criteria_evaluation, country_iso=country_iso, verbose=True + ) + ) diff --git a/generalresearch/models/prodege/task_collection.py b/generalresearch/models/prodege/task_collection.py new file mode 100644 index 0000000..4765f29 --- /dev/null +++ b/generalresearch/models/prodege/task_collection.py @@ -0,0 +1,97 @@ +from typing import List, Dict, Any + +import pandas as pd +from pandera import Column, DataFrameSchema, Check, Index + +from generalresearch.locales import Localelator +from generalresearch.models.prodege import ProdegeStatus +from generalresearch.models.prodege.survey import ProdegeSurvey +from generalresearch.models.thl.survey.task_collection import ( + TaskCollection, + create_empty_df_from_schema, +) + +COUNTRY_ISOS = Localelator().get_all_countries() +LANGUAGE_ISOS = Localelator().get_all_languages() + +ProdegeTaskCollectionSchema = DataFrameSchema( + columns={ + "status": Column(str, Check.isin(ProdegeStatus)), + "cpi": Column(float, Check.between(min_value=0, max_value=100)), + "country_iso": Column(str, Check.isin(COUNTRY_ISOS)), # 2 letter, lowercase + "language_iso": Column(str, Check.isin(LANGUAGE_ISOS)), # 3 letter, lowercase + "desired_count": Column(int, Check.greater_than(min_value=0)), + "remaining_count": Column(int, Check.greater_than_or_equal_to(min_value=0)), + "achieved_completes": Column(int, Check.greater_than_or_equal_to(min_value=0)), + "bid_loi": Column(int, Check.between(0, 120 * 60), nullable=True), + "bid_ir": Column(float, Check.between(0, 1), nullable=True), + "actual_loi": Column(int, Check.between(0, 120 * 60), nullable=True), + "actual_ir": Column(float, Check.between(0, 1), nullable=True), + "conversion_rate": Column(float, Check.between(0, 1), nullable=True), + "created": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "updated": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "used_question_ids": Column(List[str]), + "all_hashes": Column(List[str]), # set >> list for column support + "is_recontact": Column(bool), + # Not including here: entrance_url, max_clicks_settings, past_participation, include_psids, exclude_psids, + # quotas, source, conditions + # Adding a derived field: is_recontact, which is True is include_psids is not None + }, + checks=[], + index=Index( + str, + name="survey_id", + checks=Check.str_length(min_value=1, max_value=16), + unique=True, + ), + strict=True, + coerce=False, + drop_invalid_rows=False, +) + + +class ProdegeTaskCollection(TaskCollection): + items: List[ProdegeSurvey] + _schema = ProdegeTaskCollectionSchema + + @staticmethod + def to_row(s: ProdegeSurvey) -> Dict[str, Any]: + fields = [ + "survey_id", + "status", + "country_iso", + "language_iso", + "cpi", + "desired_count", + "remaining_count", + "achieved_completes", + "bid_loi", + "bid_ir", + "actual_loi", + "actual_ir", + "conversion_rate", + "created", + "updated", + "is_recontact", + "used_question_ids", + "all_hashes", + ] + d = dict() + for k in fields: + d[k] = getattr(s, k) + d["cpi"] = float(d["cpi"]) + d["used_question_ids"] = list(d["used_question_ids"]) + d["all_hashes"] = list(d["all_hashes"]) + return d + + def to_df(self) -> pd.DataFrame: + rows = [] + for s in self.items: + rows.append(self.to_row(s)) + if rows: + df = pd.DataFrame.from_records(rows, index="survey_id") + df["bid_loi"] = df["bid_loi"].astype("Int64") + df["actual_loi"] = df["actual_loi"].astype("Int64") + return df + else: + return create_empty_df_from_schema(self._schema) diff --git a/generalresearch/models/repdata/__init__.py b/generalresearch/models/repdata/__init__.py new file mode 100644 index 0000000..9706b28 --- /dev/null +++ b/generalresearch/models/repdata/__init__.py @@ -0,0 +1,16 @@ +from enum import Enum + + +class RepDataStatus(str, Enum): + LIVE = "LIVE" + DRAFT = "DRAFT" + PAUSED = "PAUSED" + COMPLETE = "COMPLETE" + CANCELLED = "CANCELLED" + # We need another status to mark if a survey we thought was live does not + # come back from the API, we'll mark it as NOT_FOUND + NOT_FOUND = "NOT_FOUND" + # We need another status to mark if a survey is ineligible for entrances + # (b/c it doesn't have a single live stream) and so we are not bothering + # to make API calls to update it + INELIGIBLE = "INELIGIBLE" diff --git a/generalresearch/models/repdata/question.py b/generalresearch/models/repdata/question.py new file mode 100644 index 0000000..8b4eb4e --- /dev/null +++ b/generalresearch/models/repdata/question.py @@ -0,0 +1,255 @@ +from __future__ import annotations + +import json +import logging +from enum import Enum +from functools import cached_property +from typing import List, Optional, Literal, Any, Dict, Set +from uuid import UUID + +from pydantic import ( + BaseModel, + Field, + model_validator, + ConfigDict, + field_validator, + PositiveInt, +) + +from generalresearch.models import Source, MAX_INT32 +from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +class RepDataUserQuestionAnswer(BaseModel): + # This is optional b/c this model can be used for eligibility checks for + # "anonymous" users, which are represented by a list of question answers + # not associated with an actual user. No default b/c we must explicitly + # set the field to None. + user_id: Optional[PositiveInt] = Field(lt=MAX_INT32) + question_id: str = Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") + # This is optional b/c we do not need it when writing these to the db. When + # these are fetched from the db for use in yield-management, we read this + # field from the repdata_question table. + question_type: Optional[RepDataQuestionType] = Field(default=None) + # This may be a pipe-separated string if the question_type is multi. regex + # means any chars except capital letters + option_id: str = Field(pattern=r"^[^A-Z]*$") + created: AwareDatetimeISO = Field() + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: str = Field( + max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True + ) + # 3-char ISO 639-2/B, lowercase + language_iso: str = Field( + max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True + ) + + @cached_property + def options_ids(self) -> Set[str]: + return set(self.option_id.split("|")) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", exclude={"question_type"}) + d["created"] = self.created.replace(tzinfo=None) + return d + + +class RepDataQuestionOption(BaseModel): + model_config = ConfigDict(populate_by_name=True) + id: str = Field( + min_length=1, + max_length=16, + pattern=r"^(([0-9]+)|-3105)$", + frozen=True, + validation_alias="Code", + description="The unique identifier for a response to a qualification", + ) + text: str = Field( + min_length=1, + max_length=1024, + frozen=True, + validation_alias="OptionName", + description="The response text shown to respondents", + ) + # Order does not come back explicitly in the API, but the responses seem + # to be ordered + order: int = Field() + + +class RepDataQuestionType(str, Enum): + """ + {'Derived', 'Multi Punch', 'Numeric - Open End', 'Single Punch', 'Zip Code'} + """ + + SINGLE_SELECT = "s" + MULTI_SELECT = "m" + TEXT_ENTRY = "t" + UNKNOWN = "u" + + @classmethod + def from_api(cls, a: int): + API_TYPE_MAP = { + "Single Punch": RepDataQuestionType.SINGLE_SELECT, + "Multi Punch": RepDataQuestionType.MULTI_SELECT, + "Numeric - Open End": RepDataQuestionType.TEXT_ENTRY, + "Zip Code": RepDataQuestionType.TEXT_ENTRY, + "Derived": RepDataQuestionType.UNKNOWN, + } + return API_TYPE_MAP[a] + + +class RepDataQuestion(MarketplaceQuestion): + model_config = ConfigDict(extra="ignore", populate_by_name=True) + question_id: UUIDStr = Field( + description="The unique identifier for the qualification", + validation_alias="QualificationUD", + frozen=True, + ) + question_name: str = Field( + min_length=1, max_length=64, frozen=True, validation_alias="QualificationName" + ) + lucid_id: Optional[str] = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + validation_alias="StandardGlobalID", + frozen=True, + ) + lucid_name: Optional[str] = Field( + min_length=1, max_length=64, frozen=True, validation_alias="StandardGlobalName" + ) + question_text: str = Field( + max_length=1024, + min_length=1, + description="The text shown to respondents", + validation_alias="QualificationText", + ) + question_type: RepDataQuestionType = Field( + frozen=True, validation_alias="QualificationType" + ) + options: Optional[List[RepDataQuestionOption]] = Field(default=None, min_length=1) + source: Literal[Source.REPDATA] = Source.REPDATA + + @property + def internal_id(self) -> str: + return self.lucid_id + + @field_validator("question_id", mode="before") + @classmethod + def check_uuid_type(cls, v: str | UUID) -> str: + return UUID(v).hex if isinstance(v, str) else v.hex + + @model_validator(mode="after") + def check_type_options_agreement(self): + # If type == "text_entry", options is None. Otherwise, must be set. + if self.question_type == RepDataQuestionType.TEXT_ENTRY: + assert self.options is None, "TEXT_ENTRY shouldn't have options" + else: + assert self.options is not None, "missing options" + return self + + @classmethod + def from_api( + cls, d: dict, country_iso: str, language_iso: str + ) -> Optional["RepDataQuestion"]: + """ + :param d: Raw response from API + """ + try: + return cls._from_api(d, country_iso, language_iso) + except Exception as e: + logger.warning(f"Unable to parse question: {d}. {e}") + return None + + @classmethod + def _from_api( + cls, d: dict, country_iso: str, language_iso: str + ) -> "RepDataQuestion": + d["QualificationType"] = RepDataQuestionType.from_api(d["QualificationType"]) + # zip code/age has a placeholder invalid option for some reason + if d["QualificationType"] == RepDataQuestionType.TEXT_ENTRY: + d["QualificationOptions"] = None + options = None + if d["QualificationOptions"]: + options = [ + RepDataQuestionOption(id=str(r["Code"]), text=r["OptionName"], order=n) + for n, r in enumerate(d["QualificationOptions"]) + if r + ] + return cls( + **d, options=options, country_iso=country_iso, language_iso=language_iso + ) + + @classmethod + def from_db(cls, d: dict): + options = None + if d["options"]: + options = [ + RepDataQuestionOption(id=r["id"], text=r["text"], order=r["order"]) + for r in d["options"] + ] + return cls( + question_id=d["question_id"], + question_text=d["question_text"], + question_name=d["question_name"], + lucid_id=d["lucid_id"], + lucid_name=d["lucid_name"], + question_type=d["question_type"], + country_iso=d["country_iso"], + language_iso=d["language_iso"], + options=options, + is_live=d["is_live"], + category_id=d.get("category_id"), + ) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", by_alias=True) + d["options"] = json.dumps(d["options"]) + return d + + def to_upk_question(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionType, + UpkQuestionSelectorMC, + UpkQuestionSelectorTE, + UpkQuestion, + order_exclusive_options, + ) + + upk_type_selector_map = { + RepDataQuestionType.SINGLE_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.SINGLE_ANSWER, + ), + RepDataQuestionType.MULTI_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.MULTIPLE_ANSWER, + ), + RepDataQuestionType.TEXT_ENTRY: ( + UpkQuestionType.TEXT_ENTRY, + UpkQuestionSelectorTE.SINGLE_LINE, + ), + } + upk_type, upk_selector = upk_type_selector_map[self.question_type] + d = { + "ext_question_id": self.external_id, + "country_iso": self.country_iso, + "language_iso": self.language_iso, + "type": upk_type, + "selector": upk_selector, + "text": self.question_text, + } + if self.options: + d["choices"] = [ + UpkQuestionChoice(id=c.id, text=c.text, order=n) + for n, c in enumerate(self.options) + ] + q = UpkQuestion(**d) + order_exclusive_options(q) + return q diff --git a/generalresearch/models/repdata/survey.py b/generalresearch/models/repdata/survey.py new file mode 100644 index 0000000..572e696 --- /dev/null +++ b/generalresearch/models/repdata/survey.py @@ -0,0 +1,565 @@ +# docs are a pdf +from __future__ import annotations + +import json +import logging +from datetime import datetime, timezone +from decimal import Decimal +from functools import cached_property +from typing import List, Optional, Dict, Any, Set, Literal, Type +from typing_extensions import Self +from uuid import UUID + +from pydantic import ( + BaseModel, + Field, + field_validator, + ConfigDict, + computed_field, + model_validator, +) + +from generalresearch.grpc import timestamp_from_datetime +from generalresearch.locales import Localelator +from generalresearch.models import ( + LogicalOperator, + Source, + TaskCalculationType, + DeviceType, +) +from generalresearch.models.custom_types import ( + CoercedStr, + UUIDStr, + AwareDatetimeISO, +) +from generalresearch.models.repdata import RepDataStatus +from generalresearch.models.thl.demographics import Gender +from generalresearch.models.thl.survey import MarketplaceTask +from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, +) + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +locale_helper = Localelator() + + +class RepDataCondition(MarketplaceCondition): + question_id: CoercedStr = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + validation_alias="StandardGlobalQuestionID", + ) + values: List[str] = Field(min_length=1, validation_alias="PreCodes") + value_type: Literal[ConditionValueType.LIST] = Field( + default=ConditionValueType.LIST + ) + + @classmethod + def from_api(cls, d: Dict[str, Any]) -> "RepDataCondition": + if d["Condition"] == "Is": + d["logical_operator"] = LogicalOperator.OR + d["negate"] = False + elif d["Condition"] == "IsNot": + # todo: idk if this is really and, but its safer to say it is (not all values) + d["logical_operator"] = LogicalOperator.AND + d["negate"] = True + else: + raise ValueError(f"unknown condition: {d['Condition']}") + return cls.model_validate(d) + + +class RepDataQuota(BaseModel): + """ + A quota that can be on a stream. The parent stream has a CalculationType, + which dictates the meaning of the fields “Quota”, “QuotaAchieved” and + “QuotaRemaining + """ + + model_config = ConfigDict(populate_by_name=True, frozen=True) + quota_id: CoercedStr = Field( + min_length=1, max_length=16, pattern=r"^[0-9]+$", validation_alias="QuotaId" + ) + quota_uuid: UUIDStr = Field(validation_alias="QuotaUd") + name: str = Field(validation_alias="QuotaName") + desired_count: Optional[int] = Field( + default=None, + validation_alias="Quota", + description="Desired completes or starts (depending on calculation_type)", + ) + achieved_count: int = Field( + validation_alias="QuotaAchieved", + description="Achieved completes or starts (depending on calculation_type)", + ) + remaining_count: Optional[int] = Field( + validation_alias="QuotaRemaining", + description="Completes or starts remaining (depending on calculation_type). Should " + "be used as the indicator for whether more respondents are needed to a " + "specific quota. If QuotaRemaining value = 0, then pause. If None, then the quota" + "is completely open (i.e. infinity). Unclear if this is true though (see .is_open)", + ) + conditions: List[RepDataCondition] = Field(min_length=1) + condition_hashes: List[str] = Field(min_length=1, default_factory=list) + + @field_validator("quota_uuid", mode="before") + @classmethod + def check_uuid_type(cls, v: str | UUID) -> str: + return UUID(v).hex if isinstance(v, str) else v.hex + + @model_validator(mode="before") + @classmethod + def set_condition_hashes(cls, data: Any): + if data.get("conditions"): + data["condition_hashes"] = [q.criterion_hash for q in data["conditions"]] + return data + + @property + def is_open(self) -> bool: + # According to the docs, if remaining count is None then the quota is + # open, but this does not seem to be the case. See e.g. stream_id='125928' + return self.remaining_count and self.remaining_count > 0 + + @classmethod + def from_api(cls, quota_res) -> Self: + d = quota_res.copy() + d["conditions"] = [RepDataCondition.from_api(q) for q in d["Questions"]] + # Sometimes this is an empty string. (todo: does that mean 0? who knows?) + d["QuotaAchieved"] = d["QuotaAchieved"] if d["QuotaAchieved"] != "" else 0 + d["Quota"] = d["Quota"] if d["Quota"] != "" else 0 + d["QuotaRemaining"] = d["QuotaRemaining"] if d["QuotaRemaining"] != "" else None + return cls.model_validate(d) + + def to_hashed_quota(self): + d = self.model_dump(mode="json", exclude={"conditions"}) + return RepDataHashedQuota.model_validate(d) + + def passes(self, criteria_evaluation: Dict[str, Optional[bool]]) -> Optional[bool]: + # We have to match all conditions within the quota. + return self.is_open and all( + criteria_evaluation.get(c) for c in self.condition_hashes + ) + + +class RepDataHashedQuota(RepDataQuota): + conditions: None = Field(default=None, exclude=True) + + +class RepDataStream(MarketplaceTask): + model_config = ConfigDict(populate_by_name=True) + stream_id: CoercedStr = Field( + min_length=1, max_length=16, pattern=r"^[0-9]+$", validation_alias="StreamId" + ) + stream_uuid: UUIDStr = Field(validation_alias="StreamUd") + + stream_name: str = Field(max_length=256, validation_alias="StreamName") + stream_status: RepDataStatus = Field(validation_alias="StreamStatus") + calculation_type: TaskCalculationType = Field( + description="Indicates whether the targets are counted per Complete or Survey Start", + validation_alias="CalculationType", + ) + + qualifications: List[RepDataCondition] = Field(min_length=1) + qualification_hashes: List[str] = Field(min_length=1, default_factory=list) + quotas: List[RepDataQuota] = Field(min_length=1) + hashed_quotas: List[RepDataHashedQuota] = Field(min_length=1, default_factory=list) + + used_question_ids: Set[str] = Field(default_factory=set) + + # Note: The API returns both Expected and ExpectedStreamCompletes which are the same + expected_count: int = Field( + validation_alias="ExpectedStreamCompletes", + description="If CalculationType = COMPLETES, represents the required completes from" + "the suppler, if STARTS, then the required survey starts", + ) + # Note: this is new as of 2024-May + remaining_count: int = Field( + description="Remaining number of Completes or Survey Starts. If “Remaining”= 0, then pause sample", + validation_alias="Remaining", + ) + + cpi: Decimal = Field(gt=0, le=100, validation_alias="CPI") + days_in_field: Optional[int] = Field(validation_alias="DaysInField", default=None) + + # # -------------- # # + # Below here: these fields are useless because it is our own data. + # # -------------- # # + actual_ir: int = Field( + ge=0, + le=100, + validation_alias="ActualIR", + description="In-field survey incidence rate", + ) + actual_loi: int = Field( + ge=0, + le=120 * 60, + validation_alias="ActualLOI", + description="In-field median LOI (in seconds)", + ) + actual_conversion: int = Field( + ge=0, + le=100, + validation_alias="Conversion", + description="Represents the live conversion rate for the supplier", + ) + + actual_complete_count: int = Field( + validation_alias="ActualStreamCompletes", + description="the total number of completes for the supplier", + ) + actual_count: int = Field( + validation_alias="Actual", + description="If CalculationType = COMPLETES, represents the total completes from " + "the supplier, if STARTS, then the total survey starts", + ) + source: Literal[Source.REPDATA] = Field(default=Source.REPDATA) + + # These are copied from the survey so that this can implement the + # MarketplaceTask class ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: str = Field( + max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True + ) + # 3-char ISO 639-2/B, lowercase + language_iso: str = Field( + max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True + ) + + @field_validator("stream_uuid", mode="before") + @classmethod + def check_uuid_type(cls, v: str | UUID) -> str: + return UUID(v).hex if isinstance(v, str) else v.hex + + @model_validator(mode="before") + @classmethod + def set_qualification_hashes(cls, data: Any): + if data.get("qualifications"): + data["qualification_hashes"] = [ + q.criterion_hash for q in data["qualifications"] + ] + return data + + @model_validator(mode="before") + @classmethod + def set_hashed_quotas(cls, data: Any): + if data.get("quotas"): + data["hashed_quotas"] = [q.to_hashed_quota() for q in data["quotas"]] + return data + + @model_validator(mode="before") + @classmethod + def set_used_questions(cls, data: Any): + if data.get("used_question_ids"): + return data + s = set() + if data.get("qualifications"): + s.update({q.question_id for q in data["qualifications"]}) + if data.get("quotas"): + for quota in data["quotas"]: + s.update({q.question_id for q in quota.conditions}) + data["used_question_ids"] = s + return data + + @model_validator(mode="before") + @classmethod + def set_locale(cls, data: Any): + data["country_isos"] = [data["country_iso"]] + data["language_isos"] = [data["language_iso"]] + return data + + @property + def internal_id(self) -> str: + return self.stream_id + + @computed_field + @cached_property + def all_hashes(self) -> Set[str]: + s = set(self.qualification_hashes.copy()) + for q in self.hashed_quotas: + s.update(set(q.condition_hashes)) + return s + + @property + def all_conditions(self) -> List[RepDataCondition]: + cs = self.qualifications.copy() + for quota in self.quotas: + cs.extend(quota.conditions.copy()) + cs = list({c.criterion_hash: c for c in cs}.values()) + return cs + + @property + def is_open(self) -> bool: + # The stream is open if the status is open and there is at least 1 open + # quota, and the expected_count > actual_count + return ( + self.stream_status == RepDataStatus.LIVE + and any(q.is_open for q in self.hashed_quotas) + and self.remaining_count > 0 + ) + + @property + def is_live(self) -> bool: + return self.stream_status == RepDataStatus.LIVE + + @property + def condition_model(self) -> Type[MarketplaceCondition]: + return RepDataCondition + + @property + def age_question(self) -> str: + return "42" + + @property + def marketplace_genders(self) -> Dict[Gender, Optional[MarketplaceCondition]]: + return { + Gender.MALE: RepDataCondition( + question_id="43", + values=["1"], + value_type=ConditionValueType.LIST, + ), + Gender.FEMALE: RepDataCondition( + question_id="43", + values=["2"], + value_type=ConditionValueType.LIST, + ), + Gender.OTHER: None, + } + + @classmethod + def from_api(cls, stream_res, country_iso, language_iso): + # qualifications and quotas need to be added to the stream_res manually + d = stream_res.copy() + d["CalculationType"] = TaskCalculationType.from_api(d["CalculationType"]) + d["ActualLOI"] = d["ActualLOI"] * 60 + d["StreamStatus"] = d["StreamStatus"].upper() + # todo: cpi decimal places? + d["CPI"] = d["CPI"] + d["qualifications"] = [ + RepDataCondition.from_api(q) for q in d["qualifications"] + ] + d["quotas"] = [RepDataQuota.from_api(q) for q in d["quotas"]] + return cls.model_validate( + d | {"country_iso": country_iso, "language_iso": language_iso} + ) + + def to_hashed_stream(self): + d = self.model_dump(mode="json", exclude={"qualifications", "quotas"}) + return RepDataStreamHashed.model_validate(d) + + +class RepDataStreamHashed(RepDataStream): + qualifications: None = Field(default=None, exclude=True) + quotas: None = Field(default=None, exclude=True) + + def to_mysql(self): + d = self.model_dump(mode="json") + d["qualification_hashes"] = json.dumps(d["qualification_hashes"]) + d["hashed_quotas"] = json.dumps(d["hashed_quotas"]) + d["used_question_ids"] = json.dumps(d["used_question_ids"]) + return d + + @classmethod + def from_db(cls, res, survey: RepDataSurveyHashed): + # We need certain fields copied over here so that a stream can exist + # independent of the survey + res["country_iso"] = survey.country_iso + res["language_iso"] = survey.language_iso + return cls.model_validate(res) + + +class RepDataSurvey(BaseModel): + model_config = ConfigDict(populate_by_name=True) + survey_id: CoercedStr = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + validation_alias="SurveyNumber", + ) + survey_uuid: UUIDStr = Field(validation_alias="SurveyUd") + survey_name: str = Field(max_length=256, validation_alias="SurveyName") + project_uuid: UUIDStr = Field( + validation_alias="ProjectUd", description="ID for the parent project" + ) + + survey_status: RepDataStatus = Field(validation_alias="SurveyStatus") + + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: str = Field( + max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True + ) + # 3-char ISO 639-2/B, lowercase + language_iso: str = Field( + max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True + ) + + estimated_loi: int = Field( + gt=0, + le=90 * 60, + validation_alias="EstimatedLOI", + description="Expected median time that respondents will need to take the " + "Survey from start to finish.", + ) + estimated_ir: int = Field(ge=0, le=100, validation_alias="EstimatedIR") + collects_pii: bool = Field( + validation_alias="PII", description="Indicates whether PII is collected" + ) + + allowed_devices: List[DeviceType] = Field( + min_length=1, validation_alias="Device Compatibility" + ) + + streams: List[RepDataStream] = Field(min_length=1) + hashed_streams: List[RepDataStreamHashed] = Field( + min_length=1, default_factory=list + ) + + # These do not come from the API. We set them ourselves + created: Optional[AwareDatetimeISO] = Field(default=None) + last_updated: Optional[AwareDatetimeISO] = Field(default=None) + + @field_validator("survey_uuid", "project_uuid", mode="before") + @classmethod + def check_uuid_type(cls, v: str | UUID) -> str: + return UUID(v).hex if isinstance(v, str) else v.hex + + @model_validator(mode="before") + @classmethod + def set_hashed_streams(cls, data: Any): + if data.get("streams"): + data["hashed_streams"] = [q.to_hashed_stream() for q in data["streams"]] + return data + + @field_validator("allowed_devices", mode="after") + def sort_allowed_devices(cls, values: List[str]): + return sorted(values) + + @property + def is_open(self) -> bool: + # The survey is open if the status is open and there is at least 1 + # open stream + return self.survey_status == RepDataStatus.LIVE and any( + q.is_open for q in self.hashed_streams + ) + + @property + def is_live(self) -> bool: + # A survey may be live, but it only has 1 stram which is not live. + # And so it is not really live. We have to check this separately. + return self.survey_status == RepDataStatus.LIVE + + @property + def all_hashes(self) -> Set[str]: + s = set() + for stream in self.hashed_streams: + s.update(stream.all_hashes) + return s + + @property + def all_conditions(self) -> List[RepDataCondition]: + cs = list() + for stream in self.streams: + cs.extend(stream.all_conditions) + # dedupe by criterion_hash + cs = list({c.criterion_hash: c for c in cs}.values()) + return cs + + @property + def allowed_devices_str(self) -> str: + return ",".join(map(str, sorted([d.value for d in self.allowed_devices]))) + + @classmethod + def from_api(cls, survey_response) -> Optional["RepDataSurvey"]: + """ + :param survey_response: Raw response from API + """ + try: + return cls._from_api(survey_response) + except Exception as e: + survey_id = survey_response.get("survey_id") or survey_response.get( + "SurveyNumber" + ) + logger.warning(f"Unable to parse survey {survey_id}. {e}") + return None + + @classmethod + def _from_api(cls, survey_response) -> "RepDataSurvey": + d = survey_response.copy() + d["country_iso"] = locale_helper.get_country_iso(d["SurveyCountry"].lower()) + d["language_iso"] = locale_helper.get_language_iso(d["SurveyLanguage"].lower()) + d["EstimatedLOI"] = d["EstimatedLOI"] * 60 + d["SurveyStatus"] = d["SurveyStatus"].upper() + d["allowed_devices"] = [ + DeviceType[x["device_name"].upper()] for x in d["DeviceCompatibility"] + ] + d["streams"] = [ + RepDataStream.from_api( + stream, country_iso=d["country_iso"], language_iso=d["language_iso"] + ) + for stream in d["Streams"] + ] + return cls.model_validate(d) + + def __hash__(self): + # We need this so this obj can be added into a set. + return hash(self.survey_id) + + def is_unchanged(self, other) -> bool: + # Avoiding overloading __eq__ because it looks kind of complicated? I + # want to be explicit that this is not testing object equivalence, just + # that the objects don't require any db updates + return self.model_dump(exclude={"created", "last_updated"}) == other.model_dump( + exclude={"created", "last_updated"} + ) + + def is_changed(self, other) -> bool: + return not self.is_unchanged(other) + + def to_mysql(self) -> Dict[str, Any]: + return self.to_hashed_survey().to_mysql() + + def to_hashed_survey(self) -> "RepDataSurveyHashed": + d = self.model_dump(mode="json", exclude={"streams"}) + return RepDataSurveyHashed(**d) + + def to_marketplace_task(self): + pass + + +class RepDataSurveyHashed(RepDataSurvey): + streams: None = Field(default=None, exclude=True) + + @classmethod + def from_db(cls, res): + res["allowed_devices"] = [ + DeviceType(int(x)) for x in res["allowed_devices"].split(",") + ] + if res["created"] is not None: + res["created"] = res["created"].replace(tzinfo=timezone.utc) + res["last_updated"] = res["last_updated"].replace(tzinfo=timezone.utc) + return cls.model_validate(res) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", by_alias=True, exclude={"hashed_streams"}) + d["allowed_devices"] = ",".join( + map(str, sorted([d.value for d in self.allowed_devices])) + ) + d["streams"] = [stream.to_mysql() for stream in self.hashed_streams] + if self.created: + d["created"] = self.created.replace(tzinfo=None) + return d + + def to_grpc(self, repdata_pb2): + now = datetime.now(tz=timezone.utc) + timestamp = timestamp_from_datetime(now) + return repdata_pb2.RepDataOpportunity( + json_str=self.model_dump_json(), + timestamp=timestamp, + is_live=self.is_live, + survey_id=self.survey_id, + ) + + @classmethod + def from_grpc(cls, msg): + return cls.model_validate_json(msg.json_str) diff --git a/generalresearch/models/repdata/task_collection.py b/generalresearch/models/repdata/task_collection.py new file mode 100644 index 0000000..7b99638 --- /dev/null +++ b/generalresearch/models/repdata/task_collection.py @@ -0,0 +1,132 @@ +from typing import List + +import pandas as pd +from pandera import Column, DataFrameSchema, Check, Index + +from generalresearch.locales import Localelator +from generalresearch.models import TaskCalculationType +from generalresearch.models.repdata import RepDataStatus +from generalresearch.models.repdata.survey import RepDataSurveyHashed +from generalresearch.models.thl.survey.task_collection import ( + TaskCollection, + create_empty_df_from_schema, +) + +COUNTRY_ISOS = Localelator().get_all_countries() +LANGUAGE_ISOS = Localelator().get_all_languages() + +RepDataTaskCollectionSchema = DataFrameSchema( + columns={ + # --- These fields come from the Survey object --- + "survey_id": Column( + str, Check.str_length(min_value=1, max_value=16), unique=False + ), + "survey_uuid": Column( + str, Check.str_length(min_value=32, max_value=32), unique=False + ), + "survey_name": Column(str, Check.str_length(min_value=1, max_value=256)), + "project_uuid": Column(str, Check.str_length(min_value=32, max_value=32)), + "survey_status": Column(str, Check.isin(RepDataStatus)), + "country_iso": Column(str, Check.isin(COUNTRY_ISOS)), # 2 letter, lowercase + "language_iso": Column(str, Check.isin(LANGUAGE_ISOS)), # 3 letter, lowercase + "estimated_loi": Column(int, Check.between(0, 90 * 60)), + "estimated_ir": Column(int, Check.between(0, 100)), + "collects_pii": Column(bool), + "allowed_devices": Column(str), + "created": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "last_updated": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + # --- These come from the Stream object --- + # This is the index ---v + # "stream_id": Column(str, Check.str_length(min_value=1, max_value=16), unique=True), + "stream_uuid": Column( + str, Check.str_length(min_value=32, max_value=32), unique=True + ), + "stream_name": Column(str, Check.str_length(min_value=1, max_value=256)), + "stream_status": Column(str, Check.isin(RepDataStatus)), + "remaining_count": Column(int, Check.greater_than_or_equal_to(0)), + "calculation_type": Column(str, Check.isin(TaskCalculationType)), + "cpi": Column(float, Check.between(min_value=0, max_value=100)), + "used_question_ids": Column(List[str]), + "all_hashes": Column(List[str]), # set >> list for column support + }, + checks=[ + # # There's only 1 or 2 live surveys, so we can't really assert anything ... + # Check(lambda df: df.shape[0] > 50, + # description="There should always be more than 50 surveys"), + # + # # Check(lambda df: 60 <= df.opp_obs_median_loi.mean() < 30 * 60, + # # description="Survey opp LOI should be 1 - 30 min on average."), + # # Check(lambda df: 60 <= df.quota_obs_median_loi.mean() < 30 * 60, + # # description="Surveys opp quota LOI should be 1 - 30 min on average."), + # + # Check(lambda df: .25 <= df.cpi.mean() < 5, + # description="Surveys CPI should be $.25 - $5 on average."), + # + # Check(lambda df: "us" in df.country_iso.value_counts().index[:3], + # description="United States must be in the top 3 countries."), + ], + index=Index( + str, + name="stream_id", + checks=Check.str_length(min_value=1, max_value=16), + unique=True, + ), + strict=True, + coerce=False, + drop_invalid_rows=False, +) + + +class RepDataTaskCollection(TaskCollection): + items: List[RepDataSurveyHashed] + _schema = RepDataTaskCollectionSchema + + def to_rows(self, s: RepDataSurveyHashed): + survey_fields = [ + "survey_id", + "survey_uuid", + "survey_name", + "project_uuid", + "survey_status", + "country_iso", + "language_iso", + "estimated_loi", + "estimated_ir", + "collects_pii", + "created", + "last_updated", + ] + stream_fields = [ + "stream_id", + "stream_uuid", + "stream_name", + "stream_status", + "calculation_type", + "cpi", + "used_question_ids", + "all_hashes", + "remaining_count", + ] + rows = [] + d = dict() + for k in survey_fields: + d[k] = getattr(s, k) + d["allowed_devices"] = s.allowed_devices_str + for stream in s.hashed_streams: + ds = d.copy() + for k in stream_fields: + ds[k] = getattr(stream, k) + ds["cpi"] = float(ds["cpi"]) + ds["used_question_ids"] = list(ds["used_question_ids"]) + ds["all_hashes"] = list(ds["all_hashes"]) + rows.append(ds) + return rows + + def to_df(self): + rows = [] + for s in self.items: + rows.extend(self.to_rows(s)) + if rows: + return pd.DataFrame.from_records(rows, index="stream_id") + else: + return create_empty_df_from_schema(self._schema) diff --git a/generalresearch/models/sago/__init__.py b/generalresearch/models/sago/__init__.py new file mode 100644 index 0000000..292f0f2 --- /dev/null +++ b/generalresearch/models/sago/__init__.py @@ -0,0 +1,13 @@ +from enum import Enum + +from pydantic import Field +from typing_extensions import Annotated + +SagoQuestionIdType = Annotated[ + str, Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") +] + + +class SagoStatus(str, Enum): + LIVE = "LIVE" + NOT_LIVE = "NOT_LIVE" diff --git a/generalresearch/models/sago/question.py b/generalresearch/models/sago/question.py new file mode 100644 index 0000000..8137b38 --- /dev/null +++ b/generalresearch/models/sago/question.py @@ -0,0 +1,284 @@ +# https://developer-beta.market-cube.com/api-details#api=definition-api&operation=get-api-v1-definition-qualification +# -answers-lanaguge-languageid +import json +import logging +from enum import Enum +from functools import cached_property +from typing import List, Optional, Literal, Any, Dict, Set + +from pydantic import ( + BaseModel, + Field, + model_validator, + field_validator, + PositiveInt, + ConfigDict, +) + +from generalresearch.models import Source, string_utils, MAX_INT32 +from generalresearch.models.custom_types import AwareDatetimeISO +from generalresearch.models.thl.profiling.marketplace import MarketplaceQuestion + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +class SagoQuestionOption(BaseModel): + id: str = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + frozen=True, + description="The unique identifier for a response to a qualification", + ) + # This is returned by the API but does not seem to be used for anything. + # Will keep it any ways. + code: Optional[str] = Field(min_length=1, max_length=16) + text: str = Field( + min_length=1, + max_length=1024, + frozen=True, + description="The response text shown to respondents", + ) + # Order does not come back explicitly in the API, but the responses seem + # to be ordered + order: int = Field() + + @field_validator("text", mode="after") + def remove_nbsp(cls, s: str): + return string_utils.remove_nbsp(s) + + +class SagoQuestionType(str, Enum): + """ + From the API: + {1: 'Single Punch', 2: 'Multi Punch', 3: 'Open Ended', 4: 'Dummy', + 5: 'Calculated Dummy', 6: 'Range', 7: 'EmailType', 8: 'Info', + 9: 'Compound', 10: 'Calendar', 11: 'Single Punch Image', + 12: 'Multi Punch Image', 14: 'VideoType'} + + Only {1, 2, 3, 6, 7, 8, 12} seem to be used. 8 and 12 seems to be unused. + """ + + # 1 + SINGLE_SELECT = "s" + + # 2 + MULTI_SELECT = "m" + + # 3, 6 (range is just age), 7 (asking for email). + TEXT_ENTRY = "t" + + @classmethod + def from_api(cls, a: int): + API_TYPE_MAP = { + 1: SagoQuestionType.SINGLE_SELECT, + 2: SagoQuestionType.MULTI_SELECT, + 3: SagoQuestionType.TEXT_ENTRY, + 6: SagoQuestionType.TEXT_ENTRY, + 7: SagoQuestionType.TEXT_ENTRY, + } + return API_TYPE_MAP[a] if a in API_TYPE_MAP else None + + +class SagoUserQuestionAnswer(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + # This is optional b/c this model can be used for eligibility checks for + # "anonymous" users, which are represented by a list of question answers + # not associated with an actual user. No default b/c we must explicitly set + # the field to None. + user_id: Optional[PositiveInt] = Field(lt=MAX_INT32) + question_id: str = Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") + + # This is optional b/c we do not need it when writing these to the db. When + # these are fetched from the db for use in yield-management, we read this + # field from the question table. + question_type: Optional[SagoQuestionType] = Field(default=None) + + # This may be a pipe-separated string if the question_type is multi. + # regex means any chars except capital letters + pre_code: str = Field(pattern=r"^[^A-Z]*$", validation_alias="option_id") + created: AwareDatetimeISO = Field() + + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: str = Field( + max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True + ) + # 3-char ISO 639-2/B, lowercase + language_iso: str = Field( + max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True + ) + + @property + def option_id(self) -> str: + return self.pre_code + + @cached_property + def options_ids(self) -> Set[str]: + return set(self.pre_code.split("|")) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", exclude={"question_type"}) + d["created"] = self.created.replace(tzinfo=None) + return d + + +class SagoQuestion(MarketplaceQuestion): + question_id: str = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + description="The unique identifier for the qualification", + frozen=True, + ) + question_name: str = Field( + max_length=255, min_length=1, description="A short name for the question" + ) + question_text: str = Field( + max_length=1024, min_length=1, description="The text shown to respondents" + ) + question_type: SagoQuestionType = Field(frozen=True) + options: Optional[List[SagoQuestionOption]] = Field(default=None, min_length=1) + + # This comes from the API field "qualificationCategoryId" + tags: Optional[str] = Field(default=None, frozen=True) + source: Literal[Source.SAGO] = Source.SAGO + + @property + def internal_id(self) -> str: + return self.question_id + + @model_validator(mode="after") + def check_type_options_agreement(self): + # If type == "text_entry", options is None. Otherwise, must be set. + if self.question_type == SagoQuestionType.TEXT_ENTRY: + assert self.options is None, "TEXT_ENTRY shouldn't have options" + else: + assert self.options is not None, "missing options" + return self + + @field_validator("question_name", "question_text", "tags", mode="after") + def remove_nbsp(cls, s: Optional[str]): + return string_utils.remove_nbsp(s) + + @classmethod + def from_api( + cls, d: dict, country_iso: str, language_iso: str + ) -> Optional["SagoQuestion"]: + """ + :param d: Raw response from API + :param country_iso: + :param language_iso: + :return: + """ + try: + return cls._from_api(d, country_iso, language_iso) + except Exception as e: + logger.warning(f"Unable to parse question: {d}. {e}") + return None + + @classmethod + def _from_api(cls, d: dict, country_iso, language_iso) -> "SagoQuestion": + sago_category_to_tags = { + 1: "Standard", + 2: "Custom", + 4: "PID", + 5: "Profile", + 12: "SAGO Standard", + } + question_type = SagoQuestionType.from_api(d["qualificationTypeId"]) + if question_type == SagoQuestionType.TEXT_ENTRY: + # The API returns an option for each of these for some reason + options = None + else: + options = [ + SagoQuestionOption( + id=str(r["answerId"]), + code=r["answerCode"], + text=r["text"].strip(), + order=n, + ) + for n, r in enumerate(d["qualificationAnswers"]) + ] + return cls( + question_id=str(d["qualificationId"]), + question_name=d["name"], + question_text=d["text"], + question_type=question_type, + tags=sago_category_to_tags.get(d["qualificationCategoryId"]), + options=options, + country_iso=country_iso, + language_iso=language_iso, + ) + + @classmethod + def from_db(cls, d: dict): + options = None + if d["options"]: + options = [ + SagoQuestionOption( + id=r["id"], code=r["code"], text=r["text"], order=r["order"] + ) + for r in d["options"] + ] + return cls( + question_id=d["question_id"], + question_text=d["question_text"], + question_name=d["question_name"], + question_type=d["question_type"], + country_iso=d["country_iso"], + language_iso=d["language_iso"], + options=options, + tags=d["tags"], + is_live=d["is_live"], + category_id=d.get("category_id"), + ) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", by_alias=True) + d["options"] = json.dumps(d["options"]) + return d + + def to_upk_question(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionType, + UpkQuestionSelectorMC, + UpkQuestionSelectorTE, + UpkQuestion, + order_exclusive_options, + ) + + upk_type_selector_map = { + SagoQuestionType.SINGLE_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.SINGLE_ANSWER, + ), + SagoQuestionType.MULTI_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.MULTIPLE_ANSWER, + ), + SagoQuestionType.TEXT_ENTRY: ( + UpkQuestionType.TEXT_ENTRY, + UpkQuestionSelectorTE.SINGLE_LINE, + ), + } + upk_type, upk_selector = upk_type_selector_map[self.question_type] + d = { + "ext_question_id": self.external_id, + "country_iso": self.country_iso, + "language_iso": self.language_iso, + "type": upk_type, + "selector": upk_selector, + "text": self.question_text, + } + if self.options: + d["choices"] = [ + UpkQuestionChoice(id=c.id, text=c.text, order=n) + for n, c in enumerate(self.options) + ] + q = UpkQuestion(**d) + order_exclusive_options(q) + return q diff --git a/generalresearch/models/sago/survey.py b/generalresearch/models/sago/survey.py new file mode 100644 index 0000000..a2e9a8a --- /dev/null +++ b/generalresearch/models/sago/survey.py @@ -0,0 +1,417 @@ +from __future__ import annotations + +import json +import logging +from datetime import timezone +from decimal import Decimal +from functools import cached_property +from typing import Optional, Dict, Any, List, Literal, Set, Tuple, Annotated, Type +from typing_extensions import Self + +from more_itertools import flatten +from pydantic import Field, ConfigDict, BaseModel, model_validator, computed_field + +from generalresearch.locales import Localelator +from generalresearch.models import Source, LogicalOperator +from generalresearch.models.custom_types import ( + CoercedStr, + AwareDatetimeISO, + AlphaNumStrSet, + AlphaNumStr, + DeviceTypes, + IPLikeStrSet, +) +from generalresearch.models.sago import SagoStatus +from generalresearch.models.thl.demographics import Gender +from generalresearch.models.thl.survey import MarketplaceTask +from generalresearch.models.thl.survey.condition import ( + ConditionValueType, + MarketplaceCondition, +) + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +locale_helper = Localelator() + + +class SagoCondition(MarketplaceCondition): + model_config = ConfigDict(populate_by_name=True, frozen=False, extra="ignore") + question_id: Optional[CoercedStr] = Field( + min_length=1, max_length=16, pattern=r"^[0-9]+$" + ) + # There isn't really a hard limit, but their API is inconsistent and + # sometimes returns all the options comma-separated instead of as a list. + # Try to catch that. + values: List[Annotated[str, Field(max_length=128)]] = Field() + + _CONVERT_LIST_TO_RANGE = ["59"] + + @classmethod + def from_api(cls, d: Dict[str, Any]) -> "SagoCondition": + d["logical_operator"] = LogicalOperator.OR + d["value_type"] = ConditionValueType(d["value_type"]) + d["negate"] = False + d["values"] = [x.strip().lower() for x in d["values"]] + return cls.model_validate(d) + + +class SagoQuota(BaseModel): + model_config = ConfigDict(populate_by_name=True, frozen=True) + + # We don't ever need this ... ? + quota_id: str = Field() + + # the docs say nothing about this... are they different in diff quotas??? + cpi: Decimal = Field(gt=0, le=100, decimal_places=2, max_digits=5) + + remaining_count: int = Field() + condition_hashes: List[str] = Field(min_length=0, default_factory=list) + + # There is no explicit status. The quota is closed if the count is 0 + + def __hash__(self) -> int: + return hash(tuple((tuple(self.condition_hashes), self.remaining_count))) + + @property + def is_open(self) -> bool: + min_open_spots = 3 + return self.remaining_count >= min_open_spots + + @classmethod + def from_api(cls, d: Dict) -> Self: + return cls.model_validate(d) + + def passes(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Passes means we 1) meet all conditions (aka "match") AND 2) the + # quota is open. + return self.is_open and self.matches(criteria_evaluation) + + def matches(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Matches means we meet all conditions. + # We can "match" a quota that is closed. In that case, we would not be + # eligible for the survey. + return all(criteria_evaluation.get(c) for c in self.condition_hashes) + + def matches_optional( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Optional[bool]: + # We need to know if any conditions are unknown to avoid matching a + # full quota. If any fail, then we know we fail regardless of any + # being unknown. + evals = [criteria_evaluation.get(c) for c in self.condition_hashes] + if False in evals: + return False + if None in evals: + return None + return True + + def matches_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Passes back "matches" (T/F/none) and a list of unknown criterion hashes + hash_evals = { + cell: criteria_evaluation.get(cell) for cell in self.condition_hashes + } + evals = set(hash_evals.values()) + if False in evals: + return False, set() + if None in evals: + return None, {cell for cell, ev in hash_evals.items() if ev is None} + return True, set() + + +class SagoSurvey(MarketplaceTask): + model_config = ConfigDict(populate_by_name=True) + + survey_id: CoercedStr = Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") + # There is no status returned, using one I make up b/c is_live depends on it, + status: SagoStatus = Field(default=SagoStatus.LIVE) + # is_live: bool = Field(default=True) # can't overload the is_live property ... + cpi: Decimal = Field(gt=0, le=100, decimal_places=2, max_digits=5) + buyer_id: CoercedStr = Field(max_length=32) + + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: str = Field( + max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True + ) + # 3-char ISO 639-2/B, lowercase + language_iso: str = Field( + max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True + ) + + # unknown what the values actually correspond to. {1, 71, 73, 105, 116} + account_id: str = Field( + description="differentiates Market Cube from Panel Cube accounts" + ) + study_type_id: str = Field() + industry_id: str = Field() + + allowed_devices: DeviceTypes = Field(min_length=1) + collects_pii: bool = Field(default=False) + + survey_exclusions: Optional[AlphaNumStrSet] = Field( + description="list of excluded survey ids", default=None + ) + ip_exclusions: Optional[IPLikeStrSet] = Field( + description="list of excluded IP addresses", default=None + ) + + # Documentation I think is wrong. These are the keys "LOI" and "IR". it + # doesn't say that they are bid or not, but they never seem to change ... + bid_loi: Optional[int] = Field(default=None, le=120 * 60) + bid_ir: Optional[float] = Field(default=None, ge=0, le=1) + + live_link: str = Field() + + # this comes from the Survey Reservation endpoint + remaining_count: int = Field() + + qualifications: List[str] = Field(default_factory=list) + quotas: List[SagoQuota] = Field(default_factory=list) + + source: Literal[Source.SAGO] = Field(default=Source.SAGO) + + used_question_ids: Set[AlphaNumStr] = Field(default_factory=set) + + # This is a "special" key to store all conditions that are used (as + # "condition_hashes") throughout this survey. In the reduced representation + # of this task (nearly always, for db i/o, in global_vars) this field will + # be null. + conditions: Optional[Dict[str, SagoCondition]] = Field(default=None) + + # These come from the API + modified_api: AwareDatetimeISO = Field( + description="When the survey was last updated in sago's system" + ) + + # This does not come from the API. We set it when we update this in the db. + created: Optional[AwareDatetimeISO] = Field(default=None) + updated: Optional[AwareDatetimeISO] = Field(default=None) + + @property + def internal_id(self) -> str: + return self.survey_id + + @computed_field + def is_live(self) -> bool: + return self.status == SagoStatus.LIVE + + @property + def is_open(self) -> bool: + # The survey is open if the status is OPEN and there is at least 1 + # open quota (or there are no quotas!) + return self.is_live and ( + any(q.is_open for q in self.quotas) or len(self.quotas) == 0 + ) + + @computed_field + @cached_property + def all_hashes(self) -> Set[str]: + s = set(self.qualifications) + for q in self.quotas: + s.update(set(q.condition_hashes)) + return s + + @model_validator(mode="before") + @classmethod + def set_locale(cls, data: Any): + data["country_isos"] = [data["country_iso"]] + data["language_isos"] = [data["language_iso"]] + return data + + @model_validator(mode="before") + @classmethod + def set_used_questions(cls, data: Any): + if data.get("used_question_ids") is not None: + return data + if not data.get("conditions"): + data["used_question_ids"] = set() + return data + data["used_question_ids"] = { + c.question_id for c in data["conditions"].values() if c.question_id + } + return data + + @property + def condition_model(self) -> Type[MarketplaceCondition]: + return SagoCondition + + @property + def age_question(self) -> str: + return "59" + + @property + def marketplace_genders(self) -> Dict[Gender, Optional[MarketplaceCondition]]: + return { + Gender.MALE: SagoCondition( + question_id="60", + values=["58"], + value_type=ConditionValueType.LIST, + ), + Gender.FEMALE: SagoCondition( + question_id="60", + values=["59"], + value_type=ConditionValueType.LIST, + ), + Gender.OTHER: None, + } + + @classmethod + def from_api(cls, d: Dict) -> Optional["SagoSurvey"]: + try: + return cls._from_api(d) + except Exception as e: + logger.warning(f"Unable to parse survey: {d}. {e}") + return None + + @classmethod + def _from_api(cls, d: Dict): + return cls.model_validate(d) + + def __repr__(self) -> str: + # Fancy repr that abbreviates ip_exclusions and survey_exclusions + repr_args = list(self.__repr_args__()) + for n, (k, v) in enumerate(repr_args): + if k in {"ip_exclusions", "survey_exclusions"}: + if v and len(v) > 6: + v = sorted(v) + v = v[:3] + ["…"] + v[-3:] + repr_args[n] = (k, v) + join_str = ", " + repr_str = join_str.join( + repr(v) if a is None else f"{a}={v!r}" for a, v in repr_args + ) + return f"{self.__repr_name__()}({repr_str})" + + def is_unchanged(self, other): + # Avoiding overloading __eq__ because it looks kind of complicated? I + # want to be explicit that this is not testing object equivalence, just + # that the objects don't require any db updates. We also exclude + # conditions b/c this is just the condition_hash definitions + return self.model_dump( + exclude={"updated", "conditions", "created"} + ) == other.model_dump(exclude={"updated", "conditions", "created"}) + + def to_mysql(self): + d = self.model_dump( + mode="json", + exclude={ + "all_hashes", + "country_isos", + "language_isos", + "source", + "conditions", + }, + ) + d["qualifications"] = json.dumps(d["qualifications"]) + d["quotas"] = json.dumps(d["quotas"]) + d["used_question_ids"] = json.dumps(sorted(d["used_question_ids"])) + d["modified_api"] = self.modified_api + d["updated"] = self.updated + return d + + @classmethod + def from_db(cls, d: Dict[str, Any]): + d["created"] = d["created"].replace(tzinfo=timezone.utc) + d["updated"] = d["updated"].replace(tzinfo=timezone.utc) + d["modified_api"] = d["modified_api"].replace(tzinfo=timezone.utc) + d["qualifications"] = json.loads(d["qualifications"]) + d["used_question_ids"] = json.loads(d["used_question_ids"]) + d["quotas"] = json.loads(d["quotas"]) + return cls.model_validate(d) + + def passes_qualifications( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> bool: + # We have to match all quals + return all(criteria_evaluation.get(q) for q in self.qualifications) + + def passes_qualifications_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Passes back "passes" (T/F/none) and a list of unknown criterion hashes + hash_evals = {q: criteria_evaluation.get(q) for q in self.qualifications} + evals = set(hash_evals.values()) + # We have to match all. So if any are False, we know we don't pass + if False in evals: + return False, set() + # If any are None, we don't know + if None in evals: + return None, {cell for cell, ev in hash_evals.items() if ev is None} + return True, set() + + def passes_quotas(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Many surveys have 0 quotas. Quotas are exclusionary. + # They can NOT match a quota where currently_open=0 + any_pass = True + for q in self.quotas: + matches = q.matches_optional(criteria_evaluation) + if matches in {True, None} and not q.is_open: + # We also cannot be unknown for this quota, b/c we might fall into it, which would be a fail. + return False + return any_pass + + def passes_quotas_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Many surveys have 0 quotas. Quotas are exclusionary. + # They can NOT match a quota where currently_open=0 + if len(self.quotas) == 0: + return True, set() + quota_eval = { + quota: quota.matches_soft(criteria_evaluation) for quota in self.quotas + } + evals = set(g[0] for g in quota_eval.values()) + if any(m[0] is True and not q.is_open for q, m in quota_eval.items()): + # matched a full quota + return False, set() + if any(m[0] is None and not q.is_open for q, m in quota_eval.items()): + # Unknown match for full quota + if True in evals: + # we match 1 other, so the missing are only this type + return None, set( + flatten( + [ + m[1] + for q, m in quota_eval.items() + if m[0] is None and not q.is_open + ] + ) + ) + else: + # we don't match any quotas, so everything is unknown + return None, set( + flatten([m[1] for q, m in quota_eval.items() if m[0] is None]) + ) + if True in evals: + return True, set() + if None in evals: + return None, set( + flatten([m[1] for q, m in quota_eval.items() if m[0] is None]) + ) + return False, set() + + def determine_eligibility( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> bool: + return ( + self.is_open + and self.passes_qualifications(criteria_evaluation) + and self.passes_quotas(criteria_evaluation) + ) + + def determine_eligibility_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + if self.is_open is False: + return False, set() + pass_quals, h_quals = self.passes_qualifications_soft(criteria_evaluation) + pass_quotas, h_quotas = self.passes_quotas_soft(criteria_evaluation) + if pass_quals and pass_quotas: + return True, set() + elif pass_quals is False or pass_quotas is False: + return False, set() + else: + return None, h_quals | h_quotas diff --git a/generalresearch/models/sago/task_collection.py b/generalresearch/models/sago/task_collection.py new file mode 100644 index 0000000..c2f168a --- /dev/null +++ b/generalresearch/models/sago/task_collection.py @@ -0,0 +1,81 @@ +from typing import List, Set + +import pandas as pd +from pandera import Column, DataFrameSchema, Check, Index + +from generalresearch.locales import Localelator +from generalresearch.models.sago import SagoStatus +from generalresearch.models.sago.survey import SagoSurvey +from generalresearch.models.thl.survey.task_collection import ( + TaskCollection, + create_empty_df_from_schema, +) + +COUNTRY_ISOS: Set[str] = Localelator().get_all_countries() +LANGUAGE_ISOS: Set[str] = Localelator().get_all_languages() + +SagoTaskCollectionSchema = DataFrameSchema( + columns={ + "status": Column(str, Check.isin(SagoStatus)), + "cpi": Column(float, Check.between(min_value=0, max_value=100)), + "buyer_id": Column(str), + "country_iso": Column(str, Check.isin(COUNTRY_ISOS)), # 2 letter, lowercase + "language_iso": Column(str, Check.isin(LANGUAGE_ISOS)), # 3 letter, lowercase + "account_id": Column(str), + "study_type_id": Column(str), + "industry_id": Column(str), + "allowed_devices": Column(str), + "collects_pii": Column(bool), + "bid_loi": Column("Int32", Check.between(0, 90 * 60), nullable=True), + "bid_ir": Column(float, Check.between(0, 1), nullable=True), + "remaining_count": Column(int), + "created": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "updated": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "used_question_ids": Column(List[str]), + "all_hashes": Column(List[str]), # set >> list for column support + }, + checks=[], + index=Index( + str, + name="survey_id", + checks=Check.str_length(min_value=1, max_value=16), + unique=True, + ), + strict=True, + coerce=True, + drop_invalid_rows=False, +) + + +class SagoTaskCollection(TaskCollection): + items: List[SagoSurvey] + _schema = SagoTaskCollectionSchema + + def to_row(self, s: SagoSurvey): + d = s.model_dump( + mode="json", + exclude={ + "country_isos", + "language_isos", + "qualifications", + "quotas", + "source", + "conditions", + "is_live", + "survey_exclusions", + "ip_exclusions", + "live_link", + "modified_api", + }, + ) + d["cpi"] = float(s.cpi) + return d + + def to_df(self): + rows = [] + for s in self.items: + rows.append(self.to_row(s)) + if rows: + return pd.DataFrame.from_records(rows, index="survey_id") + else: + return create_empty_df_from_schema(self._schema) diff --git a/generalresearch/models/spectrum/__init__.py b/generalresearch/models/spectrum/__init__.py new file mode 100644 index 0000000..b62c089 --- /dev/null +++ b/generalresearch/models/spectrum/__init__.py @@ -0,0 +1,15 @@ +from enum import Enum + +from pydantic import Field +from typing_extensions import Annotated + +SpectrumQuestionIdType = Annotated[ + str, Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") +] + + +class SpectrumStatus(int, Enum): + DRAFT = 11 + LIVE = 22 + PAUSED = 33 + CLOSED = 44 diff --git a/generalresearch/models/spectrum/question.py b/generalresearch/models/spectrum/question.py new file mode 100644 index 0000000..549fc7e --- /dev/null +++ b/generalresearch/models/spectrum/question.py @@ -0,0 +1,371 @@ +# https://purespectrum.atlassian.net/wiki/spaces/PA/pages/36851836/Get+Attributes+By+Qualification+Code +from __future__ import annotations + +import json +import logging +from datetime import datetime, timezone +from enum import Enum +from functools import cached_property +from typing import List, Optional, Literal, Any, Dict, Set +from uuid import UUID + +from pydantic import ( + BaseModel, + Field, + model_validator, + field_validator, + PositiveInt, +) +from typing_extensions import Self + +from generalresearch.models import Source, string_utils, MAX_INT32 +from generalresearch.models.custom_types import AwareDatetimeISO +from generalresearch.models.spectrum import SpectrumQuestionIdType +from generalresearch.models.thl.profiling.marketplace import ( + MarketplaceQuestion, +) + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +class SpectrumUserQuestionAnswer(BaseModel): + # This is optional b/c this model can be used for eligibility checks + # for "anonymous" users, which are represented by a list of question + # answers not associated with an actual user. No default b/c we must + # explicitly set the field to None. + + user_id: Optional[PositiveInt] = Field(lt=MAX_INT32) + question_id: SpectrumQuestionIdType = Field() + # This is optional b/c we do not need it when writing these to the + # db. When these are fetched from the db for use in yield-management, + # we read this field from the spectrum_question table. + question_type: Optional[SpectrumQuestionType] = Field(default=None) + # This may be a pipe-separated string if the question_type is multi. regex + # means any chars except capital letters + option_id: str = Field(pattern=r"^[^A-Z]*$") + created: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: str = Field( + max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True + ) + # 3-char ISO 639-2/B, lowercase + language_iso: str = Field( + max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True + ) + + @cached_property + def options_ids(self) -> Set[str]: + return set(self.option_id.split("|")) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", exclude={"question_type"}) + d["created"] = self.created.replace(tzinfo=None) + return d + + +class SpectrumQuestionOption(BaseModel): + id: str = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + frozen=True, + description="The unique identifier for a response to a qualification", + ) + text: str = Field( + min_length=1, + max_length=1024, + frozen=True, + description="The response text shown to respondents", + ) + # Order does not come back explicitly in the API, and the API does not + # order them at all. Generally, they should be ordered by the id, but + # this isn't consistent. + order: int = Field() + + @field_validator("text", mode="after") + def remove_nbsp(cls, s: str) -> str: + return string_utils.remove_nbsp(s) + + +class SpectrumQuestionType(str, Enum): + # The documentation defines 4 types (1,2,3,4), however 2 is the same as 1 + # and never comes back in the api, and we also get back 5, 6, and 7, + # which are all undocumented. + + # This is for type 1 or 2 in their docs (singlepunch or singlepunch-alt) + SINGLE_SELECT = "s" + # Type 3 (multipunch) + MULTI_SELECT = "m" + + # Type 5 is undocumented, but seems to be integer free response + # Type 7 is undocumented, but looks to be free-response / open-ended, + # generally data quality related (e.g. Please tell us, how would you + # like to spend your weekend?) + TEXT_ENTRY = "t" + + # Type 4 (range). These all seem to be testing questions, and there is + # nothing to indicate how this should work at all, so we should not + # use this. + # + # RANGE = 'r' + # Type 6 is undocumented, but is a children question that relies on another + # question using unknown also as a catch-all incase they change their + # API randomly + UNKNOWN = "u" + + @staticmethod + def get_api_map() -> Dict[int, SpectrumQuestionType]: + return { + 1: SpectrumQuestionType.SINGLE_SELECT, + 2: SpectrumQuestionType.SINGLE_SELECT, + 3: SpectrumQuestionType.MULTI_SELECT, + 5: SpectrumQuestionType.TEXT_ENTRY, + 7: SpectrumQuestionType.TEXT_ENTRY, + } + + @classmethod + def from_api(cls, a: int): + api_type_map = cls.get_api_map() + return api_type_map[a] if a in api_type_map else None + + +class SpectrumQuestionClass(int, Enum): + CORE = 1 + EXTENDED = 2 + CUSTOM = 3 + + +class SpectrumQuestion(MarketplaceQuestion): + # This is called "qualification_code" in the API + question_id: SpectrumQuestionIdType = Field( + description="The unique identifier for the qualification", frozen=True + ) + # In the API: desc + question_name: str = Field( + max_length=255, + min_length=1, + frozen=True, + description="A short name for the question", + ) + question_text: str = Field( + max_length=1024, + min_length=1, + description="The text shown to respondents", + frozen=False, + ) + question_type: SpectrumQuestionType = Field( + description="The type of question asked", frozen=True + ) + # This comes from the API field "cat". It is not really documented. It + # looks to be a comma-separated str of "tags" or keywords associated + # with a question, but they are freeform and don't pertain to any sort + # of structured schema. This will be useful ChatGPT + tags: Optional[str] = Field(default=None, frozen=True) + options: Optional[List[SpectrumQuestionOption]] = Field( + default=None, min_length=1, frozen=True + ) + # This comes from the API. Of course there are more than what is documented. + # (1 = Core profiling question, 2 = Extended, 3 = Custom, 4 = ???) + class_num: SpectrumQuestionClass = Field(frozen=True) + # This comes from the API. It is when it was created in Spectrum's DB, + # not when we created it + created: Optional[AwareDatetimeISO] = Field(default=None, frozen=True) + + source: Literal[Source.SPECTRUM] = Source.SPECTRUM + + @property + def internal_id(self) -> str: + return self.question_id + + @model_validator(mode="before") + @classmethod + def clean_text_qid_from_api(cls, data: Any): + # Almost all questions have "variable names" in the question text. + # Remove this e.g. 'Are you registered in any of the following US + # political parties? %%1040%%' or 'My household earns approximately + # $%%213%% per year' + s = data["question_text"].strip() + search_str = f"%%{data['question_id']}%%" + if search_str in s: + if s.endswith(search_str): + s = s.replace(search_str, "").strip() + else: + s = s.replace(search_str, "___") + # After we do this, there shouldn't be any others + if "%%" in s: + raise ValueError("question text has unknown variables") + data["question_text"] = s + return data + + @field_validator("question_name", "question_text", "tags", mode="after") + def remove_nbsp(cls, s: Optional[str]): + return string_utils.remove_nbsp(s) + + @model_validator(mode="before") + @classmethod + def crop_name_from_api(cls, data: Any): + # Some of the names are ridiculously long. They aren't used for + # anything to its safe to crop it + data["question_name"] = data["question_name"].strip()[:255] + return data + + @model_validator(mode="after") + def check_type_options_agreement(self): + # If type == "text_entry", options is None. Otherwise, must be set. + if self.question_type == SpectrumQuestionType.TEXT_ENTRY: + assert self.options is None, "TEXT_ENTRY shouldn't have options" + else: + assert self.options is not None, "missing options" + return self + + @field_validator("options") + @classmethod + def order_options(cls, options): + if options: + options.sort(key=lambda x: x.order) + return options + + @field_validator("options") + @classmethod + def uniquify_options(cls, options: Optional[List[SpectrumQuestionOption]]): + if options: + # The API returns questions with identical option IDs multiple + # times. They seem to all be typo/corrections to the text, so + # it should be safe to just remove the duplicates. We have no + # way of knowing which is the intended text though. + + opt_d = {opt.id: opt for opt in options} + options = list(opt_d.values()) + for n, opt in enumerate(options): + opt.order = n + return options + + @classmethod + def from_api( + cls, d: dict, country_iso: str, language_iso: str + ) -> Optional["SpectrumQuestion"]: + # To not pollute our logs, we know we are skipping any question that + # meets the following conditions: + if not SpectrumQuestionType.from_api(d["type"]): + return None + if d["class"] not in { + x.value for x in SpectrumQuestionClass.__members__.values() + }: + return None + try: + return cls._from_api(d, country_iso, language_iso) + except Exception as e: + logger.warning(f"Unable to parse question: {d}. {e}") + return None + + @classmethod + def _from_api(cls, d: dict, country_iso: str, language_iso: str) -> Self: + options = None + if d.get("condition_codes"): + # Sometimes they use the key "name" instead of "text" ... ? + key = "text" if "text" in d["condition_codes"][0] else "name" + # Sometimes options are blank + d["condition_codes"] = [x for x in d["condition_codes"] if x[key]] + options = [ + SpectrumQuestionOption(id=r["id"], text=r[key], order=order) + for order, r in enumerate( + sorted(d["condition_codes"], key=lambda x: int(x["id"])) + ) + ] + + created = ( + datetime.utcfromtimestamp(d["crtd_on"] / 1000).replace(tzinfo=timezone.utc) + if d.get("crtd_on") + else None + ) + return cls( + question_id=str(d["qualification_code"]), + question_text=d["text"], + question_type=SpectrumQuestionType.from_api(d["type"]), + question_name=d["desc"], + tags=d["cat"], + class_num=d["class"], + created=created, + country_iso=country_iso, + language_iso=language_iso, + options=options, + ) + + @classmethod + def from_db(cls, d: dict) -> Self: + options = None + if d["options"]: + options = [ + SpectrumQuestionOption(id=r["id"], text=r["text"], order=r["order"]) + for r in d["options"] + ] + d["created"] = ( + d["created"].replace(tzinfo=timezone.utc) if d["created"] else None + ) + + return cls( + question_id=d["question_id"], + question_name=d["question_name"], + question_text=d["question_text"], + question_type=d["question_type"], + country_iso=d["country_iso"], + language_iso=d["language_iso"], + options=options, + is_live=d["is_live"], + category_id=( + UUID(d.get("category_id")).hex if d.get("category_id") else None + ), + tags=d["tags"], + class_num=d["class_num"], + created=d["created"], + ) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", by_alias=True) + d["options"] = json.dumps(d["options"]) + if self.created: + d["created"] = self.created.replace(tzinfo=None) + return d + + def to_upk_question(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionType, + UpkQuestionSelectorMC, + UpkQuestionSelectorTE, + UpkQuestion, + ) + + upk_type_selector_map = { + SpectrumQuestionType.SINGLE_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.SINGLE_ANSWER, + ), + SpectrumQuestionType.MULTI_SELECT: ( + UpkQuestionType.MULTIPLE_CHOICE, + UpkQuestionSelectorMC.MULTIPLE_ANSWER, + ), + SpectrumQuestionType.TEXT_ENTRY: ( + UpkQuestionType.TEXT_ENTRY, + UpkQuestionSelectorTE.SINGLE_LINE, + ), + } + upk_type, upk_selector = upk_type_selector_map[self.question_type] + d = { + "ext_question_id": self.external_id, + "country_iso": self.country_iso, + "language_iso": self.language_iso, + "type": upk_type, + "selector": upk_selector, + "text": self.question_text, + } + if self.options: + d["choices"] = [ + UpkQuestionChoice(id=c.id, text=c.text, order=c.order) + for c in self.options + ] + return UpkQuestion(**d) diff --git a/generalresearch/models/spectrum/survey.py b/generalresearch/models/spectrum/survey.py new file mode 100644 index 0000000..7bebaa2 --- /dev/null +++ b/generalresearch/models/spectrum/survey.py @@ -0,0 +1,514 @@ +from __future__ import annotations + +import json +import logging +from datetime import timezone +from decimal import Decimal +from typing import Optional, Dict, Any, List, Literal, Set, Tuple, Type +from typing_extensions import Self + +from more_itertools import flatten +from pydantic import Field, ConfigDict, BaseModel, model_validator, computed_field + +from generalresearch.locales import Localelator +from generalresearch.models import TaskCalculationType, Source +from generalresearch.models.custom_types import ( + CoercedStr, + AwareDatetimeISO, + AlphaNumStrSet, + UUIDStrSet, + AlphaNumStr, +) +from generalresearch.models.spectrum import SpectrumStatus +from generalresearch.models.thl.demographics import Gender +from generalresearch.models.thl.survey import MarketplaceTask +from generalresearch.models.thl.survey.condition import ( + ConditionValueType, + MarketplaceCondition, +) + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +locale_helper = Localelator() + + +class SpectrumCondition(MarketplaceCondition): + model_config = ConfigDict(populate_by_name=True, frozen=False, extra="ignore") + + question_id: Optional[CoercedStr] = Field( + min_length=1, + max_length=16, + pattern=r"^[0-9]+$", + validation_alias="qualification_code", + ) + + @model_validator(mode="after") + def change_age_range_to_list(self) -> Self: + """Spectrum uses ranges usually for ages. Ranges take longer to + evaluate b/c they have to be converted into ints and then require + multiple evaluations. Just convert into a list of values which only + requires one easy match. + e.g. convert age values from '20-22|20-21|25-26' to '|20|21|22|25|26|' + """ + if self.question_id == "212" and self.value_type == ConditionValueType.RANGE: + try: + values = [tuple(map(int, v.split("-"))) for v in self.values] + assert all(len(x) == 2 for x in values) + except (ValueError, AssertionError): + return self + self.values = sorted( + {str(val) for tupl in values for val in range(tupl[0], tupl[1] + 1)} + ) + self.value_type = ConditionValueType.LIST + return self + + @classmethod + def from_api(cls, d: Dict[str, Any]) -> "SpectrumCondition": + """Ranges can get returns with a key "units" indicating years or + months. This is ridiculous, and we don't ask for birthdate, so we + can't really get month accuracy. Normalize to years. + """ + if "range_sets" in d: + for rs in d["range_sets"]: + if rs["units"] == 312: + rs["from"] = round(rs["from"] / 12) + rs["to"] = round(rs["to"] / 12) + d["values"] = [ + "{0}-{1}".format(rs["from"] or "inf", rs["to"] or "inf") + for rs in d["range_sets"] + ] + d["value_type"] = ConditionValueType.RANGE + return cls.model_validate(d) + else: + d["values"] = list(map(str.lower, d["condition_codes"])) + d["value_type"] = ConditionValueType.LIST + return cls.model_validate(d) + + +class SpectrumQuota(BaseModel): + model_config = ConfigDict(populate_by_name=True, frozen=True) + + # We don't ever need this. There's also a crtd_on and mod_on field, which + # we ignore. quota_id: UUIDStr = Field() + + # API response is quantities.currently_open + remaining_count: int = Field( + description="Number of completes currently available in the quota. If " + "the value is 0, any respondent matching this quota will be rejected." + ) + condition_hashes: List[str] = Field(min_length=0, default_factory=list) + + # API also returns remaining & achieved, but these are supplier-scoped. + # There is no explicit status. The quota is closed if the count is 0 + + def __hash__(self) -> int: + return hash(tuple((tuple(self.condition_hashes), self.remaining_count))) + + @property + def is_open(self) -> bool: + # currently_open takes into account respondents in progress, so + # theoretically we should just check that there is >0 spots left + min_open_spots = 1 + return self.remaining_count >= min_open_spots + + @classmethod + def from_api(cls, d: Dict) -> Self: + d["remaining_count"] = d["quantities"]["currently_open"] + return cls.model_validate(d) + + def passes(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Passes means we 1) meet all conditions (aka "match") AND 2) the + # quota is open. + return self.is_open and self.matches(criteria_evaluation) + + def matches(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # Matches means we meet all conditions. We can "match" a quota that is + # closed. In that case, we would not be eligible for the survey. + return all(criteria_evaluation.get(c) for c in self.condition_hashes) + + def matches_optional( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Optional[bool]: + # We need to know if any conditions are unknown to avoid matching a + # full quota. If any fail, then we know we fail regardless of any + # being unknown. + evals = [criteria_evaluation.get(c) for c in self.condition_hashes] + if False in evals: + return False + if None in evals: + return None + return True + + def matches_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Passes back "matches" (T/F/none) and a list of unknown criterion hashes + hash_evals = { + cell: criteria_evaluation.get(cell) for cell in self.condition_hashes + } + if False in hash_evals.values(): + return False, set() + if None in hash_evals.values(): + return None, {cell for cell, ev in hash_evals.items() if ev is None} + return True, set() + + +class SpectrumSurvey(MarketplaceTask): + model_config = ConfigDict(populate_by_name=True) + # Keys in API response that are undocumented: soft_launch, pds, project_last_complete_date + # Keys in API not used: price_type, buyer_message, last_complete_date (OUR last complete date) + # supplier_completes key is OUR DATA. It contains a "remaining" count, but this is just the + # sum of the quota remaining counts (I think) + + survey_id: CoercedStr = Field(min_length=1, max_length=16, pattern=r"^[0-9]+$") + survey_name: str = Field(max_length=256) + status: SpectrumStatus = Field(validation_alias="survey_status") + + field_end_date: AwareDatetimeISO = Field( + description="When this survey is scheduled to end fielding. May stay open past fielding" + ) + # Most are 232 - "Exciting New" which I assume is the default + category_code: CoercedStr = Field(max_length=3, min_length=3, default="232") + # API calls this "click_balancing" + calculation_type: TaskCalculationType = Field( + description="Indicates whether the targets are counted per Complete or Survey Start", + default=TaskCalculationType.COMPLETES, + ) + + requires_pii: bool = Field( + default=False, description="unclear what pii is", validation_alias="pii" + ) + buyer_id: CoercedStr = Field( + description="Identifier of client requesting the study", max_length=32 + ) + cpi: Decimal = Field(gt=0, le=100, decimal_places=2, max_digits=5) + + # called "survey_grouping" in API. If a respondent has previously taken any + # of these surveys, they will be excluded if that survey was taken in + # the exclusion_period. + survey_exclusions: Optional[AlphaNumStrSet] = Field( + description="list of excluded survey ids", default=None + ) + exclusion_period: int = Field(default=30, description="in days") + + # API does not explicitly return the Bid values. It returns a LOI and IR + # that is the Bid value when the last block is null. As such, sometimes + # it may be set, sometimes not. We'll store it in the db if we see it, + # but then when we update the survey, it may not be returned, and so + # when we update the db, we must not overwrite this with NULL. + # API key: "survey_performance" + + bid_loi: Optional[int] = Field(default=None, le=120 * 60) + bid_ir: Optional[float] = Field(default=None, ge=0, le=1) + overall_loi: Optional[int] = Field(default=None, le=120 * 60) + overall_ir: Optional[float] = Field(default=None, ge=0, le=1) + last_block_loi: Optional[int] = Field(default=None, le=120 * 60) + last_block_ir: Optional[float] = Field(default=None, ge=0, le=1) + + # Undocumented. They sent us an email indicating that this is the last time + # there was a complete for all suppliers on this survey. + project_last_complete_date: Optional[AwareDatetimeISO] = Field(default=None) + + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: str = Field( + max_length=2, min_length=2, pattern=r"^[a-z]{2}$", frozen=True + ) + # 3-char ISO 639-2/B, lowercase + language_iso: str = Field( + max_length=3, min_length=3, pattern=r"^[a-z]{3}$", frozen=True + ) + + # The API returns 'incl_excl' which is a boolean indicating if the psid + # list is an exclude or include list. If incl_excl = 1, the survey has an + # include list, and only those in the list are eligible. This list gets + # updated everytime someone on the list takes the survey. + include_psids: Optional[UUIDStrSet] = Field(default=None) + exclude_psids: Optional[UUIDStrSet] = Field(default=None) + + qualifications: List[str] = Field(default_factory=list) + quotas: List[SpectrumQuota] = Field(default_factory=list) + + source: Literal[Source.SPECTRUM] = Field(default=Source.SPECTRUM) + + used_question_ids: Set[AlphaNumStr] = Field(default_factory=set) + + # This is a "special" key to store all conditions that are used (as + # "condition_hashes") throughout this survey. In the reduced + # representation of this task (nearly always, for db i/o, in + # global_vars) this field will be null. + conditions: Optional[Dict[str, SpectrumCondition]] = Field(default=None) + + # These come from the API + created_api: AwareDatetimeISO = Field( + description="Creation date of opportunity", validation_alias="crtd_on" + ) + modified_api: AwareDatetimeISO = Field( + description="When the survey was last updated in spectrum's system", + validation_alias="mod_on", + ) + + # This does not come from the API. We set it when we update this in the db. + updated: Optional[AwareDatetimeISO] = Field(default=None) + + @property + def internal_id(self) -> str: + return self.survey_id + + @computed_field + def is_live(self) -> bool: + return self.status == SpectrumStatus.LIVE + + @property + def is_open(self) -> bool: + # The survey is open if the status is OPEN and there is at least 1 + # open quota (or there are no quotas!) + return self.is_live and ( + any(q.is_open for q in self.quotas) or len(self.quotas) == 0 + ) + + @computed_field + @property + def all_hashes(self) -> Set[str]: + s = set(self.qualifications) + for q in self.quotas: + s.update(set(q.condition_hashes)) + return s + + @model_validator(mode="before") + @classmethod + def set_locale(cls, data: Any): + data["country_isos"] = [data["country_iso"]] + data["language_isos"] = [data["language_iso"]] + return data + + @model_validator(mode="before") + @classmethod + def set_used_questions(cls, data: Any): + if data.get("used_question_ids") is not None: + return data + if not data.get("conditions"): + data["used_question_ids"] = set() + return data + data["used_question_ids"] = { + c.question_id for c in data["conditions"].values() if c.question_id + } + return data + + @property + def condition_model(self) -> Type[MarketplaceCondition]: + return SpectrumCondition + + @property + def age_question(self) -> str: + return "212" + + @property + def marketplace_genders(self) -> Dict[Gender, Optional[MarketplaceCondition]]: + return { + Gender.MALE: SpectrumCondition( + question_id="211", + values=["111"], + value_type=ConditionValueType.LIST, + ), + Gender.FEMALE: SpectrumCondition( + question_id="211", + values=["112"], + value_type=ConditionValueType.LIST, + ), + Gender.OTHER: None, + } + + @classmethod + def from_api(cls, d: Dict) -> Optional["SpectrumSurvey"]: + try: + return cls._from_api(d) + except Exception as e: + logger.warning(f"Unable to parse survey: {d}. {e}") + return None + + @classmethod + def _from_api(cls, d: Dict) -> Self: + assert d["click_balancing"] in {0, 1}, "unknown click_balancing value" + d["calculation_type"] = ( + TaskCalculationType.STARTS + if d["click_balancing"] + else TaskCalculationType.COMPLETES + ) + + d["conditions"] = dict() + + # If we haven't hit the "detail" endpoint, we won't get this + d.setdefault("qualifications", []) + qualifications = [SpectrumCondition.from_api(q) for q in d["qualifications"]] + for q in qualifications: + d["conditions"][q.criterion_hash] = q + d["qualifications"] = [x.criterion_hash for x in qualifications] + + quotas = [] + d.setdefault("quotas", []) + for quota in d["quotas"]: + criteria = [SpectrumCondition.from_api(q) for q in quota["criteria"]] + quota["condition_hashes"] = [x.criterion_hash for x in criteria] + quotas.append(SpectrumQuota.from_api(quota)) + for q in criteria: + d["conditions"][q.criterion_hash] = q + d["quotas"] = quotas + return cls.model_validate(d) + + def is_unchanged(self, other) -> bool: + # Avoiding overloading __eq__ because it looks kind of complicated? I + # want to be explicit that this is not testing object equivalence, just + # that the objects don't require any db updates. We also exclude + # conditions b/c this is just the condition_hash definitions + return self.model_dump(exclude={"updated", "conditions"}) == other.model_dump( + exclude={"updated", "conditions"} + ) + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump( + mode="json", + exclude={ + "all_hashes", + "country_isos", + "language_isos", + "source", + "conditions", + }, + ) + d["qualifications"] = json.dumps(d["qualifications"]) + d["quotas"] = json.dumps(d["quotas"]) + d["used_question_ids"] = json.dumps(sorted(d["used_question_ids"])) + d["created_api"] = self.created_api + d["updated"] = self.updated + d["modified_api"] = self.modified_api + d["field_end_date"] = self.field_end_date + d["project_last_complete_date"] = self.project_last_complete_date + return d + + @classmethod + def from_db(cls, d: Dict[str, Any]) -> Self: + d["created_api"] = d["created_api"].replace(tzinfo=timezone.utc) + d["updated"] = d["updated"].replace(tzinfo=timezone.utc) + d["modified_api"] = d["modified_api"].replace(tzinfo=timezone.utc) + d["field_end_date"] = ( + d["field_end_date"].replace(tzinfo=timezone.utc) + if d["field_end_date"] + else None + ) + d["project_last_complete_date"] = ( + d["project_last_complete_date"].replace(tzinfo=timezone.utc) + if d["project_last_complete_date"] + else None + ) + if "qualifications" in d: + d["qualifications"] = json.loads(d["qualifications"]) + if "quotas" in d: + d["quotas"] = json.loads(d["quotas"]) + d["used_question_ids"] = json.loads(d["used_question_ids"]) + return cls.model_validate(d) + + """ + Yield Management/Eligibility Description: + # https://purespectrum.atlassian.net/wiki/spaces/PA/pages/33604951/Respondent+Order+of+Operations + """ + + def passes_qualifications( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> bool: + # We have to match all quals + return all(criteria_evaluation.get(q) for q in self.qualifications) + + def passes_qualifications_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # Passes back "passes" (T/F/none) and a list of unknown criterion hashes + hash_evals = {q: criteria_evaluation.get(q) for q in self.qualifications} + # We have to match all. So if any are False, we know we don't pass + + if False in hash_evals.values(): + return False, set() + + # If any are None, we don't know + if None in hash_evals.values(): + return None, {cell for cell, ev in hash_evals.items() if ev is None} + return True, set() + + def passes_quotas(self, criteria_evaluation: Dict[str, Optional[bool]]) -> bool: + # We have to match at least 1 quota, but they can NOT match a quota + # where currently_open=0 + any_pass = False + for q in self.quotas: + matches = q.matches_optional(criteria_evaluation) + if matches in {True, None} and not q.is_open: + # We also cannot be unknown for this quota, b/c we might fall + # into it, which would be a fail. + return False + if matches: + any_pass = True + return any_pass + + def passes_quotas_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + # We have to match at least 1 quota, but they can NOT match a quota + # where currently_open=0 + quota_eval = { + quota: quota.matches_soft(criteria_evaluation) for quota in self.quotas + } + evals = set(g[0] for g in quota_eval.values()) + if any(m[0] is True and not q.is_open for q, m in quota_eval.items()): + # matched a full quota + return False, set() + if any(m[0] is None and not q.is_open for q, m in quota_eval.items()): + # Unknown match for full quota + if True in evals: + # we match 1 other, so the missing are only this type + return None, set( + flatten( + [ + m[1] + for q, m in quota_eval.items() + if m[0] is None and not q.is_open + ] + ) + ) + else: + # we don't match any quotas, so everything is unknown + return None, set( + flatten([m[1] for q, m in quota_eval.items() if m[0] is None]) + ) + if True in evals: + return True, set() + if None in evals: + return None, set( + flatten([m[1] for q, m in quota_eval.items() if m[0] is None]) + ) + return False, set() + + def determine_eligibility( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> bool: + return ( + self.is_open + and self.passes_qualifications(criteria_evaluation) + and self.passes_quotas(criteria_evaluation) + ) + + def determine_eligibility_soft( + self, criteria_evaluation: Dict[str, Optional[bool]] + ) -> Tuple[Optional[bool], Set[str]]: + if self.is_open is False: + return False, set() + pass_quals, h_quals = self.passes_qualifications_soft(criteria_evaluation) + # Check for not passing quals before bothering to do the rest + if pass_quals is False: + return False, set() + pass_quotas, h_quotas = self.passes_quotas_soft(criteria_evaluation) + if pass_quals and pass_quotas: + return True, set() + elif pass_quals is False or pass_quotas is False: + return False, set() + else: + return None, h_quals | h_quotas diff --git a/generalresearch/models/spectrum/task_collection.py b/generalresearch/models/spectrum/task_collection.py new file mode 100644 index 0000000..8aeac7d --- /dev/null +++ b/generalresearch/models/spectrum/task_collection.py @@ -0,0 +1,110 @@ +from typing import List, Set, Dict + +import pandas as pd +from pandera import Column, DataFrameSchema, Check, Index + +from generalresearch.locales import Localelator +from generalresearch.models import TaskCalculationType +from generalresearch.models.spectrum import SpectrumStatus +from generalresearch.models.spectrum.survey import SpectrumSurvey +from generalresearch.models.thl.survey.task_collection import ( + create_empty_df_from_schema, + TaskCollection, +) + +COUNTRY_ISOS: Set[str] = Localelator().get_all_countries() +LANGUAGE_ISOS: Set[str] = Localelator().get_all_languages() + +SpectrumTaskCollectionSchema = DataFrameSchema( + columns={ + "survey_name": Column(str, Check.str_length(min_value=1, max_value=256)), + "status": Column(int, Check.isin(SpectrumStatus)), + "field_end_date": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "category_code": Column(), + "calculation_type": Column(str, Check.isin(TaskCalculationType)), + "requires_pii": Column(bool), + "buyer_id": Column(str), + "cpi": Column(float, Check.between(min_value=0, max_value=100)), + "bid_loi": Column("Int32", Check.between(0, 90 * 60), nullable=True), + "bid_ir": Column(float, Check.between(0, 1), nullable=True), + "overall_loi": Column("Int32", Check.between(0, 90 * 60), nullable=True), + "overall_ir": Column(float, Check.between(0, 1), nullable=True), + "last_block_loi": Column("Int32", Check.between(0, 90 * 60), nullable=True), + "last_block_ir": Column(float, Check.between(0, 1), nullable=True), + "project_last_complete_date": Column( + dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=True + ), + "country_iso": Column(str, Check.isin(COUNTRY_ISOS)), # 2 letter, lowercase + "language_iso": Column(str, Check.isin(LANGUAGE_ISOS)), # 3 letter, lowercase + # exclude_psids is potentially large. We don't need these usually, we just want to know + # if include_psids is set, if so then this is a recontact + # "exclude_psids": Column(bool), + "include_psids": Column(str, nullable=True), + "created_api": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "modified_api": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "updated": Column(dtype=pd.DatetimeTZDtype(tz="UTC")), + "used_question_ids": Column(List[str]), + "all_hashes": Column(List[str]), # set >> list for column support + }, + checks=[], + index=Index( + str, + name="survey_id", + checks=Check.str_length(min_value=1, max_value=16), + unique=True, + ), + strict=True, + coerce=True, + drop_invalid_rows=False, +) + + +class SpectrumTaskCollection(TaskCollection): + items: List[SpectrumSurvey] + _schema = SpectrumTaskCollectionSchema + + def to_rows(self, s: SpectrumSurvey) -> List[Dict]: + fields = [ + "survey_name", + "status", + "field_end_date", + "category_code", + "calculation_type", + "requires_pii", + "buyer_id", + "cpi", + "bid_loi", + "bid_ir", + "overall_loi", + "overall_ir", + "last_block_loi", + "last_block_ir", + "project_last_complete_date", + "country_iso", + "language_iso", + "include_psids", + "created_api", + "modified_api", + "updated", + "used_question_ids", + "all_hashes", + "survey_id", + ] + rows = [] + d = dict() + for k in fields: + d[k] = getattr(s, k) if hasattr(s, k) else None + d["used_question_ids"] = list(s.used_question_ids) + d["cpi"] = float(s.cpi) + d["all_hashes"] = list(d["all_hashes"]) + rows.append(d) + return rows + + def to_df(self): + rows = [] + for s in self.items: + rows.extend(self.to_rows(s)) + if rows: + return pd.DataFrame.from_records(rows, index="survey_id") + else: + return create_empty_df_from_schema(self._schema) diff --git a/generalresearch/models/string_utils.py b/generalresearch/models/string_utils.py new file mode 100644 index 0000000..23c1017 --- /dev/null +++ b/generalresearch/models/string_utils.py @@ -0,0 +1,12 @@ +import unicodedata +from typing import Optional + + +def remove_nbsp(s: Optional[str]) -> Optional[str]: + # Some text comes back from the API with lots of (copied from excel or + # something), and random unicode... + if s: + s = s.replace("\u00a0", " ").strip() + s = unicodedata.normalize("NFKD", s) + + return s diff --git a/generalresearch/models/thl/__init__.py b/generalresearch/models/thl/__init__.py new file mode 100644 index 0000000..875b2bb --- /dev/null +++ b/generalresearch/models/thl/__init__.py @@ -0,0 +1,34 @@ +from decimal import Decimal +from typing import Optional + +from generalresearch.models.thl.finance import ( + ProductBalances, + POPFinancial, +) +from generalresearch.models.thl.payout import ( + BrokerageProductPayoutEvent, + PayoutEvent, +) +from generalresearch.models.thl.product import Product + +_ = ( + Product, + PayoutEvent, + BrokerageProductPayoutEvent, + ProductBalances, + POPFinancial, +) + +Product.model_rebuild() +PayoutEvent.model_rebuild() +BrokerageProductPayoutEvent.model_rebuild() + + +def decimal_to_int_cents(usd: Optional[Decimal]) -> Optional[int]: + return round(usd * 100) if usd is not None else None + + +def int_cents_to_decimal(value: Optional[int], decimals: int = 2) -> Optional[Decimal]: + if value is None: + return None + return (Decimal(value) / Decimal(100)).quantize(Decimal(10) ** -decimals) diff --git a/generalresearch/models/thl/category.py b/generalresearch/models/thl/category.py new file mode 100644 index 0000000..fe330f1 --- /dev/null +++ b/generalresearch/models/thl/category.py @@ -0,0 +1,62 @@ +from typing import Optional +from uuid import uuid4 + +from pydantic import BaseModel, Field, model_validator, PositiveInt +from typing_extensions import Self + +from generalresearch.models.custom_types import UUIDStr + + +class Category(BaseModel, frozen=True): + id: Optional[PositiveInt] = Field(exclude=True, default=None) + + uuid: UUIDStr = Field(examples=[uuid4().hex]) + + adwords_vertical_id: Optional[str] = Field(default=None, max_length=8) + + label: str = Field(max_length=255, examples=["Hair Loss"]) + + # The path is '/' separated string, that shows the full hierarchy. + # e.g. "Hair Loss" has the path: "/Beauty & Fitness/Hair Care/Hair Loss" + path: str = Field( + pattern=r"^\/.*[^\/]$", + examples=["/Beauty & Fitness/Hair Care/Hair Loss"], + ) + + parent_id: Optional[PositiveInt] = Field(default=None, exclude=True) + parent_uuid: Optional[UUIDStr] = Field(default=None, examples=[uuid4().hex]) + + @model_validator(mode="after") + def check_path(self) -> Self: + assert self.label in self.path, "invalid path" + return self + + @model_validator(mode="after") + def check_parent(self) -> Self: + if self.id is not None: + assert self.parent_id != self.id, "you can't be your own parent!" + if self.uuid and self.parent_uuid: + assert self.parent_uuid != self.uuid, "you can't be your own parent!" + return self + + @property + def root_label(self) -> str: + # If path is "/Beauty & Fitness/Hair Care/Hair Loss", this returns "Beauty & Fitness" + return self.path.split("/", 2)[1] + + @property + def parent_path(self) -> Optional[str]: + # If path is "/Beauty & Fitness/Hair Care/Hair Loss", this returns "/Beauty & Fitness/Hair Care" + return self.path.rsplit("/", 1)[0] or None + + @property + def is_root(self) -> bool: + return self.parent_path is None + + def to_offerwall_api(self) -> dict: + return { + "id": self.uuid, + "label": self.label, + "adwords_id": self.adwords_vertical_id, + "adwords_label": self.label if self.adwords_vertical_id else None, + } diff --git a/generalresearch/models/thl/contest/__init__.py b/generalresearch/models/thl/contest/__init__.py new file mode 100644 index 0000000..7cf1f54 --- /dev/null +++ b/generalresearch/models/thl/contest/__init__.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Optional, Dict, Any +from uuid import uuid4 + +from pydantic import ( + BaseModel, + Field, + model_validator, + computed_field, + PositiveInt, +) +from typing_extensions import Self + +from generalresearch.currency import USDCent +from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.thl.contest.definitions import ContestPrizeKind +from generalresearch.models.thl.user import User + + +class ContestEntryRule(BaseModel): + """Defines rules the user must meet to be allowed to enter this contest + Only applies if the ContestType is ENTRY! + """ + + max_entry_amount_per_user: Optional[USDCent | PositiveInt] = Field( + description="Maximum total value of entries per user", + default=None, + ) + + max_daily_entries_per_user: Optional[PositiveInt] = Field( + description="Maximum entries per user allowed per day for this contest", + default=None, + ) + + # TODO: Only allow entries if user meets some criteria: gold-membership status, + # ID/phone verified, min_completes etc... + # Maybe these get put in a separate model b/c the could apply if the ContestType is not ENTRY + min_completes: Optional[int] = None + min_membership_level: Optional[int] = None + id_verified: Optional[bool] = None + + +class ContestEndCondition(BaseModel): + """Defines the conditions to evaluate to determine when the contest is over. + Multiple conditions can be set. The contest is over once ANY conditions are met. + """ + + target_entry_amount: USDCent | PositiveInt | None = Field( + default=None, + ge=1, + description="The contest is over once this amount is reached. (sum of all entry amount)", + ) + # In a LeaderboardContest, ends_at equals the leaderboard's end period plus 90 minutes + ends_at: Optional[AwareDatetimeISO] = Field( + default=None, description="The contest is over at this time." + ) + + +class ContestPrize(BaseModel): + kind: ContestPrizeKind = Field( + description=ContestPrizeKind.as_openapi_with_value_descriptions() + ) + name: Optional[str] = Field(default=None) + description: Optional[str] = Field(default=None) + + estimated_cash_value: USDCent = Field( + description="Estimated cash value of prize in USDCents", + ) + cash_amount: Optional[USDCent] = Field( + default=None, + description="If the kind=ContestPrizeKind.CASH, this is the amount of the prize", + ) + promotion_id: Optional[UUIDStr] = Field( + default=None, + description="If the kind=ContestPrizeKind.PROMOTION, this is the promotion ID", + ) + # only if the contest.contest_type = LEADERBOARD + leaderboard_rank: Optional[PositiveInt] = Field( + default=None, + description="The prize is for achieving this rank in the associated " + "leaderboard. The highest rank is 1.", + ) + + @model_validator(mode="after") + def validate_cash_value(self) -> Self: + if self.kind == ContestPrizeKind.CASH: + assert ( + self.estimated_cash_value == self.cash_amount + ), "if kind is CASH, cash_amount must equal estimated_cash_value" + return self + + +class ContestWinner(BaseModel): + """ + In a Raffle, the ContestEntryType can be COUNT or CASH. In the CASH type, + the unit of entry is 1 USDCent (one penny). Implicitly, each penny entered + buys 1 entry into the raffle, and one entry is randomly selected for + each prize. + + A contest should have as many winners as there are prizes + special case 1: there are fewer entries than prizes + special case 2: leaderboard contest with ties + """ + + uuid: UUIDStr = Field(default_factory=lambda: uuid4().hex) + created_at: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When this user won this prize", + ) + + user: Optional[User] = Field(exclude=True, default=None) + + prize: ContestPrize = Field() + + awarded_cash_amount: Optional[USDCent] = Field( + default=None, + description="The actual amount this user receives. For cash prizes, if there was a tie, " + "this could be different from the prize amount.", + ) + + @computed_field() + @property + def product_user_id(self) -> Optional[str]: + # TODO: we'll have to pull username or censored emails or something + if self.user: + return self.user.product_user_id + + # @computed_field() + # @property + # def censored_product_user_id(self) -> str: + # return censor_product_user_id(self.user) + + def model_dump_mysql(self, contest_id: int) -> Dict[str, Any]: + data = self.model_dump(mode="json", exclude={"user"}) + + data["contest_id"] = contest_id + data["created_at"] = self.created_at + data["user_id"] = self.user.user_id + data["prize"] = self.prize.model_dump_json() + + return data diff --git a/generalresearch/models/thl/contest/contest.py b/generalresearch/models/thl/contest/contest.py new file mode 100644 index 0000000..232a038 --- /dev/null +++ b/generalresearch/models/thl/contest/contest.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import json +from abc import abstractmethod, ABC +from datetime import timezone, datetime +from typing import List, Tuple, Optional, Dict +from uuid import uuid4 + +from pydantic import ( + BaseModel, + Field, + HttpUrl, + ConfigDict, + model_validator, + NonNegativeInt, +) +from typing_extensions import Self + +from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.thl.contest import ( + ContestEndCondition, + ContestPrize, + ContestWinner, +) +from generalresearch.models.thl.contest.definitions import ( + ContestStatus, + ContestType, + ContestEndReason, +) +from generalresearch.models.thl.locales import CountryISOs + + +class ContestBase(BaseModel, ABC): + """ + This model will be used also as the "Create" API class, so nothing + goes on here that is not settable by an api user. + """ + + model_config = ConfigDict(validate_assignment=True) + + name: str = Field( + max_length=128, description="Name of contest. Can be displayed to user." + ) + description: Optional[str] = Field( + default=None, + max_length=2048, + description="Description of contest. Can be displayed to user.", + ) + + contest_type: ContestType = Field( + description=ContestType.as_openapi_with_value_descriptions() + ) + + end_condition: ContestEndCondition = Field() + """Defines the conditions to win one or more prizes once the contest is ended""" + prizes: List[ContestPrize] = Field(default_factory=list, min_items=1) + + starts_at: AwareDatetimeISO = Field( + description="When the contest starts", + default_factory=lambda: datetime.now(tz=timezone.utc), + ) + + terms_and_conditions: Optional[HttpUrl] = Field(default=None) + + status: ContestStatus = Field(default=ContestStatus.ACTIVE) + + country_isos: Optional[CountryISOs] = Field( + description="Contest is restricted to these countries. If null, all countries are allowed", + default=None, + ) + + def update(self, **kwargs) -> None: + # For dealing with updating multiple fields at once that would + # otherwise break validations + self.model_config["validate_assignment"] = False + for k, v in kwargs.items(): + setattr(self, k, v) + self.model_config["validate_assignment"] = True + self.__class__.model_validate(self) + + +class Contest(ContestBase): + id: Optional[int] = Field( + default=None, + exclude=True, + description="pk in db", + ) + + uuid: UUIDStr = Field(default_factory=lambda: uuid4().hex) + + product_id: UUIDStr = Field(description="Contest applies only to a single BP") + + created_at: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When this contest was created", + ) + updated_at: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When this contest was last modified. Does not include " + "entries being created/modified", + ) + + ended_at: Optional[AwareDatetimeISO] = Field( + default=None, + description="When the contest ended", + ) + + end_reason: Optional[ContestEndReason] = Field( + default=None, + description="The reason the contest ended", + ) + + all_winners: Optional[List[ContestWinner]] = Field( + default=None, + exclude=True, + description="All prize winners of this contest", + ) + + @model_validator(mode="after") + def validate_end(self): + if self.status == ContestStatus.ACTIVE: + assert self.ended_at is None, "ended_at when status is active" + assert self.end_reason is None, "end_reason when status is active" + assert self.all_winners is None, "all_winners when status is active" + else: + assert self.ended_at, "must set ended_at if contest ended" + assert self.end_reason, "must set end_reason if contest ended" + + return self + + # def is_user_winner(self, user: User): + # assert self.status == ContestStatus.COMPLETED + # result = self.result + # for winner in result.winners: + # if winner.user_id == user.user_id: + # return True + # return False + + def should_end(self) -> Tuple[bool, Optional[ContestEndReason]]: + if self.status == ContestStatus.ACTIVE: + if self.end_condition.ends_at: + if datetime.now(tz=timezone.utc) >= self.end_condition.ends_at: + return True, ContestEndReason.ENDS_AT + + return False, None + + @abstractmethod + def select_winners(self) -> Optional[List[ContestWinner]]: ... + + def end_contest(self) -> None: + e, reason = self.should_end() + if not e: + return None + # todo: Acquire a lock here, b/c this next part involves randomness + # so we can't have it happen more than once + winners = self.select_winners() + if winners is not None: + self.update( + status=ContestStatus.COMPLETED, + ended_at=datetime.now(tz=timezone.utc), + end_reason=reason, + all_winners=winners, + ) + else: + self.update( + status=ContestStatus.COMPLETED, + ended_at=datetime.now(tz=timezone.utc), + end_reason=reason, + ) + return None + + def model_dump_mysql(self, **kwargs) -> Dict: + d = self.model_dump(mode="json", **kwargs) + + d["created_at"] = self.created_at + d["updated_at"] = self.updated_at + d["starts_at"] = self.starts_at + if self.ended_at: + d["ended_at"] = self.ended_at + d["end_condition"] = self.end_condition.model_dump_json() + d["prizes"] = json.dumps([p.model_dump(mode="json") for p in self.prizes]) + + return d + + @classmethod + def model_validate_mysql(cls, data) -> Self: + data = {k: v for k, v in data.items() if k in cls.model_fields.keys()} + if isinstance(data["end_condition"], dict): + data["end_condition"] = ContestEndCondition.model_validate( + data["end_condition"] + ) + data["prizes"] = [ContestPrize.model_validate(p) for p in data["prizes"]] + return cls.model_validate(data) + + @property + def prize_count(self) -> NonNegativeInt: + return len(self.prizes) + + +class ContestUserView(Contest): + """This is the user's 'view' of a contest.""" + + product_user_id: str = Field() + + # TODO: this could show a more detailed ContestWinner model, maybe + # including like shipping status or whatever + user_winnings: List[ContestWinner] = Field( + description="The prizes won in this contest by the requested user", + default_factory=list, + ) + + def is_user_eligible(self, country_iso: str) -> Tuple[bool, str]: + now = datetime.now(tz=timezone.utc) + + assert country_iso.lower() == country_iso + if now < self.starts_at: + return False, "contest has not yet started" + if self.status != ContestStatus.ACTIVE: + return False, "contest not active" + if self.country_isos is not None and country_iso not in self.country_isos: + return False, "ineligible country" + + return True, "" diff --git a/generalresearch/models/thl/contest/contest_entry.py b/generalresearch/models/thl/contest/contest_entry.py new file mode 100644 index 0000000..146f06f --- /dev/null +++ b/generalresearch/models/thl/contest/contest_entry.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Union, Dict, Any +from uuid import uuid4 + +from pydantic import ( + Field, + BaseModel, + model_validator, + computed_field, +) + +from generalresearch.currency import USDCent +from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.thl.contest.definitions import ContestEntryType +from generalresearch.models.thl.user import User + + +class ContestEntryCreate(BaseModel): + entry_type: ContestEntryType = Field() + # The meaning of this field is dictated by the contest's ContestEntryType + amount: Union[USDCent, int] = Field( + description="The amount of the entry in integer counts or USD Cents", + gt=0, + default=None, + ) + # This is used in the Create Entry API. We'll look up the user and set + # user_id. When we return this model in the API, user_id is excluded + product_user_id: str = Field( + min_length=3, + max_length=128, + examples=["app-user-9329ebd"], + description="A unique identifier for each user, which is set by the " + "Supplier. It should not contain any sensitive information" + "like email or names, and should avoid using any" + "incrementing values.", + ) + + +class ContestEntry(BaseModel): + uuid: UUIDStr = Field(default_factory=lambda: uuid4().hex) + created_at: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(timezone.utc) + ) + updated_at: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(timezone.utc) + ) + + # entry_type and amount are the same as on ContestEntryCreate + entry_type: ContestEntryType = Field() + + # The meaning of this field is dictated by the contest's ContestEntryType + amount: Union[USDCent, int] = Field( + description="The amount of the entry in integer counts or USD Cents", + gt=0, + default=None, + ) + + # user_id used internally, for DB joins/index + user: User = Field(exclude=True) + + @model_validator(mode="before") + @classmethod + def validate_amount_type(cls, data: Dict) -> Dict: + from generalresearch.models.thl.contest.definitions import ( + ContestEntryType, + ) + + amount = data.get("amount") + entry_type = data.get("entry_type") + + if entry_type == ContestEntryType.COUNT: + assert isinstance(amount, int) and not isinstance( + amount, USDCent + ), "amount must be int in ContestEntryType.COUNT" + elif entry_type == ContestEntryType.CASH: + # This may be coming from the DB, in which case it is an int. + data["amount"] = USDCent(data["amount"]) + return data + + @computed_field() + def amount_str(self) -> str: + from generalresearch.models.thl.contest.definitions import ( + ContestEntryType, + ) + + if self.entry_type == ContestEntryType.COUNT: + return str(self.amount) + + elif self.entry_type == ContestEntryType.CASH: + return self.amount.to_usd_str() + + @computed_field() + @property + def censored_product_user_id(self) -> str: + from generalresearch.models.thl.contest.utils import ( + censor_product_user_id, + ) + + return censor_product_user_id(user=self.user) + + def model_dump_mysql(self, contest_id: int) -> Dict[str, Any]: + data = self.model_dump(mode="json", exclude={"user"}) + data["contest_id"] = contest_id + data["created_at"] = self.created_at + data["updated_at"] = self.updated_at + data["user_id"] = self.user.user_id + return data diff --git a/generalresearch/models/thl/contest/definitions.py b/generalresearch/models/thl/contest/definitions.py new file mode 100644 index 0000000..1a71408 --- /dev/null +++ b/generalresearch/models/thl/contest/definitions.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from enum import Enum + +from generalresearch.utils.enum import ReprEnumMeta + + +class ContestStatus(str, Enum): + ACTIVE = "active" + COMPLETED = "completed" + CANCELLED = "cancelled" + + +class ContestType(str, Enum, metaclass=ReprEnumMeta): + """There are 3 contest types. They have a common base, with some unique + configurations and behaviors for each. + """ + + # Explicit entries, winner(s) by random draw among entries. aka "random draw". + RAFFLE = "raffle" + + # Winner(s) by rank in a leaderboard. No entries. + LEADERBOARD = "leaderboard" + + # Reward is guaranteed for everyone who passes a threshold / meets some criteria + MILESTONE = "milestone" + + +class ContestEndReason(str, Enum): + """ + Defines why a contest ended + """ + + # Contest was cancelled. There are no winners. + CANCELLED = "cancelled" + + # Contest reached the target entry amount. + TARGET_ENTRY_AMOUNT = "target_entry_amount" + + # Contest reached the target end date. + ENDS_AT = "ends_at" + + # Contest reached the max number of winners (only in a milestone contest) + MAX_WINNERS = "max_winners" + + +class ContestPrizeKind(str, Enum, metaclass=ReprEnumMeta): + # A physical prize (e.g. a iPhone, cash in the mail, dinner with Max) + PHYSICAL = "physical" + + # A promotion is a temporary or special offer that provides extra value + # or benefits (e.g. 20% bonus on completes for the next 7 days) + PROMOTION = "promotion" + + # Money is deposited into user's virtual wallet + CASH = "cash" + + +class ContestEntryTrigger(str, Enum): + """ + Defines what action/event triggers a (possible) entry into the contest (automatically). + This only is valid on milestone contests + """ + + TASK_COMPLETE = "task_complete" + TASK_ATTEMPT = "task_attempt" + REFERRAL = "referral" + + +class ContestEntryType(str, Enum, metaclass=ReprEnumMeta): + """ + All entries into a contest must be of the same type, and match + the entry_type of the Contest itself. + """ + + # Each entry into the contest is an integer "count". In all current use + # cases, the value is 1, but we could change this if needed. + # This could be for e.g. each Task Complete, task attempt, or even each + # referral, etc. + COUNT = "count" + + # Each entry is tracking cash in units of USDCent. + CASH = "cash" + + +class LeaderboardTieBreakStrategy(str, Enum): + """ + Strategies for resolving ties in leaderboard-based contests. + """ + + # All tied users at a rank split the total value of prizes for those ranks. + # All prizes must be CASH + SPLIT_PRIZE_POOL = "split_prize_pool" + + # All tied users receive the full prize for that rank (i.e., duplicate + # prizes are issued). All prizes must be type PROMOTION + DUPLICATE_PRIZES = "duplicate_prizes" + + # First user(s) to reach the score win in case of a tie + # Might be used in case of physical prizes that can't be split + EARLIEST_TO_REACH = "earliest_to_reach" diff --git a/generalresearch/models/thl/contest/examples.py b/generalresearch/models/thl/contest/examples.py new file mode 100644 index 0000000..9748597 --- /dev/null +++ b/generalresearch/models/thl/contest/examples.py @@ -0,0 +1,404 @@ +from typing import Dict + +from pydantic import HttpUrl + +from generalresearch.config import EXAMPLE_PRODUCT_ID +from generalresearch.currency import USDCent + + +def _example_raffle_create(schema: Dict) -> None: + from generalresearch.models.thl.contest.raffle import ( + RaffleContestCreate, + ) + from generalresearch.models.thl.contest import ( + ContestEndCondition, + ContestPrize, + ContestEntryRule, + ) + from generalresearch.models.thl.contest.definitions import ( + ContestType, + ContestPrizeKind, + ) + from generalresearch.models.thl.contest.contest_entry import ( + ContestEntryType, + ) + + schema["example"] = RaffleContestCreate( + name="Win an iPhone", + description="iPhone winner will be drawn in proportion to entry " + "amount. Contest ends once $800 has been entered.", + contest_type=ContestType.RAFFLE, + end_condition=ContestEndCondition(target_entry_amount=USDCent(800_00)), + prizes=[ + ContestPrize( + kind=ContestPrizeKind.PHYSICAL, + name="iPhone 16", + estimated_cash_value=USDCent(800_00), + ) + ], + starts_at="2025-06-12T21:12:58.061170Z", + terms_and_conditions=None, + entry_rule=ContestEntryRule( + max_entry_amount_per_user=10000, max_daily_entries_per_user=1000 + ), + country_isos={"us", "ca"}, + entry_type=ContestEntryType.CASH, + ).model_dump(mode="json") + + +def _example_raffle(schema: Dict) -> None: + from generalresearch.models.thl.contest.raffle import RaffleContest + from generalresearch.models.thl.contest import ( + ContestEndCondition, + ContestPrize, + ContestEntryRule, + ) + from generalresearch.models.thl.contest.definitions import ( + ContestStatus, + ContestPrizeKind, + ContestType, + ) + from generalresearch.models.thl.contest.contest_entry import ( + ContestEntryType, + ) + + schema["example"] = RaffleContest( + name="Win an iPhone", + description="iPhone winner will be drawn in proportion to entry " + "amount. Contest ends once $800 has been entered.", + contest_type=ContestType.RAFFLE, + end_condition=ContestEndCondition(target_entry_amount=USDCent(800_00)), + prizes=[ + ContestPrize( + kind=ContestPrizeKind.PHYSICAL, + name="iPhone 16", + estimated_cash_value=USDCent(800_00), + ) + ], + starts_at="2025-06-12T21:12:58.061170Z", + terms_and_conditions=None, + entry_rule=ContestEntryRule( + max_entry_amount_per_user=10000, max_daily_entries_per_user=1000 + ), + country_isos={"us", "ca"}, + entry_type=ContestEntryType.CASH, + status=ContestStatus.ACTIVE, + uuid="ce3968b8e18a4b96af62007f262ed7f7", + created_at="2025-06-12T21:12:58.061205Z", + updated_at="2025-06-12T21:12:58.061205Z", + current_amount=4723, + current_participants=12, + product_id=EXAMPLE_PRODUCT_ID, + ).model_dump(mode="json") + + +def _example_raffle_user_view(schema: Dict) -> None: + from generalresearch.models.thl.contest.raffle import RaffleUserView + from generalresearch.models.thl.contest import ( + ContestEndCondition, + ContestPrize, + ContestEntryRule, + ) + from generalresearch.models.thl.contest.definitions import ( + ContestStatus, + ContestPrizeKind, + ContestType, + ) + from generalresearch.models.thl.contest.contest_entry import ( + ContestEntryType, + ) + + schema["example"] = RaffleUserView( + name="Win an iPhone", + description="iPhone winner will be drawn in proportion to entry " + "amount. Contest ends once $800 has been entered.", + contest_type=ContestType.RAFFLE, + end_condition=ContestEndCondition(target_entry_amount=USDCent(800_00)), + prizes=[ + ContestPrize( + kind=ContestPrizeKind.PHYSICAL, + name="iPhone 16", + estimated_cash_value=USDCent(800_00), + ) + ], + starts_at="2025-06-12T21:12:58.061170Z", + terms_and_conditions=None, + entry_rule=ContestEntryRule( + max_entry_amount_per_user=10000, max_daily_entries_per_user=1000 + ), + country_isos={"us", "ca"}, + entry_type=ContestEntryType.CASH, + status=ContestStatus.ACTIVE, + uuid="ce3968b8e18a4b96af62007f262ed7f7", + created_at="2025-06-12T21:12:58.061205Z", + updated_at="2025-06-12T21:12:58.061205Z", + current_amount=4723, + current_participants=12, + product_id=EXAMPLE_PRODUCT_ID, + user_amount=420, + user_amount_today=0, + product_user_id="test-user", + ).model_dump(mode="json") + + +def _example_milestone_create(schema: Dict) -> None: + from generalresearch.models.thl.contest.milestone import ( + MilestoneContestCreate, + MilestoneContestEndCondition, + ContestEntryTrigger, + ) + from generalresearch.models.thl.contest import ( + ContestPrize, + ) + from generalresearch.models.thl.contest.definitions import ( + ContestType, + ContestPrizeKind, + ) + + schema["example"] = MilestoneContestCreate( + name="Win a 50% bonus for 7 days and a $5 bonus after your first 10 completes!", + description="Only valid for the first 50 users", + contest_type=ContestType.MILESTONE, + end_condition=MilestoneContestEndCondition(max_winners=50), + prizes=[ + ContestPrize( + kind=ContestPrizeKind.PROMOTION, + name="50% bonus on completes for 7 days", + estimated_cash_value=USDCent(0), + ), + ContestPrize( + kind=ContestPrizeKind.CASH, + name="$5.00 Bonus", + cash_amount=USDCent(5_00), + estimated_cash_value=USDCent(5_00), + ), + ], + entry_trigger=ContestEntryTrigger.TASK_COMPLETE, + target_amount=10, + starts_at="2025-06-12T21:12:58.061170Z", + terms_and_conditions=HttpUrl("https://www.example.com"), + ).model_dump(mode="json") + + +def _example_milestone(schema: Dict) -> None: + from generalresearch.models.thl.contest.milestone import ( + MilestoneContest, + MilestoneContestEndCondition, + ContestEntryTrigger, + ) + from generalresearch.models.thl.contest import ( + ContestPrize, + ) + from generalresearch.models.thl.contest.definitions import ( + ContestType, + ContestPrizeKind, + ) + + schema["example"] = MilestoneContest( + name="Win a 50% bonus for 7 days and a $5 bonus after your first 10 completes!", + description="Only valid for the first 50 users", + contest_type=ContestType.MILESTONE, + end_condition=MilestoneContestEndCondition(max_winners=50), + prizes=[ + ContestPrize( + kind=ContestPrizeKind.PROMOTION, + name="50% bonus on completes for 7 days", + estimated_cash_value=USDCent(0), + ), + ContestPrize( + kind=ContestPrizeKind.CASH, + name="$5.00 Bonus", + cash_amount=USDCent(5_00), + estimated_cash_value=USDCent(5_00), + ), + ], + entry_trigger=ContestEntryTrigger.TASK_COMPLETE, + target_amount=10, + starts_at="2025-06-12T21:12:58.061170Z", + terms_and_conditions=HttpUrl("https://www.example.com"), + product_id=EXAMPLE_PRODUCT_ID, + uuid="747fe3b709ae460e816821dcb81aebb9", + created_at="2025-06-12T21:12:58.061205Z", + updated_at="2025-06-12T21:12:58.061205Z", + win_count=12, + ).model_dump(mode="json") + + +def _example_milestone_user_view(schema: Dict) -> None: + from generalresearch.models.thl.contest.milestone import ( + MilestoneUserView, + MilestoneContestEndCondition, + ContestEntryTrigger, + ) + from generalresearch.models.thl.contest import ContestPrize + from generalresearch.models.thl.contest.definitions import ( + ContestType, + ContestPrizeKind, + ) + + schema["example"] = MilestoneUserView( + name="Win a 50% bonus for 7 days and a $5 bonus after your first 10 completes!", + description="Only valid for the first 50 users", + contest_type=ContestType.MILESTONE, + end_condition=MilestoneContestEndCondition(max_winners=50), + prizes=[ + ContestPrize( + kind=ContestPrizeKind.PROMOTION, + name="50% bonus on completes for 7 days", + estimated_cash_value=USDCent(0), + ), + ContestPrize( + kind=ContestPrizeKind.CASH, + name="$5.00 Bonus", + cash_amount=USDCent(5_00), + estimated_cash_value=USDCent(5_00), + ), + ], + entry_trigger=ContestEntryTrigger.TASK_COMPLETE, + target_amount=10, + starts_at="2025-06-12T21:12:58.061170Z", + terms_and_conditions=HttpUrl("https://www.example.com"), + product_id=EXAMPLE_PRODUCT_ID, + uuid="747fe3b709ae460e816821dcb81aebb9", + created_at="2025-06-12T21:12:58.061205Z", + updated_at="2025-06-12T21:12:58.061205Z", + win_count=12, + user_amount=8, + product_user_id="test-user", + ).model_dump(mode="json") + + +def _example_leaderboard_contest_create(schema: Dict) -> None: + from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContestCreate, + ) + from generalresearch.models.thl.contest import ( + ContestPrize, + ) + from generalresearch.models.thl.contest.definitions import ( + ContestType, + ContestPrizeKind, + ) + + schema["example"] = LeaderboardContestCreate( + name="Prizes for top survey takers this week", + description="$15 1st place, $10 2nd, $5 3rd place US weekly", + contest_type=ContestType.LEADERBOARD, + prizes=[ + ContestPrize( + name="$15 Cash", + estimated_cash_value=USDCent(15_00), + cash_amount=USDCent(15_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=1, + ), + ContestPrize( + name="$10 Cash", + estimated_cash_value=USDCent(10_00), + cash_amount=USDCent(10_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=2, + ), + ContestPrize( + name="$5 Cash", + estimated_cash_value=USDCent(5_00), + cash_amount=USDCent(5_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=3, + ), + ], + leaderboard_key=f"leaderboard:{EXAMPLE_PRODUCT_ID}:us:weekly:2025-05-26:complete_count", + ).model_dump(mode="json") + + return None + + +def _example_leaderboard_contest(schema: Dict) -> None: + from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContest, + ) + from generalresearch.models.thl.contest import ( + ContestPrize, + ) + from generalresearch.models.thl.contest.definitions import ( + ContestType, + ContestPrizeKind, + ) + + schema["example"] = LeaderboardContest( + name="Prizes for top survey takers this week", + description="$15 1st place, $10 2nd, $5 3rd place US weekly", + contest_type=ContestType.LEADERBOARD, + prizes=[ + ContestPrize( + name="$15 Cash", + estimated_cash_value=USDCent(15_00), + cash_amount=USDCent(15_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=1, + ), + ContestPrize( + name="$10 Cash", + estimated_cash_value=USDCent(10_00), + cash_amount=USDCent(10_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=2, + ), + ContestPrize( + name="$5 Cash", + estimated_cash_value=USDCent(5_00), + cash_amount=USDCent(5_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=3, + ), + ], + leaderboard_key=f"leaderboard:{EXAMPLE_PRODUCT_ID}:us:weekly:2025-05-26:complete_count", + product_id=EXAMPLE_PRODUCT_ID, + ).model_dump(mode="json") + + return None + + +def _example_leaderboard_contest_user_view(schema: Dict) -> None: + from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContestUserView, + ) + from generalresearch.models.thl.contest import ( + ContestPrize, + ) + from generalresearch.models.thl.contest.definitions import ( + ContestPrizeKind, + ContestType, + ) + + schema["example"] = LeaderboardContestUserView( + name="Prizes for top survey takers this week", + description="$15 1st place, $10 2nd, $5 3rd place US weekly", + contest_type=ContestType.LEADERBOARD, + prizes=[ + ContestPrize( + name="$15 Cash", + estimated_cash_value=USDCent(15_00), + cash_amount=USDCent(15_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=1, + ), + ContestPrize( + name="$10 Cash", + estimated_cash_value=USDCent(10_00), + cash_amount=USDCent(10_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=2, + ), + ContestPrize( + name="$5 Cash", + estimated_cash_value=USDCent(5_00), + cash_amount=USDCent(5_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=3, + ), + ], + leaderboard_key=f"leaderboard:{EXAMPLE_PRODUCT_ID}:us:weekly:2025-05-26:complete_count", + product_id=EXAMPLE_PRODUCT_ID, + product_user_id="test-user", + ).model_dump(mode="json") diff --git a/generalresearch/models/thl/contest/exceptions.py b/generalresearch/models/thl/contest/exceptions.py new file mode 100644 index 0000000..a5cdd42 --- /dev/null +++ b/generalresearch/models/thl/contest/exceptions.py @@ -0,0 +1,2 @@ +class ContestError(Exception): + pass diff --git a/generalresearch/models/thl/contest/io.py b/generalresearch/models/thl/contest/io.py new file mode 100644 index 0000000..8133c1a --- /dev/null +++ b/generalresearch/models/thl/contest/io.py @@ -0,0 +1,47 @@ +from datetime import datetime, timezone +from typing import Union +from uuid import uuid4 + +from generalresearch.models.thl.contest.definitions import ContestType +from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContest, + LeaderboardContestCreate, + LeaderboardContestUserView, +) +from generalresearch.models.thl.contest.milestone import ( + MilestoneContest, + MilestoneContestCreate, + MilestoneUserView, +) +from generalresearch.models.thl.contest.raffle import ( + RaffleContest, + RaffleContestCreate, + RaffleUserView, +) + +model_cls = { + ContestType.RAFFLE: RaffleContest, + ContestType.MILESTONE: MilestoneContest, + ContestType.LEADERBOARD: LeaderboardContest, +} +user_model_cls = { + ContestType.RAFFLE: RaffleUserView, + ContestType.MILESTONE: MilestoneUserView, + ContestType.LEADERBOARD: LeaderboardContestUserView, +} +ContestCreate = Union[ + RaffleContestCreate, LeaderboardContestCreate, MilestoneContestCreate +] +from generalresearch.models.thl.contest.contest import Contest + + +def contest_create_to_contest( + product_id: str, contest_create: ContestCreate +) -> Contest: + now = datetime.now(tz=timezone.utc) + d = contest_create.model_dump(mode="json") + d["uuid"] = uuid4().hex + d["product_id"] = product_id + d["created_at"] = now + d["updated_at"] = now + return model_cls[contest_create.contest_type].model_validate(d) diff --git a/generalresearch/models/thl/contest/leaderboard.py b/generalresearch/models/thl/contest/leaderboard.py new file mode 100644 index 0000000..0b24190 --- /dev/null +++ b/generalresearch/models/thl/contest/leaderboard.py @@ -0,0 +1,289 @@ +from datetime import datetime, timezone, timedelta +from typing import Optional, Literal, List, Tuple, Dict, Any + +from pydantic import ( + Field, + ConfigDict, + computed_field, + model_validator, + PrivateAttr, +) +from redis import Redis +from typing_extensions import Self + +from generalresearch.decorators import LOG +from generalresearch.managers.leaderboard import country_timezone +from generalresearch.managers.leaderboard.manager import LeaderboardManager +from generalresearch.managers.thl.user_manager.user_manager import ( + UserManager, +) +from generalresearch.models.thl.contest import ( + ContestWinner, + ContestEndCondition, +) +from generalresearch.models.thl.contest.contest import ( + Contest, + ContestBase, + ContestUserView, +) +from generalresearch.models.thl.contest.definitions import ( + ContestStatus, + ContestType, + ContestPrizeKind, + ContestEndReason, + LeaderboardTieBreakStrategy, +) +from generalresearch.models.thl.contest.examples import ( + _example_leaderboard_contest_user_view, + _example_leaderboard_contest, + _example_leaderboard_contest_create, +) +from generalresearch.models.thl.leaderboard import ( + Leaderboard, + LeaderboardCode, + LeaderboardFrequency, +) + + +class LeaderboardContestCreate(ContestBase): + model_config = ConfigDict( + validate_assignment=True, + extra="forbid", + json_schema_extra=_example_leaderboard_contest_create, + ) + + contest_type: Literal[ContestType.LEADERBOARD] = Field( + default=ContestType.LEADERBOARD + ) + + # leaderboard:{product_id}:{country_iso}:{freq.value}:{date_str}:{board_code.value}" + leaderboard_key: str = Field( + description="The specific leaderboard instance this contest is connected to", + examples=[ + "leaderboard:7a9d8d02334449ceb105764f77e1ba97:us:weekly:2025-05-26:complete_count" + ], + ) + + # This is optional here. It'll get calculated from the leaderboard's end time + 90 min. + end_condition: ContestEndCondition = Field(default_factory=ContestEndCondition) + + @model_validator(mode="after") + def check_prize_rank(self) -> Self: + for prize in self.prizes: + assert prize.leaderboard_rank, "prize leaderboard_rank must be set" + + self.prizes.sort(key=lambda x: x.leaderboard_rank) + ranks = {x.leaderboard_rank for x in self.prizes} + assert None not in ranks, "Must have leaderboard_rank defined" + assert min(ranks) == 1, "Must start with rank 1" + assert ranks == set( + range(min(ranks), max(ranks) + 1) + ), "cannot skip prize leaderboard_ranks" + return self + + @model_validator(mode="after") + def validate_leaderboard_key(self) -> Self: + # Force validation + _ = self.leaderboard_model + return self + + @model_validator(mode="after") + def check_end_condition(self) -> Self: + assert ( + not self.end_condition.target_entry_amount + ), "target_entry_amount not valid in leaderboard contest" + # the ends_at will get set automatically from the leaderboard_key + return self + + @property + def leaderboard_key_parts(self) -> Dict: + assert self.leaderboard_key.count(":") == 5, "invalid leaderboard_key" + parts = self.leaderboard_key.split(":") + _, product_id, country_iso, freq_str, date_str, board_code_value = parts + freq = LeaderboardFrequency(freq_str) + board_code = LeaderboardCode(board_code_value) + timezone = country_timezone()[country_iso] + period_start_local = datetime.strptime(date_str, "%Y-%m-%d").replace( + tzinfo=timezone + ) + return { + "freq": freq, + "product_id": product_id, + "board_code": board_code, + "period_start_local": period_start_local, + "country_iso": country_iso, + } + + @property + def leaderboard_model(self) -> Leaderboard: + parts = self.leaderboard_key_parts + # This isn't hitting the db/redis or anything. Just initializing the model, so we can access + # some computed properties. + return Leaderboard.model_validate( + parts | {"row_count": 0, "bpid": parts["product_id"]} + ) + + +class LeaderboardContest(LeaderboardContestCreate, Contest): + model_config = ConfigDict( + validate_assignment=True, + extra="forbid", + json_schema_extra=_example_leaderboard_contest, + arbitrary_types_allowed=True, + ) + + # TODO: only this strategy supported for now + tie_break_strategy: Literal[LeaderboardTieBreakStrategy.SPLIT_PRIZE_POOL] = Field( + default=LeaderboardTieBreakStrategy.SPLIT_PRIZE_POOL + ) + + _redis_client: Optional[Redis] = PrivateAttr(default=None) + _user_manager: Optional[UserManager] = PrivateAttr(default=None) + + @model_validator(mode="after") + def validate_product_lb_key(self) -> Self: + assert ( + self.product_id == self.leaderboard_key_parts["product_id"] + ), "leaderboard_key product_id is invalid" + if self.country_isos: + assert ( + len(self.country_isos) == 1 + ), "Can only set 1 country_iso in a leaderboard contest" + assert ( + list(self.country_isos)[0] == self.leaderboard_key_parts["country_iso"] + ), "leaderboard_key country_iso must match the country_isos" + else: + self.country_isos = {self.leaderboard_key_parts["country_iso"]} + return self + + @model_validator(mode="after") + def validate_tie_break(self) -> Self: + if self.tie_break_strategy == LeaderboardTieBreakStrategy.SPLIT_PRIZE_POOL: + assert all( + p.kind == ContestPrizeKind.CASH for p in self.prizes + ), "All prizes must be cash due to the tie-break strategy" + return self + + @model_validator(mode="after") + def set_ends_at(self) -> Self: + ends_at = self.leaderboard_model.period_end_utc + timedelta(minutes=90) + assert self.end_condition.ends_at in { + None, + ends_at, + }, "Do not set the end_condition. It will be calculated" + self.end_condition.ends_at = ends_at + return self + + def get_leaderboard(self) -> Leaderboard: + lbm = self.get_leaderboard_manager() + return lbm.get_leaderboard() + + def get_leaderboard_manager(self) -> LeaderboardManager: + parts = self.leaderboard_key_parts + lbm = LeaderboardManager( + redis_client=self._redis_client, + board_code=parts["board_code"], + country_iso=parts["country_iso"], + freq=parts["freq"], + product_id=parts["product_id"], + within_time=parts["period_start_local"], + ) + return lbm + + def should_end(self) -> Tuple[bool, Optional[ContestEndReason]]: + if self.status == ContestStatus.ACTIVE: + if self.end_condition.ends_at: + if datetime.now(tz=timezone.utc) >= self.end_condition.ends_at: + return True, ContestEndReason.ENDS_AT + + return False, None + + def select_winners(self) -> List[ContestWinner]: + from generalresearch.models.thl.contest.utils import ( + distribute_leaderboard_prizes, + ) + + assert self.should_end(), "contest must be complete to select a winner" + assert ( + self.tie_break_strategy == LeaderboardTieBreakStrategy.SPLIT_PRIZE_POOL + ), "invalid tie break strategy" + redis_client = self._redis_client + user_manager = self._user_manager + assert redis_client and user_manager, "must set redis_client and user_manager" + + lb = self.get_leaderboard() + prize_values = [p.cash_amount for p in self.prizes] + assert all(x for x in prize_values), "invalid prize cash amount" + result = distribute_leaderboard_prizes(prize_values, lb.rows) + user_rank = {r.bpuid: r.rank for r in lb.rows} + winners = [] + prizes = sorted(self.prizes, key=lambda x: x.cash_amount, reverse=True) + for bpuid, cash_value in result.items(): + prize = prizes[user_rank[bpuid] - 1] # lb rank starts at 1 :facepalm: + user = user_manager.get_user( + product_id=self.product_id, product_user_id=bpuid + ) + winners.append( + ContestWinner(user=user, awarded_cash_amount=cash_value, prize=prize) + ) + return winners + + @computed_field + @property + def country_iso(self) -> str: + return self.leaderboard_key.split(":")[2] + + def model_dump_mysql(self) -> Dict[str, Any]: + d = super().model_dump_mysql( + exclude={ + "tie_break_strategy", + "country_iso", + } + ) + return d + + +class LeaderboardContestUserView(LeaderboardContest, ContestUserView): + model_config = ConfigDict( + validate_assignment=True, + extra="forbid", + json_schema_extra=_example_leaderboard_contest_user_view, + ) + + @computed_field(description="The current rank of this user in this contest") + @property + def user_rank(self) -> Optional[int]: + if not self._redis_client: + return None + + lb = self.get_leaderboard() + for row in lb.rows: + if row.bpuid == self.product_user_id: + return row.rank + + return None + + def is_user_eligible(self, country_iso: str) -> Tuple[bool, str]: + passes, msg = super().is_user_eligible(country_iso=country_iso) + if not passes: + return False, msg + + if country_iso != self.country_iso: + return False, "Invalid country" + + if self.user_winnings: + return False, "User already won" + + now = datetime.now(tz=timezone.utc) + if self.leaderboard_model.period_end_utc < now: + return False, "Contest is over" + if self.leaderboard_model.period_start_utc > now: + return False, "Contest hasn't started" + + # This would indicate something is wrong, as something else should have done this + e, reason = self.should_end() + if e: + LOG.warning("contest should be over") + return False, "contest is over" + + return True, "" diff --git a/generalresearch/models/thl/contest/milestone.py b/generalresearch/models/thl/contest/milestone.py new file mode 100644 index 0000000..f902401 --- /dev/null +++ b/generalresearch/models/thl/contest/milestone.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +import logging +from datetime import timedelta +from typing import Literal, Optional, Tuple, Dict + +from pydantic import ( + Field, + ConfigDict, + BaseModel, + PositiveInt, +) +from typing_extensions import Self + +from generalresearch.models.custom_types import AwareDatetimeISO +from generalresearch.models.thl.contest.contest import ( + ContestBase, + Contest, + ContestUserView, +) +from generalresearch.models.thl.contest.contest_entry import ContestEntry +from generalresearch.models.thl.contest.definitions import ( + ContestEntryType, + ContestStatus, + ContestType, + ContestEndReason, + ContestEntryTrigger, +) +from generalresearch.models.thl.contest.examples import ( + _example_milestone_create, + _example_milestone, + _example_milestone_user_view, +) + +logging.basicConfig() +LOG = logging.getLogger() +LOG.setLevel(logging.INFO) + + +class MilestoneEntry(ContestEntry): + # Same as ContestEntry, but always a count. + + entry_type: Literal[ContestEntryType.COUNT] = Field(default=ContestEntryType.COUNT) + + # TODO: Must fix - how can the default be None if it's not Optional... + amount: int = Field( + default=None, + description="The amount of the entry in integer counts", + gt=0, + ) + + +class MilestoneContestEndCondition(BaseModel): + """Defines the conditions to evaluate to determine when the contest is over. + Multiple conditions can be set. The contest is over once ANY conditions are met. + """ + + max_winners: Optional[PositiveInt] = Field( + default=None, + description="The contest will end once this many users have won (i.e. reached" + "the milestone).", + ) + + ends_at: Optional[AwareDatetimeISO] = Field( + default=None, description="The Contest is over at the ends_at time." + ) + + +class MilestoneContestConfig(BaseModel): + """ + Contest configuration specific to a milestone contest + """ + + target_amount: PositiveInt = Field( + description="Each user 'wins' (receives prizes) once this target amount is reached." + ) + entry_trigger: Optional[ContestEntryTrigger] = Field( + description="What user action triggers an entry automatically.", + default=None, + ) + + # These two fields allow something like: "Get a complete in your first 24 hours!" + valid_for: Optional[timedelta] = Field( + description="The time after valid_for_event for which the contest is open", + default=None, + ) + valid_for_event: Optional[Literal["signup"]] = Field(default=None) + + +class MilestoneContestCreate(ContestBase, MilestoneContestConfig): + """Reward is guaranteed for everyone who passes a threshold / meets + some criteria. + + e.g. $5 bonus after 10 lifetime completes, OR "after earning $100", + OR "passing ID verification". + + A milestone has at most 1 entry (contest_entry) table per user + per contest. In that entry, we track the "amount", whether is + it completes, money, whatever, as an integer. + + An instance of a milestone contest is "scoped" to an individual user + (i.e, the entries/balance should only be populated for the user of + interest only) + """ + + model_config = ConfigDict( + validate_assignment=True, + extra="forbid", + json_schema_extra=_example_milestone_create, + ) + + contest_type: Literal[ContestType.MILESTONE] = Field(default=ContestType.MILESTONE) + + end_condition: MilestoneContestEndCondition = Field() + + +class MilestoneContest(MilestoneContestCreate, Contest): + model_config = ConfigDict( + validate_assignment=True, + extra="forbid", + json_schema_extra=_example_milestone, + ) + + entry_type: Literal[ContestEntryType.COUNT] = Field(default=ContestEntryType.COUNT) + + # Note: a milestone can only ever be reached ONCE per user. + win_count: int = Field( + description="The number of times the milestone has been reached.", + default=0, + ) + + def should_end(self) -> Tuple[bool, Optional[ContestEndReason]]: + res, msg = super().should_end() + + if res: + return res, msg + + if self.status == ContestStatus.ACTIVE: + if self.end_condition.max_winners: + if self.win_count >= self.end_condition.max_winners: + return True, ContestEndReason.MAX_WINNERS + + return False, None + + def select_winners(self) -> None: + # milestone contest winners are selected as each user reaches the milestone, so this + # just does nothing + return None + + def model_dump_mysql(self): + d = super().model_dump_mysql( + exclude={ + "entry_trigger", + "target_amount", + "valid_for", + "valid_for_event", + } + ) + d["milestone_config"] = MilestoneContestConfig( + entry_trigger=self.entry_trigger, + target_amount=self.target_amount, + valid_for=self.valid_for, + valid_for_event=self.valid_for_event, + ).model_dump_json() + return d + + @classmethod + def model_validate_mysql(cls, data: Dict) -> Self: + data.update( + MilestoneContestConfig.model_validate(data["milestone_config"]).model_dump() + ) + data["end_condition"] = MilestoneContestEndCondition.model_validate( + data["end_condition"] + ) + return super().model_validate_mysql(data) + + +class MilestoneUserView(MilestoneContest, ContestUserView): + model_config = ConfigDict( + validate_assignment=True, + extra="forbid", + json_schema_extra=_example_milestone_user_view, + ) + + valid_until: Optional[AwareDatetimeISO] = Field( + default=None, + exclude=True, + description="If valid_for is set, this gets populated wrt this user", + ) + user_amount: int = Field( + description="The total amount for this user for this contest" + ) + + def should_award(self): + if self.status == ContestStatus.ACTIVE: + if self.should_have_awarded(): + return True + return False + + def should_have_awarded(self): + if self.target_amount: + if self.user_amount >= self.target_amount: + return True + return False + + def is_user_eligible(self, country_iso: str) -> Tuple[bool, str]: + passes, msg = super().is_user_eligible(country_iso=country_iso) + if not passes: + return False, msg + + if self.should_have_awarded(): + return False, "User should have won already" + + if self.user_winnings: + return False, "User already won" + + # todo: check valid_for and valid_for_event + # i.e. it hasn't been >24 hrs since user signed up, or whatever + + # This would indicate something is wrong, as something else should have done this + e, reason = self.should_end() + if e: + LOG.warning("contest should be over") + return False, "contest is over" + # TODO: others in self.entry_rule ... min_completes, id_verified, etc. + return True, "" diff --git a/generalresearch/models/thl/contest/raffle.py b/generalresearch/models/thl/contest/raffle.py new file mode 100644 index 0000000..9a01d0f --- /dev/null +++ b/generalresearch/models/thl/contest/raffle.py @@ -0,0 +1,317 @@ +from __future__ import annotations + +import logging +import random +from collections import defaultdict +from datetime import datetime, timezone +from typing import Literal, List, Dict, Tuple, Optional, Union + +from pydantic import ( + Field, + model_validator, + computed_field, + field_validator, + ConfigDict, +) +from scipy.stats import hypergeom +from typing_extensions import Self + +from generalresearch.currency import USDCent +from generalresearch.models.thl.contest import ( + ContestEntryRule, + ContestWinner, +) +from generalresearch.models.thl.contest.contest import ( + Contest, + ContestBase, + ContestUserView, +) +from generalresearch.models.thl.contest.contest_entry import ContestEntry +from generalresearch.models.thl.contest.definitions import ( + ContestEntryType, + ContestStatus, + ContestType, + ContestEndReason, +) +from generalresearch.models.thl.contest.examples import ( + _example_raffle_create, + _example_raffle, + _example_raffle_user_view, +) + +logging.basicConfig() +LOG = logging.getLogger() +LOG.setLevel(logging.INFO) + + +class RaffleContestCreate(ContestBase): + model_config = ConfigDict( + validate_assignment=True, + extra="forbid", + json_schema_extra=_example_raffle_create, + ) + + contest_type: Literal[ContestType.RAFFLE] = Field(default=ContestType.RAFFLE) + + # Only cash supported for now. We don't have ledger methods to deal with ContestEntryType.COUNT + entry_type: Literal[ContestEntryType.CASH] = Field(default=ContestEntryType.CASH) + entry_rule: ContestEntryRule = Field(default_factory=ContestEntryRule) + + @model_validator(mode="after") + def at_least_1_end_condition(self): + ec = self.end_condition + if not any([ec.target_entry_amount, ec.ends_at]): + raise ValueError("At least one end condition must be specified") + return self + + +class RaffleContest(RaffleContestCreate, Contest): + model_config = ConfigDict( + validate_assignment=True, + extra="forbid", + json_schema_extra=_example_raffle, + ) + + entries: List[ContestEntry] = Field(default_factory=list, exclude=True) + + current_amount: Union[int, USDCent] = Field( + default=0, description="Sum of all entry amounts" + ) + current_participants: int = Field( + default=0, description="Count of unique participants" + ) + + @field_validator("entries", mode="after") + def sort_entries(cls, v: List[ContestEntry]): + return sorted(v, key=lambda x: x.created_at) + + @model_validator(mode="after") + def validate_entry_type(self): + assert all( + entry.entry_type == self.entry_type for entry in self.entries + ), f"all entries must be of type {self.entry_type}" + return self + + @field_validator("current_amount", mode="before") + def coerce_current_amount(cls, v, info): + if v is None: + return None + if info.data.get("entry_type") == ContestEntryType.CASH: + return USDCent(v) + elif info.data.get("entry_type") == ContestEntryType.COUNT: + return int(v) + return v + + @model_validator(mode="after") + def validate_end_condition_cash(self): + # Make sure target amount is the right type + if self.end_condition and self.end_condition.target_entry_amount: + if self.entry_type == ContestEntryType.CASH: + self.end_condition.target_entry_amount = USDCent( + self.end_condition.target_entry_amount + ) + else: + self.end_condition.target_entry_amount = int( + self.end_condition.target_entry_amount + ) + return self + + def select_winners(self) -> List["ContestWinner"]: + from generalresearch.models.thl.contest import ContestWinner + + assert self.is_complete(), "contest must be complete to select a winner" + if not self.entries: + return [] + + # Each contest entry is one penny. We need to know how many + # total entries each user has. + # If there is more than 1 prize, the winning entry is subtracted + # from the user's entry count + user_amount = defaultdict(int) + user_id_user = dict() + for entry in self.entries: + user_amount[entry.user.user_id] += entry.amount + user_id_user[entry.user.user_id] = entry.user + + winners = [] + for prize in self.prizes: + # todo: should the prizes be ordered lowest estimated_cash_value + # to highest? or the other way around? + user_id = self.select_winner(user_amount) + winners.append(ContestWinner(user=user_id_user[user_id], prize=prize)) + user_amount[user_id] -= 1 + user_amount = {k: v for k, v in user_amount.items() if v > 0} + if not user_amount: + break + + return winners + + def should_end(self) -> Tuple[bool, Optional["ContestEndReason"]]: + res, msg = super().should_end() + if res: + return res, msg + if self.status == ContestStatus.ACTIVE: + if self.end_condition.target_entry_amount: + if self.current_amount >= self.end_condition.target_entry_amount: + return True, ContestEndReason.TARGET_ENTRY_AMOUNT + return False, None + + @staticmethod + def select_winner(user_amount: Dict[int, int]) -> int: + """ + user_amount: Dict[user_id, amount], is total entry count for each user, + e.g. {1111: 5, 2222: 1, 3333: 2} + returns: user_id of winner + """ + user_idx = [] + total = 0 + for user, amount in user_amount.items(): + total += amount + user_idx.append((user, total)) + # Generate a list of the cumulative sum of entries, indexed + # by each user. e.g. user_idx = [(1111, 5), (2222, 6), (3333, 8)] + # We then generate a random number between 0 and the max, and + # the winner is the first user who's cumcount is greater. + idx = random.randint(1, total) + winner = next(x[0] for x in user_idx if idx <= x[1]) + return winner + + # @property + # def current_entry_count(self) -> int: + # assert self.entry_type == ContestEntryType.COUNT + # # this is only valid if the amounts are 1 + # assert not self.entries or all(e.amount == 1 for e in self.entries) + # return len(self.entries) + + def get_current_participants(self) -> int: + return len({entry.user.user_id for entry in self.entries}) + + def get_current_amount(self) -> Union[int | USDCent]: + return sum([x.amount for x in self.entries]) + + def get_user_amount(self, product_user_id: str) -> Union[int | USDCent]: + # Sum of this user's amounts + return sum( + e.amount for e in self.entries if e.user.product_user_id == product_user_id + ) + + def is_complete(self) -> bool: + """Check if contest has reached any completion condition""" + if self.status == ContestStatus.COMPLETED: + return True + c = self.end_condition + if c.target_entry_amount and self.current_amount >= c.target_entry_amount: + return True + if c.ends_at and datetime.now(tz=timezone.utc) >= c.ends_at: + return True + return False + + def model_dump_mysql(self): + d = super().model_dump_mysql() + d["entry_rule"] = self.entry_rule.model_dump_json() + return d + + @classmethod + def model_validate_mysql(cls, data: Dict) -> Self: + data["entry_rule"] = ContestEntryRule.model_validate(data["entry_rule"]) + return super().model_validate_mysql(data) + + +class RaffleUserView(RaffleContest, ContestUserView): + model_config = ConfigDict( + validate_assignment=True, + extra="forbid", + json_schema_extra=_example_raffle_user_view, + ) + + user_amount: Union[int, USDCent] = Field( + description="The total amount this user has entered" + ) + user_amount_today: Union[int, USDCent] = Field( + description="The total amount this user has entered in the past 24 hours" + ) + + @computed_field( + description="Probability of this user winning 1 or more prizes, if the contest" + "ended right now" + ) + @property + def current_win_probability(self) -> float: + # equals 1 minus the probability of winning 0 prizes + # This is equivalent to user_amount / current_amount if there is only 1 prize. + if self.current_amount == 0: + # otherwise the result is NaN. If there are no entrances yet, return 0 + return 0.0 + return 1 - hypergeom.pmf( + 0, self.current_amount, self.user_amount, self.prize_count + ) + + @computed_field( + description="Probability of this user winning 1 or more prizes, once the contest" + "is projected to end. This value is only calculated if the contest has a target_entry_amount" + "end condition." + ) + @property + def projected_win_probability(self) -> Optional[float]: + if self.end_condition.target_entry_amount is None: + return None + + return 1 - hypergeom.pmf( + 0, + max(self.end_condition.target_entry_amount, self.current_amount), + self.user_amount, + self.prize_count, + ) + + # Not sure how to return this in api response, too confusing. Maybe use later. + # Left for tests only. + @property + def current_prize_count_probability(self) -> Dict[int, float]: + # M: Population size (total entry amount) + M = self.current_amount + # n: number of success states (user's entry amount) + n = self.user_amount + # N: number of draws + N = self.prize_count + + # Probability of drawing k of user's tickets (user winning K times) + probs = {k: hypergeom.pmf(k, M, n, N) for k in range(1, N + 1)} + return probs + + def is_entry_eligible(self, entry: ContestEntry) -> Tuple[bool, str]: + if self.entry_rule.max_entry_amount_per_user: + if ( + self.user_amount + entry.amount + ) > self.entry_rule.max_entry_amount_per_user: + return False, "Entry would exceed max amount per user." + + if self.entry_rule.max_daily_entries_per_user: + if ( + self.user_amount_today + entry.amount + ) > self.entry_rule.max_daily_entries_per_user: + return False, "Entry would exceed max amount per user per day." + return True, "" + + def is_user_eligible(self, country_iso: str) -> Tuple[bool, str]: + passes, msg = super().is_user_eligible(country_iso=country_iso) + if not passes: + return False, msg + + if self.entry_rule.max_entry_amount_per_user: + # Greater or equal b/c we're asking if the user is eligible to + # enter MORE, now! If it equals, nothing is wrong, just that they + # are not eligible anymore. + if self.user_amount >= self.entry_rule.max_entry_amount_per_user: + return False, "Reached max amount per user." + + if self.entry_rule.max_daily_entries_per_user: + if self.user_amount_today >= self.entry_rule.max_daily_entries_per_user: + return False, "Reached max amount today." + + # This would indicate something is wrong, as something else should have done this + e, reason = self.should_end() + if e: + LOG.warning("contest should be over") + return False, "contest is over" + # todo: others in self.entry_rule ... min_completes, id_verified, etc. + return True, "" diff --git a/generalresearch/models/thl/contest/utils.py b/generalresearch/models/thl/contest/utils.py new file mode 100644 index 0000000..1505f7b --- /dev/null +++ b/generalresearch/models/thl/contest/utils.py @@ -0,0 +1,76 @@ +from typing import TYPE_CHECKING, List, Dict + +if TYPE_CHECKING: + from generalresearch.models.thl.user import User + from generalresearch.currency import USDCent + from generalresearch.models.thl.leaderboard import LeaderboardRow + + +def censor_product_user_id(user: "User") -> str: + s = user.product_user_id + + if len(s) >= 24: + return f"{s[:4]}{'*' * (len(s) - 8)}{s[-4:]}" + elif len(s) >= 6: + return f"{s[:1]}{'*' * (len(s) - 2)}{s[-1:]}" + else: + return "*" * len(s) + + +def distribute_leaderboard_prizes( + prizes: List["USDCent"], leaderboard_rows: List["LeaderboardRow"] +) -> Dict[str, "USDCent"]: + """ + Distributes leaderboard prizes among tied users. + The prizes for the tied places are pooled together and divided + equally among all tied participants. + + :param prizes: List of cash value for prizes (in descending order). + :param leaderboard_rows: List of LeaderboardRow, sorted by score descending / rank ascending. + + Returns: + dict: Mapping {user: prize_amount} for all tied users. + + See also: + https://en.wikipedia.org/wiki/Ranking#Standard_competition_ranking_(%221224%22_ranking) + https://www.pgatour.com/fedexcup/overview + + (Points are distributed to those in tying positions using the same method + currently used to distribute prize money when there is a tie. That is, the + total points for each tying position will be averaged and that average will + be distributed to each player in the tying position.) + + """ + from generalresearch.currency import USDCent + + if not prizes or not leaderboard_rows: + return {} + + leaderboard_rows = sorted(leaderboard_rows, key=lambda x: x.rank) + prizes = sorted(prizes, reverse=True) + + result = {} + place = 0 # index into prizes + rank = 1 + + while place < len(prizes): + # Get all users tied for this rank + tie_group = [row for row in leaderboard_rows if row.rank == rank] + + # Determine which prize places this tie group occupies + tie_prizes = prizes[place : place + len(tie_group)] + if not tie_prizes: + break + + # Pool prizes for all places they occupy, then split among the group + total = sum(p for p in tie_prizes) + split = USDCent(round(total / len(tie_group))) + for row in tie_group: + result[row.bpuid] = split + + # Advance prize index by number of tied users (skip the places this group occupied) + place += len(tie_group) + # Advance the rank + rank += 1 + + return result diff --git a/generalresearch/models/thl/definitions.py b/generalresearch/models/thl/definitions.py new file mode 100644 index 0000000..259093b --- /dev/null +++ b/generalresearch/models/thl/definitions.py @@ -0,0 +1,343 @@ +import copy +from enum import Enum + +from generalresearch.utils.enum import ReprEnumMeta + + +class ReservedQueryParameters(str, Enum, metaclass=ReprEnumMeta): + PRODUCT_ID = "product_id" + PRODUCT_USER_ID = "bp_user_id" + BPUID = "bpuid" + + COUNTRY_CODES = "country_codes" + COUNTRY = "country" + COUNTRY_ISO = "country_iso" + LANGUAGE_CODES = "lang_codes" + LANGUAGES = "languages" + + CPI = "cpi" + CURRENCY = "currency" + PAYOUT = "payout" + + DURATION = "duration" + LOI = "loi" + REQUESTED_LOI = "req_loi" + + IP_ADDRESS = "ip" + KEYWORD_ARGS = "kwargs" + + NAME = "name" + DESCRIPTION = "description" + QUALITY = "quality" + QUALITY_CATEGORY = "quality_category" + + STATUS = "status" + TASK_URI = "task_uri" + TASKS = "tasks" + TASK_STATUS_ID = "tsid" + + URI = "uri" + URL = "url" + FORMAT = "format" + ENTRY_LINK = "entry_link" + BUCKET = "b" + INDEX = "i" + INDEX_VERBOSE = "idx" + INFO = "info" + ELLE = "l" + N_BINS = "n_bins" + + +class THLPaths(str, Enum, metaclass=ReprEnumMeta): + + # Endpoints on thl-fsb + TASK_ADJUSTMENT = "f4484dbdf144451ab60cda256ce14266" + ACCESS_CONTROL = "9bf111afe03e40719c5cd0de0dc43c31" + + # Endpoints on thl-core (TO BE REMOVED) + # (under /api/v1/) + GET_GRLIQ_JS = "d9e1d3fbfa934b249abfd71f0f3bd667" + GET_GRLIQ_LOGO = "1fe9fdec9eae43fa848c930972141436" + + # Endpoints on GRL-IQ + # (under /api/) + GET_GRLIQ_JS_INLINE = "d9e1d3fbfa934b249abfd71f0f3bd667" + GET_GRLIQ_JS_ATTR = "4a2954b34cc24f93be3e8b218e323b88" + + +class Status(str, Enum, metaclass=ReprEnumMeta): + """ + The outcome of a session or wall event. If the session is still in progress, the status will be NULL. + """ + + # User completed the job successfully and should be paid something + COMPLETE = "c" + # User did not successfully complete the task. They were rejected by either + # GRL, the marketplace, or the buyer. + FAIL = "f" + # User abandoned the task. This would only get set if the BP lets us know + # the user took some action to exit out of the task + ABANDON = "a" + # User either abandoned the task or was never returned to us for some + # reason. After a pre-determined amount of time (configurable on the BP + # level), any task that does not have a status will time out. + TIMEOUT = "t" + + +class WallAdjustedStatus(str, Enum, metaclass=ReprEnumMeta): + # Task was reconciled to complete + ADJUSTED_TO_COMPLETE = "ac" + # Task was reconciled to incomplete + ADJUSTED_TO_FAIL = "af" + # The cpi for a task was changed. This applies to Wall events ONLY. + CPI_ADJUSTMENT = "ca" + # This is only supported for compatibility reasons, as we currently do not + # do anything with confirmed completes as they have historically been + # meaningless. They only get added to the thl_taskadjustment table, and + # won't get used in the Wall.adjusted_status (for now. The + # WallManager.adjust_status does not support doing anything with this). + CONFIRMED_COMPLETE = "cc" + + +class SessionAdjustedStatus(str, Enum, metaclass=ReprEnumMeta): + """An adjusted_status is set if a session is adjusted by the marketplace + after the original return. A session can be adjusted multiple times. + This is the most recent status. If a session was originally a complete, + was adjusted to incomplete, then back to complete, the adjusted_status + will be None, but the adjusted_timestamp will be set to the most recent + change. + """ + + # Task was reconciled to complete + ADJUSTED_TO_COMPLETE = "ac" + # Task was reconciled to incomplete + ADJUSTED_TO_FAIL = "af" + # The payout was changed. This applies to Sessions ONLY. + PAYOUT_ADJUSTMENT = "pa" + + +class StatusCode1(int, Enum, metaclass=ReprEnumMeta): + """ + __High level status code for outcome of the session.__ + This should only be NULL if the Status is ABANDON or TIMEOUT + """ + + # Do not use 0 because grpc does not distinguish between 0 and None. + + # User terminated in buyer survey + BUYER_FAIL = 1 + # User terminated in buyer survey for quality reasons + BUYER_QUALITY_FAIL = 2 + # User failed in marketplace's prescreener + PS_FAIL = 3 + # User rejected by marketplace for quality reasons + PS_QUALITY = 4 + # User is explicitly blocked by the marketplace. Note: on some marketplaces, + # users can have multiple PS_QUALITY terminations and still complete + # surveys. + PS_BLOCKED = 5 + # User rejected by marketplace for over quota + PS_OVERQUOTA = 6 + # User rejected by marketplace for duplicate + PS_DUPLICATE = 7 + # The user failed within the GRS Platform + GRS_FAIL = 8 + # The user failed within the GRS Platform for quality reasons + GRS_QUALITY_FAIL = 9 + + # The user abandoned/timed out within the GRS Platform + GRS_ABANDON = 10 + # The user abandoned/timed out within the marketplace's pre-screen system. + # Note: On most marketplaces, we have no way of distinguishing between + # this and BUYER_ABANDON. BUYER_ABANDON is used as the default, unless we + # know it is PS_ABANDON. + PS_ABANDON = 11 + # The user abandoned/timed out within the client survey + BUYER_ABANDON = 12 + + # The status code is not documented + UNKNOWN = 13 + # The user completed the task successfully + COMPLETE = 14 + + # Something was wrong upon the user redirecting from the marketplace, e.g. no postback received, + # or url hashing failures. + MARKETPLACE_FAIL = 15 + + # **** Below here should ONLY be used on a Session (not a Wall) **** + + # User failed before being sent into a marketplace + SESSION_START_FAIL = 16 + # User failed between attempts + SESSION_CONTINUE_FAIL = 17 + # User failed before being sent into a marketplace for "security" reasons + SESSION_START_QUALITY_FAIL = 18 + # User failed between attempts for "security" reasons + SESSION_CONTINUE_QUALITY_FAIL = 19 + + +class SessionStatusCode2(int, Enum, metaclass=ReprEnumMeta): + """ + __Status Detail__ + This should be set if the Session.status_code_1 is SESSION_XXX_FAIL + """ + + # Unable to parse either the bucket_id, request_id, or nudge_id from the url + ENTRY_URL_MODIFICATION = 1 + # The client's IP failed maxmind lookup, or we failed to store it for some reason + UNRECOGNIZED_IP = 2 + # User is using an anonymous IP + USER_IS_ANONYMOUS = 3 + # User is blocked + USER_IS_BLOCKED = 4 + # User is rate limited + USER_IS_RATE_LIMITED = 5 + # The client's useragent was not categorized as desktop, mobile, or tablet + UNRECOGNIZED_DEVICE = 6 + # The user clicked after 5 min + OFFERWALL_EXPIRED = 7 + # Something unexpected happened + INTERNAL_ERROR = 8 + # The user requested the offerwall for a different country than their IP + # address indicates + OFFERWALL_COUNTRY_MISMATCH = 9 + # The bucket id indicated in the url does not exist. This is likely due + # to the user clicking on a bucket for an offerwall that has already + # been refreshed. + INVALID_BUCKET_ID = 10 + # Not necessarily the user's fault. We thought we had surveys, but due to + # for e.g. the user entering on a different device than we thought, there + # really are none. If we get a lot of these, then that might indicate + # something is wrong. + NO_TASKS_AVAILABLE = 11 + # The entrance attempt was flagged by GRLIQ as suspicious + ATTEMPT_IS_SUSPICIOUS = 12 + # No GRLIQ forensics post was received + GRLIQ_MISSING = 13 + + +class WallStatusCode2(int, Enum, metaclass=ReprEnumMeta): + """ + This should be set if the Wall.status_code_1 is MARKETPLACE_FAIL + """ + + # The redirect URL (coming back from the marketplace) failed hashing checks + URL_HASHING_CHECK_FAILED = 12 + # The redirect URL was missing required query params or was unparseable + BROKEN_REDIRECT = 16 + # The redirect URL was invalid or inconsistent in some way and as a result + # we could not determine the outcome. This could be if a redirect received + # did not match the user's most recent attempt. + INVALID_REDIRECT = 17 + # The redirect indicated a complete, but no/invalid Postback was received + # from the marketplace + INVALID_MARKETPLACE_POSTBACK = 13 + # No/invalid Postback was received from the marketplace. Used in cases where + # the redirect does not contain a status. + NO_MARKETPLACE_POSTBACK = 18 + # The marketplace indicates the user completed the survey, but we don't + # think this is valid due to speeding. Generally this cutoff is the 95th + # percentile of our calculated CompletionTime survey stat. + COMPLETE_TOO_FAST = 14 + # Something happened during the handling of this redirect (on our side) + INTERNAL_ERROR = 15 + + +WALL_ALLOWED_STATUS_STATUS_CODE = { + Status.COMPLETE: {StatusCode1.COMPLETE}, + Status.FAIL: { + StatusCode1.BUYER_FAIL, + StatusCode1.BUYER_QUALITY_FAIL, + StatusCode1.PS_FAIL, + StatusCode1.PS_QUALITY, + StatusCode1.PS_DUPLICATE, + StatusCode1.PS_OVERQUOTA, + StatusCode1.PS_BLOCKED, + StatusCode1.GRS_FAIL, + StatusCode1.GRS_QUALITY_FAIL, + StatusCode1.UNKNOWN, + StatusCode1.MARKETPLACE_FAIL, + }, + Status.ABANDON: { + StatusCode1.PS_ABANDON, + StatusCode1.BUYER_ABANDON, + StatusCode1.GRS_ABANDON, + }, + Status.TIMEOUT: { + StatusCode1.PS_ABANDON, + StatusCode1.BUYER_ABANDON, + StatusCode1.GRS_ABANDON, + }, +} +SESSION_ALLOWED_STATUS_STATUS_CODE = copy.deepcopy(WALL_ALLOWED_STATUS_STATUS_CODE) +SESSION_ALLOWED_STATUS_STATUS_CODE[Status.FAIL].update( + { + StatusCode1.SESSION_START_FAIL, + StatusCode1.SESSION_START_QUALITY_FAIL, + StatusCode1.SESSION_CONTINUE_FAIL, + StatusCode1.SESSION_CONTINUE_QUALITY_FAIL, + } +) + +WALL_ALLOWED_STATUS_CODE_1_2 = { + StatusCode1.MARKETPLACE_FAIL: { + WallStatusCode2.URL_HASHING_CHECK_FAILED, + WallStatusCode2.INVALID_MARKETPLACE_POSTBACK, + WallStatusCode2.COMPLETE_TOO_FAST, + } +} + + +class ReportValue(int, Enum, metaclass=ReprEnumMeta): + """ + The reason a user reported a task. + """ + + # Used to indicate the user exited the task without giving feedback + REASON_UNKNOWN = 0 + # Task is in the wrong language/country, unanswerable question, won't proceed to + # next question, loading forever, error message + TECHNICAL_ERROR = 1 + # Task ended (completed or failed, and showed the user some dialog + # indicating the task was over), but failed to redirect + NO_REDIRECT = 2 + # Asked for full name, home address, identity on another site, cc# + PRIVACY_INVASION = 3 + # Asked about children, employer, medical issues, drug use, STDs, etc. + UNCOMFORTABLE_TOPICS = 4 + # Asked to install software, signup/login to external site, access webcam, + # promise to pay using external site, etc. + ASKED_FOR_NOT_ALLOWED_ACTION = 5 + # Task doesn't work well on a mobile device + BAD_ON_MOBILE = 6 + # Too long, too boring, confusing, complicated, too many + # open-ended/free-response questions + DIDNT_LIKE = 7 + + +class PayoutStatus(str, Enum, metaclass=ReprEnumMeta): + """The max size of the db field that holds this value is 20, so please + don't add new values longer than that! + """ + + # The user has requested a payout. The money is taken from their + # wallet. A PENDING request can either be APPROVED, REJECTED, or + # CANCELLED. We can also implicitly skip the APPROVED step and go + # straight to COMPLETE or FAILED. + PENDING = "PENDING" + # The request is approved (by us or automatically). Once approved, + # it can be FAILED or COMPLETE. + APPROVED = "APPROVED" + # The request is rejected. The user loses the money. + REJECTED = "REJECTED" + # The user requests to cancel the request, the money goes back into their wallet. + CANCELLED = "CANCELLED" + # The payment was approved, but failed within external payment provider. + # This is an "error" state, as the money won't have moved anywhere. A + # FAILED payment can be tried again and be COMPLETE. + FAILED = "FAILED" + # The payment was sent successfully and (usually) a fee was charged + # to us for it. + COMPLETE = "COMPLETE" + # Not supported # REFUNDED: I'm not sure if this is possible or + # if we'd want to allow it. diff --git a/generalresearch/models/thl/demographics.py b/generalresearch/models/thl/demographics.py new file mode 100644 index 0000000..fd833c5 --- /dev/null +++ b/generalresearch/models/thl/demographics.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import copy +from collections import defaultdict, Counter +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Literal, List, Dict + +import numpy as np + +from generalresearch.models.thl.locales import CountryISO + +if TYPE_CHECKING: + from generalresearch.models.thl.survey import MarketplaceTask + + +@dataclass(frozen=True) +class DemographicTarget: + country: CountryISO | Literal["*"] + gender: Gender | Literal["*"] + age_group: AgeGroup | Literal["*"] + + def __post_init__(self): + assert self.country == self.country.lower(), "country must be lower" + + def to_tags(self): + gender = self.gender.value if isinstance(self.gender, Gender) else "*" + age_group = ( + self.age_group.value if isinstance(self.age_group, AgeGroup) else "*" + ) + return { + "country": self.country, + "gender": gender, + "age_group": age_group, + } + + +class Gender(str, Enum): + """ + The respondent's gender + """ + + MALE = "male" + FEMALE = "female" + OTHER = "other" + + +class AgeGroup(Enum): + """ + The respondent's age. + """ + + AGE_UNDER_18 = (0, 17, "<18") + AGE_18_TO_35 = (18, 35, "18-35") + AGE_36_TO_55 = (36, 55, "36-55") + AGE_56_TO_75 = (56, 75, "56-75") + AGE_OVER_75 = (76, 120, ">75") + + def __init__(self, low: int, high: int, label: str): + # [inclusive, + self._low = low + # exclusive) + self._high = high + self.label = label + + @property + def low(self): + return self._low + + @property + def high(self): + return self._high + + @property + def value(self): + return self.label + + +def calculate_demographic_metrics(opps: List[MarketplaceTask]): + """ + Measurement: marketplace_survey_demographics + tags: source (marketplace) + : (all combinations of): country, gender, age_groups + values/fields: cost (aka cpi) (min, p25, p50, mean, p75, p90, p95, p99, max) + : count, open_count + """ + source = {opp.source for opp in opps} + assert len(source) == 1 + source = list(source)[0] + survey_cpi = defaultdict(list) + target_open = defaultdict(int) + for opp in opps: + is_open = opp.is_open + cpi = float(opp.cpi) + tgs = opp.demographic_targets + for t in tgs: + survey_cpi[t].append(cpi) + if is_open: + target_open[t] += 1 + survey_counter = {k: len(v) for k, v in survey_cpi.items()} + survey_counter = {k: {"count": v} for k, v in survey_counter.items() if v} + + grp_stats = dict() + for grp, costs in survey_cpi.items(): + stats = { + "cost_min": np.min(costs), + "cost_p25": np.percentile(costs, 25), + "cost_p50": np.median(costs), + "cost_mean": np.mean(costs), + "cost_p75": np.percentile(costs, 75), + "cost_p90": np.percentile(costs, 90), + "cost_p95": np.percentile(costs, 95), + "cost_p99": np.percentile(costs, 99), + "cost_max": np.max(costs), + "open_count": target_open[grp], + } + grp_stats[grp] = stats + survey_counter[grp].update(stats) + + # fmt: off + TOP_COUNTRIES = [ + 'us', 'cn', 'au', 'gb', 'kr', 'de', 'at', 'fr', 'es', 'jp', + 'ca', 'ie', 'br', 'mx', 'nl', 'ar', 'nz', 'in', 'sg', 'it', + 'be', 'hk', 'ch', 'co', 'my' + ] + # fmt: on + survey_counter = { + k: v + for k, v in survey_counter.items() + if k.country == "*" or k.country in TOP_COUNTRIES + } + + base = { + "measurement": "marketplace_survey_demographics", + "tags": {"source": source.value}, + "fields": {}, + } + points = [] + for k, v in survey_counter.items(): + d = copy.deepcopy(base) + d["tags"].update(k.to_tags()) + d["fields"].update(v) + points.append(d) + return points + + +def calculate_used_question_metrics( + opps: List[MarketplaceTask], qid_label: Dict[str, str] +): + """ + Measurement: marketplace_survey_targeting + tags: source (marketplace), "type", country (all and individual) + values/fields: {question_label: count} + """ + source = {opp.source for opp in opps} + assert len(source) == 1 + source = list(source)[0] + country_q_counter = defaultdict(Counter) + for opp in opps: + for q in opp.used_question_ids: + if q not in qid_label: + continue + label = qid_label[q] + country_q_counter["*"][label] += 1 + country_q_counter[opp.country_iso][label] += 1 + + points = [] + for country, q in country_q_counter.items(): + points.append( + { + "measurement": "marketplace_survey_targeting", + "tags": { + "source": source.value, + "type": "question_label", + "country": country, + }, + "fields": dict(q), + } + ) + return points diff --git a/generalresearch/models/thl/finance.py b/generalresearch/models/thl/finance.py new file mode 100644 index 0000000..6a24b5e --- /dev/null +++ b/generalresearch/models/thl/finance.py @@ -0,0 +1,881 @@ +import random +from datetime import timezone +from typing import Optional, TYPE_CHECKING, List +from uuid import uuid4 + +import pandas as pd +from pydantic import ( + BaseModel, + Field, + NonNegativeInt, + ConfigDict, + model_validator, + computed_field, + field_validator, +) +from pydantic.json_schema import SkipJsonSchema + +from generalresearch.currency import USDCent +from generalresearch.decorators import LOG +from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.thl.definitions import SessionAdjustedStatus +from generalresearch.pg_helper import PostgresConfig + +payout_example = random.randint(150, 750 * 100) +adjustment_example = random.randint(-1_000, 50 * 100) + +if TYPE_CHECKING: + from generalresearch.models.thl.ledger import LedgerAccount + from generalresearch.managers.thl.product import ProductManager + from generalresearch.models.thl.ledger import AccountType, Direction + + +class AdjustmentType(BaseModel): + amount: int = Field( + description="The total amount (in USD cents) that the Brokerage Product" + "has earned within a respective time period from a specific" + "Source of Tasks." + ) + + adjustment: "SessionAdjustedStatus" = Field( + description=SessionAdjustedStatus.as_openapi(), + examples=[SessionAdjustedStatus.ADJUSTED_TO_FAIL.value], + ) + + +class POPFinancial(BaseModel): + """ + We can't use our USDCent class in here because aside from it not + supporting negative values for our adjustments, FastAPI also + complains because it doesn't know how to generate documentation + for it. - Max 2024-06-25 + """ + + # --- Tracking / Tagging --- + product_id: Optional[UUIDStr] = Field(default=None, examples=[uuid4().hex]) + + time: AwareDatetimeISO = Field( + description="The starting time block for the respective 'Period' that" + "this grouping is on. The `time` could be the start of a " + "1 minute or 1 hour block for example." + ) + + # --- Numeric --- + + payout: NonNegativeInt = Field( + default=0, + description="The total amount (in USD cents) that the Brokerage Product" + "has earned within a respective time period.", + examples=[payout_example], + ) + + adjustment: int = Field( + description="The total amount (in USD cents) that the Brokerage Product" + "has had adjusted within a respective time period. Most of" + "the time, this will be negative due to Complete to " + "Incomplete reconciliations. However, it can also be " + "positive due to Incomplete to Complete adjustments.", + examples=[adjustment_example], + ) + + adjustment_types: List[AdjustmentType] = Field() + + expense: int = Field( + description="For Product accounts that are setup with Respondent payouts," + "competitions, user bonuses, or other associated 'costs', those" + "expenses are accounted for here. This will be negative for" + "those types of costs." + ) + + net: int = Field( + description="This is the sum of the Payout total, Adjustment and any " + "Expenses total. It can be positive or negative for any " + "specific time period.", + examples=[payout_example + adjustment_example], + ) + + payment: int = Field( + description="Any ACH or Wire amount that was issued between GRL and " + "the Supplier.", + examples=[3_408_288], + ) + + @staticmethod + def list_from_pandas( + input_data: pd.DataFrame, accounts: List["LedgerAccount"] + ) -> List["POPFinancial"]: + """ + This list can either be for a Product or a Business. The difference + is that the list of accounts will either be len()=1 (Product) or + len()>1 (Business), it's also possible that the business only + has a single Product. + + """ + from generalresearch.incite.schemas.mergers.pop_ledger import ( + numerical_col_names, + ) + + from generalresearch.config import is_debug + + # Validate the input accounts + assert len(accounts) > 0, "Must provide accounts" + from generalresearch.models.thl.ledger import ( + AccountType, + Direction, + ) + + assert all([a.account_type == AccountType.BP_WALLET for a in accounts]) + assert all([a.normal_balance == Direction.CREDIT for a in accounts]) + if not is_debug(): + assert all([a.currency == "USD" for a in accounts]) + + if input_data.empty: + return [] + + assert isinstance(input_data.index, pd.MultiIndex) + assert list(input_data.index.names) == ["time_idx", "account_id"] + assert input_data.columns.to_list() == numerical_col_names + uniq_acct_cnt: int = input_data.index.get_level_values(1).unique().size + + # https://grl.sentry.io/issues/5704598444/?project=4507416823332864 + # I changed this to <= because it is (I think) okay to have missing + # events if there was no period financial activity -- Max 2024-08-12 + assert uniq_acct_cnt <= len(accounts) + + account_product_map = {a.uuid: a.reference_uuid for a in accounts} + + res = [] + for index, row in input_data.reset_index().iterrows(): + index: int # Not useful, just a RangeIndex + row: pd.DataFrame + + row["time_idx"] = row.time_idx.to_pydatetime().replace(tzinfo=timezone.utc) + instance = ProductBalances.from_pandas(row) + + res.append( + POPFinancial( + product_id=account_product_map[row.account_id], + time=row.time_idx, + payout=instance.payout, + adjustment=instance.adjustment, + adjustment_types=[ + AdjustmentType.model_validate( + { + "adjustment": SessionAdjustedStatus.ADJUSTED_TO_COMPLETE, + "amount": instance.adjustment_credit, + } + ), + AdjustmentType.model_validate( + { + "adjustment": SessionAdjustedStatus.ADJUSTED_TO_FAIL, + "amount": instance.adjustment_debit, + } + ), + ], + expense=instance.expense, + net=instance.net, + payment=instance.payment, + ) + ) + + return res + + +class ProductBalances(BaseModel): + model_config = ConfigDict(extra="ignore", populate_by_name=True) + + # --- Tracking / Tagging --- + product_id: Optional[UUIDStr] = Field(default=None, examples=[uuid4().hex]) + last_event: Optional[AwareDatetimeISO] = Field(default=None) + + # --- Numeric --- + + # TODO: will these ever NOT be 0? + mp_payment_credit: SkipJsonSchema[NonNegativeInt] = Field( + default=0, exclude=True, validation_alias="mp_payment.CREDIT" + ) + mp_payment_debit: SkipJsonSchema[NonNegativeInt] = Field( + default=0, exclude=True, validation_alias="mp_payment.DEBIT" + ) + mp_adjustment_credit: SkipJsonSchema[NonNegativeInt] = Field( + default=0, exclude=True, validation_alias="mp_adjustment.CREDIT" + ) + mp_adjustment_debit: SkipJsonSchema[NonNegativeInt] = Field( + default=0, exclude=True, validation_alias="mp_adjustment.DEBIT" + ) + bp_payment_debit: SkipJsonSchema[NonNegativeInt] = Field( + default=0, exclude=True, validation_alias="bp_payment.DEBIT" + ) + plug_credit: SkipJsonSchema[NonNegativeInt] = Field( + default=0, exclude=True, validation_alias="plug.CREDIT" + ) + + plug_debit: SkipJsonSchema[NonNegativeInt] = Field( + default=0, exclude=True, validation_alias="plug.DEBIT" + ) + + bp_payment_credit: NonNegativeInt = Field( + default=0, + validation_alias="bp_payment.CREDIT", + description="The total amount that has been earned by the Task " + "completes, for this Brokerage Product account.", + examples=[18_837], + ) + + adjustment_credit: NonNegativeInt = Field( + default=0, + validation_alias="bp_adjustment.CREDIT", + description="Positive reconciliations issued back to the Brokerage " + "Product account.", + examples=[2], + ) + + adjustment_debit: NonNegativeInt = Field( + default=0, + validation_alias="bp_adjustment.DEBIT", + description="Negative reconciliations for any Task completes", + examples=[753], + ) + + supplier_credit: NonNegativeInt = Field( + default=0, + validation_alias="bp_payout.CREDIT", + description="ACH or Wire amounts issued to GRL from a Supplier to recoup " + "for a negative Brokerage Product balance", + examples=[0], + ) + + supplier_debit: NonNegativeInt = Field( + default=0, + validation_alias="bp_payout.DEBIT", + description="ACH or Wire amounts sent to a Supplier", + examples=[10_000], + ) + + user_bonus_credit: NonNegativeInt = Field( + default=0, + validation_alias="user_bonus.CREDIT", + # TODO: @greg - when would this ever NOT be 0 + description="If a respondent ever pays back an product account.", + examples=[0], + ) + + user_bonus_debit: NonNegativeInt = Field( + default=0, + validation_alias="user_bonus.DEBIT", + description="Pay a user into their wallet balance. There is no fee " + "here. There is only a fee when the user requests a payout." + "The bonus could be as a bribe, winnings for a contest, " + "leaderboard, etc.", + examples=[2_745], + ) + + # --- Hidden helper values --- + + issued_payment: NonNegativeInt = Field( + default=0, + description="This is the amount that we decide to credit as having" + "taken from this Product. If there is any amount not issued" + "it is summed up over the Business to offset any negative" + "balances elsewhere.", + ) + + # --- Validate --- + @model_validator(mode="after") + def check_unknown_fields(self) -> "ProductBalances": + """ + I don't fully understand what these fields are supposed to be + when looking at bp_wallet accounts. However, I know that they're + always 0 so far, so let's assert that so it'll fail if they're + not.. then figure out why... + """ + val = sum( + [ + self.mp_payment_credit, + self.mp_payment_debit, + self.mp_adjustment_credit, + self.mp_adjustment_debit, + self.bp_payment_debit, + self.plug_credit, + ] + ) + + if val > 0: + raise ValueError("review data: unknown field not 0") + + return self + + # --- Properties --- + @computed_field( + title="Task Payouts", + description="The sum amount of all Task payouts", + examples=[18_837], + return_type=int, + ) + @property + def payout(self) -> int: + return self.bp_payment_credit + + @computed_field( + title="Task Payouts USD Str", + examples=["$18,837.00"], + return_type=str, + ) + @property + def payout_usd_str(self) -> str: + from generalresearch.currency import USDCent + + return USDCent(self.payout).to_usd_str() + + @computed_field( + title="Task Adjustments", + description="The sum amount of all Task Adjustments", + examples=[-751], + return_type=int, + ) + @property + def adjustment(self) -> int: + return (self.adjustment_credit - self.plug_debit) + (self.adjustment_debit * -1) + + @computed_field( + title="Product Expenses", + description="The sum amount of any associated Product Expenses (eg: " + "user bonuses)", + examples=[-2_745], + return_type=int, + ) + @property + def expense(self) -> int: + return self.user_bonus_credit + (self.user_bonus_debit * -1) + + # --- Properties: account related --- + @computed_field( + title="Net Earnings", + description="The Product's Net Earnings which is equal to the total" + "amount of Task Payouts, with Task Adjustments and any" + "Product Expenses deducted. This can be positive or" + "negative.", + examples=[15341], + return_type=int, + ) + @property + def net(self) -> int: + return self.payout + self.adjustment + self.expense + + @computed_field( + title="Supplier Payments", + description="The sum amount of all Supplier Payments (eg ACH or Wire " + "transfers)", + examples=[10_000], + return_type=NonNegativeInt, + ) + @property + def payment(self): + """We'll consider this positive, even though it's really a deduction + from their balance... they'll want to see it as positive. + """ + return (self.supplier_credit * -1) + self.supplier_debit + + @computed_field( + title="Supplier Payments", + examples=["$10,000"], + return_type=str, + ) + @property + def payment_usd_str(self): + from generalresearch.currency import USDCent + + return USDCent(self.payment).to_usd_str() + + @computed_field( + title="Product Balance", + description="The Product's Balance which is equal to the Product's Net" + "amount with already issued Supplier Payments deducted. " + "This can be positive or negative.", + examples=[5_341], + return_type=int, + ) + @property + def balance(self) -> int: + return self.net + (self.payment * -1) + + @computed_field( + title="Smart Retainer", + description="The Smart Retainer is an about of money that is held by" + "GRL to account for any Task Adjustments that may occur" + "in the future. The amount will always be positive, and" + "if the Product's balance is negative, the retainer will " + "be $0.00 as the Product is not eligible for any Supplier" + "Payments either way.", + examples=[1_335], + return_type=NonNegativeInt, + ) + @property + def retainer(self) -> NonNegativeInt: + if self.balance <= 0: + # We don't need to show a retainer amount if the account is already + # in a financial deficit + return 0 + + return abs(int(self.balance * 0.25)) + + @computed_field( + title="Smart Retainer USD Str", + examples=["$1,335.00"], + return_type=str, + ) + @property + def retainer_usd_str(self) -> str: + from generalresearch.currency import USDCent + + return USDCent(self.retainer).to_usd_str() + + @computed_field( + title="Available Balance", + description="The Available Balance is the amount that is currently, and" + "immediately available for withdraw from the Supplier's" + "balance. Supplier Payments are made every Friday for " + "Businesses with an ACH connected Bank Account to GRL, " + "while a Business that requires an International Wire " + "are issued on the last Friday of every Month.", + examples=[4_006], + return_type=NonNegativeInt, + ) + @property + def available_balance(self) -> NonNegativeInt: + if self.balance <= 0: + return 0 + + ab = self.balance - self.retainer + if ab <= 0: + return 0 + + return ab + + @computed_field( + title="Available Balance USD Str", + examples=["$4,006.00"], + return_type=str, + ) + @property + def available_balance_usd_str(self) -> str: + from generalresearch.currency import USDCent + + return USDCent(self.available_balance).to_usd_str() + + @computed_field( + title="Recoup", + examples=[282], + return_type="USDCent", + ) + @property + def recoup(self) -> "USDCent": + from generalresearch.currency import USDCent + + if self.balance >= 0: + return USDCent(0) + + return USDCent(abs(self.balance)) + + @computed_field( + title="Recoup Str", + examples=["$2.04"], + return_type=str, + ) + @property + def recoup_usd_str(self) -> str: + return self.recoup.to_usd_str() + + # --- Properties: account related --- + @computed_field( + title="Adjustment Percentage", + description="The percentage of USDCent value that has been adjusted" + "over all time for this Product.", + examples=[0.064938], + return_type=float, + ) + @property + def adjustment_percent(self) -> float: + if self.payout <= 0: + return 0.00 + + return abs(self.adjustment) / self.payout + + @staticmethod + def from_pandas( + input_data: pd.DataFrame | pd.Series, + ): + LOG.debug(f"ProductBalances.from_pandas(input_data={input_data.shape})") + + if isinstance(input_data, pd.Series): + return ProductBalances.model_validate(input_data.to_dict()) + + elif isinstance(input_data, pd.DataFrame): + assert isinstance(input_data.index, pd.DatetimeIndex), "Invalid input data" + + # The pop merge is grouped by 1min intervals. Therefore, if we take + # the maximum of the dt.floor("1min") value and add 1min to it, we + # can assume that the parquet files from incite will include any + # events up to that timestamp + pq_last_event_close = input_data.index.max() + pd.Timedelta(minutes=1) + + pb = ProductBalances.model_validate(input_data.sum().to_dict()) + pb.last_event = pq_last_event_close.to_pydatetime() + return pb + + else: + raise NotImplementedError("Can't handle this input") + + def __str__(self) -> str: + return ( + f"Product: {self.product_id or '—'}\n" + f"Total Payout: ${self.payout / 100:,.2f}\n" + f"Total Adjustment: ${self.adjustment / 100:,.2f}\n" + f"Total Expense: ${self.expense / 100:,.2f}\n" + f"–––\n" + f"Net: ${self.net / 100:,.2f}\n" + f"Balance: ${self.balance / 100:,.2f}\n" + f"Smart Retainer: ${self.retainer / 100:,.2f}\n" + f"Available Balance: ${self.available_balance / 100:,.2f}" + ).replace("$-", "-$") + + +class BusinessBalances(BaseModel): + product_balances: List[ProductBalances] = Field(default_factory=list) + + # --- Validators --- + @field_validator("product_balances") + def required_product_ids(cls, v: List[ProductBalances]): + """The BusinessBalances needs to be able to distinguish between all + the child Products; in order to do this, we need to assert that + they all explicitly are set + """ + + if any([pb.product_id is None for pb in v]): + raise ValueError("'product_id' must be set for BusinessBalance children.") + + return v + + # --- Properties --- + @computed_field( + title="Task Payouts", + description="The sum amount of all Task payouts", + examples=[18_837], + return_type=int, + ) + @property + def payout(self) -> int: + return sum([i.payout for i in self.product_balances]) + + @computed_field( + title="Task Payouts USD Str", + examples=["$18,837"], + return_type=str, + ) + @property + def payout_usd_str(self) -> str: + from generalresearch.currency import USDCent + + return USDCent(self.payout).to_usd_str() + + @computed_field( + title="Task Adjustments", + description="The sum amount of all Task Adjustments", + examples=[-751], + return_type=int, + ) + @property + def adjustment(self) -> int: + adjustment_credit = sum([pb.adjustment_credit for pb in self.product_balances]) + plug_debit = sum([pb.plug_debit for pb in self.product_balances]) + adjustment_debit = sum([pb.adjustment_debit for pb in self.product_balances]) + + return (adjustment_credit - plug_debit) + (adjustment_debit * -1) + + @computed_field( + title="Task Adjustments USD Str", + examples=["-$2,745.00"], + return_type=str, + ) + @property + def adjustment_usd_str(self) -> str: + from generalresearch.currency import format_usd_cent + + return format_usd_cent(self.adjustment) + + @computed_field( + title="Business Expenses", + description="The sum amount of any associated Business Expenses (eg: " + "user bonuses)", + examples=[-2_745], + return_type=int, + ) + @property + def expense(self) -> int: + user_bonus_credit = sum([pb.user_bonus_credit for pb in self.product_balances]) + user_bonus_debit = sum([pb.user_bonus_debit for pb in self.product_balances]) + + return user_bonus_credit + (user_bonus_debit * -1) + + @computed_field( + title="Business Expenses USD Str", + examples=["-$2,745.00"], + return_type=str, + ) + @property + def expense_usd_str(self) -> str: + from generalresearch.currency import format_usd_cent + + return format_usd_cent(self.expense) + + # --- Properties: account related --- + @computed_field( + title="Net Earnings", + description="The Business's Net Earnings which is equal to the total" + "amount of Task Payouts, with Task Adjustments and any" + "Product Expenses deducted. This can be positive or" + "negative.", + examples=[15341], + return_type=int, + ) + @property + def net(self) -> int: + return self.payout + self.adjustment + self.expense + + @computed_field( + title="Net Earnings USD Str", + examples=["$15,341"], + return_type=str, + ) + @property + def net_usd_str(self) -> str: + from generalresearch.currency import format_usd_cent + + return format_usd_cent(self.net) + + @computed_field( + title="Supplier Payments", + description="The sum amount of all Supplier Payments (eg ACH or Wire " + "transfers)", + examples=[10_000], + return_type=NonNegativeInt, + ) + @property + def payment(self): + """We'll consider this positive, even though it's really a deduction + from their balance... they'll want to see it as positive. + """ + supplier_credit = sum([pb.supplier_credit for pb in self.product_balances]) + supplier_debit = sum([pb.supplier_debit for pb in self.product_balances]) + + return (supplier_credit * -1) + supplier_debit + + @computed_field( + title="Supplier Payments USD Str", + examples=["$10,000.00"], + return_type=str, + ) + @property + def payment_usd_str(self) -> str: + from generalresearch.currency import USDCent + + return USDCent(self.payment).to_usd_str() + + @computed_field( + title="Business Balance", + description="The Business's Balance which is equal to the Business's Net" + "amount with already issued Supplier Payments deducted. " + "This can be positive or negative.", + examples=[5_341], + return_type=int, + ) + @property + def balance(self) -> int: + return self.net + (self.payment * -1) + + @computed_field( + title="Business Balance USD Str", + examples=["$5,341.00"], + return_type=str, + ) + @property + def balance_usd_str(self) -> str: + from generalresearch.currency import format_usd_cent + + return format_usd_cent(self.balance) + + @computed_field( + title="Smart Retainer", + description="The Smart Retainer is an about of money that is held by" + "GRL to account for any Task Adjustments that may occur" + "in the future. The amount will always be positive, and" + "if the Business's balance is negative, the retainer will " + "be $0.00 as the Business is not eligible for any Supplier" + "Payments either way.", + examples=[1_335], + return_type=NonNegativeInt, + ) + @property + def retainer(self) -> NonNegativeInt: + return sum([pb.retainer for pb in self.product_balances]) + + @computed_field( + title="Smart Retainer USD Str", + examples=["$1,335.00"], + return_type=str, + ) + @property + def retainer_usd_str(self) -> str: + from generalresearch.currency import USDCent + + return USDCent(self.retainer).to_usd_str() + + @computed_field( + title="Available Balance", + description="The Available Balance is the amount that is currently, and" + "immediately available for withdraw from the Supplier's" + "balance. Supplier Payments are made every Friday for " + "Businesses with an ACH connected Bank Account to GRL, " + "while a Business that requires an International Wire " + "are issued on the last Friday of every Month.", + examples=[4_006], + return_type=NonNegativeInt, + ) + @property + def available_balance(self) -> NonNegativeInt: + if self.balance <= 0: + return 0 + + ab = self.balance - self.retainer + if ab <= 0: + return 0 + + return ab + + @computed_field( + title="Available Balance USD Str", + examples=["$4,006.00"], + return_type=str, + ) + @property + def available_balance_usd_str(self) -> str: + from generalresearch.currency import USDCent + + return USDCent(self.available_balance).to_usd_str() + + # --- Properties: account related --- + @computed_field( + title="Adjustment Percentage", + description="The percentage of USDCent value that has been adjusted" + "over all time for this Product. This is not an aggregation" + "of each of the children Product Balances, but calculated" + "across all traffic of the children", + examples=[0.064938], + return_type=float, + ) + @property + def adjustment_percent(self) -> float: + if self.payout <= 0: + return 0.00 + + return abs(self.adjustment) / self.payout + + @computed_field( + title="Business Recoup Hold", + description="The sum amount of all Supplier Payments (eg ACH or Wire " + "transfers)", + examples=[10_000], + return_type="USDCent", + ) + @property + def recoup(self) -> "USDCent": + """Returns the sum of this Business' recouped amount from any + children Products. + """ + from generalresearch.currency import USDCent + + return USDCent(sum([i.recoup for i in self.product_balances])) + + @computed_field( + title="Business Recoup Hold Str", + examples=["$2.04"], + return_type=str, + ) + @property + def recoup_usd_str(self) -> str: + return self.recoup.to_usd_str() + + # --- Methods --- + + def __str__(self) -> str: + return ( + f"Products: {len(self.product_balances)}\n" + f"Total Payout: ${self.payout / 100:,.2f}\n" + f"Total Adjustment: ${self.adjustment / 100:,.2f}\n" + f"Total Expense: ${self.expense / 100:,.2f}\n" + f"–––\n" + f"Net: ${self.net / 100:,.2f}\n" + f"Balance: ${self.balance / 100:,.2f}\n" + f"Smart Retainer: ${self.retainer / 100:,.2f}\n" + f"Available Balance: ${self.available_balance / 100:,.2f}" + ).replace("$-", "-$") + + # --- Methods --- + @staticmethod + def from_pandas( + input_data: pd.DataFrame, + accounts: List["LedgerAccount"], + thl_pg_config: PostgresConfig, + ) -> "BusinessBalances": + LOG.debug(f"BusinessBalances.from_pandas(input_data={input_data.shape})") + + from generalresearch.incite.schemas.mergers.pop_ledger import ( + numerical_col_names, + ) + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.ledger import ( + AccountType, + Direction, + ) + + # Validate the input accounts + assert len(accounts) > 0, "Must provide accounts" + assert all([a.account_type == AccountType.BP_WALLET for a in accounts]) + assert all([a.normal_balance == Direction.CREDIT for a in accounts]) + from generalresearch.config import is_debug + + if not is_debug(): + assert all([a.currency == "USD" for a in accounts]) + + # Validate the input dataframe + assert input_data.index.name == "account_id" + assert len(input_data.index) <= len(accounts) + assert input_data.columns.to_list() == numerical_col_names + + account_product_map = {a.uuid: a.reference_uuid for a in accounts} + + product_balances = [] + for account_id, series in input_data.iterrows(): + pb = ProductBalances.from_pandas(series) + pb.product_id = account_product_map[account_id] + product_balances.append(pb) + + # Sort the ProductBalances so that they're always in a consistent + # sorted order. + from generalresearch.managers.thl.product import ProductManager + + pm = ProductManager(pg_config=thl_pg_config) + products: List[Product] = pm.get_by_uuids( + product_uuids=[pb.product_id for pb in product_balances] + ) + sorted_products_uuids = [ + p.uuid for p in sorted(products, key=lambda x: x.created) + ] + product_uuid_order = { + value: idx for idx, value in enumerate(sorted_products_uuids) + } + product_balances = sorted( + product_balances, key=lambda pb: product_uuid_order[pb.product_id] + ) + + return BusinessBalances.model_validate({"product_balances": product_balances}) diff --git a/generalresearch/models/thl/grliq.py b/generalresearch/models/thl/grliq.py new file mode 100644 index 0000000..1e769aa --- /dev/null +++ b/generalresearch/models/thl/grliq.py @@ -0,0 +1,10 @@ +from generalresearch.grliq.models.decider import ( + Decider, + AttemptDecision, + GrlIqAttemptResult, +) + +# thl-core is importing these models from here. need to update thl-core then can get rid of this +_ = Decider +_ = AttemptDecision +_ = GrlIqAttemptResult diff --git a/generalresearch/models/thl/ipinfo.py b/generalresearch/models/thl/ipinfo.py new file mode 100644 index 0000000..d8abfc0 --- /dev/null +++ b/generalresearch/models/thl/ipinfo.py @@ -0,0 +1,348 @@ +import ipaddress +from datetime import datetime, timezone +from typing import Optional, Dict, Any, Literal, Tuple + +import geoip2.models +from faker import Faker +from pydantic import ( + BaseModel, + Field, + PositiveInt, + field_validator, + PrivateAttr, + ConfigDict, +) +from typing_extensions import Self + +from generalresearch.models.custom_types import ( + AwareDatetimeISO, + IPvAnyAddressStr, + CountryISOLike, +) +from generalresearch.models.thl.maxmind.definitions import UserType +from generalresearch.pg_helper import PostgresConfig + +fake = Faker() + +PrefixLength = Literal["/128", "/64", "/32"] + + +def normalize_ip(ip: IPvAnyAddressStr) -> Tuple[str, PrefixLength]: + """ + Normalize an IP address for MySQL storage. + + - IPv4: returned unchanged + - IPv6: converted to its /64 network address and returned + in fully expanded (exploded) form + Returns: + (ip, lookup_prefix) + """ + addr = ipaddress.ip_address(ip) + if addr.version == 4: + return ip, "/32" + net64 = ipaddress.IPv6Network((addr, 64), strict=False) + return net64.network_address.exploded, "/64" + + +class IPGeoname(BaseModel): + geoname_id: PositiveInt = Field() + + continent_code: Optional[str] = Field(default=None, max_length=2) + continent_name: Optional[str] = Field(default=None, max_length=32) + + country_iso: CountryISOLike = Field( + description="The ISO code of the country associated with the IP address.", + examples=[fake.country_code().lower()], + ) + country_name: Optional[str] = Field(default=None, max_length=64) + + subdivision_1_iso: Optional[str] = Field( + default=None, + description="The ISO code of the primary subdivision (e.g., state or province).", + max_length=3, + ) + subdivision_1_name: Optional[str] = Field( + default=None, + description="The name of the primary subdivision (e.g., state or province).", + max_length=255, + ) + subdivision_2_iso: Optional[str] = Field( + default=None, + description="The ISO code of the secondary subdivision (if applicable).", + max_length=3, + ) + subdivision_2_name: Optional[str] = Field( + default=None, + description="The name of the secondary subdivision (if applicable).", + max_length=255, + ) + + city_name: Optional[str] = Field( + default=None, + max_length=255, + description="The name of the city associated with the IP address.", + examples=[fake.city()], + ) + metro_code: Optional[int] = Field(default=None) + + time_zone: Optional[str] = Field( + default=None, + max_length=60, + description="The time zone associated with the geographical location.", + examples=[fake.timezone()], + ) + is_in_european_union: Optional[bool] = Field(default=None) + + updated: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + ) + + @field_validator( + "country_iso", + "continent_code", + "subdivision_1_iso", + "subdivision_2_iso", + mode="before", + ) + def make_lower(cls, value: str): + if value is not None: + return value.lower() + return value + + # --- ORM --- + def model_dump_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json") + d["updated"] = self.updated + return d + + @classmethod + def from_mysql(cls, d: Dict) -> Self: + d["updated"] = d["updated"].replace(tzinfo=timezone.utc) + + return cls.model_validate(d) + + @classmethod + def from_insights(cls, res: geoip2.models.Insights) -> Self: + geoname_id = res.city.geoname_id + # Some ips don't have city level specificity. grab the first subdivision if it exists + if geoname_id is None and len(res.subdivisions) > 0: + geoname_id = res.subdivisions[0].geoname_id + elif geoname_id is None: + # No city, no subdivision, use the country + geoname_id = res.country.geoname_id + # Some ips have a city but no subdivisions (41.33.89.99) + d = { + "geoname_id": geoname_id, + "continent_code": res.continent.code, + "continent_name": res.continent.name, + "country_iso": res.country.iso_code, + "country_name": res.country.name, + "city_name": res.city.name, + "metro_code": res.location.metro_code, + "time_zone": res.location.time_zone, + "is_in_european_union": res.country.is_in_european_union, + "subdivision_1_iso": None, + "subdivision_1_name": None, + "subdivision_2_iso": None, + "subdivision_2_name": None, + } + if len(res.subdivisions) > 0: + d.update( + { + "subdivision_1_iso": res.subdivisions[0].iso_code, + "subdivision_1_name": res.subdivisions[0].name, + } + ) + if len(res.subdivisions) > 1: + d.update( + { + "subdivision_2_iso": res.subdivisions[1].iso_code, + "subdivision_2_name": res.subdivisions[1].name, + } + ) + return cls.model_validate(d) + + +class IPInformation(BaseModel): + ip: IPvAnyAddressStr = Field() + # This doesn't get stored in mysql/redis, b/c we only look up by the normalized ip + lookup_prefix: Optional[PrefixLength] = Field(default=None, exclude=True) + + geoname_id: Optional[PositiveInt] = Field(default=None) + + country_iso: CountryISOLike = Field( + description="The ISO code of the country associated with the IP address.", + examples=[fake.country_code().lower()], + ) + + registered_country_iso: Optional[CountryISOLike] = Field( + default=None, + description="The ISO code of the country where the IP address is " + "registered.", + examples=[fake.country_code().lower()], + ) + is_anonymous: Optional[bool] = Field( + default=None, + description="Indicates whether the IP address is associated with an " + "anonymous source (e.g., VPN, proxy).", + examples=[False], + ) + is_anonymous_vpn: Optional[bool] = Field(default=None) + is_hosting_provider: Optional[bool] = Field(default=None) + is_public_proxy: Optional[bool] = Field(default=None) + is_tor_exit_node: Optional[bool] = Field(default=None) + is_residential_proxy: Optional[bool] = Field(default=None) + + autonomous_system_number: Optional[PositiveInt] = Field(default=None) + autonomous_system_organization: Optional[str] = Field(default=None, max_length=255) + + domain: Optional[str] = Field(default=None, max_length=255) + isp: Optional[str] = Field( + default=None, + description="The Internet Service Provider associated with the " "IP address.", + examples=["Comcast"], + ) + + mobile_country_code: Optional[str] = Field(default=None, max_length=3) + mobile_network_code: Optional[str] = Field(default=None, max_length=3) + + network: Optional[str] = Field(default=None, max_length=56) + organization: Optional[str] = Field(default=None, max_length=255) + + static_ip_score: Optional[float] = Field( + default=None, + description="A score indicating the likelihood that the IP address is static.", + ) + user_type: Optional[UserType] = Field( + default=None, + description="The type of user associated with the IP address " + "(e.g., 'residential', 'business').", + examples=[UserType.SCHOOL], + ) + postal_code: Optional[str] = Field( + default=None, + description="The postal code associated with the IP address.", + examples=[fake.postcode()], + ) + + latitude: Optional[float] = Field( + description="The latitude coordinate of the IP address location.", + default=None, + examples=[float(fake.latitude())], + ) + longitude: Optional[float] = Field( + description="The longitude coordinate of the IP address location.", + default=None, + examples=[float(fake.longitude())], + ) + + accuracy_radius: Optional[int] = Field( + default=None, + description="The approximate radius of accuracy for the latitude " + "and longitude, in kilometers.", + examples=[fake.random_int(min=25, max=250)], + ) + + updated: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + ) + + _geoname: Optional[IPGeoname] = PrivateAttr(default=None) + + @field_validator("country_iso", "registered_country_iso", mode="before") + def make_lower(cls, value: str): + if value is not None: + return value.lower() + return value + + @property + def basic(self) -> bool: + # This could be almost any field, but we're checking here if maxmind + # insights was run on this record. If not, then most of the optional + # fields will be None + return self.is_anonymous is None + + @property + def geoname(self) -> Optional["IPGeoname"]: + return self._geoname or None + + def normalize_ip(self): + normalized_ip, lookup_prefix = normalize_ip(self.ip) + self.ip = normalized_ip + self.lookup_prefix = lookup_prefix + return None + + # --- prefetch_* --- + def prefetch_geoname( + self, + pg_config: PostgresConfig, + ) -> None: + if self.geoname_id is None: + raise ValueError("Must provide geoname_id") + + from generalresearch.managers.thl.ipinfo import IPGeonameManager + + ip_gm = IPGeonameManager(pg_config=pg_config) + + self._geoname = ip_gm.get_by_id(geoname_id=self.geoname_id) + + return None + + # --- ORM --- + def model_dump_mysql(self): + d = self.model_dump(mode="json", exclude={"geoname"}) + d["updated"] = self.updated + return d + + @classmethod + def from_mysql(cls, d: Dict) -> Self: + d["updated"] = d["updated"].replace(tzinfo=timezone.utc) + + return cls.model_validate(d) + + @classmethod + def from_insights(cls, res: geoip2.models.Insights) -> Self: + geoname_id = res.city.geoname_id + # Some ips don't have city level specificity. grab the first subdivision if it exists + if geoname_id is None and len(res.subdivisions) > 0: + geoname_id = res.subdivisions[0].geoname_id + elif geoname_id is None: + # No city, no subdivision, use the country + geoname_id = res.country.geoname_id + return cls.model_validate( + { + "ip": res.traits.ip_address, + "network": str(res.traits.network), + "geoname_id": geoname_id, + "country_iso": res.country.iso_code.upper(), + "registered_country_iso": ( + res.registered_country.iso_code.upper() + if res.registered_country.iso_code + else None + ), + "is_anonymous": res.traits.is_anonymous, + "is_anonymous_vpn": res.traits.is_anonymous_vpn, + "is_hosting_provider": res.traits.is_hosting_provider, + "is_public_proxy": res.traits.is_public_proxy, + "is_tor_exit_node": res.traits.is_tor_exit_node, + "is_residential_proxy": res.traits.is_residential_proxy, + "autonomous_system_number": res.traits.autonomous_system_number, + "autonomous_system_organization": res.traits.autonomous_system_organization, + "domain": res.traits.domain, + "isp": res.traits.isp, + "mobile_country_code": res.traits.mobile_country_code, + "mobile_network_code": res.traits.mobile_network_code, + "organization": res.traits.organization, + "static_ip_score": res.traits.static_ip_score, + "user_type": res.traits.user_type, + # IP-specific location that may be different for different IPs in the same City + "postal_code": res.postal.code, + "latitude": res.location.latitude, + "longitude": res.location.longitude, + "accuracy_radius": res.location.accuracy_radius, + } + ) + + +class GeoIPInformation(IPInformation, IPGeoname): + model_config = ConfigDict(extra="ignore") diff --git a/generalresearch/models/thl/leaderboard.py b/generalresearch/models/thl/leaderboard.py new file mode 100644 index 0000000..a4c2134 --- /dev/null +++ b/generalresearch/models/thl/leaderboard.py @@ -0,0 +1,349 @@ +from __future__ import annotations + +import logging +from datetime import datetime, timedelta, timezone +from enum import Enum +from typing import List, Literal +from uuid import UUID, uuid3 +from zoneinfo import ZoneInfo + +import math +import pandas as pd +from pydantic import ( + BaseModel, + Field, + NonNegativeInt, + model_validator, + computed_field, + AwareDatetime, + field_validator, +) + +from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.legacy.api_status import StatusResponse +from generalresearch.models.thl.locales import CountryISO +from generalresearch.utils.enum import ReprEnumMeta + +logger = logging.getLogger() + + +class LeaderboardCode(str, Enum, metaclass=ReprEnumMeta): + """ + The type of leaderboard. What the "values" represent. + """ + + # Number of Completes + COMPLETE_COUNT = "complete_count" + # Largest Single Payout + LARGEST_PAYOUT = "largest_user_payout" + # (Sum of) Total Payouts + SUM_PAYOUTS = "sum_user_payout" + + +class LeaderboardFrequency(str, Enum, metaclass=ReprEnumMeta): + """ + The time period range for the leaderboard. + """ + + # UTC midnight to UTC midnight + DAILY = "daily" + # Sunday Midnight to Sunday Midnight UTC + WEEKLY = "weekly" + # Jan 1 00:00:00 + MONTHLY = "monthly" + + +class LeaderboardRow(BaseModel): + bpuid: str = Field(description="product_user_id", examples=["app-user-9329ebd"]) + + rank: int = Field( + description="The numerical data ranks (1 through n) of the values. Ties " + "are ranked using the lowest rank in the group.", + examples=[1], + ) + + value: int = Field( + description="The value. The meaning of the value is dependent on the LeaderboardCode.", + examples=[7], + ) + + def censor(self): + censor_idx = math.ceil(len(self.bpuid) / 2) + self.bpuid = self.bpuid[:censor_idx] + ("*" * len(self.bpuid[censor_idx:])) + + +class Leaderboard(BaseModel): + """A leaderboard exists independently for each product_id in each country. + Each country is associated with a single timezone. There is a daily, + weekly, and monthly leaderboard. + """ + + id: UUIDStr = Field( + description="Unique ID for this leaderboard", + examples=["845b0074ad533df580ebb9c80cc3bce1"], + default=None, + ) + + name: str = Field( + description="Descriptive name for the leaderboard based on the board_code", + examples=["Number of Completes"], + default=None, + ) + + board_code: LeaderboardCode = Field( + description=LeaderboardCode.as_openapi_with_value_descriptions(), + examples=[LeaderboardCode.COMPLETE_COUNT], + ) + + bpid: UUIDStr = Field( + description="product_id", examples=["4fe381fb7186416cb443a38fa66c6557"] + ) + + country_iso: CountryISO = Field( + description="The country this leaderboard is for.", examples=["us"] + ) + + freq: LeaderboardFrequency = Field( + description=LeaderboardFrequency.as_openapi_with_value_descriptions(), + examples=[LeaderboardFrequency.DAILY], + ) + + timezone_name: str = Field( + description="The timezone for the requested country", + examples=["America/New_York"], + default=None, + ) + + sort_order: Literal["ascending", "descending"] = Field(default="descending") + + row_count: NonNegativeInt = Field( + description="The total number of rows in the leaderboard.", examples=[2] + ) + + rows: List[LeaderboardRow] = Field( + default_factory=list, + examples=[ + [ + LeaderboardRow(bpuid="app-user-9329ebd", value=4, rank=1), + LeaderboardRow(bpuid="app-user-7923skw", value=3, rank=2), + ] + ], + ) + + period_start_local: AwareDatetime = Field( + description="The start of the time period covered by this board in local time, tz-aware", + examples=[ + datetime(2024, 7, 12, 0, 0, 0, 0, tzinfo=ZoneInfo("America/New_York")) + ], + # This can't be excluded or else cacheing doesn't work. + # If we want it not in the API response, we need to make a LeaderboardOut + # exclude=True, + ) + + period_end_local: AwareDatetime = Field( + description="The end of the time period covered by this board in local time, tz-aware", + examples=[ + datetime( + 2024, + 7, + 12, + 23, + 59, + 59, + 999999, + tzinfo=ZoneInfo("America/New_York"), + ) + ], + # exclude=True, + default=None, + ) + + @property + def board_key(self): + product_id = self.bpid + country_iso = self.country_iso + freq = self.freq + board_code = self.board_code + date_str = self.period_start_local.strftime("%Y-%m-%d") + return f"leaderboard:{product_id}:{country_iso}:{freq.value}:{date_str}:{board_code.value}" + + @property + def period_start_utc(self) -> datetime: + # The start of the time period covered by this board in UTC, tz-aware + # e.g. datetime(2024, 7, 12, 4, 0, 0, 0, tzinfo=timezone.utc) + return self.period_start_local.astimezone(timezone.utc) + + @property + def period_end_utc(self) -> datetime: + # The end of the time period covered by this board in UTC, tz-aware + # e.g. datetime(2024, 7, 13, 3, 59, 59, 999999, tzinfo=timezone.utc) + return self.period_end_local.astimezone(timezone.utc) + + @computed_field( + description="(unix timestamp) The start time of the time range this leaderboard covers.", + examples=[1720756800], + ) + def start_timestamp(self) -> int: + return int(self.period_start_utc.timestamp()) + + @computed_field( + description="(unix timestamp) The end time of the time range this leaderboard covers.", + examples=[1720843199], + ) + def end_timestamp(self) -> int: + return int(self.period_end_utc.timestamp()) + + @property + def timezone(self) -> ZoneInfo: + return ZoneInfo(self.timezone_name) + + @computed_field(description="The UTC offset for the timezone", examples=["-0400"]) + @property + def utc_offset(self) -> str: + return self.period_start_local.strftime("%z") + + @computed_field( + description="The start time of the time range this leaderboard covers " + "(local time, in the leaderboard's timezone).", + examples=["2024-07-12T00:00:00-04:00"], + ) + @property + def local_start_time(self) -> str: + return self.period_start_local.isoformat() + + @computed_field( + description="The end time of the time range this leaderboard covers " + "(local time, in the leaderboard's timezone).", + examples=["2024-07-12T23:59:59.999999-04:00"], + ) + @property + def local_end_time(self) -> str: + return self.period_end_local.isoformat() + + @computed_field( + description="A formatted string for time period covered by this " + "leaderboard. Can be used to display to users.", + examples=["2024-02-07 to 2024-02-08"], + ) + @property + def start_end_str(self) -> str: + start = self.period_start_local.strftime("%Y-%m-%d") + end = (self.period_end_local + timedelta(minutes=1)).strftime("%Y-%m-%d") + return f"{start} to {end}" + + @model_validator(mode="after") + def set_name(self): + if self.name is None: + self.name = { + LeaderboardCode.COMPLETE_COUNT: "Number of Completes", + LeaderboardCode.LARGEST_PAYOUT: "Largest Single Payout", + LeaderboardCode.SUM_PAYOUTS: "Total Payouts", + }[self.board_code] + return self + + @model_validator(mode="after") + def set_id(self): + if self.id is None: + self.id = self.generate_leaderboard_id() + return self + + @model_validator(mode="after") + def set_timezone_name(self): + if self.timezone_name is None: + self.timezone_name = self.period_start_local.tzinfo.key + return self + + @model_validator(mode="after") + def validate_period(self): + t = pd.Timestamp(self.period_start_local).tz_localize(tz=None) + freq_pd = { + LeaderboardFrequency.WEEKLY: "W-SUN", + LeaderboardFrequency.DAILY: "D", + LeaderboardFrequency.MONTHLY: "M", + }[self.freq] + period = t.to_period(freq_pd) + period_start_local = period.start_time.to_pydatetime().replace( + tzinfo=self.timezone + ) + period_end_local = ( + period.end_time.replace(nanosecond=0) + .to_pydatetime() + .replace(tzinfo=self.timezone) + ) + assert ( + period_start_local == self.period_start_local + ), f"invalid period_start_local {self.period_start_local}. The period starts at {period_start_local}" + if self.period_end_local is not None: + assert self.period_end_local == period_end_local, "invalid period" + else: + self.period_end_local = period_end_local + return self + + @field_validator("rows") + @classmethod + def sort_rows(cls, rows): + return sorted(rows, key=lambda row: (row.rank, row.bpuid)) + + def generate_leaderboard_id(self) -> str: + # Consistently generates the same UUID for a given leaderboard instance. + # https://docs.python.org/3/library/uuid.html#uuid.uuid3 + u = UUID("abee11ed-2943-4fb3-88c5-943921765dc0") # randomly chosen + name_str = "-".join( + [ + self.board_code.value, + self.bpid, + self.country_iso, + self.freq.value, + str(self.start_timestamp), + ] + ) + return uuid3(u, name_str).hex + + def censor(self): + for row in self.rows: + row.censor() + + +class LeaderboardResponse(StatusResponse): + leaderboard: Leaderboard = Field() + + +class LeaderboardWinner(BaseModel): + rank: int = Field( + description="The user's final rank in the leaderboard", examples=[1] + ) + freq: LeaderboardFrequency = Field( + description=LeaderboardFrequency.as_openapi_with_value_descriptions(), + examples=[LeaderboardFrequency.DAILY], + ) + board_code: LeaderboardCode = Field( + description=LeaderboardCode.as_openapi_with_value_descriptions(), + examples=[LeaderboardCode.COMPLETE_COUNT], + ) + country_iso: CountryISO = Field( + description="The country this leaderboard is for.", examples=["us"] + ) + issued: AwareDatetimeISO = Field( + description="When the prize was issued.", + examples=["2022-10-17T05:59:14.570231Z"], + ) + bpuid: str = Field(description="product_user_id", examples=["app-user-9329ebd"]) + description: str = Field(examples=["Bonus for daily contest"]) + amount: int = Field(description="(USD cents) The reward amount", examples=[1000]) + amount_str: str = Field( + description="The amount as a formatted string in USD. Can be " + "displayed to the user.", + examples=["$10.00"], + ) + contest_start: AwareDatetimeISO = Field( + description="When the leaderboard started", + examples=["2022-10-16T04:00:00Z"], + ) + contest_end: AwareDatetimeISO = Field( + description="When the leaderboard ended", + examples=["2022-10-17T04:00:00Z"], + ) + + +class LeaderboardWinnerResponse(StatusResponse): + winners: List[LeaderboardWinner] = Field(default_factory=list) diff --git a/generalresearch/models/thl/ledger.py b/generalresearch/models/thl/ledger.py new file mode 100644 index 0000000..045eeba --- /dev/null +++ b/generalresearch/models/thl/ledger.py @@ -0,0 +1,625 @@ +from datetime import datetime, timezone +from enum import Enum +from typing import Dict, Optional, List, Literal, Annotated, Union +from uuid import uuid4 + +from pydantic import ( + BaseModel, + Field, + field_validator, + model_validator, + ConfigDict, + PositiveInt, + computed_field, + NonNegativeInt, +) +from typing_extensions import Self + +from generalresearch.models.custom_types import ( + UUIDStr, + AwareDatetimeISO, + check_valid_uuid, + HttpsUrlStr, +) +from generalresearch.models.thl.ledger_example import ( + _example_user_tx_payout, + _example_user_tx_bonus, + _example_user_tx_complete, + _example_user_tx_adjustment, +) +from generalresearch.models.thl.pagination import Page +from generalresearch.models.thl.payout_format import ( + PayoutFormatType, + format_payout_format, +) +from generalresearch.utils.enum import ReprEnumMeta + + +class Direction(int, Enum, metaclass=ReprEnumMeta): + """Entries on the debit side will increase debit normal accounts, while + entries on the credit side will decrease them. Conversely, entries on + the credit side will increase credit normal accounts, while entries on + the debit side will decrease them. + + By convention (?), the db will store transactions as debit-normal. For + a credit-normal account, we should flip the signs. + """ + + CREDIT = -1 + DEBIT = 1 + + +class OrderBy(str, Enum, metaclass=ReprEnumMeta): + ASC = "ASC" + + DESC = "DESC" + + +class AccountType(str, Enum, metaclass=ReprEnumMeta): + # Revenue from BP payment commission + BP_COMMISSION = "bp_commission" + # BP wallets (owed balance) + BP_WALLET = "bp_wallet" + # User's wallet + USER_WALLET = "user_wallet" + # Cash account + CASH = "cash" + # Revenue (money coming in) + REVENUE = "revenue" + # Expense + EXPENSE = "expense" + # Contest wallet (holds money entered into Contests) + CONTEST_WALLET = "contest_wallet" + # Line of Credit (LOC) account + CREDIT_LINE = "credit_line" + # wxet account operational funds + WA_WALLET = "wa_wallet" + # wxet account monies used to fund work + WA_BUDGET_POOL = "wa_budget_pool" + # wxet account funds which are being temporarily held + WA_HELD = "wa_held" + # wxet account credit line (LOC) + WA_CREDIT_LINE = "wa_credit_line" + + +class TransactionMetadataColumns(str, Enum): + BONUS = "bonus_id" + # Note: EVENT & EVENT2 represent the same concept. I accidentally made + # this inconsistent. + EVENT = "event_payout" + EVENT2 = "payoutevent" + + PAYOUT_TYPE = "payout_type" + SOURCE = "source" + + SESSION = "thl_session" + WALL = "thl_wall" + + TX_TYPE = "tx_type" + USER = "user" + + CONTEST = "contest" + + +class TransactionType(str, Enum): + """These are used in the Ledger to annotate the type of transaction (in + metadata: tx_type) + """ + + # We receive payment from a marketplace for a task complete + MP_PAYMENT = "mp_payment" + + # We pay a Brokerage Product for a session complete + BP_PAYMENT = "bp_payment" + + # A marketplace adjusts the payment for a task complete + MP_ADJUSTMENT = "mp_adjustment" + + # We adjust the payment to a BP for a session complete + BP_ADJUSTMENT = "bp_adjustment" + + # We pay out a Brokerage Product their balance + BP_PAYOUT = "bp_payout" + + # A user is paid (or penalized!) into their wallet balance for some reason, + # such as a leaderboard award, or reward for reporting a task. (This + # might be called "expenses" in finance reports). + USER_BONUS = "user_bonus" + + # A transaction is made to plug accounting imbalances + PLUG = "plug" + + # Transactions for a user requesting redemption of their wallet balance. + USER_PAYOUT_REQUEST = "user_payout_request" + USER_PAYOUT_COMPLETE = "user_payout_complete" + USER_PAYOUT_CANCEL = "user_payout_cancel" + + # User is entering a Contest (typically a Raffle) + USER_ENTER_CONTEST = "user_enter_contest" + CLOSE_CONTEST = "close_contest" + + # User won a milestone contest + USER_MILESTONE = "user_milestone" + + +class LedgerAccount(BaseModel, validate_assignment=True, frozen=True): + uuid: UUIDStr = Field( + default_factory=lambda: uuid4().hex, + description="A unique identifier for this Ledger Account", + examples=["c3c3566b5b1b4961b63a5670a2dc923d"], + ) + + display_name: str = Field( + max_length=64, + description="Human-readable description of the Ledger Account", + examples=["BP Wallet c3c3566b5b1b4961b63a5670a2dc923d"], + ) + + qualified_name: str = Field(max_length=255) + + account_type: AccountType = Field( + description=AccountType.as_openapi(), + examples=[AccountType.BP_WALLET.value], + ) + + normal_balance: Direction = Field(description=Direction.as_openapi()) + + reference_type: Optional[str] = Field(default=None) + + reference_uuid: Optional[UUIDStr] = Field( + default=None, + description="The associated Product ID or other parent account that" + "this Ledger Account is intended to track transactions for." + "If Wallet mode is enabled, this can also handle tracking" + "individual users.", + examples=["61dd0b086fd048518762757612b4a6d3"], + ) + + currency: str = Field( + default="USD", + max_length=32, + description="GRL's Ledger system allows tracking of transactions in" + "any currency possible. This is useful for tracking" + "points, stars, coins, or any other currency that may be" + "used in a Supplier's platform.", + ) + + @model_validator(mode="after") + def check_qualified_name(self) -> Self: + assert self.qualified_name.startswith( + f"{self.currency}:{self.account_type.value}" + ), "qualified name should start with {currency}:{account_type}" + return self + + @field_validator("currency", mode="after") + def check_currency(cls, currency: str) -> str: + # The currency should be either USD (or "test") or a valid uuid. + from generalresearch.currency import LedgerCurrency + + if currency not in [e.value for e in LedgerCurrency]: + check_valid_uuid(currency) + return currency + + +class LedgerEntry(BaseModel): + id: Optional[int] = Field(default=None) + + direction: Direction + account_uuid: UUIDStr + + amount: PositiveInt = Field( + lt=2**63 - 1, + strict=True, + description="The USDCent amount. A LedgerEntry cannot be made for" + "0 USDCent and it cannot be negative.", + examples=[531], + ) + + # This really shouldn't be Optional, but it has to be in order to + # instantiate this class before the LedgerTransaction exists + transaction_id: Optional[int] = Field(default=None) + + @classmethod + def from_amount(cls, account_uuid: UUIDStr, amount: int): + if amount > 0: + return cls( + direction=Direction.CREDIT, + amount=amount, + account_uuid=account_uuid, + ) + elif amount < 0: + return cls( + direction=Direction.DEBIT, + amount=abs(amount), + account_uuid=account_uuid, + ) + else: + raise ValueError("amount must not be 0") + + # @property + # def hash(self): + # # we'll use this to prevent duplicate transactions. + # s = ';'.join(sorted([self.direction, self.account_uuid, self.amount], + # key=lambda x: (x.direction, x.account_uuid, x.amount))) + # hash_object = hashlib.sha256(s.encode()) + # return hash_object.hexdigest() + + +class LedgerTransaction(BaseModel): + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + id: Optional[int] = Field(default=None) + + created: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When the Transaction (TX) was created into the database." + "This does not represent the exact time for any action" + "which may be responsible for this Transaction (TX), and " + "TX timestamps will likely be a few milliseconds delayed", + ) + + ext_description: Optional[str] = Field(default=None, max_length=255) + tag: Optional[str] = Field(default=None, max_length=255) + metadata: Dict[str, str] = Field(default_factory=dict) + + entries: List[LedgerEntry] = Field( + default_factory=list, + description="A Transaction (TX) is composed of multiple Entry events.", + ) + + # @property + # def hash(self): + # # we'll use this to prevent duplicate transactions. + # metadata_str = ",".join(sorted([f"{k}={v}" for k, v in self.metadata.items()])) + # s = ';'.join([metadata_str] + [entry.hash for entry in self.entries]) + # hash_object = hashlib.sha256(s.encode()) + # return hash_object.hexdigest() + + @field_validator("created", mode="after") + @classmethod + def check_created_future(cls, created: AwareDatetimeISO) -> AwareDatetimeISO: + """Created should not be in the future. This will mess up + LedgerAccountStatement / groupby rollups. + """ + assert ( + datetime.now(tz=timezone.utc) > created + ), "created cannot be in the future" + return created + + @field_validator("entries", mode="after") + @classmethod + def check_entries(cls, entries: List[LedgerEntry]) -> List[LedgerEntry]: + """Transactions should enforce double-entry upon creation. Each + transaction needs to have at least two entries, which, in aggregate, + must affect credit and debit sides in equal amounts. + """ + if entries: + assert len(entries) >= 2, "ledger transaction must have 2 or more entries" + assert ( + sum(x.amount * x.direction for x in entries) == 0 + ), "ledger entries must balance" + return entries + + def model_dump_mysql(self, *args, **kwargs) -> dict: + d = self.model_dump(mode="json", *args, **kwargs) + if "created" in d: + d["created"] = self.created.replace(tzinfo=None) + return d + + def to_user_tx( + self, user_account: LedgerAccount, product_id: str, payout_format: str + ): + from generalresearch.models.thl.wallet import PayoutType + + d = self.model_dump(include={"created"}) + d["tx_type"] = self.metadata.get("tx_type") + d["product_id"] = product_id + d["payout_format"] = payout_format + debits = [ + x + for x in self.entries + if x.direction == Direction.DEBIT and x.account_uuid == user_account.uuid + ] + credits = [ + x + for x in self.entries + if x.direction == Direction.CREDIT and x.account_uuid == user_account.uuid + ] + + if d["tx_type"] == TransactionType.USER_PAYOUT_REQUEST.value: + assert len(debits) == 1 + d["amount"] = debits[0].amount * -1 + d["payout_id"] = self.metadata["payoutevent"] + payout_type = PayoutType(self.metadata["payout_type"].upper()) + if payout_type == PayoutType.AMT_ASSIGNMENT: + d["description"] = "HIT Reward" + elif payout_type == PayoutType.AMT_BONUS: + d["description"] = "HIT Bonus" + else: + raise ValueError(payout_type) + return UserLedgerTransactionUserPayout.model_validate(d) + elif d["tx_type"] == TransactionType.BP_PAYMENT.value: + assert len(credits) == 1 + d["amount"] = credits[0].amount + d["tsid"] = self.metadata.get("thl_session") + return UserLedgerTransactionTaskComplete.model_validate(d) + elif d["tx_type"] == TransactionType.USER_BONUS.value: + assert len(credits) == 1 + d["amount"] = credits[0].amount + return UserLedgerTransactionUserBonus.model_validate(d) + elif d["tx_type"] == TransactionType.BP_ADJUSTMENT.value: + assert len(debits) == 1 or len(credits) == 1 + if len(debits) == 1: + # complete -> fail + d["amount"] = debits[0].amount * -1 + else: + # fail -> complete + d["amount"] = credits[0].amount + d["tsid"] = self.metadata.get("thl_session") + return UserLedgerTransactionTaskAdjustment.model_validate(d) + + +class UserLedgerTransaction(BaseModel): + """ + Represents a LedgerTransaction item that would get shown to a user. This + is only used in wallet-managed accounts. Everything (especially the + amount) is w.r.t the user. + """ + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + created: AwareDatetimeISO = Field(description="When the Transaction was created") + + description: str = Field( + max_length=255, description="External description suitable for UI" + ) + + # This should be a USDCent, but we need to support negative numbers here ... + amount: int = Field( + strict=True, + description=( + "The net amount affecting the user's wallet, in USDCents. " + "Positive means the user's balance increased; negative means it decreased." + ), + examples=[+500, -250], + ) + + # Needed to generate urls + product_id: Optional[str] = Field(default=None, exclude=True) + # Needed to generate amount_string + payout_format: Optional[PayoutFormatType] = Field(default=None, exclude=True) + # The balance in this account immediately after this tx. + # It is optional b/c we'll calculate this from the query + balance_after: Optional[int] = Field(default=None) + + def create_url(self, product_id: str): + raise NotImplementedError() + + @computed_field( + description="A link to where the user can get more details about this transaction", + ) + def url(self) -> Optional[HttpsUrlStr]: + if self.product_id is None: + return None + return self.create_url(product_id=self.product_id) + + @computed_field( + description="The 'amount' with the payout_format applied.", + ) + def amount_string(self) -> Optional[HttpsUrlStr]: + if self.payout_format is None: + return None + return format_payout_format( + payout_format=self.payout_format, payout_int=self.amount + ) + + +class UserLedgerTransactionUserPayout(UserLedgerTransaction): + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + json_schema_extra=_example_user_tx_payout, + ) + + tx_type: Literal[TransactionType.USER_PAYOUT_REQUEST] = Field( + default=TransactionType.USER_PAYOUT_REQUEST + ) + + payout_id: UUIDStr = Field( + description="A unique identifier for the payout", + examples=["a3848e0a53d64f68a74ced5f61b6eb68"], + ) + + def create_url(self, product_id: str): + return f"https://fsb.generalresearch.com/{product_id}/cashout/{self.payout_id}/" + + @model_validator(mode="after") + def validate_amount(self): + assert self.amount < 0, ( + "In a user payout, the amount should be negative. This represents the user's " + "wallet balance decreasing because this amount was actually dispersed to them." + ) + return self + + +class UserLedgerTransactionUserBonus(UserLedgerTransaction): + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + json_schema_extra=_example_user_tx_bonus, + ) + + tx_type: Literal[TransactionType.USER_BONUS] = Field( + default=TransactionType.USER_BONUS + ) + description: str = Field( + max_length=255, + description="External description suitable for UI", + default="Compensation Bonus", + ) + + def create_url(self, product_id: str): + return None + + @model_validator(mode="after") + def validate_amount(self): + assert self.amount > 0, f"UserLedgerTransactionUserBonus: {self.amount=}" + return self + + +class UserLedgerTransactionTaskComplete(UserLedgerTransaction): + """ + In a BP with user wallet enabled, the task-complete transaction would have + line items for both the credit to the bp_wallet_account and credit to + user_account. This is the user-detail, so we've only caring about the + user's payment. + """ + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + json_schema_extra=_example_user_tx_complete, + ) + + tx_type: Literal[TransactionType.BP_PAYMENT] = Field( + default=TransactionType.BP_PAYMENT + ) + + description: str = Field( + max_length=255, + description="External description suitable for UI", + default="Task Complete", + ) + + tsid: UUIDStr = Field( + description="A unique identifier for the session", + examples=["a3848e0a53d64f68a74ced5f61b6eb68"], + ) + + def create_url(self, product_id: str): + return f"https://fsb.generalresearch.com/{product_id}/status/{self.tsid}/" + + @model_validator(mode="after") + def validate_amount(self): + assert self.amount >= 0, f"UserLedgerTransactionTaskComplete: {self.amount=}" + return self + + +class UserLedgerTransactionTaskAdjustment(UserLedgerTransaction): + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + json_schema_extra=_example_user_tx_adjustment, + ) + + tx_type: Literal[TransactionType.BP_ADJUSTMENT] = Field( + default=TransactionType.BP_ADJUSTMENT + ) + + description: str = Field( + max_length=255, + description="External description suitable for UI", + default="Task Adjustment", + ) + + tsid: UUIDStr = Field( + description="A unique identifier for the session", + examples=["a3848e0a53d64f68a74ced5f61b6eb68"], + ) + + def create_url(self, product_id: str): + return f"https://fsb.generalresearch.com/{product_id}/status/{self.tsid}/" + + +UserLedgerTransactionType = Annotated[ + Union[ + UserLedgerTransactionUserPayout, + UserLedgerTransactionUserBonus, + UserLedgerTransactionTaskAdjustment, + UserLedgerTransactionTaskComplete, + ], + Field(discriminator="tx_type"), +] + + +class UserLedgerTransactionTypeSummary(BaseModel): + entry_count: NonNegativeInt = Field(default=0) + min_amount: Optional[int] = Field( + description="positive or negative USDCent", default=None + ) + max_amount: Optional[int] = Field( + description="positive or negative USDCent", default=None + ) + total_amount: Optional[int] = Field( + description="positive or negative USDCent", default=None + ) + + +class UserLedgerTransactionTypesSummary(BaseModel): + # Each key is a possible value of the TransactionType enum + bp_adjustment: UserLedgerTransactionTypeSummary = Field( + default_factory=UserLedgerTransactionTypeSummary + ) + bp_payment: UserLedgerTransactionTypeSummary = Field( + default_factory=UserLedgerTransactionTypeSummary + ) + user_bonus: UserLedgerTransactionTypeSummary = Field( + default_factory=UserLedgerTransactionTypeSummary + ) + user_payout_request: UserLedgerTransactionTypeSummary = Field( + default_factory=UserLedgerTransactionTypeSummary + ) + + +class UserLedgerTransactions(Page): + """ + A (paginated) collection that holds transaction models that can be shown to a (wallet-managed) user. + """ + + transactions: List[UserLedgerTransactionType] = Field(default_factory=list) + # The summary is w.r.t an optional time-filter. The transactions are + # paginated so the counts won't necesarily match. In other words, the + # summary is across all transaction in all pages, not this the transactions + # in this page. + summary: UserLedgerTransactionTypesSummary = Field() + + @classmethod + def from_txs( + cls, + user_account: LedgerAccount, + txs: List[LedgerTransaction], + product_id: str, + payout_format: str, + summary: UserLedgerTransactionTypesSummary, + page: int, + size: int, + total: int, + ): + user_txs = [ + tx.to_user_tx( + user_account=user_account, + product_id=product_id, + payout_format=payout_format, + ) + for tx in txs + ] + return cls.model_validate( + { + "transactions": user_txs, + "summary": summary, + "total": total, + "page": page, + "size": size, + } + ) + + +class LedgerAccountStatement(BaseModel): + id: Optional[int] = Field(default=None) + account_uuid: UUIDStr + filter_str: Optional[str] = Field(default=None) + effective_at_lower_bound: AwareDatetimeISO + effective_at_upper_bound: AwareDatetimeISO + starting_balance: int = Field(lt=2**63 - 1, ge=0) + ending_balance: int = Field(lt=2**63 - 1, ge=0) + sql_query: Optional[str] = Field(default=None) diff --git a/generalresearch/models/thl/ledger_example.py b/generalresearch/models/thl/ledger_example.py new file mode 100644 index 0000000..32cd464 --- /dev/null +++ b/generalresearch/models/thl/ledger_example.py @@ -0,0 +1,62 @@ +from datetime import datetime, timezone +from typing import Dict +from uuid import uuid4 + + +def _example_user_tx_payout(schema: Dict) -> None: + from generalresearch.models.thl.ledger import ( + UserLedgerTransactionUserPayout, + ) + + schema["example"] = UserLedgerTransactionUserPayout( + product_id=uuid4().hex, + payout_id=uuid4().hex, + amount=-5, + description="HIT Reward", + payout_format="${payout/100:.2f}", + created=datetime.now(tz=timezone.utc), + ).model_dump(mode="json") + + +def _example_user_tx_bonus(schema: Dict) -> None: + from generalresearch.models.thl.ledger import ( + UserLedgerTransactionUserBonus, + ) + + schema["example"] = UserLedgerTransactionUserBonus( + product_id=uuid4().hex, + amount=100, + description="Compensation Bonus", + payout_format="${payout/100:.2f}", + created=datetime.now(tz=timezone.utc), + ).model_dump(mode="json") + + +def _example_user_tx_complete(schema: Dict) -> None: + from generalresearch.models.thl.ledger import ( + UserLedgerTransactionTaskComplete, + ) + + schema["example"] = UserLedgerTransactionTaskComplete( + product_id=uuid4().hex, + amount=38, + description="Task Complete", + payout_format="${payout/100:.2f}", + created=datetime.now(tz=timezone.utc), + tsid=uuid4().hex, + ).model_dump(mode="json") + + +def _example_user_tx_adjustment(schema: Dict) -> None: + from generalresearch.models.thl.ledger import ( + UserLedgerTransactionTaskAdjustment, + ) + + schema["example"] = UserLedgerTransactionTaskAdjustment( + product_id=uuid4().hex, + amount=-38, + description="Task Adjustment", + payout_format="${payout/100:.2f}", + created=datetime.now(tz=timezone.utc), + tsid=uuid4().hex, + ).model_dump(mode="json") diff --git a/generalresearch/models/thl/locales.py b/generalresearch/models/thl/locales.py new file mode 100644 index 0000000..210a158 --- /dev/null +++ b/generalresearch/models/thl/locales.py @@ -0,0 +1,32 @@ +from typing import Annotated, Set + +from pydantic import AfterValidator + +from generalresearch.locales import Localelator +from generalresearch.models.custom_types import ( + to_comma_sep_str, + from_comma_sep_str, +) + +locale_helper = Localelator() +COUNTRY_ISOS: Set[str] = locale_helper.get_all_countries() +LANGUAGE_ISOS: Set[str] = locale_helper.get_all_languages() + + +def is_valid_country_iso(v: str) -> str: + assert v in COUNTRY_ISOS, f"invalid country_iso: {v}" + return v + + +def is_valid_language_iso(v: str) -> str: + assert v in LANGUAGE_ISOS, f"invalid language_iso: {v}" + return v + + +# ISO 3166-1 alpha-2 (two-letter codes, lowercase) +CountryISO = Annotated[str, AfterValidator(is_valid_country_iso)] +# 3-char ISO 639-2/B, lowercase +LanguageISO = Annotated[str, AfterValidator(is_valid_language_iso)] + +CountryISOs = Annotated[Set[CountryISO], to_comma_sep_str, from_comma_sep_str] +LanguageISOs = Annotated[Set[LanguageISO], to_comma_sep_str, from_comma_sep_str] diff --git a/generalresearch/models/thl/maxmind/__init__.py b/generalresearch/models/thl/maxmind/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/models/thl/maxmind/definitions.py b/generalresearch/models/thl/maxmind/definitions.py new file mode 100644 index 0000000..01431c7 --- /dev/null +++ b/generalresearch/models/thl/maxmind/definitions.py @@ -0,0 +1,22 @@ +from enum import Enum + +from generalresearch.utils.enum import ReprEnumMeta + + +class UserType(Enum, metaclass=ReprEnumMeta): + # https://support.maxmind.com/hc/en-us/articles/4408430082971-IP-Trait-Risk-Data#h_01FN6V8JMQMWZGWNPPAW77ZPY4 + BUSINESS = "business" + CAFE = "cafe" + CELLULAR = "cellular" + COLLEGE = "college" + CDN = "content_delivery_network" + CPN = "consumer_privacy_network" + GOVERNMENT = "government" + HOSTING = "hosting" + LIBRARY = "library" + MILITARY = "military" + RESIDENTIAL = "residential" + ROUTER = "router" + SCHOOL = "school" + SEARCH_ENGINE = "search_engine_spider" + TRAVELER = "traveler" diff --git a/generalresearch/models/thl/offerwall/__init__.py b/generalresearch/models/thl/offerwall/__init__.py new file mode 100644 index 0000000..2da9d43 --- /dev/null +++ b/generalresearch/models/thl/offerwall/__init__.py @@ -0,0 +1,321 @@ +from __future__ import annotations + +import hashlib +import json +from decimal import Decimal +from enum import Enum +from typing import Literal, Optional, Dict, Set, Any + +from pydantic import ( + BaseModel, + Field, + model_validator, + computed_field, + PositiveInt, +) +from typing_extensions import Self + +from generalresearch.models import Source +from generalresearch.models.custom_types import IPvAnyAddressStr +from generalresearch.models.thl.locales import ( + CountryISO, + LanguageISO, + locale_helper, +) +from generalresearch.models.thl.offerwall.behavior import ( + OfferWallBehaviorsType, +) +from generalresearch.models.thl.product import ( + OfferWallCategoryRequest, + OfferWallRequestYieldmanParams, +) +from generalresearch.models.thl.user import User + + +class OfferWallType(str, Enum): + """ + The specific offerwall type + """ + + TOPN_PLUS = "b145b803" + TOPN_PLUS_BLOCK = "d48cce47" + TOPN_PLUS_BLOCK_RECONTACT = "1e5f0af8" + STARWALL_PLUS = "5481f322" + STARWALL_PLUS_BLOCK = "7fa1b3f4" + STARWALL_PLUS_BLOCK_RECONTACT = "630db2a4" + MARKETPLACE = "5fa23085" + TIMEBUCKS = "1705e4f8" + TIMEBUCKS_BLOCK = "0af0f7ec" + SOFTPAIR = "37d1da64" + SOFTPAIR_BLOCK = "7a89dcdb" + ONESHOT = "6f27b1ae" + ONESHOT_SOFTPAIR = "18347426" + WXET = "55a4e1a9" + # LEGACY + SINGLE = "5fl8bpv5" + TOPN = "45b7228a7" + STARWALL = "b59a2d2b" + + +class OfferWallTypeClass(str, Enum): + """ + A higher level "class" to organize similar offerwall types. + For e.g. STARWALL_PLUS_BLOCK, STARWALL_PLUS, STARWALL all use the same + bucket-generation algorithm, and have the same API response, except + for maybe including extra keys or have specific customizations. + """ + + TOPN = "TOPN" + STARWALL = "STARWALL" + MARKETPLACE = "MARKETPLACE" + SOFTPAIR = "SOFTPAIR" + SINGLE = "SINGLE" + + +OFFERWALL_TYPE_CLASS = { + OfferWallType.TOPN: OfferWallTypeClass.TOPN, + OfferWallType.TOPN_PLUS: OfferWallTypeClass.TOPN, + OfferWallType.TOPN_PLUS_BLOCK: OfferWallTypeClass.TOPN, + OfferWallType.TOPN_PLUS_BLOCK_RECONTACT: OfferWallTypeClass.TOPN, + OfferWallType.TIMEBUCKS: OfferWallTypeClass.TOPN, + OfferWallType.TIMEBUCKS_BLOCK: OfferWallTypeClass.TOPN, + OfferWallType.ONESHOT: OfferWallTypeClass.STARWALL, + OfferWallType.WXET: OfferWallTypeClass.STARWALL, + OfferWallType.STARWALL: OfferWallTypeClass.STARWALL, + OfferWallType.STARWALL_PLUS: OfferWallTypeClass.STARWALL, + OfferWallType.STARWALL_PLUS_BLOCK: OfferWallTypeClass.STARWALL, + OfferWallType.STARWALL_PLUS_BLOCK_RECONTACT: OfferWallTypeClass.STARWALL, + OfferWallType.MARKETPLACE: OfferWallTypeClass.MARKETPLACE, + OfferWallType.SOFTPAIR: OfferWallTypeClass.SOFTPAIR, + OfferWallType.SOFTPAIR_BLOCK: OfferWallTypeClass.SOFTPAIR, + OfferWallType.ONESHOT_SOFTPAIR: OfferWallTypeClass.SOFTPAIR, + OfferWallType.SINGLE: OfferWallTypeClass.SINGLE, +} + +# TODO: We could have a class for each offerwalltype, and each has attributes, +# but this is the only attribute I can think of, so just doing this +USER_BLOCK_OFFERWALLS = { + OfferWallType.TOPN_PLUS_BLOCK, + OfferWallType.STARWALL_PLUS_BLOCK, + OfferWallType.TOPN_PLUS_BLOCK_RECONTACT, + OfferWallType.STARWALL_PLUS_BLOCK_RECONTACT, + OfferWallType.TIMEBUCKS_BLOCK, + OfferWallType.ONESHOT, + OfferWallType.ONESHOT_SOFTPAIR, + OfferWallType.SOFTPAIR_BLOCK, + OfferWallType.WXET, +} + + +class OfferWallRequest(BaseModel): + offerwall_type: OfferWallType = Field() + user: User = Field() + + ip: Optional[IPvAnyAddressStr] = Field( + default=None, + description="Respondent's IP address (IPv4 or IPv6). Either 'ip' must be " + "provided, or 'country_iso' must be provided if 'ip' is " + "not provided.", + ) + + country_iso: CountryISO = Field( + description="Respondent's country code (ISO 3166-1 alpha-2, lowercase)" + ) + language_isos: Set[LanguageISO] = Field( + description="Respondent's desired language (ISO 639-2/B, lowercase)", + ) + + behavior: Optional[OfferWallBehaviorsType] = Field( + default=None, + max_length=12, + description="Allows using custom scoring functions. Please " + "discuss directly with GRL.", + ) + + min_payout: Optional[Decimal] = Field( + default=None, + description="Decimal representation of the minimum amount of USD that " + "any of the tasks will pay", + examples=["1.23"], + ) + + duration: Optional[int] = Field( + default=60 * 90, + description="Maximum length of desired task (in seconds).", + gt=0, + ) + + n_bins: Optional[int] = Field( + default=None, + description="Number of bins requested in the offerwall.", + le=100, + gt=0, + ) + + min_bin_size: Optional[int] = Field( + default=None, + description="Minimum number of tasks that need to be in a bucket", + gt=0, + le=100, + ) + + dynamic_min_bin_size: bool = Field( + default=True, + description="Allows the bin size to drop below the min bin size when not enough tasks " + "are available.", + ) + + split_by: Literal["payout", "duration"] = Field( + default="payout", description="Cluster tasks by payout or duration" + ) + + passthrough_kwargs: Dict[str, str] = Field( + default_factory=dict, + description="These are pulled from the url params. They are any 'extra' url params " + "in the getofferwall request. They'll be available through the task_status " + "endpoint. These used to be in the wallsessionmetadata.", + ) + + # Only for soft pair (offerwall_id, max_options, max_questions) + offerwall_id: Optional[str] = Field(default=None) + max_options: Optional[int] = Field( + default=None, + description="Max number of options an allowed question can have (allowed to be asked)", + ) + max_questions: Optional[int] = Field( + default=None, + description="Max number of missing questions on a single bin", + ) + + category_request: OfferWallCategoryRequest = Field( + default_factory=OfferWallCategoryRequest + ) + + yieldman_kwargs: OfferWallRequestYieldmanParams = Field( + default_factory=OfferWallRequestYieldmanParams, + description="These get passed into the scoring function to adjust how filter/score eligible" + "tasks that are used to build an offerwall. There are not setable directly from the " + "url, instead a behavior can be set, which may translate into things here. Or" + "these may be set in the bpc table globally for a BP.", + ) + + marketplaces: Optional[Set[Source]] = Field( + default=None, + description="If set, restrict tasks to those from these marketplaces only.", + ) + + grpc_method: Literal["GetOpportunityIDs", "GetOpportunitiesSoftPairing"] = Field( + description="Which grpc method should be hit for this offerwall", + default="GetOpportunityIDs", + ) + + @model_validator(mode="after") + def check_grpc_method(self) -> Self: + if self.offerwall_type_class == OfferWallTypeClass.SOFTPAIR: + assert self.grpc_method == "GetOpportunitiesSoftPairing", "grpc_method" + else: + assert self.grpc_method == "GetOpportunityIDs", "grpc_method" + return self + + @model_validator(mode="after") + def set_language_isos(self) -> Self: + # Special hook for languages. If no languages were passed, we set the + # lang_codes to 'eng' and the default lang for their country. + if len(self.language_isos) == 0: + self.language_isos = { + "eng", + locale_helper.get_default_lang_from_country(self.country_iso), + } - {None} + return self + + @model_validator(mode="after") + def set_offerwall_defaults(self) -> Self: + # Set specific defaults depending on the offerwall type + if self.offerwall_type_class == OfferWallTypeClass.SOFTPAIR: + self.max_options = self.max_options if self.max_options is not None else 40 + self.max_questions = ( + self.max_questions if self.max_questions is not None else 3 + ) + self.n_bins = self.n_bins if self.n_bins is not None else 12 + self.min_bin_size = ( + self.min_bin_size if self.min_bin_size is not None else 3 + ) + if self.offerwall_type_class == OfferWallTypeClass.MARKETPLACE: + self.min_bin_size = ( + self.min_bin_size if self.min_bin_size is not None else 3 + ) + if self.offerwall_type_class in { + OfferWallTypeClass.STARWALL, + OfferWallTypeClass.TOPN, + }: + self.n_bins = self.n_bins if self.n_bins is not None else 1 + self.min_bin_size = ( + self.min_bin_size if self.min_bin_size is not None else 1 + ) + return self + + @computed_field + def offerwall_type_class(self) -> OfferWallTypeClass: + return OFFERWALL_TYPE_CLASS[self.offerwall_type] + + @property + def request_id(self) -> str: + return hashlib.md5( + json.dumps(self.model_dump(mode="json"), sort_keys=True).encode("utf-8") + ).hexdigest()[:7] + + def to_grpc_request(self) -> Dict[str, Any]: + # We need this so thl-core can refresh an offerwall in order to continue + # a session + d = self.model_dump(mode="json") + kwargs = dict() + keys = [ + "n_bins", + "min_bin_size", + "max_options", + "max_questions", + "behavior", + "min_payout", + "duration", + "dynamic_min_bin_size", + "split_by", + ] + for k in keys: + if getattr(self, k) is not None: + kwargs[k] = str(getattr(self, k)) + d["offerwall_kwargs"] = kwargs + d["start_task"] = { + "product_id": self.user.product_id, + "bp_user_id": self.user.product_user_id, + "req_duration": self.duration, + "country_iso": self.country_iso, + "languages": [{"iso_code": x} for x in self.language_isos], + "kwargs": kwargs, + } + # We can't import protos in here, so the caller has to actually + # cast this dict as a generalresearch_pb2.OfferwallRequest + return { + "start_task": d["start_task"], + "offerwall_type": d["offerwall_type"], + "offerwall_kwargs": d["offerwall_kwargs"], + } + + @property + def product_id(self) -> Optional[str]: + return self.user.product_id + + @property + def product_user_id(self) -> Optional[str]: + return self.user.product_user_id + + @property + def bpuid(self) -> Optional[str]: + return self.user.product_user_id + + @property + def min_opp_count(self) -> PositiveInt: + # This is the min number of surveys we need available to show any + # offerwall at all. + if self.min_bin_size is not None: + return self.min_bin_size + return 3 diff --git a/generalresearch/models/thl/offerwall/base.py b/generalresearch/models/thl/offerwall/base.py new file mode 100644 index 0000000..66149a9 --- /dev/null +++ b/generalresearch/models/thl/offerwall/base.py @@ -0,0 +1,685 @@ +import statistics +from datetime import timedelta +from decimal import Decimal +from string import Formatter +from typing import Optional, List, Any, Set, Dict, Tuple +from uuid import uuid4 + +import numpy as np +import pandas as pd +from pydantic import ( + BaseModel, + Field, + NonNegativeFloat, + NonNegativeInt, + ConfigDict, + field_validator, + model_validator, +) +from typing_extensions import Self, Annotated + +from generalresearch.models import Source +from generalresearch.models.custom_types import UUIDStr, HttpsUrl +from generalresearch.models.legacy.bucket import ( + Bucket as LegacyBucket, + Eligibility, + CategoryAssociation, + DurationSummary, + PayoutSummary, + PayoutSummaryDecimal, + SurveyEligibilityCriterion, +) +from generalresearch.models.legacy.definitions import OfferwallReason +from generalresearch.models.thl.locales import CountryISO +from generalresearch.models.thl.offerwall import ( + OfferWallType, + OfferWallTypeClass, + OFFERWALL_TYPE_CLASS, +) +from generalresearch.models.thl.offerwall.bucket import ( + generate_offerwall_entry_url, +) +from generalresearch.models.thl.profiling.upk_question import UpkQuestion +from generalresearch.models.thl.soft_pair import SoftPairResultType +from generalresearch.models.thl.user import User + + +class MergeTableFeatures(BaseModel): + """ + This is just a pydantic representation of the survey stats features from + a row from the merge table. It isn't meant to be used by itself. + """ + + model_config = ConfigDict(allow_inf_nan=False, populate_by_name=True) + + PRESCREEN_CONVERSION_ALPHA: float = Field( + description="Alpha parameter from a Beta distribution", + alias="PRESCREEN_CONVERSION.alpha", + gt=0, + default=1, + ) + PRESCREEN_CONVERSION_BETA: float = Field( + description="Beta parameter from a Beta distribution", + alias="PRESCREEN_CONVERSION.beta", + ge=0, + default=0, + ) + PRESCREEN_CONVERSION: float = Field( + description="Penalized mean value for the task's prescreen conversion. The penalized mean is the 20th " + "percentile" + "of the inverse cumulative distribution.", + ge=0, + le=1, + default=1, + ) + + CONVERSION_ALPHA: float = Field( + description="Alpha parameter from a Beta distribution", + alias="CONVERSION.alpha", + gt=0, + ) + CONVERSION_BETA: float = Field( + description="Beta parameter from a Beta distribution", + alias="CONVERSION.beta", + gt=0, + ) + CONVERSION: float = Field( + description="Penalized mean value for the task's conversion. The penalized mean is the 20th percentile" + "of the inverse cumulative distribution.", + ge=0, + le=1, + ) + + # Normal distribution, so mu is real number, but this represents the completion time, so it + # has to be positive. We can restrict it more in that me are never going to predict + # time longer than ~~ 2 hours (np.log(120*60)) or <= 0 sec (np.log(1) = 0) + COMPLETION_TIME_MU: float = Field( + description="Mu parameter from a Normal distribution", + alias="COMPLETION_TIME.mu", + gt=1, + le=10, + ) + COMPLETION_TIME_SIGMA: float = Field( + description="Sigma parameter from a Normal distribution", + alias="COMPLETION_TIME.sigma", + gt=0, + lt=10, + ) + COMPLETION_TIME_LOG: float = Field( + description="Penalized mean value for the task's log-transformed completion time. The penalized " + "mean is the 80th percentile of the inverse cumulative distribution.", + ge=0, + le=10, + ) + COMPLETION_TIME: float = Field( + description="Exponential of the COMPLETION_TIME_LOG. This is in seconds.", + gt=0, + le=120 * 60, + ) + # Note: We generally also will have a predicted_loi or something with is just the inverse-log + # of COMPLETION_TIME, so that we can report it in seconds. + + DROPOFF_RATE_ALPHA: float = Field( + description="Alpha parameter from a Beta distribution", + alias="DROPOFF_RATE.alpha", + gt=0, + ) + DROPOFF_RATE_BETA: float = Field( + description="Beta parameter from a Beta distribution", + alias="DROPOFF_RATE.beta", + gt=0, + ) + DROPOFF_RATE: float = Field( + description="Penalized mean value for the task's dropoff/abandonment rate. The penalized mean is the 60th " + "percentile of the inverse cumulative distribution.", + ge=0, + le=1, + ) + + USER_REPORT_COEFF: float = Field( + description="Lower values indicate the task, or similar tasks, have been reported by users.", + ge=0, + le=1, + default=1, + ) + + LONG_FAIL: float = Field( + description="Lower values indicate the task is likely to terminate later", + ge=0, + le=10, + default=1, + ) + + RECON_LIKELIHOOD: float = Field( + description="Likelihood the task will get reconciled.", + ge=0, + le=1, + default=0, + ) + + IS_MOBILE_ELIGIBLE_ALPHA: float = Field( + description="Alpha parameter from a Beta distribution", + alias="IS_MOBILE_ELIGIBLE.alpha", + gt=0, + default=1, + ) + IS_MOBILE_ELIGIBLE_BETA: float = Field( + description="Beta parameter from a Beta distribution", + alias="IS_MOBILE_ELIGIBLE.beta", + ge=0, + default=0, + ) + IS_MOBILE_ELIGIBLE: float = Field( + description="Penalized mean likelihood that the task can be completed on a mobile device", + ge=0, + le=1, + default=1, + ) + + IS_DESKTOP_ELIGIBLE_ALPHA: float = Field( + description="Alpha parameter from a Beta distribution", + alias="IS_DESKTOP_ELIGIBLE.alpha", + gt=0, + default=1, + ) + IS_DESKTOP_ELIGIBLE_BETA: float = Field( + description="Beta parameter from a Beta distribution", + alias="IS_DESKTOP_ELIGIBLE.beta", + ge=0, + default=0, + ) + IS_DESKTOP_ELIGIBLE: float = Field( + description="Penalized mean likelihood that the task can be completed on a Desktop", + ge=0, + le=1, + default=1, + ) + + IS_TABLET_ELIGIBLE_ALPHA: float = Field( + description="Alpha parameter from a Beta distribution", + alias="IS_TABLET_ELIGIBLE.alpha", + gt=0, + default=1, + ) + IS_TABLET_ELIGIBLE_BETA: float = Field( + description="Beta parameter from a Beta distribution", + alias="IS_TABLET_ELIGIBLE.beta", + ge=0, + default=0, + ) + IS_TABLET_ELIGIBLE: float = Field( + description="Penalized mean likelihood that the task can be completed on a Tablet", + ge=0, + le=1, + default=1, + ) + + @model_validator(mode="before") + @classmethod + def set_completion_time_log(cls, data: Dict[str, Any]) -> Dict[str, Any]: + # This isn't actually in the merge table + data["COMPLETION_TIME_LOG"] = np.log(data["COMPLETION_TIME"]) + return data + + +class TaskResult(BaseModel): + """ + This is a task, like as in ScoredTaskResult, but one that does not have + associated scoring features. This is used only for GRS tasks + """ + + model_config = ConfigDict(allow_inf_nan=False, extra="forbid") + + # used in prioritize_with_stats_ids + internal_id: str = Field( + description="This is the survey's id within the marketplace" + ) + source: Source = Field() + country_iso: CountryISO = Field() + buyer_id: Optional[str] = Field(min_length=1, max_length=32, default=None) + + # todo: GRS is allowed to be 0, but all the others can't. make a validator + cpi: Decimal = Field(ge=0, le=100, decimal_places=5, max_digits=7) + + # Only GRS tasks will have this set. All other marketplaces will have + # to make a grpc call to generate this. This is a str b/c it is actually + # a format string. + entry_link: Optional[str] = Field( + default=None, + examples=[ + "https://{domain}/session/?39057c8b=c4ed212601494f8c8836e38a55102d10&c184efc0=test&0bb50182={mid}" + ], + ) + + @model_validator(mode="after") + def validate_cpi(self) -> Self: + if self.cpi == 0: + assert self.source == Source.GRS, "cpi should be >0" + return self + + @model_validator(mode="after") + def validate_entry_link(self) -> Self: + if self.source == Source.GRS: + if self.entry_link: + fmt_str = sorted( + [ + fname + for _, fname, _, _ in Formatter().parse(self.entry_link) + if fname + ] + ) + assert all( + x in {"domain", "mid"} for x in fmt_str + ), "unrecognized format variable" + else: + assert self.entry_link is None, f"entry link not allowed for {self.source}" + return self + + @property + def external_id(self) -> str: + return f"{self.source.value}:{self.internal_id}" + + @property + def id_code(self) -> str: + return self.external_id + + +class ScoredTaskResult(TaskResult, MergeTableFeatures): + """ + This represents a single task, that a user is eligible for, and the task's + associated scoring features. + A list of these are used for further filtering and eventually in order to + generate an offerwall + """ + + model_config = ConfigDict(allow_inf_nan=False, extra="ignore") + + cpi: Decimal = Field(gt=0, le=100, decimal_places=5, max_digits=7) + payout: Decimal = Field(gt=0, le=100, decimal_places=5, max_digits=7) + + loi: float = Field( + description="Same as COMPLETION_TIME, but using the 60th percentile. Also in seconds." + "This is generally used within offerwall creation as a more accurate prediction.", + gt=0, + le=120 * 60, + ) + + # range is 0<->Inf (exclusive), but generally between 0 and single digits + score: NonNegativeFloat = Field( + description="This is the score as outputted by the scoring function", + default=0, + ) + + # also called various places as "p" + scaled_score: float = Field( + ge=0, + le=1, + description="used for offerwall stuff, range 0<->1", + default=0, + ) + + # --- These 3 are generally used for "SoftPair" offerwalls only. However, + # in all other types of offerwalls, by default, the pair type is + # unconditional, we just never check/use this for anything. + pair_type: SoftPairResultType = Field(default=SoftPairResultType.UNCONDITIONAL) + + # The set of marketplace's question codes (internal id) that are unknown. + # This should only be set it SoftPairResultType is conditional + unknown_mp_question_ids: Optional[Set[str]] = Field(default=None) + + # Question ids (from marketplace_question table) for the questions that + # will be asked (that would fulfill the unknown questions specified in + # unknown_mp_question_ids) + unknown_question_ids: Optional[Set[UUIDStr]] = Field(default=None) + + # ---- Soft Pair end ---- + + is_recontact: bool = Field(default=False) + + @field_validator("cpi", mode="before") + def cpi_from_float(cls, v: Decimal) -> Decimal: + return Decimal(v).quantize(Decimal("0.00000")) + + @property + def unknown_mp_qids(self) -> Optional[Set[str]]: + # marketplace's curie-formatted question IDs that are unknown + return ( + {self.source + ":" + q for q in self.unknown_mp_question_ids} + if self.unknown_mp_question_ids is not None + else None + ) + + def to_row(self) -> Dict[str, Any]: + d = self.model_dump(mode="json") + d["id_code"] = self.id_code + return d + + +class ScoredTaskResults(BaseModel): + tasks: List[ScoredTaskResult] = Field() + + @property + def availability_count(self) -> NonNegativeInt: + return len(self.tasks) + + def to_pandas(self) -> pd.DataFrame: + columns = ( + list(ScoredTaskResult.model_fields.keys()) + + list(ScoredTaskResult.model_computed_fields.keys()) + + ["id_code"] + ) + df = pd.DataFrame([x.to_row() for x in self.tasks], columns=columns) + df["payout"] = df["payout"].astype(float) + df["cpi"] = df["cpi"].astype(float) + return df + + def take_top(self, n=100) -> List[ScoredTaskResult]: + return sorted(self.tasks, key=lambda x: x.score, reverse=True)[:n] + + +class OfferwallBucket(BaseModel): + """ + See also py-utils: models.legacy.bucket: Bucket. That is used only in + handling API responses. This class is used internally to handle offerwall + creation/management. + """ + + model_config = ConfigDict( + extra="forbid", validate_assignment=True, ser_json_timedelta="float" + ) + + id: UUIDStr = Field( + description="Unique identifier this particular bucket", + examples=["5ba2fe5010cc4d078fc3cc0b0cc264c3"], + default_factory=lambda: uuid4().hex, + ) + uri: Optional[HttpsUrl] = Field( + examples=[ + "https://task.generalresearch.com/api/v1/52d3f63b2709/797df4136c604a6c8599818296aae6d1/?i" + "=5ba2fe5010cc4d078fc3cc0b0cc264c3&b=test&66482fb=e7baf5e" + ], + description="The URL to send a respondent into. Must not edit this URL in any way", + default=None, + ) + + tasks: List[ScoredTaskResult] = Field() + + category: List[CategoryAssociation] = Field(default_factory=list) + + # Used only in marketplace offerwall + source: Optional[Source] = Field(default=None) + source_name: Optional[str] = Field(default=None) + + # Normally these are calculated. However, in some offerwalls we duplicate + # buckets, so they're not "true" calculated values. + custom_min_payout: ( + Annotated[Decimal, Field(max_digits=5, decimal_places=2, ge=0, le=100)] | None + ) = Field( + description="Custom: Min payout across all tasks", + default=None, + ) + custom_q1_duration: Optional[float] = Field( + description="Custom: Q1 loi across all tasks", + default=None, + gt=0, + le=120 * 60, + ) + + quality_score: float = Field(default=0) + + eligibility_criteria: Optional[Tuple[SurveyEligibilityCriterion, ...]] = Field( + description="The reasons the user is eligible for tasks in this bucket", + default=None, + ) + eligibility_explanation: Optional[str] = Field( + default=None, + description="Human-readable text explaining a user's eligibility for tasks in this bucket", + examples=[ + "You are a **47-year-old** **white** **male** with a *college degree*, who's employer's retirement plan is **Fidelity Investments**." + ], + ) + + @property + def missing_questions(self) -> Set[UUIDStr]: + # Used only in softpair. + # The question id is the question's uuid (in the marketplace_question table / UpkQuestion.id) + # It is just the set union of task.softpair.question_ids for all tasks in this bucket. + if self.tasks[0].pair_type != SoftPairResultType.CONDITIONAL: + return set() + mq = set() + for task in self.tasks: + if task.unknown_question_ids: + mq.update(task.unknown_question_ids) + return mq + + @property + def default_quality_score(self) -> float: + # uses the euclidean norm, which is more influenced by outliers + score = np.array([x.score for x in self.tasks][:5]) + return float(np.sqrt((score**2).sum())) + + @property + def payout(self) -> Optional[Decimal]: + # The payout is the Min payout across all tasks + return min([x.payout for x in self.tasks], default=None) + + @property + def loi(self) -> Optional[float]: + # The loi is the Max LOI across all tasks + return max([x.loi for x in self.tasks], default=None) + + @property + def min_payout(self) -> Decimal: + return self.payout_summary.min + + @property + def max_payout(self) -> Decimal: + return self.payout_summary.max + + @property + def min_duration(self) -> int: + return self.duration_summary.min + + @property + def max_duration(self) -> int: + return self.duration_summary.max + + @property + def sns(self) -> List[str]: + return [t.id_code for t in self.tasks] + + @property + def duration_summary(self) -> DurationSummary: + # TODO: we could cache these and then have a validator that runs to + # update them if the tasks change? idk + # There shouldn't ever be GRS in here anyways, right? + durations = [res.loi for res in self.tasks if res.source != Source.GRS] + durations = durations if durations else [0] + min_duration, q1_duration, q2_duration, q3_duration, max_duration = np.quantile( + durations, [0, 0.25, 0.5, 0.75, 1] + ) + mean_duration = round(statistics.mean(durations)) + return DurationSummary( + min=round(min_duration), + q1=round(q1_duration), + q2=round(q2_duration), + q3=round(q3_duration), + max=round(max_duration), + mean=mean_duration, + ) + + @property + def payout_summary(self) -> PayoutSummaryDecimal: + return PayoutSummaryDecimal( + min=Decimal(self.payout_summary_int.min) / 100, + max=Decimal(self.payout_summary_int.max) / 100, + q1=Decimal(self.payout_summary_int.q1) / 100, + q2=Decimal(self.payout_summary_int.q2) / 100, + q3=Decimal(self.payout_summary_int.q3) / 100, + mean=Decimal(self.payout_summary_int.mean) / 100, + ) + + @property + def payout_summary_int(self) -> PayoutSummary: + payouts = [ + round(res.payout * 100) for res in self.tasks if res.source != Source.GRS + ] + payouts = payouts if payouts else [0] # so min, max, quantile doesnt fail + min_payout, q1_payout, q2_payout, q3_payout, max_payout = np.quantile( + payouts, [0, 0.25, 0.5, 0.75, 1] + ) + mean_payout = round(statistics.mean(payouts)) + return PayoutSummary( + min=round(min_payout), + q1=round(q1_payout), + q2=round(q2_payout), + q3=round(q3_payout), + max=round(max_payout), + mean=mean_payout, + ) + + @property + def eligibility(self) -> Optional[SoftPairResultType]: + # We're assuming there is never a conditional or ineligible survey + # after a unconditional. There can be unconditional surveys + # after conditional surveys, in which case the bucket is still + # conditional. + if self.tasks[0].pair_type is not None: + pair_type = self.tasks[0].pair_type + if pair_type in { + SoftPairResultType.UNCONDITIONAL, + SoftPairResultType.CONDITIONAL, + SoftPairResultType.INELIGIBLE, + }: + return pair_type + else: + raise ValueError(f"Unexpected pair_type {pair_type}") + + @property + def eligibility_str(self) -> Optional[Eligibility]: + return ( + { + SoftPairResultType.UNCONDITIONAL: "unconditional", + SoftPairResultType.CONDITIONAL: "conditional", + SoftPairResultType.INELIGIBLE: "ineligible", + }[self.eligibility] + if self.eligibility is not None + else None + ) + + def to_legacy_bucket(self) -> LegacyBucket: + # The legacy bucket is used in the Session model. I don't want to change it now, + # but there's no reason we couldn't + return LegacyBucket( + loi_min=timedelta(seconds=self.duration_summary.min), + loi_max=timedelta(seconds=self.duration_summary.max), + loi_q1=timedelta(seconds=self.duration_summary.q1), + loi_q2=timedelta(seconds=self.duration_summary.q2), + loi_q3=timedelta(seconds=self.duration_summary.q3), + user_payout_min=self.payout_summary.min, + user_payout_max=self.payout_summary.max, + user_payout_q1=self.payout_summary.q1, + user_payout_q2=self.payout_summary.q2, + user_payout_q3=self.payout_summary.q3, + ) + + def generate_bucket_entry_url( + self, user: User, request_id: Optional[str] = None + ) -> None: + product_id = user.product_id + product_user_id = user.product_user_id + base_enter_url = ( + f"https://task.generalresearch.com/api/v1/52d3f63b2709/{product_id}/?" + ) + if ( + self.eligibility is None + or self.eligibility == SoftPairResultType.UNCONDITIONAL + ): + self.uri = generate_offerwall_entry_url( + base_enter_url, self.id, product_user_id, request_id=request_id + ) + + return None + + # def __repr__(self): + # exclude = { + # "PRESCREEN_CONVERSION_ALPHA", + # "PRESCREEN_CONVERSION_BETA", + # "CONVERSION_ALPHA", + # "CONVERSION_BETA", + # "COMPLETION_TIME_MU", + # "COMPLETION_TIME_SIGMA", + # "COMPLETION_TIME_LOG", + # "DROPOFF_RATE_ALPHA", + # "DROPOFF_RATE_BETA", + # "IS_MOBILE_ELIGIBLE_ALPHA", + # "IS_MOBILE_ELIGIBLE_BETA", + # "IS_DESKTOP_ELIGIBLE_ALPHA", + # "IS_DESKTOP_ELIGIBLE_BETA", + # "IS_TABLET_ELIGIBLE_ALPHA", + # "IS_TABLET_ELIGIBLE_BETA", + # "cpi", + # "source", + # "internal_id", + # "is_recontact", + # "IS_MOBILE_ELIGIBLE", + # "IS_DESKTOP_ELIGIBLE", + # "IS_TABLET_ELIGIBLE", + # "USER_REPORT_COEFF", + # "PRESCREEN_CONVERSION", + # } + # return json.dumps( + # self.model_dump(mode="json", exclude={"tasks": {"__all__": exclude}}), + # indent=4, + # ) + + +class OfferwallBase(BaseModel): + model_config = ConfigDict( + extra="forbid", validate_assignment=True, ser_json_timedelta="float" + ) + + id: UUIDStr = Field( + description="Unique identifier for this offerwall", + default_factory=lambda: uuid4().hex, + ) + offerwall_type: OfferWallType = Field() + buckets: List[OfferwallBucket] = Field() + + # Note: this != the sum(len(tasks) in buckets) b/c we filter out a lot + availability_count: int = Field(default=0, description="Number of available tasks") + attempted_live_eligible_count: NonNegativeInt = Field( + description=( + "Number of currently live opportunities for which the respondent " + "meets all eligibility criteria but is excluded due to a prior attempt. " + "Only includes surveys that are still live and otherwise eligible; " + "does not include previously attempted surveys that are no longer available." + ), + examples=[7], + default=0, + ) + offerwall_reasons: List[OfferwallReason] = Field( + default_factory=list, + description=( + "Explanations describing why so many or few opportunities are available." + ), + examples=[[OfferwallReason.USER_BLOCKED, OfferwallReason.UNDER_MINIMUM_AGE]], + ) + + # Contains the full info about any questions in any bucket's + # missing_questions. + questions: List[UpkQuestion] = Field(default_factory=list) + + @property + def offerwall_type_class(self) -> OfferWallTypeClass: + return OFFERWALL_TYPE_CLASS[self.offerwall_type] + + @property + def task_count(self) -> NonNegativeInt: + return sum(len(b.tasks) for b in self.buckets) + + def generate_bucket_entry_urls(self, user: User, request_id: str) -> None: + for bucket in self.buckets: + bucket.generate_bucket_entry_url(user=user, request_id=request_id) + + return None diff --git a/generalresearch/models/thl/offerwall/behavior.py b/generalresearch/models/thl/offerwall/behavior.py new file mode 100644 index 0000000..e8da334 --- /dev/null +++ b/generalresearch/models/thl/offerwall/behavior.py @@ -0,0 +1,45 @@ +from typing import Any, Dict, Literal + +from pydantic import Field, BaseModel + + +class OfferWallBehavior(BaseModel): + id: str = Field() + name: str = Field() + kwargs: Dict[str, Any] = Field(default_factory=dict) + + +OFFERWALL_BEHAVIOR_PRESETS = [ + OfferWallBehavior( + id="0adc081e", + name="Best for New Users", + kwargs={ + "longfail_factor_adj": 1, + "conversion_factor_adj": 1.5, + "dropoffrate_factor_adj": 1, + }, + ), + OfferWallBehavior( + id="626984a8", name="Dopamine Hit", kwargs={"conversion_factor_adj": 2} + ), + OfferWallBehavior( + id="e3259520", + name="Optimal Surveys!", + kwargs={ + "longfail_factor_adj": 0, + "conversion_factor_adj": 0, + "dropoffrate_factor_adj": 0, + }, + ), + OfferWallBehavior( + id="ffbd76b8", + name="Low Frustration", + kwargs={ + "longfail_factor_adj": 2, + "conversion_factor_adj": 0.5, + "dropoffrate_factor_adj": 2, + }, + ), +] +OFFERWALL_BEHAVIOR_PRESETS_DICT = {x.id: x for x in OFFERWALL_BEHAVIOR_PRESETS} +OfferWallBehaviorsType = Literal["0adc081e", "626984a8", "e3259520", "ffbd76b8"] diff --git a/generalresearch/models/thl/offerwall/bucket.py b/generalresearch/models/thl/offerwall/bucket.py new file mode 100644 index 0000000..dd93cd1 --- /dev/null +++ b/generalresearch/models/thl/offerwall/bucket.py @@ -0,0 +1,20 @@ +from typing import Optional +from urllib.parse import urlencode + + +def generate_offerwall_entry_url( + base_url: str, + obj_id: str, + bp_user_id: str, + request_id: Optional[str] = None, + nudge_id: Optional[str] = None, +): + # for an offerwall entry link, we need the clicked bucket_id and the request hash (so we know + # which GetOfferwall cache to get + query_dict = {"i": obj_id, "b": bp_user_id} + if request_id: + query_dict["66482fb"] = request_id + if nudge_id: + query_dict["5e0e0323"] = nudge_id + enter_url = base_url + urlencode(query_dict) + return enter_url diff --git a/generalresearch/models/thl/offerwall/cache.py b/generalresearch/models/thl/offerwall/cache.py new file mode 100644 index 0000000..6040175 --- /dev/null +++ b/generalresearch/models/thl/offerwall/cache.py @@ -0,0 +1,59 @@ +from datetime import datetime, timezone +from typing import Dict, Any, List, Optional + +from pydantic import BaseModel, Field + +from generalresearch.models import Source +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr +from generalresearch.models.thl.offerwall import OfferWallRequest +from generalresearch.models.thl.offerwall.base import ( + OfferwallBase, + TaskResult, + ScoredTaskResult, +) + + +class GetOfferWallCache(BaseModel): + """ + This object gets cached by thl-grpc when an offerwall request is made. If/when + the user enters a bucket, this object is read in thl-core. It is also + used if the offerwall needs to be "refreshed". + """ + + request: OfferWallRequest = Field() + request_id: str = Field() + offerwall: OfferwallBase = Field() + all_sids: List[str] = Field() + timestamp: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(timezone.utc) + ) + latest_ip_info: Dict[str, Any] = Field( + description="So we can easily check if user's IP info has changed" + ) + profiling_task: Optional[TaskResult] = Field( + description="Profiling task", default=None + ) + is_avg_offerwall: bool = Field() + + # These only get set once a bucket is clicked. + clicked_timestamp: Optional[AwareDatetimeISO] = Field(default=None) + clicked_bucket: Optional[UUIDStr] = Field(default=None) + + +class SessionInfoCache(BaseModel): + """ + This is used within thl-core to manage a session + """ + + # This starts out as just the tasks within the clicked bucket, but + # will get pruned as tasks are attempted + tasks: List[ScoredTaskResult] = Field() + + started: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + + # The count of attempts per marketplace + mp_retry_count: Dict[Source, int] = Field(default_factory=dict) + + hard_retry_count: int = Field(default=0) diff --git a/generalresearch/models/thl/pagination.py b/generalresearch/models/thl/pagination.py new file mode 100644 index 0000000..6679cb4 --- /dev/null +++ b/generalresearch/models/thl/pagination.py @@ -0,0 +1,22 @@ +from typing import Optional + +from math import ceil +from pydantic import BaseModel, Field, computed_field + + +class Page(BaseModel): + # Based on fastapi_pagination.Page + page: int = Field(default=1, ge=1, description="Page number") + size: int = Field(default=50, ge=1, le=100, description="Page size") + total: Optional[int] = Field( + default=None, ge=0, description="Total number of results" + ) + + @computed_field(description="Total number of pages") + def pages(self) -> Optional[int]: + if self.size == 0: + return 0 + elif self.total is not None: + return ceil(self.total / self.size) + else: + return None diff --git a/generalresearch/models/thl/payout.py b/generalresearch/models/thl/payout.py new file mode 100644 index 0000000..b6f880f --- /dev/null +++ b/generalresearch/models/thl/payout.py @@ -0,0 +1,353 @@ +import json +from datetime import datetime, timezone +from typing import Dict, Optional, Collection, List +from uuid import uuid4 + +from pydantic import ( + BaseModel, + Field, + computed_field, + field_validator, + PositiveInt, +) +from typing_extensions import Self + +from generalresearch.currency import USDCent +from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.thl.definitions import PayoutStatus +from generalresearch.models.thl.ledger import OrderBy +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.thl.wallet.cashout_method import ( + CashMailOrderData, +) +from generalresearch.redis_helper import RedisConfig + + +class PayoutEvent(BaseModel): + """Base Pydantic Model to represent the `event_payout` table + + This table supports multiple different kinds of "Payouts": + + - UserPayoutEvent - A User (survey or task taker) is requesting to + withdraw money from their balance + + - BusinessPayoutEvent - A Supplier gets paid out via ACH / Wire + + - BrokerageProductPayoutEvent - BusinessPayoutEvent is composed of + multiple BrokerageProductPayoutEvents. + """ + + uuid: UUIDStr = Field( + title="Payout Event Unique Identifier", + default_factory=lambda: uuid4().hex, + examples=["9453cd076713426cb68d05591c7145aa"], + ) + + debit_account_uuid: UUIDStr = Field( + description="The LedgerAccount.uuid that money is being requested from. " + "Thie User or Brokerage Product is retrievable through the " + "LedgerAccount.reference_uuid", + examples=["18298cb1583846fbb06e4747b5310693"], + ) + + cashout_method_uuid: UUIDStr = Field( + description="References a row in the account_cashoutmethod table. This " + "is the cashout method that was used to request this " + "payout. (A cashout is the same thing as a payout)", + examples=["a6dc1fc1bf934557b952f253dee12813"], + ) + + created: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + + # In the smallest unit of the currency being transacted. For USD, this + # is cents. + amount: PositiveInt = Field( + lt=2**63 - 1, + strict=True, + description="The USDCent amount int. This cannot be 0 or negative", + examples=[531], + ) + + status: Optional[PayoutStatus] = Field( + default=PayoutStatus.PENDING, + description=PayoutStatus.as_openapi(), + examples=[PayoutStatus.COMPLETE], + ) + + # Used for holding an external, payout-type-specific identifier + ext_ref_id: Optional[str] = Field(default=None) + payout_type: PayoutType = Field( + description=PayoutType.as_openapi(), examples=[PayoutType.ACH] + ) + + request_data: Dict = Field( + default_factory=dict, + description="Stores payout-type-specific information that is used to " + "request this payout from the external provider.", + ) + + order_data: Optional[Dict | CashMailOrderData] = Field( + default=None, + description="Stores payout-type-specific order information that is " + "returned from the external payout provider.", + ) + + def update( + self, + status: PayoutStatus, + ext_ref_id: Optional[str] = None, + order_data: Optional[Dict] = None, + ) -> None: + + self.check_status_change_allowed(status) + + # These 3 things are the only modifiable attributes + self.status = status + self.ext_ref_id = ext_ref_id + self.order_data = order_data + + return None + + def check_status_change_allowed(self, status: PayoutStatus) -> None: + + # We may not be changing the status when this method gets called. It's + # possible to be called when we're updating other attributes so + # allow immediate bypass if it isn't actually different. + if self.status == status: + return None + + if self.status in { + PayoutStatus.REJECTED, + PayoutStatus.CANCELLED, + PayoutStatus.COMPLETE, + }: + raise ValueError(f"status {self.status} is final. No changes allowed") + + if self.status == PayoutStatus.PENDING: + assert status != PayoutStatus.PENDING, "status is already PENDING!" + + elif self.status == PayoutStatus.APPROVED: + assert status in { + PayoutStatus.FAILED, + PayoutStatus.COMPLETE, + }, f"status APPROVED can only be FAILED or COMPLETED, not {status}" + + elif self.status == PayoutStatus.FAILED: + assert status in { + PayoutStatus.CANCELLED, + PayoutStatus.COMPLETE, + }, f"status FAILED can only be CANCELLED or COMPLETED, not {status}" + else: + raise ValueError("this shouldn't happen") + + # --- ORM --- + + def model_dump_mysql(self, *args, **kwargs) -> Dict: + d = self.model_dump(mode="json", *args, **kwargs) + + if "created" in d: + d["created"] = self.created.replace(tzinfo=None) + + if d.get("request_data") is not None: + d["request_data"] = json.dumps(self.request_data) + + if d.get("order_data") is not None: + if isinstance(self.order_data, dict): + d["order_data"] = json.dumps(self.order_data) + else: + d["order_data"] = self.order_data.model_dump_json() + + return d + + +class UserPayoutEvent(PayoutEvent): + """A user has requested to be paid from their wallet balance.""" + + # These two fields are copied here from the LedgerAccount through the + # debit_account_uuid for convenience. They will get populated if the + # PayoutEventManager retrieves a PayoutEvent from the db. + # Requires joining on: + # - accounting_cashoutmethod + # - ledger_account + account_reference_type: Optional[str] = Field(default=None) + account_reference_uuid: Optional[UUIDStr] = Field(default=None) + + # By default, this will just be the cashout_method.name. This also is + # populated from the db and so does not need to be set (there is no + # `description` field in event_payout) + description: Optional[str] = Field(default=None) + + @field_validator("payout_type", mode="before") + @classmethod + def normalize_enum(cls, v): + if isinstance(v, str): + try: + return PayoutType[v.upper()] + except KeyError: + raise ValueError(f"Invalid payout_type: {v}") + return v + + +class BrokerageProductPayoutEvent(PayoutEvent): + """The amount + + - created: When the Brokerage Product was paid out + """ + + product_id: UUIDStr = Field( + description="The Brokerage Product that was paid out", + examples=["1108d053e4fa47c5b0dbdcd03a7981e7"], + ) + + @computed_field( + return_type=PayoutType, + description=PayoutType.as_openapi(), + examples=[PayoutType.ACH], + ) + @property + def method(self) -> PayoutType: + return self.payout_type + + @computed_field(return_type=USDCent, examples=["$10,000.000"]) + @property + def amount_usd(self) -> USDCent: + return USDCent(self.amount) + + @computed_field(return_type=str, examples=["$10,000.000"]) + @property + def amount_usd_str(self) -> str: + return self.amount_usd.to_usd_str() + + # --- ORM --- + + @classmethod + def from_payout_event( + cls, + pe: PayoutEvent, + account_product_mapping: Optional[Dict[UUIDStr, UUIDStr]] = None, + redis_config: Optional[RedisConfig] = None, + ) -> Self: + + if account_product_mapping is None: + rc = redis_config.create_redis_client() + account_product_mapping: Dict = rc.hgetall(name="pem:account_to_product") + assert isinstance(account_product_mapping, dict) + assert pe.uuid in account_product_mapping.keys() + + d = pe.model_dump() + d["product_id"] = account_product_mapping[pe.debit_account_uuid] + return cls.model_validate(d) + + @classmethod + def from_payout_events( + cls, + payout_events: Collection[PayoutEvent], + order_by=OrderBy, + account_product_mapping: Optional[Dict[UUIDStr, UUIDStr]] = None, + redis_config: Optional[RedisConfig] = None, + ) -> List[Self]: + + if account_product_mapping is None: + rc = redis_config.create_redis_client() + account_product_mapping: Dict = rc.hgetall(name="pem:account_to_product") + assert isinstance(account_product_mapping, dict) + + res = [] + for pe in payout_events: + res.append( + cls.from_payout_event( + pe=pe, account_product_mapping=account_product_mapping + ) + ) + + match order_by: + case OrderBy.ASC: + sorted_list = sorted(res, key=lambda x: x.created, reverse=False) + case OrderBy.DESC: + sorted_list = sorted(res, key=lambda x: x.created, reverse=True) + case _: + raise ValueError("Invalid order provided..") + + return sorted_list + + +class BusinessPayoutEvent(BaseModel): + """A single ACH or Wire event to a Business Bank Account""" + + bp_payouts: List[BrokerageProductPayoutEvent] = Field( + description="Here is the list of Brokerage Product Payouts that" + "this Business Payout includes.", + min_length=1, + ) + + @computed_field( + title="Amount", + description="The amount issued to the Bank Account", + examples=[19_823_43], + return_type=USDCent, + ) + @property + def amount(self) -> USDCent: + return USDCent(sum([p.amount for p in self.bp_payouts])) + + @computed_field( + title="Amount USD Str", + description="The amount issued to the Bank Account as a USD string", + examples=["$19,823.43"], + return_type=str, + ) + @property + def amount_usd_str(self) -> str: + return self.amount.to_usd_str() + + @computed_field( + title="Created", + description="This is equal to the created time of the first" + "Brokerage Product Payout Event.", + return_type=AwareDatetimeISO, + ) + @property + def created(self) -> AwareDatetimeISO: + return self.bp_payouts[0].created + + @computed_field( + title="Line Items", + description="The number of sub-payments", + return_type=PositiveInt, + ) + @property + def line_items(self): + return len(self.bp_payouts) + + @computed_field( + title="External Reference ID", + description="ACH Transaction ID", + return_type=Optional[str], + ) + @property + def ext_ref_id(self): + return self.bp_payouts[0].ext_ref_id + + # --- Validators --- + + @field_validator("bp_payouts", mode="before") + @classmethod + def normalize_enum(cls, v): + """This can be a list of Instances or Python Dictionaries depending + on how it's initialized. + """ + + assert isinstance(v, list) + + def get_field(obj, field): + if isinstance(obj, dict): + return obj.get(field) + return getattr(obj, field, None) + + assert all( + get_field(i, "ext_ref_id") == get_field(v[0], "ext_ref_id") for i in v + ), "Not all group values are the same" + + return v diff --git a/generalresearch/models/thl/payout_format.py b/generalresearch/models/thl/payout_format.py new file mode 100644 index 0000000..f108e46 --- /dev/null +++ b/generalresearch/models/thl/payout_format.py @@ -0,0 +1,96 @@ +import decimal +import re + +from pydantic import AfterValidator, Field +from typing_extensions import Annotated + +# Matches only digits, parenthesis, + , -, *, / and the string payout. +xform_format_re = re.compile(pattern=r"^[\d()+\-*/.]*payout[\d()+\-*/.]*$") + + +def validate_payout_format(payout_format: str) -> str: + # We validate the payout format by just trying to use it with 4 example + # numbers. Each step is checked along the way. This is not really any + # slower to do for real. + assert len(payout_format) < 40, "invalid format" + for payout_int in [0, 1, 200, 2245]: + format_payout_format(payout_format, payout_int) + return payout_format + + +def format_payout_format(payout_format: str, payout_int: int) -> str: + """ + Generate a str representation of a payout. Typically, this would be displayed to a user. + :param payout_format: see BPC_DEFAULTS.payout_format + :param payout_int: The actual value in integer usd cents. + """ + assert isinstance(payout_int, int), "payout_int must be an integer" + try: + lidx = payout_format.index("{") + except ValueError: + raise ValueError("Must wrap payout transformation in {}") + + try: + ridx = payout_format.index("}") + except ValueError: + raise ValueError("Must wrap payout transformation in {}") + + prefix = payout_format[:lidx] + suffix = payout_format[ridx + 1 :] + inside = payout_format[lidx + 1 : ridx] + + try: + xform, formatstr = inside.split(":") + except ValueError as e: + raise ValueError( + "Payout format string must contain ':' to distinguish between transformations and formatting." + ) + + assert xform_format_re.match(xform) is not None, "Invalid transformation" + + try: + # if we only cared about strings, could do: set(xform) <= allowed + # ()+-*/, "payout" + + # x = re.search("(payout)+[\d\(\)\+\-\*\/\ ]*", xform) + # print("xform:", xform, bool(x)) + + payout = decimal.Decimal(eval(xform, {"payout": payout_int})) + + except NameError as e: + raise ValueError("Payout format string must contain 'payout' variable.") + + except ZeroDivisionError as e: + raise ValueError("Cannot divide by zero.") + + except TypeError as e: + # "{payout()*1:}" - TypeError: 'int' object is not callable + raise ValueError("Invalid type reference.") + except Exception as e: + raise ValueError(f"Invalid payout transformation") + + formatstr = f"{{:{formatstr}}}" + + try: + payout_str = prefix + formatstr.format(payout) + suffix + except ValueError: + raise ValueError("Invalid format string.") + return payout_str + + +description = """ +The format describing the str representation of a payout. Typically, this would be displayed to a user. +The payout_format is similar to python format string with a subset of functionality supported. +Only float with a precision are supported along with an optional comma for a thousands separator. +In addition, a mathematical operator can be applied, such as dividing by 100. +Examples are shown assuming payout = 100 (one dollar). +- "{payout*10:,.0f} Points" -> "1,000 Points" +- "${payout/100:.2f}" -> "$1.00" +""" +examples = ["{payout*10:,.0f} Points", "${payout/100:.2f}", "{payout:.0f}"] + +PayoutFormatField = Field(description=description, examples=examples) +PayoutFormatOptionalField = Field( + default=None, description=description, examples=examples +) +PayoutFormatType = Annotated[str, AfterValidator(validate_payout_format)] diff --git a/generalresearch/models/thl/product.py b/generalresearch/models/thl/product.py new file mode 100644 index 0000000..8f1c668 --- /dev/null +++ b/generalresearch/models/thl/product.py @@ -0,0 +1,1427 @@ +from __future__ import annotations + +import copy +import inspect +import json +import warnings +from collections import defaultdict +from decimal import Decimal +from enum import Enum +from functools import partial, cached_property +from typing import ( + Optional, + List, + Callable, + Any, + Literal, + Dict, + Set, + TYPE_CHECKING, +) +from urllib.parse import parse_qs, urlsplit, urlencode, urlunsplit +from uuid import uuid4 + +import math +import pandas as pd +from dask.distributed import Client +from pydantic import ( + BaseModel, + Field, + model_validator, + field_validator, + field_serializer, + NonNegativeInt, + NonNegativeFloat, + computed_field, + ConfigDict, + PositiveInt, + PositiveFloat, +) +from pydantic.json_schema import SkipJsonSchema +from typing_extensions import Self + +from generalresearch.currency import USDCent +from generalresearch.decorators import LOG +from generalresearch.models import Source +from generalresearch.models.custom_types import ( + UUIDStr, + AwareDatetimeISO, + HttpsUrlStr, + CountryISOLike, +) +from generalresearch.models.thl.ledger import LedgerAccount +from generalresearch.models.thl.payout_format import ( + PayoutFormatType, + description as payout_format_description, + examples as payout_format_examples, + format_payout_format, +) +from generalresearch.models.thl.supplier_tag import SupplierTag +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.utils import decimal_to_usd_cents +from generalresearch.redis_helper import RedisConfig + +if TYPE_CHECKING: + from generalresearch.models.thl.payout import ( + BrokerageProductPayoutEvent, + ) + from generalresearch.incite.base import GRLDatasets + + from generalresearch.managers.thl.payout import ( + BrokerageProductPayoutEventManager, + ) + from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, + ) + from generalresearch.incite.mergers.pop_ledger import PopLedgerMerge + from generalresearch.models.thl.finance import ( + ProductBalances, + POPFinancial, + ) + from generalresearch.models.thl.user import User + + +# fmt: off +GRS_SKINS = [ + "mmfwcl.com", "profile.generalresearch.com", + "eureka.generalresearch.com", "opinioncapital.generalresearch.com", + "freeskins.generalresearch.com", "surveys.freeskins.com", + "cheddar.generalresearch.com", "drop.generalresearch.com", + "mobrog.generalresearch.com", "freecash.generalresearch.com", + "inbrain.generalresearch.com", "just.generalresearch.com", + "300large.generalresearch.com", "samplicious.generalresearch.com", + "l.generalresearch.com", "surveys2skins.generalresearch.com", + "surveyjunkie.generalresearch.com", "opinionhero.generalresearch.com", + "ozone.generalresearch.com", "adbloom.generalresearch.com", + "prime.generalresearch.com", "rakuten.generalresearch.com", + "pch.generalresearch.com", "solipay.generalresearch.com", + "widget.generalresearch.com", "inventory.adbloom.co", + "voooice.generalresearch.com", "kashkick.generalresearch.com", + "splendid.generalresearch.com", "monlix.generalresearch.com", + "freeward.generalresearch.com", "surveys.mnlx.me", + "bananabucks.generalresearch.com", "cashcamel.generalresearch.com", + "surveypop.generalresearch.com", "surveymagic.generalresearch.com", + "surveyspin.generalresearch.com", "cube.generalresearch.com", + "innovate.generalresearch.com", "timebucks.generalresearch.com", + "kaching.generalresearch.com", "precision.generalresearch.com", + "bitburst.generalresearch.com", "talk.generalresearch.com", + "theorem.generalresearch.com", "surveys.timewallresearch.com", + "gmo.generalresearch.com", "pinchme.generalresearch.com" +] + + +# fmt: on + + +class OfferwallConfig(BaseModel): + pass + + +class ProfilingConfig(BaseModel): + # called "harmonizer_config" in the old bpc version + + enabled: bool = Field( + default=True, + description="If False, the harmonizer/profiling system is not used at all. This should " + "never be False unless special circumstances", + ) + + grs_enabled: bool = Field( + default=True, + description="""If grs_enabled is False, and is_grs is passed in the profiling-questions call, + then don't actually return any questions. This allows a client to hit the endpoint with no limit + and still get questions. In effect, this means that we'll redirect the user through the GRS + system but won't present them any questions.""", + ) + + n_questions: Optional[PositiveInt] = Field( + default=None, + description="Use to hard code the number of questions to ask. None means use default algorithm.", + ) + + max_questions: PositiveInt = Field( + default=10, + description="The max number of questions we would ask in a session", + ) + + avg_question_count: PositiveFloat = Field( + default=5, + description="The average number of questions to ask in a session", + ) + + # Don't set this to 0, use enabled + task_injection_freq_mult: PositiveFloat = Field( + default=1, + description="Scale how frequently we inject profiling questions, relative to the default." + "1 is default, 2 is twice as often. 10 means always. 0.5 half as often", + ) + + non_us_mult: PositiveFloat = Field( + default=2, + description="Non-us multiplier, used to increase freq and length of profilers in all non-us countries." + "This value is multiplied by task_injection_freq_mult and avg_question_count.", + ) + + hidden_questions_expiration_hours: PositiveInt = Field( + default=7 * 24, + description="How frequently we should refresh hidden questions", + ) + + # todo: nothing uses this + # consent: Dict = Field() + # # Used to configure consent questions + # "consent": { + # "enabled": False, + # "property_code": "" # gr:consent_v1 + # } + # } + + +class UserHealthConfig(BaseModel): + # Users in these countries are "blocked". Blocked in quotes because + # the user doesn't actually get blocked, they just are treated like they + # are blocked. + banned_countries: List[CountryISOLike] = Field(default_factory=list) + + # Decide if a user can be blocked for IP-related triggers such as sharing IPs + # and location history. This should eventually be deprecated and replaced + # with something with more specificity. + allow_ban_iphist: bool = Field(default=True) + + # These are only checked by ym-user-predict, which I'm not sure even works properly. + # To be deprecated ... don't even use them. + userprofit_cutoff: Optional[Decimal] = Field(default=None, exclude=True) + recon_cutoff: Optional[float] = Field(default=None, exclude=True) + droprate_cutoff: Optional[float] = Field(default=None, exclude=True) + conversion_cutoff: Optional[float] = Field(default=None, exclude=True) + + @field_validator("banned_countries", mode="after") + def sort_values(cls, values: List[str]): + return sorted(values) + + +class OfferWallRequestYieldmanParams(BaseModel): + # model_config = ConfigDict(extra='forbid') + # keys: use_stats, use_harmonizer, allow_pii, add_default_lang_eng, first_n_completes_easier_per_day are + # ignored/deprecated + # allow_pii: bool = Field(default=True, description="Allow tasks that request PII. This actually does nothing.") + + # see thl-grpc:yield_management.scoring.score() for more info + conversion_factor_adj: float = Field( + default=0.0, + description="Centered around 0. Higher results in higher weight given to conversion (in the scoring function)", + ) + + dropoffrate_factor_adj: float = Field( + default=0.0, + description="Centered around 0. Higher results in higher penalty given to dropoffs (in the scoring function)", + ) + + longfail_factor_adj: float = Field( + default=0.0, + description="Centered around 0. Higher results in higher penalty given to long fail (in the scoring function)", + ) + + recon_factor_adj: float = Field( + default=0.0, + description="Centered around 0. Higher results in higher penalty given to recons (in the scoring function)", + ) + + recon_likelihood_max: float = Field( + default=0.8, description="Tolerance for recon likelihood (0 to 1)" + ) + + def update(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class OfferWallCategoryRequest(BaseModel): + # Only include these categories + adwords_category: Optional[List[str]] = Field(default=None, examples=[["45", "65"]]) + category: Optional[List[str]] = Field( + default=None, examples=[["98c137e4e90a4d92ac6c00e523eb1b50"]] + ) + # Exclude these categories + exclude_adwords_category: Optional[List[str]] = Field( + default=None, examples=[["1558"]] + ) + exclude_category: Optional[List[str]] = Field( + default=None, + examples=[ + [ + "21536f160f784189be6194ca894f3a65", + "7aa8bf4e71a84dc3b2035f93f9f9c77e", + ] + ], + ) + + @property + def any(self): + return bool( + self.adwords_category + or self.category + or self.exclude_adwords_category + or self.exclude_category + ) + + +class YieldManConfig(BaseModel): + category_request: OfferWallCategoryRequest = Field( + default_factory=OfferWallCategoryRequest + ) + scoring_params: OfferWallRequestYieldmanParams = Field( + default_factory=OfferWallRequestYieldmanParams + ) + + +class SourcesConfig(BaseModel): + """Describes the marketplaces or sources that a BP can access and their + respective configs, + aka 'BP:Marketplace Configs' + """ + + model_config = ConfigDict(frozen=True) + + user_defined: List[SourceConfig] = Field(default_factory=list) + + @model_validator(mode="after") + def validate_user_defined(self): + cs = [c.name for c in self.user_defined] + assert len(cs) == len(set(cs)), "Can only have one SourceConfig per Source!" + return self + + @cached_property + def default_sources(self) -> List[SourceConfig]: + return [SourceConfig.model_validate({"name": s}) for s in Source] + + @cached_property + def sources(self) -> List[SourceConfig]: + # If a BP has no user_defined SourceConfigs, we use the default. Any + # defined in user_defined will replace the default for that + # SourceConfig.name + # This can be a cached_property because the class is frozen. If we ever + # change that, we should make this a property instead. + user_defined = {x.name: x for x in self.user_defined} + default = {x.name: x for x in self.default_sources} + default.update(user_defined) + return list(default.values()) + + +class PayoutConfig(BaseModel): + """Store configuration related to payouts, payout transformation, and user + payout formatting.""" + + payout_format: Optional[PayoutFormatType] = Field( + default=None, + description=payout_format_description, + examples=payout_format_examples, + ) + + payout_transformation: Optional[PayoutTransformation] = Field( + default=None, + description="How the BP's payout is converted to the User's payout", + ) + + @model_validator(mode="before") + @classmethod + def payout_format_default(cls, data: Any): + # If the BP's user payout_transformation is None, the payout_format + # should also be None. If payout_transformation is set, and + # payout_format is none, use $XX.YY + if data.get("payout_transformation") is None: + # Don't assert this b/c it'll fail b/c a lot of BPC in the db has + # this set for no reason + data["payout_format"] = None + else: + if data.get("payout_format") is None: + data["payout_format"] = "${payout/100:.2f}" + return data + + +class SessionConfig(BaseModel): + """Stores configuration related to the Session, a session being a users + experience attempting to do work. + """ + + max_session_len: int = Field( + default=600, + ge=60, + le=90 * 60, + description="The amount of time (in seconds) that a respondent may spend " + "attempting to get into a survey within a session.If NULL, " + "there is no limit.", + ) + + max_session_hard_retry: int = Field( + default=5, + ge=0, + description="The number of surveys that a respondent may attempt within a " + "session before the session is terminated.", + ) + + min_payout: Decimal = Field( + default=Decimal("0.14"), + description="""The minimum amount the user should be paid for a complete. If + no payout transformation is defined, the value is based on the BP's payout. + If a payout transformation is defined, the min_payout is applied on the + user's payout. Note, this is separate and distinct from the payout + transformation's min payout. The payout transformation's min_payout does not + care what the task's actual payout was. This min_payout will prevent + the user from being show any tasks that would pay below this amount.""", + examples=[Decimal("0.50")], + ) + + +class UserCreateConfig(BaseModel): + """Stores configuration for the user creation experience. + + The user creation limit is determined dynamically based on the median + daily completion rate. min_hourly_create_limit & + max_hourly_create_limit can be used to constrain the dynamically + determined rate limit within set values. + """ + + min_hourly_create_limit: NonNegativeInt = Field( + default=0, + description="The smallest allowed value for the hourly user create limit.", + ) + + max_hourly_create_limit: Optional[NonNegativeInt] = Field( + default=None, + description="The largest allowed value for the hourly user create " + "limit. If None, the hourly create limit is unconstrained.", + ) + + def clip_hourly_create_limit(self, limit: int) -> int: + limit = max(self.min_hourly_create_limit, limit) + if self.max_hourly_create_limit is not None: + limit = min(limit, self.max_hourly_create_limit) + return limit + + +class UserWalletConfig(BaseModel): + """ + Stores configuration for the user wallet handling + """ + + enabled: bool = Field( + default=False, description="If enabled, the users' wallets are managed." + ) + + # This field could go in supported_payout_types ---v + amt: bool = Field(default=False, description="Uses Amazon Mechanical Turk") + + supported_payout_types: Set["PayoutType"] = Field( + default={PayoutType.CASH_IN_MAIL, PayoutType.TANGO, PayoutType.PAYPAL} + ) + + min_cashout: Optional[Decimal] = Field( + default=None, + gt=0, + description="Minimum cashout amount. If enabled is True and no min_cashout is " + "set, will default to $0.01.", + examples=[Decimal("10.00")], + ) + + @field_serializer("supported_payout_types", when_used="json") + def serialize_supported_payout_types_in_order( + self, supported_payout_types: Set["PayoutType"] + ) -> Set["PayoutType"]: + return set(sorted(supported_payout_types)) + + @field_validator("min_cashout", mode="after") + @classmethod + def check_payout_decimal_places(cls, v: Decimal) -> Decimal: + if v is not None: + assert ( + v.as_tuple().exponent >= -2 + ), "Must have 2 or fewer decimal places ('XXX.YY')" + # explicitly make sure it is 2 decimal places, after checking that it is + # already 2 or less. + v = v.quantize(Decimal("0.00")) + return v + + @model_validator(mode="after") + def check_enabled(self): + if self.enabled is False: + assert self.amt is False, "amt can't be set if enabled is False" + assert ( + self.min_cashout is None + ), "min_cashout can't be set if enabled is False" + else: + if self.min_cashout is None: + self.min_cashout = Decimal("0.01") + return self + + +class PayoutTransformationPercentArgs(BaseModel): + pct: NonNegativeFloat = Field( + le=1.0, + description="The percentage of the bp_payout to pay the user", + examples=[0.5], + ) + + min_payout: Optional[Decimal] = Field( + default=None, + description="The minimum amount paid for a complete. Note: This does not " + "check that the actual payout was at least this amount.", + examples=[Decimal("0.50")], + ) + + max_payout: Optional[Decimal] = Field( + default=None, + description="The maximum amount paid for a complete", + examples=[Decimal("5.00")], + ) + + @field_validator("min_payout", "max_payout", mode="after") + @classmethod + def check_payout_decimal_places(cls, v: Decimal) -> Decimal: + if v is not None: + assert ( + v.as_tuple().exponent >= -2 + ), "Must have 2 or fewer decimal places ('XXX.YY')" + # explicitly make sure it is 2 decimal places, after checking that it is + # already 2 or less. + v = v.quantize(Decimal("0.00")) + return v + + @field_validator("pct", mode="after") + def validate_payout_transformation(cls, pct: float) -> float: + if pct >= 0.95: + warnings.warn("Are you sure you want to pay respondents >95% of CPI?") + + if pct == 0: + raise ValueError("Disable payout transformation if payout percentage is 0%") + + return pct + + +class PayoutTransformation(BaseModel): + """This model describe how the bp_payout is converted to the user_payout. + If None, the user_payout is None. + + If the user_wallet_enabled is `False`, the user_payout is used to + 1) know how to transform the expected payouts for offerwall buckets + (if min_payout is requested, this is based on the user_payout) + 2) show the user (using the payout_format) how much they made (in + the Task Status Response). + + If the user_wallet_enabled is `True`, then in addition to the above, the + user_payout is the amount actually paid to the user's wallet. + """ + + f: Literal["payout_transformation_percent", "payout_transformation_amt"] = Field( + description="The name of the transformation function to use." + ) + + kwargs: Optional[PayoutTransformationPercentArgs] = Field( + description="The kwargs to pass to the transformation function.", + examples=[{"pct": 0.50, "max_payout": "5.00"}], + default=None, + ) + + def get_payout_transformation_func(self) -> Callable: + """Returns a callable which transforms the bp_payout to the + user_payout. + """ + assert self.f in { + "payout_transformation_percent", + "payout_transformation_amt", + }, f"unsupported f: {self.f}" + if self.f == "payout_transformation_amt": + return self.payout_transformation_amt + else: + return partial( + self.payout_transformation_percent, **self.kwargs.model_dump() + ) + + def payout_transformation_percent( + self, + payout: Decimal, + pct: Decimal = 1, + min_payout: Decimal = 0, + max_payout: Optional[Decimal] = None, + ) -> Decimal: + """Payout transformation for user displayed values""" + if min_payout is None: + min_payout = Decimal(0) + pct = Decimal(pct) + + payout = Decimal(payout) + min_payout = Decimal(min_payout) + max_payout = Decimal(max_payout) if max_payout else None + + payout: Decimal = payout * pct + payout: Decimal = max([payout, min_payout]) + payout: Decimal = min([payout, max_payout]) if max_payout else payout + return payout + + def payout_transformation_amt( + self, payout: Decimal, user_wallet_balance: Optional[Decimal] = None + ) -> Decimal: + """Payout transformation for user displayed values""" + # If user_wallet_balance isn't passed, we are re-calculating this + # (display, adjustment) so ignore the 7-cent rounding. + if user_wallet_balance is None: + return self.payout_transformation_percent(payout=payout, pct=Decimal(".95")) + payout = Decimal(payout) + + payout: Decimal = payout * Decimal("0.95") + new_balance = payout + user_wallet_balance + # If the new_balance is <0, we aren't paying anything, so use the + # full amount + if new_balance < 0: + return payout + + amt = (5 * math.floor((int(new_balance * 100) - 2) / 5)) + 2 + rounded_new_balance = Decimal(amt / 100).quantize(Decimal("0.00")) + payout = rounded_new_balance - user_wallet_balance + if payout < Decimal(0): + return Decimal(0) + + return payout + + +class SourceConfig(BaseModel): + """ + This could also be named "BP:Marketplace Config", as it describes the + config for a BP on a single marketplace. + """ + + name: Source = Field() + active: bool = Field(default=True) + banned_countries: List[CountryISOLike] = Field(default_factory=list) + allow_mobile_ip: bool = Field(default=True) + + allow_pii_only_buyers: bool = Field( + default=False, + description="Allow Tasks from Buyers that want traffic that comes from " + "Suppliers that can identify their users. Only supported on " + "Pure Spectrum.", + ) + + allow_unhashed_buyers: bool = Field( + default=False, + description="Return Tasks from Buyers that don't have URL hashing " + "enabled. Only supported on Pure Spectrum.", + ) + + withhold_profiling: bool = Field( + default=False, + description="For some Products, we may have privacy agreements " + "prohibiting us from sharing information with the inventory" + "Source. If True, don't add MRPQ (Market Research Profiling" + "Question) onto the entry link.", + ) + + # Allows marketplace to return survey as eligible if there are unknown + # question where the user can answer any possible answer and still + # be eligible + pass_unconditional_eligible_unknowns: bool = Field( + default=True, description="Not used at the moment" + ) + + +class Scope(str, Enum): + GLOBAL = "global" + TEAM = "team" + PRODUCT = "product" + + +class IntegrationMode(str, Enum): + # We handle integration, get paid + PLATFORM = "platform" + # "external" credentials, we do not get paid for this activity + PASS_THROUGH = "pass_through" + + +class SupplyConfig(BaseModel): + """Describes the set of policies for how GRL can interact with marketplaces. + This is only used on the special "global product".""" + + model_config = ConfigDict(frozen=False, validate_assignment=True) + + policies: List[SupplyPolicy] = Field(default_factory=list) + + @property + def configs(self): + return self.policies + + @model_validator(mode="after") + def validate_scope_global(self): + gcs = [c.name for c in self.policies if c.scope == Scope.GLOBAL] + assert len(gcs) == len(set(gcs)), "Can only have one GLOBAL policy per Source" + return self + + @model_validator(mode="after") + def validate_scope_team(self): + team_names = [ + (c.name, team_id) + for c in self.policies + if c.scope == Scope.TEAM + for team_id in c.team_ids + ] + assert len(team_names) == len( + set(team_names) + ), "Can only have one TEAM policy per Source per Team" + return self + + @model_validator(mode="after") + def validate_scope_bp(self): + bp_names = [ + (c.name, product_id) + for c in self.policies + if c.scope == Scope.PRODUCT + for product_id in c.product_ids + ] + assert len(bp_names) == len( + set(bp_names) + ), "Can only have one PRODUCT policy per Source per BP" + return self + + @property + def global_scoped_policies(self): + return [c for c in self.policies if c.scope == Scope.GLOBAL] + + @property + def team_scoped_policies(self): + return [c for c in self.policies if c.scope == Scope.TEAM] + + @property + def product_scoped_policies(self): + return [c for c in self.policies if c.scope == Scope.PRODUCT] + + @property + def global_scoped_policies_dict(self) -> Dict[Source, SupplyPolicy]: + return {c.name: c for c in self.policies if c.scope == Scope.GLOBAL} + + @property + def team_scoped_policies_dict( + self, + ) -> Dict[str, Dict[Source, SupplyPolicy]]: + # str in top-level dict is the team_id + d = defaultdict(dict) + for c in self.team_scoped_policies: + for team_id in c.team_ids: + d[team_id][c.name] = c + return d + + @property + def product_scoped_policies_dict( + self, + ) -> Dict[str, Dict[Source, SupplyPolicy]]: + # str in top-level dict is the product_id + d = defaultdict(dict) + for c in self.product_scoped_policies: + for product_id in c.product_ids: + d[product_id][c.name] = c + return d + + def get_policies_for( + self, product_id: str, team_id: str + ) -> Dict[Source, SupplyPolicy]: + """ + Is there a config scoped to this product? If not, + Is there a config scoped to this team? If not, + Use global config. + """ + d = self.global_scoped_policies_dict.copy() + d.update(self.team_scoped_policies_dict.get(team_id, dict())) + d.update(self.product_scoped_policies_dict.get(product_id, dict())) + return d + + def get_config_for_product(self, product: Product) -> MergedSupplyConfig: + product_id = product.id + team_id = product.team_id + policy_dict = copy.deepcopy( + self.get_policies_for(product_id=product_id, team_id=team_id) + ) + # 'supply_dict' is the config GRL is allowed to use for this product/team. + # The specific product's SourcesConfig can still override some things. + sources_dict = {s.name: s for s in product.sources_config.sources} + + return MergedSupplyConfig( + policies=[ + SupplyPolicy.merge_source_config( + supply_policy=policy_dict[source], + source_config=sources_dict[source], + ) + for source in policy_dict.keys() + ] + ) + + +class SupplyPolicy(SourceConfig): + """ + One policy describing how GRL can interact with a marketplaces in a + certain way. This is only used on the special "global product", and then + internally in grpc logic. + """ + + address: List[str] = Field(description="address for the grpc GetOpps call") + + allow_vpn: bool = Field(default=False) + + distribute_harmonizer_active: bool = Field(default=True) + + supplier_id: Optional[str] = Field( + default=None, + description="For some inventory Sources, we may partition traffic using " + "different supplier accounts instead", + ) + + team_ids: Optional[List[UUIDStr]] = Field(default=None) + product_ids: Optional[List[UUIDStr]] = Field(default=None) + + integration_mode: IntegrationMode = Field(default=IntegrationMode.PLATFORM) + + @computed_field( + description="There must be only 1 GLOBAL config per Source. We can have more than " + "one TEAM/PRODUCT config per Source." + ) + @property + def scope(self) -> Scope: + if self.team_ids is not None: + return Scope.TEAM + if self.product_ids is not None: + return Scope.PRODUCT + return Scope.GLOBAL + + @classmethod + def merge_source_config( + cls, supply_policy: SupplyPolicy, source_config: SourceConfig + ) -> Self: + # This function could also be called "apply_bp_overrides". + # We have a SupplyConfig (which describes how GRL is allowed to + # interact with a marketplace). and we retrieved a BP's source_config + # (for this same Source), which can override certain properties. + # Do that here. + + assert supply_policy.name == source_config.name, "Must operate on same Source" + out_config = supply_policy.model_copy() + out_config.active = supply_policy.active and source_config.active + out_config.banned_countries = sorted( + set(supply_policy.banned_countries + source_config.banned_countries) + ) + out_config.allow_mobile_ip = source_config.allow_mobile_ip + out_config.allow_unhashed_buyers = source_config.allow_unhashed_buyers + out_config.allow_pii_only_buyers = source_config.allow_pii_only_buyers + out_config.withhold_profiling = source_config.withhold_profiling + out_config.pass_unconditional_eligible_unknowns = ( + source_config.pass_unconditional_eligible_unknowns + ) + return out_config + + +class MergedSupplyConfig(SupplyConfig): + """ + This is a supply config after it has been merged/harmonized/reconciled with + the Brokerage Product's SourcesConfig. This is what is used to do the + getOpps work. + """ + + # At this point, there will be one single policy per Source + # (vs in the global config, which lists possibly many policies per source (that + # are applied to different scopes)) + @model_validator(mode="after") + def validate_single_policy(self): + sources = [c.name for c in self.policies] + assert len(sources) == len(set(sources)), "Can only have one policy per Source" + return self + + +class Product(BaseModel, validate_assignment=True): + id: UUIDStr = Field( + default_factory=lambda: uuid4().hex, + description="Unique identifier of the Brokerage Product", + examples=["1108d053e4fa47c5b0dbdcd03a7981e7"], + ) + + id_int: SkipJsonSchema[Optional[PositiveInt]] = Field(default=None) + + name: str = Field( + min_length=3, + max_length=255, + description="Name of the Brokerage Product. Must be unique within a Team", + examples=["Website ABC"], + ) + + enabled: bool = Field( + default=True, + description="This is only used to hard block a Product in order to " + "immediately & safely protect against fraud entrances.", + ) + + payments_enabled: bool = Field( + default=True, + description="This is only to determine if ACH or Wire payments should " + "be made to the Produce.", + ) + + created: Optional[AwareDatetimeISO] = Field( + # TODO: make this non-nullable + default=None, + description="When the Product was created, this does necessarily mean " + "it started to retrieve traffic at that time.", + ) + + team_id: Optional[UUIDStr] = Field( + # TODO: make this non-nullable + default=None, + examples=["b96c1209cf4a4baaa27d38082421a039"], + description="The organization (group of generalresearch.com admin " + "accounts) that is allowed to modify and manage this" + "Product", + ) + + business_id: Optional[UUIDStr] = Field( + default=None, + examples=[uuid4().hex], + description="The legal business entity or individual that is " + "responsible for this account, and that receive Supplier" + "Payments for this Product's activity.", + ) + + tags: Set["SupplierTag"] = Field( + default_factory=set, + description="Tags which are used to annotate supplier traffic", + ) + + commission_pct: Decimal = Field( + default=Decimal("0.05"), decimal_places=5, max_digits=6, le=1, ge=0 + ) + + redirect_url: HttpsUrlStr = Field( + description="Where to redirect the user after finishing a session. When a " + "user get redirected back to the supplier, a query param will be " + "added with " + "the name 'tsid', and the value of the TSID for the session. For " + "example: " + "callback_uri: 'https://www.example.com/test/?a=1&b=2' " + "might result in the user getting redirected to: " + "'https://www.example.com/grl-callback/?a=1&b=2&tsid" + "=c6ab6ba1e75b44e2bf5aab00fc68e3b7'.", + examples=["https://www.example.com/grl-callback/?a=1&b=2"], + ) + + # This is called grs_domain in the BP table + harmonizer_domain: HttpsUrlStr = Field( + default="https://profile.generalresearch.com/", + description="This is the domain that is used for the GRS (General " + "Research Survey) platform. This is a simple javascript " + "application which may profile the respondent for any" + "profiling questions, along with collecting any browser" + "based security information. The value is a scheme+domain " + "only (no path).", + ) + + # We can do this b/c SourcesConfig & SupplyConfigs have different top-level keys, + # so it'll try to model validate with each in order. + sources_config: SourcesConfig | SupplyConfig = Field(default_factory=SourcesConfig) + + session_config: SessionConfig = Field(default_factory=SessionConfig) + + payout_config: PayoutConfig = Field(default_factory=PayoutConfig) + + user_wallet_config: UserWalletConfig = Field(default_factory=UserWalletConfig) + + user_create_config: UserCreateConfig = Field(default_factory=UserCreateConfig) + + # these are just empty placeholders + offerwall_config: OfferwallConfig = Field(default_factory=OfferwallConfig) + profiling_config: ProfilingConfig = Field(default_factory=ProfilingConfig) + user_health_config: UserHealthConfig = Field(default_factory=UserHealthConfig) + yield_man_config: YieldManConfig = Field(default_factory=YieldManConfig) + + # Initialization is deferred until unless it's called + # (see .prebuild_***()) + balance: Optional["ProductBalances"] = Field( + default=None, description="Product Balance" + ) + + payouts_total_str: Optional[str] = Field(default=None) + payouts_total: Optional[USDCent] = Field(default=None) + payouts: Optional[List["BrokerageProductPayoutEvent"]] = Field( + default=None, + description="Product Payouts. These are the ACH or Wire payments that were sent to the" + "Business on behalf of this specific Product", + ) + + pop_financial: Optional[List["POPFinancial"]] = Field(default=None) + bp_account: Optional[LedgerAccount] = Field(default=None) + + # --- Validators --- + @field_validator("harmonizer_domain", mode="before") + def harmonizer_domain_https(cls, s: Optional[str]): + # in the db, this has no scheme. accept both with a default of https:// + if s is not None and not (s.startswith("https://") or s.startswith("http://")): + s = f"https://{s}" + return s + + @field_validator("harmonizer_domain", mode="after") + def validate_harmonizer_domain(cls, v: str): + if urlsplit(v).netloc not in GRS_SKINS: + raise ValueError("Unsupported harmonizer_domain") + return v + + @field_validator("harmonizer_domain", mode="after") + def harmonizer_domain_only(cls, s: str): + # maks sure there is no path + url_split = urlsplit(s) + assert ( + url_split.path == "/" + ), f"harmonizer_domain should be a schema+domain only: {url_split.path}" + assert ( + url_split.query == "" + ), f"harmonizer_domain should be a schema+domain only: {url_split.query}" + assert ( + url_split.fragment == "" + ), f"harmonizer_domain should be a schema+domain only: {url_split.fragment}" + return s + + @field_validator("redirect_url", mode="after") + def validate_redirect_url(cls, s: str) -> str: + url_split = urlsplit(s) + query_dict = parse_qs(url_split.query) + assert "tsid" not in query_dict, "URL should not contain a query param 'tsid'" + return s + + # --- Properties --- + @property + def commission(self) -> Decimal: + return self.commission_pct + + @property + def uuid(self) -> UUIDStr: + return self.id + + @property + def business_uuid(self) -> UUIDStr: + return self.business_id + + @property + def team_uuid(self) -> UUIDStr: + return self.team_id + + @property + def callback_uri(self): + return self.redirect_url + + @property + def sources(self): + return self.sources_config.sources + + @property + def sources_dict(self) -> Dict[Source, SourceConfig]: + # This stores the same info as sources but with the keys as a Source + return {x.name: x for x in self.sources} + + # Should make sure nothing uses this and remove it + # @property + # def routers(self): + # return self.sources + + @computed_field + def user_wallet(self) -> UserWalletConfig: + return self.user_wallet_config + + @property + def user_wallet_enabled(self) -> bool: + return self.user_wallet_config.enabled + + @property + def user_wallet_amt(self) -> bool: + # Controls whether AMT-related cashout methods and ledger transactions are + # allowed. + return self.user_wallet_config.amt + + @property + def cache_key(self) -> str: + return f"product:{self.uuid}" + + @property + def file_key(self) -> str: + return f"product-{self.uuid}" + + # --- Prefetch --- + def prefetch_bp_account(self, thl_lm: "ThlLedgerManager"): + account = thl_lm.get_account_or_create_bp_wallet(product=self) + self.bp_account = account + + return None + + # --- Prebuild --- + + def prebuild_balance( + self, + thl_lm: "ThlLedgerManager", + ds: "GRLDatasets", + client: Client, + pop_ledger: Optional["PopLedgerMerge"] = None, + ) -> None: + """ + This returns the Product's Balances that are calculated across + all time. They are inclusive of every transaction that has ever + occurred in relation to this particular Product. + + GRL does not use a Net30 or other time or Monthly styling billing + practice. All financials are calculated in real time and immediately + available based off the real-time calculated Smart Retainer balance. + + Smart Retainer: + GRL's fully automated smart retainer system incorporates the real-time + recon risk exposure on the BPID account. The retainer amount is prone + to change every few hours based off real time traffic characteristics. + The intention is to provide protection against an account immediately + stopping traffic and having up to 2 months worth of reconciliations + continue to roll in. Using the Smart Retainer amount will allow the + most amount of an accounts balance to be deposited into the owner's + account at any frequency without being tied to monthly invoicing. The + goal is to be as aggressive as possible and not hold funds longer than + absolutely required, Smart Retainer accounts are supported for any + volume levels. + """ + LOG.debug(f"Product.prebuild_balance({self.uuid=})") + + from generalresearch.incite.schemas.mergers.pop_ledger import ( + numerical_col_names, + ) + from generalresearch.models.thl.ledger import LedgerAccount + + account: LedgerAccount = thl_lm.get_account_or_create_bp_wallet(product=self) + assert self.id == account.reference_uuid + + if pop_ledger is None: + from generalresearch.incite.defaults import pop_ledger as plm + + pop_ledger = plm(ds=ds) + + ddf = pop_ledger.ddf( + force_rr_latest=False, + include_partial=True, + columns=numerical_col_names + ["time_idx"], + filters=[ + ("account_id", "==", account.uuid), + ], + ) + + if ddf is None: + raise AssertionError("Cannot build Product Balance") + + df = client.compute(collections=ddf, sync=True) + + if df.empty: + # A Product may not have any ledger transactional events. Don't + # attempt to build a balance, leave it as None rather than + # all zeros + LOG.warning(f"Product({self.uuid=}).prebuild_balance empty dataframe") + assert thl_lm.get_account_balance_timerange(account=account) == 0, ( + "If the df is empty, we can also assume that there should be no " + "transactions in the ledger." + ) + return None + + df = df.set_index("time_idx") + from generalresearch.models.thl.finance import ProductBalances + + balance = ProductBalances.from_pandas(df) + balance.product_id = self.uuid + + bal: int = thl_lm.get_account_balance_timerange( + account=account, time_end=balance.last_event + ) + assert bal == balance.balance, "Sql and Parquet Balance inconsistent" + + self.balance = balance + return None + + def prebuild_pop_financial( + self, + thl_lm: "ThlLedgerManager", + ds: "GRLDatasets", + client: Client, + pop_ledger: Optional["PopLedgerMerge"] = None, + ) -> None: + """This is very similar to the Product POP Financial endpoint; however, + it returns more than one item for a single time interval. This is + because more than a single account will have likely had any + financial activity within that time window. + """ + if self.bp_account is None: + self.prefetch_bp_account(thl_lm=thl_lm) + + from generalresearch.models.admin.request import ( + ReportRequest, + ReportType, + ) + from generalresearch.incite.schemas.mergers.pop_ledger import ( + numerical_col_names, + ) + + rr = ReportRequest(report_type=ReportType.POP_LEDGER, interval="5min") + + if pop_ledger is None: + from generalresearch.incite.defaults import pop_ledger as plm + + pop_ledger = plm(ds=ds) + + ddf = pop_ledger.ddf( + force_rr_latest=False, + include_partial=True, + columns=numerical_col_names + ["time_idx", "account_id"], + filters=[ + ("account_id", "==", self.bp_account.uuid), + ("time_idx", ">=", pop_ledger.start), + ], + ) + if ddf is None: + self.pop_financial = [] + return None + + df = client.compute(collections=ddf, sync=True) + + if df.empty: + self.pop_financial = [] + return None + + df = df.groupby( + [pd.Grouper(key="time_idx", freq=rr.interval), "account_id"] + ).sum() + + from generalresearch.models.thl.finance import POPFinancial + + self.pop_financial = POPFinancial.list_from_pandas( + input_data=df, accounts=[self.bp_account] + ) + + return None + + def prebuild_payouts( + self, + thl_lm: "ThlLedgerManager", + bp_pem: "BrokerageProductPayoutEventManager", + ) -> None: + LOG.debug(f"Product.prebuild_payouts({self.uuid=})") + from generalresearch.models.thl.ledger import OrderBy + + self.payouts = bp_pem.get_bp_bp_payout_events_for_products( + thl_ledger_manager=thl_lm, + product_uuids=[self.uuid], + order_by=OrderBy.DESC, + ) + + self.prebuild_payouts_total() + + def prebuild_payouts_total(self) -> None: + assert self.payouts is not None + + self.payouts_total = USDCent(sum([po.amount for po in self.payouts])) + self.payouts_total_str = self.payouts_total.to_usd_str() + + return None + + # def prebuild_pop(self): + # account = LM.get_account(qualified_name=f"{LM.currency.value}:bp_wallet:{product.id}") + # + # from main import data + # + # gv: GlobalVar = data["gv"] + # + # ddf = gv.pop_ledger.ddf( + # force_rr_latest=False, + # include_partial=True, + # columns=numerical_col_names + ["time_idx"], + # filters=[ + # ("account_id", "==", account.uuid), + # ("time_idx", ">=", rr.start), + # ], + # ) + # + # df = gv.dask_client.compute(collections=ddf, sync=True) + # df = df.set_index("time_idx").resample(rr.freq).sum() + # + # res = [] + # for index, row in df.iterrows(): + # index: pd.Timestamp + # row: pd.DataFrame + # + # dt = index.to_pydatetime().replace(tzinfo=None) + # instance = ProductBalances.from_pandas(row) + # + # res.append( + # { + # "time": dt, + # "payout": instance.payout / 100, + # "adjustment": instance.adjustment / 100, + # "expense": instance.expense / 100, + # "net": (instance.payout + instance.adjustment + instance.expense) / 100, + # } + # ) + # + # df = pd.DataFrame.from_records(res) + + # def financial( + # product: Product = Depends(product_from_path), + # rr: ReportRequest = Depends(rr_from_query), + # ) -> Any: + # account = LM.get_account(qualified_name=f"{LM.currency.value}:bp_wallet:{product.id}") + # + # from main import data + # + # gv: GlobalVar = data["gv"] + # + # ddf = gv.pop_ledger.ddf( + # force_rr_latest=False, + # include_partial=True, + # columns=numerical_col_names + ["time_idx", "account_id"], + # filters=[("account_id", "==", account.uuid), ("time_idx", ">=", rr.start)], + # ) + # + # df = gv.dask_client.compute(collections=ddf, sync=True) + # + # # We only do it this way so it's consistent with the Business.financial view + # df = df.groupby([pd.Grouper(key="time_idx", freq=rr.interval), "account_id"]).sum() + # return POPFinancial.list_from_pandas(df, accounts=[account]) + + # def payments(self): + # """Payments are the amount of money that General Research has sent + # the owner of this Product. + # + # These are typically ACH or Wire payments to company bank accounts. + # These are not respondent payments for Products where + # + # This is Provided in a standard list without any POP Grouping to show + # the exact time and amount of any Issued Payments. + # """ + # + # account = LM.get_account(qualified_name=f"{LM.currency.value}:bp_wallet:{product.id}") + # + # from main import data + # + # gv: GlobalVar = data["gv"] + # ddf = gv.pop_ledger.ddf( + # force_rr_latest=False, + # include_partial=True, + # columns=numerical_col_names + ["time_idx", "account_id"], + # filters=[("account_id", "==", account.uuid)], + # ) + # + # df = gv.dask_client.compute(collections=ddf, sync=True) + + # --- Methods --- + def set_cache( + self, + thl_lm: "ThlLedgerManager", + ds: "GRLDatasets", + client: Client, + bp_pem: "BrokerageProductPayoutEventManager", + redis_config: RedisConfig, + pop_ledger: Optional[PopLedgerMerge] = None, + ) -> None: + LOG.debug(f"Product.set_cache({self.uuid=})") + + ex_secs = 60 * 60 * 24 * 3 # 3 days + + self.prefetch_bp_account(thl_lm=thl_lm) + + self.prebuild_balance( + thl_lm=thl_lm, ds=ds, client=client, pop_ledger=pop_ledger + ) + self.prebuild_payouts(thl_lm=thl_lm, bp_pem=bp_pem) + self.prebuild_pop_financial( + thl_lm=thl_lm, ds=ds, client=client, pop_ledger=pop_ledger + ) + + # Validation steps. Don't save into redis until we confirm against + # the ledger. This allows parquet + db ledger balance checks + # The balance check needs to stop when the last parquet file was + # built, otherwise they'll appear unequal when it's really just + # a delay in the incite merge file not being built yet. + # bal = thl_lm.get_account_balance_timerange(time_end=) + + rc = redis_config.create_redis_client() + rc.set(name=self.cache_key, value=self.model_dump_json(), ex=ex_secs) + + return None + + def determine_bp_payment(self, thl_net: Decimal) -> Decimal: + """ + How much should we pay the BP? + """ + # How much we should get paid by the MPs for all completes in this session ( + # usually 0 or 1 completes) + commission_amount = self.determine_bp_commission(thl_net) + payout = thl_net - commission_amount + payout = payout.quantize(Decimal("0.01")) + return payout + + def determine_bp_commission(self, thl_net: Decimal) -> Decimal: + return (thl_net * self.commission_pct).quantize(Decimal("0.01")) + + def get_payout_transformation_func(self) -> Callable: + """ """ + if self.payout_config.payout_transformation is None: + return lambda x: x + else: + return ( + self.payout_config.payout_transformation.get_payout_transformation_func() + ) + + def calculate_user_payment( + self, bp_payout: Decimal, user_wallet_balance: Optional[Decimal] = None + ) -> Optional[Decimal]: + """ + :param bp_payout: This is the amount we paid to the brokerage product + :return: The amount that should be paid to the user + """ + if self.payout_config.payout_transformation is None: + return None + payout_xform_func = self.get_payout_transformation_func() + kwargs = dict() + if "user_wallet_balance" in inspect.signature(payout_xform_func).parameters: + kwargs["user_wallet_balance"] = user_wallet_balance + user_payout: Decimal = payout_xform_func(bp_payout, **kwargs) + user_payout = user_payout.quantize(Decimal("0.00")) + return user_payout + + def generate_bp_redirect(self, tsid: str): + url_split = urlsplit(self.redirect_url) + query_dict = parse_qs(url_split.query) + query_dict["tsid"] = [tsid] + url_split = list(url_split) + url_split[3] = urlencode(query_dict, doseq=True) + url = urlunsplit(url_split) + return url + + def format_payout_format(self, payout: Decimal) -> Optional[str]: + assert isinstance(payout, Decimal), "payout should be a Decimal" + if self.payout_config.payout_format is None: + return None + payout_int = decimal_to_usd_cents(payout) + return format_payout_format(self.payout_config.payout_format, payout_int) + + # --- ORM --- + + def model_dump_mysql(self, *args, **kwargs) -> Dict[str, Any]: + d = self.model_dump(mode="json", *args, **kwargs) + + if "created" in d: + d["created"] = self.created.replace(tzinfo=None) + + # JSONify these various configuration objects + for k in [ + "user_create_config", + "payout_config", + "session_config", + "sources_config", + "user_wallet_config", + "offerwall_config", + "profiling_config", + "user_health_config", + "yield_man_config", + ]: + if k in d: + d[k] = json.dumps(d[k]) + return d diff --git a/generalresearch/models/thl/profiling/__init__.py b/generalresearch/models/thl/profiling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/models/thl/profiling/marketplace.py b/generalresearch/models/thl/profiling/marketplace.py new file mode 100644 index 0000000..414a437 --- /dev/null +++ b/generalresearch/models/thl/profiling/marketplace.py @@ -0,0 +1,127 @@ +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from functools import cached_property +from typing import Any, Dict, Set, Optional, Tuple + +from pydantic import PositiveInt, BaseModel, Field, computed_field, ConfigDict + +from generalresearch.models import MAX_INT32, Source +from generalresearch.models.custom_types import ( + AwareDatetimeISO, + UUIDStr, + CountryISOLike, + LanguageISOLike, +) +from generalresearch.models.thl.locales import CountryISO, LanguageISO + + +class MarketplaceQuestion(BaseModel, ABC): + model_config = ConfigDict(extra="allow") + + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: CountryISOLike = Field(frozen=True) + + # 3-char ISO 639-2/B, lowercase + language_iso: LanguageISOLike = Field(frozen=True) + + # To avoid deleting questions, if a question no longer comes back in the + # API response (or in some cases, depending on how the question library + # is retrieved, if the question is not used by any live surveys), we'll + # mark it as not live. + is_live: bool = Field(default=True) + + # This should be an "abstract field", but there is no way to do that, so + # just listing it here. It should be overridden by the implementation + source: Source = Field() + + # Refers to a Category that we annotate. The info is stored in different + # dbs, so it may not be possible to retrieve the Category from the id, + # so we just store the id here. + category_id: Optional[UUIDStr] = Field(default=None) + + # # This doesn't work + # @property + # @abstractmethod + # def source(self) -> Source: + # ... + + @property + @abstractmethod + def internal_id(self) -> str: + """This is the value that is used for this question within the marketplace. Typically, + this is question_id. Innovate uses question_key.""" + ... + + @property + def external_id(self) -> str: + return f"{self.source.value}:{self.internal_id}" + + @property + def _key(self) -> Tuple[str, CountryISOLike, LanguageISOLike]: + """This uniquely identifies a question in a locale. There is a unique index + on this in the db. e.g. (question_id, country_iso, language_iso)""" + return self.internal_id, self.country_iso, self.language_iso + + @abstractmethod + def to_upk_question(self): ... + + @computed_field + def num_options(self) -> Optional[int]: + return len(self.options) if self.options is not None else None + + def __hash__(self): + # We need this so this obj can be added into a set. + return hash(self._key) + + def __repr__(self) -> str: + # Fancy repr that only shows the first and last 3 options if the + # question has more than 6. + repr_args = list(self.__repr_args__()) + for n, (k, v) in enumerate(repr_args): + if k == "options": + if v and len(v) > 6: + v = v[:3] + ["..."] + v[-3:] + repr_args[n] = ("options", v) + join_str = ", " + repr_str = join_str.join( + repr(v) if a is None else f"{a}={v!r}" for a, v in repr_args + ) + return f"{self.__repr_name__()}({repr_str})" + + +class MarketplaceUserQuestionAnswer(BaseModel): + # This is optional b/c this model can be used for eligibility checks for + # "anonymous" users, which are represented by a list of question answers + # not associated with an actual user. No default b/c we must explicitly + # set the field to None. + user_id: Optional[PositiveInt] = Field(lt=MAX_INT32) + + question_id: str = Field() + + # This is optional b/c we do not need it when writing these to the db. When + # these are fetched from the db for use in yield-management, we read this + # field from the marketplace's question table. + # This should be overloaded in each implementation !!! + question_type: Optional[str] = Field(default=None) + + # This may be a pipe-separated string if the question_type is multi. Regex + # means any chars except capital letters + option_id: str = Field(pattern=r"^[^A-Z]*$") + created: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + country_iso: CountryISO = Field(frozen=True) + language_iso: LanguageISO = Field(frozen=True) + + @cached_property + def options_ids(self) -> Set[str]: + return set(self.option_id.split("|")) + + @property + def pre_code(self) -> str: + return self.option_id + + def to_mysql(self) -> Dict[str, Any]: + d = self.model_dump(mode="json", exclude={"question_type"}) + d["created"] = self.created.replace(tzinfo=None) + return d diff --git a/generalresearch/models/thl/profiling/other_option.py b/generalresearch/models/thl/profiling/other_option.py new file mode 100644 index 0000000..ae58e90 --- /dev/null +++ b/generalresearch/models/thl/profiling/other_option.py @@ -0,0 +1,56 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + ) + +texts_exact = { + "none", + "other", + "never", + "na", + "n/a", + "nothing", + "unsure", + "uncertain", + "unknown", + "decline", +} + +texts_in = { + "none of the above", + "none of these", + "prefer not to answer", + "prefer not to say", + "dont know", + "don't know", + "not applicable", + "other option", + "other response", + "decline to answer", + "rather not say", + "no answer", + "no preference", + "no opinion", + "not sure", + "i don't", + "i dont", + "i do not", +} + + +def option_is_catch_all(c: "UpkQuestionChoice") -> bool: + """ + Exclusive not specifically in the sense that it is a multi-select question + and if this option is selected no others can be selected. But also in the + sense that this option should not be filtered out. It is the "catch all". + Even a multi-select question can have >1 exclusive options. + """ + if c.id == "-3105": + return True + if c.text.lower() in texts_exact: + return True + if any(t in c.text.lower() for t in texts_in): + return True + return False diff --git a/generalresearch/models/thl/profiling/question.py b/generalresearch/models/thl/profiling/question.py new file mode 100644 index 0000000..6f6b270 --- /dev/null +++ b/generalresearch/models/thl/profiling/question.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Optional, Dict, Tuple + +from pydantic import ( + BaseModel, + Field, + ConfigDict, + computed_field, +) + +from generalresearch.models.custom_types import ( + UUIDStr, + AwareDatetimeISO, + CountryISOLike, + LanguageISOLike, +) +from generalresearch.models.thl.profiling.upk_question import UpkQuestion + + +class Question(BaseModel): + model_config = ConfigDict(validate_assignment=True) + + id: Optional[UUIDStr] = Field(default=None, alias="question_id") + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: CountryISOLike = Field() + # 3-char ISO 639-2/B, lowercase + language_iso: LanguageISOLike = Field() + + property_code: Optional[str] = Field( + default=None, + description="What marketplace question this question links to", + pattern=r"^[a-z]{1,2}\:.*", + ) + data: UpkQuestion = Field() + is_live: bool = Field() + custom: Dict = Field(default_factory=dict) + last_updated: AwareDatetimeISO = Field() + + @computed_field + @property + def md5sum(self) -> str: + return self.data.md5sum + + def validate_question_answer(self, answer: Tuple[str, ...]) -> Tuple[bool, str]: + return self.data.validate_question_answer(answer=answer) diff --git a/generalresearch/models/thl/profiling/upk_property.py b/generalresearch/models/thl/profiling/upk_property.py new file mode 100644 index 0000000..2b63aef --- /dev/null +++ b/generalresearch/models/thl/profiling/upk_property.py @@ -0,0 +1,94 @@ +from enum import Enum +from functools import cached_property +from typing import List, Optional, Dict +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter + +from generalresearch.models.custom_types import UUIDStr, CountryISOLike +from generalresearch.models.thl.category import Category +from generalresearch.utils.enum import ReprEnumMeta + + +class PropertyType(str, Enum, metaclass=ReprEnumMeta): + # UserProfileKnowledge Item + UPK_ITEM = "i" + # UserProfileKnowledge Numerical + UPK_NUMERICAL = "n" + # UserProfileKnowledge Text + UPK_TEXT = "x" + + # Not used + # UPK_DATETIME = "a" + # UPK_TIME = "t" + # UPK_DATE = "d" + + +class Cardinality(str, Enum, metaclass=ReprEnumMeta): + # Zero or More + ZERO_OR_MORE = "*" + # Zero or One + ZERO_OR_ONE = "?" + + +class UpkItem(BaseModel): + id: UUIDStr = Field(examples=["497b1fedec464151b063cd5367643ffa"]) + label: str = Field(max_length=255, examples=["high_school_completion"]) + description: Optional[str] = Field( + max_length=1024, examples=["Completed high school"], default=None + ) + + +class UpkProperty(BaseModel): + """ + This used to be called "QuestionInfo", which is a bad name, + as this describes a UPK Property, like "educational_attainment", + not the question that asks for your education. + """ + + model_config = ConfigDict(populate_by_name=True) + + property_id: UUIDStr = Field(examples=[uuid4().hex]) + + property_label: str = Field(max_length=255, examples=["educational_attainment"]) + + prop_type: PropertyType = Field( + default=PropertyType.UPK_ITEM, + description=PropertyType.as_openapi_with_value_descriptions(), + ) + + cardinality: Cardinality = Field( + default=Cardinality.ZERO_OR_ONE, + description=Cardinality.as_openapi_with_value_descriptions(), + ) + + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: CountryISOLike = Field() + + gold_standard: bool = Field( + default=False, + description="A Gold-Standard question has been enumerated for all " + "possible values (per country) as best as possible by GRL," + "allowing it to be mapped across inventory sources. A " + "property not marked as Gold-Standard may have: 1) " + "marketplace qid associations & 2) category associations, " + "but doesn't have a defined 'range' (list of allowed items" + "in a multiple choice question). " + "This is used for exposing a user's profiling data & for" + "the Nudge API.", + ) + + allowed_items: Optional[List[UpkItem]] = Field(default=None) + + categories: List[Category] = Field(default_factory=list) + + @cached_property + def allowed_items_by_label(self) -> Dict[str, UpkItem]: + return {i.label: i for i in self.allowed_items} + + @cached_property + def allowed_items_by_id(self) -> Dict[UUIDStr, UpkItem]: + return {i.id: i for i in self.allowed_items} + + +ProfilingInfo = TypeAdapter(List[UpkProperty]) diff --git a/generalresearch/models/thl/profiling/upk_question.py b/generalresearch/models/thl/profiling/upk_question.py new file mode 100644 index 0000000..2b952ec --- /dev/null +++ b/generalresearch/models/thl/profiling/upk_question.py @@ -0,0 +1,683 @@ +from __future__ import annotations + +import hashlib +import json +import re +from enum import Enum +from functools import cached_property +from typing import List, Optional, Union, Literal, Dict, Tuple, Set + +from pydantic import ( + BaseModel, + Field, + model_validator, + field_validator, + ConfigDict, + NonNegativeInt, + PositiveInt, +) +from typing_extensions import Annotated + +from generalresearch.models import Source +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.thl.category import Category + + +class UPKImportance(BaseModel): + task_count: Optional[int] = Field( + ge=0, + default=None, + examples=[47], + description="The number of live Tasks that use this UPK Question", + ) + + task_score: Optional[float] = Field( + ge=0, + default=None, + examples=[0.11175522477414712], + description="GRL's internal ranked score for the UPK Question", + ) + + marketplace_task_count: Optional[Dict[Source, NonNegativeInt]] = Field( + default=None, + examples=[{Source.DYNATA: 23, Source.SPECTRUM: 24}], + description="The number of live Tasks that use this UPK Question per marketplace", + ) + + +class PatternValidation(BaseModel): + model_config = ConfigDict(frozen=True) + + message: str = Field(description="Message to display if validation fails") + + pattern: str = Field( + description="Regex string to validate. min_length and max_length are " + "checked separately, even if they are part of the regex." + ) + + +class UpkQuestionChoice(BaseModel): + model_config = ConfigDict(frozen=False, populate_by_name=True) + + # The choice ID uses the marketplace's code. This needs to be >32 for pollfish + id: str = Field( + min_length=1, + max_length=64, + pattern=r"^[\w\s\.\-]+$", + description="The unique identifier for a response to a qualification", + serialization_alias="choice_id", + validation_alias="choice_id", + frozen=True, + ) + + text: str = Field( + min_length=1, + description="The response text shown to respondents", + alias="choice_text", + frozen=True, + ) + + order: NonNegativeInt = Field() + + # Allows you to group answer choices together (used for display or extra logic) + group: Optional[int] = Field(default=None) + + exclusive: bool = Field( + default=False, + description="If answer is exclusive, it can be the only option selected", + ) + + importance: Optional[UPKImportance] = Field(default=None) + + def __hash__(self): + # We don't know the question ID!! Unique within a question only! + return hash(self.id) + + +class UpkQuestionChoiceOut(UpkQuestionChoice): + pass + # importance: Optional[UPKImportance] = Field(default=None, exclude=True) + + +class UpkQuestionType(str, Enum): + # The question has options that the user must select from. A MC question + # can be e.g. Selector.SINGLE_ANSWER or Selector.MULTIPLE_ANSWER to + # indicate only 1 or more than 1 option can be selected respectively. + MULTIPLE_CHOICE = "MC" + # The question has no options; the user must enter text. + TEXT_ENTRY = "TE" + # The question presents a slider of possible values, typically a numerical range. + SLIDER = "SLIDER" + # The question has no UI elements. + HIDDEN = "HIDDEN" + + +class UpkQuestionSelector(str, Enum): + pass + + +class UpkQuestionSelectorMC(UpkQuestionSelector): + SINGLE_ANSWER = "SA" + MULTIPLE_ANSWER = "MA" + DROPDOWN_LIST = "DL" + SELECT_BOX = "SB" + MULTI_SELECT_BOX = "MSB" + + +class UpkQuestionSelectorTE(UpkQuestionSelector): + SINGLE_LINE = "SL" + MULTI_LINE = "ML" + ESSAY_TEXT_BOX = "ETB" + + +class UpkQuestionSelectorSLIDER(UpkQuestionSelector): + HORIZONTAL_SLIDER = "HSLIDER" + VERTICAL_SLIDER = "VSLIDER" + + +class UpkQuestionSelectorHIDDEN(UpkQuestionSelector): + HIDDEN = "HIDDEN" + + +class UpkQuestionConfigurationMC(BaseModel): + model_config = ConfigDict(frozen=True, extra="forbid") + + # --- UpkQuestionType.MULTIPLE_CHOICE Options --- + # A multiple choice question with MA may allow a limited number of options + # to be selected. + # If the selector is SA, this should be set to 1. If the selector is MA, + # then this must be <= len(choices). + type: Literal[UpkQuestionType.MULTIPLE_CHOICE] = Field( + exclude=True, default=UpkQuestionType.MULTIPLE_CHOICE + ) + + max_select: Optional[int] = Field(gt=0, default=None) + + +class UpkQuestionConfigurationTE(BaseModel): + model_config = ConfigDict(frozen=True, extra="forbid") + + # --- UpkQuestionType.TEXT_ENTRY Options --- + type: Literal[UpkQuestionType.TEXT_ENTRY] = Field( + exclude=True, default=UpkQuestionType.TEXT_ENTRY + ) + + # Sets input form attribute; not the same as regex validation + max_length: Optional[PositiveInt] = Field( + default=None, + description="Maximum str length of any input. Meant as an easy, non" + "regex based check.", + ) + + # The text input box must contain this number of chars before submission + # is allowed + min_length: Optional[PositiveInt] = Field( + default=None, + description="Minimum str length of any input. Meant as an easy, non" + "regex based check.", + ) + + @model_validator(mode="after") + def check_options_agreement(self): + if self.max_length is not None and self.min_length is not None: + assert ( + self.min_length <= self.max_length + ), "max_length must be >= min_length" + return self + + +class UpkQuestionConfigurationSLIDER(BaseModel): + model_config = ConfigDict(frozen=True) + + # --- UpkQuestionType.SLIDER Options --- + type: Literal[UpkQuestionType.SLIDER] = Field( + exclude=True, default=UpkQuestionType.SLIDER + ) + + # TODO: constraints. we don't have any of these so not wasting time on this + slider_min: Optional[float] = Field(default=None) + slider_max: Optional[float] = Field(default=None) + slider_start: Optional[float] = Field(default=None) + slider_step: Optional[float] = Field(default=None) + + +class UpkQuestionValidation(BaseModel): + model_config = ConfigDict(frozen=True) + + # --- UpkQuestionType.TEXT_ENTRY Options --- + patterns: Optional[List[PatternValidation]] = Field(min_length=1) + + +SelectorType = Union[ + UpkQuestionSelectorMC, + UpkQuestionSelectorTE, + UpkQuestionSelectorSLIDER, + UpkQuestionSelectorHIDDEN, +] +Configuration = Annotated[ + Union[ + UpkQuestionConfigurationMC, + UpkQuestionConfigurationTE, + UpkQuestionConfigurationSLIDER, + ], + Field(discriminator="type"), +] + +example_upk_question = { + "choices": [ + { + "order": 0, + "choice_id": "1", + "exclusive": False, + "choice_text": "Yes", + }, + {"order": 1, "choice_id": "2", "exclusive": False, "choice_text": "No"}, + ], + "selector": "SA", + "task_count": 49, + "task_score": 3.3401743283265684, + "marketplace_task_count": { + "d": 9, + "w": 20, + "s": 20, + }, + "country_iso": "us", + "question_id": "fb20fd4773304500b39c4f6de0012a5a", + "language_iso": "eng", + "question_text": "Are you registered to vote at your present address, or not?", + "question_type": "MC", + "importance": UPKImportance( + task_count=49, + task_score=3.3401743283265684, + marketplace_task_count={ + Source.DYNATA: 9, + Source.WXET: 20, + Source.SPECTRUM: 20, + }, + ).model_dump(mode="json"), + "categories": [ + Category( + uuid="87b6d819f3ca4815bf1f135b1e829cc6", + adwords_vertical_id="396", + label="Politics", + path="/News/Politics", + parent_uuid="f66dddba61424ce5be2a38731450a0e1", + ).model_dump() + ], +} + + +class UpkQuestion(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + json_schema_extra={"example": example_upk_question}, + # Don't set this to True. Breaks in model validator (infinite recursion) + validate_assignment=False, + ) + + # The id is globally unique + id: Optional[UUIDStr] = Field(default=None, alias="question_id") + + # The format is "{Source}:{question_id}" where Source is 1 or 2 chars, and + # question_id is the marketplace's ID for this question. + ext_question_id: Optional[str] = Field( + default=None, + description="what marketplace question this question links to", + pattern=r"^[a-z]{1,2}\:.*", + ) + + type: UpkQuestionType = Field(alias="question_type") + + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: str = Field(max_length=2, min_length=2, pattern=r"^[a-z]{2}$") + # 3-char ISO 639-2/B, lowercase + language_iso: str = Field(max_length=3, min_length=3, pattern=r"^[a-z]{3}$") + + text: str = Field( + min_length=1, + description="The text shown to respondents", + alias="question_text", + ) + + # Don't set a min_length=1 here. We'll allow this to be created, but it + # won't be askable with empty choices. + choices: Optional[List[UpkQuestionChoice]] = Field(default=None) + selector: SelectorType = Field(default=None) + configuration: Optional[Configuration] = Field(default=None) + validation: Optional[UpkQuestionValidation] = Field(default=None) + importance: Optional[UPKImportance] = Field(default=None) + + categories: List[Category] = Field( + default_factory=list, + description="Categories associated with this question", + ) + + explanation_template: Optional[str] = Field( + description="Human-readable template for explaining how a user's answer to this question affects eligibility", + examples=[ + "The company that administers your employer's retirement plan is {answer}." + ], + default=None, + ) + explanation_fragment_template: Optional[str] = Field( + description="A very short, natural-language explanation fragment that can be combined with others into a single sentence", + examples=["whose employer's retirement plan is {answer}"], + default=None, + ) + + @property + def _key(self): + if self.id is None: + raise ValueError("must set .id first") + return self.id, self.country_iso, self.language_iso + + @property + def locale(self) -> str: + return self.country_iso + "_" + self.language_iso + + @property + def source(self) -> Optional[Source]: + if self.ext_question_id: + return Source(self.ext_question_id.split(":", 1)[0]) + + @cached_property + def choices_text_lookup(self): + if self.choices is None: + return None + return {c.id: c.text for c in self.choices} + + @model_validator(mode="before") + @classmethod + def check_configuration_type(cls, data: Dict): + # The model knows what the type of Configuration to grab depending on + # the key 'type' which it expects inside the configuration object. + # Here, we grab the type from the top-level model instead. + config = data.get("configuration") + if isinstance(config, dict) and config.get("type") is None: + data.setdefault("configuration", {}) + data["configuration"]["type"] = data.get("type") or data.get( + "question_type" + ) + return data + + @model_validator(mode="after") + def check_type_options_agreement(self): + # If type == "text_entry", options is None. Otherwise, must be set. + if self.type in {UpkQuestionType.TEXT_ENTRY, UpkQuestionType.HIDDEN}: + if isinstance(self.choices, list) and len(self.choices) == 0: + self.choices = None + assert ( + self.choices is None + ), f"No `choices` are allowed for type `{self.type}`" + else: + assert self.choices is not None, f"`choices` must be set" + return self + + @model_validator(mode="after") + def set_default_selector(self): + if self.selector is None: + if self.type == UpkQuestionType.MULTIPLE_CHOICE: + self.selector = UpkQuestionSelectorMC.SINGLE_ANSWER + elif self.type == UpkQuestionType.TEXT_ENTRY: + self.selector = UpkQuestionSelectorTE.SINGLE_LINE + elif self.type == UpkQuestionType.SLIDER: + self.selector = UpkQuestionSelectorSLIDER.HORIZONTAL_SLIDER + else: + self.selector = UpkQuestionSelectorHIDDEN.HIDDEN + return self + + @model_validator(mode="after") + def check_type_selector_agreement(self): + if self.type == UpkQuestionType.MULTIPLE_CHOICE: + assert isinstance( + self.selector, UpkQuestionSelectorMC + ), f"type `{self.type}` must have selector UpkQuestionSelectorMC" + if self.type == UpkQuestionType.TEXT_ENTRY: + assert isinstance( + self.selector, UpkQuestionSelectorTE + ), f"type `{self.type}` must have selector UpkQuestionSelectorTE" + if self.type == UpkQuestionType.SLIDER: + assert isinstance( + self.selector, UpkQuestionSelectorTE + ), f"type `{self.type}` must have selector UpkQuestionSelectorTE" + if self.type == UpkQuestionType.HIDDEN: + assert isinstance( + self.selector, UpkQuestionSelectorHIDDEN + ), f"type `{self.type}` must have selector UpkQuestionSelectorTE" + return self + + @model_validator(mode="after") + def check_type_validator_agreement(self): + if self.validation and self.validation.patterns is not None: + assert ( + self.type == UpkQuestionType.TEXT_ENTRY + ), "validation.patterns is only allowed on Text Entry Questions" + return self + + @model_validator(mode="after") + def check_config_choices(self): + if self.type == UpkQuestionType.MULTIPLE_CHOICE and self.configuration: + if self.selector in { + UpkQuestionSelectorMC.SINGLE_ANSWER, + UpkQuestionSelectorMC.DROPDOWN_LIST, + UpkQuestionSelectorMC.SELECT_BOX, + }: + assert ( + self.configuration.max_select == 1 + ), f"configuration.max_select must be 1 if the selector is {self.selector.value}" + else: + assert self.configuration.max_select <= len( + self.choices + ), "configuration.max_select must be >= len(choices)" + return self + + @field_validator("choices") + @classmethod + def order_choices(cls, choices): + if choices: + choices.sort(key=lambda x: x.order) + return choices + + @field_validator("choices") + @classmethod + def validate_choices(cls, choices): + if choices: + ids = {x.id for x in choices} + assert len(ids) == len(choices), "choices.id must be unique" + orders = {x.order for x in choices} + assert len(orders) == len(choices), "choices.order must be unique" + return choices + + @field_validator("explanation_template", "explanation_fragment_template") + @classmethod + def validate_explanation_template(cls, v): + if v is None: + return v + if "{answer}" not in v: + raise ValueError("field must include '{answer}'") + return v + + @property + def is_askable(self) -> bool: + if len(self.text) < 5: + # It should have some text that is question-like. 5 is chosen + # because it is the shortest known "real" question (spectrum + # gender = "I'm a") + return False + + if len(self.text) > 1024: + # This usually means it is some sort of ridiculous terms & + # conditions they want the user to agree to, which we don't want + # to support + return False + + # Almost nothing has >1k options, besides location stuff (cities, + # etc.) which should get harmonized. When presenting them, we'll + # filter down options to at most 50. + if self.choices and (len(self.choices) <= 1 or len(self.choices) > 1000): + return False + + return True + + @property + def md5sum(self): + # Used to determine if a question has changed + d = { + "question_text": self.text, + "question_type": self.type.value, + "selector": self.selector.value, + "choices": ( + [{"choice_id": x.id, "choice_text": x.text} for x in self.choices] + if self.choices + else [] + ), + } + return hashlib.md5(json.dumps(d, sort_keys=True).encode("utf-8")).hexdigest() + + def to_api_format(self): + d = self.model_dump(mode="json", exclude_none=True, by_alias=True) + # This doesn't currently get included, I think it could but not sure + # if it would break anything + d.pop("ext_question_id", None) + # API expects task_score and task_count on the top-level + d.update(d.pop("importance", {})) + return d + + def validate_question_answer(self, answer: Tuple[str, ...]) -> Tuple[bool, str]: + """ + Returns (is_valid, error_message). + """ + try: + self._validate_question_answer(answer) + except AssertionError as e: + return False, str(e) + else: + return True, "" + + def _validate_question_answer(self, answer: Tuple[str, ...]) -> None: + """ + If the question is MC, validate: + - validate selector SA vs MA (1 selected vs >1 selected) + - the answers match actual codes in the choices + - validate configuration.max_select + - validate choices.exclusive + If the question is TE, validate that: + - configuration.max_length + - validation.patterns + Throws AssertionError if the answer is invalid, otherwise returns None + """ + answer = tuple(answer) + # There should never be multiple of the same value + assert sorted(set(answer)) == sorted( + answer + ), "Multiple of the same answer submitted" + if self.type == UpkQuestionType.MULTIPLE_CHOICE: + assert len(answer) >= 1, "MC question with no selected answers" + choice_codes = set(x.id for x in self.choices) + if self.selector == UpkQuestionSelectorMC.SINGLE_ANSWER: + assert ( + len(answer) == 1 + ), "Single Answer MC question with >1 selected answers" + elif self.selector == UpkQuestionSelectorMC.MULTIPLE_ANSWER: + assert len(answer) <= len( + self.choices + ), "More options selected than allowed" + assert all( + ans in choice_codes for ans in answer + ), "Invalid Options Selected" + max_select = ( + self.configuration.max_select + if self.configuration + else 0 or len(self.choices) + ) + assert len(answer) <= max_select, "More options selected than allowed" + exclusive_choice = next((x for x in self.choices if x.exclusive), None) + if exclusive_choice: + exclusive_choice_id = exclusive_choice.id + assert ( + answer == (exclusive_choice_id,) + or exclusive_choice_id not in answer + ), "Invalid exclusive selection" + elif self.type == UpkQuestionType.TEXT_ENTRY: + assert len(answer) == 1, "Only one answer allowed" + answer = answer[0] + assert len(answer) > 0, "Must provide answer" + max_length = ( + self.configuration.max_length if self.configuration else 0 or 100000 + ) + assert len(answer) <= max_length, "Answer longer than allowed" + if self.validation and self.validation.patterns: + for pattern in self.validation.patterns: + assert re.search(pattern.pattern, answer), pattern.message + elif self.type == UpkQuestionType.HIDDEN: + pass + + +class UpkQuestionOut(UpkQuestion): + choices: Optional[List[UpkQuestionChoiceOut]] = Field(default=None) + # Return both importance top-level model and extracted keys for now. + # Eventually deprecate one way. + task_count: Optional[int] = Field( + ge=0, + default=None, + examples=[47], + description="The number of live Tasks that use this UPK Question", + ) + + task_score: Optional[float] = Field( + ge=0, + default=None, + examples=[0.11175522477414712], + description="GRL's internal ranked score for the UPK Question", + ) + + marketplace_task_count: Optional[Dict[Source, NonNegativeInt]] = Field( + default=None, + examples=[{Source.DYNATA: 23, Source.SPECTRUM: 24}], + description="The number of live Tasks that use this UPK Question per marketplace", + ) + + @model_validator(mode="after") + def populate_from_importance(self): + # When we return through the api, bring the importance keys to the top-level + if self.importance: + self.task_count = self.importance.task_count + self.task_score = self.importance.task_score + self.marketplace_task_count = self.importance.marketplace_task_count + return self + + +def order_exclusive_options(q: UpkQuestion): + """ + The idea is to call then when doing a MP -> UPK conversion, where the + marketplace doesn't have the order specified. + """ + from generalresearch.models.thl.profiling.other_option import ( + option_is_catch_all, + ) + + if q.choices: + last_choices = [c for c in q.choices if option_is_catch_all(c)] + for c in last_choices: + q.choices.remove(c) + q.choices.append(c) + c.exclusive = True + if last_choices: + for idx, c in enumerate(q.choices): + c.order = idx + + +def trim_options(q: UpkQuestion, max_options: int = 50) -> UpkQuestion: + """Filter weighted MC/SC Options during Offerwall Requests or Refresh + + - Remove any of ZERO importance + - ~50 option HARD limit, keep only the 50 highest scoring + - In soft-pair, take up to requested, or 50 + - Implement N-1 to keep options that are a catch-all / exclusive. + """ + from generalresearch.models.thl.profiling.other_option import ( + option_is_catch_all, + ) + + q = q.model_copy() + if not q.choices: + return q + if q.ext_question_id.startswith("gr:") or q.ext_question_id.startswith("g:"): + return q + + special_choices: Set[UpkQuestionChoice] = { + c for c in q.choices if option_is_catch_all(c) + } + + if q.choices[0].importance is None: + # We're calculating UpkQuestionChoice important on (1) UpkQuestionChoice + # Creation and (2) every 60min, so this should always be set. However, + # if isn't for some reason, don't fail... just show a random set of + # 50 UpkQuestionChoices. Sorry ¯\_(ツ)_/¯ + for c in q.choices: + c.importance = UPKImportance(task_score=1, task_count=1) + + possible_choices = [ + c for c in q.choices if c.importance.task_count > 0 or c in special_choices + ] + if possible_choices: + q.choices = possible_choices + else: + # We can't have a MC question with all choices filtered out. + pass + + if len(q.choices) > max_options: + choices = q.choices + # If there is a Special Choice (eg: "none of the above", "decline to + # answer", "prefer not to say", etc) always include it at the bottom. + idx = max_options - len(special_choices) + choices = set( + sorted(choices, key=lambda x: x.importance.task_score, reverse=True)[:idx] + ) + choices.update(special_choices) + q.choices = sorted(choices, key=lambda x: x.order) + + return q + + +UpkQuestionOut.model_rebuild() diff --git a/generalresearch/models/thl/profiling/upk_question_answer.py b/generalresearch/models/thl/profiling/upk_question_answer.py new file mode 100644 index 0000000..28b9f27 --- /dev/null +++ b/generalresearch/models/thl/profiling/upk_question_answer.py @@ -0,0 +1,116 @@ +from datetime import datetime, timezone +from typing import Optional, Union, Dict +from uuid import uuid4 + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PositiveInt, + model_validator, + computed_field, +) +from typing_extensions import Self + +from generalresearch.models import MAX_INT32 +from generalresearch.models.custom_types import ( + UUIDStr, + AwareDatetimeISO, + CountryISOLike, +) +from generalresearch.models.thl.profiling.upk_property import ( + PropertyType, + Cardinality, +) + + +class UpkQuestionAnswer(BaseModel): + """ """ + + model_config = ConfigDict(populate_by_name=True) + + user_id: PositiveInt = Field(lt=MAX_INT32) + + question_id: Optional[UUIDStr] = Field( + examples=[uuid4().hex], + description="The ID of the question that was asked in order to determine this", + default=None, + ) + session_id: Optional[UUIDStr] = Field( + examples=[uuid4().hex], + description="The thl_session in which the question was asked", + default=None, + ) + + property_id: UUIDStr = Field(examples=[uuid4().hex]) + + property_label: str = Field(max_length=255, examples=["educational_attainment"]) + + prop_type: PropertyType = Field( + default=PropertyType.UPK_ITEM, + description=PropertyType.as_openapi_with_value_descriptions(), + ) + + cardinality: Cardinality = Field( + default=Cardinality.ZERO_OR_ONE, + description=Cardinality.as_openapi_with_value_descriptions(), + ) + + # ISO 3166-1 alpha-2 (two-letter codes, lowercase) + country_iso: CountryISOLike = Field() + + created: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + + # If the property is PropertyType.UPK_ITEM, it should have an item (and no value). + # If the property is UPK_NUMERICAL or UPK_TEXT, it'll have a value (and no item). + item_id: Optional[UUIDStr] = Field( + default=None, examples=["497b1fedec464151b063cd5367643ffa"] + ) + item_label: Optional[str] = Field( + default=None, max_length=255, examples=["high_school_completion"] + ) + value_text: Optional[str] = Field( + default=None, + max_length=1024, + ) + value_num: Optional[float] = Field( + default=None, + ) + + @computed_field + @property + def value(self) -> Optional[Union[str, float]]: + if self.prop_type == PropertyType.UPK_ITEM: + return self.item_label + elif self.prop_type == PropertyType.UPK_TEXT: + return self.value_text + elif self.prop_type == PropertyType.UPK_NUMERICAL: + return self.value_num + + @model_validator(mode="after") + def check_value_vs_item(self) -> Self: + if self.prop_type == PropertyType.UPK_ITEM: + if not self.item_id or not self.item_label: + raise ValueError("item_id and item_label must be provided for UPK_ITEM") + if self.value_num is not None or self.value_text is not None: + raise ValueError("value and value_text must be None for UPK_ITEM") + + elif self.prop_type in { + PropertyType.UPK_NUMERICAL, + PropertyType.UPK_TEXT, + }: + if self.item_id or self.item_label: + raise ValueError("item_id and item_label must be None for non-UPK_ITEM") + if self.prop_type == PropertyType.UPK_NUMERICAL and self.value_num is None: + raise ValueError("value must be provided for UPK_NUMERICAL") + if self.prop_type == PropertyType.UPK_TEXT and self.value_text is None: + raise ValueError("value_text must be provided for UPK_TEXT") + + return self + + def model_dump_mysql(self) -> Dict: + d = self.model_dump(mode="json") + d["created"] = self.created + return d diff --git a/generalresearch/models/thl/profiling/user_info.py b/generalresearch/models/thl/profiling/user_info.py new file mode 100644 index 0000000..609121e --- /dev/null +++ b/generalresearch/models/thl/profiling/user_info.py @@ -0,0 +1,76 @@ +from typing import Optional, List + +from pydantic import BaseModel, ConfigDict, Field +from pydantic.json_schema import SkipJsonSchema + +from generalresearch.models import Source +from generalresearch.models.custom_types import AwareDatetimeISO +from generalresearch.models.thl.profiling.user_question_answer import ( + MarketplaceResearchProfileQuestion, +) +from generalresearch.models.thl.user import User + + +class UserProfileKnowledgeAnswer(BaseModel): + # Returns {id, label, translation} when the prop_type is an item, + # and only {value} if it's a string/text (such as for postalcode) + id: Optional[str] = Field(default=None) + label: Optional[str] = Field(default=None) + translation: Optional[str] = Field(default=None) + + value: Optional[str] = Field(default=None) + + +class UserProfileKnowledge(BaseModel): + property_id: str = Field() + property_label: str = Field() + translation: str = Field() + + answer: List[UserProfileKnowledgeAnswer] = Field(default_factory=list) + + created: AwareDatetimeISO = Field( + description="When the User submitted this Profiling data" + ) + + +class MarketProfileKnowledge(BaseModel): + """ + This is used solely in API responses, so it is simplified. + """ + + source: Source = Field( + max_length=16, description="Marketplace this question is from" + ) + + question_id: str = Field(examples=["gender", "1843", "gender_plus"]) + + answer: List[str] = Field( + default_factory=list, examples=[["male"], ["7657644"], ["1"]] + ) + + created: AwareDatetimeISO = Field( + description="When the User submitted this Profiling data" + ) + + @classmethod + def from_MarketplaceResearchProfileQuestion( + cls, q: MarketplaceResearchProfileQuestion + ): + return cls( + source=q.source, + question_id=q.question_code, + answer=list(q.answer), + created=q.timestamp, + ) + + +class UserInfo(BaseModel): + model_config = ConfigDict() + + user: SkipJsonSchema[Optional[User]] = Field(exclude=True, default=None) + + user_profile_knowledge: List[UserProfileKnowledge] = Field(default_factory=list) + + marketplace_profile_knowledge: List[MarketProfileKnowledge] = Field( + default_factory=list + ) diff --git a/generalresearch/models/thl/profiling/user_question_answer.py b/generalresearch/models/thl/profiling/user_question_answer.py new file mode 100644 index 0000000..b42183d --- /dev/null +++ b/generalresearch/models/thl/profiling/user_question_answer.py @@ -0,0 +1,160 @@ +import json +from datetime import datetime, timezone, timedelta +from typing import Dict, Tuple, Iterator, Optional, Literal, Union, Any + +from pydantic import ( + PositiveInt, + Field, + field_validator, + model_validator, + BaseModel, + ConfigDict, +) +from typing_extensions import Self + +from generalresearch.grpc import timestamp_to_datetime +from generalresearch.models import Source, MAX_INT32 +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr +from generalresearch.models.thl.locales import CountryISO, LanguageISO +from generalresearch.models.thl.profiling.upk_question import UpkQuestion + + +class UserQuestionAnswer(BaseModel): + + model_config = ConfigDict(validate_assignment=True) + + user_id: Optional[PositiveInt] = Field(lt=MAX_INT32, default=None) + question_id: UUIDStr = Field() + answer: Tuple[str, ...] = Field() + timestamp: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + + country_iso: Union[CountryISO, Literal["xx"]] = Field() + language_iso: Union[LanguageISO, Literal["xxx"]] = Field() + + # Store a property code associated with this question_id. e.g. "gr:hispanic" or "d:192" + property_code: str = Field() + # Stores any question answers that are calculated from this answer + calc_answers: Optional[Dict[str, Tuple[str, ...]]] = Field(default=None) + + @field_validator("calc_answers") + def sorted_calc_answers(cls, calc_answers) -> Optional[Dict[str, Tuple[str, ...]]]: + if calc_answers is None: + return None + + return {k: tuple(sorted(v)) for k, v in calc_answers.items()} + + @field_validator("calc_answers") + def validate_keys(cls, calc_answers) -> Optional[Dict[str, Tuple[str, ...]]]: + if calc_answers is None: + return None + + assert all( + ":" in k for k in calc_answers.keys() + ), "calc_answers expects the keys to be in format source:question_code" + return calc_answers + + def model_dump_mysql(self, session_id: Optional[str] = None) -> Dict[str, Any]: + d = self.model_dump(mode="json", exclude={"calc_answers", "timestamp"}) + d["answer"] = json.dumps(self.answer) + # Note naming inconsistency here: calc_answer/s + d["calc_answer"] = json.dumps(self.calc_answers) + # Note naming inconsistency here: created vs timestamp + d["created"] = self.timestamp + d["session_id"] = session_id + return d + + def get_mrpqs(self) -> Iterator["MarketplaceResearchProfileQuestion"]: + for k, v in self.calc_answers.items(): + source, question_code = k.split(":", 1) + yield MarketplaceResearchProfileQuestion( + question_code=question_code, + source=source, + country_iso=self.country_iso, + language_iso=self.language_iso, + answer=tuple(sorted(set(v))), + timestamp=self.timestamp, + ) + + @field_validator("answer") + def sorted_answer(cls, answer): + return tuple(sorted(answer)) + + def __hash__(self) -> int: + return hash((self.question_id, self.answer, self.timestamp)) + + def validate_question_answer(self, question: UpkQuestion) -> Tuple[bool, str]: + """ + Returns (is_valid, error_message). + """ + try: + assert question.id == self.question_id, "mismatched question id" + assert ( + question.country_iso == self.country_iso + ), "country_iso doesn't match question's country" + assert ( + question.language_iso == self.language_iso + ), "language_iso doesn't match question's language" + question._validate_question_answer(self.answer) + except AssertionError as e: + return False, str(e) + else: + return True, "" + + def is_stale(self) -> bool: + return self.timestamp < datetime.now(tz=timezone.utc) - timedelta(days=30) + + @classmethod + def from_grpc(cls, msg, default_timestamp: datetime) -> Self: + """ + Handles correctly issues with grpc timestamps + :param msg: "thl.protos.generalresearch_pb2.ProfilingQuestionAnswer" + """ + assert default_timestamp.tzinfo is not None, "must use tz-aware timestamps" + timestamp = timestamp_to_datetime(msg.timestamp) + timestamp = default_timestamp if timestamp < datetime(2000, 1, 1) else timestamp + return cls( + question_id=msg.question_id, + answer=tuple(msg.answer), + timestamp=timestamp, + ) + + +# We can't set a redis list to [] vs None. We'll push this dummy answer into +# the cache to signify the user has no answered questions. It'll get removed +# by the 30 day old check once we pull it back anyways +DUMMY_UQA = UserQuestionAnswer( + question_id="f118edd01cf1476ba7200a175fb4351d", + answer=("0",), + timestamp=datetime(2020, 1, 1, tzinfo=timezone.utc), + country_iso="xx", + language_iso="xxx", + property_code="dummy", + calc_answers=dict(), +) + + +class MarketplaceResearchProfileQuestion(BaseModel): + """Answer submitted to a question by a user, that has been transformed + into a question answer that is specific to a marketplace.""" + + question_code: str = Field( + description="# the question id/code on the marketplace", min_length=1 + ) + source: Source = Field() # the one or two-letter marketplace code + answer: Tuple[str, ...] = Field(min_length=1) + timestamp: AwareDatetimeISO = Field() + country_iso: CountryISO = Field() + language_iso: LanguageISO = Field() + + @model_validator(mode="after") + def validate_keys(self): + assert ( + ":" not in self.question_code + ), "question_code expected to not be in curie format" + return self + + @property + def answer_str(self) -> str: + return "|".join(self.answer) diff --git a/generalresearch/models/thl/report_task.py b/generalresearch/models/thl/report_task.py new file mode 100644 index 0000000..e945b3b --- /dev/null +++ b/generalresearch/models/thl/report_task.py @@ -0,0 +1,50 @@ +import random +from collections import defaultdict +from typing import List, Collection, Optional + +from pydantic import BaseModel, ConfigDict, Field + +from generalresearch.models.thl.definitions import ReportValue +from generalresearch.models.thl.user import BPUIDStr + +# If a report is made with multiple values, we'll take the one with the +# highest priority +REPORT_PRIORITY = defaultdict( + lambda: 2, + { + ReportValue.REASON_UNKNOWN: 0, # lowest priority + ReportValue.TECHNICAL_ERROR: 1, # next highest + ReportValue.DIDNT_LIKE: 1, + }, +) + + +def prioritize_report_values( + report_values: Collection[ReportValue], +) -> Optional[ReportValue]: + if not report_values: + return None + report_values = list(set(report_values)) + random.shuffle(report_values) + return sorted(report_values, key=lambda x: REPORT_PRIORITY[x])[-1] + + +class ReportTask(BaseModel): + model_config = ConfigDict(extra="forbid") + + bpuid: BPUIDStr = Field( + title="product_user_id", + description="The unique identifier for the user, which is set by the " + "Supplier.", + examples=["app-user-9329ebd"], + ) + + reasons: List[ReportValue] = Field( + description=ReportValue.as_openapi_with_value_descriptions(), + examples=[[3, 4]], + default_factory=list, + ) + + notes: str = Field( + default="", examples=["The survey wanted to watch me eat Haejang-guk"] + ) diff --git a/generalresearch/models/thl/session.py b/generalresearch/models/thl/session.py new file mode 100644 index 0000000..0066559 --- /dev/null +++ b/generalresearch/models/thl/session.py @@ -0,0 +1,1344 @@ +import json +import logging +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from typing import Optional, Dict, Any, Tuple, Union, List, Annotated +from uuid import uuid4 +from typing import TYPE_CHECKING + +from pydantic import ( + BaseModel, + AwareDatetime, + Field, + model_validator, + field_validator, + computed_field, + ConfigDict, + field_serializer, +) +from typing_extensions import Self + +from generalresearch.models import DeviceType, Source +from generalresearch.models.custom_types import ( + UUIDStr, + AwareDatetimeISO, + IPvAnyAddressStr, + EnumNameSerializer, +) +from generalresearch.models.legacy.bucket import Bucket +from generalresearch.models.thl import ( + Product, + decimal_to_int_cents, + int_cents_to_decimal, +) +from generalresearch.models.thl.definitions import ( + Status, + SessionAdjustedStatus, + WallAdjustedStatus, + StatusCode1, + ReportValue, + WallStatusCode2, + SessionStatusCode2, + WALL_ALLOWED_STATUS_CODE_1_2, + WALL_ALLOWED_STATUS_STATUS_CODE, +) +from generalresearch.models.thl.user import User + +if TYPE_CHECKING: + from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, + ) + +logger = logging.getLogger("Wall") + + +class WallBase(BaseModel): + """ + TODO: We want to extend a new test that does more rigorous testing + and usage of any Wall.user_id vs Session.user_id on manually + setting of the attribute vs retrieving any results from the database + """ + + model_config = ConfigDict( + extra="forbid", validate_assignment=True, ser_json_timedelta="float" + ) + + uuid: UUIDStr = Field(default_factory=lambda: uuid4().hex) + source: Source + buyer_id: Optional[str] = Field(default=None, max_length=32) + req_survey_id: str = Field(max_length=32) + req_cpi: Decimal = Field(decimal_places=5, lt=1000, ge=0) + started: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + + # These get set on creation, or updated when the wall event is finished. So + # they shouldn't really ever be NULL, but you don't have to pass them in + # on instantiation + survey_id: Optional[str] = Field(max_length=32, default=None) + cpi: Optional[Decimal] = Field(lt=1000, ge=0, default=None) + + # Gets set when a wall is "finished" + finished: Optional[AwareDatetimeISO] = Field(default=None) + status: Optional[Status] = None + status_code_1: Optional[StatusCode1] = None + status_code_2: Optional[WallStatusCode2] = None + + ext_status_code_1: Optional[str] = Field(default=None, max_length=32) + ext_status_code_2: Optional[str] = Field(default=None, max_length=32) + ext_status_code_3: Optional[str] = Field(default=None, max_length=32) + + report_value: Optional[ReportValue] = None + report_notes: Optional[str] = Field(default=None, max_length=255) + + # This is the most recent reconciliation status of the wall event. + # Possible values: 'ac' (adjusted to complete), 'af' (adj to fail) + # If a wall gets adjusted and adjusted back to its original status, the + # adjusted_status = None + adjusted_status: Optional[WallAdjustedStatus] = None + + # This is not really used, it is only important if the requested CPI + # doesn't match the adjusted amount, which shouldn't happen as no + # marketplaces support partial reconciles, and because the req_cpi and + # cpi are set whether or not the survey completed. + # - If adjusted_status = 'ac': adjusted_cpi is the amount paid (should + # equal the `cpi`) + # - If adjusted_status = 'af': adjusted_cpi is 0.00 + adjusted_cpi: Optional[Decimal] = Field(default=None, lt=1000, ge=0) + + # This timestamp gets updated every time there is an adjustment. Even if + # we flip-flop, this will be set (and adjusted_status will be None). + adjusted_timestamp: Optional[AwareDatetimeISO] = Field(default=None) + + # --- Validation --- + + # noinspection PyNestedDecorators + @field_validator("req_cpi", "cpi", "adjusted_cpi", mode="before") + @classmethod + def check_decimal_type(cls, v: Decimal) -> Decimal: + # pydantic is unable to set strict=True, so we'll do that manually here + if v is not None: + assert type(v) == Decimal, f"Must pass a Decimal, not a {type(v)}" + return v + + # noinspection PyNestedDecorators + @field_validator("req_cpi", "cpi", "adjusted_cpi", mode="after") + @classmethod + def check_cpi_decimal_places(cls, v: Decimal) -> Decimal: + if v is not None: + assert ( + v.as_tuple().exponent >= -5 + ), "Must have 5 or fewer decimal places ('XXX.YYYYY')" + return v + + @model_validator(mode="before") + @classmethod + def set_survey_id(cls, data: Any): + # This gets called upon assignment also, so we can't do this. IDK how + # to make it only run upon initialization ... + # + # if data.get('survey_id'): + # assert data.get('survey_id') == data['req_survey_id'], \ + # "upon init, survey_id must equal req_survey_id" + data["survey_id"] = ( + data["survey_id"] if data.get("survey_id") else data["req_survey_id"] + ) + return data + + # noinspection PyNestedDecorators + @field_validator("buyer_id", mode="before") + @classmethod + def set_buyer_id(cls, v: str) -> str: + # Max limit of 32 char, but I don't think we should fail if not, + # we'll just crop it + if v is not None: + v = v[:32] + return v + + @model_validator(mode="after") + def check_timestamps(self): + assert self.started <= datetime.now( + tz=timezone.utc + ), "Started must not be in the future" + if self.finished: + assert self.finished > self.started, "Finished must be after started" + assert self.finished - self.started <= timedelta( + minutes=90 + ), "Maximum wall event time is 90 min" + return self + + @model_validator(mode="after") + def check_ext_statuses(self): + if self.ext_status_code_3 is not None: + assert ( + self.ext_status_code_1 is not None + ), "Set ext_status_code_1 before ext_status_code_3" + assert ( + self.ext_status_code_2 is not None + ), "Set ext_status_code_2 before ext_status_code_3" + if self.ext_status_code_2 is not None: + assert ( + self.ext_status_code_1 is not None + ), "Set ext_status_code_1 before ext_status_code_2" + return self + + @model_validator(mode="after") + def check_status(self): + if self.status in {Status.COMPLETE, Status.FAIL}: + assert self.finished is not None, "finished should be set" + if self.status == Status.COMPLETE: + assert ( + self.status_code_1 == StatusCode1.COMPLETE + ), "status_code_1 should be COMPLETE" + return self + + @model_validator(mode="after") + def check_status_status_code_agreement(self) -> Self: + if self.status_code_1: + options = WALL_ALLOWED_STATUS_STATUS_CODE.get(self.status, {}) + assert ( + self.status_code_1 in options + ), f"If status is {self.status.value}, status_code_1 should be in {options}" + return self + + @model_validator(mode="after") + def check_status_code1_2_agreement(self) -> Self: + if self.status_code_2: + options = WALL_ALLOWED_STATUS_CODE_1_2.get(self.status_code_1, {}) + assert ( + self.status_code_2 in options + ), f"If status_code_1 is {self.status_code_1.value}, status_code_2 should be in {options}" + return self + + # --- Methods --- + + @classmethod + def from_json(cls, s: str) -> Self: + d = json.loads(s) + d["req_cpi"] = Decimal(d["req_cpi"]) + d["cpi"] = Decimal(d["cpi"]) if d.get("cpi") is not None else None + d["adjusted_cpi"] = ( + Decimal(d["adjusted_cpi"]) if d.get("adjusted_cpi") is not None else None + ) + return cls.model_validate(d) + + def is_visible(self) -> bool: + # I don't know what to call this. It's just checking if source != 'g', + # but it could be changed. We need this to determine if a complete + # on this wall event could make the session complete + return self.source != "g" + + def is_visible_complete(self) -> bool: + # I think we could also instead of checking source != 'g', we could + # check if `payout` is not NULL. This would basically have the same + # effect. + # + # return self.status == Status.COMPLETE and self.payout is not None + # and self.payout > 0 + return self.is_visible() and self.status == Status.COMPLETE + + def allow_session(self) -> bool: + if self.status == Status.COMPLETE: + return False + + return True + + def update(self, **kwargs) -> None: + """ + We might have to update multiple fields at once, or else we'll get + validation errors. There doesn't seem to be a clean way of doing this.. + + We need to be careful to not ignore a validation here b/c the + assignment will take either way. + + We shouldn't use this if the same object is being handled by multiple + threads. But I don't envision that happening. + + https://stackoverflow.com/questions/73718577/updating-multiple-pydantic-fields-that-are-validated-together + """ + self.model_config["validate_assignment"] = False + for k, v in kwargs.items(): + setattr(self, k, v) + self.model_config["validate_assignment"] = True + self.__class__.model_validate(self) + + return None + + def finish( + self, + status: Status, + status_code_1: StatusCode1, + status_code_2: Optional[WallStatusCode2] = None, + finished: Optional[datetime] = None, + ext_status_code_1: Optional[str] = None, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, + survey_id: Optional[str] = None, + cpi: Optional[Decimal] = None, + ) -> None: + + # This is just used in tests at the moment. This needs to be adjusted. + if finished is None: + finished = datetime.now(tz=timezone.utc) + + self.update( + status=status, + status_code_1=status_code_1, + status_code_2=status_code_2, + finished=finished, + ext_status_code_1=ext_status_code_1, + ext_status_code_2=ext_status_code_2, + ext_status_code_3=ext_status_code_3, + ) + + if survey_id is not None: + self.survey_id = survey_id + + if cpi is not None: + self.cpi = cpi + + return None + + def annotate_status_codes( + self, + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, + finished: Optional[datetime] = None, + ) -> None: + # This should be called by the wall manager in order to actually update db + from generalresearch import wall_status_codes + + status, status_code_1, status_code_2 = wall_status_codes.annotate_status_code( + self.source, + ext_status_code_1, + ext_status_code_2, + ext_status_code_3, + ) + if finished is None: + finished = datetime.now(tz=timezone.utc) + self.update( + status=status, + status_code_1=status_code_1, + status_code_2=status_code_2, + ext_status_code_1=ext_status_code_1, + ext_status_code_2=ext_status_code_2, + ext_status_code_3=ext_status_code_3, + finished=finished, + ) + + return None + + def is_soft_fail(self) -> bool: + from generalresearch import wall_status_codes + + assert self.status is not None, "status should not be None" + assert self.status_code_1 is not None, "status_code_1 should not be None" + return wall_status_codes.is_soft_fail(self) + + def stop_marketplace_session(self) -> bool: + from generalresearch import wall_status_codes + + assert self.status is not None, "status should not be None" + assert self.status_code_1 is not None, "status_code_1 should not be None" + return wall_status_codes.stop_marketplace_session(self) + + def get_status_after_adjustment(self) -> Status: + if self.adjusted_status in { + WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + WallAdjustedStatus.CPI_ADJUSTMENT, + }: + return Status.COMPLETE + elif self.adjusted_status == WallAdjustedStatus.ADJUSTED_TO_FAIL: + return Status.FAIL + elif self.status == Status.COMPLETE: + return Status.COMPLETE + else: + return Status.FAIL + + def get_cpi_after_adjustment(self) -> Decimal: + if self.adjusted_status in { + WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + WallAdjustedStatus.CPI_ADJUSTMENT, + }: + return self.adjusted_cpi + elif self.adjusted_status == WallAdjustedStatus.ADJUSTED_TO_FAIL: + return Decimal(0) + elif self.status == Status.COMPLETE: + return self.cpi + else: + return Decimal(0) + + def report( + self, + report_value: ReportValue, + report_notes: Optional[str] = None, + report_timestamp: Optional[AwareDatetime] = None, + ) -> None: + """When a wall event is reported: + + - IF wall event already has a status: + we only set the report_value and don't touch any timestamps + - ELSE (the wall event currently has no status): + we also set the finished timestamp and set the status to ABANDON + + Only 1 report is allowed. If this is called multiple times, the + report_value gets updated each time. + The report_timestamp shouldn't be used in practice. It is only used + to backfill from vendor_wall (where the report_timestamp is the + vw.finished for reported events) + + TODO: Transition this over to use the ReportTask pydantic model. + """ + report_timestamp = ( + report_timestamp if report_timestamp else datetime.now(tz=timezone.utc) + ) + if self.status is None and self.finished is None: + self.status = Status.ABANDON + self.finished = report_timestamp + self.report_value = report_value + self.report_notes = report_notes + + +class Wall(WallBase): + # Avoiding using the Session model here because of cyclic dependency issues. + # We just store the session id (integer) which we can use to look up a + # session. The Session model is the one who has the actual reference to a + # list of Wall models. + session_id: int + + # This is in the Session, but for convenience, add it here too, but just + # the user ID. Any user related operations should be done through the + # Session + user_id: int + + @model_validator(mode="before") + @classmethod + def set_cpi(cls, data: Any): + # if data.get('cpi'): + # assert data.get('cpi') == data['req_cpi'], \ + # "upon init, cpi must equal req_cpi" + data["cpi"] = data["cpi"] if data.get("cpi") else data["req_cpi"] + return data + + @model_validator(mode="after") + def check_adjusted_null(self) -> Self: + if self.adjusted_status is not None or self.adjusted_cpi is not None: + assert ( + self.adjusted_cpi is not None + ), "Set adjusted_cpi if the wall has been adjusted" + assert ( + self.adjusted_status is not None + ), "Set adjusted_status if the wall has been adjusted" + assert ( + self.adjusted_timestamp is not None + ), "Set adjusted_timestamp if the wall has been adjusted" + return self + + @model_validator(mode="after") + def check_adjusted_status_consistent(self) -> Self: + check_adjusted_status_consistent( + self.status, self.cpi, self.adjusted_status, self.adjusted_cpi + ) + return self + + # --- Properties --- + + @computed_field + @property + def elapsed(self) -> timedelta: + return self.finished - self.started if self.finished else None + + def to_json(self) -> str: + # We have to handle the computed_fields manually. I'm not sure if there is a better way + # to do this natively in pydantic... + d = self.model_dump(mode="json", exclude={"elapsed"}) + return json.dumps(d) + + def model_dump_mysql(self, *args, **kwargs) -> Dict: + # Generate a dictionary representation of the model, with special handling for datetimes + d = self.model_dump(mode="json", exclude={"elapsed"}, *args, **kwargs) + d["started"] = self.started.replace(tzinfo=None) + if self.finished: + d["finished"] = self.finished.replace(tzinfo=None) + if self.adjusted_timestamp: + d["adjusted_timestamp"] = self.adjusted_timestamp.replace(tzinfo=None) + return d + + +class WallOut(WallBase): + + # These get serialized to the enum name instead of the int value (for ease in UI) + status_code_1: Optional[Annotated[StatusCode1, EnumNameSerializer]] = Field( + default=None, + examples=[StatusCode1.COMPLETE.name], + description=StatusCode1.as_openapi_with_value_descriptions_name(), + ) + + status_code_2: Optional[Annotated[WallStatusCode2, EnumNameSerializer]] = Field( + default=None, + examples=[None], + description=WallStatusCode2.as_openapi_with_value_descriptions_name(), + ) + + # Exclude these 3 fields + cpi: Optional[Decimal] = Field(lt=1000, ge=0, default=None, exclude=True) + req_cpi: Optional[Decimal] = Field( + decimal_places=5, lt=1000, ge=0, default=None, exclude=True + ) + adjusted_cpi: Optional[Decimal] = Field(lt=1000, ge=0, default=None, exclude=True) + + # user_cpi is serialized to integer cents!!! + user_cpi: Optional[Decimal] = Field( + lt=1000, + ge=0, + default=None, + description=""" + The amount the user would earn from completing this task, if the status was a complete. + If the BP has no payout xform, the user_cpi is None. This is analogous to the session's + user_payout. + """, + examples=[123], + ) + + user_cpi_string: Optional[str] = Field( + default=None, + description="If a payout transformation is configured on this account, " + "this is the amount to display to the user", + examples=["123 Points"], + ) + + # Serialize user_cpi to an int + @field_serializer("user_cpi", return_type=int) + def serialize_user_cpi(self, v: Decimal, _info): + return decimal_to_int_cents(v) + + # If user_cpi is an int, put it back to a decimal + @field_validator("user_cpi", mode="before") + def deserialize_user_cpi(cls, v): + if isinstance(v, int): + return int_cents_to_decimal(v) + return v + + # noinspection PyNestedDecorators + @field_validator("user_cpi", mode="after") + @classmethod + def check_cpi_decimal_places(cls, v: Decimal) -> Decimal: + if v is not None: + assert ( + v.as_tuple().exponent >= -5 + ), "Must have 5 or fewer decimal places ('XXX.YYYYY')" + return v + + @field_validator("status_code_1", mode="before") + def transform_enum_name(cls, v: str | int) -> int: + # If we are serializing+deserializing this model (i.e. when we cache + # it), this fails because we've replaced the enum value with the + # name. Put it back here ... + if isinstance(v, str): + return StatusCode1[v] + return v + + @field_validator("status_code_2", mode="before") + def transform_enum_name2(cls, v: str | int) -> int: + # If we are serializing+deserializing this model (i.e. when we cache it), this fails because + # we've replaced the enum value with the name. But it back here ... + if isinstance(v, str): + return WallStatusCode2[v] + return v + + @classmethod + def from_wall(cls, wall: Wall, product: Product) -> Self: + d = wall.model_dump(exclude={"session_id", "user_id"}, round_trip=True) + d["user_cpi"] = None + if product.payout_config.payout_transformation is not None: + d["user_cpi"] = product.calculate_user_payment( + product.determine_bp_payment(wall.cpi) + ) + d["user_cpi_string"] = product.format_payout_format(d["user_cpi"]) + return cls.model_validate(d) + + +class WallAttempt(BaseModel): + """ + - We use this to de-duplicate entrances into surveys (prevent + sending the user in multiple times to the same survey). + - This could be just a Wall model instead, but avoiding doing that + because in this use case we only care about the "entrance", and + are not tracking/updating/caring about status/status_codes/finished. + - This is just a "minimal" Wall + """ + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + uuid: UUIDStr = Field() + source: Source = Field() + req_survey_id: str = Field(max_length=32) + started: AwareDatetimeISO = Field() + user_id: int = Field() + + @property + def task_sid(self) -> str: + return self.source.value + ":" + self.req_survey_id + + +class Session(BaseModel): + model_config = ConfigDict( + extra="forbid", validate_assignment=True, ser_json_timedelta="float" + ) + + # id will be None until db_create is called (or if this is instantiated + # from an existing session) + id: Optional[int] = None + uuid: UUIDStr = Field(default_factory=lambda: uuid4().hex) + user: User + started: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + + # This is the "bucket" the user clicked on to start this session. We only + # store the 4 fields: loi_min, loi_max, user_payout_min, user_payout_max + # in the db, but there may be other metadata associated with the bucket + # that is cached, such as the category. + clicked_bucket: Optional[Bucket] = Field(default=None) + + country_iso: Optional[str] = Field( + default=None, max_length=2, pattern=r"^[a-z]{2}$" + ) + device_type: Optional[DeviceType] = Field(default=None) + ip: Optional[IPvAnyAddressStr] = Field(default=None) + + url_metadata: Optional[Dict[str, str]] = Field(default=None) + + # Below here shouldn't be set upon initialization, or directly. + wall_events: List[Wall] = Field(default_factory=list) + + # Gets set when a session is "finished" + finished: Optional[AwareDatetimeISO] = Field(default=None) + + status: Optional[Status] = None + status_code_1: Optional[StatusCode1] = None + status_code_2: Optional[SessionStatusCode2] = None + + # There are two scenarios. Let's say the user payout transformation is + # 40% and this session pays out $1. + # + # a) user wallet enabled: The BP gets $0.60 & the user gets $0.40. + # b) user wallet disabled: The BP gets $1.00. We store what the + # user_payout should be ($0.40) only on this model, but it is not + # actually paid. + # + # This is potentially confusing b/c in case (b), the fields would be + # ($0.60) and ($0.40), but we paid $1 to the BP. + # + # To try to address this: The `payout` is the total amount that "we" are + # paying for this session. The `user_payout` "comes out" of the `payout`. + # So, in both cases (a) and (b), the payout is $1.00 and user_payout is + # $0.40. If the user wallet is enabled, we interpret this is as ($1-$0.40) + # going to the BP and ($0.40) to the user, and if the wallet is disabled, + # then the whole $1 goes to the BP and $0 to the user, but the $0.40 value + # is saved, so it can be displayed in the task status endpoint. + payout: Optional[Decimal] = Field(default=None, lt=1000, ge=0) + user_payout: Optional[Decimal] = Field(default=None, lt=1000, ge=0) + + # This is the most recent reconciliation status of the session. Generally, + # we would adjust this if the last survey in the session was adjusted + # from complete to incomplete. If any survey in the session was adjusted + # from fail -> complete (and the user didn't already get a complete) + # we'll adjust this to a complete. + adjusted_status: Optional[SessionAdjustedStatus] = None + + # If adjusted_status = 'ac': payout = 0 and adjusted_payout is the amount paid + # If adjusted_status = 'af': payout = the amount paid, adjusted_payout is 0.00 + # (the `payout` never changed, only the adjusted_payout can change). + adjusted_payout: Optional[Decimal] = Field(default=None, lt=1000, ge=0) + adjusted_user_payout: Optional[Decimal] = Field(default=None, lt=1000, ge=0) + + # This timestamp gets updated every time there is an adjustment (even if + # there are flip-flops). + adjusted_timestamp: Optional[AwareDatetimeISO] = Field(default=None) + + # --- Validation --- + + @field_validator( + "payout", + "user_payout", + "adjusted_payout", + "adjusted_user_payout", + mode="before", + ) + @classmethod + def check_decimal_type(cls, v: Decimal) -> Decimal: + # pydantic is unable to set strict=True, so we'll do that manually here + if v is not None: + assert type(v) == Decimal, f"Must pass a Decimal, not a {type(v)}" + return v + + @field_validator( + "payout", + "user_payout", + "adjusted_payout", + "adjusted_user_payout", + mode="after", + ) + @classmethod + def check_payout_decimal_places(cls, v: Decimal) -> Decimal: + if v is not None: + assert ( + v.as_tuple().exponent >= -2 + ), "Must have 2 or fewer decimal places ('XXX.YY')" + # explicitly make sure it is 2 decimal places, after checking that it is already 2 or less. + v = v.quantize(Decimal("0.00")) + return v + + @model_validator(mode="after") + def check_statuses(self): + if self.status_code_1 is None: + return self + if self.status == Status.FAIL: + assert self.status_code_1 in { + StatusCode1.SESSION_START_FAIL, + StatusCode1.SESSION_CONTINUE_FAIL, + StatusCode1.SESSION_START_QUALITY_FAIL, + StatusCode1.SESSION_CONTINUE_QUALITY_FAIL, + StatusCode1.BUYER_FAIL, + StatusCode1.BUYER_QUALITY_FAIL, + StatusCode1.PS_OVERQUOTA, + StatusCode1.PS_DUPLICATE, + StatusCode1.PS_FAIL, + StatusCode1.PS_QUALITY, + StatusCode1.PS_BLOCKED, + }, f"status_code_1 {self.status_code_1.name} invalid for status {self.status.value}" + elif self.status in {Status.TIMEOUT, Status.ABANDON}: + assert self.status_code_1 in { + StatusCode1.PS_ABANDON, + StatusCode1.GRS_ABANDON, + StatusCode1.BUYER_ABANDON, + }, f"status_code_1 {self.status_code_1.name} invalid for status {self.status.value}" + elif self.status == Status.COMPLETE: + assert ( + self.status_code_1 == StatusCode1.COMPLETE + ), f"status_code_1 {self.status_code_1.name} invalid for status {self.status.value}" + else: + assert self.status_code_1 is None, ( + f"status_code_1 {self.status_code_1.name} invalid for status " + f"{self.status.value}" + ) + return self + + @model_validator(mode="after") + def check_timestamps(self): + if self.finished: + assert self.finished > self.started, "finished is before started!" + return self + + @model_validator(mode="after") + def check_status_when_finished(self): + if self.finished: + assert self.status is not None, "once finished, we should have a status!" + return self + + @model_validator(mode="after") + def check_payout_when_complete(self): + if self.status == Status.COMPLETE: + assert ( + self.payout is not None + ), "there should be a payout if the session is marked complete" + return self + + # @model_validator(mode='after') + # def check_payouts(self): + # if self.status == 'c': + # assert self.payout > 0 + # if self.user_payout is not None: + # assert self.payout > self.user_payout + # else: + # assert self.payout is None + # return self + + @field_validator("wall_events") + @classmethod + def check_wall_events(cls, wall_events: List[Wall]): + # Note: this can't work on modifications as pydantic/python doesn't + # know if a list is mutated. We have to run it manually, or hide + # the self.wall_events attr and wrap all access + assert sorted(wall_events, key=lambda x: x.started) == wall_events, "sorted" + assert len({w.uuid for w in wall_events}) == len(wall_events) + return wall_events + + @model_validator(mode="after") + def check_adjusted(self): + if self.adjusted_status is not None or self.adjusted_payout is not None: + assert ( + self.adjusted_payout is not None + ), "Set adjusted_payout if the session has been adjusted" + assert ( + self.adjusted_status is not None + ), "Set adjusted_status if the session has been adjusted" + assert ( + self.adjusted_timestamp is not None + ), "Set adjusted_timestamp if the session has been adjusted" + if self.adjusted_user_payout is not None: + assert ( + self.adjusted_payout is not None + ), "Set adjusted_payout if adjusted_user_payout is set" + # NOTE: the other way around is NOT required! + # (the adjusted_user_payout / user_payout can be null) + return self + + @model_validator(mode="after") + def check_adjusted_status(self): + if self.adjusted_status == SessionAdjustedStatus.ADJUSTED_TO_COMPLETE: + assert self.status != Status.COMPLETE, ( + "If a Session was originally completed, reversed, and then re-reversed to complete," + "the adjusted_status should be null" + ) + if self.adjusted_status == SessionAdjustedStatus.ADJUSTED_TO_FAIL: + assert ( + self.status == Status.COMPLETE + ), "Session.status must be COMPLETE for the adjusted_status to be ADJUSTED_TO_FAIL" + return self + + # --- Properties --- + + @property + def user_id(self): + return self.user.user_id + + @property + def elapsed(self) -> timedelta: + return self.finished - self.started if self.finished else None + + # --- Methods --- + + def update(self, **kwargs) -> None: + """We might have to update multiple fields at once, or else we'll + get validation errors. There doesn't seem to be a clean way of + doing this ... + + We need to be careful to not ignore a validation here b/c the + assignment will take either way. + + We shouldn't use this if the same object is being handled by + multiple threads. But I don't envision that happening. + + https://stackoverflow.com/questions/73718577/updating-multiple-pydantic-fields-that-are-validated-together + """ + self.model_config["validate_assignment"] = False + for k, v in kwargs.items(): + setattr(self, k, v) + self.model_config["validate_assignment"] = True + self.__class__.model_validate(self) + + def model_dump_mysql( + self, *args, **kwargs + ) -> Dict[str, Union[str, int, datetime, float, None]]: + + # Generate a dictionary representation of the model, with special + # handling for datetimes, and nested models such as User & Bucket + + d = self.model_dump(mode="json", *args, **kwargs) + d["started"] = self.started.replace(tzinfo=None) + + if self.finished: + d["finished"] = self.finished.replace(tzinfo=None) + + if self.adjusted_timestamp: + d["adjusted_timestamp"] = self.adjusted_timestamp.replace(tzinfo=None) + + d["url_metadata_json"] = json.dumps(d.pop("url_metadata", {})) + clicked_bucket = d.pop("clicked_bucket") or {} + d.update( + { + k: clicked_bucket.get(k) + for k in [ + "loi_min", + "loi_max", + "user_payout_min", + "user_payout_max", + ] + } + ) + + # pymysql will complain about various values being in the dictionary + # that gets used as the connection.execute(..., args=) so we want to + # explicit about what comes back. pymysql tries to escape everything + # even if it isn't used in the actual query + d["user_id"] = self.user_id + + d.pop("user", None) + d.pop("wall_events", None) + + return d + + def append_wall_event(self, w: Wall) -> None: + wall_events = self.wall_events + [w] + # the assignment causes check_wall_events to run + self.wall_events = wall_events + + def finalize_timeout(self, task_timeout_seconds: int = 5400) -> None: + """Would usually be called on a session that has no status, presumably + by some task, when this session has timed out. Results in setting + of status. + + On a session that already has a status, this does nothing. + """ + # We need the BP's default "task timeout". Assuming this is 90 min + last_wall = self.wall_events[-1] + if ( + last_wall.status is None + and self.status is None + and datetime.now(tz=timezone.utc) + > self.started + timedelta(seconds=task_timeout_seconds) + ): + last_wall.status = Status.TIMEOUT + self.status = Status.TIMEOUT + + def determine_session_status(self) -> Tuple[Status, StatusCode1]: + """Given a list of wall events, determine what the session status + should be. If this is called, it is because the Session is *over*, + or it has timed out. + + Note: this does not support multiple completes within a session. + This should be configurable or else it is very confusing... If I + get a complete and then abandon, the BP will see it as an abandon, + but we see a complete. We should only mark it as a complete if the + BP will get a POST or we know they poll the statuses endpoint. + """ + # # If there are no wall events, it is a GRL Fail + # if len(self.wall_events) == 0: + # self.status = Status.FAIL + # self.status_code_1 = StatusCode1.SESSION_START_FAIL + # return None + + # The last wall event, regardless of GRS or external. If it is an + # abandon, the session's status is abandon + self.finalize_timeout() + last_wall = self.wall_events[-1] + assert last_wall.status is not None, "Session is still active!" + + if last_wall.status in {Status.ABANDON, Status.TIMEOUT}: + status = last_wall.status + if last_wall.is_visible(): + status_code_1 = StatusCode1.BUYER_ABANDON + else: + status_code_1 = StatusCode1.GRS_ABANDON + return status, status_code_1 + + # The last non-GRS wall event + last_wall = self.get_last_visible_wall() + + # If there are only hidden wall events, it is a SESSION_CONTINUE_FAIL + if last_wall is None: + return Status.FAIL, StatusCode1.SESSION_CONTINUE_FAIL + + # Report the status of the last wall event + elif last_wall.status == Status.COMPLETE: + return Status.COMPLETE, StatusCode1.COMPLETE + + elif last_wall.status == Status.FAIL: + status = Status.FAIL + status_code_1s = {x.status_code_1 for x in self.wall_events} + + if StatusCode1.BUYER_FAIL in status_code_1s: + status_code_1 = StatusCode1.BUYER_FAIL + elif StatusCode1.BUYER_QUALITY_FAIL in status_code_1s: + status_code_1 = StatusCode1.BUYER_QUALITY_FAIL + else: + status_code_1 = last_wall.status_code_1 + + if status_code_1 == StatusCode1.UNKNOWN: + status_code_1 = StatusCode1.BUYER_FAIL + elif status_code_1 in { + StatusCode1.MARKETPLACE_FAIL, + StatusCode1.GRS_QUALITY_FAIL, + }: + status_code_1 = StatusCode1.SESSION_CONTINUE_QUALITY_FAIL + elif status_code_1 == StatusCode1.GRS_FAIL: + status_code_1 = StatusCode1.SESSION_CONTINUE_FAIL + return status, status_code_1 + + return Status.FAIL, StatusCode1.BUYER_FAIL + + def get_last_visible_wall(self): + return next( + iter(filter(lambda x: x.is_visible(), self.wall_events[::-1])), None + ) + + def should_end_session( + self, max_session_len: timedelta, max_session_hard_retry: int + ) -> bool: + + now = datetime.now(tz=timezone.utc) + last_wall = self.get_last_visible_wall() + + if last_wall and last_wall.status == Status.COMPLETE: + return True + + if (now - self.started) > max_session_len: + return True + + hard_retry_count = sum(not wall.is_soft_fail() for wall in self.wall_events) + if hard_retry_count >= max_session_hard_retry: + return True + + # Hard limit of 40 wall events per session + if len(self.wall_events) >= 40: + return True + + return False + + def determine_payments( + self, + thl_ledger_manager: Optional["ThlLedgerManager"] = None, + ) -> Tuple[Decimal, Decimal, Decimal, Optional[Decimal]]: + # How much we should get paid by the MPs for all completes in this + # session (usually 0 or 1 completes) + thl_net: Decimal = Decimal( + sum(wall.cpi for wall in self.wall_events if wall.is_visible_complete()) + ) + + product = self.user.product + # Handle brokerage product payouts + bp_pay: Decimal = product.determine_bp_payment(thl_net) + commission_amount: Decimal = thl_net - bp_pay + + # Some payout transformations may want this: + user_wallet_balance = None + if ( + product.payout_config.payout_transformation is not None + and product.payout_config.payout_transformation.f + == "payout_transformation_amt" + ): + assert thl_ledger_manager is not None + amt = thl_ledger_manager.get_user_wallet_balance(user=self.user) + user_wallet_balance = Decimal(amt / 100).quantize(Decimal("0.01")) + user_pay: Optional[Decimal] = product.calculate_user_payment( + bp_pay, user_wallet_balance=user_wallet_balance + ) + + return thl_net, commission_amount, bp_pay, user_pay + + def get_thl_net(self) -> Decimal: + assert self.wall_events, "populate wall_events!" + assert self.user.product, "prefetch user.product!" + walls = [w for w in self.wall_events if w.source != Source.GRS] + completed_walls = [ + w for w in walls if w.get_status_after_adjustment() == Status.COMPLETE + ] + + if completed_walls: + return Decimal(sum([w.get_cpi_after_adjustment() for w in completed_walls])) + + else: + return Decimal(0) + + def determine_new_status_and_payouts( + self, + ) -> Tuple[Status, Decimal, Optional[Decimal]]: + """Session is adjusted any time one of the wall events is. Assuming + status adjustments happened on a session's wall events. Calculate + if any status changes are need to the session. + + - It is possible that complicated outcomes occur. Such as, e.g. + originally [Fail, Fail, Complete ($2)], the complete gets + reversed (Session Adj to fail $0), then the 2nd fail gets + changed to complete --> [Fail, Complete ($1), Fail] (Session Adj + to Complete $1). But since it was originally a complete, the + final status is Payout Adjustment ($2 -> $1). + + - In summary, possible outcomes: orig complete -> adj to fail, orig + fail -> adj to complete, or orig complete -> payout adj. And + also adjustments being reverted back to normal (complete -> adj + to fail -> complete), etc. + + returns: status, bp_payout, Optional[user_payout] + """ + assert self.wall_events, "populate wall_events!" + assert self.user.product, "prefetch user.product!" + + product = self.user.product + thl_net = self.get_thl_net() + + if thl_net: + adjusted_payout = product.determine_bp_payment(thl_net) + adjusted_user_payout = product.calculate_user_payment(adjusted_payout) + return Status.COMPLETE, adjusted_payout, adjusted_user_payout + + else: + if product.payout_config.payout_transformation is None: + adjusted_user_payout = None + else: + adjusted_user_payout = Decimal(0) + + return Status.FAIL, Decimal(0), adjusted_user_payout + + def adjust_status(self) -> bool: + """A complete can go to an adj_fail, or a payout adjustment. It can + then go back to a complete. + + A session that was orig a failure, can go to adj_complete. But it + cannot go to payout_adj. + """ + adjusted_timestamp = max( + [x.adjusted_timestamp for x in self.wall_events if x.adjusted_timestamp], + default=None, + ) + + new_status, new_payout, new_user_payout = ( + self.determine_new_status_and_payouts() + ) + current_status = self.get_status_after_adjustment() + current_payout = self.get_payout_after_adjustment() + original_payout = self.payout + + if (current_status == Status.FAIL and new_status == Status.FAIL) or ( + current_status == Status.COMPLETE + and new_status == Status.COMPLETE + and new_payout == current_payout + ): + # If the session is originally a complete, or a fail adjusted to + # complete, and we want to change it to complete and the payout + # is the same, (or is fail, or adjusted to fail, and we want to + # change to fail): do nothing. + logger.info(f"adjust_status: session {self.uuid} is already {new_status}") + return False + + if self.status == Status.COMPLETE: + assert ( + self.adjusted_status != SessionAdjustedStatus.ADJUSTED_TO_COMPLETE + ), "Can't have complete adj to complete" + if self.adjusted_status in { + None, + SessionAdjustedStatus.PAYOUT_ADJUSTMENT, + }: + if new_status == Status.COMPLETE: + if original_payout == new_payout: + self.update( + adjusted_status=None, + adjusted_payout=None, + adjusted_user_payout=None, + adjusted_timestamp=adjusted_timestamp, + ) + + elif self.get_payout_after_adjustment() != new_payout: + # Complete -> Complete (different payout) OR + # Complete -> Complete (different payout) -> Complete (different payout) + self.update( + adjusted_status=SessionAdjustedStatus.PAYOUT_ADJUSTMENT, + adjusted_payout=new_payout, + adjusted_timestamp=adjusted_timestamp, + adjusted_user_payout=new_user_payout, + ) + + else: + # Complete -> Complete (same payout). do nothing + raise ValueError("should never reach here") + + else: + # Complete -> Fail OR Complete -> Complete (diff payout) -> Fail + self.update( + adjusted_status=SessionAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_payout=new_payout, + adjusted_timestamp=adjusted_timestamp, + adjusted_user_payout=new_user_payout, + ) + else: + # adj_status = adj to fail + if new_status == Status.FAIL: + # Complete -> Fail -> Fail (do nothing) + raise ValueError("should never reach here") + + else: + # Complete -> Fail -> Complete + if original_payout == new_payout: + self.update( + adjusted_status=None, + adjusted_payout=None, + adjusted_timestamp=adjusted_timestamp, + adjusted_user_payout=None, + ) + else: + # complete -> fail -> complete (different payout) + self.update( + adjusted_status=SessionAdjustedStatus.PAYOUT_ADJUSTMENT, + adjusted_payout=new_payout, + adjusted_timestamp=adjusted_timestamp, + adjusted_user_payout=new_user_payout, + ) + else: + # originally a failure. possible adj_status -> {None, adj to complete} + assert self.adjusted_status not in { + SessionAdjustedStatus.ADJUSTED_TO_FAIL, + SessionAdjustedStatus.PAYOUT_ADJUSTMENT, + }, "Can't have fail adj to fail or payout adj" + if self.adjusted_status == SessionAdjustedStatus.ADJUSTED_TO_COMPLETE: + if new_status == Status.FAIL: + # Fail -> Complete -> Fail + self.update( + adjusted_status=None, + adjusted_payout=None, + adjusted_timestamp=adjusted_timestamp, + adjusted_user_payout=None, + ) + else: + # Fail -> Complete + if new_payout != self.adjusted_payout: + # Fail -> Complete -> Complete (new payout) + # If a session is originally Fail. And then its adjusted + # to complete, and then a 2nd wall in that session is + # also adj to complete, is the session adj to complete + # or payout_adj? I'm sticking with adj to complete, + # and the adj payout changed. + self.update( + adjusted_status=SessionAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_payout=new_payout, + adjusted_timestamp=adjusted_timestamp, + adjusted_user_payout=new_user_payout, + ) + else: + # Fail -> Complete -> Complete (same payout) + raise ValueError("should never reach here") + else: + # adj status is None + if new_status == Status.COMPLETE: + # Fail -> Complete + self.update( + adjusted_status=SessionAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_payout=new_payout, + adjusted_timestamp=adjusted_timestamp, + adjusted_user_payout=new_user_payout, + ) + return True + + def get_status_after_adjustment(self) -> Status: + if self.adjusted_status in { + SessionAdjustedStatus.ADJUSTED_TO_COMPLETE, + SessionAdjustedStatus.PAYOUT_ADJUSTMENT, + }: + return Status.COMPLETE + elif self.adjusted_status == SessionAdjustedStatus.ADJUSTED_TO_FAIL: + return Status.FAIL + elif self.status == Status.COMPLETE: + return Status.COMPLETE + else: + return Status.FAIL + + def get_payout_after_adjustment(self) -> Decimal: + if self.adjusted_status is not None: + return self.adjusted_payout + else: + return self.payout or Decimal(0) + + def get_user_payout_after_adjustment(self) -> Optional[Decimal]: + if self.adjusted_status is not None: + return self.adjusted_user_payout + else: + return self.user_payout + + +def check_adjusted_status_consistent( + status: Status, + cpi: Decimal, + adjusted_status: WallAdjustedStatus, + adjusted_cpi: Decimal, +): + if adjusted_status == WallAdjustedStatus.ADJUSTED_TO_COMPLETE: + assert status != Status.COMPLETE, ( + "If a Wall was originally completed, reversed, and then re-reversed to complete," + "the adjusted_status should be null" + ) + assert adjusted_cpi == cpi, "adjusted_cpi should be equal to the original cpi" + + elif adjusted_status == WallAdjustedStatus.ADJUSTED_TO_FAIL: + assert ( + status == Status.COMPLETE + ), "Wall.status must be COMPLETE for the adjusted_status to be ADJUSTED_TO_FAIL" + assert ( + adjusted_cpi == 0 + ), "adjusted_cpi should be 0 if adjusted_status is ADJUSTED_TO_FAIL" + + elif adjusted_status == WallAdjustedStatus.CPI_ADJUSTMENT: + # the original status is allowed to be anything + # the adjusted cpi should be something different + assert ( + adjusted_cpi != 0 and adjusted_cpi != cpi + ), "If CPI_ADJUSTMENT, the adjusted_cpi should be different from the original cpi or 0" + + elif adjusted_status is None: + assert adjusted_cpi is None, "incompatible adjusted values" + + +def check_adjusted_status_wall_consistent( + status: Status, + cpi: Optional[Decimal] = None, + adjusted_status: Optional[WallAdjustedStatus] = None, + adjusted_cpi: Optional[Decimal] = None, + new_adjusted_status: Optional[WallAdjustedStatus] = None, + new_adjusted_cpi: Optional[Decimal] = None, +) -> Tuple[bool, str]: + """ + Raises an AssertionError if inconsistent. + + - status, cpi, adjusted_status, adjusted_cpi are the wall's CURRENT values + - new_adjusted_status & new_adjusted_cpi are attempting to be set + + We are checking if the adjustment is allowed, based on the attempt's current status. + """ + try: + _check_adjusted_status_wall_consistent( + status=status, + cpi=cpi, + adjusted_status=adjusted_status, + adjusted_cpi=adjusted_cpi, + new_adjusted_status=new_adjusted_status, + new_adjusted_cpi=new_adjusted_cpi, + ) + except AssertionError as e: + return False, str(e) + return True, "" + + +def _check_adjusted_status_wall_consistent( + status: Status, + cpi: Optional[Decimal] = None, + adjusted_status: Optional[WallAdjustedStatus] = None, + adjusted_cpi: Optional[Decimal] = None, + new_adjusted_status: Optional[WallAdjustedStatus] = None, + new_adjusted_cpi: Optional[Decimal] = None, +) -> None: + """ + See check_adjusted_status_wall_consistent + """ + # Check that we're actually changing something + if adjusted_status == new_adjusted_status and adjusted_cpi == new_adjusted_cpi: + raise AssertionError(f"attempt is already {adjusted_status=}, {adjusted_cpi=}") + + # adjusted_status/adjusted_cpi agreement + check_adjusted_status_consistent( + status=status, + cpi=cpi, + adjusted_status=new_adjusted_status, + adjusted_cpi=new_adjusted_cpi, + ) + + # status / adjusted_status agreement + if status == Status.COMPLETE: + assert ( + new_adjusted_status != WallAdjustedStatus.ADJUSTED_TO_COMPLETE + ), "adjusted status can't be ADJUSTED_TO_COMPLETE if the status is already COMPLETE" + elif status == Status.FAIL: + assert ( + new_adjusted_status != WallAdjustedStatus.ADJUSTED_TO_FAIL + ), "adjusted status can't be ADJUSTED_TO_FAIL if the status is already FAIL" + else: + # status is None/timeout/abandon, which we treat as a fail anyway + assert ( + new_adjusted_status != WallAdjustedStatus.ADJUSTED_TO_FAIL + ), "attempt is already a failure" + + # adjusted_status / new_adjusted_status agreement + if new_adjusted_status == WallAdjustedStatus.CPI_ADJUSTMENT: + assert ( + new_adjusted_cpi != adjusted_cpi + ), f"adjusted_cpi is already {adjusted_cpi}" diff --git a/generalresearch/models/thl/soft_pair.py b/generalresearch/models/thl/soft_pair.py new file mode 100644 index 0000000..b8bc1b4 --- /dev/null +++ b/generalresearch/models/thl/soft_pair.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Set + +from generalresearch.models import Source +from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, +) + + +class SoftPairResultType(int, Enum): + # type=1 - Eligible unconditionally + UNCONDITIONAL = 1 + # type=3 - Eligible conditionally. Must include question_ids. + CONDITIONAL = 3 + # type=2 - Eligible conditionally, includes the option_ids that would make + # the pairing eligible. This is unused in practice because it is often + # impossible to describe the relationship + UNUSED = 2 + # This isn't used in practice b/c the survey just wouldn't be returned. + # This is just for testing/validation. + INELIGIBLE = 4 + + +@dataclass +class SoftPairResult: + # We use this within the Marketplace's get_opportunities_soft_pairing + # "hot path". We instantiate a SoftPairResult for each survey-result. + # There is a lot of overhead in pydantic and that causes the call to be + # kind of slow, especially for spectrum. There isn't a lot of validation + # needed here, so I think it is a reasonable tradeoff to make this a + # dataclass instead. + pair_type: SoftPairResultType + source: Source + survey_id: str + conditions: Optional[Set[MarketplaceCondition]] = None + + @property + def survey_sid(self) -> str: + return self.source + ":" + self.survey_id + + @property + def grpc_string(self) -> Optional[str]: + # This is what is expected by thl-grpc in a mp_pb2.MPOpportunityIDListSoftPairing response (grpc) + if self.pair_type == SoftPairResultType.UNCONDITIONAL: + return self.survey_id + elif self.pair_type == SoftPairResultType.CONDITIONAL: + return ( + self.survey_id + + ":" + + ";".join(sorted(set([c.question_id for c in self.conditions]))) + ) + else: + return None + + +@dataclass +class SoftPairResultOut: + # This is used by the thl-grpc to parse the grpc message + pair_type: SoftPairResultType + source: Source + survey_id: str + question_ids: Optional[Set[str]] = None + + @property + def survey_sid(self) -> str: + return self.source + ":" + self.survey_id diff --git a/generalresearch/models/thl/stats.py b/generalresearch/models/thl/stats.py new file mode 100644 index 0000000..fc8be59 --- /dev/null +++ b/generalresearch/models/thl/stats.py @@ -0,0 +1,43 @@ +from typing import Optional + +from pydantic import BaseModel, Field, model_validator, computed_field + + +class StatisticalSummary(BaseModel): + """ + Stores the five-number summary of a dataset. This consists of the minimum, + first quartile (Q1), median (Q2), third quartile (Q3), and maximum. + + Mean is optional. + """ + + min: int = Field() + max: int = Field() + mean: Optional[int] = Field(default=None) + q1: int = Field() + q2: int = Field() + q3: int = Field() + + @model_validator(mode="after") + def check_values(self): + assert self.max >= self.min, "invalid max/min" + assert self.q1 >= self.min, "invalid q1/min" + assert self.q2 >= self.q1, "invalid q1/q2" + assert self.q3 >= self.q2, "invalid q2/q3" + assert self.max >= self.q3, "invalid q3/max" + return self + + @property + def iqr(self) -> int: + # Interquartile Range (IQR) + return self.q3 - self.q1 + + @computed_field + @property + def lower_whisker(self) -> int: + return round(self.q1 - (1.5 * self.iqr)) + + @computed_field + @property + def upper_whisker(self) -> int: + return round(self.q3 + (1.5 * self.iqr)) diff --git a/generalresearch/models/thl/supplier_tag.py b/generalresearch/models/thl/supplier_tag.py new file mode 100644 index 0000000..ad84c9b --- /dev/null +++ b/generalresearch/models/thl/supplier_tag.py @@ -0,0 +1,16 @@ +from enum import Enum + + +class SupplierTag(str, Enum): + """Available tags which can be used to annotate supplier traffic + + Note: should not include commas! + """ + + MOBILE = "mobile" + JS_OFFERWALL = "js-offerwall" + DOI = "double-opt-in" + SSO = "single-sign-on" + PHONE_VERIFIED = "phone-number-verified" + TEST_A = "test-a" + TEST_B = "test-b" diff --git a/generalresearch/models/thl/survey/__init__.py b/generalresearch/models/thl/survey/__init__.py new file mode 100644 index 0000000..fd78091 --- /dev/null +++ b/generalresearch/models/thl/survey/__init__.py @@ -0,0 +1,225 @@ +from abc import ABC, abstractmethod +from decimal import Decimal +from itertools import product +from typing import Set, Optional, List, Dict, Type + +from more_itertools import flatten +from pydantic import BaseModel, Field + +from generalresearch.models import Source +from generalresearch.models.thl.demographics import ( + DemographicTarget, + Gender, + AgeGroup, +) +from generalresearch.models.thl.locales import ( + CountryISOs, + LanguageISOs, + CountryISO, + LanguageISO, +) +from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, +) + + +class MarketplaceTask(BaseModel, ABC): + """This is called a "Task" even though generally it represents a survey + because some marketplaces have non-standard nesting structures. The + task is the unit of work we target a user for. So if a marketplace has + a survey with quotas that have different CPIs and we target a user to + a specific quota, then the quota is the unit of work. + """ + + # model_config = ConfigDict(extra="allow") + + cpi: Decimal = Field(gt=0, le=100, decimal_places=2, max_digits=5) + + # In some marketplaces, a task can be targeted to one or more country or language. + country_isos: CountryISOs = Field(min_items=1) + language_isos: LanguageISOs = Field(min_items=1) + + # For convenience, we'll store a single country/lang field as well, since + # 99% of tasks across all marketplaces, even those that support multiple, + # are only targeted to 1 country/lang. Which specific country/lang is + # stored here, for tasks that target more than 1, is undefined. + country_iso: CountryISO = Field() + language_iso: LanguageISO = Field() + + # These should be overloaded with more specific type hints + buyer_id: Optional[str] = Field(min_length=1, max_length=32, default=None) + # This is in seconds + bid_loi: Optional[int] = Field(default=None, le=90 * 60) + bid_ir: Optional[float] = Field(default=None, ge=0, le=1) + + # This should be an "abstract field", but there is no way to do that, so + # just listing it here. It should be overridden by the implementation + source: Source = Field() + # This should also + used_question_ids: Set[str] = Field(default_factory=set) + + # This is a "special" key to store all conditions that are used (as + # "condition_hashes") throughout this survey. In the reduced + # representation of this task (nearly always, for db i/o, in global_vars) + # this field will be null. + conditions: Optional[Dict[str, MarketplaceCondition]] = Field(default=None) + + @property + @abstractmethod + def internal_id(self) -> str: + """This is the value that is used for this survey within the marketplace. Typically, + this is survey_id/survey_number. Morning is quota_id, repdata: stream_id. + """ + ... + + @property + def external_id(self) -> str: + return f"{self.source.value}:{self.internal_id}" + + @property + @abstractmethod + def all_hashes(self) -> Set[str]: ... + + @property + @abstractmethod + def is_open(self) -> bool: ... + + @property + @abstractmethod + def is_live(self) -> bool: ... + + def __hash__(self): + # We need this so this obj can be added into a set. + return hash(self.external_id) + + def is_unchanged(self, other) -> bool: + # Avoiding overloading __eq__ because it looks kind of complicated? I + # want to be explicit that this is not testing object equivalence, + # just that the objects don't require any db updates. We also exclude + # conditions b/c this is just the condition_hash definitions + return self.model_dump() == other.model_dump() + + def is_changed(self, other) -> bool: + return not self.is_unchanged(other) + + @property + @abstractmethod + def condition_model(self) -> Type[MarketplaceCondition]: + """ + The Condition Model for this survey class + """ + pass + + @property + @abstractmethod + def age_question(self) -> str: + """ + The age question ID + """ + pass + + @property + @abstractmethod + def marketplace_genders( + self, + ) -> Dict[Gender, Optional[MarketplaceCondition]]: + """ + Mapping of generic Gender to the marketplace condition for that gender + """ + pass + + @property + def marketplace_age_groups( + self, + ) -> Dict[AgeGroup, Optional[MarketplaceCondition]]: + """ + Mapping of generic age groups to the marketplace condition for those ages + """ + return { + ag: self.condition_model( + question_id=self.age_question, + values=list(map(str, range(ag.low, ag.high + 1))), + value_type=ConditionValueType.LIST, + ) + for ag in AgeGroup + } + + @property + def targeted_ages(self) -> Set[str]: + assert self.conditions is not None, "conditions must be populated" + cs = [self.conditions[k] for k in self.all_hashes if k in self.conditions] + age_cs = [c for c in cs if c.question_id == self.age_question] + age_list = [c for c in age_cs if c.value_type == ConditionValueType.LIST] + age_range = [c for c in age_cs if c.value_type == ConditionValueType.RANGE] + age_values = set(flatten([c.values for c in age_list])) + for c in age_range: + ranges = c.values_ranges + for r in ranges: + r = list(r) + if r[0] == float("-inf"): + r[0] = 0 + if r[1] == float("inf"): + r[1] = 120 + age_values.update(set(map(str, range(int(r[0]), int(r[1]) + 1)))) + return age_values + + @property + def targeted_age_groups(self) -> Set[AgeGroup]: + age_values = self.targeted_ages + age_conditions = self.marketplace_age_groups + age_targeting = set() + for ag, condition in age_conditions.items(): + if condition.evaluate_criterion({condition.question_id: age_values}): + age_targeting.add(ag) + if len(age_targeting) == 0: + # survey with no age targeting is implicitly targeting any age? or only >18? idk + age_targeting.update( + { + AgeGroup.AGE_18_TO_35, + AgeGroup.AGE_36_TO_55, + AgeGroup.AGE_56_TO_75, + AgeGroup.AGE_OVER_75, + } + ) + return age_targeting + + @property + def targeted_genders(self) -> Set[Gender]: + mp_genders = self.marketplace_genders + gender_targeting = set() + if mp_genders[Gender.MALE].criterion_hash in self.all_hashes: + gender_targeting.add(Gender.MALE) + if mp_genders[Gender.FEMALE].criterion_hash in self.all_hashes: + gender_targeting.add(Gender.FEMALE) + if len(gender_targeting) == 0: + gender_targeting.update({Gender.MALE, Gender.FEMALE}) + return gender_targeting + + @property + def demographic_targets(self) -> List[DemographicTarget]: + targets = [DemographicTarget(country="*", gender="*", age_group="*")] + + gt = self.targeted_genders + for gender in gt: + targets.append(DemographicTarget(country="*", gender=gender, age_group="*")) + + at = self.targeted_age_groups + for age_grp in at: + targets.append( + DemographicTarget(country="*", gender="*", age_group=age_grp) + ) + + for gender, age_grp in product(gt, at): + targets.append( + DemographicTarget(country="*", gender=gender, age_group=age_grp) + ) + + for c in self.country_isos: + orig_targets = targets.copy() + country_targets = [ + DemographicTarget(country=c, gender=t.gender, age_group=t.age_group) + for t in orig_targets + ] + targets.extend(country_targets) + return targets diff --git a/generalresearch/models/thl/survey/buyer.py b/generalresearch/models/thl/survey/buyer.py new file mode 100644 index 0000000..0d235ed --- /dev/null +++ b/generalresearch/models/thl/survey/buyer.py @@ -0,0 +1,218 @@ +from datetime import timezone, datetime +from decimal import Decimal +from typing import Optional, Annotated + +from math import log +from pydantic import ( + model_validator, + BaseModel, + ConfigDict, + Field, + PositiveInt, + NonNegativeInt, + computed_field, +) +from scipy.stats import beta as beta_dist + +from generalresearch.models import Source +from generalresearch.models.custom_types import ( + AwareDatetimeISO, + CountryISOLike, + UUIDStr, +) + + +class Buyer(BaseModel): + """ + The entity that commissions and pays for a task and uses the resulting data or insights. + """ + + model_config = ConfigDict(validate_assignment=True) + + id: Optional[PositiveInt] = Field(default=None, exclude=True) + # todo: need to add to db + uuid: Optional[UUIDStr] = Field(default=None) + + source: Source = Field( + description="The marketplace this buyer is on.\n" + Source.as_openapi() + ) + code: str = Field( + min_length=1, + max_length=128, + description="The internal code on this marketplace for this buyer", + ) + label: Optional[str] = Field(default=None, max_length=255) + created: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When this entry was made, or when the buyer was first seen", + ) + + @property + def natural_key(self) -> str: + return f"{self.source.value}:{self.code}" + + @property + def source_code(self) -> str: + return self.natural_key + + +class BuyerActivity(BaseModel): + """ + Information about live Tasks from this buyer + """ + + live_task_count: PositiveInt = Field() + avg_cpi: Decimal = Field() + avg_score: float = Field() + max_score: float = Field() + + +class BuyerWithDetail(BaseModel): + """For API Responses""" + + buyer: Buyer = Field() + activity: Optional[BuyerActivity] = Field(default=None) + + +class BuyerCountryStat(BaseModel): + """ + Aggregated performance summary for a specific buyer within a single + country. Metrics are computed across all observed tasks for this (buyer, + country) pair and include risk-adjusted conversion and dropoff estimates, + LOI deviation relative to survey expectations, quality signals, and a + composite ranking score. All rate-based metrics use Bayesian shrinkage to + reduce small-sample noise. The score is intended for relative ranking + among buyers within comparable contexts. + """ + + model_config = ConfigDict(validate_assignment=False) + + # ---- Identity ---- + buyer_id: Optional[PositiveInt] = Field( + default=None, + exclude=True, + description="This is the pk of the Buyer object in the db", + ) + country_iso: Optional[CountryISOLike] = Field( + default=None, + description="If null, this is a weighted average across all countries", + examples=["us"], + ) + + # --- For lookup / de-normalization --- + source: Source = Field( + description="The marketplace this buyer is on.\n" + Source.as_openapi(), + examples=[Source.DYNATA], + ) + code: str = Field( + min_length=1, + max_length=128, + description="The internal code on this marketplace for this buyer", + examples=["abc123"], + ) + + # --- Observation Counts --- + task_count: NonNegativeInt = Field( + description="The count of observed tasks", examples=[100] + ) + + # ---- Distributions ---- + conversion_alpha: float = Field( + gt=0, description="Alpha parameter from a Beta distribution", examples=[40.0] + ) + conversion_beta: float = Field( + gt=0, description="Beta parameter from a Beta distribution", examples=[190.0] + ) + + @computed_field( + description="Penalized mean (20th percentile) of conversion", + examples=[0.15264573817318744], + ) + @property + def conversion_p20(self) -> Annotated[float, Field(ge=0, le=1)]: + return float(beta_dist.ppf(0.2, self.conversion_alpha, self.conversion_beta)) + + dropoff_alpha: float = Field( + gt=0, description="Alpha parameter from a Beta distribution", examples=[20.0] + ) + dropoff_beta: float = Field( + gt=0, description="Beta parameter from a Beta distribution", examples=[50.0] + ) + + @computed_field( + description="Penalized mean (60th percentile) of the dropoff/abandonment rate", + examples=[0.29748756969632695], + ) + @property + def dropoff_p60(self) -> Annotated[float, Field(ge=0, le=1)]: + return float(beta_dist.ppf(0.6, self.dropoff_alpha, self.dropoff_beta)) + + # --- Expectations --- + loi_excess_ratio: float = Field( + ge=0, + description=( + "Volume-weighted average of (observed LOI / expected LOI) " + "across all completed tasks. " + "1.0 means exactly as expected. " + ">1 longer than expected. <1 shorter." + ), + examples=[1], + ) + + # ---- Risk / quality ---- + long_fail_rate: float = Field( + ge=0, + le=10, + description="Lower values indicate tasks are likely to late terminate", + examples=[1], + ) + user_report_coeff: float = Field( + ge=0, le=1, description="Lower values indicate more user reports", examples=[1] + ) + recon_likelihood: float = Field( + ge=0, le=1, description="Likelihood tasks will get reconciled", examples=[0.05] + ) + + # ---- Scoring ---- + score: float = Field( + description="Composite score calculated from all of the individual features", + default=None, + examples=[-5.329389837486194], + ) + + @model_validator(mode="after") + def compute_score(self): + eps = 1e-12 + + # ---- Conversion (logit) ---- + c = min(max(self.conversion_p20, eps), 1 - eps) + C = log(c / (1 - c)) + + # ---- Dropoff ---- + d = min(max(self.dropoff_p60, 0.0), 1.0 - eps) + D = log(1 - d) + + # ---- LOI symmetric penalty ---- + loi = max(self.loi_excess_ratio, eps) + L = -abs(log(loi)) + + # ---- Long fail ---- + F = -log(1 + max(self.long_fail_rate, 0.0)) + + # ---- User report ---- + R = log(max(self.user_report_coeff, eps)) + + # ---- Reconciliation ---- + Q = log(max(self.recon_likelihood, eps)) + + raw_score = 2.0 * C + 1.5 * D + L + F + R + Q + + # ---- Small-sample shrinkage ---- + n_eff = self.conversion_alpha + self.conversion_beta + k = 100.0 # tuning parameter + + weight = n_eff / (n_eff + k) + + self.score = weight * raw_score + + return self diff --git a/generalresearch/models/thl/survey/condition.py b/generalresearch/models/thl/survey/condition.py new file mode 100644 index 0000000..a60b034 --- /dev/null +++ b/generalresearch/models/thl/survey/condition.py @@ -0,0 +1,337 @@ +import hashlib +from abc import ABC +from enum import Enum +from functools import cached_property +from typing import List, Dict, Set, Optional, Any, Tuple + +from pydantic import ( + BaseModel, + Field, + computed_field, + ConfigDict, + field_validator, + model_validator, + StringConstraints, + PrivateAttr, +) +from typing_extensions import Self, Annotated + +from generalresearch.models import LogicalOperator + +MarketplaceConditionHash = Annotated[ + str, StringConstraints(min_length=7, max_length=7, pattern=r"^[a-f0-9]+$") +] + + +class ConditionValueType(int, Enum): + # The values are a list of strings that are matched entirely. + # e.g. ['a', 'b', 'c'] + LIST = 1 + + # The values are a list of ranges. e.g. ["19-25", "35-40"], + RANGE = 2 + + # The values should be empty, we only care that the user has an answer + # for question_id + ANSWERED = 3 + + # The condition cannot be defined in any way that can be understood by us. + # The question_id may not even be one that is exposed to us. This is + # solely to indicate there is additional profiling on a survey. + # `values` is ignored. + INEFFABLE = 4 + + # The condition is checking a user's membership / group IDs; this is + # analogous to a recontact, where specific users are targeted. The + # question_id here should be null. In dynata this is called an invite + # collection. + RECONTACT = 5 + + +class MarketplaceCondition(BaseModel, ABC): + """This represents a targeting condition that can be attached to a + qualification or quota + """ + + model_config = ConfigDict(populate_by_name=True) + + logical_operator: LogicalOperator = Field(default=LogicalOperator.OR) + value_type: ConditionValueType = Field() + negate: bool = Field(default=False) + + # ---- These fields should be overridden in the implementor --- + question_id: Optional[str] = Field(frozen=True) + values: List[str] = Field() + + # These question_ids get converted to list value types + _CONVERT_LIST_TO_RANGE: List[str] = PrivateAttr(default_factory=list) + + @field_validator("values", mode="after") + def sort_values(cls, values: List[str]): + return sorted(values) + + @field_validator("values", mode="after") + def check_values_lower(cls, values: List[str]): + assert values == [s.lower() for s in values], "values must be lowercase" + return values + + @field_validator("logical_operator", mode="after") + def explain_not(cls, logical_operator: LogicalOperator): + assert logical_operator != LogicalOperator.NOT, ( + "Use LogicalOperator.OR/AND and negate=True in place of LogicalOperator.NOT. Otherwise the meaning" + "is ambiguous. Do you want people who don't have a (CAT and DOG), or either don't have a CAT or" + "don't have a DOG?" + ) + return logical_operator + + @model_validator(mode="after") + def type_values_default(self) -> Self: + if self.value_type in { + ConditionValueType.ANSWERED, + ConditionValueType.INEFFABLE, + }: + assert not self.values, "values must be empty" + return self + + @model_validator(mode="after") + def check_type_question_id_agreement(self) -> Self: + if self.value_type in {ConditionValueType.RECONTACT}: + assert ( + self.question_id is None + ), "question_id should be NULL for ConditionValueType.RECONTACT" + else: + assert self.question_id is not None, "question_id must be set" + return self + + @model_validator(mode="after") + def check_type_values_agreement(self) -> Self: + if self.value_type in { + ConditionValueType.LIST, + ConditionValueType.RANGE, + }: + assert len(self.values) > 0, "values must not be empty" + if self.value_type == ConditionValueType.RANGE: + assert ( + self.logical_operator == LogicalOperator.OR + ), "Only OR is allowed with ranges" + assert all( + s.count("-") == 1 for s in self.values + ), "range values must have one hyphen" + for v in self.values: + assert all( + self.is_numeric_including_inf(x) for x in v.split("-") + ), f"invalid range: {v}" + elif self.value_type in { + ConditionValueType.ANSWERED, + ConditionValueType.INEFFABLE, + }: + assert len(self.values) == 0, "values must be empty" + return self + + @model_validator(mode="after") + def change_ranges_to_list(self) -> Self: + """ + Decide to do this per marketplace. + Some use ranges for ages. Ranges take longer to evaluate b/c they have to be converted + into ints and then require multiple evaluations. Just convert into a list of values + which only requires one easy match. + e.g. convert age values from '20-22|20-21|25-26' to '|20|21|22|25|26|' + """ + if ( + self.question_id in self._CONVERT_LIST_TO_RANGE + and self.value_type == ConditionValueType.RANGE + ): + try: + values = [tuple(map(int, v.split("-"))) for v in self.values] + assert all(len(x) == 2 for x in values) + except (ValueError, AssertionError): + return self + self.values = sorted( + {str(val) for tupl in values for val in range(tupl[0], tupl[1] + 1)} + ) + self.value_type = ConditionValueType.LIST + return self + + @computed_field + @cached_property + def criterion_hash(self) -> MarketplaceConditionHash: + # This model is frozen, so the criterion string can/will never change. + return self._hash_string(self._criterion_str) + + @property + def hash(self) -> MarketplaceConditionHash: + return self.criterion_hash + + @cached_property + def values_str(self) -> str: + return f"|{'|'.join(self.values)}|".lower() if self.values else "" + + @cached_property + def _criterion_str(self) -> str: + # e.g. '42;OR;False;1;|18|19|20|21|' + return ";".join( + [ + str(self.question_id), + self.logical_operator, + str(self.negate), + str(self.value_type.value), + self.values_str, + ] + ) + + @cached_property + def values_minified(self) -> str: + if len(self.values) > 6: + v = self.values[:3] + ["…"] + self.values[-3:] + else: + v = self.values + return f"|{'|'.join(v)}|" + + @cached_property + def minified(self): + return ";".join( + [ + str(self.question_id), + self.logical_operator, + str(self.negate), + str(self.value_type.value), + self.values_minified, + ] + ) + + @computed_field + @property + def value_len(self) -> int: + return len(self.values) + + @computed_field + @property + def sizeof(self) -> int: + return sum(len(v) for v in self.values) + + @cached_property + def values_ranges(self) -> List[Tuple[float, float]]: + assert ( + self.value_type == ConditionValueType.RANGE + ), "only call this method when value_type is RANGE" + values = [tuple(map(float, v.split("-"))) for v in self.values] + # Treat 'inf' as negative infinity if it is a lower bound. + values = [ + (float("-inf") if start == float("inf") else start, end) + for start, end in values + ] + return values + + @classmethod + def _hash_string(cls, s: str) -> str: + return hashlib.md5(s.encode()).hexdigest()[:7] + + @classmethod + def from_mysql(cls, d: Dict[str, Any]) -> Self: + d["values"] = d["values"][1:-1].split("|") if d["values"][1:-1] else [] + return cls.model_validate(d) + + def to_mysql(self) -> Dict[str, str]: + # This is what is stored in the xxx_criterion table + d = self.model_dump( + mode="json", + include={ + "question_id", + "criterion_hash", + "value_type", + "logical_operator", + "negate", + }, + ) + d["values"] = self.values_str + return d + + @staticmethod + def is_numeric_including_inf(s) -> bool: + try: + float(s) + return True + except ValueError: + return False + + def __hash__(self) -> int: + # this is so it can be put into a set / dictionary key + return hash(self.criterion_hash) + + def __repr__(self) -> str: + # Fancy repr that only shows the first and last 3 values if there are more than 6. + repr_args = list(self.__repr_args__()) + for n, (k, v) in enumerate(repr_args): + if k == "values": + if v and len(v) > 6: + v = v[:3] + ["…"] + v[-3:] + repr_args[n] = ("values", v) + join_str = ", " + repr_str = join_str.join( + repr(v) if a is None else f"{a}={v!r}" for a, v in repr_args + ) + return f"{self.__repr_name__()}({repr_str})" + + def evaluate_criterion( + self, + user_qas: Dict[str, Set[str]], + user_groups: Optional[Set[str]] = None, + ) -> Optional[bool]: + """Given this user's MRPQs, do they "pass" this criterion? + + :param user_qas: user's quals. Looks like {'qid1': {'ans1', 'ans2'}} + :param user_groups: a list of "groups" the user is associated with. + This is only used for RECONTACT conditions. + :return: True, False, or None (means we don't know) + """ + if self.value_type == ConditionValueType.RECONTACT: + assert ( + user_groups is not None + ), "user_groups must be known for RECONTACT conditions" + if self.logical_operator == LogicalOperator.OR: + passes = any(x in self.values for x in user_groups) + return not passes if self.negate else passes + elif self.logical_operator == LogicalOperator.AND: + passes = all(x in user_groups for x in self.values) + return not passes if self.negate else passes + + # It is unclear what we should do with INEFFABLE conditions. We keep + # them b/c we want to know that they exist, but we have nothing to + # check, so they'll just return True always + if self.value_type == ConditionValueType.INEFFABLE: + return True + + answer = user_qas.get(self.question_id) + if self.value_type == ConditionValueType.ANSWERED: + if answer is None: + return self.negate + else: + return not self.negate + + if answer is None: + return None + + if self.value_type == ConditionValueType.LIST: + if self.logical_operator == LogicalOperator.OR: + passes = any(f"|{x}|" in self.values_str for x in answer) + return not passes if self.negate else passes + elif self.logical_operator == LogicalOperator.AND: + passes = all(x in answer for x in self.values) + return not passes if self.negate else passes + + if self.value_type == ConditionValueType.RANGE: + assert ( + self.logical_operator == LogicalOperator.OR + ), "Only OR is allowed with ranges" + # The answer and values are assumed here to be numeric. The values + # are expected to be two numerics separated by a dash. e.g. + # "1-10". The interval is always closed (includes the endpoints, + # gte/lte). Unbounded ranges are also supported, indicated by + # "inf". e.g. 'inf-100' (meaning -Infinity to 100), or '10-inf' + try: + answer = list(map(float, answer)) + except ValueError: + return None + values = self.values_ranges + passes = any([start <= x <= end for start, end in values for x in answer]) + return not passes if self.negate else passes diff --git a/generalresearch/models/thl/survey/model.py b/generalresearch/models/thl/survey/model.py new file mode 100644 index 0000000..9e7a402 --- /dev/null +++ b/generalresearch/models/thl/survey/model.py @@ -0,0 +1,321 @@ +from datetime import timezone, datetime +from decimal import Decimal +from typing import Optional, List, Tuple, Dict +from typing_extensions import Annotated + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PositiveInt, + model_validator, + computed_field, + NonNegativeInt, + NonNegativeFloat, + field_validator, +) + +from generalresearch.managers.thl.buyer import Buyer +from generalresearch.models import Source +from generalresearch.models.custom_types import ( + AwareDatetimeISO, + CountryISOLike, + SurveyKey, + EnumNameSerializer, + PropertyCode, +) +from generalresearch.models.thl.category import Category +from generalresearch.models.thl.definitions import Status, StatusCode1 +from generalresearch.models.thl.pagination import Page + + +class SurveyCategoryModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + + category: Category = Field() + strength: Optional[float] = Field(default=None) + + +class SurveyEligibilityDefinition(BaseModel): + """ + Survey-level declaration of which questions + may contribute to eligibility. + + This does NOT encode rules or qualifying values. + """ + + # References a marketplace-specific question + property_codes: Tuple[PropertyCode, ...] = Field(default_factory=tuple) + + @model_validator(mode="after") + def sort_question_ids(self): + self.property_codes = tuple(sorted(self.property_codes)) + return self + + +class Survey(BaseModel): + model_config = ConfigDict(validate_assignment=True) + + id: Optional[PositiveInt] = Field(default=None, exclude=True) + + source: Source = Field() + survey_id: str = Field(min_length=1, max_length=32, examples=["127492892"]) + + buyer_id: Optional[int] = Field( + default=None, exclude=True, description="This is the DB's fk id" + ) + # ---v So the fk id can be looked up from the code + buyer_code: Optional[str] = Field( + min_length=1, max_length=128, default=None, examples=["124"] + ) + + created_at: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + updated_at: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + + is_live: bool = Field(default=True) + is_recontact: bool = Field(default=False) + + categories: List[SurveyCategoryModel] = Field(default_factory=list) + + eligibility_criteria: Optional[SurveyEligibilityDefinition] = Field(default=None) + + @property + def natural_key(self) -> SurveyKey: + return f"{self.source.value}:{self.survey_id}" + + @property + def buyer(self): + assert self.buyer_code is not None + return Buyer(source=self.source, code=self.buyer_code) + + @property + def buyer_natural_key(self) -> str: + return self.buyer.natural_key + + @model_validator(mode="after") + def category_strengths(self): + if any(s.strength is not None for s in self.categories): + assert all( + s.strength is not None for s in self.categories + ), "If any category strength is not None, all should be set" + assert ( + abs(sum(s.strength for s in self.categories) - 1) <= 0.01 + ), "Strengths should some to 1" + return self + + def model_dump_sql(self): + d = self.model_dump(mode="json", exclude={"categories"}) + d["buyer_id"] = self.buyer_id + d["eligibility_criteria"] = None + if self.eligibility_criteria is not None: + d["eligibility_criteria"] = self.eligibility_criteria.model_dump_json() + return d + + +class SurveyStat(BaseModel): + model_config = ConfigDict(validate_assignment=True) + + id: Optional[PositiveInt] = Field(exclude=True, default=None) + + # ---- Identity ---- + survey_id: Optional[PositiveInt] = Field( + default=None, + exclude=True, + description="This is the pk of the Survey object in the db", + ) + quota_id: str = Field( + default="__all__", + max_length=32, + description="The marketplace's internal quota id", + ) + country_iso: CountryISOLike = Field() + version: int = Field(ge=0) + + # --- For lookup / de-normalization, to avoid potentially costly + # joins on marketplace_survey table --- + survey_source: Optional[Source] = Field(default=None, exclude=True) + survey_survey_id: Optional[str] = Field( + default=None, exclude=True, min_length=1, max_length=32 + ) + survey_is_live: bool = Field(default=True, exclude=True) + + # ---- Pricing / cutoffs ---- + cpi: Decimal = Field(decimal_places=5, lt=1000, ge=0) + complete_too_fast_cutoff: PositiveInt = Field(description="in seconds") + + # ---- Distributions ---- + + prescreen_conv_alpha: float = Field(..., ge=0) + prescreen_conv_beta: float = Field(..., ge=0) + + conv_alpha: float = Field(..., ge=0) + conv_beta: float = Field(..., ge=0) + + dropoff_alpha: float = Field(..., ge=0) + dropoff_beta: float = Field(..., ge=0) + + completion_time_mu: float = Field(...) + completion_time_sigma: float = Field(..., gt=0) + + # ---- Eligibility (probabilistic) ---- + + mobile_eligible_alpha: float = Field(..., ge=0) + mobile_eligible_beta: float = Field(..., ge=0) + + desktop_eligible_alpha: float = Field(..., ge=0) + desktop_eligible_beta: float = Field(..., ge=0) + + tablet_eligible_alpha: float = Field(..., ge=0) + tablet_eligible_beta: float = Field(..., ge=0) + + # ---- Risk / quality ---- + + long_fail_rate: float = Field(..., ge=0, le=1) + user_report_coeff: float = Field(..., ge=0, le=1) + recon_likelihood: float = Field(..., ge=0, le=1) + + # ---- Scoring ---- + + score_x0: float = Field(...) + score_x1: float = Field(...) + score: float = Field(...) + + # ---- Metadata ---- + + updated_at: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + + @property + def natural_key(self) -> str: + assert self.survey_source is not None and self.survey_survey_id is not None + return f"{self.survey_source.value}:{self.survey_survey_id}:{self.quota_id}:{self.country_iso}:{self.version}" + + @property + def survey_natural_key(self) -> str: + # same as Survey.natural_key + return f"{self.survey_source.value}:{self.survey_survey_id}" + + @property + def unique_key(self) -> Tuple[int, Optional[str], str, int]: + return self.survey_id, self.quota_id, self.country_iso, self.version + + def model_dump_sql(self): + d = self.model_dump(mode="json") + d["survey_id"] = self.survey_id + d["survey_is_live"] = self.survey_is_live + d["survey_survey_id"] = self.survey_survey_id + d["survey_source"] = self.survey_source + return d + + +class TaskActivity(BaseModel): + model_config = ConfigDict(validate_assignment=True) + + source: Source = Field() + survey_id: str = Field(min_length=1, max_length=32, examples=["127492892"]) + + status_counts: Dict[Status, NonNegativeInt] = Field(default_factory=dict) + status_code_1_counts: Dict[StatusCode1, NonNegativeInt] = Field( + default_factory=dict + ) + in_progress_count: NonNegativeInt = Field( + default=0, + description="Count of entrances that have no Status and were entered within the past 90 minutes", + ) + last_complete: Optional[AwareDatetimeISO] = Field(default=None) + last_entrance: Optional[AwareDatetimeISO] = Field(default=None) + + @computed_field + @property + def total_finished(self) -> int: + return sum(self.status_counts.values()) + + @computed_field + @property + def total_entrances(self) -> int: + return self.total_finished + self.in_progress_count + + # ---- percentages ---- + @computed_field + @property + def status_percentages(self) -> Dict[Status, NonNegativeFloat]: + total = self.total_finished + if total == 0: + return {} + return {k: round(v / total, 3) for k, v in self.status_counts.items()} + + @computed_field + @property + def status_code_1_percentages(self) -> Dict[StatusCode1, NonNegativeFloat]: + total = sum(self.status_code_1_counts.values()) + if total == 0: + return {} + return {k: round(v / total, 3) for k, v in self.status_code_1_counts.items()} + + +class TaskActivityPublic(BaseModel): + source: Optional[Source] = Field(exclude=True, default=None) + survey_id: Optional[str] = Field( + min_length=1, max_length=32, examples=["127492892"], exclude=True, default=None + ) + + status_percentages: Dict[Status, NonNegativeFloat] = Field(default_factory=dict) + status_code_1_percentages: Dict[ + Annotated[StatusCode1, EnumNameSerializer], NonNegativeFloat + ] = Field(default_factory=dict) + + last_complete: Optional[AwareDatetimeISO] = Field(default=None) + last_entrance: Optional[AwareDatetimeISO] = Field(default=None) + + @field_validator("status_code_1_percentages", mode="before") + def transform_enum_name_pct(cls, value: dict) -> dict: + # If we are serializing+deserializing this model (i.e. when we cache + # it), this fails because we've replaced the enum value with the + # name. Put it back here ... + return { + StatusCode1[k] if isinstance(k, str) else k: v for k, v in value.items() + } + + @property + def natural_key(self) -> SurveyKey: + return f"{self.source.value}:{self.survey_id}" + + +class TaskActivityPrivate(TaskActivityPublic): + status_counts: Dict[Status, int] = Field(default_factory=dict) + status_code_1_counts: Dict[Annotated[StatusCode1, EnumNameSerializer], int] = Field( + default_factory=dict + ) + in_progress_count: NonNegativeInt = Field( + description="Count of entrances that have no Status and were entered within the past 90 minutes", + default=0, + ) + + @field_validator("status_code_1_counts", mode="before") + def transform_enum_name_cnt(cls, value: dict) -> dict: + # If we are serializing+deserializing this model (i.e. when we cache + # it), this fails because we've replaced the enum value with the + # name. Put it back here ... + return { + StatusCode1[k] if isinstance(k, str) else k: v for k, v in value.items() + } + + +class TaskWithDetail(BaseModel): + """For API Responses""" + + task: Survey = Field() + stats: List[SurveyStat] = Field(default_factory=list) + activity_global: Optional[TaskActivityPublic] = Field(default=None) + activity_product: Optional[TaskActivityPrivate] = Field(default=None) + + +class TasksWithDetail(Page): + """For API Responses""" + + tasks: List[TaskWithDetail] = Field(default_factory=list) diff --git a/generalresearch/models/thl/survey/penalty.py b/generalresearch/models/thl/survey/penalty.py new file mode 100644 index 0000000..915254f --- /dev/null +++ b/generalresearch/models/thl/survey/penalty.py @@ -0,0 +1,63 @@ +import abc +from datetime import timezone, datetime +from typing import List, Literal, Union + +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter +from typing_extensions import Annotated + +from generalresearch.models import Source +from generalresearch.models.custom_types import ( + UUIDStr, + AwareDatetimeISO, +) + + +class SurveyPenalty(BaseModel, abc.ABC): + """ + BP or Team-specific penalization to a survey, for the purpose of + rate-limiting entrances from a BP or Team into a survey + """ + + model_config = ConfigDict(validate_assignment=True, extra="forbid") + + kind: Literal["bp", "team"] + + source: Source = Field() + survey_id: str = Field(min_length=1, max_length=32) + + penalty: float = Field(ge=0, le=1) + + created: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + + @property + def sid(self): + return f"{self.source.value}:{self.survey_id}" + + +class BPSurveyPenalty(SurveyPenalty): + """ + BP-specific penalization to a survey, for the purpose of + rate-limiting entrances from a BP into a survey + """ + + kind: Literal["bp"] = "bp" + product_id: UUIDStr = Field(examples=["be40ff316fd4450dbaa53c13cc0cba04"]) + + +class TeamSurveyPenalty(SurveyPenalty): + """ + Team-specific penalization to a survey, for the purpose of + rate-limiting entrances from a Team into a survey + """ + + kind: Literal["team"] = "team" + team_id: UUIDStr = Field(examples=["2ac57f2264334af7874be56a06ef75db"]) + + +Penalty = Annotated[ + Union[BPSurveyPenalty, TeamSurveyPenalty], + Field(discriminator="kind"), +] +PenaltyListAdapter = TypeAdapter(List[Penalty]) diff --git a/generalresearch/models/thl/survey/task_collection.py b/generalresearch/models/thl/survey/task_collection.py new file mode 100644 index 0000000..68c2ff3 --- /dev/null +++ b/generalresearch/models/thl/survey/task_collection.py @@ -0,0 +1,60 @@ +import copy +import json +import logging +from typing import List + +import pandas as pd +import pandera +from pandera import DataFrameSchema +from pydantic import Field, ConfigDict, BaseModel, model_validator + +from generalresearch.models.thl.survey import MarketplaceTask + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +class TaskCollection(BaseModel): + """I'm calling this a task and not a survey or whatever b/c it will be + exposed externally to this project and we don't care what the internal + structure is. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + # overload this with the correct type! + items: List[MarketplaceTask] + df: pd.DataFrame = Field(default_factory=pd.DataFrame) + + # overload this with the correct schema! + _schema: DataFrameSchema + + @model_validator(mode="after") + def handle_df(self): + df = self.to_df() + try: + df = self._schema.validate(df, lazy=True) + except pandera.errors.SchemaErrors as exc: + idx = exc.failure_cases["index"] + if len(idx) >= len(df) * 0.10: + raise exc + logger.info(f"{self.__repr_name__()}:handle_df:{json.dumps(exc.message)}") + df.drop(index=list(idx), inplace=True) + # we need to redo the validation after removing failing rows! + df = self._schema.validate(df) + self.df = df + return self + + def to_df(self) -> pd.DataFrame: ... + + +def create_empty_df_from_schema(schema: DataFrameSchema) -> pd.DataFrame: + # Create an empty df from the schema. We have to do this or else a plain empty df + # will fail validating non-nullable columns b/c they don't have a default. + schema = copy.deepcopy(schema) + schema.coerce = True + schema.add_missing_columns = True + index = pd.Index([], name=schema.index.name, dtype=schema.index.dtype.type) + empty_df = schema.coerce_dtype(pd.DataFrame(columns=[*schema.columns], index=index)) + return empty_df diff --git a/generalresearch/models/thl/synchronize_global_vars.py b/generalresearch/models/thl/synchronize_global_vars.py new file mode 100644 index 0000000..4a587e1 --- /dev/null +++ b/generalresearch/models/thl/synchronize_global_vars.py @@ -0,0 +1,17 @@ +from typing import List + +from pydantic import Field, BaseModel + + +class SynchronizeGlobalVarsMsg(BaseModel): + """Used within a Redis pub/sub to clear/invalidate internal caches, + typically on objects stored in GLOBAL_VARS or functools caches. + """ + + # Specifies the key / subkey to be acted upon. + # For example ["mrpq", 123] would apply to GLOBAL_VARS["mrpq"][123] + key_path: List[str] = Field() + # e.g. GLOBAL_VARS["mrpq"].pop(123, None) + pop: bool = Field(default=False) + # e.g. GLOBAL_VARS["mrpq"][123].clear() + clear: bool = Field(default=False) diff --git a/generalresearch/models/thl/task_adjustment.py b/generalresearch/models/thl/task_adjustment.py new file mode 100644 index 0000000..afb2a6a --- /dev/null +++ b/generalresearch/models/thl/task_adjustment.py @@ -0,0 +1,89 @@ +from datetime import datetime, timezone +from decimal import Decimal +from typing import Optional +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator + +from generalresearch.models import Source, MAX_INT32 +from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.thl.definitions import ( + WallAdjustedStatus, +) + + +class TaskAdjustmentEvent(BaseModel): + """ + This represents a notification that we've received from a marketplace + about the adjustment of a Wall's status. We might have multiple events + for the same wall event. The Wall.adjusted_status stores the latest + status, while the thl_taskadjustment table stores each time a + change occurred. + """ + + model_config = ConfigDict(validate_assignment=True, extra="forbid") + + uuid: UUIDStr = Field(default_factory=lambda: uuid4().hex) + created: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When this event was created in the db", + ) + alerted: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When we were notified about this change", + ) + + """ + Please read carefully. These 3 are to be interpreted differently than how + they are used on a Wall/Session.adjusted_status or .adjusted_cpi. + + Scenario: + - Wall was originally a fail, and then it is adjusted to complete, + and then adjusted back to fail. + - We'll have two TaskAdjustmentEvents, one with adjusted_status + ADJUSTED_TO_COMPLETE and one with ADJUSTED_TO_FAIL. + - The Wall's adjusted_status will be NULL! (b/c it was adjusted back + to what it originally was) + + In other words, the TaskAdjustmentEvent's adjustment records the direction + of the adjustment, regardless of the current state of the Wall, whereas + the Wall.adjusted_status is the latest value. + """ + + adjusted_status: WallAdjustedStatus = Field() + # If WallAdjustedStatus == ac, amount is positive, af amount is negative + # Same thing as with adjusted_status, the amount is the "amount the cpi + # is changing by"!! + + amount: Optional[Decimal] = Field(lt=1000, ge=-1000, default=None) + ext_status_code: Optional[str] = Field(default=None, max_length=32) + + wall_uuid: UUIDStr = Field(description="The wall event being adjusted") + + # These 4 are just for convenience (repeated from the Wall/Session) + user_id: PositiveInt = Field( + lt=MAX_INT32, description="The user who did this wall event" + ) + started: AwareDatetimeISO = Field( + description="The wall event's started", + ) + source: Source = Field() + survey_id: str = Field(max_length=32) + + @model_validator(mode="after") + def validate_amount(self): + if self.adjusted_status == WallAdjustedStatus.ADJUSTED_TO_FAIL: + assert self.amount < 0, ( + "The amount is the amount the cpi is changing by, so for a adj to fail," + "the amount should be negative" + ) + elif self.adjusted_status == WallAdjustedStatus.ADJUSTED_TO_COMPLETE: + assert self.amount > 0, ( + "The amount is the amount the cpi is changing by, so for a adj to complete," + "the amount should be positive" + ) + elif self.adjusted_status == WallAdjustedStatus.CONFIRMED_COMPLETE: + assert self.amount is None, "cannot change the cpi for a confirmed complete" + elif self.adjusted_status == WallAdjustedStatus.CPI_ADJUSTMENT: + assert self.amount is not None + return self diff --git a/generalresearch/models/thl/task_status.py b/generalresearch/models/thl/task_status.py new file mode 100644 index 0000000..f476286 --- /dev/null +++ b/generalresearch/models/thl/task_status.py @@ -0,0 +1,292 @@ +from datetime import datetime +from typing import Dict, Optional, Any, Literal, Annotated, List + +from pydantic import ( + BaseModel, + Field, + model_validator, + NonNegativeInt, + computed_field, + field_validator, + field_serializer, +) +from typing_extensions import Self + +from generalresearch.models.custom_types import ( + UUIDStr, + AwareDatetimeISO, + EnumNameSerializer, +) +from generalresearch.models.thl import decimal_to_int_cents +from generalresearch.models.thl.definitions import ( + StatusCode1, + SessionStatusCode2, + Status, + SessionAdjustedStatus, +) +from generalresearch.models.thl.pagination import Page +from generalresearch.models.thl.payout_format import ( + PayoutFormatType, + PayoutFormatOptionalField, +) +from generalresearch.models.thl.product import ( + PayoutTransformation, + Product, +) +from generalresearch.models.thl.session import WallOut, Session + +# API uses the ints, b/c this is what the grpc returned originally ... +STATUS_MAP = { + None: 1, # generalresearch_pb2.STATUS_ENTER + Status.ABANDON: 2, # generalresearch_pb2.STATUS_INCOMPLETE + Status.TIMEOUT: 2, # generalresearch_pb2.STATUS_INCOMPLETE + Status.FAIL: 2, # generalresearch_pb2.STATUS_INCOMPLETE + Status.COMPLETE: 3, # generalresearch_pb2.STATUS_COMPLETE +} +REVERSE_STATUS_MAP = {v: k for k, v in STATUS_MAP.items()} + + +class TaskStatusResponse(BaseModel): + """The status of a session""" + + tsid: UUIDStr = Field( + description="A unique identifier for the session", + examples=["a3848e0a53d64f68a74ced5f61b6eb68"], + ) + + product_id: UUIDStr = Field( + description="The BP ID of the associated respondent", + examples=["1188cb21cb6741d79f614f6d02e9bc2a"], + ) + + product_user_id: str = Field( + min_length=3, + max_length=128, + description="A unique identifier for each user, which is set by the " + "Supplier", + examples=["app-user-9329ebd"], + ) + + started: AwareDatetimeISO = Field(description="When the session was started") + + finished: Optional[AwareDatetimeISO] = Field( + default=None, description="When the session was finished" + ) + + # This uses the grpc's Status enum. It gets serialized to an int. + status: Optional[Status] = Field( + default=None, + examples=[3], + description="The outcome of a session." + "\n - 0 - UNKNOWN\n" + " - 1 - ENTER (the user has not yet returned)\n" + " - 2 - INCOMPLETE (the user failed)\n" + " - 3 - COMPLETE (the user completed the task)", + ) + + payout: Optional[NonNegativeInt] = Field( + default=None, + lt=100_000, + examples=[500], + description="The amount paid to the supplier, in integer USD cents", + ) + + user_payout: Optional[NonNegativeInt] = Field( + default=None, + lt=100_000, + description="If a payout transformation is configured on this account, " + "this is the amount the user should earn, in integer USD cents", + examples=[337], + ) + + payout_format: Optional[PayoutFormatType] = PayoutFormatOptionalField + + user_payout_string: Optional[str] = Field( + default=None, + description="If a payout transformation is configured on this account, " + "this is the amount to display to the user", + examples=["3370 Points"], + ) + + kwargs: Dict[str, str] = Field( + default_factory=dict, + description="Any extra url params used in the offerwall request will be " + "passed back here", + ) + + status_code_1: Optional[Annotated[StatusCode1, EnumNameSerializer]] = Field( + default=None, + examples=[StatusCode1.COMPLETE.name], + description=StatusCode1.as_openapi_with_value_descriptions_name(), + ) + + status_code_2: Optional[Annotated[SessionStatusCode2, EnumNameSerializer]] = Field( + default=None, + examples=[None], + description=SessionStatusCode2.as_openapi_with_value_descriptions_name(), + ) + + adjusted_status: Optional[SessionAdjustedStatus] = Field( + default=None, + description=SessionAdjustedStatus.as_openapi_with_value_descriptions(), + examples=[None], + ) + + adjusted_timestamp: Optional[AwareDatetimeISO] = Field( + default=None, + description="When the adjusted status was last set.", + examples=[None], + ) + + adjusted_payout: Optional[NonNegativeInt] = Field( + default=None, + lt=100_000, + description="The new payout after adjustment.", + examples=[None], + ) + + adjusted_user_payout: Optional[NonNegativeInt] = Field( + default=None, + lt=100_000, + description="The new user_payout after adjustment.", + examples=[None], + ) + + adjusted_user_payout_string: Optional[str] = Field( + default=None, + description="The new user_payout_string after adjustment.", + examples=[None], + ) + + # This is used for validation purposes only. It won't get serialized + payout_transformation: Optional[PayoutTransformation] = Field( + default=None, exclude=True + ) + + wall_events: Optional[List[WallOut]] = Field(default=None) + + currency: Literal["USD"] = Field(default="USD") + final_status: int = Field(default=0, description="This is deprecated") + + # Serialize enum → int + @field_serializer("status", return_type=int) + def serialize_status(self, v: Optional[Status], _info): + return STATUS_MAP[v] + + # Accept int OR string for input, but internally store a Status enum + @field_validator("status", mode="before") + def deserialize_status(cls, v): + # int → enum + if isinstance(v, int): + return REVERSE_STATUS_MAP[v] + + if isinstance(v, str): + return Status(v) + + return v + + @model_validator(mode="before") + @classmethod + def user_payout_none(cls, data: Any): + # We changed the behaviour of user_payout at some point so that if the + # user_transformation is None, the user_payout is None, but this is + # not reflected in mysql. Change that here. + if "payout_transformation" in data and data["payout_transformation"] is None: + data["user_payout"] = None + data["adjusted_user_payout"] = None + data["user_payout_string"] = None + data["adjusted_user_payout_string"] = None + return data + + @field_validator("status_code_1", mode="before") + def transform_enum_name(cls, v: str | int) -> int: + # If we are serializing+deserializing this model (i.e. when we cache + # it), this fails because we've replaced the enum value with the + # name. Put it back here ... + if isinstance(v, str): + return StatusCode1[v] + return v + + @field_validator("status_code_2", mode="before") + def transform_enum_name2(cls, v: str | int) -> int: + # If we are serializing+deserializing this model (i.e. when we cache + # it), this fails because we've replaced the enum value with the + # name. But it back here ... + + if isinstance(v, str): + return SessionStatusCode2[v] + + return v + + @field_validator("payout", mode="before") + def transform_payout(cls, v: Optional[NonNegativeInt]) -> NonNegativeInt: + return v or 0 + + @field_validator("kwargs", mode="after") + def sanitize_kwargs(cls, v: Optional[Dict]) -> Optional[Dict]: + if v and "clicked_timestamp" in v: + try: + clicked_timestamp = datetime.strptime( + v["clicked_timestamp"], "%Y-%m-%d %H:%M:%S.%f" + ) + v["clicked_timestamp"] = ( + clicked_timestamp.isoformat(timespec="microseconds") + "Z" + ) + except ValueError: + pass + return v + + @model_validator(mode="before") + def transform_user_payout(cls, d): + # If the user_payout is None and there is a payout_format, make the user_payout 0 + if d.get("user_payout") is None and d.get("payout_format"): + d["user_payout"] = 0 + return d + + # --- Properties --- + @computed_field(return_type=str) + @property + def bpuid(self) -> str: + return self.product_user_id + + @classmethod + def from_session(cls, session: Session, product: Product) -> Self: + + user_payout_string = None + if session.user_payout is not None: + user_payout_string = product.format_payout_format(session.user_payout) + + adjusted_user_payout_string = None + if session.adjusted_user_payout is not None: + adjusted_user_payout_string = product.format_payout_format( + session.adjusted_user_payout + ) + + return TaskStatusResponse( + tsid=session.uuid, + status=session.status, + started=session.started, + finished=session.finished, + payout=decimal_to_int_cents(session.payout), + user_payout=decimal_to_int_cents(session.user_payout), + payout_format=product.payout_config.payout_format, + user_payout_string=user_payout_string, + product_id=session.user.product_id, + product_user_id=session.user.product_user_id, + kwargs=session.url_metadata or dict(), + status_code_1=session.status_code_1, + status_code_2=session.status_code_2, + adjusted_status=session.adjusted_status, + adjusted_payout=decimal_to_int_cents(session.adjusted_payout), + adjusted_user_payout=decimal_to_int_cents(session.adjusted_user_payout), + adjusted_timestamp=session.adjusted_timestamp, + adjusted_user_payout_string=adjusted_user_payout_string, + payout_transformation=product.payout_config.payout_transformation, + wall_events=[ + WallOut.from_wall(w, product=product) for w in session.wall_events + ], + ) + + +class TasksStatusResponse(Page): + tasks_status: List[TaskStatusResponse] = Field(default_factory=list) diff --git a/generalresearch/models/thl/user.py b/generalresearch/models/thl/user.py new file mode 100644 index 0000000..e393830 --- /dev/null +++ b/generalresearch/models/thl/user.py @@ -0,0 +1,323 @@ +from __future__ import annotations + +import json +import logging +import re +from datetime import timezone, datetime +from typing import Optional, Dict, List, TYPE_CHECKING +from uuid import uuid4, UUID + +from pydantic import ( + AwareDatetime, + Field, + BaseModel, + field_validator, + model_validator, + PositiveInt, + ConfigDict, + StringConstraints, + AfterValidator, +) +from sentry_sdk import set_tag, set_user +from typing_extensions import Annotated, Self + +from generalresearch.models import MAX_INT32 +from generalresearch.models.custom_types import AwareDatetimeISO, UUIDStr +from generalresearch.models.thl.ipinfo import GeoIPInformation +from generalresearch.models.thl.ledger import LedgerTransaction +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.userhealth import AuditLog +from generalresearch.pg_helper import PostgresConfig + +if TYPE_CHECKING: + from generalresearch.managers.thl.userhealth import AuditLogManager + from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, + ) + + # from generalresearch.managers.thl.userhealth import UserIpHistoryManager + +logger = logging.getLogger() + +BPUID_ALLOWED = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!#$%&()*+,-.:;<=>?@[\]^_{|}~" + + +class User(BaseModel): + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + user_id: Optional[PositiveInt] = Field( + default=None, lt=MAX_INT32, serialization_alias="id" + ) + + uuid: Optional[UUIDStr] = Field(default=None, examples=[uuid4().hex]) + + # 'product' is a Class with values that are fetched from the DB. + # Initialization is deferred until it is actually needed + # (see .prefetch_product()) + product: Optional[Product] = Field(default=None) + + product_id: Optional[UUIDStr] = Field( + default=None, examples=["4fe381fb7186416cb443a38fa66c6557"] + ) + + product_user_id: Optional[BPUIDStr] = Field( + default=None, + examples=["app-user-9329ebd"], + description="A unique identifier for each user, which is set by the " + "Supplier. It should not contain any sensitive information" + "like email or names, and should avoid using any" + "incrementing values.", + ) + + # TODO: Is it possible to protect these from ever being initialized? + # - Would need to be allowed with .from_json but not User constructor directly + # - Would need to allow private setters for setting from DB values + blocked: Optional[bool] = Field(default=False, strict=True) + + created: Optional[AwareDatetimeISO] = Field( + default=None, + description="When the user was created on the GRL platform.", + ) + + # Note: due to cacheing, last_seen might be up to a day out of date! + last_seen: Optional[AwareDatetimeISO] = Field( + default=None, + description="When the user was last seen on, or acting on any" + "part of the GRL platform.", + ) + + # --- Prefetch Fields --- + audit_log: Optional[List[AuditLog]] = Field(default=None) + transactions: Optional[List["LedgerTransaction"]] = Field(default=None) + location_history: Optional[List["GeoIPInformation"]] = Field(default=None) + + # --- Prebuild Fields --- + # session: Optional[List] = Field(default=None) + # wall: Optional[List] = Field(default=None) + + def __eq__(self, other: "User"): + return ( + self.product_id == other.product_id + and self.product_user_id == other.product_user_id + and self.user_id == other.user_id + and self.uuid == other.uuid + ) + + # --- Validation --- + @field_validator("product_user_id") + def check_product_user_id(cls, v: str) -> str: + if v is not None: + if " " in v: + raise ValueError("String cannot contain spaces") + if "\\" in v: + raise ValueError("String cannot contain backslash") + if "/" in v: + raise ValueError("String cannot contain slash") + # I think the * on the regex messes up value matches that are + # the same length as the + rex = re.fullmatch("[" + BPUID_ALLOWED + "]*", v) + if not bool(rex): + raise ValueError("String is not valid regex") + return v + + # noinspection PyNestedDecoratorsk + @field_validator("created", "last_seen") + @classmethod + def check_not_in_future(cls, v: AwareDatetime) -> AwareDatetime: + if v is not None: + try: + assert v < datetime.now(tz=timezone.utc) + except Exception: + raise ValueError("Input is in the future") + return v + + # noinspection PyNestedDecorators + @field_validator("created", "last_seen") + @classmethod + def check_after_anno_domini(cls, v: AwareDatetime) -> AwareDatetime: + if v is not None: + try: + assert v > datetime(year=2016, month=7, day=13, tzinfo=timezone.utc) + except Exception: + raise ValueError("Input is before Anno Domini") + return v + + @model_validator(mode="after") + def check_identifiable(self) -> "User": + if not self.is_identifiable: + raise ValueError("User is not identifiable") + + return self + + @model_validator(mode="after") + def check_created_first(self) -> "User": + # TODO: require the created value comes before, or is equal to the + # last_seen + created = self.created + last_seen = self.last_seen + if created is not None and last_seen is not None and created > last_seen: + raise ValueError("User created time invalid") + return self + + # --- Properties --- + @property + def is_identifiable(self) -> bool: + return bool( + self.user_id is not None + or self.uuid is not None + or (self.product_id is not None and self.product_user_id) + ) + + @classmethod + def is_valid_ubp(cls, *, product_id, product_user_id) -> bool: + # Attempt to create common_struct solely for validation purposes, + # using the product_id and product_user_id + try: + cls.check_bpuid_is_not_bpid(product_id, product_user_id) + cls( + user_id=None, + product_id=product_id, + product_user_id=product_user_id, + ) + except Exception as e: + logger.info(e) + return False + else: + return True + + # --- Methods --- + @staticmethod + def check_bpuid_is_not_bpid(product_id, product_user_id): + """Unfortunately users were already created failing this constraint, + so only check for new users! + """ + if ( + product_id is not None + and product_user_id is not None + and product_id == product_user_id + ): + raise ValueError("product_user_id must not equal the product_id") + return True + + def to_dict(self) -> Dict: + return self.model_dump(mode="python", exclude={"product"}) + + def to_json(self) -> str: + d = self.model_dump(mode="json", exclude={"product"}) + d["user_id"] = self.user_id + return json.dumps(d) + + def set_sentry_user(self): + # https://docs.sentry.io/platforms/python/enriching-events/identify-user/ + set_user( + { + "id": self.user_id, + "product_id": self.product_id, + "product_user_id": self.product_user_id, + } + ) + set_tag(key="bpid", value=self.product_id) + set_tag(key="bpuid", value=self.product_user_id) + + def delete_profiling_history(self, thl_sql_rw: PostgresConfig) -> bool: + """This is how we remove any profiling data on a user from our system. + + (1) Delete from thl-web tables + (2) Delete from thl-marketplace tables + + # Possible future steps: + # - Notify Marketplaces of deletion requests + # - FullCircle: DanaH@ilovefullcircle.com + """ + + self.set_sentry_user() + + # Delete from db.300large-web + for table in [ + "marketplace_userprofileknowledgeitem", + "marketplace_userprofileknowledgenumerical", + "marketplace_userprofileknowledgetext", + "marketplace_userquestionanswer", + "userprofile_useriphistory", + ]: + thl_sql_rw.execute_write( + query=f""" + DELETE FROM {table} + WHERE user_id = %s; + """, + params=[self.user_id], + ) + + # # Delete from db.thl-marketplaces + # We need DELETE credentials for all these... + # from generalresearch.models import Source + # mp_db_table = { + # Source.SPECTRUM: "`thl-spectrum`.`spectrum_marketresearchprofilequestion`", + # Source.INNOVATE: "`thl-innovate`.`innovate_marketresearchprofilequestion`", + # Source.DYNATA: "`thl-dynata`.`dynata_rexmarketresearchprofilequestion`", + # Source.SAGO: "`thl-schlesinger`.`sago_marketresearchprofilequestion`", + # Source.PRODEGE: "`thl-prodege`.`prodege_marketresearchprofilequestion`", + # Source.POLLFISH: "`thl-pollfish`.`pollfish_marketresearchprofilequestion`", + # Source.PRECISION: "`thl-precision`.`precision_marketresearchprofilequestion`", + # Source.MORNING_CONSULT: "`thl-morning`.`morning_marketresearchprofilequestion`", + # # Source.FULL_CIRCLE: "`300large-fullcircle`.`fullcircle_marketresearchprofilequestion`" + # } + # for source in PRIVACY_MP_MYSQLC.keys(): + # PRIVACY_MP_MYSQLC[source].execute_sql_query(f""" + # DELETE FROM {MP_DB_TABLE[source]} + # WHERE user_id = %s;""", [user.user_id], commit=True) + # + + return True + + # --- Prefetch --- + + def prefetch_product(self, pg_config: PostgresConfig) -> None: + from generalresearch.managers.thl.product import ProductManager + + if self.product is None: + pm = ProductManager(pg_config=pg_config) + self.product = pm.get_by_uuid(product_uuid=self.product_id) + + return None + + def prefetch_audit_log(self, audit_log_manager: "AuditLogManager") -> None: + self.audit_log = audit_log_manager.filter_by_user_id(user_id=self.user_id) + return None + + def prefetch_transactions(self, thl_lm: "ThlLedgerManager") -> None: + account = thl_lm.get_account_or_create_user_wallet(user=self) + self.transactions = thl_lm.get_tx_filtered_by_account(account_uuid=account.uuid) + return None + + # def prefetch_location_history(self, user_ip_history_manager: "UserIpHistoryManager") -> None: + # return user_ip_history_manager.get_user_ip_history(user_id=self.user_id) + + # --- Prebuild --- + + @classmethod + def from_db(cls, res) -> Self: + if res["created"]: + res["created"] = res["created"].replace(tzinfo=timezone.utc) + if res["last_seen"]: + res["last_seen"] = res["last_seen"].replace(tzinfo=timezone.utc) + res["product_id"] = UUID(res["product_id"]).hex + res["uuid"] = UUID(res["uuid"]).hex + return cls( + user_id=res["user_id"], + product_id=res["product_id"], + product_user_id=res["product_user_id"], + uuid=res["uuid"], + blocked=bool(res["blocked"]), + created=res["created"], + last_seen=res["last_seen"], + ) + + +# Used in other places where the bpuid is part of a model that's used in +# the API (separate from a User) +BPUIDStr = Annotated[ + str, + StringConstraints(min_length=3, max_length=128), + AfterValidator(User.check_product_user_id), +] diff --git a/generalresearch/models/thl/user_iphistory.py b/generalresearch/models/thl/user_iphistory.py new file mode 100644 index 0000000..ecdbc7a --- /dev/null +++ b/generalresearch/models/thl/user_iphistory.py @@ -0,0 +1,245 @@ +import ipaddress +from datetime import timezone, datetime, timedelta +from typing import List, Optional, Dict + +from faker import Faker +from pydantic import ( + BaseModel, + Field, + ConfigDict, + PositiveInt, + field_validator, +) +from typing_extensions import Self + +from generalresearch.models.custom_types import ( + AwareDatetimeISO, + IPvAnyAddressStr, + CountryISOLike, +) +from generalresearch.models.thl.ipinfo import ( + GeoIPInformation, + normalize_ip, +) +from generalresearch.models.thl.maxmind.definitions import UserType +from generalresearch.models.thl.user import User +from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig + +fake = Faker() + + +class UserIPRecord(BaseModel): + model_config = ConfigDict(validate_assignment=True) + + ip: IPvAnyAddressStr = Field() + created: AwareDatetimeISO = Field() + information: Optional[GeoIPInformation] = Field(default=None, exclude=True) + + @property + def country_iso(self) -> Optional[CountryISOLike]: + return self.information.country_iso if self.information else None + + @property + def is_anonymous(self) -> bool: + # default to False even if insights is not looked up + return ( + self.information.is_anonymous + if self.information + and self.information.basic is False + and self.information.is_anonymous is not None + else False + ) + + @property + def user_type(self) -> Optional[UserType]: + return self.information.user_type if self.information else None + + @property + def subdivision_1_iso(self) -> Optional[str]: + return self.information.subdivision_1_iso if self.information else None + + @property + def subdivision_2_iso(self) -> Optional[str]: + return self.information.subdivision_2_iso if self.information else None + + +class IPRecord(BaseModel): + user_id: PositiveInt = Field() + ip: IPvAnyAddressStr = Field() + created: AwareDatetimeISO = Field() + + # On a top-level, this should be an empty list if there are no forwarded_ip. + # Within a forwarded_ip record, this should be None. + forwarded_ip_records: Optional[List["IPRecord"]] = Field( + default=None, description="" + ) + + information: Optional[GeoIPInformation] = Field(default=None) + + @property + def forwarded_ips(self) -> Optional[List[IPvAnyAddressStr]]: + return ( + [x.ip for x in self.forwarded_ip_records] + if self.forwarded_ip_records is not None + else None + ) + + def ip_changed( + self, ip: IPvAnyAddressStr, forwarded_ips: List[IPvAnyAddressStr] + ) -> bool: + return not (ip == self.ip and forwarded_ips == self.forwarded_ips) + + # --- prefetch_* --- + def prefetch_ipinfo( + self, + pg_config: PostgresConfig, + redis_config: RedisConfig, + include_forwarded: bool = True, + ) -> None: + from generalresearch.managers.thl.ipinfo import GeoIpInfoManager + + m = GeoIpInfoManager(pg_config=pg_config, redis_config=redis_config) + + if include_forwarded: + ips = {self.ip} + ips.update(set(self.forwarded_ips)) + res = m.get_multi(ips) + self.information = res.get(self.ip) + for x in self.forwarded_ip_records: + x.information = res.get(x.ip) + else: + self.information = m.get(ip_address=self.ip) + return None + + # --- ORM --- + @classmethod + def from_mysql(cls, d: Dict) -> Self: + created = d["created"].replace(tzinfo=timezone.utc) + + d["created"] = created + d["forwarded_ip_records"] = [] + + for fip in [ + d.get("forwarded_ip1"), + d.get("forwarded_ip2"), + d.get("forwarded_ip3"), + d.get("forwarded_ip4"), + d.get("forwarded_ip5"), + d.get("forwarded_ip6"), + ]: + if fip: + d["forwarded_ip_records"].append( + { + "user_id": d["user_id"], + "ip": fip, + "created": created, + "forwarded_ip_records": None, + } + ) + + return cls.model_validate(d) + + +class UserIPHistory(BaseModel): + model_config = ConfigDict(validate_assignment=True) + + user_id: PositiveInt = Field() + + # In thl-gprc, we run "audit_ip_history()", and so a user should + # get blocked after 100 IP switches or 30 unique IPs + # Sorted created DESC + ips: Optional[List[UserIPRecord]] = Field( + default=None, + description="These are any IP addresses that came in ", + max_length=101, + ) + + ips_ws: Optional[List[IPRecord]] = Field( + default=None, description="These are any IP addresses that came in " + ) + + ips_dns: Optional[List[IPRecord]] = Field( + default=None, description="These are any IP addresses that came in " + ) + + # -- prefetch_ fields + user: Optional[User] = Field(default=None) + + @field_validator("ips", mode="after") + @classmethod + def ips_timestamp(cls, ips): + if ips is None: + return None + cutoff = datetime.now(tz=timezone.utc) - timedelta(days=28) + return sorted( + [x for x in ips if x.created > cutoff], + key=lambda x: x.created, + reverse=True, + ) + + def prefetch_user( + self, + pg_config: PostgresConfig, + redis_config: RedisConfig, + pg_config_rr: PostgresConfig, + ) -> None: + from generalresearch.managers.thl.user_manager.user_manager import ( + UserManager, + ) + + um = UserManager( + pg_config=pg_config, + pg_config_rr=pg_config_rr, + redis=redis_config.dsn, + ) + self.user = um.get_user(user_id=self.user_id) + + return None + + def enrich_ips(self, pg_config: PostgresConfig, redis_config: RedisConfig) -> None: + from generalresearch.managers.thl.ipinfo import GeoIpInfoManager + + m = GeoIpInfoManager(pg_config=pg_config, redis_config=redis_config) + + ip_addresses = {x.ip for x in self.ips if x.information is None} + res = m.get_multi(ip_addresses=ip_addresses) + for x in self.ips: + if res.get(x.ip): + x.information = res[x.ip] + + return None + + def collapse_ip_records(self): + """ + - Records where sequential ipv6 addresses are in the same /64 block, + just keep the last one. + - If a user has a new ip b/c they've simply alternated between a ipv4 + and ipv6, only keep the most recent 1 of each version. + """ + records = self.ips.copy() + + res = [] + last_ipv4 = None + last_ipv6 = None + + # Iterate through (most recent first) + for record in records: + ip = ipaddress.ip_address(record.ip) + if ip.version == 4: + if last_ipv4 and last_ipv4 == record.ip: + continue + last_ipv4 = record.ip + res.append(record) + elif ip.version == 6: + normalized_ip, _ = normalize_ip(ip) + # If the latest ipv6 is the same /64 block as an older one, + # discard the older one. + if last_ipv6 and last_ipv6 == normalized_ip: + continue + last_ipv6 = normalized_ip + res.append(record) + else: + raise ValueError("we've ripped a hole in the universe") + + return res diff --git a/generalresearch/models/thl/user_profile.py b/generalresearch/models/thl/user_profile.py new file mode 100644 index 0000000..6514c7d --- /dev/null +++ b/generalresearch/models/thl/user_profile.py @@ -0,0 +1,122 @@ +import hashlib +from typing import Optional, Dict, Any, List + +from pydantic import ( + Field, + BaseModel, + ConfigDict, + EmailStr, + PositiveInt, + computed_field, +) +from pydantic.json_schema import SkipJsonSchema +from typing_extensions import Self, Annotated + +from generalresearch.models import MAX_INT32, Source +from generalresearch.models.custom_types import UUIDStr +from generalresearch.models.thl.user import User +from generalresearch.models.thl.user_streak import UserStreak + + +class UserMetadata(BaseModel): + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + user_id: SkipJsonSchema[Optional[PositiveInt]] = Field( + exclude=True, default=None, lt=MAX_INT32 + ) + + email_address: Optional[EmailStr] = Field( + default=None, examples=["contact@mail.com"] + ) + + @computed_field + def email_md5( + self, + ) -> Annotated[ + Optional[str], + Field( + min_length=32, + max_length=32, + description="MD5 hash of the email address", + examples=["053fc3d5575362159e0c782abec83ffa"], + ), + ]: + if self.email_address is None: + return None + + return hashlib.md5(self.email_address.encode("utf-8")).hexdigest() + + @computed_field + def email_sha1( + self, + ) -> Annotated[ + Optional[str], + Field( + min_length=40, + max_length=40, + description="SHA1 hash of the email address", + examples=["6280fb76135b3585c0c5403be04844a0f0bae726"], + ), + ]: + if self.email_address is None: + return None + return hashlib.sha1(self.email_address.encode("utf-8")).hexdigest() + + @computed_field + def email_sha256( + self, + ) -> Annotated[ + Optional[str], + Field( + min_length=64, + max_length=64, + description="SHA256 hash of the email address", + examples=[ + "8a098233e750f08de87d6053c06a58724287f34372368b6dc28b7ad4a77f3d39" + ], + ), + ]: + if self.email_address is None: + return None + return hashlib.sha256(self.email_address.encode("utf-8")).hexdigest() + + def to_db(self) -> Dict[str, Any]: + res = self.model_dump(mode="json") + res["user_id"] = self.user_id + return res + + @classmethod + def from_db(cls, user_id, email_address, **kwargs) -> Self: + # If the hashes are passed, just validate that they match + obj = cls.model_validate({"user_id": user_id, "email_address": email_address}) + + if kwargs.get("email_md5") is not None: + assert obj.email_md5 == kwargs["email_md5"], "email_md5 mismatch" + + if kwargs.get("email_sha1") is not None: + assert obj.email_sha1 == kwargs["email_sha1"], "email_sha1 mismatch" + + if kwargs.get("email_sha256") is not None: + assert obj.email_sha256 == kwargs["email_sha256"], "email_sha256 mismatch" + + return obj + + +class UserProfile(UserMetadata): + model_config = ConfigDict() + + user: User = Field() + + marketplace_pids: Dict[Source, UUIDStr] = Field( + default_factory=dict, + description="User's PID in marketplaces", + examples=[ + { + Source.CINT: "b507a2c00c3e481fb82f23655d142198", + Source.DYNATA: "deffe922063e4b9980206a62c3df2fba", + Source.INNOVATE: "1dd9bd986794444eb97cb921aee5663f", + } + ], + ) + + streaks: List[UserStreak] = Field(default_factory=list) diff --git a/generalresearch/models/thl/user_quality_event.py b/generalresearch/models/thl/user_quality_event.py new file mode 100644 index 0000000..f98e42d --- /dev/null +++ b/generalresearch/models/thl/user_quality_event.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from decimal import Decimal +from enum import Enum +from typing import List, Literal, Optional + +from pydantic import BaseModel, Field, PositiveInt + +from generalresearch.models import Source, MAX_INT32 +from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.thl.definitions import WallAdjustedStatus +from generalresearch.models.thl.user import BPUIDStr +from generalresearch.utils.enum import ReprEnumMeta + +""" +Typically used internally. These affect a user's quality standing. +""" + + +class QualityEventType(str, Enum, metaclass=ReprEnumMeta): + """ + Currently, the grpc call SendUserQualityEvents handles both the + recons/task adj, access control, and "security/hash failure" events. + Splitting those up for the web api, even though in the backend, all + 3 might hit the same grpc call. + """ + + # Used to adjust a Wall's adjustment_status + task_adjustment = "task_adjustment" + + # Manually adding a user to the whitelist + add_to_whitelist = "add_to_whitelist" + + # Manually adding a user to the blacklist + add_to_blacklist = "add_to_blacklist" + + # Clear any manual access control for a user + clear_access_control_list = "clear_access_control_list" + + +class AccessControlEvent(BaseModel): + quality_event_type: Literal[ + QualityEventType.add_to_whitelist, + QualityEventType.add_to_blacklist, + QualityEventType.clear_access_control_list, + ] = Field() + # One of user_id / (product_id, bpuid) is required. + product_id: Optional[UUIDStr] = Field( + default=None, examples=["4fe381fb7186416cb443a38fa66c6557"] + ) + bpuid: Optional[BPUIDStr] = Field(default=None, examples=["app-user-9329ebd"]) + user_id: Optional[PositiveInt] = Field(default=None, lt=MAX_INT32) + + +class AccessControlEventBody(BaseModel): + events: List[AccessControlEvent] = Field(max_length=100, min_length=1) + + +class TaskAdjustmentEvent(BaseModel): + mid: UUIDStr = Field() + source: Source = Field() + status: WallAdjustedStatus = Field() + alert_time: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + quality_event_type: Literal[QualityEventType.task_adjustment] = Field( + default=QualityEventType.task_adjustment + ) + + # Only MID is needed to populate all the following, however we can pass them in order + # to perform validation. If any disagree, an error should be raised. + survey_id: Optional[str] = Field(max_length=32, default=None) + amount: Optional[Decimal] = Field( + description="If negative, the status should adjusted to incomplete", + default=None, + ) + event_time: Optional[AwareDatetimeISO] = Field( + description="This is when the original wall event was started", + default=None, + ) + product_id: Optional[UUIDStr] = Field( + default=None, examples=["4fe381fb7186416cb443a38fa66c6557"] + ) + bpuid: Optional[BPUIDStr] = Field(default=None, examples=["app-user-9329ebd"]) + user_id: Optional[PositiveInt] = Field(default=None, lt=MAX_INT32) + + +class TaskAdjustmentEventBody(BaseModel): + events: List[TaskAdjustmentEvent] = Field(max_length=100, min_length=1) diff --git a/generalresearch/models/thl/user_streak.py b/generalresearch/models/thl/user_streak.py new file mode 100644 index 0000000..fa6d3b1 --- /dev/null +++ b/generalresearch/models/thl/user_streak.py @@ -0,0 +1,152 @@ +from datetime import date, datetime, timedelta +from enum import Enum +from typing import Optional, Tuple +from zoneinfo import ZoneInfo + +import pandas as pd +from pydantic import ( + BaseModel, + NonNegativeInt, + Field, + computed_field, + AwareDatetime, + model_validator, + ConfigDict, + PositiveInt, +) +from pydantic.json_schema import SkipJsonSchema + +from generalresearch.managers.leaderboard import country_timezone +from generalresearch.models import MAX_INT32 +from generalresearch.models.thl.locales import CountryISO + + +class StreakPeriod(str, Enum): + # Midnight to midnight in the tz associated with the user's country + DAY = "day" + # Sunday midnight - sunday midnight + WEEK = "week" + # e.g. 2000-01-01 to 2000-01-31 23:59:59.999999 + MONTH = "month" + + +class StreakFulfillment(str, Enum): + """ + What has to happen for a user to fulfill a period for a streak + """ + + # User has to finish a Session (excluding Session start failure) + ACTIVE = "active" + # User has to complete a Session + COMPLETE = "complete" + + +class StreakState(str, Enum): + # The activity for today was completed! + ACTIVE = "active" + # They had activity yesterday, but not today, and can still continue today + # Should we call this "AT_RISK" instead ?? (I had "open") + AT_RISK = "at_risk" + # Missed the window. Streak is broken + BROKEN = "broken" + + +PERIOD_TO_PD_FREQ = { + StreakPeriod.DAY: "D", + StreakPeriod.WEEK: "W-SUN", # Sunday-based week + StreakPeriod.MONTH: "M", +} + + +class UserStreak(BaseModel): + model_config = ConfigDict( + ser_json_timedelta="float", validate_assignment=True, extra="forbid" + ) + + user_id: SkipJsonSchema[Optional[PositiveInt]] = Field( + exclude=True, default=None, lt=MAX_INT32 + ) + country_iso: CountryISO = Field() + + # What defines the streak + period: StreakPeriod = Field() + fulfillment: StreakFulfillment = Field() + + current_streak: NonNegativeInt = Field() + longest_streak: NonNegativeInt = Field() + state: StreakState = Field() + last_fulfilled_period_start: Optional[date] = Field(default=None) + + @computed_field() + @property + def timezone_name(self) -> str: + return str(self.timezone) + + @property + def timezone(self) -> ZoneInfo: + return country_timezone()[self.country_iso] + + @property + def now_local(self) -> AwareDatetime: + return datetime.now(tz=self.timezone) + + @computed_field() + @property + def current_period_bounds(self) -> Tuple[AwareDatetime, AwareDatetime]: + return self.get_period_bounds(datetime.now(tz=self.timezone).date()) + + @computed_field() + @property + def last_fulfilled_period_bounds(self) -> Optional[Tuple[datetime, datetime]]: + return self.get_period_bounds(self.last_fulfilled_period_start) + + @computed_field() + @property + def time_remaining_in_period(self) -> Optional[timedelta]: + # Time left to continue your streak + if self.state in {StreakState.BROKEN, StreakState.ACTIVE}: + return None + period_end = self.current_period_bounds[1] + return period_end - self.now_local + + @model_validator(mode="after") + def check_state(self): + if self.state == StreakState.BROKEN: + assert ( + self.current_streak == 0 + ), "StreakState.BROKEN but current_streak not 0" + + if self.current_streak != 0: + assert ( + self.state != StreakState.BROKEN + ), "current_streak not 0 but StreakState.BROKEN" + return self + + @model_validator(mode="after") + def check_longest_streak(self): + assert ( + self.longest_streak >= self.current_streak + ), "Current streak can't be longer than longest streak" + return self + + def get_period_bounds( + self, start_date: date + ) -> Optional[Tuple[datetime, datetime]]: + """ + Returns (period_start_local, period_end_local) + Both timezone-aware. + """ + + if not start_date: + return None + + freq = PERIOD_TO_PD_FREQ[self.period] + tz = self.timezone + + period = pd.Timestamp(start_date).to_period(freq) + period_start_local = period.start_time.to_pydatetime(warn=False).replace( + tzinfo=tz + ) + period_end_local = period.end_time.to_pydatetime(warn=False).replace(tzinfo=tz) + + return period_start_local, period_end_local diff --git a/generalresearch/models/thl/userhealth.py b/generalresearch/models/thl/userhealth.py new file mode 100644 index 0000000..8ea81dd --- /dev/null +++ b/generalresearch/models/thl/userhealth.py @@ -0,0 +1,77 @@ +from datetime import datetime, timezone +from enum import Enum +from typing import Optional, Dict + +from pydantic import Field, BaseModel, PositiveInt, NonNegativeFloat +from typing_extensions import Self + +from generalresearch.models.custom_types import AwareDatetimeISO + + +class AuditLogLevel(int, Enum): + CRITICAL = 50 + FATAL = CRITICAL + ERROR = 40 + WARNING = 30 + WARN = WARNING + INFO = 20 + DEBUG = 10 + NOTSET = 0 + + +class AuditLog(BaseModel): + """Table / Model for logging "actions" taken by a user or "events" that + are related to a User + """ + + id: Optional[PositiveInt] = Field(default=None) + user_id: PositiveInt = Field() + + created: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + examples=[datetime.now(tz=timezone.utc)], + description="When did this event occur", + ) + + level: AuditLogLevel = Field( + description="The level of importance for this event. Works the same as " + "python logging levels. It is an integer 0 - 50, and " + "implementers of this field could map it to the predefined " + "levels: (`CRITICAL`, `ERROR`, `WARNING`, `INFO`, `DEBUG`)." + "This is NOT the same concept as the 'strength' of whatever " + "event happened; it is just for sorting, filtering and " + "display purposes. For e.g. multiple level 20 events != the " + "'importance' of one level 40 event.", + examples=[AuditLogLevel.WARNING], + ) + + # The "class" or "type" or event that happened. + # e.g. "upk-audit", "ip-audit", "entrance-limit" + event_type: str = Field(max_length=64, examples=["entrance-limit"]) + + event_msg: Optional[str] = Field( + default=None, + min_length=3, + max_length=256, + description="The event message. Could be displayed on user's page", + ) + + event_value: Optional[NonNegativeFloat] = Field( + default=None, + description="Optionally store a numeric value associated with this " + "event. For e.g. if we recalculate the user's normalized " + "recon rate, and it is 'high', we could store an event like " + "(event_type='recon-rate', event_msg='higher than allowed " + "recon rate' event_value=0.42)", + examples=[0.42], + ) + + def model_dump_mysql(self, **kwargs) -> Dict: + d = self.model_dump(mode="json", **kwargs) + d["created"] = self.created.replace(tzinfo=None) + return d + + @classmethod + def from_mysql(cls, d: Dict) -> Self: + d["created"] = d["created"].replace(tzinfo=timezone.utc) + return AuditLog.model_validate(d) diff --git a/generalresearch/models/thl/wallet/__init__.py b/generalresearch/models/thl/wallet/__init__.py new file mode 100644 index 0000000..928d67f --- /dev/null +++ b/generalresearch/models/thl/wallet/__init__.py @@ -0,0 +1,87 @@ +from enum import Enum + +from generalresearch.utils.enum import ReprEnumMeta + + +class PayoutType(str, Enum, metaclass=ReprEnumMeta): + """ + The method in which the requested payout is delivered. + """ + + # The max size of the db field that holds this value is 14, so please + # don't add new values longer than that! + + # User is paid out to their personal PayPal email address + PAYPAL = "PAYPAL" + # User is paid uut via a Tango Gift Card + TANGO = "TANGO" + # DWOLLA + DWOLLA = "DWOLLA" + # A payment is made to a bank account using ACH + ACH = "ACH" + # A payment is made to a bank account using ACH + WIRE = "WIRE" + # A payment is made in cash and mailed to the user. + CASH_IN_MAIL = "CASH_IN_MAIL" + # A payment is made as a prize with some monetary value + PRIZE = "PRIZE" + + # This is used to designate either AMT_BONUS or AMT_HIT + AMT = "AMT" + # Amazon Mechanical Turk as a Bonus + AMT_BONUS = "AMT_BONUS" + # Amazon Mechanical Turk for a HIT + AMT_HIT = "AMT_ASSIGNMENT" + AMT_ASSIGNMENT = "AMT_ASSIGNMENT" + + +class Currency(str, Enum): + # United States Dollar + USD = "USD" + # Canadian Dollar + CAD = "CAD" + # British Pound Sterling + GBP = "GBP" + # Euro + EUR = "EUR" + # Indian Rupee + INR = "INR" + # Australian Dollar + AUD = "AUD" + # Polish Zloty + PLN = "PLN" + # Swedish Krona + SEK = "SEK" + # Singapore Dollar + SGD = "SGD" + # Mexican Peso + MXN = "MXN" + + +CURRENCY_FORMATTER = { + "USD": lambda x: "${:,.2f}".format(x / 100), + "CAD": lambda x: "${:,.2f} CAD".format(x / 100), + "GBP": lambda x: "{:,.2f} £".format(x / 100), + "EUR": lambda x: "€{:,.2f}".format(x / 100), + "INR": lambda x: "₹{:,.2f}".format(x / 100), + "AUD": lambda x: "${:,.2f} AUD".format(x / 100), + "PLN": lambda x: "{:,.2f} zł".format(x / 100), + "SEK": lambda x: "{:,.2f} kr".format(x / 100), + "SGD": lambda x: "${:,.2f} SGD".format(x / 100), + "MXN": lambda x: "${:,.2f} MXN".format(x / 100), +} + +# The max value user can redeem in one go in foreign currencies. should be < $250 +# in order to avoid exchange rate issues +CURRENCY_MAX_VALUE = { + "USD": 250, + "CAD": 200, + "GBP": 100, + "EUR": 100, + "INR": 10000, + "AUD": 200, + "PLN": 500, + "SEK": 1000, + "SGD": 200, + "MXN": 4000, +} diff --git a/generalresearch/models/thl/wallet/cashout_method.py b/generalresearch/models/thl/wallet/cashout_method.py new file mode 100644 index 0000000..def724d --- /dev/null +++ b/generalresearch/models/thl/wallet/cashout_method.py @@ -0,0 +1,443 @@ +from __future__ import annotations + +import hashlib +import logging +from datetime import datetime, timezone +from enum import Enum +from typing import List, Dict, Any, Optional, Literal, Union + +from pydantic import ( + BaseModel, + Field, + ConfigDict, + NonNegativeInt, + PositiveInt, + EmailStr, + model_validator, + field_validator, +) +from typing_extensions import Self + +from generalresearch.currency import USDCent +from generalresearch.models.custom_types import ( + UUIDStr, + HttpsUrlStr, + AwareDatetimeISO, +) +from generalresearch.models.legacy.api_status import StatusResponse +from generalresearch.models.thl.definitions import PayoutStatus +from generalresearch.models.thl.locales import CountryISO +from generalresearch.models.thl.user import BPUIDStr, User +from generalresearch.models.thl.wallet import PayoutType, Currency +from generalresearch.utils.enum import ReprEnumMeta + +logger = logging.getLogger() + +example_cashout_method = { + "id": "941d489c3ce04eb39a0ddb7f8f75db74", + "bpid": "6a3ddfb747344bbc93efadf1c3a16e1a", + "bpuid": None, + "currency": "USD", + "data": {"terms": "...", "disclaimer": "..."}, + "description": "...", + "image_url": "https://d30s7yzk2az89n.cloudfront.net/images/brands/b238587-1200w-326ppi.png", + "max_value": 25000, + "min_value": 500, + "name": "Visa® Prepaid Card USD", + "type": "TANGO", +} + + +class CashoutMethodBase(BaseModel): + """ + A user can request a payout of their wallet balance via a cashout method. This is the way + in which the money is paid. The terms cashout and payout are used interchangeably. + """ + + model_config = ConfigDict(json_schema_extra={"example": example_cashout_method}) + + id: UUIDStr = Field(description="Unique ID for this cashout method") + + currency: Literal["USD"] = Field( + default="USD", + description="The currency of the cashout. Only USD is supported.", + ) + original_currency: Optional[Currency] = Field( + default=None, + description="The base currency of the money paid out. This is used for " + "e.g. sending an Amazon UK gift card", + ) + # This also is used for the PayoutEvent.request_data + data: Union[ + PaypalCashoutMethodData, + TangoCashoutMethodData, + CashMailCashoutMethodData, + AmtCashoutMethodData, + ] = Field(discriminator="type") + description: str = Field( + description="The description of the cashout method.", default="" + ) + image_url: Optional[HttpsUrlStr] = Field( + description="Link to an image to display", default=None + ) + max_value: PositiveInt = Field( + description="(In lowest unit of the original_currency), " + "The maximum amount that can be cashed out in one transaction." + ) + min_value: NonNegativeInt = Field( + description="(In lowest unit of the original_currency), " + "The minimum amount that can be cashed out in one transaction." + ) + name: str = Field(description="A descriptive name for the cashout method.") + # In the db, this is called "provider" + type: PayoutType = Field( + description=PayoutType.as_openapi_with_value_descriptions(), + ) + ext_id: Optional[str] = Field( + default=None, + description="An external ID. Can be shown to a user to disambiguate " + "a user's possibly multiple methods", + ) + usd_exchange_rate: Optional[float] = Field(default=None) + max_value_usd: Optional[USDCent] = Field( + default=None, + description="(In lowest unit of USD), " + "The maximum amount that can be cashed out in one transaction.", + ) + min_value_usd: Optional[USDCent] = Field( + default=None, + description="(In lowest unit of USD), " + "The minimum amount that can be cashed out in one transaction.", + ) + + # + # @property + # def min_value_usd(self): + # if self.original_currency == Currency.USD: + # return self.min_value + # if self.usd_exchange_rate is None: + # return None + # return self.min_value * self.usd_exchange_rate + + def validate_requested_amount(self, amount: PositiveInt): + """ + Check if 'amount' is a valid amount that can be requested. + :param amount: The amount to be requested in USD Cents + """ + if amount <= 0: + raise ValueError("Amount must be positive") + if not self.min_value <= amount <= self.max_value: + raise ValueError( + f"Invalid amount requested: ${amount / 100:.2f}. Must be between" + f" ${int(self.min_value) / 100:.2f} and ${int(self.max_value) / 100:.2f}" + ) + if self.type == PayoutType.CASH_IN_MAIL: + if amount % 500 != 0: + raise ValueError("Amount must be in increments of $5.00") + return True + + +class CashoutMethod(CashoutMethodBase): + user: Optional[User] = Field( + default=None, + description="If set, this cashout method is custom for this user. For example" + "a user may have a paypal cashout method with their paypal" + "email associated.", + ) + last_updated: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + is_live: bool = Field(default=True) + + @model_validator(mode="after") + def validate_user(self) -> Self: + if self.type in {PayoutType.PAYPAL, PayoutType.CASH_IN_MAIL}: + assert ( + self.user is not None + ), "user_id must be set for this cashout method type" + else: + assert ( + self.user is None + ), "user_id must NOT be set for this cashout method type" + return self + + +class CashoutMethodOut(CashoutMethodBase): + product_id: Optional[UUIDStr] = Field( + default=None, examples=["4fe381fb7186416cb443a38fa66c6557"] + ) + + product_user_id: Optional[str] = Field( + default=None, + min_length=3, + max_length=128, + examples=["app-user-9329ebd"], + description="A unique identifier for each user, which is set by the " + "Supplier. It should not contain any sensitive information" + "like email or names, and should avoid using any" + "incrementing values.", + ) + + @classmethod + def from_cashout_method(cls, cm: CashoutMethod) -> Self: + d = cm.model_dump() + if cm.user: + d["product_id"] = cm.user.product_id + d["product_user_id"] = cm.user.product_user_id + return cls.model_validate(d) + + +class USDeliveryAddress(BaseModel): + name_or_attn: str = Field(min_length=1, max_length=50) + company: Optional[str] = Field( + default=None, + min_length=1, + max_length=50, + ) + phone_number: Optional[str] = Field( + default=None, + min_length=10, + max_length=10, + pattern=r"^[0-9]+$", + ) + address: str = Field(min_length=1, max_length=100) + city: str = Field(min_length=1, max_length=100) + state: str = Field(min_length=1, max_length=2) + postal_code: str = Field(min_length=1, max_length=10) + country: CountryISO = Field(default="us") + + def md5sum(self) -> str: + return hashlib.md5(self.model_dump_json().encode()).hexdigest() + + +class CashMailCashoutMethodData(BaseModel): + type: Literal[PayoutType.CASH_IN_MAIL] = Field(default=PayoutType.CASH_IN_MAIL) + + delivery_address: USDeliveryAddress = Field( + description="Delivery address where payment should be sent" + ) + + +class PaypalCashoutMethodData(BaseModel): + type: Literal[PayoutType.PAYPAL] = Field(default=PayoutType.PAYPAL) + + email: EmailStr = Field( + description="Email address of the paypal user", + examples=["test@example.com"], + ) + + +class TangoCashoutMethodData(BaseModel): + type: Literal[PayoutType.TANGO] = Field(default=PayoutType.TANGO) + utid: str = Field(description="tango utid") + # TODO: Can't be CountryISOLike because it appears to be allcaps + countries: List[str] = Field() + value_type: Literal["variable", "fixed"] = Field() + disclaimer: str = Field(default="") + terms: str = Field(default="") + + @field_validator("countries", mode="after") + def countries_case(cls, countries: List[str]) -> List[str]: + return [x.lower() for x in countries] + + +class AmtCashoutMethodData(BaseModel): + type: Literal[PayoutType.AMT] = Field(default=PayoutType.AMT) + + +class CashoutMethodsResponse(StatusResponse): + cashout_methods: List[CashoutMethodOut] = Field() + + +class DeliveryStatus(str, Enum): + PENDING = "Pending" + SHIPPED = "Shipped" + IN_TRANSIT = "In Transit" + OUT_FOR_DELIVERY = "Out for Delivery" + DELIVERED = "Delivered" + RETURNED = "Returned" + CANCELED = "Canceled" + FAILED_ATTEMPT = "Failed Attempt" + LOST = "Lost" + + +class ShippingCarrier(str, Enum): + USPS = "USPS" + FEDEX = "FedEx" + UPS = "UPS" + DHL = "DHL" + + +class ShippingMethod(str, Enum): + STANDARD = "Standard" + EXPRESS = "Express" + TWO_DAY = "Two-Day" + OVERNIGHT = "Overnight" + SAME_DAY = "Same Day" + + +# This goes in the PayoutEvent.order_data +class CashMailOrderData(BaseModel): + type: Literal[PayoutType.CASH_IN_MAIL] = Field(default=PayoutType.CASH_IN_MAIL) + shipping_cost: Optional[PositiveInt] = Field( + description="(USD cents) The shipping cost. This amount get charged to the BP.", + strict=True, + ) + tracking_number: Optional[str] = Field( + default=None, + min_length=1, + max_length=50, + ) + shipping_method: Optional[ShippingMethod] = Field( + default=None, + min_length=1, + max_length=50, + description="Standard, express, etc.", + ) + carrier: Optional[ShippingCarrier] = Field( + default=None, + min_length=1, + max_length=50, + description="Name of the shipping company, e.g., USPS, FedEx, DHL", + ) + ship_date: Optional[AwareDatetimeISO] = Field(default=None) + estimated_delivery_date: Optional[AwareDatetimeISO] = Field(default=None) + delivery_status: Optional[DeliveryStatus] = Field( + default=None, + min_length=1, + max_length=50, + description="Current status of delivery, e.g., pending, in " + "transit, delivered", + ) + last_updated: Optional[AwareDatetimeISO] = Field( + default=None, + description="Timestamp of the last status update", + ) + + +class CreateCashoutMethodRequest(BaseModel): + bpuid: BPUIDStr = Field( + description="(product_user_id) The user to create this cashout method for.", + examples=["app-user-9329ebd"], + ) + type: PayoutType = Field( + description=PayoutType.as_openapi_with_value_descriptions(), + examples=[PayoutType.PAYPAL], + ) + + +class CreatePayPalCashoutMethodRequest( + PaypalCashoutMethodData, + CreateCashoutMethodRequest, +): + pass + + +class CreateCashMailCashoutMethodRequest( + CashMailCashoutMethodData, CreateCashoutMethodRequest +): + pass + + +class CashoutMethodResponse(StatusResponse): + cashout_method: CashoutMethodOut = Field() + + +class CreateCashoutRequest(BaseModel): + bpuid: BPUIDStr = Field( + description="(product_user_id) The user requesting a cashout.", + examples=["app-user-9329ebd"], + ) + amount: PositiveInt = Field( + description="(USD cents) The amount requested for the cashout.", + strict=True, + examples=[531], + ) + cashout_method_id: UUIDStr = Field( + description="Unique ID for the cashout method the cashout is being requested with.", + examples=["941d489c3ce04eb39a0ddb7f8f75db74"], + ) + + +class CashoutRequestInfo(BaseModel): + """See models.thl.payout: PayoutEvent. We've confused a CashOut and a + Payout. This is used only in the API response. + """ + + id: Optional[UUIDStr] = Field( + description="Unique ID for this cashout. This may be NULL if the " + "status is REJECTED or FAILED, which may happen if the " + "request is invalid.", + examples=["3ceb847aaf9f40f4bd15b2b5e083abf6"], + ) + description: str = Field( + description="This is the name of the cashout method.", + examples=["Visa® Prepaid Card USD"], + ) + message: Optional[str] = Field(default=None) + status: Optional[PayoutStatus] = Field( + default=PayoutStatus.PENDING, + description=PayoutStatus.as_openapi(), + examples=[PayoutStatus.PENDING], + ) + transaction_info: Optional[Dict[str, Any]] = Field(default=None) + + +class CashoutRequestResponse(StatusResponse): + cashout: CashoutRequestInfo = Field() + + +example_foreign_value = { + "value": "138", + "currency": "CAD", + "value_string": "$1.38 CAD", +} + + +class RedemptionCurrency(str, Enum, metaclass=ReprEnumMeta): + """ + Supported Currencies for Foreign Redemptions + """ + + # US Dollars. Smallest Unit: Cents. + USD = "USD" + # Canadian Dollars. Smallest Unit: Cents. + CAD = "CAD" + # British Pounds. Smallest Unit: Pence. + GBP = "GBP" + # Euros. Smallest Unit: Cents. + EUR = "EUR" + # Indian Rupees. Smallest Unit: Paise. + INR = "INR" + # Australian Dollars. Smallest Unit: Cents. + AUD = "AUD" + # Polish Zloty. Smallest Unit: Grosz. + PLN = "PLN" + # Swedish Krona. Smallest Unit: Öre. + SEK = "SEK" + # Singapore Dollars. Smallest Unit: Cents. + SGD = "SGD" + # Mexican Pesos. Smallest Unit: Centavos. + MXN = "MXN" + + +class CashoutMethodForeignValue(BaseModel): + """ + Shows the expected value of a redemption in a foreign currency. + """ + + model_config = ConfigDict(json_schema_extra={"example": example_foreign_value}) + + value: NonNegativeInt = Field( + description="Value of the redemption in the currency's smallest unit." + ) + currency: RedemptionCurrency = Field( + description=RedemptionCurrency.as_openapi_with_value_descriptions() + ) + value_string: str = Field( + description="A string representation of the value in the currency." + ) + + +class CashoutMethodForeignValueResponse(StatusResponse): + cashout_method_value: CashoutMethodForeignValue = Field() diff --git a/generalresearch/models/thl/wallet/payout.py b/generalresearch/models/thl/wallet/payout.py new file mode 100644 index 0000000..97e0c3d --- /dev/null +++ b/generalresearch/models/thl/wallet/payout.py @@ -0,0 +1,214 @@ +import json +from datetime import datetime, timezone +from typing import Dict, Optional, Collection, List +from uuid import uuid4 + +from pydantic import ( + BaseModel, + Field, + PositiveInt, + computed_field, + field_validator, +) + +from generalresearch.currency import USDCent +from generalresearch.models.custom_types import UUIDStr, AwareDatetimeISO +from generalresearch.models.thl.definitions import PayoutStatus +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.thl.wallet.cashout_method import ( + CashMailOrderData, +) + + +class PayoutEvent(BaseModel, validate_assignment=True): + """A user has requested to be paid from their wallet balance.""" + + uuid: UUIDStr = Field( + default_factory=lambda: uuid4().hex, + examples=["9453cd076713426cb68d05591c7145aa"], + ) + + # This is the LedgerAccount.uuid that this money is being requested + # from. The user/BP is retrievable through the LedgerAccount.reference_uuid + debit_account_uuid: UUIDStr = Field(examples=["18298cb1583846fbb06e4747b5310693"]) + + # These two fields are copied here from the LedgerAccount through the + # debit_account_uuid for convenience. They will get populated if the + # PayoutEventManager retrieves a PayoutEvent from the db. + account_reference_type: Optional[str] = Field(default=None) + account_reference_uuid: Optional[UUIDStr] = Field(default=None) + + # References a row in the account_cashoutmethod table. This is the + # cashout method that was used to request this payout. (A cashout is + # the same thing as a payout) + cashout_method_uuid: UUIDStr = Field(examples=["a6dc1fc1bf934557b952f253dee12813"]) + + # By default, this will just be the cashout_method.name. This also is + # populated from the db and so does not need to be set (there is no + # `description` field in event_payout) + description: Optional[str] = Field(default=None) + created: AwareDatetimeISO = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + + # In the smallest unit of the currency being transacted. For USD, this + # is cents. + amount: PositiveInt = Field( + lt=2**63 - 1, + strict=True, + description="The USDCent amount int. This cannot be 0 or negative", + examples=[531], + ) + + status: PayoutStatus = Field( + default=PayoutStatus.PENDING, + description=PayoutStatus.as_openapi(), + examples=[PayoutStatus.COMPLETE], + ) + + # Used for holding an external, payout-type-specific identifier + ext_ref_id: Optional[str] = Field(default=None) + payout_type: PayoutType = Field( + description=PayoutType.as_openapi(), examples=[PayoutType.ACH] + ) + + # Stores payout-type-specific information that is used to request this + # payout from the external provider. + request_data: Dict = Field(default_factory=dict) + + # Stores payout-type-specific order information that is returned from + # the external payout provider. + order_data: Optional[Dict | CashMailOrderData] = Field(default=None) + + @field_validator("payout_type", mode="before") + @classmethod + def normalize_enum(cls, v): + if isinstance(v, str): + try: + return PayoutType[v.upper()] + except KeyError: + raise ValueError(f"Invalid payout_type: {v}") + return v + + def update( + self, + status: PayoutStatus, + ext_ref_id: Optional[str] = None, + order_data: Optional[Dict] = None, + ) -> None: + # These 3 things are the only modifiable attributes + self.check_status_change_allowed(status) + self.status = status + self.ext_ref_id = ext_ref_id + self.order_data = order_data + + def check_status_change_allowed(self, status: PayoutStatus) -> None: + if self.status in { + PayoutStatus.REJECTED, + PayoutStatus.CANCELLED, + PayoutStatus.COMPLETE, + }: + raise ValueError(f"status {self.status} is final. No changes allowed") + + if self.status == PayoutStatus.PENDING: + assert status != PayoutStatus.PENDING, "status is already PENDING!" + + elif self.status == PayoutStatus.APPROVED: + assert status in { + PayoutStatus.FAILED, + PayoutStatus.COMPLETE, + }, f"status APPROVED can only be FAILED or COMPLETED, not {status}" + + elif self.status == PayoutStatus.FAILED: + assert status in { + PayoutStatus.CANCELLED, + PayoutStatus.COMPLETE, + }, f"status FAILED can only be CANCELLED or COMPLETED, not {status}" + + else: + raise ValueError("this shouldn't happen") + + def model_dump_mysql(self, *args, **kwargs) -> dict: + d = self.model_dump(mode="json", *args, **kwargs) + if "created" in d: + d["created"] = self.created.replace(tzinfo=None) + if d.get("request_data") is not None: + d["request_data"] = json.dumps(self.request_data) + if d.get("order_data") is not None: + if isinstance(self.order_data, dict): + d["order_data"] = json.dumps(self.order_data) + else: + d["order_data"] = self.order_data.model_dump_json() + return d + + +class BPPayoutEvent(BaseModel): + uuid: UUIDStr = Field( + title="Brokerage Product Payout ID", + description="Unique identifier for the Payout Event", + examples=["9453cd076713426cb68d05591c7145aa"], + ) + + product_id: UUIDStr = Field( + description="The Brokerage Product that was paid out", + examples=["1108d053e4fa47c5b0dbdcd03a7981e7"], + ) + + created: AwareDatetimeISO = Field( + description="When the Brokerage Product was paid out", + default_factory=lambda: datetime.now(tz=timezone.utc), + ) + + amount: USDCent = Field( + lt=2**63 - 1, + strict=True, + description="The USDCent amount int. This cannot be 0 or negative", + examples=[531], + ) + + status: Optional[PayoutStatus] = Field( + default=PayoutStatus.PENDING, + description=PayoutStatus.as_openapi(), + examples=[PayoutStatus.COMPLETE], + ) + + method: PayoutType = Field( + title="Payout Method", + description=PayoutType.as_openapi(), + examples=[PayoutType.ACH], + ) + + @computed_field(return_type=str, examples=["$10,000.000"]) + @property + def amount_usd(self) -> str: + return self.amount.to_usd_str() + + @staticmethod + def from_pe( + payout_events: Collection[PayoutEvent], + account_product_mapping: Dict[str, str], + order_by="ASC", + ) -> List["BPPayoutEvent"]: + res = [] + for pe in payout_events: + bp_pe = BPPayoutEvent.model_validate( + { + "uuid": pe.uuid, + "product_id": account_product_mapping[pe.debit_account_uuid], + "created": pe.created, + "amount": USDCent(pe.amount), + "status": pe.status, + "method": pe.payout_type, + } + ) + res.append(bp_pe) + + match order_by: + case "ASC": + sorted_list = sorted(res, key=lambda x: x.created, reverse=False) + case "DESC": + sorted_list = sorted(res, key=lambda x: x.created, reverse=True) + case _: + raise ValueError("Invalid order provided..") + + return sorted_list diff --git a/generalresearch/models/thl/wallet/user_wallet.py b/generalresearch/models/thl/wallet/user_wallet.py new file mode 100644 index 0000000..917a09d --- /dev/null +++ b/generalresearch/models/thl/wallet/user_wallet.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import logging + +from pydantic import BaseModel, Field, ConfigDict, NonNegativeInt + +from generalresearch.models.legacy.api_status import StatusResponse +from generalresearch.models.thl.payout_format import ( + PayoutFormatType, + PayoutFormatField, +) + +logger = logging.getLogger() + +example_wallet_balance = { + "amount": 123, + "redeemable_amount": 100, + "payout_format": "{payout*10:,.0f} Points", + "amount_string": "1230 Points", + "redeemable_amount_string": "1000 Points", +} + + +class UserWalletBalance(BaseModel): + model_config = ConfigDict(json_schema_extra={"example": example_wallet_balance}) + + # This can be negative (due to recons for instance), but shouldn't be often ... + amount: int = Field(description="(USD cents) The amount in the user's wallet.") + redeemable_amount: NonNegativeInt = Field( + description="(USD cents) The amount in the user's wallet this is currently redeemable." + ) + payout_format: PayoutFormatType = PayoutFormatField + amount_string: str = Field( + description="The 'amount' with the payout_format applied. Can be displayed to the user." + ) + redeemable_amount_string: str = Field( + description="The 'redeemable_amount' with the payout_format applied. Can be displayed to the user." + ) + + +class UserWalletBalanceResponse(StatusResponse): + wallet: UserWalletBalance = Field() diff --git a/generalresearch/models/utils.py b/generalresearch/models/utils.py new file mode 100644 index 0000000..29a48f7 --- /dev/null +++ b/generalresearch/models/utils.py @@ -0,0 +1,9 @@ +from decimal import Decimal + + +def usd_cents_to_decimal(v: int) -> Decimal: + return Decimal(Decimal(int(v)) / Decimal(100)) + + +def decimal_to_usd_cents(d: Decimal) -> int: + return round(d * Decimal(100)) diff --git a/generalresearch/pg_helper.py b/generalresearch/pg_helper.py new file mode 100644 index 0000000..064a883 --- /dev/null +++ b/generalresearch/pg_helper.py @@ -0,0 +1,124 @@ +from typing import Optional + +from psycopg.adapt import Buffer +from psycopg.types.net import InetLoader, Address, Interface +from psycopg.types.string import TextLoader +from pydantic import PostgresDsn + +import psycopg +from psycopg.rows import dict_row, RowFactory +from psycopg.types.uuid import UUIDLoader +from psycopg.types.datetime import TimestampLoader +from datetime import timezone + + +class UUIDHexLoader(UUIDLoader): + def load(self, data): + value = super().load(data) + return value.hex + + +class UTCTimestampLoader(TimestampLoader): + def load(self, data): + dt = super().load(data) + if dt is None: + return None + assert dt.tzinfo is None, "expected naive dt" + return dt.replace(tzinfo=timezone.utc) + + +class BPCharLoader(TextLoader): + def load(self, data): + data = super().load(data) + if data is None: + return None + if type(data) is bytes: + return data.decode("utf-8").rstrip(" ") + else: + return data.rstrip(" ") + + +class InetHostLoader(InetLoader): + def load(self, data): + data = super().load(data) + if data is None: + return None + return str(data.exploded).split("/")[0] + + +class PostgresConfig: + def __init__( + self, + dsn: PostgresDsn, + connect_timeout: int, + statement_timeout: float, + schema: Optional[str] = None, + row_factory: RowFactory = dict_row, + ): + """ + Hold configuration to enable postgres operations. + + :param dsn: See https://www.postgresql.org/docs/current/libpq-connect.html + For timeouts and other options, see: + https://www.postgresql.org/docs/current/runtime-config-client.html + :param connect_timeout: (seconds) Maximum time to wait while connecting. + :param statement_timeout: (seconds) Abort any statement that takes more than the specified amount of time. + + # Note, there is no read/write timeout. See idle_in_transaction_session_timeout, lock_timeout, etc. + # There is also transaction_timeout also, but is only in the latest version? and I'm not sure the difference. + """ + self.dsn = dsn + self.connect_timeout = connect_timeout + self.statement_timeout = statement_timeout + self.schema = schema or dsn.path.lstrip("/") + assert 0 < connect_timeout < 130, "connect_timeout should be in seconds" + self.row_factory = row_factory + + @property + def db(self): + return self.dsn.path[1:] + + def make_connection(self) -> psycopg.Connection: + options = [ + f"-c statement_timeout={round(self.statement_timeout*1000)}", + "-c timezone=UTC", + "-c client_encoding=UTF8", + ] + if self.schema: + options.append(f"-c search_path={self.schema},public") + options_str = " ".join(options) + conn = psycopg.connect( + str(self.dsn), + connect_timeout=self.connect_timeout, + options=options_str, + row_factory=self.row_factory, + ) + conn.adapters.register_loader("uuid", UUIDHexLoader) + conn.adapters.register_loader("timestamp", UTCTimestampLoader) + conn.adapters.register_loader("bpchar", BPCharLoader) + conn.adapters.register_loader("inet", InetHostLoader) + return conn + + def execute_sql_query(self, query, params=None): + # This is only intended for SELECT queries + assert "SELECT" in query.upper(), "Supports SELECTs only" + + with self.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=query, params=params) + return c.fetchall() + + def execute_write(self, query, params=None) -> int: + cmd = query.lstrip().upper() + assert ( + cmd.startswith("INSERT") + or cmd.startswith("UPDATE") + or cmd.startswith("DELETE") + ), "Supports INSERT/UPDATE only" + + with self.make_connection() as conn: + with conn.cursor() as c: + c.execute(query=query, params=params) + rowcount = c.rowcount + conn.commit() + return rowcount diff --git a/generalresearch/priority_thread_pool.py b/generalresearch/priority_thread_pool.py new file mode 100644 index 0000000..9d254e4 --- /dev/null +++ b/generalresearch/priority_thread_pool.py @@ -0,0 +1,67 @@ +import time +from concurrent.futures.thread import ThreadPoolExecutor, _WorkItem +from queue import PriorityQueue +from uuid import uuid4 + + +class WorkItemPriorityQueue(PriorityQueue): + """ + Custom Class that overloads get and put so that the priority is handled in + a way as to not break ThreadPoolExecutor, which expect a regular Queue + """ + + def get(self, block=True, timeout=None): + """ + Assumes the items are of format: (priority, tie-breaker, item) + """ + res = super().get(block, timeout) + assert type(res) == tuple and len(res) == 3 + return res[2] + + def put(self, item, block=True, timeout=None) -> None: + """ + Assumes `item` is a concurrent.futures.thread._WorkItem (but can also + be None). If not None, attempts to pull out `priority` from the + kwargs and use it to build the (priority, tie-breaker, item) tuple, + which is put on the PriorityQueue. + """ + if item is None: + item = (0, uuid4().hex, None) + elif type(item) == _WorkItem: + priority = item.kwargs.pop("priority", 0) + item = (priority, uuid4().hex, item) + else: + raise ValueError("unexpected item, type: ", type(item)) + super().put(item, block, timeout) + + +class PriorityThreadPoolExecutor(ThreadPoolExecutor): + """ + Set the priority of a job using the kwarg 'priority'. Note, if you are + attempting to run a function that itself has a kwarg called `priority`, + this will not work as expected. + + Example usage: + >> q = PriorityThreadPoolExecutor(max_workers=1) + >> q.submit(do_nothing, 'high', priority=-1) + """ + + def __init__(self, max_workers=None, thread_name_prefix=""): + super().__init__(max_workers, thread_name_prefix) + self._work_queue = WorkItemPriorityQueue() + + +def do_nothing(input): + print("working: ", input) + time.sleep(1) + print("done: ", input) + + +if __name__ == "__main__": + q = PriorityThreadPoolExecutor(max_workers=1) + q.submit(do_nothing, "low") + q.submit(do_nothing, "low") + q.submit(do_nothing, "low") + q.submit(do_nothing, "high", priority=-1) + q.submit(do_nothing, "super-high!", priority=-100) + q.shutdown(wait=False) diff --git a/generalresearch/redis_helper.py b/generalresearch/redis_helper.py new file mode 100644 index 0000000..61e02d9 --- /dev/null +++ b/generalresearch/redis_helper.py @@ -0,0 +1,33 @@ +import redis +from pydantic import RedisDsn + + +class RedisConfig: + def __init__( + self, + dsn: RedisDsn, + decode_responses: bool = True, + socket_timeout: float = 0.1, + socket_connect_timeout: float = 0.1, + ): + """ + Holds configuration for creating redis clients. + """ + self.dsn = dsn + self.decode_responses = decode_responses + self.socket_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout + + @property + def db(self): + return self.dsn.path[1:] + + def create_redis_client(self) -> redis.Redis: + # Clients are thread safe. We can just create one upon init + redis_config_dict = { + "url": str(self.dsn), + "decode_responses": self.decode_responses, + "socket_timeout": self.socket_timeout, + "socket_connect_timeout": self.socket_connect_timeout, + } + return redis.Redis.from_url(**redis_config_dict) diff --git a/generalresearch/resources/__init__.py b/generalresearch/resources/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/schemas/__init__.py b/generalresearch/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/schemas/survey_stats.py b/generalresearch/schemas/survey_stats.py new file mode 100644 index 0000000..532e0d7 --- /dev/null +++ b/generalresearch/schemas/survey_stats.py @@ -0,0 +1,159 @@ +import pandas as pd +from pandera import DataFrameSchema, Column, Check, Index + +from generalresearch.locales import Localelator +from generalresearch.models import Source + +COUNTRY_ISOS = Localelator().get_all_countries() +kosovo = "xk" +COUNTRY_ISOS.add(kosovo) + +SURVEY_STATS_COLS = [ + "PRESCREEN_CONVERSION.alpha", + "PRESCREEN_CONVERSION.beta", + "CONVERSION.alpha", + "CONVERSION.beta", + "COMPLETION_TIME.mu", + "COMPLETION_TIME.sigma", + "LONG_FAIL.value", + "USER_REPORT_COEFF.value", + "RECON_LIKELIHOOD.value", + "DROPOFF_RATE.alpha", + "DROPOFF_RATE.beta", + "IS_MOBILE_ELIGIBLE.alpha", + "IS_MOBILE_ELIGIBLE.beta", + "IS_DESKTOP_ELIGIBLE.alpha", + "IS_DESKTOP_ELIGIBLE.beta", + "IS_TABLET_ELIGIBLE.alpha", + "IS_TABLET_ELIGIBLE.beta", + "cpi", + "country_iso", + "is_recontact", + "score_x0", + "score_x1", + "buyer_id", + "complete_too_fast_cutoff", +] + +# e.g. The parameters for a beta distribution are just real numbers > 0, but +# in practice, if the numbers are "very" large, something is def wrong. +# Defining "very" somewhat arbitrarily here. +SUSPICIOUSLY_LARGE_NUMBER = (2**32 / 2) - 1 # 2147483647 + +PositiveRealNumber = Column( + dtype=float, + nullable=False, + checks=Check.between( + min_value=0, max_value=SUSPICIOUSLY_LARGE_NUMBER, include_min=False + ), +) +NonNegativeRealNumber = Column( + dtype=float, + nullable=False, + checks=Check.between( + min_value=0, max_value=SUSPICIOUSLY_LARGE_NUMBER, include_min=True + ), +) +RealNumber = Column( + dtype=float, + nullable=False, + checks=Check.between( + min_value=SUSPICIOUSLY_LARGE_NUMBER, max_value=SUSPICIOUSLY_LARGE_NUMBER + ), +) +# Real number between 0 and 1 inclusive. +UnitInterval = Column( + dtype=float, + nullable=False, + checks=Check.between(min_value=0, max_value=1), +) + +SID_CHECKS = [ + Check.str_length(min_value=3, max_value=67), + Check.str_matches("^[a-z]{1,2}\:[A-Za-z0-9]+"), + Check( + lambda x: len(set(x.str.split(":").str[0])) == 1, + error="the sources must all be the same", + ), +] + +SurveyStatSchema = DataFrameSchema( + index=Index( + name="sid", + description="CURIE format (source:task_id)", + dtype=str, + unique=True, + checks=SID_CHECKS, + ), + columns={ + "source": Column( + dtype=str, nullable=False, checks=[Check.isin([e.value for e in Source])] + ), + "task_id": Column(dtype=str, nullable=False, checks=[Check.str_length(1, 64)]), + "PRESCREEN_CONVERSION.alpha": PositiveRealNumber, + "PRESCREEN_CONVERSION.beta": PositiveRealNumber, + "CONVERSION.alpha": PositiveRealNumber, + "CONVERSION.beta": PositiveRealNumber, + # Normal distribution, so mu is real number, but this represents the + # completion time, so it has to be positive. We can restrict it more + # in that me are never going to predict time longer than ~~ 2 + # hours (np.log(120*60)) or <= 0 sec (np.log(1) = 0) + "COMPLETION_TIME.mu": Column( + dtype=float, + nullable=False, + checks=Check.between(min_value=1, max_value=10, include_min=False), + ), + "COMPLETION_TIME.sigma": Column( + dtype=float, + nullable=False, + checks=Check.between(min_value=0, max_value=10, include_min=False), + ), + # this should be much less than 10... I think check + "LONG_FAIL.value": Column( + dtype=float, + nullable=False, + checks=Check.between(min_value=0, max_value=10, include_min=False), + ), + "USER_REPORT_COEFF.value": UnitInterval, + "RECON_LIKELIHOOD.value": UnitInterval, + "DROPOFF_RATE.alpha": PositiveRealNumber, + "DROPOFF_RATE.beta": PositiveRealNumber, + "IS_MOBILE_ELIGIBLE.alpha": PositiveRealNumber, + "IS_MOBILE_ELIGIBLE.beta": PositiveRealNumber, + "IS_DESKTOP_ELIGIBLE.alpha": PositiveRealNumber, + "IS_DESKTOP_ELIGIBLE.beta": PositiveRealNumber, + "IS_TABLET_ELIGIBLE.alpha": PositiveRealNumber, + "IS_TABLET_ELIGIBLE.beta": PositiveRealNumber, + "cpi": Column( + dtype=float, + checks=Check.between(min_value=0, max_value=1_000), + nullable=False, + ), + "country_iso": Column( + dtype=str, + checks=[ + Check.str_length(min_value=1, max_value=2), + Check.isin(COUNTRY_ISOS), # 2 letter, lowercase + ], + nullable=True, + ), + "is_recontact": Column(dtype=bool), + "score_x0": NonNegativeRealNumber, + "score_x1": NonNegativeRealNumber, + "buyer_id": Column( + dtype=str, + checks=[Check.str_length(min_value=1, max_value=32)], + nullable=True, + ), + "complete_too_fast_cutoff": Column( + dtype=float, + nullable=False, + checks=Check.between(min_value=0, max_value=120 * 60, include_min=False), + ), + "created": Column( + dtype=pd.DatetimeTZDtype(tz="UTC"), nullable=True, required=False + ), + }, + checks=[], + coerce=True, +) diff --git a/generalresearch/sql_helper.py b/generalresearch/sql_helper.py new file mode 100644 index 0000000..b92813b --- /dev/null +++ b/generalresearch/sql_helper.py @@ -0,0 +1,351 @@ +import logging +from typing import Any, Dict, List, Tuple, Union, Optional +from uuid import UUID + +from pydantic import MySQLDsn, PostgresDsn, MariaDBDsn +from pymysql import Connection + +ListOrTupleOfStrings = Union[List[str], Tuple[str, ...]] +ListOrTupleOfListOrTuple = Union[ + List[List], List[Tuple], Tuple[List, ...], Tuple[Tuple, ...] +] + +DataBaseDsn = Union[MySQLDsn, MariaDBDsn, PostgresDsn] + + +class MultipleObjectsReturned(Exception): + pass + + +class SqlConnector: + """ + SqlConnector is GRL's simplified SQLAlchemy.. it's basic, it's raw + it does whatever want. Maybe we overwrite this to just use SQLAlchemy + on the backend... but it's just not worth it for now. + """ + + # For connection and cursor handling, and any difference between mysql + # and postgresql + def __init__(self, dsn: Optional[DataBaseDsn] = None, **kwargs): + """ + Anything in kwargs gets passed into the engine_module's connect + function. To be used for e.g.: + s = SqlHelper('127.0.0.1', 'root', '', '300large', read_timeout=10) + """ + + self.dsn = dsn + + # I'm intentionally doing a match case here so that we'll make sure + # we can NOT use this on old versions of python 😈 + if "mysql" in self.dsn.scheme: + import pymysql as engine_module + + self.engine_module = engine_module + self.cursor_class = engine_module.cursors.DictCursor + self.quote_char = "`" + + elif "maria" in self.dsn.scheme: + import pymysql as engine_module + + self.engine_module = engine_module + self.cursor_class = engine_module.cursors.DictCursor + self.quote_char = "`" + + if "autocommit" in kwargs: + raise AssertionError("Be clear, be explicit.") + + self.kwargs = kwargs + + @property + def dbname(self) -> str: + return self.dsn.path[1:] + + @property + def db_name(self) -> str: + return self.dsn.path[1:] + + @property + def db(self) -> str: + return self.dsn.path[1:] + + def is_mysql(self) -> bool: + return "mysql" in self.dsn.scheme + + def is_maria(self) -> bool: + return "maria" in self.dsn.scheme + + def make_connection(self) -> Connection: + # We are making a new connection for every cursor to make sure + # multithreading/processing works correctly. + if self.is_mysql(): + connection = self.engine_module.connect( + host=self.dsn.host, + user=self.dsn.username, + password=self.dsn.password, + db=self.dsn.path[1:], + cursorclass=self.cursor_class, + **self.kwargs, + ) + return connection + + elif self.is_maria(): + # TODO: We want to to support this at some point, but will + # require alt handling of uuid hex values as MariaDB + # saves them as xxxx-xxxx-xxxx. We must: + # - Confirm evaluations and joins with with hex or non-hex versions + # - Decide if we return as hex version, or alter the UUIDStr custom_type + # as they'll now have more than 32 chars + connection = self.engine_module.connect( + host=self.dsn.host, + user=self.dsn.username, + password=self.dsn.password, + database=self.dsn.path[1:], + cursorclass=self.cursor_class, + ) + return connection + + +def is_uuid4(s: Any) -> bool: + if not isinstance(s, str): + return False + + if len(s) not in (32, 36): + return False + + try: + u = UUID(s, version=4) + return u.hex == s if len(s) == 32 else str(u) == s + except (ValueError, AttributeError, TypeError): + return False + + +def decode_uuids(row: Dict) -> Dict: + return { + key: (UUID(value, version=4).hex if is_uuid4(value) else value) + for key, value in row.items() + } + + +class SqlHelper(SqlConnector): + + def __init__(self, dsn: Optional[DataBaseDsn] = None, **kwargs): + super(SqlHelper, self).__init__(dsn, **kwargs) + + def execute_sql_query(self, query, params=None, commit=False) -> List[Dict]: + for param in params if params else []: + if isinstance(param, (tuple, list, set)) and len(param) == 0: + logging.warning("param is empty. not executing query") + return [] + connection = self.make_connection() + c = connection.cursor() + c.execute(query, params) + if commit: + connection.commit() + + if self.is_maria(): + return [decode_uuids(row) for row in c.fetchall()] + + else: + return c.fetchall() + + def _quote(self, s) -> str: + return f"{self.quote_char}{s}{self.quote_char}" + + def bulk_insert( + self, + table_name: str, + field_names: ListOrTupleOfStrings, + values_to_insert: ListOrTupleOfListOrTuple, + cursor=None, + ignore_existing=False, + ) -> None: + """ + :param table_name: name of table + :param field_names: list or tuple of field names, corresponding to their + index/order in `values_to_insert`. + :param values_to_insert: list of lists, where the inner list contains + each row of values to be inserted. The order corresponds to the order + of `field_names`. + :param cursor: If cursor is passed, the insert is NOT committed! + :param ignore_existing: adds 'ON CONFLICT DO NOTHING' to SQL statement. + """ + assert len(set([len(x) for x in values_to_insert])) == 1 + if cursor is None: + connection = self.make_connection() + c = connection.cursor() + else: + c = cursor + + if self.is_mysql() or self.is_maria(): + values_to_insert = [ + [c.connection.escape_string(v) if isinstance(v, str) else v for v in vv] + for vv in values_to_insert + ] + + table_name_str = self._quote(table_name) + field_name_str = ",".join(map(self._quote, field_names)) + values_str = ",".join(["%s"] * len(values_to_insert[0])) + query = "" + + if self.is_mysql() or self.is_maria(): + ignore_str = "IGNORE" if ignore_existing else "" + query = f"INSERT {ignore_str} INTO {table_name_str} ({field_name_str}) VALUES ({values_str});" + + c.executemany(query, values_to_insert) + if cursor is None: + c.connection.commit() + + return None + + def bulk_update( + self, + table_name: str, + field_names: ListOrTupleOfStrings, + values_to_insert: ListOrTupleOfListOrTuple, + cursor=None, + ) -> None: + if len(values_to_insert) == 0: + return None + + assert len(set([len(x) for x in values_to_insert])) == 1 + if cursor is None: + connection = self.make_connection() + c = connection.cursor() + else: + c = cursor + + values_to_insert = [ + [c.connection.literal(v) for v in vv] for vv in values_to_insert + ] + field_names = ["`" + x + "`" for x in field_names] + field_name_str = ",".join(field_names) + table_name_str = self._quote(table_name) + + values_str = ",\n".join(["(" + ",".join(x) + ")" for x in values_to_insert]) + update_col_str = ", ".join(f"{k}=VALUES({k})" for k in field_names) + update_str = f"ON DUPLICATE KEY UPDATE {update_col_str}" + query = f"INSERT INTO {table_name_str} ({field_name_str}) VALUES {values_str} {update_str};" + c.execute(query) + if cursor is None: + c.connection.commit() + + return None + + def get_or_create( + self, + table_name: str, + primary_key: str, + lookup_dict: dict, + update_dict: dict, + cursor=None, + ) -> Tuple[Union[str, int], bool]: + """ + returns the value of the primary key ONLY, and bool (created) + """ + # primary_key = "id" + # lookup_dict = {"name": "Market Cube"} + # table_name = "lucid_lucidaccount" + # update_dict = {"name": "Market Cube"} + lookup_fns = ",".join( + ["`" + x + "`" for x in set(lookup_dict.keys()) | {primary_key}] + ) + lookup_vals = " AND ".join([f"`{fn}`=%({fn})s" for fn in lookup_dict.keys()]) + table_name_str = self._quote(table_name) + query = f"SELECT {lookup_fns} FROM {table_name_str} WHERE {lookup_vals} LIMIT 2" + if cursor is None: + connection = self.make_connection() + c = connection.cursor() + else: + c = cursor + + c.execute(query, lookup_dict) + res = c.fetchall() + num = len(res) + if num > 1: + raise MultipleObjectsReturned( + f"get() {table_name} returned more than one obj -- it returned {num}!" + ) + if num == 1: + return res[0][primary_key], False + new_pk = self.create(table_name, update_dict, cursor=c) + return new_pk, True + + def create( + self, + table_name: str, + create_dict: dict, + cursor=None, + commit=True, + primary_key=None, + ) -> Optional[int]: + """ + Create the item in table `table_name`. + In postgresql, `primary_key` needs to be given in order to return the + pk of the just created item + """ + if cursor is None: + connection = self.make_connection() + c = connection.cursor() + else: + c = cursor + field_names = ",".join(map(self._quote, create_dict)) + vals = ",".join([f"%({fn})s" for fn in create_dict.keys()]) + table_name_str = self._quote(table_name) + query = f"INSERT INTO {table_name_str} ({field_names}) VALUES ({vals})" + c.execute(query, create_dict) + if self.is_mysql(): + new_pk = c.lastrowid + else: + new_pk = None + if commit: + c.connection.commit() + return new_pk + + def filter( + self, + table_name: str, + field_names: ListOrTupleOfStrings, + filter_d=None, + limit=None, + cursor=None, + ) -> List[Dict[Any, Any]]: + + if cursor is None: + connection = self.make_connection() + c = connection.cursor() + else: + c = cursor + + table_name_str = self._quote(table_name) + field_names = ["`" + x + "`" for x in field_names] + field_name_str = ",".join(field_names) + if filter_d: + lookup_vals = " AND ".join([f"`{fn}`=%({fn})s" for fn in filter_d.keys()]) + lookup_str = f" WHERE {lookup_vals}" + else: + lookup_str = "" + limit_str = f"LIMIT {limit}" if limit else "" + + query = ( + f"SELECT {field_name_str} FROM {table_name_str} {lookup_str} {limit_str};" + ) + c.execute(query, filter_d) + + return c.fetchall() + + def delete(self, table_name: str, field_name: str, values, cursor=None) -> None: + if cursor is None: + connection = self.make_connection() + c = connection.cursor() + else: + c = cursor + + table_name_str = self._quote(table_name) + field_name_str = self._quote(field_name) + values_str = ",".join([c.connection.literal(v) for v in values]) + query = f"DELETE FROM {table_name_str} WHERE {field_name_str} IN ({values_str})" + c.execute(query) + if cursor is None: + c.connection.commit() + + return None diff --git a/generalresearch/thl_django/README.md b/generalresearch/thl_django/README.md new file mode 100644 index 0000000..c9479ee --- /dev/null +++ b/generalresearch/thl_django/README.md @@ -0,0 +1,70 @@ +# THL Django App + +This package contains the Django models and migrations required to create the +THL-compatible database schema. It is meant to be installed inside any Django +project so the schema can be applied automatically with `makemigrations` and +`migrate`. + +--- + +## 1. Installation + +Add the package to your environment: +(e.g. local development) +```bash +pip install generalresearch[django] +``` + +(e.g. editable install recommended during development) +```bash +pip install -e '/path/to/project/py-utils[django]' +``` + + +## 2. Add the App to INSTALLED_APPS + +In your Django test project's settings.py: + +``` +INSTALLED_APPS = [ + # ... + "generalresearch.thl_django", +] +``` + +## 3. Test that it worked +```shell +python manage.py shell +``` + +# For use in Jenkins / pytest + +There is a dummy/minimal django project under the `app` folder. This is set up +with thl_django as the only installed_app, and to read all setting from +environment variables. + +## Example Usage + +```postgresql +-- postgres=# +CREATE DATABASE "thl-jenkins" WITH TEMPLATE = template0 ENCODING = 'UTF8'; +``` + +```shell +pip install generalresearch[django] +export DB_NAME=thl-jenkins +export DB_USER=postgres +export DB_PASSWORD=password +export DB_HOST=127.0.0.1 +``` +```shell +# Confirm imports worked +python -m generalresearch.thl_django.app.manage shell -v 2 + +> assert settings.DATABASES['default']['NAME'] == 'thl-jenkins' +``` + +```shell +# Migrate +python -m generalresearch.thl_django.app.manage migrate --noinput +``` \ No newline at end of file diff --git a/generalresearch/thl_django/__init__.py b/generalresearch/thl_django/__init__.py new file mode 100644 index 0000000..877290c --- /dev/null +++ b/generalresearch/thl_django/__init__.py @@ -0,0 +1 @@ +default_app_config = "generalresearch.thl_django.apps.THLSchemaConfig" diff --git a/generalresearch/thl_django/accounting/__init__.py b/generalresearch/thl_django/accounting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/thl_django/accounting/models.py b/generalresearch/thl_django/accounting/models.py new file mode 100644 index 0000000..ac65f0a --- /dev/null +++ b/generalresearch/thl_django/accounting/models.py @@ -0,0 +1,44 @@ +import uuid + +from django.db import models + + +class CashoutMethod(models.Model): + """ + Stores info about different methods a user could use to redeem money + from their wallet. Each entry is a specific instance of a "thing" a user + can get, for e.g. a Visa Prepaid Card from Tango, and there will be many + CashoutMethods, all of provider "tango". + """ + + # The primary identifier for this method, may be exposed external to THL + id = models.UUIDField(default=uuid.uuid4, primary_key=True) + last_updated = models.DateTimeField(auto_now=True) + is_live = models.BooleanField(default=False) + + # This is the service that is "handling" the cashout, e.g. "TANGO", + # "DWOLLA", "PAYPAL", etc... + provider = models.CharField(max_length=32) + + # This is the method_provider's identifier for this cashout method + # (e.g. if method_provider = TANGO, this is the UTID. A (method_provider, + # ext_id) should uniquely map to an `id`. + ext_id = models.CharField(max_length=255, null=True) + + # Not required here as it will prob be in `data`, but just for convenience. + name = models.CharField(max_length=512) + + # Other method_class-specific data (min_value, max_value, value_type, + # disclaimer, etc...) + data = models.JSONField(default=dict) + + # For creating user-specific cashout methods + user_id = models.PositiveIntegerField(null=True) + + class Meta: + db_table = "accounting_cashoutmethod" + + indexes = [ + models.Index(fields=["user_id"]), + models.Index(fields=["provider", "ext_id"]), + ] diff --git a/generalresearch/thl_django/app/__init__.py b/generalresearch/thl_django/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/thl_django/app/manage.py b/generalresearch/thl_django/app/manage.py new file mode 100644 index 0000000..33f2367 --- /dev/null +++ b/generalresearch/thl_django/app/manage.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +import os +import sys + +if __name__ == "__main__": + os.environ.setdefault( + "DJANGO_SETTINGS_MODULE", "generalresearch.thl_django.app.settings" + ) + from django.core.management import execute_from_command_line + + execute_from_command_line(sys.argv) diff --git a/generalresearch/thl_django/app/settings.py b/generalresearch/thl_django/app/settings.py new file mode 100644 index 0000000..0d3f47a --- /dev/null +++ b/generalresearch/thl_django/app/settings.py @@ -0,0 +1,23 @@ +import os + +INSTALLED_APPS = [ + "django.contrib.contenttypes", + "generalresearch.thl_django", +] + +DATABASES = { + "default": { + "ENGINE": "django.db.backends.postgresql", + "NAME": os.environ.get("DB_NAME", "thl-test"), + "USER": os.environ.get("DB_USER", "postgres"), + "PASSWORD": os.environ.get("DB_PASSWORD", "password"), + "HOST": os.environ.get("DB_HOST", "127.0.0.1"), + "PORT": "5432", + } +} +DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" +LANGUAGE_CODE = "en-us" +TIME_ZONE = "UTC" +USE_I18N = True +USE_L10N = True +USE_TZ = True diff --git a/generalresearch/thl_django/apps.py b/generalresearch/thl_django/apps.py new file mode 100644 index 0000000..a3a0721 --- /dev/null +++ b/generalresearch/thl_django/apps.py @@ -0,0 +1,15 @@ +from django.apps import AppConfig + + +class THLSchemaConfig(AppConfig): + name = "generalresearchutils.thl_django" + label = "thl_django" + + def ready(self): + from .accounting import models # noqa: F401 # pycharm: keep + from .common import models # noqa: F401 # pycharm: keep + from .contest import models # noqa: F401 # pycharm: keep + from .event import models # noqa: F401 # pycharm: keep + from .marketplace import models # noqa: F401 # pycharm: keep + from .userhealth import models # noqa: F401 # pycharm: keep + from .userprofile import models # noqa: F401 # pycharm: keep diff --git a/generalresearch/thl_django/common/__init__.py b/generalresearch/thl_django/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/thl_django/common/models.py b/generalresearch/thl_django/common/models.py new file mode 100644 index 0000000..a41d2eb --- /dev/null +++ b/generalresearch/thl_django/common/models.py @@ -0,0 +1,745 @@ +import uuid + +from django.db import models +from django.db.models import Q + + +class THLSession(models.Model): + """ + The top level table is a session. Instead of only containing an ID, it'll have an internal + auto-increment, and an external UUID (that is for the session itself, not shared with the Wall ids). + """ + + id = models.BigAutoField(primary_key=True, null=False) + + # This is what gets exposed externally + uuid = models.UUIDField(null=False, unique=True) + + # The user that started this session + user_id = models.BigIntegerField(null=False) + + # Makes it easy to look at session total elapsed time + started = models.DateTimeField(null=False) + finished = models.DateTimeField(null=True) + + # These are the promised "parameters" of the session, as specified by the + # clicked bucket. User clicked on a bucket with LOI between these values. + loi_min = models.SmallIntegerField(null=True) + loi_max = models.SmallIntegerField(null=True) + + # User clicked on a bucket with a promised payout between these values. + # This is the user_payout, which is the bp_payout with the + # user_payout_transformation applied. + user_payout_min = models.DecimalField(max_digits=5, decimal_places=2, null=True) + user_payout_max = models.DecimalField(max_digits=5, decimal_places=2, null=True) + + # This is a shortcut for us to easily see country specific activity, it's + # set when the user enters the thl_session. We assert elsewhere in the code + # that prevents a user from changing country_iso while in a survey + country_iso = models.CharField(max_length=2, null=True) + + # The user's device type as determined by their useragent (nullable for + # legacy reasons) + device_type = models.SmallIntegerField(null=True) + + # The user's latest IP address when starting this session (nullable for + # legacy reasons) + ip = models.GenericIPAddressField(null=True) + + # The GRL status of the session. This is reportable externally. This can be + # independent of the status of the final wall event. Possible values: + # NULL (enter) f (failure) c (complete) a (user exited, or started + # another session) t (timeout). + # The status can only ever change from NULL-> something, and once it is + # set, cannot change. + status = models.CharField(max_length=1, default=None, null=True) + + # This is a more detailed status for the session. Again, may be unrelated + # to any associated wall events, or there may even be NO wall events + # (GRL Fail). This is also reportable externally. + # + # This is calculated from the underlying wall events' status codes. We may + # just report the last wall event's status_code, or maybe the most + # common? Or: If the user entered any client survey, it is a buyer_fail. + # Otherwise, it is usually the last wall event's. + # + # Uses GRL's category: (only those where status = 'f' would have a status_code) + # Buyer Fail: User terminated in buyer survey + # Buyer Quality Fail: User terminated in buyer survey for quality reasons + # PS Term: User failed in marketplace's pre-screener + # PS Quality Term: User rejected by marketplace for quality reasons + # PS OverQuota: User rejected by marketplace + # PS Duplicate: User rejected by marketplace + # GRL Fail: User was never sent into a marketplace (generally quality + # reasons, or user answered questions that made them ineligible for + # the requested survey) + # + # Note: PS OverQuota is theoretically "our" fault, and we don't want to + # expose that? Maybe we map it to PS Term when we expose it? + status_code_1 = models.SmallIntegerField(null=True) + + # If the status_code_1 is GRL Fail, we could include another reason for the + # failure, such as VPN Usage, user is blocked, no eligible surveys, etc. + status_code_2 = models.SmallIntegerField(null=True) + + # This may or may not be related to the CPI of the final survey. A session + # could have a payout even if the final survey was not a complete. This + # is the amount paid to the BP. It does not change, even if the session + # is reversed. + # -- Do we set it to 0.00 once the session is over? + payout = models.DecimalField(max_digits=5, decimal_places=2, null=True) + + # The amount the BP should pay to the user. Only is set if configured by the BP. + user_payout = models.DecimalField(max_digits=5, decimal_places=2, null=True) + + # This is the most recent reconciliation status of the session. Generally, + # we would adjust this if the last survey in the session was adjusted + # from complete to incomplete. + # + # Possible values: 'ac' (adjusted to complete), 'af' (adj to fail) + adjusted_status = models.CharField(max_length=2, null=True) + + # If the session is 'fail' -> 'adj. to complete': payout is NULL (or 0?), + # adjusted_payout is the amount paid if 'complete' -> 'adj. to fail': + # payout is the amount, adjusted_payout is 0.00 + # + # If a survey is complete, and then adjusted to incomplete, and then back + # to complete, then both adjusted_status and adjusted_payout would go + # back to NULL, however the adjusted_timestamp would be set! + adjusted_payout = models.DecimalField(max_digits=5, decimal_places=2, null=True) + adjusted_user_payout = models.DecimalField( + max_digits=5, decimal_places=2, null=True + ) + + # This timestamp gets updated every time there is an adjustment (it is + # the latest) + adjusted_timestamp = models.DateTimeField(null=True) + + # Wall Session Metadata. This is a passthrough of any extra arguments the + # BP appends on the offerwall request. + url_metadata = models.JSONField(null=True) + + class Meta: + db_table = "thl_session" + indexes = [ + # For rolling window searches + models.Index(fields=["user_id", "started"]), + models.Index(fields=["started"]), + # Used primarily for "dashboard"-related tasks, where we would + # filter by started first and the group by one of these field. + models.Index(fields=["country_iso"]), + models.Index( + fields=["adjusted_status"], + name="thl_session_adj_status_nn_idx", + condition=Q(adjusted_status__isnull=False), + ), + models.Index( + fields=["adjusted_timestamp"], + name="thl_session_adj_ts_nn_idx", + condition=Q(adjusted_timestamp__isnull=False), + ), + models.Index(fields=["device_type"]), + models.Index(fields=["ip"]), + # uuid will already have an index due to unique + ] + + +class THLWall(models.Model): + """ + A wall event must always exist within a session. + """ + + # This is what gets exposed externally (to marketplaces), it's also what + # we'll map over from wall.mid. We need this for the marketplace redirects. + uuid = models.UUIDField(primary_key=True) + + # We can use a ForeignKey here because we want these two tables + # connected with a key constraint + session = models.ForeignKey( + THLSession, + on_delete=models.RESTRICT, + null=False, + related_name="session", + ) + + # This is the marketplace we sent user to. len=2 for us to potentially expand. + source = models.CharField(max_length=2, null=False) + + # Buyer / account within the marketplace / source + buyer_id = models.CharField(max_length=32, null=True) + + # When we create the wall event, we set both survey_id & req_survey_id to + # the same value. If the user comes back from the redirect from a + # different survey_id, we'll change the survey_id to the + # "returned"/"actual" survey_id and req_survey_id unchanged. + survey_id = models.CharField(max_length=32, null=False) + req_survey_id = models.CharField(max_length=32, null=False) + + # This works the exact same as survey_id / req_survey_id. This CPI 2includes + # any applicable marketplace commission. + # (It is possible they got sent elsewhere, or the CPI of the survey changed, + # and it wasn't updated in time. If so, we'd update the cpi field). + cpi = models.DecimalField(max_digits=8, decimal_places=5, null=False) + req_cpi = models.DecimalField(max_digits=8, decimal_places=5, null=False) + + # thl_session.started does not necessarily equal thl_wall.started? + started = models.DateTimeField(null=False) + finished = models.DateTimeField(null=True) + + # The GRL status of the wall event. Possible values: + # NULL (enter) f (failure) c (complete) a (user exited, or started + # another session) t (timeout). + # + # The status can only ever change from NULL-> something, and once it is + # set, cannot change. + status = models.CharField(max_length=1, default=None, null=True) + + # This is a more detailed status for the wall event. We will map each + # marketplace's status codes to one of these categories. Note: some + # marketplaces don't return enough information and so some marketplaces + # might only ever use a subset of these. + # + # Uses GRL's category: (only those where status = 'f' would have a status_code) + # Buyer Fail: User terminated in buyer survey + # Buyer Quality Fail: User terminated in buyer survey for quality reasons + # PS Term: User failed in marketplace's prescreener + # PS Quality Term: User rejected by marketplace for quality reasons + # PS OverQuota: User rejected by marketplace + # PS Duplicate: User rejected by marketplace + status_code_1 = models.SmallIntegerField(null=True) + + # For future expansion + status_code_2 = models.SmallIntegerField(null=True) + + # External status codes + # This is the marketplace's status / status code / status reason / whatever + # they call it. + ext_status_code_1 = models.CharField(max_length=32, null=True) + ext_status_code_2 = models.CharField(max_length=32, null=True) + ext_status_code_3 = models.CharField(max_length=32, null=True) + + # The thl_wall event can be reported, without breaking the thl_session. A + # user may want to report the first survey as invasive, but still continue + report_value = models.SmallIntegerField(null=True) + report_notes = models.CharField(max_length=255, null=True) + + # This is the most recent reconciliation status of the wall event. + adjusted_status = models.CharField(max_length=2, null=True) + + # If the session is 'fail' -> 'adj. to complete': cpi is NULL (or 0?), + # adjusted_cpi is the amount paid if 'complete' -> 'adj. to fail': cpi + # is the amount, adjusted_cpi is 0.00 + adjusted_cpi = models.DecimalField(max_digits=8, decimal_places=5, null=True) + adjusted_timestamp = models.DateTimeField(null=True) + + class Meta: + db_table = "thl_wall" + + # We could start to do stuff like this to ensure a session doesn't + # contain the same survey more than twice in the session + unique_together = ("session", "source", "survey_id") + + # A session shouldn't have more than 100 wall events, or we should put + # additional indices. (session_id, started). + + indexes = [ + # uuid is primary key so already has an index + models.Index(fields=["started"]), + models.Index(fields=["source", "survey_id", "started"]), + models.Index( + fields=["adjusted_status"], + name="thl_wall_adj_status_nn_idx", + condition=Q(adjusted_status__isnull=False), + ), + models.Index( + fields=["adjusted_timestamp"], + name="thl_wall_adj_ts_nn_idx", + condition=Q(adjusted_timestamp__isnull=False), + ), + ] + + +# # TODO: in the future +# class WallProgress(models.Model): +# # Completion Percentage. We could have GRS send this, or calculate it +# from the received answers, and/or direct buyer relationships could +# send us this data. +# progress = models.FloatField(null=True) +# # Useful for tracking if a user is still "in" a survey. +# progress_last_updated = models.DateTimeField(null=True) +# # wall id +# wall = models.ForeignKey + + +class THLUser(models.Model): + """ + Class for the generic concept of a user in our entire platform + """ + + # This is the value that will get passed around internally to uniquely + # identify a user. We'll never share this value outside of + # General Research. Yes, it supports integers far larger + # than the world's population. However, the additional storage overhead + # is trivial and this table will likely get spammed with tons of signups + # for users that will never be used. + id = models.BigAutoField(primary_key=True, null=False, blank=False) + + # This uuid is what gets exposed anytime the user value is publicly + # available. We don't use it for passing around internally because + # it's large (32 char str). However, we keep it as the primary key so that + # this table has less auto-increment issues + uuid = models.UUIDField(null=False, blank=False, unique=True) + + # This is no longer a foreign key, so at least enforce they're UUIDs + product_id = models.UUIDField(null=False, blank=False, unique=False) + + # We're going to limit the length of BPUID values to 128 characters. + product_user_id = models.CharField( + max_length=128, null=False, blank=False, unique=False + ) + + # ------------------ + # ---- METADATA ---- + # ------------------ + + # This will be useful for looking at signups per country globally. + created = models.DateTimeField(null=False) + + # This will be useful as our Daily Active User metric (DAU) that a lot + # of marketplaces want reported on. We'll be able to change the + # logic for when this is updated, but keeping it in a table will + # make queries much easier. + # We'll force null=True, and his will be the same as created + # until they're "seen again". + last_seen = models.DateTimeField(null=False) + + # this isn't used for any security measures, but we have an increasing + # need to provide a "live panel book" to describe our user base. This + # will help immediately filter users down by "after users from X country + # in the past X days". Which is now nearly impossible. + last_country_iso = models.CharField(max_length=2, null=True) + + # Along with the last_country_iso, we have the last_geoname_id which + # we could use to aggregate by state, timezone, continent, etc. + # Note: Do not use PositiveIntegerField as this adds a constraint which + # prevents this column from being added instantly + last_geoname_id = models.IntegerField(null=True) + + # Also for convenience, as this is available in the userhealth_iphistory + # table, but we'd need to groupby/sort, so store the user's latest IP here + last_ip = models.GenericIPAddressField(null=True) + + # No index needed on it, just a quick attribute check for if we process + # additional resources for this user + blocked = models.BooleanField(default=False) + + class Meta: + db_table = "thl_user" + + # The same BPUID can't be present within a product_id + unique_together = ("product_id", "product_user_id") + + indexes = [ + # id already has an index as the primary key + # This will be used to look up a from GRS or + # possibly another "outside source". Does + # not need to be a composite with anything else. + # Note: Index is already created due to being marked unique + # models.Index(fields=["uuid"]), + # We will never look up a product_user_id by itself (because it's + # not unique), so this will always be a composite + # Note: Index is created by the `unique_together` above + # models.Index(fields=["product_id", "product_user_id"]), + models.Index(fields=["created"]), + models.Index(fields=["last_seen"]), + models.Index(fields=["last_country_iso"]), + ] + + +class THLUserMetadata(models.Model): + """ + Stores information about a user that is modifiable, including by the BP or + potentially by the user itself. As opposed to the THLUser table + which does not store fields that can be directly set. + """ + + # There is a one-to-one relationship between this and the THLUser table, + # so this id equals THLUser.id + user = models.OneToOneField( + to=THLUser, on_delete=models.RESTRICT, null=False, primary_key=True + ) + + email_address = models.CharField(max_length=320, null=True) + email_sha256 = models.CharField(max_length=64, null=True) + email_sha1 = models.CharField(max_length=40, null=True) + email_md5 = models.CharField(max_length=32, null=True) + + class Meta: + db_table = "thl_usermetadata" + indexes = [ + models.Index(fields=["email_address"]), + models.Index(fields=["email_sha256"]), + models.Index(fields=["email_sha1"]), + models.Index(fields=["email_md5"]), + ] + + +class IPInformation(models.Model): + """ + Most of the info in here can be imported from the City Plus csv + (https://www.maxmind.com/en/geoip2-city), but we'll wait just use Insights. + + Using the Country DB files, we can only populate: ip, country_iso, + registered_country_iso. + + If the IP address is in a tier 1 or 2 country, we'll call insights after. + If the IP is in a tier 3 country, we'll call insights only if they + actually enter a bucket. + """ + + ip = models.GenericIPAddressField(primary_key=True) + + # Use this to join on IPLocation table. + # In the insights API response, this is the city.geoname_id + geoname_id = models.PositiveIntegerField(null=True) + + # This is duplicated in the IPLocation table, but keeping for convenience. + # This is the country the IP address is physically in, and is inferrable + # from the geoname_id. + country_iso = models.CharField(max_length=2, blank=False, null=True) + + # The country in which the IP is registered (by the ISP) + registered_country_iso = models.CharField(max_length=2, blank=False, null=True) + + # Traits (these come from Anonymous DB, through GeoIP2 Insights Web Service) + # https://dev.maxmind.com/geoip/docs/databases/anonymous-ip + is_anonymous = models.BooleanField(null=True) + is_anonymous_vpn = models.BooleanField(null=True) + is_hosting_provider = models.BooleanField(null=True) + is_public_proxy = models.BooleanField(null=True) + is_tor_exit_node = models.BooleanField(null=True) + is_residential_proxy = models.BooleanField(null=True) + + # More Traits + autonomous_system_number = models.IntegerField(null=True, blank=False) + autonomous_system_organization = models.CharField( + max_length=255, null=True, blank=False + ) + domain = models.CharField(max_length=255, null=True, blank=True) + isp = models.CharField(max_length=255, null=True, blank=False) + mobile_country_code = models.CharField(max_length=3, null=True, blank=False) + mobile_network_code = models.CharField(max_length=3, null=True, blank=False) + network = models.CharField(max_length=56, null=True, blank=False) + organization = models.CharField(max_length=255, null=True, blank=False) + static_ip_score = models.FloatField(null=True) # ranges from 0 to 99.99 + user_type = models.CharField(max_length=64, null=True, blank=False) + # Leaving this out as it will be immediately out of date unless we keep this updated + # user_count = models.PositiveIntegerField(null=True) + + # Location fields that may be different for different IPs in the same City + postal_code = models.CharField(max_length=20, blank=True, null=True) + latitude = models.DecimalField(max_digits=10, decimal_places=6, null=True) + longitude = models.DecimalField(max_digits=10, decimal_places=6, null=True) + accuracy_radius = models.PositiveSmallIntegerField(null=True) + + updated = models.DateTimeField(auto_now=True) + + class Meta: + db_table = "thl_ipinformation" + indexes = [ + models.Index(fields=["updated"]), + ] + + +class GeoName(models.Model): + """ + Stores information about the city, continent, country, postal, and subdivisions + https://dev.maxmind.com/geoip/docs/databases/city-and-country#locations-files + + All of this info comes back in an Insights API call. We can check if the row + exists in this table, and just create it once. + """ + + geoname_id = models.PositiveIntegerField(primary_key=True, null=False) + # We could store place names in other languages, but I don't anticipate us + # doing this. If we did, the primary key would be (geoname_id, locale_code). + # locale_code = models.CharField(max_length=5, default='eng') + + # AF - Africa, AN - Antarctica, AS - Asia, EU - Europe, + # NA - North America, OC - Oceania, SA - South America + continent_code = models.CharField(max_length=2, blank=False, null=False) + continent_name = models.CharField(max_length=32, blank=False, null=False) + + # Below here are all optional, although country will be set 99% of the time + country_iso = models.CharField(max_length=2, blank=False, null=True) + country_name = models.CharField(max_length=64, blank=False, null=True) + subdivision_1_iso = models.CharField(max_length=3, blank=False, null=True) + subdivision_1_name = models.CharField(max_length=255, blank=False, null=True) + subdivision_2_iso = models.CharField(max_length=3, blank=False, null=True) + subdivision_2_name = models.CharField(max_length=255, blank=False, null=True) + city_name = models.CharField(max_length=255, blank=False, null=True) + metro_code = models.PositiveSmallIntegerField(null=True) + time_zone = models.CharField(max_length=60, blank=False, null=True) + is_in_european_union = models.BooleanField(null=True) + + updated = models.DateTimeField(auto_now=True) + + class Meta: + db_table = "thl_geoname" + indexes = [ + models.Index(fields=["updated"]), + ] + + +class LedgerDirection(models.IntegerChoices): + # This choice of Positive/Negative per direction is arbitrary and is + # just used for ease of multiplying numbers together later instead of + # using strings. + CREDIT = -1, "credit" + DEBIT = 1, "debit" + + +class LedgerAccount(models.Model): + """ + A ledger_account is an account in a double-entry accounting system. Each + ledger account can optionally be associated with a uuid in another + table, such as a brokerage product or user. + + Further reading: https://docs.moderntreasury.com/ledgers/docs/digital-wallet-tutorial?tab=Transactions-API + """ + + uuid = models.UUIDField(null=False, primary_key=True) + + # Name which could be used to display this account + display_name = models.CharField(max_length=64, null=False) + + # A fully qualified name which could be used for the purposes of grouping + # and placing this account into a hierarchical structure. The elements + # are colon-separated. The fully qualified name must be unique. This + # could be used to look up an account. + qualified_name = models.CharField(max_length=255, null=False, unique=True) + + # Used to tag an account with its general "purpose". Could be used for + # group bys for reporting purposes. + # e.g. "bp_commission" which stores commission from BP payments + account_type = models.CharField(max_length=30, null=True) + + # Each account must be debit or credit-normal. + normal_balance = models.SmallIntegerField( + null=False, choices=LedgerDirection.choices + ) + + # Could be a reference to a BP, or account, or user in another table. + reference_type = models.CharField(max_length=30, null=True) + reference_uuid = models.UUIDField(null=True) + + # The currency for this account's transactions. For now, all will be "USD". + # I can imagine we could have a LedgerCurrency table that store the: + # currency_exponent, display format str, conversion_rate, etc, and then + # this would be a uuid/pk into that table. + currency = models.CharField(max_length=32, null=False) + + # The currency's smallest denomination unit. For e.g. an account + # denomination in USD has a currency_exponent of 2, because 1 cent = + # 1*10^-2 USD. I think this is just used for display purposes. + # currency_exponent = models.SmallIntegerField(null=False) + + class Meta: + db_table = "ledger_account" + indexes = [ + # This is not a unique index because an entity could have + # multiple accounts + # + # I don't think we need an index on reference_type b/c no two + # entities should have the same uuids anyway. + models.Index(fields=["reference_uuid"]), + ] + + +class LedgerTransaction(models.Model): + """ + A ledger_transaction is a transaction between two or more ledger accounts. + To create a ledger transaction, there must be at least one credit + ledger entry and one debit ledger entry. Additionally, the sum of all + credit entry amounts must equal the sum of all debit entry amounts. + """ + + id = models.BigAutoField(primary_key=True, null=False) + created = models.DateTimeField(null=False) + + # Optionally add notes to the transaction that could be displayed in + # an account statement + ext_description = models.CharField(max_length=255, null=True) + + # Optionally tag a transaction for quick and easy searching (used for + # de-duplication / locking purposes) + tag = models.CharField(max_length=255, null=True) + + class Meta: + db_table = "ledger_transaction" + indexes = [ + models.Index(fields=["created"]), + models.Index(fields=["tag"]), + ] + + +class LedgerTransactionMetadata(models.Model): + """ + Used to associate a transaction with metadata: a thl_session, or thl_wall, + or user quality history event, or multiple of each, or something + else in the future ... + """ + + id = models.BigAutoField(primary_key=True, null=False) + transaction = models.ForeignKey( + LedgerTransaction, on_delete=models.RESTRICT, null=False + ) + key = models.CharField(max_length=30) + value = models.CharField(max_length=255) + + class Meta: + db_table = "ledger_transactionmetadata" + # You can only have 1 key per transaction. So a transaction cannot + # be associated with multiple thl_session uuids for e.g., but it + # can be associated with a thl_session and uqh. + # + # If there is a need to associate a transaction with a list of + # thl_sessions (for example a bonus for 10 completes), the + # transaction should instead be associated with a single "contest" + # object that itself points to those 10 completes (or whatever). + unique_together = ("transaction", "key") + indexes = [ + models.Index(fields=["key", "value"]), + ] + + +class LedgerEntry(models.Model): + """ + A.K.K "line item". A ledger_entry represents an accounting entry within + a parent ledger transaction. + """ + + id = models.BigAutoField(primary_key=True, null=False) + direction = models.SmallIntegerField(null=False, choices=LedgerDirection.choices) + account = models.ForeignKey( + LedgerAccount, + on_delete=models.RESTRICT, + null=False, + related_name="account", + ) + # In the smallest unit of the currency being transacted. For + # USD, this is cents. + amount = models.BigIntegerField(null=False) + transaction = models.ForeignKey( + LedgerTransaction, + on_delete=models.RESTRICT, + null=False, + related_name="transaction", + ) + + class Meta: + db_table = "ledger_entry" + + +class LedgerAccountStatement(models.Model): + """ + Provides the starting and ending balances of a ledger account for a + specific time period. The statement could optionally apply a metadata + filter to the account. + """ + + id = models.BigAutoField(primary_key=True, null=False) + account = models.ForeignKey(LedgerAccount, on_delete=models.RESTRICT, null=False) + + # For optional filtering: key/values applied to the + # LedgerTransactionMetadata to filter transactions for this account, as + # a '&' delimited, key=value string (sorted by key). e.g. + # "transaction_type=cashout&user_id=12345" + filter_str = models.CharField(max_length=255, null=True) + + # The inclusive lower bound of the effective_at timestamp of the ledger + # entries to be included in the statement + effective_at_lower_bound = models.DateTimeField(null=False) + + # The exclusive upper bound of the effective_at timestamp of the ledger + # entries to be included in the statement + effective_at_upper_bound = models.DateTimeField(null=False) + starting_balance = models.BigIntegerField(null=False) + ending_balance = models.BigIntegerField(null=False) + + # sql query used to generate this data + sql_query = models.TextField(null=True) + + class Meta: + db_table = "ledger_accountstatement" + indexes = [ + models.Index(fields=["account", "filter_str", "effective_at_lower_bound"]), + ] + # Maybe there should be a unique index on ('account', 'filter_str', + # 'effective_at_lower_bound', 'effective_at_upper_bound') ? + # + # Maybe we should add a OPEN/CLOSED flag to indicate the time period + # is still "open", and another field with the timestamp of the last + # transaction within the statements, so that we can + # continuously update a statement. + + +class TaskAdjustment(models.Model): + """ + This used to be userprofile.UserQualityHistory. This now only stores + Task Adjustments/recons. Any other quality types should go in he + userhealth_auditlog. + + This stores a reference to a THLWall record (wall_uuid), which should + have identical source, survey_id, and started values (copied here + for convenience). + """ + + uuid = models.UUIDField(default=uuid.uuid4, primary_key=True) + + # This is 'af' (adjusted to fail), 'ac' (adjusted to complete), 'cc' + # (confirmed complete), or possibly 'ca' (cpi adjustment) (not yet + # supported). + adjusted_status = models.CharField(max_length=2, null=False) + + # External status code: marketplace's status / status code / status + # reason / whatever they call it. + ext_status_code = models.CharField(max_length=32, null=True) + + # The amount that is being adjusted. If positive, this is the amount + # added to the original payment, if negative, this amount is taken + # back (complete -> recon). This should agree with adjusted_status. + # + # This should be NULL only if the adjusted_status is cc. + # + # Note: this is in USD b/c THLWall cpi and adjusted_cpi are in USD, and + # we only ever transact in task completions in USD. + amount = models.DecimalField(decimal_places=2, max_digits=5, null=True) + + # When were we notified about this? + alerted = models.DateTimeField(null=False) + + # When we created this record. + created = models.DateTimeField(auto_now_add=True) + + # This is inferrable through the wall_uuid -> thl_session, but copied + # here for convenience + user_id = models.BigIntegerField(null=False) + + # This is the wall event that had the adjustment + wall_uuid = models.UUIDField(null=False) + + # These 3 are also inferrable through thl_wall, but copied here for + # convenience. When the user started the task that had a quality event + started = models.DateTimeField(null=False) + source = models.CharField(max_length=2, null=False) + survey_id = models.CharField(max_length=32, null=False) + + class Meta: + db_table = "thl_taskadjustment" + + indexes = [ + models.Index(fields=["created"]), + models.Index(fields=["user_id"]), + models.Index(fields=["wall_uuid"]), + ] diff --git a/generalresearch/thl_django/contest/__init__.py b/generalresearch/thl_django/contest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/thl_django/contest/models.py b/generalresearch/thl_django/contest/models.py new file mode 100644 index 0000000..415abf8 --- /dev/null +++ b/generalresearch/thl_django/contest/models.py @@ -0,0 +1,115 @@ +from django.db import models + + +class Contest(models.Model): + """ """ + + id = models.BigAutoField(primary_key=True, null=False) + + uuid = models.UUIDField(null=False, unique=True) + product_id = models.UUIDField(null=False) + + name = models.CharField(max_length=128) + description = models.CharField(max_length=2048, null=True) + country_isos = models.CharField(max_length=1024, null=True) + + contest_type = models.CharField(max_length=32) + status = models.CharField(max_length=32) + + starts_at = models.DateTimeField() + terms_and_conditions = models.CharField(max_length=2048, null=True) + + end_condition = models.JSONField() + prizes = models.JSONField() + + # ---- Only set when the contest ends ---- + ended_at = models.DateTimeField(null=True) + end_reason = models.CharField(max_length=32, null=True) + # ---- END Only set when the contest ends ---- + + # ---- Contest-type-specific keys ---- + + # For raffle contests + entry_type = models.CharField(max_length=8, null=True) + entry_rule = models.JSONField(null=True) + + # These get calculated by / (are dependent on) the entries, but I'm adding + # these as fields, so we can quickly retrieve them without having to join + # on the entry table and redo the summations. + # They are nullable because they do not apply to leaderboard contests + current_participants = models.IntegerField(null=True) + current_amount = models.IntegerField(null=True) + + # For Milestone + milestone_config = models.JSONField(null=True) + # For keeping track of the number of times this milestone has been reached + win_count = models.IntegerField(null=True) + + # For LeaderboardContest + # e.g. 'leaderboard:48d6ff6664bc4767a0d8e5381f7e5cf0:us:monthly:2024-01-01:largest_user_payout' + leaderboard_key = models.CharField(max_length=128, null=True) + + # ---- END Contest-type-specific keys ---- + + created_at = models.DateTimeField(auto_now_add=True, null=False) + # updated_at gets set to created_at when object is created! + # updated_at means a property of the contest itself is modified, NOT + # including an entry being created/modified. + updated_at = models.DateTimeField(auto_now=True, null=False) + + class Meta: + db_table = "contest_contest" + indexes = [ + # id and uuid will already have an index + models.Index(fields=["product_id", "created_at"]), + models.Index(fields=["product_id", "status"]), + ] + + +class ContestEntry(models.Model): + id = models.BigAutoField(primary_key=True, null=False) + uuid = models.UUIDField(null=False, unique=True) + + # The Contest.id this entry pertains to + contest_id = models.BigIntegerField(null=False) + + amount = models.IntegerField(null=False) + user_id = models.BigIntegerField(null=False) + + created_at = models.DateTimeField(auto_now_add=True, null=False) + # updated_at gets set to created_at when object is created! + # Raffle entries are NOT modifiable, but in a milestone contest, + # we'll update the 'amount' per (contest, user). + updated_at = models.DateTimeField(auto_now=True, null=False) + + class Meta: + db_table = "contest_contestentry" + indexes = [ + # id and uuid will already have an index + models.Index(fields=["user_id", "created_at"]), + models.Index(fields=["contest_id", "user_id"]), + ] + + +class ContestWinner(models.Model): + id = models.BigAutoField(primary_key=True, null=False) + uuid = models.UUIDField(null=False, unique=True) + + # The Contest.id this entry pertains to + contest_id = models.BigIntegerField(null=False) + user_id = models.BigIntegerField(null=False) + + prize = models.JSONField(null=False) + # If it's a tie, and the prize is cash, multiple users may split a prize. + awarded_cash_amount = models.IntegerField(null=True) + + # Milestone winners are created at different times, so we need this also. + created_at = models.DateTimeField(auto_now_add=True, null=False) + + class Meta: + db_table = "contest_contestwinner" + indexes = [ + # id and uuid will already have an index + models.Index(fields=["user_id", "created_at"]), + models.Index(fields=["contest_id"]), + ] diff --git a/generalresearch/thl_django/event/__init__.py b/generalresearch/thl_django/event/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/thl_django/event/models.py b/generalresearch/thl_django/event/models.py new file mode 100644 index 0000000..3e19910 --- /dev/null +++ b/generalresearch/thl_django/event/models.py @@ -0,0 +1,91 @@ +import uuid + +from django.db import models + + +class Bribe(models.Model): + """ + This is meant to store info about "manual" bribes, in which Customer + Support directly gives a user money (put into their wallet) as a result + of an email/support communication. + + Each bribe would have its own row, with the metadata about the bribe in + the data field. + + There is no "sent" field, because the financial impact of bribes is + determined by the ledger, not by presence or absence in this table. + """ + + uuid = models.UUIDField(default=uuid.uuid4, primary_key=True) + + # This is the LedgerAccount.uuid that this Payout Event is associated with. + # The user/BP is retrievable through the LedgerAccount.reference_uuid. + credit_account_uuid = models.UUIDField(null=False) + created = models.DateTimeField(auto_now_add=True) + + # In the smallest unit of the currency being transacted. For USD, this + # is cents. + amount = models.BigIntegerField(null=False) + ext_ref_id = models.CharField(max_length=64, null=True) # support ticket ID? + description = models.TextField( + null=True + ) # could be shown to the user in their transactions description + data = models.JSONField(null=True) # content of email? (optional) + + class Meta: + db_table = "event_bribe" + + indexes = [ + models.Index(fields=["created"]), + models.Index(fields=["credit_account_uuid"]), + models.Index(fields=["ext_ref_id"]), + ] + + +class Payout(models.Model): + """ + Money is paid out of a virtual wallet. + """ + + uuid = models.UUIDField(default=uuid.uuid4, primary_key=True) + + # This is the LedgerAccount.uuid that this money is being requested from. + # The user/BP is retrievable through the LedgerAccount.reference_uuid + debit_account_uuid = models.UUIDField(null=False) + + # References a row in the account_cashoutmethod table. This is the cashout + # method that was used to request this payout. (A cashout is the same + # thing as a payout) + cashout_method_uuid = models.UUIDField(null=False) + created = models.DateTimeField(auto_now_add=True) + + # In the smallest unit of the currency being transacted. For USD, this is cents. + amount = models.BigIntegerField(null=False) + + # The allowed values for `status` are defined in py-utils: + # generalresearch/models/thl/payout.py:PayoutStatus + status = models.CharField(max_length=20, null=True) + + # Used for holding an external, payouttype-specific identifier + ext_ref_id = models.CharField(max_length=64, null=True) + + # The allowed values for `payout_type` are defined in py-utils: + # generalresearch/models/thl/payout.py:PayoutType + payout_type = models.CharField(max_length=14) + + # Stores payout-type-specific information that is used to request this + # payout from the external provider. + request_data = models.JSONField(null=True) + + # Stores payout-type-specific order information that is returned from + # the external payout provider. + order_data = models.JSONField(null=True) + + class Meta: + db_table = "event_payout" + + indexes = [ + models.Index(fields=["created"]), + models.Index(fields=["debit_account_uuid"]), + models.Index(fields=["ext_ref_id"]), + ] diff --git a/generalresearch/thl_django/marketplace/__init__.py b/generalresearch/thl_django/marketplace/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/thl_django/marketplace/models.py b/generalresearch/thl_django/marketplace/models.py new file mode 100644 index 0000000..7a67804 --- /dev/null +++ b/generalresearch/thl_django/marketplace/models.py @@ -0,0 +1,757 @@ +import uuid + +from django.db import models +from django.db.models import Q + + +class ProbeLog(models.Model): + """ + Table for logging probes of tasks' entry links. Typically, using playwright. + """ + + id = models.BigAutoField(primary_key=True) + + source = models.CharField(max_length=2, null=False) + survey_id = models.CharField(max_length=32, null=False) + + # When the probe started + started = models.DateTimeField(null=False) + + # The url that was probed + live_url = models.CharField(max_length=3000, null=False) + + # The relative path to the har-file generated + har_path = models.CharField(max_length=1000, null=False) + + # The result of the probe + result = models.CharField(max_length=64, null=True) + + class Meta: + db_table = "marketplace_probelog" + + indexes = [ + models.Index(fields=["source", "survey_id"]), + models.Index(fields=["started"]), + ] + + +""" +General naming notes: + - Property: describes some concept about a user. e.g. age, education level, + the car brand they are planning on buying. + - Edge: an association between a user -> property -> value. The value can + be one of multiple different types (item, numerical, string, date). + - Item: represents a concept or class. I don't want to use "class" or "object" + b/c it conflicts with python namespaces. This is a "thing" such as + "male" or "honda". + - Concept: What I'm calling something that is either a property or item, + such as a translation, which both properties and items have. + +Examples: +- Gender. It is valid in all countries. + Four possible values: male, female, non-binary, other. + Non-binary is a subclass of other? todo +- Age. Property="age_in_years". We could also have a property "birth_date" + value is an int (or for birthdate, a date). +- Hispanic. Only valid in the US & CA. Options are the same in each country + Options: Yes, No. + sub-options: Mexican, puerto rican, etc are subclasses of Yes. +- Education level: Valid in every country, but the options are different in + many countries + - "Secondary Education" in DE, "high school" in US (not translations, + these are different concepts) +- postal_code: value type is a string. We could have special + structured/hierarchical datatypes for location regions (city, county, state). +- car's fuel source (c:fuel, l:96563, ...) + +Migration Notes: +you must delete all rows from the following tables before running migration or it will fail +---mysql +delete from marketplace_externalid; +delete from marketplace_node; +delete from marketplace_property; +delete from marketplace_userprofileknowledge; +--- +and we must comment out some code that writes to and reads from the UPK first before any migrations +""" + + +class Property(models.Model): + """ + Stores the list of properties and their types + """ + + id = models.UUIDField(default=uuid.uuid4, null=False, primary_key=True) + label = models.CharField(max_length=255, null=False) + description = models.TextField() + + # * -- zero or more + # ? -- zero or one + cardinality = models.CharField(max_length=9, null=False) + + TYPE_CHOICES = ( + ("n", "numerical"), + ("x", "text"), + ("i", "item"), + # ('a', 'datetime'), + # ('t', 'time'), + # ('d', 'date'), + ) + prop_type = models.CharField(choices=TYPE_CHOICES, max_length=1, default="c") + + class Meta: + db_table = "marketplace_property" + + +class Item(models.Model): + """ + Represent things such as male or female. A item that is unambiguously + the same thing across countries, will have the same ID within a + property's range, but not across different properties. For e.g. + - "male" as a possible gender is the same thing in US & DE. + - "high school graduate" in the US is different from "Gesamtschule" + in Germany + - "Honda" as an answer to "what kind of car do you drive?" is different + than "Honda" as an answer to "what kind of car are you planning + on buying?" + """ + + id = models.UUIDField(default=uuid.uuid4, null=False, primary_key=True) + label = models.CharField(max_length=255, null=False) + description = models.TextField(null=True) # optional, for notes + + class Meta: + db_table = "marketplace_item" + + +class PropertyCountry(models.Model): + """ + This associates a property with the countries it is "allowed" to be used in. + e.g. hispanic only applies in US & CA. + + For item properties, this is kind of unnecessary, b/c we'll know it from + the PropertyConceptRange table. But for the others, who have no item + ranges, we need this (like for age). + """ + + property_id = models.UUIDField(null=True) + country_iso = models.CharField(max_length=2, null=False) + + # Used for changing how UPK is exposed. A gold standard question we've + # enumerated possible values (in that country) and (as best as possible) + # mapped them across marketplaces. A property not marked as gold-standard + # maybe has 1) marketplace qid associations & 2) category associations, + # but doesn't have a defined "range" (list of allowed items in a + # multiple choice question). Used for exposing a user's profiling data & + # for the Nudge API. + gold_standard = models.BooleanField(default=False) + + class Meta: + db_table = "marketplace_propertycountry" + + indexes = [ + models.Index(fields=["property_id"]), + models.Index(fields=["country_iso"]), + ] + + +class PropertyItemRange(models.Model): + """ + "Range" means the set of possible values for this property in this country + e.g. (gender (every country): male, female, other), + or education (us): high school, university, whatever, which is different + than the options in germany. + """ + + property_id = models.UUIDField(null=True) + item = models.ForeignKey(Item, on_delete=models.CASCADE) + country_iso = models.CharField(max_length=2, null=False) + + class Meta: + db_table = "marketplace_propertyitemrange" + indexes = [models.Index(fields=["country_iso", "property_id"])] + + +class ConceptTranslation(models.Model): + """ + One table for both properties and classes. + """ + + concept_id = models.UUIDField() + language_iso = models.CharField(max_length=3) + text = models.TextField() + + class Meta: + db_table = "marketplace_concepttranslation" + + indexes = [ + models.Index(fields=["concept_id"]), + models.Index(fields=["language_iso"]), + ] + + +class PropertyMarketplaceAssociation(models.Model): + """ + Associates a property with a marketplace's question ID (many-to-many) + """ + + property_id = models.UUIDField(null=True) + source = models.CharField(max_length=1, null=False) + question_id = models.CharField(max_length=32, null=False) + + class Meta: + db_table = "marketplace_propertymarketplaceassociation" + + indexes = [ + models.Index(fields=["source", "question_id"]), + models.Index(fields=["property_id"]), + ] + + +class Category(models.Model): + """ + https://cloud.google.com/natural-language/docs/categories + https://developers.google.com/adwords/api/docs/appendix/verticals + https://developers.google.com/adwords/api/docs/appendix/codes-formats + """ + + id = models.AutoField(primary_key=True) + uuid = models.UUIDField(unique=True) + parent = models.ForeignKey("Category", null=True, on_delete=models.SET_NULL) + adwords_vertical_id = models.CharField(max_length=8, null=True) + label = models.CharField(max_length=255) + # stores a "path-style" label for easy searching, tagging, convenience. + # e.g. '/Hobbies & Leisure/Outdoors/Fishing' + path = models.CharField(max_length=1024, null=True) + + class Meta: + db_table = "marketplace_category" + + +class PropertyCategoryAssociation(models.Model): + """ + Associates a property with a category (many-to-many) + """ + + property_id = models.UUIDField(null=True) + category = models.ForeignKey(Category, on_delete=models.CASCADE) + + class Meta: + db_table = "marketplace_propertycategoryassociation" + + indexes = [models.Index(fields=["property_id"])] + + +# class ExternalID(models.Model): +# """" +# Probably mostly for location based concepts (geocode ID for "Florida", +# or wikidata ID for whatever). +# """ +# concept_id = models.UUIDField(null=False) +# curie = models.CharField(max_length=255, null=False) +# +# class PropertyAnnotation(models.Model): +# """ +# This could be used to define a range for non item-properties +# (e.g. age), maybe... or maybe associating properties with another +# (age_in_years <-> birth_date) +# """ +# pass +# +# class ClassStatement(models.Model): +# """This could be used to associate classes between one another across +# questions. For example to link the two Honda concepts in the +# above example. +# """ +# pass + + +class UserProfileKnowledge(models.Model): + """ + This only stores the most recent knowledge per user_id/property_id/question + Purposely not using foreign keys for the property or values b/c this table is + going to be huge and this will add complexity, overhead, unintended consequences... + """ + + user_id = models.PositiveIntegerField() + property_id = models.UUIDField() + # value = models.UUIDField() + + # If we ask the user, we'll have a session and question ID. We could also + # accept a 'gender' from a BP, and we may not know exactly how it was + # asked, so we have to support no question_id. + session_id = models.UUIDField(null=True) + question_id = models.UUIDField(null=True) + + # If question_id is optional, we need the locale here also, even though it + # would be inferable from the question_id which itself is locale-scoped. + country_iso = models.CharField(max_length=2, default="us") + + # I don't think lang should be a field here. If we ask the question, we'll + # know from the question_id the lang. Otherwise, we may not know what + # lang the question was asked in. The way we have the itemrange set up, + # the lang does not affect the possible options, only the translation. + # I think it is a really, really small edge case in which the language + # would affect an answer here. + # language_iso = models.CharField(max_length=3, default='eng') + + # when this specific edge (user, prop, value) was created/updated/added/changed + created = models.DateTimeField(auto_now=True) + + class Meta: + abstract = True + + +class UserProfileKnowledgeItem(UserProfileKnowledge): + """ + Same as UserProfileKnowledge but for value type of numerical + """ + + value = models.UUIDField(null=False) + + class Meta: + db_table = "marketplace_userprofileknowledgeitem" + indexes = [ + models.Index(fields=["user_id"]), + models.Index(fields=["created"]), + models.Index(fields=["property_id"]), + ] + + +class UserProfileKnowledgeNumerical(UserProfileKnowledge): + """ + Same as UserProfileKnowledge but for value type of numerical + """ + + value = models.FloatField(null=False) + + class Meta: + db_table = "marketplace_userprofileknowledgenumerical" + indexes = [ + models.Index(fields=["user_id"]), + models.Index(fields=["created"]), + models.Index(fields=["property_id"]), + ] + + +class UserProfileKnowledgeText(UserProfileKnowledge): + """ + Same as UserProfileKnowledge but for value type of numerical + """ + + value = models.CharField(max_length=1024) + + class Meta: + db_table = "marketplace_userprofileknowledgetext" + indexes = [ + models.Index(fields=["user_id"]), + models.Index(fields=["created"]), + models.Index(fields=["property_id"]), + ] + + +class Question(models.Model): + """ + Stores the info about a Question that is asked to a user. + """ + + id = models.UUIDField(default=uuid.uuid4, primary_key=True, null=False) + + # Used for detecting changes to marketplace questions (that were changed + # by the marketplace) + md5sum = models.CharField(max_length=32, null=True) + country_iso = models.CharField(max_length=2, default="us") + language_iso = models.CharField(max_length=3, default="eng") + + # this is either a upk code (e.g. gr:gender) or a marketplace question ID + # (e.g. c:gender (cint's gender question)) + property_code = models.CharField(max_length=64) + + # conforms to question.schema.json (lives in thl-yieldman/mrpq/jsonschema) + data = models.JSONField(default=dict) + + # shortcut for determining if the data.task_score > 0 + is_live = models.BooleanField(default=False) + + # Optionally describes custom, manual or automatic, modifications made to + # a question. for e.g.: marking it as never to be asked, or adding a + # "None of the above" option + custom = models.JSONField(default=dict) + + # When this question was last modified + last_updated = models.DateTimeField(null=True) + + # Human-readable template for explaining how a user's answer to this + # question affects eligibility + explanation_template = models.TextField(max_length=255, null=True) + + # A very short, natural-language explanation fragment that can be combined + # with others into a single sentence + explanation_fragment_template = models.TextField(max_length=255, null=True) + + class Meta: + db_table = "marketplace_question" + + indexes = [ + models.Index(fields=["last_updated"]), + models.Index(fields=["property_code"]), + ] + + +class UserQuestionAnswer(models.Model): + """ + Stores the info about the event of a user answering a question. + This is distinct from UPK b/c, for e.g. a user could be asked + a) what is your gender? (male, female) or + b) what is your gender? (male, female, other) + The UPK table would store the user's latest "gr:gender" -> answer + The user question answer would store the info about which question they were asked, + and what they answered. + """ + + question = models.ForeignKey("Question", on_delete=models.DO_NOTHING) + created = models.DateTimeField() + session_id = models.UUIDField(null=True) + user_id = models.IntegerField() + + # The user's answer to the question. Stores the actual value for text + # questions, or the selected choice's codes for MC. Always a list!! + # e.g. ["92116"] for a text entry, or ["3", "7", "9"] for multiple choice + answer = models.JSONField(default=list) + + # We'll save in here the marketplace answers that we've inferred/generated + # from the answer e.g. {"l:123": ["1", "2"], "c:643": ["6", "5"]} + calc_answer = models.JSONField(default=dict) + + class Meta: + db_table = "marketplace_userquestionanswer" + + indexes = [ + models.Index(fields=["user_id", "question_id", "-created"]), + models.Index(fields=["created"]), + ] + + +class UserGroup(models.Model): + # Used for tracking user-user identity + # If userA == userB and userB == userC, then userA == userC (transitive) + user_id = models.PositiveIntegerField() + user_group = models.UUIDField() + created = models.DateTimeField(null=False) + + class Meta: + db_table = "marketplace_usergroup" + + unique_together = ("user_id", "user_group") + indexes = [ + models.Index(fields=["created"]), + models.Index(fields=["user_id"]), + models.Index(fields=["user_group"]), + ] + + +class Buyer(models.Model): + id = models.BigAutoField(primary_key=True) + + # The marketplace's 2-letter code {l, c, d, h, s, ...} + source = models.CharField(max_length=2, null=False) + + # The marketplace's ID/code for this buyer + code = models.CharField(max_length=128, null=False) + + # optional text name for the buyer, if available + label = models.CharField(max_length=255, null=True) + + # when this entry was made, or when the buyer was first seen + created = models.DateTimeField(auto_now_add=True, null=False) + + class Meta: + db_table = "marketplace_buyer" + + unique_together = ("source", "code") + indexes = [models.Index(fields=["created"])] + + +class BuyerGroup(models.Model): + """ + If we know that a buyer is the same buyer across different marketplaces, + we can link them here. + + Constraints here enforce: + - a buyer can only be in 1 group, once (no duplicates) + - a group can have multiple buyers + """ + + id = models.BigAutoField(primary_key=True) + + # This is the buyer group's universal ID (can expose this) + group = models.UUIDField(default=uuid.uuid4, null=False) + + # OneToOneField: Same thing as a ForeignKey with unique = True + buyer = models.OneToOneField(Buyer, on_delete=models.RESTRICT) + created = models.DateTimeField(auto_now_add=True, null=False) + + class Meta: + db_table = "marketplace_buyergroup" + + indexes = [ + models.Index(fields=["created"]), + models.Index(fields=["group"]), + ] + + +class Survey(models.Model): + id = models.BigAutoField(primary_key=True) + # The "unique" key in this table is: + # (source, survey_id) + source = models.CharField(max_length=2, null=False) + survey_id = models.CharField(max_length=32, null=False) + buyer = models.ForeignKey(Buyer, null=True, on_delete=models.PROTECT) + created_at = models.DateTimeField(auto_now_add=True, null=False) + updated_at = models.DateTimeField(auto_now=True, null=False) + + # This I'm not sure about. We want data to be able to return in an + # offerwall "why" a user is eligible for this survey. The complexity + # is mapping the mp-specific question codes to the question ids we + # ask, and unsure where best to do that. + # + # Also, 99% are going to have age/gender, which is a waste of data, so + # maybe we want to structure this differently. + # # used_question_ids = models.JSONField(default=list) + # Going with this instead. We'll have to structure it clearly in pydantic: + eligibility_criteria = models.JSONField(null=True) + + is_live = models.BooleanField(null=False) + is_recontact = models.BooleanField(default=False) + + # .... more metadata: survey platform/host, category, etc ... + + class Meta: + db_table = "marketplace_survey" + indexes = [ + models.Index(fields=["source", "is_live"]), + # Tiny index compared to ----^, but only if we filter WHERE is_live = TRUE + models.Index( + fields=["source"], + name="survey_live_by_source", + condition=models.Q(is_live=True), + ), + models.Index(fields=["created_at"]), + models.Index(fields=["updated_at"]), + ] + constraints = [ + models.UniqueConstraint( + fields=["source", "survey_id"], + name="uniq_survey_source_survey_id", + ) + ] + + +class SurveyCategory(models.Model): + """ + Associates a Survey with one or more Categories, with an optional strength / weight. + """ + + id = models.BigAutoField(primary_key=True) + survey = models.ForeignKey(Survey, on_delete=models.CASCADE) + category = models.ForeignKey(Category, on_delete=models.RESTRICT) + + # Strength / confidence / relevance. 0.0–1.0 probability + # The sum(strength) for a survey should add up to 1. + strength = models.FloatField( + null=True, + help_text="Relative relevance or confidence (0–1)", + ) + + class Meta: + db_table = "marketplace_surveycategory" + constraints = [ + models.UniqueConstraint( + fields=["survey", "category"], + name="uniq_survey_category", + ) + ] + + +class SurveyStat(models.Model): + id = models.BigAutoField(primary_key=True) + + # The "unique" key in this table (what all stats are calculated for) is: + # ((source, survey_id)=survey, quota_id, country_iso, version) + survey = models.ForeignKey(Survey, on_delete=models.RESTRICT) + + # We could calculate stats for a specific quota. This should be nullable, + # but that is problematic in a unique key b/c NULL != NULL, so the comparison + # doesn't use the index properly. Instead of null, use a sentinel value + # like "__all__". + quota_id = models.CharField(max_length=32, null=False) + + # We could also have stats per country, if a survey is open to multiple countries. + # Use 'ZZ' if a survey is open to any country, and we calculate the stats pooled. + country_iso = models.CharField(max_length=2, null=False) + + cpi = models.DecimalField(max_digits=8, decimal_places=5, null=False) + complete_too_fast_cutoff = models.IntegerField(help_text="Seconds") + + # ---- Distributions ---- + + prescreen_conv_alpha = models.FloatField() + prescreen_conv_beta = models.FloatField() + + conv_alpha = models.FloatField() + conv_beta = models.FloatField() + + dropoff_alpha = models.FloatField() + dropoff_beta = models.FloatField() + + completion_time_mu = models.FloatField() + completion_time_sigma = models.FloatField() + + # Eligibility modeled probabilistically + mobile_eligible_alpha = models.FloatField() + mobile_eligible_beta = models.FloatField() + + desktop_eligible_alpha = models.FloatField() + desktop_eligible_beta = models.FloatField() + + tablet_eligible_alpha = models.FloatField() + tablet_eligible_beta = models.FloatField() + + # ---- Scalar risk / quality metrics ---- + + long_fail_rate = models.FloatField() + user_report_coeff = models.FloatField() + recon_likelihood = models.FloatField() + + # Survey penalty gets converted to score_x0 = 0 and score_x1 = (1-penalty) + # (these are the coefficients that'll be applied to the final score, e.g. + # score_x0 + {x}*score_x1 + {x}*score_x2^2 ... ) + score_x0 = models.FloatField() + score_x1 = models.FloatField() + + # generalized/predicated score + score = models.FloatField() + + # ---- Metadata ---- + # We can use this to compare yield-management strategies, or A/B test stats, etc... + version = models.PositiveIntegerField(help_text="Bump when logic changes") + + # Set when a row is created, updated, or turned not live (even if stats didn't change) + updated_at = models.DateTimeField(auto_now=True) + + # These are de-normalized from the Survey for ease of join / SQL operations. + # They should match exactly the fields on the referenced Survey. + survey_is_live = models.BooleanField(null=False) + survey_survey_id = models.CharField(max_length=32, null=False) + survey_source = models.CharField(max_length=2, null=False) + + class Meta: + db_table = "marketplace_surveystat" + indexes = [ + models.Index( + fields=["survey"], + name="surveystat_live_survey_idx", + condition=Q(survey_is_live=True), + ), + ] + constraints = [ + models.UniqueConstraint( + fields=[ + "survey", + "quota_id", + "country_iso", + "version", + ], + name="uniq_surveystat_survey_quota_country_version", + ) + ] + + +# class SurveyStatBP(models.Model): +# """ +# Defines Brokerage Product-specific adjustments +# for a survey. +# +# Notes: +# - For the survey as a whole, no quota, country_iso, version. +# - This is currently only used for rate-limiting completes from +# a BP into a particular survey. +# - This table is "sparse" in that 99% of live surveys +# won't have any bp-specific adjustments. +# """ +# +# id = models.BigAutoField(primary_key=True) +# +# product_id = models.UUIDField(null=False) +# +# survey = models.ForeignKey( +# Survey, +# on_delete=models.CASCADE, +# ) +# +# # Survey penalty gets converted to score_x0 = 0 and score_x1 = (1-penalty) +# # (these are the coefficients that'll be applied to the final score, e.g. +# # score_x0 + {x}*score_x1 + {x}*score_x2^2 ... ) +# score_x0 = models.FloatField() +# score_x1 = models.FloatField() + + +# +# class SurveyStatusBucket(models.Model): +# """ +# Aggregated counts of wall.status +# Grouped by: +# ((source, survey_id)=survey, quota_id, country_iso, product_id, bucket_start, bucket_size, status) +# """ +# +# id = models.BigAutoField(primary_key=True) +# +# survey = models.ForeignKey(Survey, on_delete=models.RESTRICT) +# +# quota_id = models.CharField(max_length=32, null=True) +# country_iso = models.CharField(max_length=2, null=False) +# +# product_id = models.UUIDField(null=False) +# +# status = models.CharField(max_length=1, null=True) +# +# count = models.PositiveIntegerField(null=False, default=0) +# +# bucket_start = models.DateTimeField( +# help_text="UTC start of aggregation bucket" +# ) +# +# BUCKET_SIZES = ( +# (3600, "hour"), +# (86400, "day"), +# ) +# +# bucket_size = models.PositiveSmallIntegerField( +# help_text="Bucket size in seconds (e.g. 3600, 86400)", +# choices=BUCKET_SIZES, +# ) +# +# updated_at = models.DateTimeField(auto_now=True) +# +# class Meta: +# db_table = "marketplace_surveystatusbucket" +# constraints = [ +# models.UniqueConstraint( +# fields=[ +# "survey", +# "quota_id", +# "country_iso", +# "product_id", +# "status", +# "bucket_start", +# "bucket_size", +# ], +# nulls_distinct=False, +# name="uniq_surveystatuscount", +# ) +# ] +# indexes = [ +# models.Index(fields=["survey", "bucket_start"]), +# models.Index(fields=["status", "bucket_start"]), +# models.Index(fields=["product_id", "bucket_start"]), +# ] diff --git a/generalresearch/thl_django/migrations/0001_initial.py b/generalresearch/thl_django/migrations/0001_initial.py new file mode 100644 index 0000000..ecae35a --- /dev/null +++ b/generalresearch/thl_django/migrations/0001_initial.py @@ -0,0 +1,1066 @@ +# Generated by Django 6.0 on 2025-12-26 20:53 + +import django.db.models.deletion +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='Item', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), + ('label', models.CharField(max_length=255)), + ('description', models.TextField(null=True)), + ], + options={ + 'db_table': 'marketplace_item', + }, + ), + migrations.CreateModel( + name='Language', + fields=[ + ('code', models.CharField(help_text='three-letter language code', max_length=3, primary_key=True, serialize=False)), + ('name', models.CharField(help_text='language name', max_length=255)), + ], + options={ + 'db_table': 'userprofile_language', + }, + ), + migrations.CreateModel( + name='PayoutMethod', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('user_id', models.IntegerField(db_index=True, null=True)), + ('default', models.BooleanField(default=False)), + ('enabled', models.BooleanField(default=True)), + ('method', models.CharField(choices=[('a', 'AMT'), ('c', 'ACH'), ('t', 'Tango'), ('p', 'PAYPAL')], default='t', max_length=1)), + ('recipient', models.CharField(blank=True, max_length=200, null=True)), + ('updated', models.DateTimeField(auto_now=True)), + ('created', models.DateTimeField(auto_now_add=True)), + ], + options={ + 'db_table': 'userprofile_payoutmethod', + 'ordering': ('-created',), + 'get_latest_by': 'created', + }, + ), + migrations.CreateModel( + name='Property', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), + ('label', models.CharField(max_length=255)), + ('description', models.TextField()), + ('cardinality', models.CharField(max_length=9)), + ('prop_type', models.CharField(choices=[('n', 'numerical'), ('x', 'text'), ('i', 'item')], default='c', max_length=1)), + ], + options={ + 'db_table': 'marketplace_property', + }, + ), + migrations.CreateModel( + name='THLUser', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('uuid', models.UUIDField(unique=True)), + ('product_id', models.UUIDField()), + ('product_user_id', models.CharField(max_length=128)), + ('created', models.DateTimeField()), + ('last_seen', models.DateTimeField()), + ('last_country_iso', models.CharField(max_length=2, null=True)), + ('last_geoname_id', models.IntegerField(null=True)), + ('last_ip', models.GenericIPAddressField(null=True)), + ('blocked', models.BooleanField(default=False)), + ], + options={ + 'db_table': 'thl_user', + }, + ), + migrations.CreateModel( + name='THLWall', + fields=[ + ('uuid', models.UUIDField(primary_key=True, serialize=False)), + ('source', models.CharField(max_length=2)), + ('buyer_id', models.CharField(max_length=32, null=True)), + ('survey_id', models.CharField(max_length=32)), + ('req_survey_id', models.CharField(max_length=32)), + ('cpi', models.DecimalField(decimal_places=5, max_digits=8)), + ('req_cpi', models.DecimalField(decimal_places=5, max_digits=8)), + ('started', models.DateTimeField()), + ('finished', models.DateTimeField(null=True)), + ('status', models.CharField(default=None, max_length=1, null=True)), + ('status_code_1', models.SmallIntegerField(null=True)), + ('status_code_2', models.SmallIntegerField(null=True)), + ('ext_status_code_1', models.CharField(max_length=32, null=True)), + ('ext_status_code_2', models.CharField(max_length=32, null=True)), + ('ext_status_code_3', models.CharField(max_length=32, null=True)), + ('report_value', models.SmallIntegerField(null=True)), + ('report_notes', models.CharField(max_length=255, null=True)), + ('adjusted_status', models.CharField(max_length=2, null=True)), + ('adjusted_cpi', models.DecimalField(decimal_places=5, max_digits=8, null=True)), + ('adjusted_timestamp', models.DateTimeField(null=True)), + ], + options={ + 'db_table': 'thl_wall', + }, + ), + migrations.CreateModel( + name='UserAuditLog', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('user_id', models.BigIntegerField()), + ('created', models.DateTimeField()), + ('level', models.PositiveSmallIntegerField(default=0)), + ('event_type', models.CharField(max_length=64)), + ('event_msg', models.CharField(max_length=256, null=True)), + ('event_value', models.FloatField(null=True)), + ], + options={ + 'db_table': 'userhealth_auditlog', + }, + ), + migrations.CreateModel( + name='UserGroup', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('user_id', models.PositiveIntegerField()), + ('user_group', models.UUIDField()), + ('created', models.DateTimeField()), + ], + options={ + 'db_table': 'marketplace_usergroup', + }, + ), + migrations.CreateModel( + name='UserHealthIPHistory', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('user_id', models.BigIntegerField()), + ('ip', models.GenericIPAddressField()), + ('created', models.DateTimeField(auto_now_add=True)), + ('forwarded_ip1', models.GenericIPAddressField(null=True)), + ('forwarded_ip2', models.GenericIPAddressField(null=True)), + ('forwarded_ip3', models.GenericIPAddressField(null=True)), + ('forwarded_ip4', models.GenericIPAddressField(null=True)), + ('forwarded_ip5', models.GenericIPAddressField(null=True)), + ('forwarded_ip6', models.GenericIPAddressField(null=True)), + ], + options={ + 'db_table': 'userhealth_iphistory', + }, + ), + migrations.CreateModel( + name='UserHealthWebSocketIPHistory', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('user_id', models.BigIntegerField()), + ('ip', models.GenericIPAddressField()), + ('created', models.DateTimeField(auto_now_add=True)), + ('last_seen', models.DateTimeField(auto_now_add=True)), + ], + options={ + 'db_table': 'userhealth_iphistory_ws', + }, + ), + migrations.CreateModel( + name='UserProfileKnowledgeItem', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('user_id', models.PositiveIntegerField()), + ('property_id', models.UUIDField()), + ('session_id', models.UUIDField(null=True)), + ('question_id', models.UUIDField(null=True)), + ('country_iso', models.CharField(default='us', max_length=2)), + ('created', models.DateTimeField(auto_now=True)), + ('value', models.UUIDField()), + ], + options={ + 'db_table': 'marketplace_userprofileknowledgeitem', + }, + ), + migrations.CreateModel( + name='UserProfileKnowledgeNumerical', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('user_id', models.PositiveIntegerField()), + ('property_id', models.UUIDField()), + ('session_id', models.UUIDField(null=True)), + ('question_id', models.UUIDField(null=True)), + ('country_iso', models.CharField(default='us', max_length=2)), + ('created', models.DateTimeField(auto_now=True)), + ('value', models.FloatField()), + ], + options={ + 'db_table': 'marketplace_userprofileknowledgenumerical', + }, + ), + migrations.CreateModel( + name='UserProfileKnowledgeText', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('user_id', models.PositiveIntegerField()), + ('property_id', models.UUIDField()), + ('session_id', models.UUIDField(null=True)), + ('question_id', models.UUIDField(null=True)), + ('country_iso', models.CharField(default='us', max_length=2)), + ('created', models.DateTimeField(auto_now=True)), + ('value', models.CharField(max_length=1024)), + ], + options={ + 'db_table': 'marketplace_userprofileknowledgetext', + }, + ), + migrations.CreateModel( + name='UserQuestionAnswer', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created', models.DateTimeField()), + ('session_id', models.UUIDField(null=True)), + ('user_id', models.IntegerField()), + ('answer', models.JSONField(default=list)), + ('calc_answer', models.JSONField(default=dict)), + ], + options={ + 'db_table': 'marketplace_userquestionanswer', + }, + ), + migrations.CreateModel( + name='UserStat', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('user_id', models.PositiveIntegerField()), + ('key', models.CharField(max_length=255)), + ('value', models.FloatField(null=True)), + ('date', models.DateTimeField(auto_now=True)), + ], + options={ + 'db_table': 'userprofile_userstat', + }, + ), + migrations.CreateModel( + name='Bribe', + fields=[ + ('uuid', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), + ('credit_account_uuid', models.UUIDField()), + ('created', models.DateTimeField(auto_now_add=True)), + ('amount', models.BigIntegerField()), + ('ext_ref_id', models.CharField(max_length=64, null=True)), + ('description', models.TextField(null=True)), + ('data', models.JSONField(null=True)), + ], + options={ + 'db_table': 'event_bribe', + 'indexes': [models.Index(fields=['created'], name='event_bribe_created_765d8d_idx'), models.Index(fields=['credit_account_uuid'], name='event_bribe_credit__05f3ba_idx'), models.Index(fields=['ext_ref_id'], name='event_bribe_ext_ref_0ddf91_idx')], + }, + ), + migrations.CreateModel( + name='BrokerageProduct', + fields=[ + ('id', models.UUIDField(primary_key=True, serialize=False)), + ('id_int', models.BigIntegerField(unique=True)), + ('name', models.CharField(max_length=255)), + ('team_id', models.UUIDField(null=True)), + ('business_id', models.UUIDField(null=True)), + ('created', models.DateTimeField(auto_now_add=True, null=True)), + ('enabled', models.BooleanField(default=True)), + ('payments_enabled', models.BooleanField(default=True)), + ('commission', models.DecimalField(decimal_places=6, max_digits=6, null=True)), + ('redirect_url', models.URLField(null=True)), + ('grs_domain', models.CharField(max_length=200, null=True)), + ('profiling_config', models.JSONField(default=dict)), + ('user_health_config', models.JSONField(default=dict)), + ('yield_man_config', models.JSONField(default=dict)), + ('offerwall_config', models.JSONField(default=dict)), + ('session_config', models.JSONField(default=dict)), + ('payout_config', models.JSONField(default=dict)), + ('user_create_config', models.JSONField(default=dict)), + ], + options={ + 'db_table': 'userprofile_brokerageproduct', + 'unique_together': {('team_id', 'name')}, + }, + ), + migrations.CreateModel( + name='BrokerageProductConfig', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('key', models.CharField(max_length=255)), + ('value', models.JSONField(default=dict)), + ('product', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='thl_django.brokerageproduct')), + ], + options={ + 'db_table': 'userprofile_brokerageproductconfig', + }, + ), + migrations.CreateModel( + name='BrokerageProductTag', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('product_id', models.BigIntegerField()), + ('tag', models.CharField(max_length=64)), + ], + options={ + 'db_table': 'userprofile_brokerageproducttag', + 'unique_together': {('product_id', 'tag')}, + }, + ), + migrations.CreateModel( + name='Buyer', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('source', models.CharField(max_length=2)), + ('code', models.CharField(max_length=128)), + ('label', models.CharField(max_length=255, null=True)), + ('created', models.DateTimeField(auto_now_add=True)), + ], + options={ + 'db_table': 'marketplace_buyer', + 'indexes': [models.Index(fields=['created'], name='marketplace_created_b168c4_idx')], + 'unique_together': {('source', 'code')}, + }, + ), + migrations.CreateModel( + name='BuyerGroup', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('group', models.UUIDField(default=uuid.uuid4)), + ('created', models.DateTimeField(auto_now_add=True)), + ('buyer', models.OneToOneField(on_delete=django.db.models.deletion.RESTRICT, to='thl_django.buyer')), + ], + options={ + 'db_table': 'marketplace_buyergroup', + }, + ), + migrations.CreateModel( + name='CashoutMethod', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), + ('last_updated', models.DateTimeField(auto_now=True)), + ('is_live', models.BooleanField(default=False)), + ('provider', models.CharField(max_length=32)), + ('ext_id', models.CharField(max_length=255, null=True)), + ('name', models.CharField(max_length=512)), + ('data', models.JSONField(default=dict)), + ('user_id', models.PositiveIntegerField(null=True)), + ], + options={ + 'db_table': 'accounting_cashoutmethod', + 'indexes': [models.Index(fields=['user_id'], name='accounting__user_id_3064f8_idx'), models.Index(fields=['provider', 'ext_id'], name='accounting__provide_b10797_idx')], + }, + ), + migrations.CreateModel( + name='Category', + fields=[ + ('id', models.AutoField(primary_key=True, serialize=False)), + ('uuid', models.UUIDField(unique=True)), + ('adwords_vertical_id', models.CharField(max_length=8, null=True)), + ('label', models.CharField(max_length=255)), + ('path', models.CharField(max_length=1024, null=True)), + ('parent', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, to='thl_django.category')), + ], + options={ + 'db_table': 'marketplace_category', + }, + ), + migrations.CreateModel( + name='ConceptTranslation', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('concept_id', models.UUIDField()), + ('language_iso', models.CharField(max_length=3)), + ('text', models.TextField()), + ], + options={ + 'db_table': 'marketplace_concepttranslation', + 'indexes': [models.Index(fields=['concept_id'], name='marketplace_concept_e2bbff_idx'), models.Index(fields=['language_iso'], name='marketplace_languag_dad088_idx')], + }, + ), + migrations.CreateModel( + name='Contest', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('uuid', models.UUIDField(unique=True)), + ('product_id', models.UUIDField()), + ('name', models.CharField(max_length=128)), + ('description', models.CharField(max_length=2048, null=True)), + ('country_isos', models.CharField(max_length=1024, null=True)), + ('contest_type', models.CharField(max_length=32)), + ('status', models.CharField(max_length=32)), + ('starts_at', models.DateTimeField()), + ('terms_and_conditions', models.CharField(max_length=2048, null=True)), + ('end_condition', models.JSONField()), + ('prizes', models.JSONField()), + ('ended_at', models.DateTimeField(null=True)), + ('end_reason', models.CharField(max_length=32, null=True)), + ('entry_type', models.CharField(max_length=8, null=True)), + ('entry_rule', models.JSONField(null=True)), + ('current_participants', models.IntegerField(null=True)), + ('current_amount', models.IntegerField(null=True)), + ('milestone_config', models.JSONField(null=True)), + ('win_count', models.IntegerField(null=True)), + ('leaderboard_key', models.CharField(max_length=128, null=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ], + options={ + 'db_table': 'contest_contest', + 'indexes': [models.Index(fields=['product_id', 'created_at'], name='contest_con_product_bb9938_idx'), models.Index(fields=['product_id', 'status'], name='contest_con_product_c8fc09_idx')], + }, + ), + migrations.CreateModel( + name='ContestEntry', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('uuid', models.UUIDField(unique=True)), + ('contest_id', models.BigIntegerField()), + ('amount', models.IntegerField()), + ('user_id', models.BigIntegerField()), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ], + options={ + 'db_table': 'contest_contestentry', + 'indexes': [models.Index(fields=['user_id', 'created_at'], name='contest_con_user_id_8666d0_idx'), models.Index(fields=['contest_id', 'user_id'], name='contest_con_contest_c9ec32_idx')], + }, + ), + migrations.CreateModel( + name='ContestWinner', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('uuid', models.UUIDField(unique=True)), + ('contest_id', models.BigIntegerField()), + ('user_id', models.BigIntegerField()), + ('prize', models.JSONField()), + ('awarded_cash_amount', models.IntegerField(null=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ], + options={ + 'db_table': 'contest_contestwinner', + 'indexes': [models.Index(fields=['user_id', 'created_at'], name='contest_con_user_id_3215e1_idx'), models.Index(fields=['contest_id'], name='contest_con_contest_bae153_idx')], + }, + ), + migrations.CreateModel( + name='GeoName', + fields=[ + ('geoname_id', models.PositiveIntegerField(primary_key=True, serialize=False)), + ('continent_code', models.CharField(max_length=2)), + ('continent_name', models.CharField(max_length=32)), + ('country_iso', models.CharField(max_length=2, null=True)), + ('country_name', models.CharField(max_length=64, null=True)), + ('subdivision_1_iso', models.CharField(max_length=3, null=True)), + ('subdivision_1_name', models.CharField(max_length=255, null=True)), + ('subdivision_2_iso', models.CharField(max_length=3, null=True)), + ('subdivision_2_name', models.CharField(max_length=255, null=True)), + ('city_name', models.CharField(max_length=255, null=True)), + ('metro_code', models.PositiveSmallIntegerField(null=True)), + ('time_zone', models.CharField(max_length=60, null=True)), + ('is_in_european_union', models.BooleanField(null=True)), + ('updated', models.DateTimeField(auto_now=True)), + ], + options={ + 'db_table': 'thl_geoname', + 'indexes': [models.Index(fields=['updated'], name='thl_geoname_updated_765034_idx')], + }, + ), + migrations.CreateModel( + name='IPInformation', + fields=[ + ('ip', models.GenericIPAddressField(primary_key=True, serialize=False)), + ('geoname_id', models.PositiveIntegerField(null=True)), + ('country_iso', models.CharField(max_length=2, null=True)), + ('registered_country_iso', models.CharField(max_length=2, null=True)), + ('is_anonymous', models.BooleanField(null=True)), + ('is_anonymous_vpn', models.BooleanField(null=True)), + ('is_hosting_provider', models.BooleanField(null=True)), + ('is_public_proxy', models.BooleanField(null=True)), + ('is_tor_exit_node', models.BooleanField(null=True)), + ('is_residential_proxy', models.BooleanField(null=True)), + ('autonomous_system_number', models.IntegerField(null=True)), + ('autonomous_system_organization', models.CharField(max_length=255, null=True)), + ('domain', models.CharField(blank=True, max_length=255, null=True)), + ('isp', models.CharField(max_length=255, null=True)), + ('mobile_country_code', models.CharField(max_length=3, null=True)), + ('mobile_network_code', models.CharField(max_length=3, null=True)), + ('network', models.CharField(max_length=56, null=True)), + ('organization', models.CharField(max_length=255, null=True)), + ('static_ip_score', models.FloatField(null=True)), + ('user_type', models.CharField(max_length=64, null=True)), + ('postal_code', models.CharField(blank=True, max_length=20, null=True)), + ('latitude', models.DecimalField(decimal_places=6, max_digits=10, null=True)), + ('longitude', models.DecimalField(decimal_places=6, max_digits=10, null=True)), + ('accuracy_radius', models.PositiveSmallIntegerField(null=True)), + ('updated', models.DateTimeField(auto_now=True)), + ], + options={ + 'db_table': 'thl_ipinformation', + 'indexes': [models.Index(fields=['updated'], name='thl_ipinfor_updated_a17fec_idx')], + }, + ), + migrations.CreateModel( + name='LedgerAccount', + fields=[ + ('uuid', models.UUIDField(primary_key=True, serialize=False)), + ('display_name', models.CharField(max_length=64)), + ('qualified_name', models.CharField(max_length=255, unique=True)), + ('account_type', models.CharField(max_length=30, null=True)), + ('normal_balance', models.SmallIntegerField(choices=[(-1, 'credit'), (1, 'debit')])), + ('reference_type', models.CharField(max_length=30, null=True)), + ('reference_uuid', models.UUIDField(null=True)), + ('currency', models.CharField(max_length=32)), + ], + options={ + 'db_table': 'ledger_account', + 'indexes': [models.Index(fields=['reference_uuid'], name='ledger_acco_referen_df449c_idx')], + }, + ), + migrations.CreateModel( + name='LedgerAccountStatement', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('filter_str', models.CharField(max_length=255, null=True)), + ('effective_at_lower_bound', models.DateTimeField()), + ('effective_at_upper_bound', models.DateTimeField()), + ('starting_balance', models.BigIntegerField()), + ('ending_balance', models.BigIntegerField()), + ('sql_query', models.TextField(null=True)), + ('account', models.ForeignKey(on_delete=django.db.models.deletion.RESTRICT, to='thl_django.ledgeraccount')), + ], + options={ + 'db_table': 'ledger_accountstatement', + }, + ), + migrations.CreateModel( + name='LedgerTransaction', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('created', models.DateTimeField()), + ('ext_description', models.CharField(max_length=255, null=True)), + ('tag', models.CharField(max_length=255, null=True)), + ], + options={ + 'db_table': 'ledger_transaction', + 'indexes': [models.Index(fields=['created'], name='ledger_tran_created_091140_idx'), models.Index(fields=['tag'], name='ledger_tran_tag_48e33a_idx')], + }, + ), + migrations.CreateModel( + name='LedgerEntry', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('direction', models.SmallIntegerField(choices=[(-1, 'credit'), (1, 'debit')])), + ('amount', models.BigIntegerField()), + ('account', models.ForeignKey(on_delete=django.db.models.deletion.RESTRICT, related_name='account', to='thl_django.ledgeraccount')), + ('transaction', models.ForeignKey(on_delete=django.db.models.deletion.RESTRICT, related_name='transaction', to='thl_django.ledgertransaction')), + ], + options={ + 'db_table': 'ledger_entry', + }, + ), + migrations.CreateModel( + name='LedgerTransactionMetadata', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('key', models.CharField(max_length=30)), + ('value', models.CharField(max_length=255)), + ('transaction', models.ForeignKey(on_delete=django.db.models.deletion.RESTRICT, to='thl_django.ledgertransaction')), + ], + options={ + 'db_table': 'ledger_transactionmetadata', + }, + ), + migrations.CreateModel( + name='Payout', + fields=[ + ('uuid', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), + ('debit_account_uuid', models.UUIDField()), + ('cashout_method_uuid', models.UUIDField()), + ('created', models.DateTimeField(auto_now_add=True)), + ('amount', models.BigIntegerField()), + ('status', models.CharField(max_length=20, null=True)), + ('ext_ref_id', models.CharField(max_length=64, null=True)), + ('payout_type', models.CharField(max_length=14)), + ('request_data', models.JSONField(null=True)), + ('order_data', models.JSONField(null=True)), + ], + options={ + 'db_table': 'event_payout', + 'indexes': [models.Index(fields=['created'], name='event_payou_created_b8b87c_idx'), models.Index(fields=['debit_account_uuid'], name='event_payou_debit_a_3ae0ae_idx'), models.Index(fields=['ext_ref_id'], name='event_payou_ext_ref_a519ac_idx')], + }, + ), + migrations.CreateModel( + name='ProbeLog', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('source', models.CharField(max_length=2)), + ('survey_id', models.CharField(max_length=32)), + ('started', models.DateTimeField()), + ('live_url', models.CharField(max_length=3000)), + ('har_path', models.CharField(max_length=1000)), + ('result', models.CharField(max_length=64, null=True)), + ], + options={ + 'db_table': 'marketplace_probelog', + 'indexes': [models.Index(fields=['source', 'survey_id'], name='marketplace_source_cfaed3_idx'), models.Index(fields=['started'], name='marketplace_started_057aa4_idx')], + }, + ), + migrations.CreateModel( + name='PropertyCategoryAssociation', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('property_id', models.UUIDField(null=True)), + ('category', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='thl_django.category')), + ], + options={ + 'db_table': 'marketplace_propertycategoryassociation', + }, + ), + migrations.CreateModel( + name='PropertyCountry', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('property_id', models.UUIDField(null=True)), + ('country_iso', models.CharField(max_length=2)), + ('gold_standard', models.BooleanField(default=False)), + ], + options={ + 'db_table': 'marketplace_propertycountry', + 'indexes': [models.Index(fields=['property_id'], name='marketplace_propert_3eda38_idx'), models.Index(fields=['country_iso'], name='marketplace_country_5e4fb6_idx')], + }, + ), + migrations.CreateModel( + name='PropertyItemRange', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('property_id', models.UUIDField(null=True)), + ('country_iso', models.CharField(max_length=2)), + ('item', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='thl_django.item')), + ], + options={ + 'db_table': 'marketplace_propertyitemrange', + }, + ), + migrations.CreateModel( + name='PropertyMarketplaceAssociation', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('property_id', models.UUIDField(null=True)), + ('source', models.CharField(max_length=1)), + ('question_id', models.CharField(max_length=32)), + ], + options={ + 'db_table': 'marketplace_propertymarketplaceassociation', + 'indexes': [models.Index(fields=['source', 'question_id'], name='marketplace_source_0ad453_idx'), models.Index(fields=['property_id'], name='marketplace_propert_9f1981_idx')], + }, + ), + migrations.CreateModel( + name='Question', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), + ('md5sum', models.CharField(max_length=32, null=True)), + ('country_iso', models.CharField(default='us', max_length=2)), + ('language_iso', models.CharField(default='eng', max_length=3)), + ('property_code', models.CharField(max_length=64)), + ('data', models.JSONField(default=dict)), + ('is_live', models.BooleanField(default=False)), + ('custom', models.JSONField(default=dict)), + ('last_updated', models.DateTimeField(null=True)), + ], + options={ + 'db_table': 'marketplace_question', + 'indexes': [models.Index(fields=['last_updated'], name='marketplace_last_up_9147b8_idx'), models.Index(fields=['property_code'], name='marketplace_propert_c8d11e_idx')], + }, + ), + migrations.CreateModel( + name='Survey', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('source', models.CharField(max_length=2)), + ('survey_id', models.CharField(max_length=32)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('eligibility_criteria', models.JSONField(null=True)), + ('is_live', models.BooleanField()), + ('is_recontact', models.BooleanField(default=False)), + ('buyer', models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, to='thl_django.buyer')), + ], + options={ + 'db_table': 'marketplace_survey', + }, + ), + migrations.CreateModel( + name='SurveyCategory', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('strength', models.FloatField(help_text='Relative relevance or confidence (0–1)')), + ('category', models.ForeignKey(on_delete=django.db.models.deletion.RESTRICT, to='thl_django.category')), + ('survey', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='thl_django.survey')), + ], + options={ + 'db_table': 'marketplace_surveycategory', + }, + ), + migrations.CreateModel( + name='SurveyStat', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('quota_id', models.CharField(max_length=32)), + ('country_iso', models.CharField(max_length=2)), + ('cpi', models.DecimalField(decimal_places=5, max_digits=8)), + ('complete_too_fast_cutoff', models.IntegerField(help_text='Seconds')), + ('prescreen_conv_alpha', models.FloatField()), + ('prescreen_conv_beta', models.FloatField()), + ('conv_alpha', models.FloatField()), + ('conv_beta', models.FloatField()), + ('dropoff_alpha', models.FloatField()), + ('dropoff_beta', models.FloatField()), + ('completion_time_mu', models.FloatField()), + ('completion_time_sigma', models.FloatField()), + ('mobile_eligible_alpha', models.FloatField()), + ('mobile_eligible_beta', models.FloatField()), + ('desktop_eligible_alpha', models.FloatField()), + ('desktop_eligible_beta', models.FloatField()), + ('tablet_eligible_alpha', models.FloatField()), + ('tablet_eligible_beta', models.FloatField()), + ('long_fail_rate', models.FloatField()), + ('user_report_coeff', models.FloatField()), + ('recon_likelihood', models.FloatField()), + ('score_x0', models.FloatField()), + ('score_x1', models.FloatField()), + ('score', models.FloatField()), + ('version', models.PositiveIntegerField(help_text='Bump when logic changes')), + ('updated_at', models.DateTimeField(auto_now=True)), + ('survey', models.ForeignKey(on_delete=django.db.models.deletion.RESTRICT, to='thl_django.survey')), + ], + options={ + 'db_table': 'marketplace_surveystat', + }, + ), + migrations.CreateModel( + name='TaskAdjustment', + fields=[ + ('uuid', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), + ('adjusted_status', models.CharField(max_length=2)), + ('ext_status_code', models.CharField(max_length=32, null=True)), + ('amount', models.DecimalField(decimal_places=2, max_digits=5, null=True)), + ('alerted', models.DateTimeField()), + ('created', models.DateTimeField(auto_now_add=True)), + ('user_id', models.BigIntegerField()), + ('wall_uuid', models.UUIDField()), + ('started', models.DateTimeField()), + ('source', models.CharField(max_length=2)), + ('survey_id', models.CharField(max_length=32)), + ], + options={ + 'db_table': 'thl_taskadjustment', + 'indexes': [models.Index(fields=['created'], name='thl_taskadj_created_372998_idx'), models.Index(fields=['user_id'], name='thl_taskadj_user_id_e87483_idx'), models.Index(fields=['wall_uuid'], name='thl_taskadj_wall_uu_c23480_idx')], + }, + ), + migrations.CreateModel( + name='THLSession', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('uuid', models.UUIDField(unique=True)), + ('user_id', models.BigIntegerField()), + ('started', models.DateTimeField()), + ('finished', models.DateTimeField(null=True)), + ('loi_min', models.SmallIntegerField(null=True)), + ('loi_max', models.SmallIntegerField(null=True)), + ('user_payout_min', models.DecimalField(decimal_places=2, max_digits=5, null=True)), + ('user_payout_max', models.DecimalField(decimal_places=2, max_digits=5, null=True)), + ('country_iso', models.CharField(max_length=2, null=True)), + ('device_type', models.SmallIntegerField(null=True)), + ('ip', models.GenericIPAddressField(null=True)), + ('status', models.CharField(default=None, max_length=1, null=True)), + ('status_code_1', models.SmallIntegerField(null=True)), + ('status_code_2', models.SmallIntegerField(null=True)), + ('payout', models.DecimalField(decimal_places=2, max_digits=5, null=True)), + ('user_payout', models.DecimalField(decimal_places=2, max_digits=5, null=True)), + ('adjusted_status', models.CharField(max_length=2, null=True)), + ('adjusted_payout', models.DecimalField(decimal_places=2, max_digits=5, null=True)), + ('adjusted_user_payout', models.DecimalField(decimal_places=2, max_digits=5, null=True)), + ('adjusted_timestamp', models.DateTimeField(null=True)), + ('url_metadata', models.JSONField(null=True)), + ], + options={ + 'db_table': 'thl_session', + 'indexes': [models.Index(fields=['user_id', 'started'], name='thl_session_user_id_72123d_idx'), models.Index(fields=['started'], name='thl_session_started_d5984e_idx'), models.Index(fields=['country_iso'], name='thl_session_country_33a433_idx'), models.Index(fields=['status'], name='thl_session_status_d578b7_idx'), models.Index(fields=['status_code_1'], name='thl_session_status__4c18db_idx'), models.Index(condition=models.Q(('adjusted_status__isnull', False)), fields=['adjusted_status'], name='thl_session_adj_status_nn_idx'), models.Index(condition=models.Q(('adjusted_timestamp__isnull', False)), fields=['adjusted_timestamp'], name='thl_session_adj_ts_nn_idx'), models.Index(fields=['device_type'], name='thl_session_device__5baa4f_idx'), models.Index(fields=['ip'], name='thl_session_ip_0bb4e0_idx')], + }, + ), + migrations.CreateModel( + name='THLUserMetadata', + fields=[ + ('user', models.OneToOneField(on_delete=django.db.models.deletion.RESTRICT, primary_key=True, serialize=False, to='thl_django.thluser')), + ('email_address', models.CharField(max_length=320, null=True)), + ('email_sha256', models.CharField(max_length=64, null=True)), + ('email_sha1', models.CharField(max_length=40, null=True)), + ('email_md5', models.CharField(max_length=32, null=True)), + ], + options={ + 'db_table': 'thl_usermetadata', + }, + ), + migrations.AddIndex( + model_name='thluser', + index=models.Index(fields=['created'], name='thl_user_created_4f8f22_idx'), + ), + migrations.AddIndex( + model_name='thluser', + index=models.Index(fields=['last_seen'], name='thl_user_last_se_fe5137_idx'), + ), + migrations.AddIndex( + model_name='thluser', + index=models.Index(fields=['last_country_iso'], name='thl_user_last_co_ece962_idx'), + ), + migrations.AlterUniqueTogether( + name='thluser', + unique_together={('product_id', 'product_user_id')}, + ), + migrations.AddField( + model_name='thlwall', + name='session', + field=models.ForeignKey(on_delete=django.db.models.deletion.RESTRICT, related_name='session', to='thl_django.thlsession'), + ), + migrations.AddIndex( + model_name='userauditlog', + index=models.Index(fields=['created'], name='userhealth__created_633ca3_idx'), + ), + migrations.AddIndex( + model_name='userauditlog', + index=models.Index(fields=['user_id', 'created'], name='userhealth__user_id_e64509_idx'), + ), + migrations.AddIndex( + model_name='userauditlog', + index=models.Index(fields=['level', 'created'], name='userhealth__level_17f32e_idx'), + ), + migrations.AddIndex( + model_name='userauditlog', + index=models.Index(fields=['event_type', 'created'], name='userhealth__event_t_a45197_idx'), + ), + migrations.AddIndex( + model_name='usergroup', + index=models.Index(fields=['created'], name='marketplace_created_eecfde_idx'), + ), + migrations.AddIndex( + model_name='usergroup', + index=models.Index(fields=['user_id'], name='marketplace_user_id_a9b3ed_idx'), + ), + migrations.AddIndex( + model_name='usergroup', + index=models.Index(fields=['user_group'], name='marketplace_user_gr_47fab2_idx'), + ), + migrations.AlterUniqueTogether( + name='usergroup', + unique_together={('user_id', 'user_group')}, + ), + migrations.AddIndex( + model_name='userhealthiphistory', + index=models.Index(fields=['user_id', 'created'], name='userhealth__user_id_0f7e18_idx'), + ), + migrations.AddIndex( + model_name='userhealthiphistory', + index=models.Index(fields=['created'], name='userhealth__created_3cd6b7_idx'), + ), + migrations.AddIndex( + model_name='userhealthiphistory', + index=models.Index(fields=['ip'], name='userhealth__ip_eb3911_idx'), + ), + migrations.AddIndex( + model_name='userhealthwebsocketiphistory', + index=models.Index(fields=['user_id', 'created'], name='userhealth__user_id_c11198_idx'), + ), + migrations.AddIndex( + model_name='userhealthwebsocketiphistory', + index=models.Index(fields=['user_id', 'last_seen'], name='userhealth__user_id_1e0473_idx'), + ), + migrations.AddIndex( + model_name='userhealthwebsocketiphistory', + index=models.Index(fields=['created'], name='userhealth__created_0b3299_idx'), + ), + migrations.AddIndex( + model_name='userhealthwebsocketiphistory', + index=models.Index(fields=['last_seen'], name='userhealth__last_se_740e03_idx'), + ), + migrations.AddIndex( + model_name='userhealthwebsocketiphistory', + index=models.Index(fields=['ip'], name='userhealth__ip_4f31d3_idx'), + ), + migrations.AddIndex( + model_name='userprofileknowledgeitem', + index=models.Index(fields=['user_id'], name='marketplace_user_id_30ee59_idx'), + ), + migrations.AddIndex( + model_name='userprofileknowledgeitem', + index=models.Index(fields=['created'], name='marketplace_created_f5aa37_idx'), + ), + migrations.AddIndex( + model_name='userprofileknowledgeitem', + index=models.Index(fields=['property_id'], name='marketplace_propert_e74b55_idx'), + ), + migrations.AddIndex( + model_name='userprofileknowledgenumerical', + index=models.Index(fields=['user_id'], name='marketplace_user_id_8c520e_idx'), + ), + migrations.AddIndex( + model_name='userprofileknowledgenumerical', + index=models.Index(fields=['created'], name='marketplace_created_6185aa_idx'), + ), + migrations.AddIndex( + model_name='userprofileknowledgenumerical', + index=models.Index(fields=['property_id'], name='marketplace_propert_09a69d_idx'), + ), + migrations.AddIndex( + model_name='userprofileknowledgetext', + index=models.Index(fields=['user_id'], name='marketplace_user_id_29dcc6_idx'), + ), + migrations.AddIndex( + model_name='userprofileknowledgetext', + index=models.Index(fields=['created'], name='marketplace_created_842729_idx'), + ), + migrations.AddIndex( + model_name='userprofileknowledgetext', + index=models.Index(fields=['property_id'], name='marketplace_propert_72d583_idx'), + ), + migrations.AddField( + model_name='userquestionanswer', + name='question', + field=models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='thl_django.question'), + ), + migrations.AddIndex( + model_name='userstat', + index=models.Index(fields=['date'], name='userprofile_date_ec0d70_idx'), + ), + migrations.AddIndex( + model_name='userstat', + index=models.Index(fields=['user_id'], name='userprofile_user_id_e1f8da_idx'), + ), + migrations.AlterUniqueTogether( + name='userstat', + unique_together={('key', 'user_id')}, + ), + migrations.AlterUniqueTogether( + name='brokerageproductconfig', + unique_together={('product', 'key')}, + ), + migrations.AddIndex( + model_name='buyergroup', + index=models.Index(fields=['created'], name='marketplace_created_ff147a_idx'), + ), + migrations.AddIndex( + model_name='buyergroup', + index=models.Index(fields=['group'], name='marketplace_group_4d716b_idx'), + ), + migrations.AddIndex( + model_name='ledgeraccountstatement', + index=models.Index(fields=['account', 'filter_str', 'effective_at_lower_bound'], name='ledger_acco_account_32b783_idx'), + ), + migrations.AddIndex( + model_name='ledgertransactionmetadata', + index=models.Index(fields=['key', 'value'], name='ledger_tran_key_4e20eb_idx'), + ), + migrations.AlterUniqueTogether( + name='ledgertransactionmetadata', + unique_together={('transaction', 'key')}, + ), + migrations.AddIndex( + model_name='propertycategoryassociation', + index=models.Index(fields=['property_id'], name='marketplace_propert_bf7dff_idx'), + ), + migrations.AddIndex( + model_name='propertyitemrange', + index=models.Index(fields=['country_iso', 'property_id'], name='marketplace_country_bbc7ce_idx'), + ), + migrations.AddIndex( + model_name='survey', + index=models.Index(fields=['source', 'is_live'], name='marketplace_source_c2ce68_idx'), + ), + migrations.AddIndex( + model_name='survey', + index=models.Index(condition=models.Q(('is_live', True)), fields=['source'], name='survey_live_by_source'), + ), + migrations.AddIndex( + model_name='survey', + index=models.Index(fields=['created_at'], name='marketplace_created_6b8446_idx'), + ), + migrations.AddIndex( + model_name='survey', + index=models.Index(fields=['updated_at'], name='marketplace_updated_414ab2_idx'), + ), + migrations.AddConstraint( + model_name='survey', + constraint=models.UniqueConstraint(fields=('source', 'survey_id'), name='uniq_survey_source_survey_id'), + ), + migrations.AddConstraint( + model_name='surveycategory', + constraint=models.UniqueConstraint(fields=('survey', 'category'), name='uniq_survey_category'), + ), + migrations.AddIndex( + model_name='surveystat', + index=models.Index(fields=['updated_at'], name='marketplace_updated_439a2d_idx'), + ), + migrations.AddConstraint( + model_name='surveystat', + constraint=models.UniqueConstraint(fields=('survey', 'quota_id', 'country_iso', 'version'), name='uniq_surveystat_survey_quota_country_version'), + ), + migrations.AddIndex( + model_name='thlusermetadata', + index=models.Index(fields=['email_address'], name='thl_usermet_email_a_2414aa_idx'), + ), + migrations.AddIndex( + model_name='thlusermetadata', + index=models.Index(fields=['email_sha256'], name='thl_usermet_email_s_b37322_idx'), + ), + migrations.AddIndex( + model_name='thlusermetadata', + index=models.Index(fields=['email_sha1'], name='thl_usermet_email_s_816978_idx'), + ), + migrations.AddIndex( + model_name='thlusermetadata', + index=models.Index(fields=['email_md5'], name='thl_usermet_email_m_deff9d_idx'), + ), + migrations.AddIndex( + model_name='thlwall', + index=models.Index(fields=['started'], name='thl_wall_started_091924_idx'), + ), + migrations.AddIndex( + model_name='thlwall', + index=models.Index(fields=['source', 'survey_id', 'started'], name='thl_wall_source_016b11_idx'), + ), + migrations.AddIndex( + model_name='thlwall', + index=models.Index(fields=['source', 'status'], name='thl_wall_source_a6d26f_idx'), + ), + migrations.AddIndex( + model_name='thlwall', + index=models.Index(fields=['source', 'status_code_1'], name='thl_wall_source_feb05b_idx'), + ), + migrations.AddIndex( + model_name='thlwall', + index=models.Index(condition=models.Q(('adjusted_status__isnull', False)), fields=['adjusted_status'], name='thl_wall_adj_status_nn_idx'), + ), + migrations.AddIndex( + model_name='thlwall', + index=models.Index(condition=models.Q(('adjusted_timestamp__isnull', False)), fields=['adjusted_timestamp'], name='thl_wall_adj_ts_nn_idx'), + ), + migrations.AddIndex( + model_name='thlwall', + index=models.Index(fields=['cpi'], name='thl_wall_cpi_0481c1_idx'), + ), + migrations.AlterUniqueTogether( + name='thlwall', + unique_together={('session', 'source', 'survey_id')}, + ), + migrations.AddIndex( + model_name='userquestionanswer', + index=models.Index(fields=['user_id', 'question_id', '-created'], name='marketplace_user_id_3c045f_idx'), + ), + migrations.AddIndex( + model_name='userquestionanswer', + index=models.Index(fields=['created'], name='marketplace_created_336ac8_idx'), + ), + ] diff --git a/generalresearch/thl_django/migrations/0002_surveystat_is_live_alter_surveycategory_strength_and_more.py b/generalresearch/thl_django/migrations/0002_surveystat_is_live_alter_surveycategory_strength_and_more.py new file mode 100644 index 0000000..211c48a --- /dev/null +++ b/generalresearch/thl_django/migrations/0002_surveystat_is_live_alter_surveycategory_strength_and_more.py @@ -0,0 +1,35 @@ +# Generated by Django 6.0 on 2025-12-28 16:49 + +from django.db import migrations, models +from django.contrib.postgres.operations import AddIndexConcurrently + + +class Migration(migrations.Migration): + atomic = False + + dependencies = [ + ("thl_django", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="surveystat", + name="is_live", + field=models.BooleanField(default=False), + ), + migrations.AlterField( + model_name="surveycategory", + name="strength", + field=models.FloatField( + help_text="Relative relevance or confidence (0–1)", null=True + ), + ), + AddIndexConcurrently( + model_name="surveystat", + index=models.Index( + condition=models.Q(("is_live", True)), + fields=["survey"], + name="surveystat_live_survey_idx", + ), + ), + ] diff --git a/generalresearch/thl_django/migrations/0003_remove_surveystat_surveystat_live_survey_idx_and_more.py b/generalresearch/thl_django/migrations/0003_remove_surveystat_surveystat_live_survey_idx_and_more.py new file mode 100644 index 0000000..ecaf0a9 --- /dev/null +++ b/generalresearch/thl_django/migrations/0003_remove_surveystat_surveystat_live_survey_idx_and_more.py @@ -0,0 +1,45 @@ +# Generated by Django 6.0 on 2025-12-29 21:22 + +from django.db import migrations, models +from django.contrib.postgres.operations import ( + AddIndexConcurrently, + RemoveIndexConcurrently, +) + + +class Migration(migrations.Migration): + atomic = False + + dependencies = [ + ("thl_django", "0002_surveystat_is_live_alter_surveycategory_strength_and_more"), + ] + + operations = [ + RemoveIndexConcurrently( + model_name="surveystat", + name="surveystat_live_survey_idx", + ), + migrations.RenameField( + model_name="surveystat", + old_name="is_live", + new_name="survey_is_live", + ), + migrations.AddField( + model_name="surveystat", + name="survey_source", + field=models.CharField(max_length=2, null=True), + ), + migrations.AddField( + model_name="surveystat", + name="survey_survey_id", + field=models.CharField(max_length=32, null=True), + ), + AddIndexConcurrently( + model_name="surveystat", + index=models.Index( + condition=models.Q(("survey_is_live", True)), + fields=["survey"], + name="surveystat_live_survey_idx", + ), + ), + ] diff --git a/generalresearch/thl_django/migrations/0004_alter_surveystat_survey_is_live_and_more.py b/generalresearch/thl_django/migrations/0004_alter_surveystat_survey_is_live_and_more.py new file mode 100644 index 0000000..e791759 --- /dev/null +++ b/generalresearch/thl_django/migrations/0004_alter_surveystat_survey_is_live_and_more.py @@ -0,0 +1,28 @@ +# Generated by Django 6.0 on 2025-12-29 23:21 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('thl_django', '0003_remove_surveystat_surveystat_live_survey_idx_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='surveystat', + name='survey_is_live', + field=models.BooleanField(), + ), + migrations.AlterField( + model_name='surveystat', + name='survey_source', + field=models.CharField(max_length=2), + ), + migrations.AlterField( + model_name='surveystat', + name='survey_survey_id', + field=models.CharField(max_length=32), + ), + ] diff --git a/generalresearch/thl_django/migrations/0005_remove_surveystat_marketplace_updated_439a2d_idx.py b/generalresearch/thl_django/migrations/0005_remove_surveystat_marketplace_updated_439a2d_idx.py new file mode 100644 index 0000000..0922075 --- /dev/null +++ b/generalresearch/thl_django/migrations/0005_remove_surveystat_marketplace_updated_439a2d_idx.py @@ -0,0 +1,17 @@ +# Generated by Django 6.0 on 2025-12-31 22:24 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("thl_django", "0004_alter_surveystat_survey_is_live_and_more"), + ] + + operations = [ + migrations.RemoveIndex( + model_name="surveystat", + name="marketplace_updated_439a2d_idx", + ), + ] diff --git a/generalresearch/thl_django/migrations/0006_remove_thlsession_thl_session_status_d578b7_idx_and_more.py b/generalresearch/thl_django/migrations/0006_remove_thlsession_thl_session_status_d578b7_idx_and_more.py new file mode 100644 index 0000000..e2492ab --- /dev/null +++ b/generalresearch/thl_django/migrations/0006_remove_thlsession_thl_session_status_d578b7_idx_and_more.py @@ -0,0 +1,35 @@ +# Generated by Django 6.0 on 2026-01-02 17:38 + +from django.db import migrations +from django.contrib.postgres.operations import RemoveIndexConcurrently + + +class Migration(migrations.Migration): + atomic = False + + dependencies = [ + ('thl_django', '0005_remove_surveystat_marketplace_updated_439a2d_idx'), + ] + + operations = [ + RemoveIndexConcurrently( + model_name='thlsession', + name='thl_session_status_d578b7_idx', + ), + RemoveIndexConcurrently( + model_name='thlsession', + name='thl_session_status__4c18db_idx', + ), + RemoveIndexConcurrently( + model_name='thlwall', + name='thl_wall_source_a6d26f_idx', + ), + RemoveIndexConcurrently( + model_name='thlwall', + name='thl_wall_source_feb05b_idx', + ), + RemoveIndexConcurrently( + model_name='thlwall', + name='thl_wall_cpi_0481c1_idx', + ), + ] diff --git a/generalresearch/thl_django/migrations/0007_table_params.py b/generalresearch/thl_django/migrations/0007_table_params.py new file mode 100644 index 0000000..0feeb58 --- /dev/null +++ b/generalresearch/thl_django/migrations/0007_table_params.py @@ -0,0 +1,68 @@ +# Generated by Django 6.0 on 2026-01-02 18:36 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("thl_django", "0006_remove_thlsession_thl_session_status_d578b7_idx_and_more"), + ] + + operations = [ + migrations.RunSQL( + sql=""" + ALTER TABLE marketplace_userquestionanswer SET ( + autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2 + ); + ALTER TABLE ledger_transactionmetadata SET ( + autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2 + ); + ALTER TABLE ledger_entry SET ( + autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2 + ); + ALTER TABLE ledger_transaction SET ( + autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2 + ); + ALTER TABLE userhealth_iphistory SET ( + autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2 + ); + ALTER TABLE marketplace_userprofileknowledgeitem SET ( + autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2 + ); + ALTER TABLE userhealth_auditlog SET ( + autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2 + ); + ALTER TABLE thl_ipinformation SET ( + autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2 + ); + ALTER TABLE marketplace_userprofileknowledgenumerical SET ( + autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2 + ); + ALTER TABLE marketplace_userprofileknowledgetext SET ( + autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2 + ); + + ALTER TABLE thl_wall SET ( + autovacuum_analyze_scale_factor = 0.002, autovacuum_analyze_threshold = 10000, + autovacuum_vacuum_scale_factor = 0.02, autovacuum_vacuum_threshold = 5000, + fillfactor = 85 + ); + ALTER TABLE thl_session SET ( + autovacuum_analyze_scale_factor = 0.002, autovacuum_analyze_threshold = 10000, + autovacuum_vacuum_scale_factor = 0.02, autovacuum_vacuum_threshold = 5000, + fillfactor = 85 + ); + ALTER TABLE thl_user SET ( + autovacuum_analyze_scale_factor = 0.002, autovacuum_analyze_threshold = 10000, + autovacuum_vacuum_scale_factor = 0.02, autovacuum_vacuum_threshold = 5000, + fillfactor = 85 + ); + ALTER TABLE marketplace_surveystat SET ( + autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, + autovacuum_vacuum_scale_factor = 0.02, autovacuum_vacuum_threshold = 50000, + fillfactor = 60 + ); + """, + ), + ] diff --git a/generalresearch/thl_django/migrations/0008_question_explanation_fragment_template_and_more.py b/generalresearch/thl_django/migrations/0008_question_explanation_fragment_template_and_more.py new file mode 100644 index 0000000..eed7440 --- /dev/null +++ b/generalresearch/thl_django/migrations/0008_question_explanation_fragment_template_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 6.0 on 2026-01-29 18:25 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("thl_django", "0007_table_params"), + ] + + operations = [ + migrations.AddField( + model_name="question", + name="explanation_fragment_template", + field=models.TextField(max_length=255, null=True), + ), + migrations.AddField( + model_name="question", + name="explanation_template", + field=models.TextField(max_length=255, null=True), + ), + ] diff --git a/generalresearch/thl_django/migrations/__init__.py b/generalresearch/thl_django/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/thl_django/postgres-table-tuning.md b/generalresearch/thl_django/postgres-table-tuning.md new file mode 100644 index 0000000..3da1a5d --- /dev/null +++ b/generalresearch/thl_django/postgres-table-tuning.md @@ -0,0 +1,62 @@ + +## PostgreSQL Table Storage & Autovacuum Tuning + +For certain very large, append-only tables, we want to adjust some setting +to get the vacuum and analyze to run within a reasonable frequency. + +### Manual +First, manually analyze all tables. Note, ensure you do this with a user that has +permission, or it will just do nothing. Verify with `ANALYZE VERBOSE event_bribe;` +```postgresql +ANALYZE +``` + +### Autoanalyze +For tables that are large and typically append only, the default auto-analyze config +results in analyze almost never getting run. e.g. if a table has 10 million rows, +and the autovacuum_analyze_scale_factor default is 10%, then 1 million rows +have to be inserted or updated before the auto-analyze runs. + +```postgresql +ALTER TABLE marketplace_userquestionanswer SET (autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2); +ALTER TABLE ledger_transactionmetadata SET (autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2); +ALTER TABLE ledger_entry SET (autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2); +ALTER TABLE ledger_transaction SET (autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2); +ALTER TABLE userhealth_iphistory SET (autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2); +ALTER TABLE marketplace_userprofileknowledgeitem SET (autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2); +ALTER TABLE userhealth_auditlog SET (autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2); +ALTER TABLE thl_ipinformation SET (autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2); +ALTER TABLE marketplace_userprofileknowledgenumerical SET (autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2); +ALTER TABLE marketplace_userprofileknowledgetext SET (autovacuum_analyze_scale_factor = 0.01, autovacuum_analyze_threshold = 50000, autovacuum_vacuum_scale_factor = 0.2); + +ALTER TABLE thl_wall SET ( + autovacuum_analyze_scale_factor = 0.002, autovacuum_analyze_threshold = 10000, + autovacuum_vacuum_scale_factor = 0.02, autovacuum_vacuum_threshold = 5000, + fillfactor = 85 +); +ALTER TABLE thl_session SET ( + autovacuum_analyze_scale_factor = 0.002, autovacuum_analyze_threshold = 10000, + autovacuum_vacuum_scale_factor = 0.02, autovacuum_vacuum_threshold = 5000, + fillfactor = 85 +); +ALTER TABLE thl_user SET ( + autovacuum_analyze_scale_factor = 0.002, autovacuum_analyze_threshold = 10000, + autovacuum_vacuum_scale_factor = 0.02, autovacuum_vacuum_threshold = 5000, + fillfactor = 85 +); + +``` + +### SurveyStats + +```postgresql +ALTER TABLE marketplace_surveystat SET (fillfactor = 60); +ALTER TABLE marketplace_surveystat SET ( + autovacuum_analyze_scale_factor = 0.01, + autovacuum_analyze_threshold = 50000, + autovacuum_vacuum_scale_factor = 0.02, + autovacuum_vacuum_threshold = 50000 +); + + +``` \ No newline at end of file diff --git a/generalresearch/thl_django/postgres.md b/generalresearch/thl_django/postgres.md new file mode 100644 index 0000000..a946440 --- /dev/null +++ b/generalresearch/thl_django/postgres.md @@ -0,0 +1,143 @@ +## Suggested Postgres Settings + +**/etc/postgresql/18/main/postgresql.conf** + +For both primary and any replicas config +```text + +shared_buffers = 2048MB # Default 128MB (25% of total RAM) +effective_cache_size = 6GB # (~75% of total RAM) +work_mem = 12MB # Default 4MB +maintenance_work_mem = 256MB # Default 64MB + +# On fast ssds +random_page_cost = 1.3 +effective_io_concurrency = 50 + +statement_timeout = 60min # default unlimited +lock_timeout = 60s # default unlimited +idle_in_transaction_session_timeout = 10min # default unlimited + +min_wal_size = 1GB +max_wal_size = 4GB +max_parallel_workers_per_gather = 4 +max_parallel_maintenance_workers = 4 +max_parallel_workers = 8 +``` + +## Suggested Postgres Settings for Read Replica instances + +In the primary's config +```text +wal_level=replica +max_wal_senders=5 +hot_standby=on +wal_sender_timeout=10s +wal_keep_size=4GB +``` + +In the replica's config +```text +hot_standby_feedback = on +``` + +```postgresql +--- To run before pg_basebackup +SELECT pg_create_physical_replication_slot('xxx_dr_001'); + +--- To run if pg_basebackup fails +SELECT pg_drop_replication_slot('xxx_dr_001'); +``` + +```text +pg_basebackup -h ___ -D /var/lib/postgresql/18/main -U repl____ -P -v -R -X stream -C -S 'xxx_dr_001' +``` + +## Setting up Percona Monitoring + +Install the pmm-client. As of 2025-12-16 Percona doesn't have 3.5 on the Trixie +source, so we use bookworm still. + +```bash +sudo apt-get install gpgv lsb-release gnupg curl vim -y +wget https://repo.percona.com/apt/percona-release_latest.generic_all.deb +sudo dpkg -i percona-release_latest.generic_all.deb +sudo percona-release enable pmm3-client + +sed -i 's/trixie/bookworm/g' /etc/apt/sources.list.d/percona-pmm3-client-release.list +sed -i 's/trixie/bookworm/g' /etc/apt/sources.list.d/percona-prel-release.list +sed -i 's/trixie/bookworm/g' /etc/apt/sources.list.d/percona-telemetry-release.list + +sudo apt update +sudo apt install -y pmm-client +pmm-admin --version +``` + +With the base pmm-client installed, go ahead and setup the enhanced +pg_stat_monitor system. + +```bash +percona-release setup ppg-18 +apt-get install percona-pg-stat-monitor18 +``` + +Add these to the postgresql server's configuration file + +`vim /etc/postgresql/18/main/postgresql.conf` + +```text +shared_preload_libraries = 'pg_stat_monitor' +# Options: https://docs.percona.com/pg-stat-monitor/configuration.html#parameters-description +pg_stat_monitor.pgsm_query_max_len = 2048 +pg_stat_monitor.pgsm_normalized_query = 1 +pg_stat_monitor.pgsm_track = all +pg_stat_monitor.pgsm_enable_query_plan = 0 +pg_stat_monitor.pgsm_extract_comments = 1 +pg_stat_monitor.pgsm_track_planning = 0 +``` + +Setup a `pmm` reporting user for the database, and make sure to explicitly +allow it localhost connection permission. + +``` +vim /etc/postgresql/18/main/pg_hba.conf +host all pmm 127.0.0.1/32 scram-sha-256 +``` + +Create the extension, confirm it's working, and make sure the pmm user's +permissions are good. + +`sudo -u postgres psql`: + +```postgresql +ALTER SYSTEM SET pg_stat_monitor.pgsm_enable_query_plan = off; +SELECT pg_reload_conf(); + +-- Only create on the Master instance +CREATE EXTENSION IF NOT EXISTS pg_stat_monitor; +CREATE EXTENSION IF NOT EXISTS pg_stat_statements; +SELECT * FROM pg_available_extensions WHERE name = 'pg_stat_monitor'; +\dx pg_stat_monitor +SELECT COUNT(*) FROM pg_stat_monitor; + +CREATE USER pmm WITH PASSWORD '______'; +GRANT pg_monitor TO pmm; +GRANT CONNECT ON DATABASE postgres TO pmm; + +\c postgres +GRANT pg_monitor TO pmm; + +\c ______ +GRANT CONNECT ON DATABASE ______ TO pmm; +GRANT USAGE ON SCHEMA public TO pmm; +GRANT SELECT ON ALL TABLES IN SCHEMA public TO pmm; +ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO pmm; +``` + +``` + +pmm-admin config --server-insecure-tls --server-url=https://______:______@______.internal:443 +pmm-admin remove postgresql ______-postgresql +pmm-admin add postgresql --username=pmm --password='______' --host=127.0.0.1 --port=5432 --service-name=______-postgresql-__ --query-source=pgstatmonitor + +``` \ No newline at end of file diff --git a/generalresearch/thl_django/userhealth/__init__.py b/generalresearch/thl_django/userhealth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/thl_django/userhealth/models.py b/generalresearch/thl_django/userhealth/models.py new file mode 100644 index 0000000..1982098 --- /dev/null +++ b/generalresearch/thl_django/userhealth/models.py @@ -0,0 +1,127 @@ +from django.db import models + + +class UserHealthIPHistory(models.Model): + user_id = models.BigIntegerField(null=False) + ip = models.GenericIPAddressField() + created = models.DateTimeField(auto_now_add=True) + # Store any IPs in the X-Forwarded-For header, in order starting + # with forwarded_ip1 + forwarded_ip1 = models.GenericIPAddressField(null=True) + forwarded_ip2 = models.GenericIPAddressField(null=True) + forwarded_ip3 = models.GenericIPAddressField(null=True) + forwarded_ip4 = models.GenericIPAddressField(null=True) + forwarded_ip5 = models.GenericIPAddressField(null=True) + forwarded_ip6 = models.GenericIPAddressField(null=True) + + class Meta: + """ + We should NOT have a unique index on ('user_id', 'ip') b/c we should + insert a duplicate row (w a new timestamp) if this user is still using + this IP after N days (~7?). So that we never have to look too far back + to get a user's "current" IP. + """ + + db_table = "userhealth_iphistory" + indexes = [ + models.Index(fields=["user_id", "created"]), + models.Index(fields=["created"]), + models.Index(fields=["ip"]), + ] + + +class UserHealthWebSocketIPHistory(models.Model): + """ + Table for logging any websocket request that came from our GRS page. + + field:last_seen - the latest timestamp of user's particular IP address + that he hit us with before he switched IP address (or before NOW) + + Example using user_id,IP,timestamp: + 12345, x.x.x.x, 2023-11-11 16:00 + 12345, x.x.x.x, 2023-11-11 16:02 + 12345, x.x.x.x, 2023-11-11 16:03 + 12345, y.y.y.y, 2023-11-11 16:05 + 12345, x.x.x.x, 2023-11-11 16:07 + 98765, x.x.x.x, 2023-11-11 16:08 + 12345, z.z.z.z, 2023-11-11 16:10 + 12345, y.y.y.y, 2023-11-11 16:12 + 12345, y.y.y.y, 2024-11-11 16:15 + + Then Mysql data: + user_id, IP, created, last_seen + 12345, x.x.x.x, 2023-11-11 16:00, 2023-11-11 16:03 + 12345, y.y.y.y, 2023-11-11 16:05, 2023-11-11 16:05 + 12345, x.x.x.x, 2023-11-11 16:07, 2023-11-11 16:07 + 98765, x.x.x.x, 2023-11-11 16:08, 2023-11-11 16:08 + 12345, z.z.z.z, 2023-11-11 16:10, 2023-11-11 16:10 + 12345, y.y.y.y, 2023-11-11 16:12, 2024-11-11 16:15 + """ + + user_id = models.BigIntegerField(null=False) + ip = models.GenericIPAddressField() + created = models.DateTimeField(auto_now_add=True) + last_seen = models.DateTimeField(auto_now_add=True) + + class Meta: + """ + I'm not sure about the use case of index (user_id, last_seen). But + inserting data in this table is not in the hot path, so let's keep it. + """ + + db_table = "userhealth_iphistory_ws" + indexes = [ + models.Index(fields=["user_id", "created"]), + models.Index(fields=["user_id", "last_seen"]), + models.Index(fields=["created"]), + models.Index(fields=["last_seen"]), + models.Index(fields=["ip"]), + ] + + +class UserAuditLog(models.Model): + """ + Table for logging "actions" taken by a user or "events" related to a user + """ + + # The table will have a default autoincrement key + id = models.BigAutoField(primary_key=True) + + # The user this event pertains to + user_id = models.BigIntegerField(null=False) + + # When this event happened + created = models.DateTimeField(null=False) + + # The level of importance for this event. Works the same as python + # logging levels. It is an integer 0 - 50, and implementers of this + # field could map it to the predefined levels: (CRITICAL, ERROR, WARNING, + # INFO, DEBUG). + # + # This is NOT the same concept as the "strength" of whatever event happened; + # it is just for sorting, filtering and display purposes. For e.g. + # multiple level 20 events != the "importance" of one level 40 event. + level = models.PositiveSmallIntegerField(null=False, default=0) + + # The "class" or "type" or event that happened. + # e.g. "upk-audit", "ip-audit", "entrance-limit" + event_type = models.CharField(max_length=64, null=False) + + # The event message. Could be displayed on user's page + event_msg = models.CharField(max_length=256, null=True) + + # Optionally store a numeric value associated with this event. For e.g. + # if we recalculate the user's normalized recon rate, and it is "high", + # we could store an event like (event_type="recon-rate", + # event_msg="higher than allowed recon rate", event_value=0.42) + event_value = models.FloatField(null=True) + + class Meta: + db_table = "userhealth_auditlog" + + indexes = [ + models.Index(fields=["created"]), + models.Index(fields=["user_id", "created"]), + models.Index(fields=["level", "created"]), + models.Index(fields=["event_type", "created"]), + ] diff --git a/generalresearch/thl_django/userprofile/__init__.py b/generalresearch/thl_django/userprofile/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/thl_django/userprofile/models.py b/generalresearch/thl_django/userprofile/models.py new file mode 100644 index 0000000..b474bef --- /dev/null +++ b/generalresearch/thl_django/userprofile/models.py @@ -0,0 +1,178 @@ +import uuid + +from django.db import models + + +class UserStat(models.Model): + """This is for storing userstats calculated by yieldman. Only one user_id, + key is allowed and the value gets updated. + """ + + user_id = models.PositiveIntegerField() + key = models.CharField(max_length=255) + value = models.FloatField(null=True) + date = models.DateTimeField(auto_now=True) + + class Meta: + db_table = "userprofile_userstat" + + unique_together = ("key", "user_id") + indexes = [ + models.Index(fields=["date"]), + models.Index(fields=["user_id"]), + ] + + +class BrokerageProduct(models.Model): + """Represents a FSB, or other Product on General Research""" + + id = models.UUIDField(primary_key=True) + id_int = models.BigIntegerField(null=False, unique=True) + + name = models.CharField(max_length=255, unique=False) + # For migration, then change to false + team_id = models.UUIDField(null=True) + business_id = models.UUIDField(null=True) + + # We can back-pop the created timestamps from GR + created = models.DateTimeField(auto_now_add=True, null=True) + + enabled = models.BooleanField(default=True) + payments_enabled = models.BooleanField(default=True) + + # --- Config fields (some of these used to be in BrokerageProductConfig) --- + + # The commission percentage we charge. Should be between 0 and 1 + # inclusive. Temporarily null=True + commission = models.DecimalField(null=True, decimal_places=6, max_digits=6) + + # Where users are redirected to after finishing a task. Formerly known + # as callback_uri. This is temporarily null=True. + redirect_url = models.URLField(null=True) + + # The domain to use for GRS. Formerly known as harmonizer_domain. This is + # temporarily null=True. + grs_domain = models.CharField(max_length=200, null=True) + + # Stores config for the Profiling experience. FKA harmonizer_config. + # (e.g. task_injection_freq_mult, n_questions) + profiling_config = models.JSONField(default=dict) + + # Stores config for UserHealth (e.g. allow_ban_iphist, conversion_cutoff) + user_health_config = models.JSONField(default=dict) + + # Stores config for yield management (e.g. conversion_factor_adj). These + # are things that pertain to single tasks. + yield_man_config = models.JSONField(default=dict) + + # Stores config for offerwall creation (e.g. min_bin_size, n_bins) + offerwall_config = models.JSONField(default=dict) + + # Stores config for session creation (e.g. max_session_len, + # max_session_hard_retry, min_payout, etc) + session_config = models.JSONField(default=dict) + + # Store config for payouts and user payouts (payout_transformation, + # payout_format, etc.) + payout_config = models.JSONField(default=dict) + + # Store configuration regarding user creation. See: models/thl/product.py:UserCreateConfig + user_create_config = models.JSONField(default=dict) + + class Meta: + db_table = "userprofile_brokerageproduct" + + # Each name has to be unique within a team, but there can be multiple + # BPs with the same name overall + unique_together = ("team_id", "name") + + +class BrokerageProductConfig(models.Model): + """ + Represents the configuration settings for a FSB, or other Product + on General Research + """ + + product = models.ForeignKey( + BrokerageProduct, null=False, on_delete=models.DO_NOTHING + ) + key = models.CharField(max_length=255) + value = models.JSONField(default=dict) + + class Meta: + db_table = "userprofile_brokerageproductconfig" + + unique_together = ("product", "key") + + +class BrokerageProductTag(models.Model): + """ + Stores Tags for brokerage products which can be used to annotate + supplier traffic + """ + + id = models.BigAutoField(primary_key=True, null=False) + + product_id = models.BigIntegerField(null=False) + + # The allowed values are defined in models/thl/supplier_tag.py + tag = models.CharField(max_length=64, null=False) + + class Meta: + db_table = "userprofile_brokerageproducttag" + + # Tags are unique per product + unique_together = ("product_id", "tag") + + +class Language(models.Model): + """ + Languages we allow user's to do tasks in + Uses the ISO 639-2/B system. + + https://en.wikipedia.org/wiki/List_of_ISO_639-2_codes + """ + + code = models.CharField( + primary_key=True, max_length=3, help_text="three-letter language code" + ) + name = models.CharField(max_length=255, help_text="language name") + + class Meta: + db_table = "userprofile_language" + + +class PayoutMethod(models.Model): + """ + + ***Deprecated*** Nothing uses this + + An "Account" for users to send money to. Separated out as it + shouldn't be tied to authentication, and a user might want to + send to multiple places + """ + + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + user_id = models.IntegerField(null=True, db_index=True) + + default = models.BooleanField(default=False) + # We'll never delete these, so need a way to monitor + enabled = models.BooleanField(default=True) + + PAYOUT_CHOICES = ( + ("a", "AMT"), + ("c", "ACH"), + ("t", "Tango"), + ("p", "PAYPAL"), + ) + method = models.CharField(choices=PAYOUT_CHOICES, max_length=1, default="t") + recipient = models.CharField(max_length=200, blank=True, null=True) + + updated = models.DateTimeField(auto_now=True) + created = models.DateTimeField(auto_now_add=True) + + class Meta: + db_table = "userprofile_payoutmethod" + + ordering = ("-created",) + get_latest_by = "created" diff --git a/generalresearch/utils/__init__.py b/generalresearch/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/utils/aggregation.py b/generalresearch/utils/aggregation.py new file mode 100644 index 0000000..b168e4c --- /dev/null +++ b/generalresearch/utils/aggregation.py @@ -0,0 +1,14 @@ +from collections import defaultdict +from typing import Dict, List + + +def group_by_year(records: List[Dict], datetime_field: str) -> Dict[int, List]: + """Memory efficient - processes records one at a time""" + by_year = defaultdict(list) + + for record in records: + # Extract year from ISO string without full datetime parsing + year = int(record[datetime_field][:4]) + by_year[year].append(record) + + return dict(by_year) diff --git a/generalresearch/utils/copying_cache.py b/generalresearch/utils/copying_cache.py new file mode 100644 index 0000000..ea13f69 --- /dev/null +++ b/generalresearch/utils/copying_cache.py @@ -0,0 +1,21 @@ +from copy import deepcopy +from functools import wraps +from typing import Callable + + +def deepcopy_return(fn: Callable) -> Callable: + """ + Using this as a decorator to decorate lru_cached functions, because if we + store mutable objects in the cache and then modify them in place, + it would mutate in the cache, which typically we don't want. + + See also: # https://stackoverflow.com/a/54909677/1991066, which I'm not + using because it wraps the lru_cache and prevents us from accessing + the methods on it (like cache_clear()). + """ + + @wraps(fn) + def wrapper(*args, **kwargs): + return deepcopy(fn(*args, **kwargs)) + + return wrapper diff --git a/generalresearch/utils/enum.py b/generalresearch/utils/enum.py new file mode 100644 index 0000000..14a31de --- /dev/null +++ b/generalresearch/utils/enum.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import inspect +import re +from enum import EnumMeta +from typing import Dict + + +class ReprEnumMeta(EnumMeta): + def as_openapi(self) -> str: + return "\n".join([f" - `{e.value}` = {e.name}" for e in self]) + + def as_openapi_with_value_descriptions(self) -> str: + descriptions = get_enum_comments(self) + + # This doesn't work in Python 3.12, so check if None + val = self.__doc__ + if val: + return f"{val.strip()}\n\nAllowed values: \n" + "\n".join( + [f" - __{e.value}__ *({e.name})*: {descriptions[e.name]}" for e in self] + ) + else: + return f"\nAllowed values: \n" + "\n".join( + [f" - __{e.value}__ *({e.name})*: {descriptions[e.name]}" for e in self] + ) + + def as_openapi_with_value_descriptions_name(self) -> str: + # For use when the allowed values are the enum's NAME (like in the + # task status's status_code_1) + descriptions = get_enum_comments(self) + + # This doesn't work in Python 3.12, so check if None + val = self.__doc__ + if val: + return f"{val.strip()}\n\nAllowed values: \n" + "\n".join( + [f" - __{e.name}__: {descriptions[e.name]}" for e in self] + ) + else: + return f"\nAllowed values: \n" + "\n".join( + [f" - __{e.name}__: {descriptions[e.name]}" for e in self] + ) + + +def get_enum_comments(enum_class) -> Dict: + source = inspect.getsource(enum_class) + # Regular expression to match multi-line comments and enum values + pattern = re.compile(r"((?:\s*#.*?\n)+)\s*(\w+)\s*=") + matches = pattern.findall(source) + comments_dict = {} + for match in matches: + comment = [] + for line in match[0].strip().split("\n")[::-1]: + if line == "": + # Don't match empty lines in between comments + break + comment.append(line) + comment = "\n".join(comment[::-1]) + comment = comment.replace("\n", " ").replace("#", "").strip() + comment = re.sub(r"\s+", " ", comment) + comments_dict[match[1]] = comment + return comments_dict diff --git a/generalresearch/utils/grpc_logger.py b/generalresearch/utils/grpc_logger.py new file mode 100644 index 0000000..a98a13b --- /dev/null +++ b/generalresearch/utils/grpc_logger.py @@ -0,0 +1,78 @@ +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import time + +handler = TimedRotatingFileHandler( + "grpc_access.log", when="midnight", backupCount=3, encoding="utf-8" +) +handler.setFormatter(logging.Formatter("%(message)s")) + +logger = logging.getLogger("grpc_logger") +logger.setLevel(logging.INFO) +logger.addHandler(handler) +logger.propagate = False # avoid duplicate logs if root logger is used elsewhere + +try: + # generalresearch should NOT have a grpc dependency, so put + # this whole thing in a try-catch.. + import grpc + + class LoggingInterceptor(grpc.ServerInterceptor): + def intercept_service(self, continuation, handler_call_details): + method = handler_call_details.method + handler = continuation(handler_call_details) + + if handler is None: + return None + + def log_and_call(handler_func, request, context): + start_time = time.time() + code = grpc.StatusCode.INTERNAL + try: + response = handler_func(request, context) + code = context.code() or grpc.StatusCode.OK + return response + except Exception as e: + code = context.code() or grpc.StatusCode.INTERNAL + raise e + finally: + duration_ms = int((time.time() - start_time) * 1000) + peer = context.peer() or "unknown" + logger.info( + json.dumps( + { + "method": method, + "code_value": code.value[0], + "code_name": code.value[1], + "duration": duration_ms, + "peer": peer, + "time": start_time, + } + ) + ) + + if handler.unary_unary: + return grpc.unary_unary_rpc_method_handler( + lambda request, context: log_and_call( + handler.unary_unary, request, context + ), + request_deserializer=handler.request_deserializer, + response_serializer=handler.response_serializer, + ) + + elif handler.unary_stream: + return grpc.unary_stream_rpc_method_handler( + lambda request, context: log_and_call( + handler.unary_stream, request, context + ), + request_deserializer=handler.request_deserializer, + response_serializer=handler.response_serializer, + ) + + else: + return handler + +except ImportError as e: + print(e) + LoggingInterceptor = None diff --git a/generalresearch/wall_status_codes/__init__.py b/generalresearch/wall_status_codes/__init__.py new file mode 100644 index 0000000..102b72a --- /dev/null +++ b/generalresearch/wall_status_codes/__init__.py @@ -0,0 +1,105 @@ +from typing import Tuple, Optional + +from generalresearch.models import Source +from generalresearch.models.thl.definitions import Status, StatusCode1 +from generalresearch.models.thl.session import Wall +from generalresearch.wall_status_codes import ( + dynata, + fullcircle, + innovate, + morning, + pollfish, + precision, + spectrum, + sago, + cint, + lucid, + prodege, + repdata, +) + + +def annotate_status_code( + source: Source, + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +) -> Tuple[Status, Optional[StatusCode1], Optional[str]]: + """ + :params ext_status_code_1: marketplace-dependent code + :params ext_status_code_2: marketplace-dependent code + :params ext_status_code_3: marketplace-dependent code + + returns: (status, status_code_1, status_code_2) + """ + if source == Source.DALIA: + return Status.FAIL, StatusCode1.UNKNOWN, None + if source == Source.PULLEY: + return Status.FAIL, StatusCode1.UNKNOWN, None + return { + Source.CINT: cint.annotate_status_code, + Source.DYNATA: dynata.annotate_status_code, + Source.FULL_CIRCLE: fullcircle.annotate_status_code, + Source.INNOVATE: innovate.annotate_status_code, + Source.LUCID: lucid.annotate_status_code, + Source.MORNING_CONSULT: morning.annotate_status_code, + Source.POLLFISH: pollfish.annotate_status_code, + Source.PRECISION: precision.annotate_status_code, + Source.PRODEGE: prodege.annotate_status_code, + Source.SAGO: sago.annotate_status_code, + Source.SPECTRUM: spectrum.annotate_status_code, + Source.REPDATA: repdata.annotate_status_code, + }[source](ext_status_code_1, ext_status_code_2, ext_status_code_3) + + +def stop_marketplace_session(wall: Wall) -> bool: + if wall.source == Source.DYNATA: + return dynata.stop_marketplace_session( + wall.status_code_1, wall.ext_status_code_1 + ) + + elif wall.status_code_1 in { + StatusCode1.PS_QUALITY, + StatusCode1.BUYER_QUALITY_FAIL, + StatusCode1.PS_BLOCKED, + StatusCode1.UNKNOWN, + }: + return True + + return False + + +def is_soft_fail(wall: Wall) -> bool: + # Assuming this is already a fail... ignored otherwise + if wall.source == Source.FULL_CIRCLE: + # todo: this may not have been set when this is called?? + return fullcircle.is_soft_fail(wall.elapsed) + + elif wall.status_code_1 in { + StatusCode1.BUYER_FAIL, + StatusCode1.BUYER_QUALITY_FAIL, + StatusCode1.UNKNOWN, + }: + return False + + return True + + +# def stop_marketplace_session(source: Source, status_code_1: StatusCode1, ext_status_code_1: Optional[str] = None): +# # Each marketplace can have their own version, or use this one as the default +# if source == Source.DYNATA: +# return dynata.stop_marketplace_session(status_code_1, ext_status_code_1) +# if status_code_1 in {StatusCode1.PS_QUALITY, StatusCode1.BUYER_QUALITY_FAIL, StatusCode1.PS_BLOCKED}: +# return True +# return False +# +# +# def is_soft_fail(source: Source, status_code_1: StatusCode1, elapsed: Optional[timedelta] = None): +# # Each marketplace can have their own version, or use this one as the default +# # Assuming this is already a fail... ignored otherwise +# if source == Source.FULL_CIRCLE: +# assert elapsed is not None +# return fullcircle.is_soft_fail(elapsed) +# if status_code_1 in {StatusCode1.BUYER_FAIL, StatusCode1.BUYER_QUALITY_FAIL}: +# return False +# return True diff --git a/generalresearch/wall_status_codes/cint.py b/generalresearch/wall_status_codes/cint.py new file mode 100644 index 0000000..0fa929c --- /dev/null +++ b/generalresearch/wall_status_codes/cint.py @@ -0,0 +1,15 @@ +from typing import Optional + +from generalresearch.wall_status_codes import lucid + + +def annotate_status_code( + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +): + return lucid.annotate_status_code( + ext_status_code_1=ext_status_code_1, + ext_status_code_2=ext_status_code_2, + ext_status_code_3=ext_status_code_3, + ) diff --git a/generalresearch/wall_status_codes/dynata.py b/generalresearch/wall_status_codes/dynata.py new file mode 100644 index 0000000..f37fca5 --- /dev/null +++ b/generalresearch/wall_status_codes/dynata.py @@ -0,0 +1,128 @@ +""" +https://developers.dynata.com/docs/rex/branches/main/dispositions +checked by Greg 2023-10-10 +""" + +from collections import defaultdict +from typing import Optional, Tuple + +from generalresearch.models.thl.definitions import Status, StatusCode1 + +status_codes_name = { + "0.0": "Unknown", + "0.1": "Missing Language", + "0.2": "Missing Respondent ID", + "0.3": "Declined Consent", + "0.4": "Underage", + "0.5": "Invalid Locale", + "0.6": "Invalid Country", + "0.7": "Invalid Language", + "0.8": "Inactive Respondent", + "0.9": "Respondent Not Found", + "1.0": "Complete", + "1.1": "Partial Complete", + "2.1": "Dynata Disqualification", + "2.2": "Client Disqualification", + "2.3": "Incompatible Country", + "2.4": "Incompatible Language", + "2.5": "Incompatible Device", + "2.6": "Filter Disqualification", + "2.7": "Quota Disqualification", + "2.8": "Undisclosed Filter", # this is important for ym, should penalize + "2.10": "Partner Disqualification", + "3.0": "Default Over Quota", + "3.1": "Dynata Over Quota", + "3.2": "Client Over Quota", + "3.3": "Dynata Closed Quota", + "3.10": "Quota Group Not Open", + "3.11": "Quota Group Field Schedule", + "3.12": "Quota Group Click Balance", + "3.20": "Quota Cell Not Open", + "3.21": "Quota Cell Field Schedule", + "3.22": "Quota Cell Click Balance", + "4.0": "Duplicate", + "4.1": "Duplicate Respondent", + "4.2": "Category Exclusion", + "5.0": "General Quality", + "5.1": "Answer Quality", + "5.2": "Speeding", + "5.3": "Suspended", # aka blocked + "5.4": "Predicted Reconciliation", + "5.10": "Daily Limit", +} + +status_map = defaultdict( + lambda: Status.FAIL, **{"1.0": Status.COMPLETE, "1.1": Status.COMPLETE} +) +status_codes_ext_map = { + StatusCode1.COMPLETE: ["1.0", "1.1"], + StatusCode1.BUYER_FAIL: ["2.2", "3.2"], + StatusCode1.BUYER_QUALITY_FAIL: ["5.1", "5.2"], + StatusCode1.PS_BLOCKED: ["5.3"], + StatusCode1.PS_QUALITY: [ + "0.0", + "0.1", + "0.2", + "0.3", + "0.4", + "0.5", + "0.6", + "0.7", + "0.8", + "0.9", + "5.0", + "5.4", + ], + StatusCode1.PS_DUPLICATE: ["4.0", "4.1", "4.2"], + StatusCode1.PS_FAIL: ["2.1", "2.3", "2.4", "2.5", "2.6", "2.7", "2.8", "2.10"], + StatusCode1.PS_OVERQUOTA: [ + "3.0", + "3.1", + "3.3", + "3.10", + "3.11", + "3.12", + "3.20", + "3.21", + "3.22", + "5.10", + ], +} +ext_status_code_map = dict() +for k, v in status_codes_ext_map.items(): + for vv in v: + ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k + + +def annotate_status_code( + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +) -> Tuple: + """ + :params ext_status_code_1: this is from the callback url params: + disposition and status, '.'-joined + :params ext_status_code_2: not used + :params ext_status_code_3: not used + + returns: (status, status_code_1, status_code_2) + """ + status = status_map[ext_status_code_1] + status_code = ext_status_code_map.get(ext_status_code_1, StatusCode1.UNKNOWN) + + return status, status_code, None + + +def stop_marketplace_session(status_code_1, ext_status_code_1) -> bool: + if ext_status_code_1.startswith("5"): + # '5.10' is the user hit a Daily Limit, so they should not be sent in again today + return True + + if status_code_1 in { + StatusCode1.PS_QUALITY, + StatusCode1.BUYER_QUALITY_FAIL, + StatusCode1.PS_BLOCKED, + }: + return True + + return False diff --git a/generalresearch/wall_status_codes/fullcircle.py b/generalresearch/wall_status_codes/fullcircle.py new file mode 100644 index 0000000..08f2c44 --- /dev/null +++ b/generalresearch/wall_status_codes/fullcircle.py @@ -0,0 +1,58 @@ +""" +fullcircle doesn't really have status codes. there is no way to distinguish +between pre-screen and client terminations. We're going to call them +buyer fails for the wall and reporting, but for yield management purposes +we'll try to infer based on the time spent in survey. +""" + +from collections import defaultdict +from datetime import timedelta +from typing import Optional, Tuple + +from generalresearch.models.thl.definitions import Status, StatusCode1 + +status_codes_map = { + "1": "Complete", + "2": "Terminate", + "3": "Over-quota", + "4": "Quality terminate", +} + +status_map = defaultdict(lambda: Status.FAIL, **{"1": Status.COMPLETE}) +status_codes_ext_map = { + StatusCode1.COMPLETE: ["1"], + StatusCode1.BUYER_FAIL: ["2", "3"], + StatusCode1.BUYER_QUALITY_FAIL: ["4"], + StatusCode1.PS_BLOCKED: [], + StatusCode1.PS_QUALITY: [], + StatusCode1.PS_DUPLICATE: [], + StatusCode1.PS_FAIL: [], + StatusCode1.PS_OVERQUOTA: [], +} +ext_status_code_map = dict() +for k, v in status_codes_ext_map.items(): + for vv in v: + ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k + + +def annotate_status_code( + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +) -> Tuple: + """ + :params ext_status_code_1: this is from the callback url param 's' + :params ext_status_code_2: not used + :params ext_status_code_3: not used + + returns: (status, status_code_1, status_code_2) + """ + status = status_map[ext_status_code_1] + status_code = ext_status_code_map.get(ext_status_code_1, StatusCode1.UNKNOWN) + return status, status_code, None + + +def is_soft_fail(elapsed: timedelta) -> bool: + # Full circle has no status codes differentiating client vs PS failure. We need to make a + # determination based on the elapsed time. + return elapsed.total_seconds() < 60 diff --git a/generalresearch/wall_status_codes/innovate.py b/generalresearch/wall_status_codes/innovate.py new file mode 100644 index 0000000..0d11425 --- /dev/null +++ b/generalresearch/wall_status_codes/innovate.py @@ -0,0 +1,117 @@ +""" +https://innovatemr.stoplight.io/docs/supplier-api/ZG9jOjEzNzYxMTg2-statuses-term-reasons-and-categories +Term Reasons xls + +This is super confusing. We get a "status" (1 through 8), a term reason, (and a +category?). The status we can't use directly, because they call a quality +term (8) due to both dedupes and due to actual quality issues. So, some we +can map directly, and some we have to look at the category. +""" + +from collections import defaultdict +from typing import Optional, Tuple + +from generalresearch.models.thl.definitions import Status, StatusCode1 + +status_codes_innovate = { + "1": "Complete", + "2": "Buyer Fail", + "3": "Buyer Over Quota", + "4": "Buyer Quality Term", + "5": "PS Termination", + # We can't use this directly because the sub reasons are not all the same + "7": "PS Over Quota", # same as above. They think Quota full is a PS Term + "8": "PS Quality Term", # same as above. they think dupe is a quality term + "0": "PS Abandon", + "6": "Buyer Abandon", # really it is "Buyer Abandon" +} +status_map = defaultdict( + lambda: Status.FAIL, + **{"1": Status.COMPLETE, "0": Status.ABANDON, "6": Status.ABANDON}, +) +status_codes_ext_map = { + StatusCode1.BUYER_FAIL: ["2", "3"], + StatusCode1.BUYER_QUALITY_FAIL: ["4"], + StatusCode1.PS_BLOCKED: [], + StatusCode1.PS_QUALITY: ["8"], + StatusCode1.PS_DUPLICATE: [], + StatusCode1.PS_FAIL: ["5"], + StatusCode1.PS_OVERQUOTA: ["7"], +} +ext_status_code_map = dict() +for k, v in status_codes_ext_map.items(): + for vv in v: + ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k + +category_innovate = { + "Selected threat potential score at joblevel not allow the survey": StatusCode1.PS_QUALITY, + "OE Validation": StatusCode1.PS_QUALITY, + "Unique IP": StatusCode1.PS_DUPLICATE, + "Unique PID": StatusCode1.PS_DUPLICATE, + # 'Duplicated to token {token} and Group {groupID}': StatusCode1.PS_DUPLICATE, + # 'Duplicate Due to Multi Groups: Token {token} and Group {groupID}': StatusCode1.PS_DUPLICATE, + # todo: we should not send them into this marketplace for a day? + "User has attended {count} survey in 5 range": StatusCode1.PS_FAIL, + "PII_OPT": StatusCode1.PS_QUALITY, + "Recaptcha": StatusCode1.PS_QUALITY, + "URL Manipulation - Multiple Tries": StatusCode1.PS_QUALITY, + "URL Manipulation": StatusCode1.PS_QUALITY, + "Quota closed": StatusCode1.PS_OVERQUOTA, + "OpinionRoute Timeout Error": StatusCode1.PS_FAIL, + "OpinionRoute Error": StatusCode1.PS_FAIL, + "Invalid opinionRoute Token": StatusCode1.PS_FAIL, + "GEOIP": StatusCode1.PS_QUALITY, + "Speeder": StatusCode1.PS_QUALITY, + "Error respondent risk is too high": StatusCode1.PS_QUALITY, + "Group NA": StatusCode1.PS_OVERQUOTA, + "Job NA": StatusCode1.PS_OVERQUOTA, + "Supplier NA": StatusCode1.PS_OVERQUOTA, + "This survey Country mismatch": StatusCode1.PS_QUALITY, + "DeviceType": StatusCode1.PS_FAIL, + "Off hours": StatusCode1.PS_FAIL, + "Panel Duplicate": StatusCode1.PS_DUPLICATE, + "Not Eligible(sameSurveyElimination)": StatusCode1.PS_FAIL, + "ClientQualTerm": StatusCode1.BUYER_QUALITY_FAIL, + "BlockedRespondent": StatusCode1.PS_BLOCKED, +} + + +def annotate_status_code( + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +) -> Tuple: + """ + Only quality terminate (4 and 8), and PS term (5) return a term_reason (af=). + + :params ext_status_code_1: this is from the callback url param '&ac=' + :params ext_status_code_2: callback url param '&af=' + :params ext_status_code_3: not used + + returns: (status, status_code_1, status_code_2) status_code_2 is always None + """ + status = status_map[ext_status_code_1] + if status == Status.COMPLETE: + return status, StatusCode1.COMPLETE, None + # First use the 1 through 8 status code. Then, using the reason (af=), if available, + # try to maybe reclassify it. + if ext_status_code_1 not in ext_status_code_map: + return status, StatusCode1.UNKNOWN, None + status_code = ext_status_code_map.get(ext_status_code_1, StatusCode1.UNKNOWN) + if ext_status_code_2 in category_innovate: + status_code = category_innovate[ext_status_code_2] + if ext_status_code_2: + # Some of these have ids in them... so we have to pattern match it + if ( + "Duplicated to token " in ext_status_code_2 + or "Duplicate Due to Multi Groups" in ext_status_code_2 + ): + # innovate calls this 8 (quality fail). I think its a PS Dupe ... + status_code = StatusCode1.PS_DUPLICATE + elif "User has attended " in ext_status_code_2: + status_code = StatusCode1.PS_FAIL + elif "RED_HERRING_" in ext_status_code_2: + # innovate calls this 5 (PS fail), I think its a quality fail + status_code = StatusCode1.PS_QUALITY + + return status, status_code, None diff --git a/generalresearch/wall_status_codes/lucid.py b/generalresearch/wall_status_codes/lucid.py new file mode 100644 index 0000000..bf49b45 --- /dev/null +++ b/generalresearch/wall_status_codes/lucid.py @@ -0,0 +1,128 @@ +""" +https://support.lucidhq.com/s/article/Lucid-Marketplace-Response-Codes +https://support.lucidhq.com/s/article/Client-Response-Codes +https://support.lucidhq.com/s/article/Collecting-Data-From-Redirects +""" + +from collections import defaultdict +from typing import Optional, Tuple + +from generalresearch.models.thl.definitions import Status, StatusCode1 + +mp_codes = { + "-6": "Pre-Client Intermediary Page Drop Off", + "-5": "Failure in the Post Answer Behavior", + "-1": "Failure to Load the Lucid Marketplace", + "1": "Currently in Screener or Drop", + "3": "Respondent Sent to the Client Survey", + "21": "Industry Lockout", + "23": "Standard Qualification", + "24": "Custom Qualification", + "120": "Pre-Client Survey Opt Out", + "122": "Return to Marketplace Opt Out", + "123": "Max Client Survey Entries", + "124": "Max Time in Router", + "125": "Max Time in Router Warning Opt Out", + "126": "Max Answer Limit", + "30": "Unique IP", + "31": "RelevantID Duplicate", + "32": "Invalid Traffic", + "35": "Supplier PID Duplicate", + "36": "Cookie Duplicate", + "37": "GEO IP Mismatch", + "38": "RelevantID** Fraud Profile", + "131": "Supplier Encryption Failure", + "132": "Blocked PID", + "133": "Blocked IP", + "134": "Max Completes per Day Terminate", + "138": "Survey Group Cookie Duplicate", + "139": "Survey Group Supplier PID Duplicate", + "230": "Survey Group Unique IP", + "234": "Blocked Country IP", + "236": "No Privacy Consent", + "237": "Minimum Age", + "238": "Found on Deny List", + "240": "Invalid Browser", + "241": "Respondent Threshold Limit", + "242": "Respondent Quality Score", + "243": "Marketplace Signature Check", + "40": "Quota Full", + "41": "Supplier Allocation", + "42": "Survey Closed for Entry", + "50": "CPI Below Supplier’s Rate Card", + "98": "End of Router", +} + +# todo: finish, there's a bunch more +client_status_map = { + "30": StatusCode1.BUYER_QUALITY_FAIL, + "33": StatusCode1.BUYER_QUALITY_FAIL, + "34": StatusCode1.BUYER_QUALITY_FAIL, + "35": StatusCode1.BUYER_QUALITY_FAIL, +} + +status_map = defaultdict(lambda: Status.FAIL, **{"s": Status.COMPLETE}) +status_codes_ext_map = { + StatusCode1.COMPLETE: [], + StatusCode1.BUYER_FAIL: ["3"], + StatusCode1.BUYER_QUALITY_FAIL: [], + StatusCode1.PS_BLOCKED: ["32", "132", "133", "234", "236", "237", "238", "242"], + StatusCode1.PS_QUALITY: [ + "37", + "38", + "131", + "132", + "133", + "234", + "237", + "238", + "240", + "243", + ], + StatusCode1.PS_DUPLICATE: ["21", "35", "36", "30", "31", "138", "139", "230"], + StatusCode1.PS_FAIL: [ + "-6", + "-5", + "-1", + "1", + "120", + "122", + "125", + "134", + "23", + "24", + "123", + "124", + "236", + "241", + "50", + "98", + "126", + ], + StatusCode1.PS_OVERQUOTA: ["40", "41", "42"], +} +ext_status_code_map = dict() +for k, v in status_codes_ext_map.items(): + for vv in v: + ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k + + +def annotate_status_code( + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +) -> Tuple: + """ + :params ext_status_code_1: this indicates which callback url was hit. possible values {'s', *anything else*} + :params ext_status_code_2: this is from the callback url params: InitialStatus + :params ext_status_code_3: this is from the callback url params: ClientStatus + returns: (status, status_code_1, status_code_2) + """ + status = status_map[ext_status_code_1] + if ext_status_code_2 == "3": + status_code = client_status_map.get(ext_status_code_3, StatusCode1.BUYER_FAIL) + else: + status_code = ext_status_code_map.get(ext_status_code_2, StatusCode1.UNKNOWN) + if status == Status.COMPLETE: + status_code = StatusCode1.COMPLETE + return status, status_code, None diff --git a/generalresearch/wall_status_codes/morning.py b/generalresearch/wall_status_codes/morning.py new file mode 100644 index 0000000..2539b55 --- /dev/null +++ b/generalresearch/wall_status_codes/morning.py @@ -0,0 +1,122 @@ +from collections import defaultdict +from typing import Optional, Tuple + +from generalresearch.models.thl.definitions import Status, StatusCode1 + +""" +Status IDs. +We don't use these except for "complete". We use the "id" instead which is +the detailed status + +complete: The survey was completed successfully. +failure: The respondent was rejected for an unknown reason. These will be investigated. +over_quota: The respondent qualified for a quota, but there were no open completes available. +quality_termination: The respondent was rejected for quality reasons. +screenout: The respondent did not qualify for the survey or quota. +timeout: The respondent completed the survey after the timeout period had expired. +in_progress: The respondent interview session is still in progress, such as in the prescreener or survey. +""" + +short_code_to_status_codes_morning = { + "att_che": "attention_check", + "banned": "banned", + "bid_clo": "bid_closed", + "bi_no_fo": "bid_not_found", + "bid_pau": "bid_paused", + "complete": "complete", + "co_in_fb": "country_invalid_for_bid", + "deduplic": "deduplicated", + "excluded": "excluded", + "fa_pr_ca": "failed_prescreener_captcha", + "failure": "failure", + "inactive": "inactive", + "in_ad_qu": "in_additional_questions", + "ineligib": "ineligible", + "in_pre": "in_prescreener", + "in_sur": "in_survey", + "in_su_fa": "in_survey_failure", + "in_su_oq": "in_survey_over_quota", + "in_su_sc": "in_survey_screenout", + "in_en_pa": "invalid_entry_parameters", + "in_en_si": "invalid_entry_signature", + "la_in_fb": "language_invalid_for_bid", + "no_co_av": "no_completes_available", + "no_co_re": "no_completes_required", + "pr_co_er": "prescreener_completion_error", + "pre_tim": "prescreener_timeout", + "qu_te_ot": "quality_termination_other", + "qu_in_fb": "quota_invalid_for_bid", + "reentry": "reentry", + "speeding": "speeding", + "straight": "straightlining", + "sur_tim": "survey_timeout", + "tem_ban": "temporarily_banned", +} +status_map = defaultdict(lambda: Status.FAIL, **{"complete": Status.COMPLETE}) +status_codes_ext_map = { + StatusCode1.COMPLETE: ["complete"], + StatusCode1.BUYER_FAIL: [ + "in_survey_failure", + "in_survey_over_quota", + "in_survey_screenout", + "survey_timeout", + ], + StatusCode1.BUYER_QUALITY_FAIL: [ + "quality_termination_other", + "open_ended_response", + "speeding", + "attention_check", + "straightlining", + "suspect_response_pattern", + ], + StatusCode1.PS_BLOCKED: ["banned", "temporarily_banned"], + StatusCode1.PS_QUALITY: [ + "prescreener_attention_check", + "country_invalid_for_bid", + "failed_prescreener_captcha", + "inactive", + "invalid_entry_signature", + "invalid_entry_parameters", + ], + StatusCode1.PS_DUPLICATE: ["reentry", "deduplicated", "excluded"], + StatusCode1.PS_FAIL: [ + "failure", + "language_invalid_for_bid", + "minimum_cost_per_interview", + "prescreener_completion_error", + "prescreener_timeout", + "ineligible", + ], + StatusCode1.PS_OVERQUOTA: [ + "bid_not_found", + "bid_closed", + "bid_paused", + "no_completes_available", + "no_completes_required", + "quota_invalid_for_bid", + ], +} +ext_status_code_map = dict() +for k, v in status_codes_ext_map.items(): + for vv in v: + ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k + + +def annotate_status_code( + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +) -> Tuple: + """ + :params ext_status_code_1: from callback url params: &sti={{status_id}} + :params ext_status_code_2: from callback url params: &sdi={{status_detail_id}} + :params ext_status_code_3: not used + + returns: (status, status_code_1, status_code_2) + """ + # We pretty much do not use the status_id because it is Morning's status category, which + # lumps de-dupes into a different category, and doesn't differentiate between in-client + # and not terms. + status = status_map[ext_status_code_1] + status_code = ext_status_code_map.get(ext_status_code_2, StatusCode1.UNKNOWN) + return status, status_code, None diff --git a/generalresearch/wall_status_codes/pollfish.py b/generalresearch/wall_status_codes/pollfish.py new file mode 100644 index 0000000..fbe1169 --- /dev/null +++ b/generalresearch/wall_status_codes/pollfish.py @@ -0,0 +1,80 @@ +from collections import defaultdict +from typing import Optional, Tuple + +from generalresearch.models.thl.definitions import Status, StatusCode1 + +status_codes_map = { + "quo_ful": "quota_full", + "sur_clo": "survey_closed", + "profilin": "profiling", + "screenou": "screenout", + "duplicat": "duplicate", + "security": "security", + "geomissm": "geomissmatch", + "quality": "quality", + "has_ans": "hasty_answers", + "gibberis": "gibberish", + "captcha": "captcha", + "tp_term": "third_party_termination", + "tp_fraud": "third_party_termination_fraud", + "tp_qual": "third_party_termination_quality", + "use_rej": "user_rejection", + "vpn": "vpn", + "sur_exp": "survey_expired", + "und_pro": "underage_profiling", + "ban_phr": "banned_phrase", + "dis_rul": "disqualification_rule", + "str_lin": "straight_lining", + "su_al_ta": "survey_already_taken", + "complete": "complete", +} +status_map = defaultdict(lambda: Status.FAIL, **{"complete": Status.COMPLETE}) +status_codes_ext_map = { + StatusCode1.COMPLETE: ["complete"], + StatusCode1.BUYER_FAIL: ["third_party_termination", "screenout"], + StatusCode1.BUYER_QUALITY_FAIL: [ + "third_party_termination_fraud", + "third_party_termination_quality", + "disqualification_rule", + "hasty_answers", + "gibberish", + "banned_phrase", + "straight_lining", + ], + StatusCode1.PS_BLOCKED: [], + StatusCode1.PS_QUALITY: [ + "security", + "geomissmatch", + "quality", + "captcha", + "vpn", + ], + StatusCode1.PS_DUPLICATE: ["duplicate", "survey_already_taken"], + StatusCode1.PS_FAIL: [ + "profiling", + "underage_profiling", + "user_rejection", + "underage_profiling", + ], + StatusCode1.PS_OVERQUOTA: ["quota_full", "survey_closed", "survey_expired"], +} +ext_status_code_map = dict() +for k, v in status_codes_ext_map.items(): + for vv in v: + ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k + + +def annotate_status_code( + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +) -> Tuple: + """ + :params ext_status_code_1: from callback url params: &sti={{status_id}} + :params ext_status_code_2: from callback url params: &sdi={{status_detail_id}} + :params ext_status_code_3: not used + returns: (status, status_code_1, status_code_2) + """ + status = status_map[ext_status_code_1] + status_code = ext_status_code_map.get(ext_status_code_2, StatusCode1.UNKNOWN) + return status, status_code, None diff --git a/generalresearch/wall_status_codes/precision.py b/generalresearch/wall_status_codes/precision.py new file mode 100644 index 0000000..147685a --- /dev/null +++ b/generalresearch/wall_status_codes/precision.py @@ -0,0 +1,96 @@ +""" +https://integrations.precisionsample.com/api.html#survey%20status%20ID's + +Possible statuses: {'s', 't', 'q', 'r', 'f'} +s - complete, t - failed ot terminated, q - quota full, r - rejected, +f - client approved the Preliminary complete as Final Complete +""" + +from collections import defaultdict +from typing import Optional, Tuple + +from generalresearch.models.thl.definitions import Status, StatusCode1 + +status_codes_precision = { + "10": "Complete", + "20": "Client Terminate", + "21": "PS Terminate", + "22": "PS Terminate - Device Fail", + "23": "PS Terminate - Survey Closed", + "24": "PS Terminate - Recaptcha Fail", + "25": "PS Terminate - Exclusion Click", + "30": "Client Over Quota", + "31": "PS Over Quota", + "32": "PS Over Quota - Allocation Full", + "40": "Quality Sentinel Fail - Publisher Dupe", + "41": "Quality Sentinel Fail - Survey Dupe", + "42": "Quality Sentinel Fail - Back Button", + "43": "Quality Sentinel Fail - Verity", + "44": "Quality Sentinel Fail - Research Defender Dupe", + "45": "Quality Sentinel Fail - Research Defender Fraud", + "46": "Quality Sentinel Fail - Geo-Validation Fail", + "47": "Quality Sentinel Fail - Fraud Member", + "48": "Quality Sentinel Fail-Postal Code Country Mismatche", + "50": "Ghost Complete", + "51": "Ghost Complete - Math Validation Failed", + "52": "Ghost Complete - Encryption Fail", + "53": "Ghost Complete - Prescreen Skip", + "54": "Ghost Complete - Default Complete Lin", + "55": "Ghost Complete - Old URL", + "56": "Ghost Complete - Wrong Guid", + "70": "Quality Sentinel Fail - Sentry Tech Fail", + "71": "Quality Sentinel Fail - Sentry Behavioral Fail", + # These are secondary statuses and shouldn't be in the wall, but + # just in case precision doesn't validate + "60": "Client Reject", + "80": "Final Complete", +} +status_map = defaultdict(lambda: Status.FAIL, **{"s": Status.COMPLETE}) +status_codes_ext_map = { + StatusCode1.COMPLETE: ["10"], + StatusCode1.BUYER_FAIL: ["20", "30"], + StatusCode1.BUYER_QUALITY_FAIL: ["60"], + StatusCode1.PS_BLOCKED: ["44"], + StatusCode1.PS_QUALITY: [ + "24", + "43", + "45", + "46", + "47", + "48", + "50", + "51", + "52", + "53", + "54", + "55", + "56", + "60", + "70", + "71", + ], + StatusCode1.PS_DUPLICATE: ["40", "41", "42", "25"], + StatusCode1.PS_FAIL: ["21", "22"], + StatusCode1.PS_OVERQUOTA: ["31", "32", "23"], +} +ext_status_code_map = dict() +for k, v in status_codes_ext_map.items(): + for vv in v: + ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k + + +def annotate_status_code( + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +) -> Tuple: + """ + :params ext_status_code_1: from callback url params: status + :params ext_status_code_2: from callback url params: code + :params ext_status_code_3: not used + + returns: (status, status_code_1, status_code_2) + """ + status = status_map[ext_status_code_1] + status_code = ext_status_code_map.get(ext_status_code_2, StatusCode1.UNKNOWN) + return status, status_code, None diff --git a/generalresearch/wall_status_codes/prodege.py b/generalresearch/wall_status_codes/prodege.py new file mode 100644 index 0000000..8ce4d22 --- /dev/null +++ b/generalresearch/wall_status_codes/prodege.py @@ -0,0 +1,56 @@ +""" +https://developer.prodege.com/surveys-feed/term-reasons +""" + +from collections import defaultdict +from typing import Optional, Tuple + +from generalresearch.models.thl.definitions import Status, StatusCode1 + +status_map = defaultdict(lambda: Status.FAIL, **{"1": Status.COMPLETE}) +status_code_map = { + StatusCode1.COMPLETE: [], + StatusCode1.BUYER_FAIL: ["1", "2"], + StatusCode1.BUYER_QUALITY_FAIL: ["10", "12"], + StatusCode1.PS_BLOCKED: ["33"], + StatusCode1.PS_QUALITY: [ + "3", + "5", + "15", + "16", + "23", + "27", + "34", + "35", + "36", + "37", + "39", + ], + StatusCode1.PS_DUPLICATE: ["4", "17", "19", "20", "24", "32"], + StatusCode1.PS_FAIL: ["8", "21", "22"], + StatusCode1.PS_OVERQUOTA: ["13", "28", "29", "30", "31", "38"], +} + +status_class = dict() +for k, v in status_code_map.items(): + for vv in v: + status_class[status_code_map.get(vv, vv)] = k + + +def annotate_status_code( + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +) -> Tuple: + """ + :params ext_status_code_1: status from redirect url + :params ext_status_code_2: termreason from redirect url + :params ext_status_code_3: dqquestionid, not used. + returns: (status, status_code_1, status_code_2) + """ + status = status_map[ext_status_code_1] + status_code = status_class.get(ext_status_code_2, StatusCode1.UNKNOWN) + if status == Status.COMPLETE: + assert ext_status_code_2 is None + status_code = StatusCode1.COMPLETE + return status, status_code, None diff --git a/generalresearch/wall_status_codes/repdata.py b/generalresearch/wall_status_codes/repdata.py new file mode 100644 index 0000000..5e575c4 --- /dev/null +++ b/generalresearch/wall_status_codes/repdata.py @@ -0,0 +1,89 @@ +""" +Status codes are in a xlsx file. See thl-repdata readme +""" + +from collections import defaultdict +from typing import Optional, Tuple + +from generalresearch.models.thl.definitions import Status, StatusCode1 + +status_codes_name = { + "2": "Search Failed", + "3": "Activity Failed", + "4": "Review Failed", + "1000": "Complete", + "2000": "Client Side Term", + "3000": "Survey Quality Term (Client Side)", + "4000": "General Overquota (Client Side)", + "5001": "Stream Closed (Research Desk)", + "5002": "Speeder Term (<30sec or <20% of LOI)", + "5003": "Qualification mismatch", + "5004": "Device compatibility mismatch", + "5101": "Attempted bypass of Client Survey", + "5102": "Encryption Failure", + "6001": "Overall quota achieved (Research Desk)", + "6002": "Sub-Quota achieved (Research Desk)", + "6003": "In-Survey maximum exceeded (Research Desk)", +} +# See: 02, and 13 are de-dupes +rd_threat_name = { + "02": "Duplicate entrant into survey", + "03": "Emulator Usage", + "04": "VPN usage detected", + "05": "TOR network detected", + "06": "Public proxy server detected", + "07": "Web proxy service used", + "08": "Web crawler usage detected", + "09": "Internet fraudster detected", + "10": "Retail and ad-tech fraudster detected", + "11": "Subnet detected", + "12": "Recent Abuse detected", + "13": "Duplicate Survey Group detected", + "14": "Navigator Webdriver detected", + "15": "Developer Tool detected", + "16": "Web RTC Detected", + "17": "Proxy Detected", + "18": "MaxMind Failure", +} + +status_map = defaultdict(lambda: Status.FAIL, **{"complete": Status.COMPLETE}) +status_code_map = { + StatusCode1.COMPLETE: ["1000"], + StatusCode1.BUYER_FAIL: ["2000", "4000"], + StatusCode1.BUYER_QUALITY_FAIL: ["3000"], + StatusCode1.PS_BLOCKED: [], + StatusCode1.PS_QUALITY: ["2", "3", "4", "5002", "5101", "5102"], + StatusCode1.PS_DUPLICATE: [], + StatusCode1.PS_FAIL: ["5003", "5004"], + StatusCode1.PS_OVERQUOTA: ["5001", "6001", "6002", "6003"], +} + +status_class = dict() +for k, v in status_code_map.items(): + for vv in v: + status_class[status_code_map.get(vv, vv)] = k + + +def annotate_status_code( + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +) -> Tuple: + """ + :params ext_status_code_1: the redirect urls category (as defined in url param 549f3710b) + {'term', 'overquota', 'fraud', 'complete'} + :params ext_status_code_2: the "isc" (inbound_sub_code) + :params ext_status_code_3: "rdThreat". only used when isc=2 + returns: (status, status_code_1, status_code_2) + """ + status = status_map[ext_status_code_1] + status_code = status_class.get(ext_status_code_2, StatusCode1.UNKNOWN) + + if ext_status_code_3 in {"02", "13"}: + status_code = StatusCode1.PS_DUPLICATE + if status == Status.COMPLETE: + assert ( + status_code == StatusCode1.COMPLETE + ), "inconsistent status codes for complete" + + return status, status_code, None diff --git a/generalresearch/wall_status_codes/sago.py b/generalresearch/wall_status_codes/sago.py new file mode 100644 index 0000000..ac354ce --- /dev/null +++ b/generalresearch/wall_status_codes/sago.py @@ -0,0 +1,193 @@ +""" +https://developer-beta.market-cube.com/api-details#api=definition-api&operation=get-api-v1-definition-return-status-list +""" + +from collections import defaultdict +from typing import Optional, Tuple + +from generalresearch.models.thl.definitions import Status, StatusCode1 + +status_codes_schlesinger = { + "1": "Complete", + "2": "Buyer Fail", + "3": "Buyer Fail", + "4": "Buyer Fail", + "6": "PS Quality Term", # includes Duplicate + "8": "PS Term", + "10": "PS Over Quota", + "15": "Buyer Fail", # not in documentation + "0": "Abandon", # really it is "PS Abandon" + "11": "Abandon", # really it is "Buyer Abandon" +} + +status_reason_name = { + "1": "Not a Unique Sample Cube User", + "4": "GeoIP - wrong country", + "7": "Duplicate - not a unique IP", + "9": "Security Terminate - supplier not allocated to the survey", + "16": "Client redirect - SHA-1 mismatch", + "18": "Supplier entry - encryption failure", + "22": "Trap question failure", + "26": "Terminate - Min LOI logic", + "27": "Terminate - Max LOI logic", + "29": "Security Terminate - Survey closed", + "30": "Recontact - wrong PID user", + "31": "Security Terminate - financial termination", + "33": "Security Terminate - Unique Link error", + "34": "Partial Complete - SOI", + "35": "Terminate - survey does not allow desktop devices", + "36": "Terminate - survey does not allow mobile devices", + "37": "Recontact - terminate", + "38": "Terminate - quality check on multi-punch", + "39": "Reconciled - terminate", + "40": "Reconciled - complete", + "41": "Already a Complete", + "42": "RelevantID - duplicate attempt", + "43": "RelevantID - fraud profile score too high", + "44": "RelevantID - wrong country", + "45": "RelevantID - call failed", + "46": "Sample Cube - overquota", + "47": "Terminate - demographic qualifications", + "48": "Complete - not reconciled", + "49": "Client - overquota", + "50": "Terminate - client side logic", + "51": "Drop Out on landing page", + "52": "Security terminate - client side logic", + "54": "Client redirect - drop out / in progress", + "55": "MaxMind - IP blacklisted", + "56": "Terminate - survey does not allow tablet devices", + "57": "GeoIP - State check, US only", + "58": "Drop Out on qualifications", + "60": "Duplicate - survey group unique IP", + "61": "Duplicate - survey group unique SID", + "62": "Duplicate - survey group RelevantID dupe", + "63": "Security Terminate - exceeded supplier allocation", + "64": "Terminate - custom demographic qualifications", + "65": "Client - entry encryption error", + "66": "RelevantID - internal failure", + "67": "Bot detected", + "68": "Client redirect - secret value missing", + "69": "Complete - supplier reservation", + "70": "Client - authentication error", + "71": "Terminate - blocked IE browser version", + "72": "Duplicate - device user ID check", + "73": "Duplicate - survey group Device User ID dupe", + "74": "AgeDemoTerminateInconsistency", + "75": "Terminate - inconsistent gender", + "76": "Terminate - inconsistent zip", + "77": "Drop Out before client entry", + "78": "Not live survey completion", + "79": "LinkedIn - failed to login", + "80": "Linkedin - drop out on login", + "81": "Client redirect - S2S not fired", + "82": "Bad User", + "84": "Security Terminate - Speeder", + "85": "Sample Chain - fraudster", + "86": "Sample Chain - wrong country", + "87": "Sample Chain - duplicate", + "88": "Terminate - demo terminate on advanced logic", + "89": "Sample Chain - survey group duplicate", + "90": "Duplicate - client logic", + "91": "Security Terminate - Unique Link time-out", + "92": "Security Terminate - Unique Link internal server error", + "93": "Security Terminate - Unique Link exceeded expected CPI", + "94": "Sample Chain - cross panel deduplication", + "95": "First time entry exception failure", + "96": "Prescreener start init", + "97": "Prescreener start exception", + "98": "Client - qualification error", + "99": "Block User", + "100": "MBDGoogleStart", + "101": "Reconciled to Complete - Late", + "102": "Reconciled to Terminate - Late", + "103": "No Matching Unique ID", + "104": "Matching Unique ID Already Attempted", + "105": "Already a Terminate", + "106": "Unique link Already Attempted", + "107": "Internal Quality Score", + "108": "Bad IP - VPN", + "109": "Bad IP - Proxy", + "110": "Client - No Surveys", + "111": "Fraud S2S", + "112": "Bad IPQS Fraud Score", + "113": "Invalid Redirect Url", + "114": "Client Eligibility Logic", + "115": "Unique Link Blank Response", + "116": "Unique Link Client Failure", + "117": "Sample Chain-Terminate on OE Quality", +} + +status_map = defaultdict( + lambda: Status.FAIL, **{"1": Status.COMPLETE, "0": Status.ABANDON} +) + +status_codes_ext_map = { + StatusCode1.COMPLETE: ["48"], + StatusCode1.BUYER_FAIL: ["16", "29", "49", "50", "78", "114", "110", "114"], + StatusCode1.BUYER_QUALITY_FAIL: ["26", "52", "68", "81", "84"], + StatusCode1.PS_BLOCKED: ["99"], + StatusCode1.PS_QUALITY: [ + "1", + "7", + "42", + "43", + "44", + "58", + "60", + "61", + "62", + "72", + "73", + "74", + "75", + "76", + "85", + "86", + "87", + "89", + "90", + "99", + "112", + "120", + ], + StatusCode1.PS_DUPLICATE: [ + "1", + "7", + "8", + "42", + "60", + "61", + "62", + "72", + "73", + "87", + "89", + "90", + ], + StatusCode1.PS_FAIL: ["7", "29", "36", "47", "56", "58", "64"], + StatusCode1.PS_OVERQUOTA: ["29", "46", "33", "31"], +} +ext_status_code_map = dict() +for k, v in status_codes_ext_map.items(): + for vv in v: + ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k + + +def annotate_status_code( + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +) -> Tuple: + """ + :params ext_status_code_1: from callback url params: scstatus + :params ext_status_code_2: from callback url params: scsecuritystatus + :params ext_status_code_3: not used + returns: (status, status_code_1, status_code_2) + """ + status = status_map[ext_status_code_1] + status_code = ext_status_code_map.get(ext_status_code_2, StatusCode1.UNKNOWN) + # According to personal communication, scsecuritystatus may not always + # come back for completes. Going to ignore it if the status is complete + if status == Status.COMPLETE: + status_code = StatusCode1.COMPLETE + return status, status_code, None diff --git a/generalresearch/wall_status_codes/spectrum.py b/generalresearch/wall_status_codes/spectrum.py new file mode 100644 index 0000000..0cf5814 --- /dev/null +++ b/generalresearch/wall_status_codes/spectrum.py @@ -0,0 +1,165 @@ +""" +https://purespectrum.atlassian.net/wiki/spaces/PA/pages/33613201/Minimizing+Clickwaste+with+ps+rstatus +""" + +from collections import defaultdict +from typing import Optional, Tuple + +from generalresearch.models.thl.definitions import Status, StatusCode1 + +status_codes_spectrum = { + "11": "PS Drop", + "12": "PS Quota Full Core", + "13": "PS Termination Core", + "14": "PS Side In Progress", + "15": "PS Quality", + "16": "Buyer In Progress", + "17": "Buyer Quota Full", + "18": "Buyer Termination", + "19": "Buyer Drop", + "20": "Buyer Quality Termination", + "21": "Complete", + "22": "PS Survey Closed Termination", + "23": "PS Survey Paused Termination", + "24": "PS Unopened Quota Term", + "25": "PS Supplier Allocation Full", + "26": "PS Past Participation Fail", + "27": "PS Supplier Quota Allocation Full", + "28": "PS Invalid Survey", + "29": "PS LOI Threshold Failure", + "30": "Buyer Security (De-Dupe)", + "31": "Buyer Hash Failure", + "32": "PS Grouping Termination", + "33": "Buyer Reconcilliation Reject", + "35": "PS No matched quotas", + "36": "PS Max IP Throttling Termination", + "37": "PS Quota Throttling Termination", + "38": "PS PSID Geo Termination", + "40": "PS GeoIP Fail", + "41": "PS Bot Fail", + "42": "PS BlackList Fail", + "43": "PS Anonymous Fail", + "44": "PS Include Fail", + "45": "PS Termination Extended", + "46": "PS Termination Custom", + "47": "PS Quota Full Extended", + "48": "PS Quota Full Custom", + "49": "PS Include Fail", + "50": "PS Exclude Fail", + "51": "Invalid Supplier", + "52": "PSID Service Fail", + "55": "PS Unique Link Termination", + "56": "Unauthorized Augment", + "57": "PS Supplier Quota Full", + "58": "PS Supplier Quota Throttling Termination", + "59": "Buyer Config Error", + "60": "PS_Js_Fail", + "62": "Ps_NoPureScore", + "63": "PS_Blacklist_Data_Quality", + "64": "PS_Blacklist_Data_Quality_2", + "67": "PS_SC_Fraudster_Fail", + "68": "PS_SC_Threat_Fail", + "69": "PS_TC_Termination", + "70": "PS_DF_DUPE", + "71": "ScHashFail", + "73": "PS_Transaction_Fraud", + "74": "PS_Respondent_Redirect_Fail", # this apparently means dedupe + "75": "PS_Blacklist_Data_Quality_4", + "76": "PS_DQ_Screener_Invalid", + "77": "PS_Supply_Inbound_Hash_Security", + "78": "PS_DQ_Honeypot_Fail", + "79": "PS_PureText_Dedupe_Fail", + "80": "PS_AI_Text_Fail", + "81": "PS_Puretext_Language_Fail", + "82": "PS_Survey_Signature_Fail", + "83": "PS_Browser_Manipulation_Fail", + "84": "Buyer_PS_API_Fail", # no idea what this means + "85": "PS_RD_Predupe", + "86": "PS_DF_Dupe_Grouping", + "87": "PS_Supplier_Invalid_Bbsec", + "88": "PS_Supplier_Allocation_Throttle", +} +status_map = defaultdict(lambda: Status.FAIL, **{"21": Status.COMPLETE}) +status_codes_ext_map = { + StatusCode1.COMPLETE: ["21"], + StatusCode1.BUYER_FAIL: ["16", "17", "18", "19", "30", "59", "84"], + StatusCode1.BUYER_QUALITY_FAIL: ["20", "31"], + StatusCode1.PS_BLOCKED: ["42", "75"], + StatusCode1.PS_QUALITY: [ + "15", + "29", + "38", + "40", + "41", + "43", + "60", + "62", + "63", + "64", + "65", + "67", + "68", + "69", + "71", + "73", + "76", + "77", + "78", + "79", + "81", + "82", + "83", + "87", + ], + StatusCode1.PS_DUPLICATE: ["26", "32", "69", "70", "74", "85", "86"], + StatusCode1.PS_FAIL: [ + "11", + "13", + "35", + "45", + "46", + "49", + "50", + ], + StatusCode1.PS_OVERQUOTA: [ + "12", + "22", + "23", + "24", + "25", + "36", + "37", + "47", + "48", + "57", + "58", + "58", + "88", + ], + StatusCode1.UNKNOWN: [ + "72", + ], +} +ext_status_code_map = dict() +for k, v in status_codes_ext_map.items(): + for vv in v: + ext_status_code_map[status_codes_ext_map.get(vv, vv)] = k + + +def annotate_status_code( + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +) -> Tuple: + """ + :params ext_status_code_1: from url params: ps_rstatus + https://purespectrum.atlassian.net/wiki/spaces/PA/pages/33613201/Minimizing+Clickwaste+with+ps+rstatus + :params ext_status_code_2: not used + :params ext_status_code_3: not used + + returns: (status, status_code_1, status_code_2) + """ + status = status_map[ext_status_code_1] + status_code = ext_status_code_map.get(ext_status_code_1, StatusCode1.UNKNOWN) + + return status, status_code, None diff --git a/generalresearch/wall_status_codes/wxet.py b/generalresearch/wall_status_codes/wxet.py new file mode 100644 index 0000000..1cbf568 --- /dev/null +++ b/generalresearch/wall_status_codes/wxet.py @@ -0,0 +1,87 @@ +from collections import defaultdict +from typing import Optional, Dict, Tuple + +from generalresearch.models.thl.definitions import StatusCode1, Status +from generalresearch.wxet.models.definitions import ( + WXETStatus, + WXETStatusCode1, + WXETStatusCode2, +) + +status_map: Dict[WXETStatus, Status] = defaultdict( + lambda: Status.FAIL, **{WXETStatus.COMPLETE: Status.COMPLETE} +) +status_codes_ext_map = { + StatusCode1.COMPLETE: [WXETStatusCode1.COMPLETE], + StatusCode1.BUYER_FAIL: [ + WXETStatusCode1.BUYER_DUPLICATE, + WXETStatusCode1.BUYER_FAIL, + WXETStatusCode1.BUYER_OVER_QUOTA, + WXETStatusCode1.BUYER_TASK_NOT_AVAILABLE, + ], + StatusCode1.BUYER_QUALITY_FAIL: [WXETStatusCode1.BUYER_QUALITY_FAIL], + StatusCode1.BUYER_ABANDON: [WXETStatusCode1.BUYER_ABANDON], + StatusCode1.PS_BLOCKED: [], + StatusCode1.PS_QUALITY: [], + StatusCode1.PS_DUPLICATE: [], + StatusCode1.PS_ABANDON: [WXETStatusCode1.WXET_ABANDON], + StatusCode1.PS_FAIL: [WXETStatusCode1.WXET_FAIL], + StatusCode1.PS_OVERQUOTA: [], + StatusCode1.UNKNOWN: [], + StatusCode1.MARKETPLACE_FAIL: [WXETStatusCode1.BUYER_POSTBACK_NOT_RECEIVED], +} +ext_status_code_map = dict() +for k, v in status_codes_ext_map.items(): + for vv in v: + ext_status_code_map[vv] = k + +status_code2_map = { + StatusCode1.PS_QUALITY: [], + StatusCode1.PS_DUPLICATE: [ + WXETStatusCode2.WORKER_INELIGIBLE, + WXETStatusCode2.WORKER_EXCLUDED, + WXETStatusCode2.RE_ENTRY, + ], + StatusCode1.PS_OVERQUOTA: [ + WXETStatusCode2.SUPPLY_CONFIG_RESTRICTED, + WXETStatusCode2.WORKER_RATE_LIMITED, + WXETStatusCode2.TASK_RATE_LIMITED, + WXETStatusCode2.TASK_NOT_FOUND, + WXETStatusCode2.TASK_NOT_AVAILABLE, + WXETStatusCode2.TASK_NOT_FUNDED, + WXETStatusCode2.TASK_NO_FINISHES_AVAILABLE, + WXETStatusCode2.TASK_CONNECTOR_NO_FINISHES_AVAILABLE, + WXETStatusCode2.INVALID_ALLOCATION_SELECTION, + WXETStatusCode2.TASK_VERSION_MISMATCH, + ], +} +ext_status_code2_map = dict() +for k, v in status_code2_map.items(): + for vv in v: + ext_status_code2_map[vv] = k + + +def annotate_status_code( + ext_status_code_1: str, + ext_status_code_2: Optional[str] = None, + ext_status_code_3: Optional[str] = None, +) -> Tuple: + """ + :params ext_status_code_1: WXETStatus + :params ext_status_code_2: WXETStatusCode1 + :params ext_status_code_3: WXETStatusCode2 + returns: (status, status_code_1, status_code_2) + """ + ext_status_code_1 = WXETStatus(ext_status_code_1) + ext_status_code_2 = WXETStatusCode1(ext_status_code_2) + ext_status_code_3 = ( + WXETStatusCode2(ext_status_code_3) if ext_status_code_3 else None + ) + status = status_map[ext_status_code_1] + status_code_1 = ext_status_code_map.get(ext_status_code_2, StatusCode1.UNKNOWN) + status_code_2 = None + + if ext_status_code_2 == WXETStatusCode1.WXET_FAIL: + status_code_2 = ext_status_code2_map.get(ext_status_code_3) + + return status, status_code_1, status_code_2 diff --git a/generalresearch/wxet/__init__.py b/generalresearch/wxet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/wxet/models/__init__.py b/generalresearch/wxet/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generalresearch/wxet/models/definitions.py b/generalresearch/wxet/models/definitions.py new file mode 100644 index 0000000..50853dd --- /dev/null +++ b/generalresearch/wxet/models/definitions.py @@ -0,0 +1,320 @@ +from enum import Enum +from typing import Optional, Tuple + +from generalresearch.currency import USDMill +from generalresearch.utils.enum import ReprEnumMeta + + +class IncExcFilterType(str, Enum, metaclass=ReprEnumMeta): + INCLUDE = "include" + EXCLUDE = "exclude" + + +# Note: This is exactly the same as the py-utils:models/thl/definitions.py:Status. +# Keeping this because the comments (and as a result, the documentation) +# is slightly different, and specific to wxet. +class WXETStatus(str, Enum, metaclass=ReprEnumMeta): + """ + The outcome of a task attempt. If the attempt is still in progress, the status will be NULL. + """ + + # Worker completed the task. + COMPLETE = "c" + + # Worker did not complete task. They were rejected by either WXET or buyer. + FAIL = "f" + + # Worker abandoned the task. Only set if the Buyer informs us that the + # user took some action to exit out of the task + ABANDON = "a" + + # Worker either abandoned the task or was never returned. After a + # pre-determined amount of time (configurable), any task that does + # not have a status will time out. + # todo: setup the timeout logic for wxet + TIMEOUT = "t" + + +# Basically same note as for WxetStatus for WallAdjustedStatus +class WXETAdjustedStatus(str, Enum, metaclass=ReprEnumMeta): + # Task was reconciled to complete + ADJUSTED_TO_COMPLETE = "ac" + # Task was reconciled to incomplete + ADJUSTED_TO_FAIL = "af" + # The cpi for a task was adjusted + CPI_ADJUSTMENT = "ca" + # The user was redirected without a Postback, but the postback was then "immediately" + # recieved. The supplier thinks this was a failure. This is distinct from an + # actual adjustment to complete. + POSTBACK_COMPLETE = "pc" + + +class WXETStatusCode1(int, Enum, metaclass=ReprEnumMeta): + """ + __High level status code for outcome of the attempt.__ + This should only be NULL if the WXETStatus is ABANDON or TIMEOUT + """ + + # This shouldn't be returned. + UNKNOWN = 1 + # Worker failed to be sent into a task. + WXET_FAIL = 2 + # The worker abandoned/timed out within wxet before being sent to task + WXET_ABANDON = 3 + + # This should never happen + # Buyer is explicitly blocked by the marketplace. + # WXET_BUYER_BLOCKED = 4 + + # Open statuses + # WXET_ = 5 + # WXET_ = 6 + # WXET_ = 7 + # WXET_ = 8 + # WXET_ = 9 + # WXET_ = 10 + + # Values below here are considered buyer-rejections + + # Worker rejected by buyer for over quota + BUYER_OVER_QUOTA = 11 + # Worker rejected by buyer for duplicate entrance + BUYER_DUPLICATE = 16 + # Worker unable to enter task due it not being available. + BUYER_TASK_NOT_AVAILABLE = 12 + # Worker abandoned/timed out within the task + BUYER_ABANDON = 13 + # Worker terminated in buyer task + BUYER_FAIL = 14 + # Worker terminated in buyer task due to quality reasons + BUYER_QUALITY_FAIL = 15 + + # Worker redirect with no callback received + BUYER_POSTBACK_NOT_RECEIVED = 30 + + # Completed the Task + COMPLETE = 99 + + @property + def is_pre_task_entry_fail(self) -> bool: + """This property helper indicates if the WXET Attempt made it into + the WXET Account's (eg: the "buyer"'s) Task. + """ + return False if self.value > 10 else True + + +class WXETStatusCode2(int, Enum, metaclass=ReprEnumMeta): + """ + __Status Detail__ + These are generally only set if the StatusCode1 is WXET_FAIL, + but don't *have* to be. Can be NULL even if StatusCode1 is WXET_FAIL + """ + + # Unexpected error + INTERNAL_ERROR = 1 + + # 2,3,4,5,6 not implemented + + # Worker does not meet supply configuration + SUPPLY_CONFIG_RESTRICTED = 2 + # Worker is ineligible due to include/exclude rules on this task + WORKER_INELIGIBLE = 3 + # Worker is excluded due to prior participation in another task via exclusion rules + WORKER_EXCLUDED = 4 + # Worker is rate limited + WORKER_RATE_LIMITED = 5 + # Task is rate limited + TASK_RATE_LIMITED = 6 + + # Worker has previously entered this task + RE_ENTRY = 7 + + # Worker was sent to a task which does not exist + TASK_NOT_FOUND = 10 + + # Task is no longer live + TASK_NOT_AVAILABLE = 11 + + # Task is not funded. (this shouldn't happen because if the task is live, it must be funded) + TASK_NOT_FUNDED = 12 + + # The task's required_finish_count was reached. + TASK_NO_FINISHES_AVAILABLE = 18 + + # The upper_limit was met within a task's connectors (e.g. within matching quotas) + TASK_CONNECTOR_NO_FINISHES_AVAILABLE = 14 + + # Worker was sent to a Task with Allocation(s) which are not valid. (quota specified + # is not associated with the task specified) + INVALID_ALLOCATION_SELECTION = 15 + + # Worker was sent to a previous version of Task. (not implemented) + TASK_VERSION_MISMATCH = 16 + + # Not eligible due to not passing the task's task_group_filters + TASK_GROUP_FILTERS_FAIL = 17 + + # Not eligible due to not passing the task's respondent filters + RESPONDENT_FILTERS_FAIL = 18 + + # Not eligible due to not passing the task's scheduled fielding + SCHEDULED_FIELDING_FAIL = 19 + + +def check_wxet_status_consistent( + status: WXETStatus, + status_code_1: Optional[WXETStatusCode1] = None, + status_code_2: Optional[WXETStatusCode2] = None, +) -> bool: + """ + Raises an AssertionError if inconsistent + """ + + if status == WXETStatus.COMPLETE: + assert ( + status_code_1 == WXETStatusCode1.COMPLETE + ), "Invalid StatusCode1 when Status=COMPLETE. Use WXETStatusCode1.COMPLETE" + + if status == WXETStatus.ABANDON: + assert status_code_1 in { + WXETStatusCode1.WXET_ABANDON, + WXETStatusCode1.BUYER_ABANDON, + }, "Invalid StatusCode1 when Status=ABANDON. Use WXET_ABANDON or BUYER_ABANDON" + + if status == WXETStatus.FAIL: + # status_code_1 can be anything except complete or abandon + assert status_code_1 not in { + WXETStatusCode1.COMPLETE, + WXETStatusCode1.WXET_ABANDON, + WXETStatusCode1.BUYER_ABANDON, + }, "Invalid StatusCode1 when Status=FAIL." + + # (Currently), status code 2 is only used if WXETStatusCode1 is wxet fail + if status_code_2 is not None: + assert ( + status_code_1 == WXETStatusCode1.WXET_FAIL + ), "status_code_1 should be WXET_FAIL if a status_code_2 is set" + + return True + + +def check_wxet_adjusted_status_attempt_consistent( + status: WXETStatus, + status_code_1: Optional[WXETStatusCode1] = None, + cpi: Optional[USDMill] = None, + adjusted_status: Optional[WXETAdjustedStatus] = None, + adjusted_cpi: Optional[USDMill] = None, + new_adjusted_status: Optional[WXETAdjustedStatus] = None, + new_adjusted_cpi: Optional[USDMill] = None, +) -> Tuple[bool, str]: + """ + Raises an AssertionError if inconsistent. + - status, status_code_1, adjusted_status, adjusted_cpi, cpi are the attempt's CURRENT values + - new_adjusted_status & new_adjusted_cpi are attempting to be set + We are checking if the adjustment is allowed, based on the attempt's current status. + """ + try: + _check_wxet_adjusted_status_attempt_consistent( + status=status, + status_code_1=status_code_1, + cpi=cpi, + adjusted_status=adjusted_status, + adjusted_cpi=adjusted_cpi, + new_adjusted_status=new_adjusted_status, + new_adjusted_cpi=new_adjusted_cpi, + ) + except AssertionError as e: + return False, str(e) + return True, "" + + +def _check_wxet_adjusted_status_attempt_consistent( + status: WXETStatus, + status_code_1: Optional[WXETStatusCode1] = None, + cpi: Optional[USDMill] = None, + adjusted_status: Optional[WXETAdjustedStatus] = None, + adjusted_cpi: Optional[USDMill] = None, + new_adjusted_status: Optional[WXETAdjustedStatus] = None, + new_adjusted_cpi: Optional[USDMill] = None, +) -> None: + """ + Raises an AssertionError if inconsistent. + - status, status_code_1, adjusted_status, adjusted_cpi, cpi are the attempt's CURRENT values + - new_adjusted_status & new_adjusted_cpi are attempting to be set + We are checking if the adjustment is allowed, based on the attempt's current status. + """ + # Check the original attempt actually even entered the client survey + if status_code_1 and status_code_1.is_pre_task_entry_fail: + raise AssertionError("pre-task entry fail, can't adjust status") + + # Check that we're actually changing something + if adjusted_status == new_adjusted_status and adjusted_cpi == new_adjusted_cpi: + raise AssertionError(f"attempt is already {adjusted_status=}, {adjusted_cpi=}") + + # adjusted_status/adjusted_cpi agreement + _check_wxet_adjusted_status_consistent( + adjusted_status=new_adjusted_status, adjusted_cpi=new_adjusted_cpi + ) + if new_adjusted_status == WXETAdjustedStatus.ADJUSTED_TO_FAIL: + assert new_adjusted_cpi == USDMill( + 0 + ), "adjusted_cpi should be 0 if adjusted_status is ADJUSTED_TO_FAIL" + elif new_adjusted_status == WXETAdjustedStatus.ADJUSTED_TO_COMPLETE: + assert ( + new_adjusted_cpi == cpi + ), "adjusted_cpi should be equal to the original cpi if adjusted_status is ADJUSTED_TO_COMPLETE" + elif new_adjusted_status == WXETAdjustedStatus.CPI_ADJUSTMENT: + assert new_adjusted_cpi != cpi and new_adjusted_cpi != USDMill( + 0 + ), "adjusted_cpi should be different than the original cpi if CPI_ADJUSTMENT" + elif adjusted_status is None: + # It'll be None if we are going, for e.g. Complete -> Fail -> Complete + assert new_adjusted_cpi is None, "adjusted_cpi should be None" + + # status / adjusted_status agreement + if status == WXETStatus.COMPLETE: + assert ( + new_adjusted_status != WXETAdjustedStatus.ADJUSTED_TO_COMPLETE + ), "adjusted status can't be ADJUSTED_TO_COMPLETE if the status is COMPLETE" + elif status == WXETStatus.FAIL: + assert ( + new_adjusted_status != WXETAdjustedStatus.ADJUSTED_TO_FAIL + ), "adjusted status can't be ADJUSTED_TO_FAIL if the status is FAIL" + else: + # status is None/timeout/abandon, which we treat as a fail anyway + assert ( + new_adjusted_status != WXETAdjustedStatus.ADJUSTED_TO_FAIL + ), "attempt is already a failure" + + # adjusted_status / new_adjusted_status agreement + if new_adjusted_status == WXETAdjustedStatus.CPI_ADJUSTMENT: + assert ( + new_adjusted_cpi != adjusted_cpi + ), f"adjusted_cpi is already {adjusted_cpi}" + + +def _check_wxet_adjusted_status_consistent( + adjusted_status: Optional[WXETAdjustedStatus] = None, + adjusted_cpi: Optional[USDMill] = None, +) -> None: + """ + Raises an AssertionError if inconsistent. + - adjusted_status & adjusted_cpi are attempting to be set + """ + # adjusted_status/adjusted_cpi agreement + adjusted_cpi = adjusted_cpi if adjusted_cpi is not None else USDMill(0) + if adjusted_status == WXETAdjustedStatus.ADJUSTED_TO_FAIL: + assert adjusted_cpi == USDMill( + 0 + ), "adjusted_cpi should be 0 if adjusted_status is ADJUSTED_TO_FAIL" + elif adjusted_status == WXETAdjustedStatus.ADJUSTED_TO_COMPLETE: + assert adjusted_cpi != USDMill( + 0 + ), "adjusted_cpi should be equal to the original cpi if adjusted_status is ADJUSTED_TO_COMPLETE" + elif adjusted_status == WXETAdjustedStatus.CPI_ADJUSTMENT: + assert adjusted_cpi != USDMill( + 0 + ), "adjusted_cpi should be different than the original cpi if CPI_ADJUSTMENT" + elif adjusted_status is None: + # It'll be None if we are going, for e.g. Complete -> Fail -> Complete + pass diff --git a/generalresearch/wxet/models/finish_type.py b/generalresearch/wxet/models/finish_type.py new file mode 100644 index 0000000..2aa4c7f --- /dev/null +++ b/generalresearch/wxet/models/finish_type.py @@ -0,0 +1,94 @@ +from enum import Enum +from typing import Set, Optional + +from generalresearch.utils.enum import ReprEnumMeta +from generalresearch.wxet.models.definitions import WXETStatus, WXETStatusCode1 + + +class FinishType(str, Enum, metaclass=ReprEnumMeta): + """A Task can be classified as "finished" based on different outcomes. +
+ This controls how the `Task.required_finish_count` value + is consumed, which informs Suppliers how many "remaining spots" are + available for this Task.
+ Note: This has nothing to do with deciding when a worker will be paid, it is + only for counting attempts towards the required_finish_count. + """ + + # Worker made it to the client Task (eg: Pay a Worker to go to a website, + # view content, or otherwise enter a Task that may not redirect them) + # An entrance only "counts" if the user was redirected into the client's + # task. This is indicated by WXETStatusCode1 >= 10. + # This is also commonly referred to as "clicks". + ENTRANCE = "entrance" + + # Buyer reports the attempt status as Complete + COMPLETE = "complete" + + # Buyer reports the attempt status as Fail or Complete. This is + # used when Abandons are not wanted to be counted. + FAIL_OR_COMPLETE = "fail_or_complete" + + # Buyer reports the attempt status as Fail + FAIL = "fail" + + @property + def finish_statuses(self) -> Set[Optional[WXETStatus]]: + """For this particular FinishType, what are the different WXETStatus + values that are consider + """ + + match self: + case FinishType.ENTRANCE: + # When an attempt occurs and the user has not yet returned, the status will be None. + # This has to count towards the finish count if the FinishType is ENTRANCE. + return { + None, + WXETStatus.ABANDON, + WXETStatus.COMPLETE, + WXETStatus.FAIL, + WXETStatus.TIMEOUT, + } + + case FinishType.COMPLETE: + return {WXETStatus.COMPLETE} + + case FinishType.FAIL: + return {WXETStatus.FAIL} + + case FinishType.FAIL_OR_COMPLETE: + return {WXETStatus.FAIL, WXETStatus.COMPLETE} + + case _: + raise ValueError() + + +def is_a_finish( + status: Optional[WXETStatus], + status_code_1: Optional[WXETStatusCode1], + finish_type: Optional[FinishType], +) -> bool: + """Determines if a wall event should be considered a finish or not. + + :param status: The status of the wall event. + :param status_code_1: The status_code_1 of the wall event. + :param finish_type: The finish_type of the task. + """ + + if status: + assert isinstance(status, WXETStatus), "Invalid status" + + if status_code_1: + assert isinstance(status_code_1, WXETStatusCode1), "Invalid status_code_1" + + if status is None: + assert status_code_1 is None, "Cannot provide status_code_1 without a status" + + # If the Worker never entered the Task, then it is not a Finish, + # regardless of the FinishType. This `is_pre_task_entry_fail` tells us + # if the Worker ever left WXET and actually made it into the Task + # experience. + if status_code_1 is not None and status_code_1.is_pre_task_entry_fail: + return False + + return status in finish_type.finish_statuses diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..daa79d1 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,6 @@ +[mypy] +python_version = 3.13 +strict_equality = True +allow_untyped_defs = True +disable_error_code = union-attr, import-untyped +allow_redefinition = True \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8039ae1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,57 @@ +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "generalresearch" +version = "3.3.4" +description = "Python Utilities for General Research" +readme = "README.md" +requires-python = ">=3.8" +dependencies = [ + "Faker", + "PyMySQL", + "psycopg", + "cachetools", + "decorator", + "limits", + "more-itertools", + "numpy", + "pandera", + "protobuf", + "pycountry", + "pydantic-extra-types", + "pydantic-settings", + "pydantic[email]", + "pytest", + "pylibmc", + "pymemcache", + "redis", + "requests", + "scipy", + "sentry-sdk", + "slackclient", + "ua-parser", + "user-agents", + "wrapt", +] +[project.optional-dependencies] +django = [ + "Django>=5.2", + "psycopg>=3.1", +] + + +[tool.setuptools.packages.find] +where = ["."] +include = [ + "generalresearch", + "generalresearch.*", + "test_utils", + "test_utils.*", +] + +[tool.setuptools.package-data] +"generalresearch" = ["locales/*.json", "resources/*.csv"] +"test_utils" = ["managers/upk/*.gz"] + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7a80011 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,108 @@ +aiohappyeyeballs==2.6.1 +aiohttp==3.12.15 +aiosignal==1.4.0 +annotated-types==0.7.0 +anyio==4.10.0 +attrs==25.3.0 +boto3==1.40.19 +botocore==1.40.19 +CacheControl==0.14.3 +cachetools==6.1.0 +certifi==2025.8.3 +cffi==1.17.1 +charset-normalizer==3.4.3 +click==8.2.2 +cloudpickle==3.1.1 +coverage==7.10.5 +cryptography==45.0.6 +dask==2025.7.0 +decorator==5.2.1 +Deprecated==1.2.18 +distributed==2025.7.0 +dnspython==2.7.0 +ecdsa==0.19.1 +email-validator==2.3.0 +Faker==37.6.0 +frozenlist==1.7.0 +fsspec==2025.7.0 +geoip2==4.7.0 +idna==3.10 +importlib_metadata==8.7.0 +iniconfig==2.1.0 +Jinja2==3.1.6 +jmespath==1.0.1 +jsonpickle==5.0.0rc1 +limits==5.5.0 +locket==1.0.0 +MarkupSafe==3.0.2 +maxminddb==2.8.2 +more-itertools==10.7.0 +msgpack==1.1.1 +multidict==6.6.4 +mypy_extensions==1.1.0 +numpy==2.3.2 +opentelemetry-api==1.36.0 +opentelemetry-sdk==1.36.0 +opentelemetry-semantic-conventions==0.57b0 +outcome==1.3.0.post0 +packaging==25.0 +pandas==2.3.2 +pandera==0.26.1 +partd==1.4.2 +phonenumbers==9.0.12 +pluggy==1.6.0 +propcache==0.3.2 +protobuf==6.32.0 +psutil==7.0.0 +psycopg==3.2.9 +psycopg-binary==3.2.9 +pyarrow==21.0.0 +pyasn1==0.6.1 +pycountry==24.6.1 +pycparser==2.22 +pydantic==2.11.7 +pydantic-extra-types==2.10.5 +pydantic-settings==2.10.1 +pydantic_core==2.33.2 +Pygments==2.19.2 +pylibmc==1.6.3 +pymemcache==4.0.0 +PyMySQL==1.1.1 +pytest==8.4.1 +pytest-anyio==0.0.0 +pytest-cov==6.2.1 +python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 +python-jose==3.5.0 +pytz==2025.2 +PyYAML==6.0.2 +redis==6.4.0 +requests==2.32.5 +rsa==4.9.1 +s3transfer==0.13.1 +scipy==1.16.1 +sentry-sdk==3.0.0a5 +setuptools==80.9.0 +six==1.17.0 +slackclient==2.9.4 +sniffio==1.3.1 +sortedcontainers==2.4.0 +tblib==3.1.0 +toolz==1.0.0 +tornado==6.5.2 +trio==0.30.0 +typeguard==4.4.4 +typing-inspect==0.9.0 +typing-inspection==0.4.1 +typing_extensions==4.15.0 +tzdata==2025.2 +ua-parse==1.0.1 +ua-parser==1.0.1 +ua-parser-builtins==0.19.0.dev79 +urllib3==2.5.0 +user-agents==2.2.0 +wheel==0.46.1 +wrapt==1.17.3 +yarl==1.20.1 +zict==3.0.0 +zipp==3.23.0 diff --git a/test_utils/__init__.py b/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/conftest.py b/test_utils/conftest.py new file mode 100644 index 0000000..7acafc5 --- /dev/null +++ b/test_utils/conftest.py @@ -0,0 +1,310 @@ +import os +import shutil +from datetime import datetime, timezone +from os.path import join as pjoin +from typing import TYPE_CHECKING, Callable +from uuid import uuid4 + +import pytest +import redis +from dotenv import load_dotenv +from pydantic import MariaDBDsn +from redis import Redis + +from generalresearch.pg_helper import PostgresConfig +from generalresearch.redis_helper import RedisConfig +from generalresearch.sql_helper import SqlHelper + +if TYPE_CHECKING: + from generalresearch.currency import USDCent + from generalresearch.models.thl.session import Status + from generalresearch.config import GRLBaseSettings + + +@pytest.fixture(scope="session") +def env_file_path(pytestconfig): + root_path = pytestconfig.rootpath + env_path = os.path.join(root_path, ".env.test") + + if os.path.exists(env_path): + load_dotenv(dotenv_path=env_path, override=True) + + return env_path + + +@pytest.fixture(scope="session") +def settings(env_file_path) -> "GRLBaseSettings": + from generalresearch.config import GRLBaseSettings + + s = GRLBaseSettings(_env_file=env_file_path) + + if s.thl_mkpl_rr_db is not None: + if s.spectrum_rw_db is None: + s.spectrum_rw_db = MariaDBDsn(f"{s.thl_mkpl_rw_db}unittest-thl-spectrum") + if s.spectrum_rr_db is None: + s.spectrum_rr_db = MariaDBDsn(f"{s.thl_mkpl_rr_db}unittest-thl-spectrum") + + s.mnt_gr_api_dir = pjoin("/tmp", f"test-{uuid4().hex[:12]}") + + return s + + +# === Database Connectors === + + +@pytest.fixture(scope="session") +def thl_web_rr(settings) -> PostgresConfig: + assert "/unittest-" in settings.thl_web_rr_db.path + + return PostgresConfig( + dsn=settings.thl_web_rr_db, + connect_timeout=1, + statement_timeout=5, + ) + + +@pytest.fixture(scope="session") +def thl_web_rw(settings) -> PostgresConfig: + assert "/unittest-" in settings.thl_web_rw_db.path + + return PostgresConfig( + dsn=settings.thl_web_rw_db, + connect_timeout=1, + statement_timeout=5, + ) + + +@pytest.fixture(scope="session") +def gr_db(settings) -> PostgresConfig: + assert "/unittest-" in settings.gr_db.path + return PostgresConfig(dsn=settings.gr_db, connect_timeout=5, statement_timeout=2) + + +@pytest.fixture(scope="session") +def spectrum_rw(settings) -> SqlHelper: + assert "/unittest-" in settings.spectrum_rw_db.path + + return SqlHelper( + dsn=settings.spectrum_rw_db, + read_timeout=2, + write_timeout=1, + connect_timeout=2, + ) + + +@pytest.fixture(scope="session") +def grliq_db(settings) -> PostgresConfig: + assert "/unittest-" in settings.grliq_db.path + + # test_words = {"localhost", "127.0.0.1", "unittest", "grliq-test"} + # assert any(w in str(postgres_config.dsn) for w in test_words), "check grliq postgres_config" + # assert "grliqdeceezpocymo" not in str(postgres_config.dsn), "check grliq postgres_config" + + return PostgresConfig( + dsn=settings.grliq_db, + connect_timeout=2, + statement_timeout=2, + ) + + +@pytest.fixture(scope="session") +def thl_redis(settings) -> "Redis": + # todo: this should get replaced with redisconfig (in most places) + # I'm not sure where this would be? in the domain name? + assert "unittest" in str(settings.thl_redis) or "127.0.0.1" in str( + settings.thl_redis + ) + + return redis.Redis.from_url( + **{ + "url": str(settings.thl_redis), + "decode_responses": True, + "socket_timeout": settings.redis_timeout, + "socket_connect_timeout": settings.redis_timeout, + } + ) + + +@pytest.fixture(scope="session") +def thl_redis_config(settings) -> RedisConfig: + assert "unittest" in str(settings.thl_redis) or "127.0.0.1" in str( + settings.thl_redis + ) + return RedisConfig( + dsn=settings.thl_redis, + decode_responses=True, + socket_timeout=settings.redis_timeout, + socket_connect_timeout=settings.redis_timeout, + ) + + +@pytest.fixture(scope="session") +def gr_redis_config(settings) -> "RedisConfig": + assert "unittest" in str(settings.gr_redis) or "127.0.0.1" in str(settings.gr_redis) + + return RedisConfig( + dsn=settings.gr_redis, + decode_responses=True, + socket_timeout=settings.redis_timeout, + socket_connect_timeout=settings.redis_timeout, + ) + + +@pytest.fixture(scope="session") +def gr_redis(settings) -> "Redis": + assert "unittest" in str(settings.gr_redis) or "127.0.0.1" in str(settings.gr_redis) + return redis.Redis.from_url( + **{ + "url": str(settings.gr_redis), + "decode_responses": True, + "socket_timeout": settings.redis_timeout, + "socket_connect_timeout": settings.redis_timeout, + } + ) + + +@pytest.fixture() +def gr_redis_async(settings): + assert "unittest" in str(settings.gr_redis) or "127.0.0.1" in str(settings.gr_redis) + + import redis.asyncio as redis_async + + return redis_async.Redis.from_url( + str(settings.gr_redis), + decode_responses=True, + socket_timeout=0.20, + socket_connect_timeout=0.20, + ) + + +# === Random helpers === + + +@pytest.fixture(scope="function") +def start(): + return datetime(year=1900, month=1, day=1, tzinfo=timezone.utc) + + +@pytest.fixture +def wall_status(request) -> "Status": + from generalresearch.models.thl.session import Status + + return request.param if hasattr(request, "wall_status") else Status.COMPLETE + + +@pytest.fixture(scope="function") +def utc_now() -> "datetime": + from datetime import datetime, timezone + + return datetime.now(tz=timezone.utc) + + +@pytest.fixture(scope="function") +def utc_hour_ago() -> "datetime": + from datetime import datetime, timezone, timedelta + + return datetime.now(tz=timezone.utc) - timedelta(hours=1) + + +@pytest.fixture(scope="function") +def utc_day_ago() -> "datetime": + from datetime import datetime, timezone, timedelta + + return datetime.now(tz=timezone.utc) - timedelta(hours=24) + + +@pytest.fixture(scope="function") +def utc_90days_ago() -> "datetime": + from datetime import datetime, timezone, timedelta + + return datetime.now(tz=timezone.utc) - timedelta(days=90) + + +@pytest.fixture(scope="function") +def utc_60days_ago() -> "datetime": + from datetime import datetime, timezone, timedelta + + return datetime.now(tz=timezone.utc) - timedelta(days=60) + + +@pytest.fixture(scope="function") +def utc_30days_ago() -> "datetime": + from datetime import datetime, timezone, timedelta + + return datetime.now(tz=timezone.utc) - timedelta(days=30) + + +# === Clean up === + + +@pytest.fixture(scope="function") +def delete_df_collection(thl_web_rw, create_main_accounts) -> Callable: + from generalresearch.incite.collections import ( + DFCollection, + DFCollectionType, + ) + + def _teardown_events(coll: "DFCollection"): + match coll.data_type: + case DFCollectionType.LEDGER: + for table in [ + "ledger_transactionmetadata", + "ledger_entry", + "ledger_transaction", + "ledger_account", + ]: + thl_web_rw.execute_write( + query=f"DELETE FROM {table};", + ) + create_main_accounts() + + case DFCollectionType.WALL | DFCollectionType.SESSION: + with thl_web_rw.make_connection() as conn: + with conn.cursor() as c: + c.execute("SET CONSTRAINTS ALL DEFERRED") + for table in [ + "thl_wall", + "thl_session", + ]: + c.execute( + query=f"DELETE FROM {table};", + ) + + case DFCollectionType.USER: + for table in ["thl_usermetadata", "thl_user"]: + thl_web_rw.execute_write( + query=f"DELETE FROM {table};", + ) + + case _: + thl_web_rw.execute_write( + query=f"DELETE FROM {coll.data_type.value};", + ) + + return _teardown_events + + +# === GR Related === + + +@pytest.fixture(scope="function") +def amount_1(request) -> "USDCent": + from generalresearch.currency import USDCent + + return USDCent(1) + + +@pytest.fixture(scope="function") +def amount_100(request) -> "USDCent": + from generalresearch.currency import USDCent + + return USDCent(100) + + +def clear_directory(path): + for entry in os.listdir(path): + full_path = os.path.join(path, entry) + if os.path.isfile(full_path) or os.path.islink(full_path): + os.unlink(full_path) # remove file or symlink + elif os.path.isdir(full_path): + shutil.rmtree(full_path) # remove folder diff --git a/test_utils/grliq/__init__.py b/test_utils/grliq/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/grliq/conftest.py b/test_utils/grliq/conftest.py new file mode 100644 index 0000000..1818794 --- /dev/null +++ b/test_utils/grliq/conftest.py @@ -0,0 +1,28 @@ +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING +from uuid import uuid4 + +import pytest + + +if TYPE_CHECKING: + from generalresearch.grliq.models.forensic_data import GrlIqData + + +@pytest.fixture(scope="function") +def mnt_grliq_archive_dir(settings): + return settings.mnt_grliq_archive_dir + + +@pytest.fixture(scope="function") +def grliq_data() -> "GrlIqData": + from generalresearch.grliq.models.forensic_data import GrlIqData + from generalresearch.grliq.managers import DUMMY_GRLIQ_DATA + + g: GrlIqData = DUMMY_GRLIQ_DATA[1]["data"] + + g.id = None + g.uuid = uuid4().hex + g.created_at = datetime.now(tz=timezone.utc) + g.timestamp = g.created_at - timedelta(seconds=10) + return g diff --git a/test_utils/grliq/managers/__init__.py b/test_utils/grliq/managers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/grliq/managers/conftest.py b/test_utils/grliq/managers/conftest.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/grliq/models/__init__.py b/test_utils/grliq/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/grliq/models/conftest.py b/test_utils/grliq/models/conftest.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/incite/__init__.py b/test_utils/incite/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/incite/collections/__init__.py b/test_utils/incite/collections/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/incite/collections/conftest.py b/test_utils/incite/collections/conftest.py new file mode 100644 index 0000000..1b61ed5 --- /dev/null +++ b/test_utils/incite/collections/conftest.py @@ -0,0 +1,205 @@ +from datetime import timedelta, datetime +from typing import TYPE_CHECKING, Optional, Callable + +import pytest + +from test_utils.incite.conftest import mnt_filepath +from test_utils.conftest import clear_directory + +if TYPE_CHECKING: + from generalresearch.incite.collections import DFCollection + from generalresearch.incite.base import GRLDatasets, DFCollectionType + from generalresearch.incite.collections.thl_web import LedgerDFCollection + from generalresearch.incite.collections.thl_web import ( + WallDFCollection, + SessionDFCollection, + TaskAdjustmentDFCollection, + UserDFCollection, + AuditLogDFCollection, + ) + + +@pytest.fixture(scope="function") +def user_collection( + mnt_filepath: "GRLDatasets", + offset: str, + duration: timedelta, + start: datetime, + thl_web_rr, +) -> "UserDFCollection": + from generalresearch.incite.collections.thl_web import ( + UserDFCollection, + DFCollectionType, + ) + + return UserDFCollection( + start=start, + finished=start + duration, + offset=offset, + pg_config=thl_web_rr, + archive_path=mnt_filepath.archive_path(enum_type=DFCollectionType.USER), + ) + + +@pytest.fixture(scope="function") +def wall_collection( + mnt_filepath: "GRLDatasets", + offset: str, + duration: timedelta, + start: datetime, + thl_web_rr, +) -> "WallDFCollection": + from generalresearch.incite.collections.thl_web import ( + WallDFCollection, + DFCollectionType, + ) + + return WallDFCollection( + start=start, + finished=start + duration if duration else None, + offset=offset, + pg_config=thl_web_rr, + archive_path=mnt_filepath.archive_path(enum_type=DFCollectionType.WALL), + ) + + +@pytest.fixture(scope="function") +def session_collection( + mnt_filepath: "GRLDatasets", + offset: str, + duration: timedelta, + start: datetime, + thl_web_rr, +) -> "SessionDFCollection": + from generalresearch.incite.collections.thl_web import ( + SessionDFCollection, + DFCollectionType, + ) + + return SessionDFCollection( + start=start, + finished=start + duration if duration else None, + offset=offset, + pg_config=thl_web_rr, + archive_path=mnt_filepath.archive_path(enum_type=DFCollectionType.SESSION), + ) + + +# IPInfoDFCollection +# IPHistoryDFCollection +# IPHistoryWSDFCollection + +# @pytest.fixture +# def ip_history_collection(mnt_filepath, offset, duration, start, +# thl_web_rw) -> IPHistoryDFCollection: +# return IPHistoryDFCollection( +# start=start, +# finished=start + duration, +# offset=offset, +# pg_config=thl_web_rw, +# archive_path=mnt_filepath.archive_path(enum_type=DFCollectionType.IP_HISTORY), +# ) + + +@pytest.fixture(scope="function") +def task_adj_collection( + mnt_filepath: "GRLDatasets", + offset: str, + duration: Optional[timedelta], + start: datetime, + thl_web_rr, +) -> "TaskAdjustmentDFCollection": + from generalresearch.incite.collections.thl_web import ( + TaskAdjustmentDFCollection, + DFCollectionType, + ) + + return TaskAdjustmentDFCollection( + start=start, + finished=start + duration if duration else duration, + offset=offset, + pg_config=thl_web_rr, + archive_path=mnt_filepath.archive_path( + enum_type=DFCollectionType.TASK_ADJUSTMENT + ), + ) + + +@pytest.fixture(scope="function") +def auditlog_collection( + mnt_filepath: "GRLDatasets", + offset: str, + duration: timedelta, + start: datetime, + thl_web_rr, +) -> "AuditLogDFCollection": + from generalresearch.incite.collections.thl_web import ( + AuditLogDFCollection, + DFCollectionType, + ) + + return AuditLogDFCollection( + start=start, + finished=start + duration, + offset=offset, + pg_config=thl_web_rr, + archive_path=mnt_filepath.archive_path(enum_type=DFCollectionType.LEDGER), + ) + + +@pytest.fixture(scope="function") +def ledger_collection( + mnt_filepath: "GRLDatasets", + offset: str, + duration: timedelta, + start: datetime, + thl_web_rr, +) -> "LedgerDFCollection": + from generalresearch.incite.collections.thl_web import ( + LedgerDFCollection, + DFCollectionType, + ) + + return LedgerDFCollection( + start=start, + finished=start + duration if duration else duration, + offset=offset, + pg_config=thl_web_rr, + archive_path=mnt_filepath.archive_path(enum_type=DFCollectionType.LEDGER), + ) + + +@pytest.fixture(scope="function") +def rm_ledger_collection(ledger_collection) -> Callable: + def _rm_ledger_collection(): + clear_directory(ledger_collection.archive_path) + + return _rm_ledger_collection + + +# -------------------------- +# Generic / Base +# -------------------------- + + +@pytest.fixture(scope="function") +def df_collection( + mnt_filepath, + df_collection_data_type: "DFCollectionType", + offset, + duration, + utc_90days_ago, + thl_web_rr, +) -> "DFCollection": + from generalresearch.incite.collections import DFCollection + + start = utc_90days_ago.replace(microsecond=0) + + return DFCollection( + data_type=df_collection_data_type, + archive_path=mnt_filepath.archive_path(enum_type=df_collection_data_type), + offset=offset, + pg_config=thl_web_rr, + start=start, + finished=start + duration, + ) diff --git a/test_utils/incite/conftest.py b/test_utils/incite/conftest.py new file mode 100644 index 0000000..759467a --- /dev/null +++ b/test_utils/incite/conftest.py @@ -0,0 +1,201 @@ +from datetime import datetime, timezone, timedelta +from os.path import join as pjoin +from pathlib import Path +from random import choice as randchoice +from shutil import rmtree +from typing import Callable, TYPE_CHECKING, Optional +from uuid import uuid4 + +import pytest +from _pytest.fixtures import SubRequest +from faker import Faker + +from test_utils.managers.ledger.conftest import session_with_tx_factory +from test_utils.models.conftest import session_factory + +if TYPE_CHECKING: + from generalresearch.models.thl.user import User + from generalresearch.incite.base import GRLDatasets + from generalresearch.incite.mergers import MergeType + from generalresearch.incite.collections import ( + DFCollection, + DFCollectionType, + DFCollectionItem, + ) + +fake = Faker() + + +@pytest.fixture(scope="function") +def mnt_gr_api_dir(request: SubRequest, settings): + p = Path(settings.mnt_gr_api_dir) + p.mkdir(parents=True, exist_ok=True) + + from generalresearch.models.admin.request import ReportType + + for e in list(ReportType): + Path(pjoin(p, e.value)).mkdir(exist_ok=True) + + def tmp_file_teardown(): + assert "/mnt/" not in str(p), ( + "Under no condition, testing or otherwise should we have code delete " + " any folders or potential data on a network mount" + ) + + rmtree(p) + + request.addfinalizer(tmp_file_teardown) + + return p + + +@pytest.fixture(scope="function") +def event_report_request(utc_hour_ago, start): + from generalresearch.models.admin.request import ( + ReportRequest, + ReportType, + ) + + return ReportRequest.model_validate( + { + "report_type": ReportType.POP_EVENT, + "interval": "5min", + "start": start, + } + ) + + +@pytest.fixture(scope="function") +def session_report_request(utc_hour_ago, start): + from generalresearch.models.admin.request import ( + ReportRequest, + ReportType, + ) + + return ReportRequest.model_validate( + { + "report_type": ReportType.POP_SESSION, + "interval": "5min", + "start": start, + } + ) + + +@pytest.fixture(scope="function") +def mnt_filepath(request: SubRequest) -> "GRLDatasets": + """Creates a temporary file path for all DFCollections & Mergers parquet + files. + """ + from generalresearch.incite.base import GRLDatasets, NFSMount + + instance = GRLDatasets( + data_src=Path(pjoin("/tmp", f"test-{uuid4().hex[:12]}")), + incite=NFSMount(point="thl-incite"), + ) + + def tmp_file_teardown(): + assert "/mnt/" not in str(instance.data_src), ( + "Under no condition, testing or otherwise should we have code delete " + " any folders or potential data on a network mount" + ) + + rmtree(instance.data_src) + + request.addfinalizer(tmp_file_teardown) + + return instance + + +@pytest.fixture(scope="function") +def start(utc_90days_ago) -> "datetime": + s = utc_90days_ago.replace(microsecond=0) + return s + + +@pytest.fixture(scope="function") +def offset() -> str: + return "15min" + + +@pytest.fixture(scope="function") +def duration() -> Optional["timedelta"]: + return timedelta(hours=1) + + +@pytest.fixture(scope="function") +def df_collection_data_type() -> "DFCollectionType": + from generalresearch.incite.collections import DFCollectionType + + return DFCollectionType.TEST + + +@pytest.fixture(scope="function") +def merge_type() -> "MergeType": + from generalresearch.incite.mergers import MergeType + + return MergeType.TEST + + +@pytest.fixture(scope="function") +def incite_item_factory( + session_factory, + product, + user_factory, + session_with_tx_factory, +) -> Callable: + def _incite_item_factory( + item: "DFCollectionItem", + observations: int = 3, + user: Optional["User"] = None, + ): + from generalresearch.incite.collections import ( + DFCollection, + DFCollectionType, + ) + from generalresearch.models.thl.session import Source + + collection: DFCollection = item._collection + data_type: DFCollectionType = collection.data_type + + for idx in range(5): + item_time = fake.date_time_between( + start_date=item.start, end_date=item.finish, tzinfo=timezone.utc + ) + + match data_type: + case DFCollectionType.USER: + user_factory(product=product, created=item_time) + + case DFCollectionType.LEDGER: + session_with_tx_factory(started=item_time, user=user) + + case DFCollectionType.WALL: + u = ( + user + if user + else user_factory(product=product, created=item_time) + ) + session_factory( + user=u, + started=item_time, + wall_source=randchoice(list(Source)), + ) + + case DFCollectionType.SESSION: + u = ( + user + if user + else user_factory(product=product, created=item_time) + ) + session_factory( + user=u, + started=item_time, + wall_source=randchoice(list(Source)), + ) + + case _: + raise ValueError("Unsupported DFCollectionItem") + + return None + + return _incite_item_factory diff --git a/test_utils/incite/mergers/__init__.py b/test_utils/incite/mergers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/incite/mergers/conftest.py b/test_utils/incite/mergers/conftest.py new file mode 100644 index 0000000..e4e3bdd --- /dev/null +++ b/test_utils/incite/mergers/conftest.py @@ -0,0 +1,247 @@ +from datetime import timedelta, datetime +from typing import TYPE_CHECKING, Optional, Callable + +import pytest + +from test_utils.conftest import clear_directory +from test_utils.incite.conftest import mnt_filepath + +if TYPE_CHECKING: + from generalresearch.incite.mergers import MergeType + from generalresearch.incite.mergers.ym_wall_summary import ( + YMWallSummaryMerge, + YMWallSummaryMergeItem, + ) + from generalresearch.incite.mergers.pop_ledger import PopLedgerMerge + from generalresearch.incite.mergers.ym_survey_wall import YMSurveyWallMerge + from generalresearch.incite.base import GRLDatasets + from generalresearch.incite.mergers.foundations.enriched_session import ( + EnrichedSessionMerge, + ) + from generalresearch.incite.mergers.foundations.enriched_task_adjust import ( + EnrichedTaskAdjustMerge, + ) + from generalresearch.incite.mergers.foundations.enriched_wall import ( + EnrichedWallMerge, + ) + from generalresearch.incite.mergers.foundations.user_id_product import ( + UserIdProductMerge, + ) + from generalresearch.incite.mergers.ym_survey_wall import ( + YMSurveyWallMergeCollectionItem, + ) + + +# -------------------------- +# Merges +# -------------------------- + + +@pytest.fixture(scope="function") +def rm_pop_ledger_merge(pop_ledger_merge) -> Callable: + def _rm_pop_ledger_merge(): + clear_directory(pop_ledger_merge.archive_path) + + return _rm_pop_ledger_merge + + +@pytest.fixture(scope="function") +def pop_ledger_merge( + mnt_filepath: "GRLDatasets", + offset: str, + start: datetime, + duration: timedelta, +) -> "PopLedgerMerge": + from generalresearch.incite.mergers.pop_ledger import PopLedgerMerge + from generalresearch.incite.mergers import MergeType + + return PopLedgerMerge( + start=start, + finished=start + duration if duration else None, + offset=offset, + archive_path=mnt_filepath.archive_path(enum_type=MergeType.POP_LEDGER), + ) + + +@pytest.fixture(scope="function") +def pop_ledger_merge_item( + start, + pop_ledger_merge, +) -> "PopLedgerMergeItem": + from generalresearch.incite.mergers.pop_ledger import PopLedgerMergeItem + + return PopLedgerMergeItem( + start=start, + _collection=pop_ledger_merge, + ) + + +@pytest.fixture(scope="function") +def ym_survey_wall_merge( + mnt_filepath: "GRLDatasets", + start: datetime, +) -> "YMSurveyWallMerge": + from generalresearch.incite.mergers.ym_survey_wall import YMSurveyWallMerge + from generalresearch.incite.mergers import MergeType + + return YMSurveyWallMerge( + start=None, + offset="10D", + archive_path=mnt_filepath.archive_path(enum_type=MergeType.YM_SURVEY_WALL), + ) + + +@pytest.fixture(scope="function") +def ym_survey_wall_merge_item( + start, ym_survey_wall_merge +) -> "YMSurveyWallMergeCollectionItem": + from generalresearch.incite.mergers.ym_survey_wall import ( + YMSurveyWallMergeCollectionItem, + ) + + return YMSurveyWallMergeCollectionItem( + start=start, + _collection=pop_ledger_merge, + ) + + +@pytest.fixture(scope="function") +def ym_wall_summary_merge( + mnt_filepath: "GRLDatasets", + offset: str, + duration: timedelta, + start: datetime, +) -> "YMWallSummaryMerge": + from generalresearch.incite.mergers.ym_wall_summary import YMWallSummaryMerge + from generalresearch.incite.mergers import MergeType + + return YMWallSummaryMerge( + start=start, + finished=start + duration, + offset=offset, + archive_path=mnt_filepath.archive_path(enum_type=MergeType.POP_LEDGER), + ) + + +def ym_wall_summary_merge_item( + start, ym_wall_summary_merge +) -> "YMWallSummaryMergeItem": + from generalresearch.incite.mergers.ym_wall_summary import ( + YMWallSummaryMergeItem, + ) + + return YMWallSummaryMergeItem( + start=start, + _collection=pop_ledger_merge, + ) + + +# -------------------------- +# Merges: Foundations +# -------------------------- + + +@pytest.fixture(scope="function") +def enriched_session_merge( + mnt_filepath: "GRLDatasets", + offset: str, + duration: timedelta, + start: datetime, +) -> "EnrichedSessionMerge": + from generalresearch.incite.mergers.foundations.enriched_session import ( + EnrichedSessionMerge, + ) + from generalresearch.incite.mergers import MergeType + + return EnrichedSessionMerge( + start=start, + finished=start + duration if duration else None, + offset=offset, + archive_path=mnt_filepath.archive_path(enum_type=MergeType.ENRICHED_SESSION), + ) + + +@pytest.fixture(scope="function") +def enriched_task_adjust_merge( + mnt_filepath: "GRLDatasets", + offset: str, + duration: timedelta, + start: datetime, +) -> "EnrichedTaskAdjustMerge": + from generalresearch.incite.mergers.foundations.enriched_task_adjust import ( + EnrichedTaskAdjustMerge, + ) + from generalresearch.incite.mergers import MergeType + + return EnrichedTaskAdjustMerge( + start=start, + finished=start + duration, + offset=offset, + archive_path=mnt_filepath.archive_path( + enum_type=MergeType.ENRICHED_TASK_ADJUST + ), + ) + + +@pytest.fixture(scope="function") +def enriched_wall_merge( + mnt_filepath: "GRLDatasets", + offset: str, + duration: timedelta, + start: datetime, +) -> "EnrichedWallMerge": + from generalresearch.incite.mergers import MergeType + from generalresearch.incite.mergers.foundations.enriched_wall import ( + EnrichedWallMerge, + ) + + return EnrichedWallMerge( + start=start, + finished=start + duration if duration else None, + offset=offset, + archive_path=mnt_filepath.archive_path(enum_type=MergeType.ENRICHED_WALL), + ) + + +@pytest.fixture(scope="function") +def user_id_product_merge( + mnt_filepath: "GRLDatasets", + duration: timedelta, + offset, + start: datetime, +) -> "UserIdProductMerge": + from generalresearch.incite.mergers.foundations.user_id_product import ( + UserIdProductMerge, + ) + from generalresearch.incite.mergers import MergeType + + return UserIdProductMerge( + start=start, + finished=start + duration, + offset=None, + archive_path=mnt_filepath.archive_path(enum_type=MergeType.USER_ID_PRODUCT), + ) + + +# -------------------------- +# Generic / Base +# -------------------------- + + +@pytest.fixture(scope="function") +def merge_collection( + mnt_filepath, + merge_type: "MergeType", + offset, + duration, + start, +): + from generalresearch.incite.mergers import MergeCollection + + return MergeCollection( + merge_type=merge_type, + start=start, + finished=start + duration, + offset=offset, + archive_path=mnt_filepath.archive_path(enum_type=merge_type), + ) diff --git a/test_utils/managers/__init__.py b/test_utils/managers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/managers/cashout_methods.py b/test_utils/managers/cashout_methods.py new file mode 100644 index 0000000..c338676 --- /dev/null +++ b/test_utils/managers/cashout_methods.py @@ -0,0 +1,76 @@ +from generalresearch.models.thl.wallet import PayoutType, Currency +from generalresearch.models.thl.wallet.cashout_method import ( + CashoutMethod, + TangoCashoutMethodData, + AmtCashoutMethodData, +) +import random + +from uuid import uuid4 + + +def random_ext_id(base: str = "U02"): + suffix = random.randint(0, 99999) + return f"{base}{suffix:05d}" + + +EXAMPLE_TANGO_CASHOUT_METHODS = [ + CashoutMethod( + id=uuid4().hex, + last_updated="2021-06-23T20:45:38.239182Z", + is_live=True, + type=PayoutType.TANGO, + ext_id=random_ext_id(), + name="Safeway eGift Card $25", + data=TangoCashoutMethodData( + value_type="fixed", countries=["US"], utid=random_ext_id() + ), + user=None, + image_url="https://d30s7yzk2az89n.cloudfront.net/images/brands/b694446-1200w-326ppi.png", + original_currency=Currency.USD, + min_value=2500, + max_value=2500, + ), + CashoutMethod( + id=uuid4().hex, + last_updated="2021-06-23T20:45:38.239182Z", + is_live=True, + type=PayoutType.TANGO, + ext_id=random_ext_id(), + name="Amazon.it Gift Certificate", + data=TangoCashoutMethodData( + value_type="variable", countries=["IT"], utid="U006961" + ), + user=None, + image_url="https://d30s7yzk2az89n.cloudfront.net/images/brands/b405753-1200w-326ppi.png", + original_currency=Currency.EUR, + min_value=1, + max_value=10000, + ), +] + +AMT_ASSIGNMENT_CASHOUT_METHOD = CashoutMethod( + id=uuid4().hex, + last_updated="2021-06-23T20:45:38.239182Z", + is_live=True, + type=PayoutType.AMT, + ext_id=None, + name="AMT Assignment", + data=AmtCashoutMethodData(), + user=None, + min_value=1, + max_value=5, +) + +AMT_BONUS_CASHOUT_METHOD = CashoutMethod( + id=uuid4().hex, + last_updated="2021-06-23T20:45:38.239182Z", + is_live=True, + type=PayoutType.AMT, + ext_id=None, + name="AMT Bonus", + data=AmtCashoutMethodData(), + user=None, + min_value=7, + max_value=4000, +) diff --git a/test_utils/managers/conftest.py b/test_utils/managers/conftest.py new file mode 100644 index 0000000..3a237d1 --- /dev/null +++ b/test_utils/managers/conftest.py @@ -0,0 +1,701 @@ +from typing import Callable, TYPE_CHECKING + +import pymysql +import pytest + +from generalresearch.managers.base import Permission +from generalresearch.models import Source +from test_utils.managers.cashout_methods import ( + EXAMPLE_TANGO_CASHOUT_METHODS, + AMT_ASSIGNMENT_CASHOUT_METHOD, + AMT_BONUS_CASHOUT_METHOD, +) + +if TYPE_CHECKING: + from generalresearch.grliq.managers.forensic_data import ( + GrlIqDataManager, + ) + from generalresearch.grliq.managers.forensic_events import ( + GrlIqEventManager, + ) + from generalresearch.grliq.managers.forensic_results import ( + GrlIqCategoryResultsReader, + ) + from generalresearch.managers.thl.userhealth import AuditLogManager + from generalresearch.managers.thl.payout import ( + BusinessPayoutEventManager, + ) + from generalresearch.managers.thl.maxmind.basic import ( + MaxmindBasicManager, + ) + + from generalresearch.managers.gr.authentication import ( + GRUserManager, + GRTokenManager, + ) + from generalresearch.managers.gr.business import ( + BusinessManager, + BusinessAddressManager, + BusinessBankAccountManager, + ) + from generalresearch.managers.gr.team import ( + TeamManager, + MembershipManager, + ) + from generalresearch.managers.thl.contest_manager import ContestManager + from generalresearch.managers.thl.ipinfo import ( + GeoIpInfoManager, + IPGeonameManager, + IPInformationManager, + ) + from generalresearch.managers.thl.ledger_manager.ledger import ( + LedgerTransactionManager, + LedgerManager, + LedgerAccountManager, + ) + from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, + ) + from generalresearch.managers.thl.maxmind import MaxmindManager + from generalresearch.managers.thl.payout import PayoutEventManager + from generalresearch.managers.thl.payout import ( + PayoutEventManager, + UserPayoutEventManager, + BrokerageProductPayoutEventManager, + ) + from generalresearch.managers.thl.product import ProductManager + from generalresearch.managers.thl.session import SessionManager + from generalresearch.managers.thl.user_manager.user_manager import ( + UserManager, + ) + from generalresearch.managers.thl.user_manager.user_metadata_manager import ( + UserMetadataManager, + ) + from generalresearch.managers.thl.userhealth import ( + AuditLogManager, + IPRecordManager, + UserIpHistoryManager, + IPGeonameManager, + IPInformationManager, + IPRecordManager, + ) + from generalresearch.managers.thl.wall import ( + WallManager, + WallCacheManager, + ) + from generalresearch.managers.thl.task_adjustment import ( + TaskAdjustmentManager, + ) + + +# === THL === + + +@pytest.fixture(scope="session") +def ltxm(thl_web_rw, thl_redis_config) -> "LedgerTransactionManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.ledger_manager.ledger import ( + LedgerTransactionManager, + ) + + return LedgerTransactionManager( + sql_helper=thl_web_rw, + permissions=[Permission.CREATE, Permission.READ], + testing=True, + redis_config=thl_redis_config, + ) + + +@pytest.fixture(scope="session") +def lam(thl_web_rw, thl_redis_config) -> "LedgerAccountManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.ledger_manager.ledger import ( + LedgerAccountManager, + ) + + return LedgerAccountManager( + pg_config=thl_web_rw, + permissions=[Permission.CREATE, Permission.READ], + testing=True, + redis_config=thl_redis_config, + ) + + +@pytest.fixture(scope="session") +def lm(thl_web_rw, thl_redis_config) -> "LedgerManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.ledger_manager.ledger import ( + LedgerManager, + ) + + return LedgerManager( + pg_config=thl_web_rw, + permissions=[ + Permission.CREATE, + Permission.READ, + Permission.UPDATE, + Permission.DELETE, + ], + testing=True, + redis_config=thl_redis_config, + ) + + +@pytest.fixture(scope="session") +def thl_lm(thl_web_rw, thl_redis_config) -> "ThlLedgerManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.ledger_manager.thl_ledger import ( + ThlLedgerManager, + ) + + return ThlLedgerManager( + pg_config=thl_web_rw, + permissions=[ + Permission.CREATE, + Permission.READ, + Permission.UPDATE, + Permission.DELETE, + ], + testing=True, + redis_config=thl_redis_config, + ) + + +@pytest.fixture(scope="session") +def payout_event_manager(thl_web_rw, thl_redis_config) -> "PayoutEventManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.payout import PayoutEventManager + + return PayoutEventManager( + pg_config=thl_web_rw, + permissions=[Permission.CREATE, Permission.READ], + redis_config=thl_redis_config, + ) + + +@pytest.fixture(scope="session") +def user_payout_event_manager(thl_web_rw, thl_redis_config) -> "UserPayoutEventManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.payout import UserPayoutEventManager + + return UserPayoutEventManager( + pg_config=thl_web_rw, + permissions=[Permission.CREATE, Permission.READ], + redis_config=thl_redis_config, + ) + + +@pytest.fixture(scope="session") +def brokerage_product_payout_event_manager( + thl_web_rw, thl_redis_config +) -> "BrokerageProductPayoutEventManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.payout import ( + BrokerageProductPayoutEventManager, + ) + + return BrokerageProductPayoutEventManager( + pg_config=thl_web_rw, + permissions=[Permission.CREATE, Permission.READ], + redis_config=thl_redis_config, + ) + + +@pytest.fixture(scope="session") +def business_payout_event_manager( + thl_web_rw, thl_redis_config +) -> "BusinessPayoutEventManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.payout import ( + BusinessPayoutEventManager, + ) + + return BusinessPayoutEventManager( + pg_config=thl_web_rw, + permissions=[Permission.CREATE, Permission.READ], + redis_config=thl_redis_config, + ) + + +@pytest.fixture(scope="session") +def product_manager(thl_web_rw) -> "ProductManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.product import ProductManager + + return ProductManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def user_manager(settings, thl_web_rw, thl_web_rr) -> "UserManager": + assert "/unittest-" in thl_web_rw.dsn.path + assert "/unittest-" in thl_web_rr.dsn.path + + from generalresearch.managers.thl.user_manager.user_manager import ( + UserManager, + ) + + return UserManager( + pg_config=thl_web_rw, + pg_config_rr=thl_web_rr, + redis=settings.redis, + ) + + +@pytest.fixture(scope="session") +def user_metadata_manager(thl_web_rw) -> "UserMetadataManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.user_manager.user_metadata_manager import ( + UserMetadataManager, + ) + + return UserMetadataManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def session_manager(thl_web_rw) -> "SessionManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.session import SessionManager + + return SessionManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def wall_manager(thl_web_rw) -> "WallManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.wall import WallManager + + return WallManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def wall_cache_manager(thl_web_rw, thl_redis_config) -> "WallCacheManager": + # assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.wall import WallCacheManager + + return WallCacheManager(pg_config=thl_web_rw, redis_config=thl_redis_config) + + +@pytest.fixture(scope="session") +def task_adjustment_manager(thl_web_rw) -> "TaskAdjustmentManager": + # assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.task_adjustment import ( + TaskAdjustmentManager, + ) + + return TaskAdjustmentManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def contest_manager(thl_web_rw) -> "ContestManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.contest_manager import ContestManager + + return ContestManager( + pg_config=thl_web_rw, + permissions=[ + Permission.CREATE, + Permission.READ, + Permission.UPDATE, + Permission.DELETE, + ], + ) + + +@pytest.fixture(scope="session") +def category_manager(thl_web_rw): + assert "/unittest-" in thl_web_rw.dsn.path + from generalresearch.managers.thl.category import CategoryManager + + return CategoryManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def buyer_manager(thl_web_rw): + # assert "/unittest-" in thl_web_rw.dsn.path + from generalresearch.managers.thl.buyer import BuyerManager + + return BuyerManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def survey_manager(thl_web_rw): + # assert "/unittest-" in thl_web_rw.dsn.path + from generalresearch.managers.thl.survey import SurveyManager + + return SurveyManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def surveystat_manager(thl_web_rw): + # assert "/unittest-" in thl_web_rw.dsn.path + from generalresearch.managers.thl.survey import SurveyStatManager + + return SurveyStatManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def surveypenalty_manager(thl_redis_config): + from generalresearch.managers.thl.survey_penalty import SurveyPenaltyManager + + return SurveyPenaltyManager(redis_config=thl_redis_config) + + +@pytest.fixture(scope="session") +def upk_schema_manager(thl_web_rw): + assert "/unittest-" in thl_web_rw.dsn.path + from generalresearch.managers.thl.profiling.schema import ( + UpkSchemaManager, + ) + + return UpkSchemaManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def user_upk_manager(thl_web_rw, thl_redis_config): + assert "/unittest-" in thl_web_rw.dsn.path + from generalresearch.managers.thl.profiling.user_upk import ( + UserUpkManager, + ) + + return UserUpkManager(pg_config=thl_web_rw, redis_config=thl_redis_config) + + +@pytest.fixture(scope="session") +def question_manager(thl_web_rw, thl_redis_config): + assert "/unittest-" in thl_web_rw.dsn.path + from generalresearch.managers.thl.profiling.question import ( + QuestionManager, + ) + + return QuestionManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def uqa_manager(thl_web_rw, thl_redis_config): + assert "/unittest-" in thl_web_rw.dsn.path + from generalresearch.managers.thl.profiling.uqa import UQAManager + + return UQAManager(redis_config=thl_redis_config, pg_config=thl_web_rw) + + +@pytest.fixture(scope="function") +def uqa_manager_clear_cache(uqa_manager, user): + # On successive py-test/jenkins runs, the cache may contain + # the previous run's info (keyed under the same user_id) + uqa_manager.clear_cache(user) + yield + uqa_manager.clear_cache(user) + + +@pytest.fixture(scope="session") +def audit_log_manager(thl_web_rw) -> "AuditLogManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.userhealth import AuditLogManager + + return AuditLogManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def ip_geoname_manager(thl_web_rw) -> "IPGeonameManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.ipinfo import IPGeonameManager + + return IPGeonameManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def ip_information_manager(thl_web_rw) -> "IPInformationManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.ipinfo import IPInformationManager + + return IPInformationManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def ip_record_manager(thl_web_rw, thl_redis_config) -> "IPRecordManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.userhealth import IPRecordManager + + return IPRecordManager(pg_config=thl_web_rw, redis_config=thl_redis_config) + + +@pytest.fixture(scope="session") +def user_iphistory_manager(thl_web_rw, thl_redis_config) -> "UserIpHistoryManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.userhealth import ( + UserIpHistoryManager, + ) + + return UserIpHistoryManager(pg_config=thl_web_rw, redis_config=thl_redis_config) + + +@pytest.fixture(scope="function") +def user_iphistory_manager_clear_cache(user_iphistory_manager, user): + # On successive py-test/jenkins runs, the cache may contain + # the previous run's info (keyed under the same user_id) + user_iphistory_manager.delete_user_ip_history_cache(user_id=user.user_id) + yield + user_iphistory_manager.delete_user_ip_history_cache(user_id=user.user_id) + + +@pytest.fixture(scope="session") +def geoipinfo_manager(thl_web_rw, thl_redis_config) -> "GeoIpInfoManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.ipinfo import GeoIpInfoManager + + return GeoIpInfoManager(pg_config=thl_web_rw, redis_config=thl_redis_config) + + +@pytest.fixture(scope="session") +def maxmind_basic_manager() -> "MaxmindBasicManager": + from generalresearch.managers.thl.maxmind.basic import ( + MaxmindBasicManager, + ) + + return MaxmindBasicManager(data_dir="/tmp/") + + +@pytest.fixture(scope="session") +def maxmind_manager(thl_web_rw, thl_redis_config) -> "MaxmindManager": + assert "/unittest-" in thl_web_rw.dsn.path + + from generalresearch.managers.thl.maxmind import MaxmindManager + + return MaxmindManager(pg_config=thl_web_rw, redis_config=thl_redis_config) + + +@pytest.fixture(scope="session") +def cashout_method_manager(thl_web_rw): + assert "/unittest-" in thl_web_rw.dsn.path + from generalresearch.managers.thl.cashout_method import ( + CashoutMethodManager, + ) + + return CashoutMethodManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def event_manager(thl_redis_config): + from generalresearch.managers.events import EventManager + + return EventManager(redis_config=thl_redis_config) + + +@pytest.fixture(scope="session") +def user_streak_manager(thl_web_rw): + assert "/unittest-" in thl_web_rw.dsn.path + from generalresearch.managers.thl.user_streak import ( + UserStreakManager, + ) + + return UserStreakManager(pg_config=thl_web_rw) + + +@pytest.fixture(scope="session") +def uqa_db_index(thl_web_rw): + # There were some custom indices created not through django. + # Make sure the index used in the index hint exists + assert "/unittest-" in thl_web_rw.dsn.path + + # query = f"""create index idx_user_id + # on `{thl_web_rw.db}`.marketplace_userquestionanswer (user_id);""" + # try: + # thl_web_rw.execute_sql_query(query, commit=True) + # except pymysql.OperationalError as e: + # if "Duplicate key name 'idx_user_id'" not in str(e): + # raise + return None + + +@pytest.fixture(scope="session") +def delete_cashoutmethod_db(thl_web_rw) -> Callable: + def _delete_cashoutmethod_db(): + thl_web_rw.execute_write( + query="DELETE FROM accounting_cashoutmethod;", + ) + + return _delete_cashoutmethod_db + + +@pytest.fixture(scope="session") +def setup_cashoutmethod_db(cashout_method_manager, delete_cashoutmethod_db): + delete_cashoutmethod_db() + for x in EXAMPLE_TANGO_CASHOUT_METHODS: + cashout_method_manager.create(x) + cashout_method_manager.create(AMT_ASSIGNMENT_CASHOUT_METHOD) + cashout_method_manager.create(AMT_BONUS_CASHOUT_METHOD) + return None + + +# === THL: Marketplaces === + + +@pytest.fixture(scope="session") +def spectrum_manager(spectrum_rw): + from generalresearch.managers.spectrum.survey import ( + SpectrumSurveyManager, + ) + + return SpectrumSurveyManager(sql_helper=spectrum_rw) + + +# === GR === +@pytest.fixture(scope="session") +def business_manager(gr_db, gr_redis_config) -> "BusinessManager": + from generalresearch.redis_helper import RedisConfig + + assert "/unittest-" in gr_db.dsn.path + assert isinstance(gr_redis_config, RedisConfig) + + from generalresearch.managers.gr.business import BusinessManager + + return BusinessManager( + pg_config=gr_db, + redis_config=gr_redis_config, + ) + + +@pytest.fixture(scope="session") +def business_address_manager(gr_db) -> "BusinessAddressManager": + assert "/unittest-" in gr_db.dsn.path + + from generalresearch.managers.gr.business import BusinessAddressManager + + return BusinessAddressManager(pg_config=gr_db) + + +@pytest.fixture(scope="session") +def business_bank_account_manager(gr_db) -> "BusinessBankAccountManager": + assert "/unittest-" in gr_db.dsn.path + + from generalresearch.managers.gr.business import ( + BusinessBankAccountManager, + ) + + return BusinessBankAccountManager(pg_config=gr_db) + + +@pytest.fixture(scope="session") +def team_manager(gr_db, gr_redis_config) -> "TeamManager": + assert "/unittest-" in gr_db.dsn.path + + from generalresearch.managers.gr.team import TeamManager + + return TeamManager(pg_config=gr_db, redis_config=gr_redis_config) + + +@pytest.fixture(scope="session") +def gr_um(gr_db, gr_redis_config) -> "GRUserManager": + assert "/unittest-" in gr_db.dsn.path + + from generalresearch.managers.gr.authentication import GRUserManager + + return GRUserManager(pg_config=gr_db, redis_config=gr_redis_config) + + +@pytest.fixture(scope="session") +def gr_tm(gr_db) -> "GRTokenManager": + assert "/unittest-" in gr_db.dsn.path + + from generalresearch.managers.gr.authentication import GRTokenManager + + return GRTokenManager(pg_config=gr_db) + + +@pytest.fixture(scope="session") +def membership_manager(gr_db) -> "MembershipManager": + assert "/unittest-" in gr_db.dsn.path + + from generalresearch.managers.gr.team import MembershipManager + + return MembershipManager(pg_config=gr_db) + + +# === GRL IQ === + + +@pytest.fixture(scope="session") +def grliq_dm(grliq_db) -> "GrlIqDataManager": + assert "/unittest-" in grliq_db.dsn.path + + from generalresearch.grliq.managers.forensic_data import ( + GrlIqDataManager, + ) + + return GrlIqDataManager(postgres_config=grliq_db) + + +@pytest.fixture(scope="session") +def grliq_em(grliq_db) -> "GrlIqEventManager": + assert "/unittest-" in grliq_db.dsn.path + + from generalresearch.grliq.managers.forensic_events import ( + GrlIqEventManager, + ) + + return GrlIqEventManager(postgres_config=grliq_db) + + +@pytest.fixture(scope="session") +def grliq_crr(grliq_db) -> "GrlIqCategoryResultsReader": + assert "/unittest-" in grliq_db.dsn.path + + from generalresearch.grliq.managers.forensic_results import ( + GrlIqCategoryResultsReader, + ) + + return GrlIqCategoryResultsReader(postgres_config=grliq_db) + + +@pytest.fixture(scope="session") +def delete_buyers_surveys(thl_web_rw, buyer_manager): + # assert "/unittest-" in thl_web_rw.dsn.path + thl_web_rw.execute_write( + """ + DELETE FROM marketplace_surveystat + WHERE survey_id IN ( + SELECT id + FROM marketplace_survey + WHERE source = %(source)s + );""", + params={"source": Source.TESTING.value}, + ) + thl_web_rw.execute_write( + """ + DELETE FROM marketplace_survey + WHERE buyer_id IN ( + SELECT id + FROM marketplace_buyer + WHERE source = %(source)s + );""", + params={"source": Source.TESTING.value}, + ) + thl_web_rw.execute_write( + """ + DELETE from marketplace_buyer + WHERE source=%(source)s; + """, + params={"source": Source.TESTING.value}, + ) + buyer_manager.populate_caches() diff --git a/test_utils/managers/contest/__init__.py b/test_utils/managers/contest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/managers/contest/conftest.py b/test_utils/managers/contest/conftest.py new file mode 100644 index 0000000..c2d4ef6 --- /dev/null +++ b/test_utils/managers/contest/conftest.py @@ -0,0 +1,295 @@ +from datetime import datetime, timezone +from decimal import Decimal +from typing import Callable, TYPE_CHECKING +from uuid import uuid4 + +import pytest + +from generalresearch.currency import USDCent + +if TYPE_CHECKING: + from generalresearch.models.thl.contest.contest import Contest + from generalresearch.models.thl.contest import ( + ContestPrize, + ContestEndCondition, + ) + + from generalresearch.models.thl.contest.definitions import ( + ContestType, + ContestPrizeKind, + ) + from generalresearch.models.thl.contest.io import contest_create_to_contest + from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContestCreate, + ) + from generalresearch.models.thl.contest.milestone import ( + MilestoneContestCreate, + ContestEntryTrigger, + MilestoneContestEndCondition, + ) + from generalresearch.models.thl.contest.raffle import ( + ContestEntryType, + ) + from generalresearch.models.thl.contest.raffle import ( + RaffleContestCreate, + ) + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + +@pytest.fixture +def raffle_contest_create() -> "RaffleContestCreate": + from generalresearch.models.thl.contest.raffle import ( + RaffleContestCreate, + ) + from generalresearch.models.thl.contest import ( + ContestPrize, + ContestEndCondition, + ) + from generalresearch.models.thl.contest.definitions import ( + ContestType, + ContestPrizeKind, + ) + from generalresearch.models.thl.contest.raffle import ( + ContestEntryType, + ) + + # This is what we'll get from the fastapi endpoint + return RaffleContestCreate( + name="test", + contest_type=ContestType.RAFFLE, + entry_type=ContestEntryType.CASH, + prizes=[ + ContestPrize( + name="iPod 64GB White", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100), + ) + ], + end_condition=ContestEndCondition(target_entry_amount=USDCent(100)), + ) + + +@pytest.fixture +def raffle_contest_in_db( + product_user_wallet_yes: "Product", + raffle_contest_create: "RaffleContestCreate", + contest_manager, +) -> "Contest": + return contest_manager.create( + product_id=product_user_wallet_yes.uuid, contest_create=raffle_contest_create + ) + + +@pytest.fixture +def raffle_contest( + product_user_wallet_yes: "Product", raffle_contest_create: "RaffleContestCreate" +) -> "Contest": + from generalresearch.models.thl.contest.io import contest_create_to_contest + + return contest_create_to_contest( + product_id=product_user_wallet_yes.uuid, contest_create=raffle_contest_create + ) + + +@pytest.fixture(scope="function") +def raffle_contest_factory( + product_user_wallet_yes: "Product", + raffle_contest_create: "RaffleContestCreate", + contest_manager, +) -> Callable: + def _create_contest(**kwargs): + raffle_contest_create.update(**kwargs) + return contest_manager.create( + product_id=product_user_wallet_yes.uuid, + contest_create=raffle_contest_create, + ) + + return _create_contest + + +@pytest.fixture +def milestone_contest_create() -> "MilestoneContestCreate": + from generalresearch.models.thl.contest import ( + ContestPrize, + ) + from generalresearch.models.thl.contest.definitions import ( + ContestType, + ContestPrizeKind, + ) + from generalresearch.models.thl.contest.milestone import ( + MilestoneContestCreate, + ContestEntryTrigger, + MilestoneContestEndCondition, + ) + + # This is what we'll get from the fastapi endpoint + return MilestoneContestCreate( + name="Win a 50% bonus for 7 days and a $1 bonus after your first 3 completes!", + description="only valid for the first 5 users", + contest_type=ContestType.MILESTONE, + prizes=[ + ContestPrize( + name="50% for 7 days", + kind=ContestPrizeKind.PROMOTION, + estimated_cash_value=USDCent(0), + ), + ContestPrize( + name="$1 Bonus", + kind=ContestPrizeKind.CASH, + cash_amount=USDCent(1_00), + estimated_cash_value=USDCent(1_00), + ), + ], + end_condition=MilestoneContestEndCondition( + ends_at=datetime(year=2030, month=1, day=1, tzinfo=timezone.utc), + max_winners=5, + ), + entry_trigger=ContestEntryTrigger.TASK_COMPLETE, + target_amount=3, + ) + + +@pytest.fixture +def milestone_contest_in_db( + product_user_wallet_yes: "Product", + milestone_contest_create: "MilestoneContestCreate", + contest_manager, +) -> "Contest": + return contest_manager.create( + product_id=product_user_wallet_yes.uuid, contest_create=milestone_contest_create + ) + + +@pytest.fixture +def milestone_contest( + product_user_wallet_yes: "Product", + milestone_contest_create: "MilestoneContestCreate", +) -> "Contest": + from generalresearch.models.thl.contest.io import contest_create_to_contest + + return contest_create_to_contest( + product_id=product_user_wallet_yes.uuid, contest_create=milestone_contest_create + ) + + +@pytest.fixture(scope="function") +def milestone_contest_factory( + product_user_wallet_yes: "Product", + milestone_contest_create: "MilestoneContestCreate", + contest_manager, +) -> Callable: + def _create_contest(**kwargs): + milestone_contest_create.update(**kwargs) + return contest_manager.create( + product_id=product_user_wallet_yes.uuid, + contest_create=milestone_contest_create, + ) + + return _create_contest + + +@pytest.fixture +def leaderboard_contest_create( + product_user_wallet_yes: "Product", +) -> "LeaderboardContestCreate": + from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContestCreate, + ) + from generalresearch.models.thl.contest import ( + ContestPrize, + ) + from generalresearch.models.thl.contest.definitions import ( + ContestType, + ContestPrizeKind, + ) + + # This is what we'll get from the fastapi endpoint + return LeaderboardContestCreate( + name="test", + contest_type=ContestType.LEADERBOARD, + prizes=[ + ContestPrize( + name="$15 Cash", + estimated_cash_value=USDCent(15_00), + cash_amount=USDCent(15_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=1, + ), + ContestPrize( + name="$10 Cash", + estimated_cash_value=USDCent(10_00), + cash_amount=USDCent(10_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=2, + ), + ], + leaderboard_key=f"leaderboard:{product_user_wallet_yes.uuid}:us:daily:2025-01-01:complete_count", + ) + + +@pytest.fixture +def leaderboard_contest_in_db( + product_user_wallet_yes: "Product", + leaderboard_contest_create: "LeaderboardContestCreate", + contest_manager, +) -> "Contest": + return contest_manager.create( + product_id=product_user_wallet_yes.uuid, + contest_create=leaderboard_contest_create, + ) + + +@pytest.fixture +def leaderboard_contest( + product_user_wallet_yes: "Product", + leaderboard_contest_create: "LeaderboardContestCreate", +): + from generalresearch.models.thl.contest.io import contest_create_to_contest + + return contest_create_to_contest( + product_id=product_user_wallet_yes.uuid, + contest_create=leaderboard_contest_create, + ) + + +@pytest.fixture(scope="function") +def leaderboard_contest_factory( + product_user_wallet_yes: "Product", + leaderboard_contest_create: "LeaderboardContestCreate", + contest_manager, +) -> Callable: + def _create_contest(**kwargs): + leaderboard_contest_create.update(**kwargs) + return contest_manager.create( + product_id=product_user_wallet_yes.uuid, + contest_create=leaderboard_contest_create, + ) + + return _create_contest + + +@pytest.fixture +def user_with_money( + request, user_factory, product_user_wallet_yes: "Product", thl_lm +) -> "User": + from generalresearch.models.thl.user import User + + params = getattr(request, "param", dict()) or {} + min_balance = int(params.get("min_balance", USDCent(1_00))) + + user: User = user_factory(product=product_user_wallet_yes) + wallet = thl_lm.get_account_or_create_user_wallet(user) + balance = thl_lm.get_account_balance(wallet) + todo = min_balance - balance + if todo > 0: + # # Put money in user's wallet + thl_lm.create_tx_user_bonus( + user=user, + ref_uuid=uuid4().hex, + description="bonus", + amount=Decimal(todo) / 100, + ) + print(f"wallet balance: {thl_lm.get_user_wallet_balance(user=user)}") + + return user diff --git a/test_utils/managers/ledger/__init__.py b/test_utils/managers/ledger/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/managers/ledger/conftest.py b/test_utils/managers/ledger/conftest.py new file mode 100644 index 0000000..b96d612 --- /dev/null +++ b/test_utils/managers/ledger/conftest.py @@ -0,0 +1,678 @@ +from datetime import datetime +from decimal import Decimal +from random import randint +from typing import Optional, Dict, Callable, TYPE_CHECKING +from uuid import uuid4 + +import pytest + +from generalresearch.currency import USDCent +from test_utils.models.conftest import ( + product_factory, + user, + product, + user_factory, + product_user_wallet_no, + wall, + product_amt_true, + product_user_wallet_yes, + session_factory, + session, + wall_factory, + payout_config, +) + +_ = ( + user_factory, + product_user_wallet_no, + wall, + product_amt_true, + product_user_wallet_yes, + session_factory, + session, + wall_factory, + payout_config, +) + +if TYPE_CHECKING: + from generalresearch.currency import LedgerCurrency + from generalresearch.models.thl.ledger import ( + Direction, + AccountType, + LedgerTransaction, + ) + from generalresearch.models.thl.ledger import ( + LedgerEntry, + LedgerAccount, + ) + from generalresearch.models.thl.payout import UserPayoutEvent + + +@pytest.fixture(scope="function") +def ledger_account(request, lm, currency) -> "LedgerAccount": + from generalresearch.models.thl.ledger import ( + Direction, + AccountType, + LedgerAccount, + ) + + account_type = getattr(request, "account_type", AccountType.CASH) + direction = getattr(request, "direction", Direction.CREDIT) + + acct_uuid = uuid4().hex + qn = ":".join([currency, account_type, acct_uuid]) + + acct_model = LedgerAccount( + uuid=acct_uuid, + display_name=f"test-{acct_uuid}", + currency=currency, + qualified_name=qn, + account_type=account_type, + normal_balance=direction, + ) + return lm.create_account(account=acct_model) + + +@pytest.fixture(scope="function") +def ledger_account_factory(request, thl_lm, lm, currency) -> Callable: + from generalresearch.models.thl.ledger import ( + Direction, + AccountType, + LedgerAccount, + ) + + def _ledger_account_factory( + product, + account_type: AccountType = AccountType.CASH, + direction: Direction = Direction.CREDIT, + ): + thl_lm.get_account_or_create_bp_wallet(product=product) + acct_uuid = uuid4().hex + qn = ":".join([currency, account_type, acct_uuid]) + + acct_model = LedgerAccount( + uuid=acct_uuid, + display_name=f"test-{acct_uuid}", + currency=currency, + qualified_name=qn, + account_type=account_type, + normal_balance=direction, + ) + return lm.create_account(account=acct_model) + + return _ledger_account_factory + + +@pytest.fixture(scope="function") +def ledger_account_credit(request, lm, currency) -> "LedgerAccount": + from generalresearch.models.thl.ledger import Direction, AccountType + + account_type = AccountType.REVENUE + acct_uuid = uuid4().hex + + qn = ":".join([currency, account_type, acct_uuid]) + from generalresearch.models.thl.ledger import LedgerAccount + + acct_model = LedgerAccount( + uuid=acct_uuid, + display_name=f"test-{acct_uuid}", + currency=currency, + qualified_name=qn, + account_type=account_type, + normal_balance=Direction.CREDIT, + ) + return lm.create_account(account=acct_model) + + +@pytest.fixture(scope="function") +def ledger_account_debit(request, lm, currency) -> "LedgerAccount": + from generalresearch.models.thl.ledger import Direction, AccountType + + account_type = AccountType.EXPENSE + acct_uuid = uuid4().hex + + qn = ":".join([currency, account_type, acct_uuid]) + from generalresearch.models.thl.ledger import LedgerAccount + + acct_model = LedgerAccount( + uuid=acct_uuid, + display_name=f"test-{acct_uuid}", + currency=currency, + qualified_name=qn, + account_type=account_type, + normal_balance=Direction.DEBIT, + ) + return lm.create_account(account=acct_model) + + +@pytest.fixture(scope="function") +def tag(request, lm) -> str: + from generalresearch.currency import LedgerCurrency + + return ( + request.param + if hasattr(request, "tag") + else f"{LedgerCurrency.TEST}:{uuid4().hex}" + ) + + +@pytest.fixture(scope="function") +def usd_cent(request) -> USDCent: + amount = randint(99, 9_999) + return request.param if hasattr(request, "usd_cent") else USDCent(amount) + + +@pytest.fixture(scope="function") +def bp_payout_event( + product, usd_cent, business_payout_event_manager, thl_lm +) -> "UserPayoutEvent": + return business_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=usd_cent, + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + +@pytest.fixture +def bp_payout_event_factory(brokerage_product_payout_event_manager, thl_lm) -> Callable: + from generalresearch.models.thl.product import Product + from generalresearch.currency import USDCent + + def _create_bp_payout_event( + product: Product, usd_cent: USDCent, ext_ref_id: Optional[str] = None + ): + return brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=usd_cent, + ext_ref_id=ext_ref_id, + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + return _create_bp_payout_event + + +@pytest.fixture(scope="function") +def currency(lm) -> "LedgerCurrency": + # return request.param if hasattr(request, "currency") else LedgerCurrency.TEST + return lm.currency + + +@pytest.fixture(scope="function") +def tx_metadata(request) -> Optional[Dict[str, str]]: + return ( + request.param + if hasattr(request, "tx_metadata") + else {f"key-{uuid4().hex[:10]}": uuid4().hex} + ) + + +@pytest.fixture(scope="function") +def ledger_tx( + request, + ledger_account_credit, + ledger_account_debit, + tag, + currency, + tx_metadata, + lm, +) -> "LedgerTransaction": + from generalresearch.models.thl.ledger import Direction, LedgerEntry + + amount = int(Decimal("1.00") * 100) + + entries = [ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=ledger_account_credit.uuid, + amount=amount, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=ledger_account_debit.uuid, + amount=amount, + ), + ] + + return lm.create_tx(entries=entries, tag=tag, metadata=tx_metadata) + + +@pytest.fixture(scope="function") +def create_main_accounts(lm, currency) -> Callable: + def _create_main_accounts(): + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + account = LedgerAccount( + display_name="Cash flow task complete", + qualified_name=f"{currency.value}:revenue:task_complete", + normal_balance=Direction.CREDIT, + account_type=AccountType.REVENUE, + currency=lm.currency, + ) + lm.get_account_or_create(account=account) + + account = LedgerAccount( + display_name="Operating Cash Account", + qualified_name=f"{currency.value}:cash", + normal_balance=Direction.DEBIT, + account_type=AccountType.CASH, + currency=currency, + ) + + lm.get_account_or_create(account=account) + + return _create_main_accounts + + +@pytest.fixture(scope="function") +def delete_ledger_db(thl_web_rw) -> Callable: + def _delete_ledger_db(): + for table in [ + "ledger_transactionmetadata", + "ledger_entry", + "ledger_transaction", + "ledger_account", + ]: + thl_web_rw.execute_write( + query=f"DELETE FROM {table};", + ) + + return _delete_ledger_db + + +@pytest.fixture(scope="function") +def wipe_main_accounts(thl_web_rw, lm, currency) -> Callable: + def _wipe_main_accounts(): + db_table = thl_web_rw.db_name + qual_names = [ + f"{currency.value}:revenue:task_complete", + f"{currency.value}:cash", + ] + + res = thl_web_rw.execute_sql_query( + query=f""" + SELECT lt.id as ltid, le.id as leid, tmd.id as tmdid, la.uuid as lauuid + FROM `{db_table}`.`ledger_transaction` AS lt + LEFT JOIN `{db_table}`.ledger_entry le + ON lt.id = le.transaction_id + LEFT JOIN `{db_table}`.ledger_account la + ON la.uuid = le.account_id + LEFT JOIN `{db_table}`.ledger_transactionmetadata tmd + ON lt.id = tmd.transaction_id + WHERE la.qualified_name IN %s + """, + params=[qual_names], + ) + + lt = {x["ltid"] for x in res if x["ltid"]} + le = {x["leid"] for x in res if x["leid"]} + tmd = {x["tmdid"] for x in res if x["tmdid"]} + la = {x["lauuid"] for x in res if x["lauuid"]} + + thl_web_rw.execute_sql_query( + query=f""" + DELETE FROM `{db_table}`.`ledger_transactionmetadata` + WHERE id IN %s + """, + params=[tmd], + commit=True, + ) + + thl_web_rw.execute_sql_query( + query=f""" + DELETE FROM `{db_table}`.`ledger_entry` + WHERE id IN %s + """, + params=[le], + commit=True, + ) + + thl_web_rw.execute_sql_query( + query=f""" + DELETE FROM `{db_table}`.`ledger_transaction` + WHERE id IN %s + """, + params=[lt], + commit=True, + ) + + thl_web_rw.execute_sql_query( + query=f""" + DELETE FROM `{db_table}`.`ledger_account` + WHERE uuid IN %s + """, + params=[la], + commit=True, + ) + + return _wipe_main_accounts + + +@pytest.fixture(scope="function") +def account_cash(lm, currency) -> "LedgerAccount": + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + account = LedgerAccount( + display_name="Operating Cash Account", + qualified_name=f"{currency.value}:cash", + normal_balance=Direction.DEBIT, + account_type=AccountType.CASH, + currency=currency, + ) + return lm.get_account_or_create(account=account) + + +@pytest.fixture(scope="function") +def account_revenue_task_complete(lm, currency) -> "LedgerAccount": + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + account = LedgerAccount( + display_name="Cash flow task complete", + qualified_name=f"{currency.value}:revenue:task_complete", + normal_balance=Direction.CREDIT, + account_type=AccountType.REVENUE, + currency=currency, + ) + return lm.get_account_or_create(account=account) + + +@pytest.fixture(scope="function") +def account_expense_tango(lm, currency) -> "LedgerAccount": + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + account = LedgerAccount( + display_name="Tango Fee", + qualified_name=f"{currency.value}:expense:tango_fee", + normal_balance=Direction.DEBIT, + account_type=AccountType.EXPENSE, + currency=currency, + ) + return lm.get_account_or_create(account=account) + + +@pytest.fixture(scope="function") +def user_account_user_wallet(lm, user, currency) -> "LedgerAccount": + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + account = LedgerAccount( + display_name=f"{user.uuid} Wallet", + qualified_name=f"{currency.value}:user_wallet:{user.uuid}", + normal_balance=Direction.CREDIT, + account_type=AccountType.USER_WALLET, + reference_type="user", + reference_uuid=user.uuid, + currency=currency, + ) + return lm.get_account_or_create(account=account) + + +@pytest.fixture(scope="function") +def product_account_bp_wallet(lm, product, currency) -> "LedgerAccount": + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + account = LedgerAccount.model_validate( + dict( + display_name=f"{product.name} Wallet", + qualified_name=f"{currency.value}:bp_wallet:{product.uuid}", + normal_balance=Direction.CREDIT, + account_type=AccountType.BP_WALLET, + reference_type="bp", + reference_uuid=product.uuid, + currency=currency, + ) + ) + return lm.get_account_or_create(account=account) + + +@pytest.fixture(scope="function") +def setup_accounts(product_factory, lm, user, currency) -> None: + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + # BP's wallet and a revenue from their commissions account. + p1 = product_factory() + + account = LedgerAccount( + display_name=f"Revenue from {p1.name} commission", + qualified_name=f"{currency.value}:revenue:bp_commission:{p1.uuid}", + normal_balance=Direction.CREDIT, + account_type=AccountType.REVENUE, + reference_type="bp", + reference_uuid=p1.uuid, + currency=currency, + ) + lm.get_account_or_create(account=account) + + account = LedgerAccount.model_validate( + dict( + display_name=f"{p1.name} Wallet", + qualified_name=f"{currency.value}:bp_wallet:{p1.uuid}", + normal_balance=Direction.CREDIT, + account_type=AccountType.BP_WALLET, + reference_type="bp", + reference_uuid=p1.uuid, + currency=currency, + ) + ) + lm.get_account_or_create(account=account) + + # BP's wallet, user's wallet, and a revenue from their commissions account. + p2 = product_factory() + account = LedgerAccount( + display_name=f"Revenue from {p2.name} commission", + qualified_name=f"{currency.value}:revenue:bp_commission:{p2.uuid}", + normal_balance=Direction.CREDIT, + account_type=AccountType.REVENUE, + reference_type="bp", + reference_uuid=p2.uuid, + currency=currency, + ) + lm.get_account_or_create(account) + + account = LedgerAccount( + display_name=f"{p2.name} Wallet", + qualified_name=f"{currency.value}:bp_wallet:{p2.uuid}", + normal_balance=Direction.CREDIT, + account_type=AccountType.BP_WALLET, + reference_type="bp", + reference_uuid=p2.uuid, + currency=currency, + ) + lm.get_account_or_create(account) + + account = LedgerAccount( + display_name=f"{user.uuid} Wallet", + qualified_name=f"{currency.value}:user_wallet:{user.uuid}", + normal_balance=Direction.CREDIT, + account_type=AccountType.USER_WALLET, + reference_type="user", + reference_uuid=user.uuid, + currency="test", + ) + lm.get_account_or_create(account=account) + + +@pytest.fixture(scope="function") +def session_with_tx_factory( + user_factory, + product, + session_factory, + session_manager, + wall_manager, + utc_hour_ago, + thl_lm, +) -> Callable: + from generalresearch.models.thl.session import ( + Status, + Session, + StatusCode1, + ) + from generalresearch.models.thl.user import User + + def _session_with_tx_factory( + user: User, + final_status: Status = Status.COMPLETE, + wall_req_cpi: Decimal = Decimal(".50"), + started: datetime = utc_hour_ago, + ) -> Session: + s: Session = session_factory( + user=user, + wall_count=2, + final_status=final_status, + wall_req_cpi=wall_req_cpi, + started=started, + ) + last_wall = s.wall_events[-1] + + wall_manager.finish( + wall=last_wall, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + finished=last_wall.finished, + ) + + status, status_code_1 = s.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s.determine_payments() + session_manager.finish_with_status( + session=s, + finished=last_wall.finished, + payout=bp_pay, + user_payout=user_pay, + status=status, + status_code_1=status_code_1, + ) + + thl_lm.create_tx_task_complete( + wall=last_wall, + user=user, + created=last_wall.finished, + force=True, + ) + + thl_lm.create_tx_bp_payment(session=s, created=last_wall.finished, force=True) + + return s + + return _session_with_tx_factory + + +@pytest.fixture(scope="function") +def adj_to_fail_with_tx_factory(session_manager, wall_manager, thl_lm) -> Callable: + from generalresearch.models.thl.session import ( + Session, + ) + from datetime import timedelta + from generalresearch.models.thl.definitions import WallAdjustedStatus + + def _adj_to_fail_with_tx_factory( + session: Session, + created: datetime, + ) -> None: + w1 = wall_manager.get_wall_events(session_id=session.id)[-1] + + # This is defined in `thl-grpc/thl/user_quality_history/recons.py:150` + # so we can't use it as part of this test anyway to add rows to the + # thl_taskadjustment table anyway.. until we created a + # TaskAdjustment Manager to put into py-utils! + + # create_task_adjustment_event( + # wall, + # user, + # adjusted_status, + # amount_usd=amount_usd, + # alert_time=alert_time, + # ext_status_code=ext_status_code, + # ) + + wall_manager.adjust_status( + wall=w1, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=Decimal("0.00"), + adjusted_timestamp=created, + ) + + thl_lm.create_tx_task_adjustment( + wall=w1, + user=session.user, + created=created + timedelta(milliseconds=1), + ) + + session.wall_events = wall_manager.get_wall_events(session_id=session.id) + session_manager.adjust_status(session=session) + + thl_lm.create_tx_bp_adjustment( + session=session, created=created + timedelta(milliseconds=2) + ) + + return None + + return _adj_to_fail_with_tx_factory + + +@pytest.fixture(scope="function") +def adj_to_complete_with_tx_factory(session_manager, wall_manager, thl_lm) -> Callable: + from generalresearch.models.thl.session import ( + Session, + ) + from datetime import timedelta + from generalresearch.models.thl.definitions import WallAdjustedStatus + + def _adj_to_complete_with_tx_factory( + session: Session, + created: datetime, + ) -> None: + w1 = wall_manager.get_wall_events(session_id=session.id)[-1] + + wall_manager.adjust_status( + wall=w1, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w1.req_cpi, + adjusted_timestamp=created, + ) + + thl_lm.create_tx_task_adjustment( + wall=w1, + user=session.user, + created=created + timedelta(milliseconds=1), + ) + + session.wall_events = wall_manager.get_wall_events(session_id=session.id) + session_manager.adjust_status(session=session) + + thl_lm.create_tx_bp_adjustment( + session=session, created=created + timedelta(milliseconds=2) + ) + + return None + + return _adj_to_complete_with_tx_factory diff --git a/test_utils/managers/upk/__init__.py b/test_utils/managers/upk/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/managers/upk/conftest.py b/test_utils/managers/upk/conftest.py new file mode 100644 index 0000000..61be924 --- /dev/null +++ b/test_utils/managers/upk/conftest.py @@ -0,0 +1,161 @@ +import os +import time +from typing import Optional +from uuid import UUID + +import pandas as pd +import pytest + +from generalresearch.pg_helper import PostgresConfig + + +def insert_data_from_csv( + thl_web_rw: PostgresConfig, + table_name: str, + fp: Optional[str] = None, + disable_fk_checks: bool = False, + df: Optional[pd.DataFrame] = None, +): + assert fp is not None or df is not None and not (fp is not None and df is not None) + if fp: + df = pd.read_csv(fp, dtype=str) + df = df.where(pd.notnull(df), None) + cols = list(df.columns) + col_str = ", ".join(cols) + values_str = ", ".join(["%s"] * len(cols)) + if "id" in df.columns and len(df["id"].iloc[0]) == 36: + df["id"] = df["id"].map(lambda x: UUID(x).hex) + args = df.to_dict("tight")["data"] + + with thl_web_rw.make_connection() as conn: + with conn.cursor() as c: + if disable_fk_checks: + c.execute("SET CONSTRAINTS ALL DEFERRED") + c.executemany( + f"INSERT INTO {table_name} ({col_str}) VALUES ({values_str})", + params_seq=args, + ) + conn.commit() + + +@pytest.fixture(scope="session") +def category_data(thl_web_rw, category_manager) -> None: + fp = os.path.join(os.path.dirname(__file__), "marketplace_category.csv.gz") + insert_data_from_csv( + thl_web_rw, + fp=fp, + table_name="marketplace_category", + disable_fk_checks=True, + ) + # Don't strictly need to do this, but probably we should + category_manager.populate_caches() + cats = category_manager.categories.values() + path_id = {c.path: c.id for c in cats} + data = [ + {"id": c.id, "parent_id": path_id[c.parent_path]} for c in cats if c.parent_path + ] + query = """ + UPDATE marketplace_category + SET parent_id = %(parent_id)s + WHERE id = %(id)s; + """ + with thl_web_rw.make_connection() as conn: + with conn.cursor() as c: + c.executemany(query=query, params_seq=data) + conn.commit() + + +@pytest.fixture(scope="session") +def property_data(thl_web_rw) -> None: + fp = os.path.join(os.path.dirname(__file__), "marketplace_property.csv.gz") + insert_data_from_csv(thl_web_rw, fp=fp, table_name="marketplace_property") + + +@pytest.fixture(scope="session") +def item_data(thl_web_rw) -> None: + fp = os.path.join(os.path.dirname(__file__), "marketplace_item.csv.gz") + insert_data_from_csv(thl_web_rw, fp=fp, table_name="marketplace_item") + + +@pytest.fixture(scope="session") +def propertycategoryassociation_data( + thl_web_rw, category_data, property_data, category_manager +) -> None: + table_name = "marketplace_propertycategoryassociation" + fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz") + # Need to lookup category pk from uuid + category_manager.populate_caches() + df = pd.read_csv(fp, dtype=str) + df["category_id"] = df["category_id"].map( + lambda x: category_manager.categories[x].id + ) + insert_data_from_csv(thl_web_rw, df=df, table_name=table_name) + + +@pytest.fixture(scope="session") +def propertycountry_data(thl_web_rw, property_data) -> None: + fp = os.path.join(os.path.dirname(__file__), "marketplace_propertycountry.csv.gz") + insert_data_from_csv(thl_web_rw, fp=fp, table_name="marketplace_propertycountry") + + +@pytest.fixture(scope="session") +def propertymarketplaceassociation_data(thl_web_rw, property_data) -> None: + table_name = "marketplace_propertymarketplaceassociation" + fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz") + insert_data_from_csv(thl_web_rw, fp=fp, table_name=table_name) + + +@pytest.fixture(scope="session") +def propertyitemrange_data(thl_web_rw, property_data, item_data) -> None: + table_name = "marketplace_propertyitemrange" + fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz") + insert_data_from_csv(thl_web_rw, fp=fp, table_name=table_name) + + +@pytest.fixture(scope="session") +def question_data(thl_web_rw) -> None: + table_name = "marketplace_question" + fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz") + insert_data_from_csv( + thl_web_rw, fp=fp, table_name=table_name, disable_fk_checks=True + ) + + +@pytest.fixture(scope="session") +def clear_upk_tables(thl_web_rw): + tables = [ + "marketplace_propertyitemrange", + "marketplace_propertymarketplaceassociation", + "marketplace_propertycategoryassociation", + "marketplace_category", + "marketplace_item", + "marketplace_property", + "marketplace_propertycountry", + "marketplace_question", + ] + table_str = ", ".join(tables) + + with thl_web_rw.make_connection() as conn: + with conn.cursor() as c: + c.execute(f"TRUNCATE {table_str} RESTART IDENTITY CASCADE;") + conn.commit() + + +@pytest.fixture(scope="session") +def upk_data( + clear_upk_tables, + category_data, + property_data, + item_data, + propertycategoryassociation_data, + propertycountry_data, + propertymarketplaceassociation_data, + propertyitemrange_data, + question_data, +) -> None: + # Wait a second to make sure the HarmonizerCache refresh loop pulls these in + time.sleep(2) + + +def test_fixtures(upk_data): + pass diff --git a/test_utils/managers/upk/marketplace_category.csv.gz b/test_utils/managers/upk/marketplace_category.csv.gz new file mode 100644 index 0000000..0f8ec1c Binary files /dev/null and b/test_utils/managers/upk/marketplace_category.csv.gz differ diff --git a/test_utils/managers/upk/marketplace_item.csv.gz b/test_utils/managers/upk/marketplace_item.csv.gz new file mode 100644 index 0000000..c12c5d8 Binary files /dev/null and b/test_utils/managers/upk/marketplace_item.csv.gz differ diff --git a/test_utils/managers/upk/marketplace_property.csv.gz b/test_utils/managers/upk/marketplace_property.csv.gz new file mode 100644 index 0000000..a781d1d Binary files /dev/null and b/test_utils/managers/upk/marketplace_property.csv.gz differ diff --git a/test_utils/managers/upk/marketplace_propertycategoryassociation.csv.gz b/test_utils/managers/upk/marketplace_propertycategoryassociation.csv.gz new file mode 100644 index 0000000..5b4ea19 Binary files /dev/null and b/test_utils/managers/upk/marketplace_propertycategoryassociation.csv.gz differ diff --git a/test_utils/managers/upk/marketplace_propertycountry.csv.gz b/test_utils/managers/upk/marketplace_propertycountry.csv.gz new file mode 100644 index 0000000..5d2a637 Binary files /dev/null and b/test_utils/managers/upk/marketplace_propertycountry.csv.gz differ diff --git a/test_utils/managers/upk/marketplace_propertyitemrange.csv.gz b/test_utils/managers/upk/marketplace_propertyitemrange.csv.gz new file mode 100644 index 0000000..84f4f0e Binary files /dev/null and b/test_utils/managers/upk/marketplace_propertyitemrange.csv.gz differ diff --git a/test_utils/managers/upk/marketplace_propertymarketplaceassociation.csv.gz b/test_utils/managers/upk/marketplace_propertymarketplaceassociation.csv.gz new file mode 100644 index 0000000..6b9fd1c Binary files /dev/null and b/test_utils/managers/upk/marketplace_propertymarketplaceassociation.csv.gz differ diff --git a/test_utils/managers/upk/marketplace_question.csv.gz b/test_utils/managers/upk/marketplace_question.csv.gz new file mode 100644 index 0000000..bcfc3ad Binary files /dev/null and b/test_utils/managers/upk/marketplace_question.csv.gz differ diff --git a/test_utils/models/__init__.py b/test_utils/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/models/conftest.py b/test_utils/models/conftest.py new file mode 100644 index 0000000..ecfd82b --- /dev/null +++ b/test_utils/models/conftest.py @@ -0,0 +1,608 @@ +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from random import randint, choice as randchoice +from typing import Callable, TYPE_CHECKING, Optional, List, Dict +from uuid import uuid4 + +import pytest +from pydantic import AwareDatetime, PositiveInt + +from generalresearch.models import Source +from generalresearch.models.thl.definitions import ( + WALL_ALLOWED_STATUS_STATUS_CODE, + Status, +) +from test_utils.managers.conftest import ( + product_manager, + user_manager, + wall_manager, + session_manager, + gr_um, + membership_manager, + team_manager, + business_manager, + business_address_manager, +) +from generalresearch.models.thl.survey.model import Survey, Buyer + +if TYPE_CHECKING: + from generalresearch.models.thl.userhealth import AuditLog, AuditLogLevel + from generalresearch.models.thl.payout import UserPayoutEvent + from generalresearch.models.gr.authentication import GRUser, GRToken + from generalresearch.models.gr.team import Team, Membership + from generalresearch.models.gr.business import ( + Business, + BusinessAddress, + BusinessBankAccount, + ) + from generalresearch.models.thl.user import User + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.session import Session, Wall + from generalresearch.currency import USDCent + from generalresearch.models.thl.product import ( + PayoutConfig, + PayoutTransformation, + PayoutTransformationPercentArgs, + ) + from generalresearch.models.thl.user_iphistory import IPRecord + from generalresearch.models.thl.ipinfo import IPGeoname, IPInformation + + +# === THL === + + +@pytest.fixture(scope="function") +def user(request, product_manager, user_manager, thl_web_rr) -> "User": + product = getattr(request, "product", None) + + if product is None: + product = product_manager.create_dummy() + + u = user_manager.create_dummy(product_id=product.id) + u.prefetch_product(pg_config=thl_web_rr) + + return u + + +@pytest.fixture +def user_with_wallet( + request, user_factory, product_user_wallet_yes: "Product" +) -> "User": + # A user on a product with user wallet enabled, but they have no money + return user_factory(product=product_user_wallet_yes) + + +@pytest.fixture +def user_with_wallet_amt(request, user_factory, product_amt_true: "Product") -> "User": + # A user on a product with user wallet enabled, on AMT, but they have no money + return user_factory(product=product_amt_true) + + +@pytest.fixture(scope="function") +def user_factory(user_manager, thl_web_rr) -> Callable: + def _create_user(product: "Product", created: Optional[datetime] = None): + u = user_manager.create_dummy(product=product, created=created) + u.prefetch_product(pg_config=thl_web_rr) + + return u + + return _create_user + + +@pytest.fixture(scope="function") +def wall_factory(wall_manager) -> Callable: + def _create_wall( + session: "Session", wall_status: "Status", req_cpi: Optional[Decimal] = None + ): + + assert session.started <= datetime.now( + tz=timezone.utc + ), "Session can't start in the future" + + if session.wall_events: + # Subsequent Wall events + wall = session.wall_events[-1] + assert not wall.finished, "Can't add new Walls until prior finishes" + # wall_started = last_wall.started + timedelta(milliseconds=1) + else: + # First Wall Event in a session + wall_started = session.started + timedelta(milliseconds=1) + + wall = wall_manager.create_dummy( + session_id=session.id, + user_id=session.user_id, + started=wall_started, + req_cpi=req_cpi, + ) + session.append_wall_event(w=wall) + + options = list(WALL_ALLOWED_STATUS_STATUS_CODE.get(wall_status, {})) + wall.finish( + finished=wall.started + timedelta(seconds=randint(a=60 * 2, b=60 * 10)), + status=wall_status, + status_code_1=randchoice(options), + ) + + return wall + + return _create_wall + + +@pytest.fixture(scope="function") +def wall(session, user, wall_manager) -> Optional["Wall"]: + from generalresearch.models.thl.task_status import StatusCode1 + + wall = wall_manager.create_dummy(session_id=session.id, user_id=user.user_id) + # thl_session.append_wall_event(wall) + wall.finish( + finished=wall.started + timedelta(seconds=randint(a=60 * 2, b=60 * 10)), + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + ) + return wall + + +@pytest.fixture(scope="function") +def session_factory( + wall_factory, session_manager, wall_manager, utc_hour_ago +) -> Callable: + from generalresearch.models.thl.session import Source + + def _create_session( + user: "User", + # Wall details + wall_count: int = 5, + wall_req_cpi: Decimal = Decimal(".50"), + wall_req_cpis: Optional[List[Decimal]] = None, + wall_statuses: Optional[List[Status]] = None, + wall_source: Source = Source.TESTING, + # Session details + final_status: Status = Status.COMPLETE, + started: datetime = utc_hour_ago, + ) -> "Session": + if wall_req_cpis: + assert len(wall_req_cpis) == wall_count + if wall_statuses: + assert len(wall_statuses) == wall_count + + s = session_manager.create_dummy(started=started, user=user, country_iso="us") + for idx in range(wall_count): + if idx == 0: + # First Wall Event in a session + wall_started = s.started + timedelta(milliseconds=1) + else: + # Subsequent Wall events + last_wall = s.wall_events[-1] + assert last_wall.finished, "Can't add new Walls until prior finishes" + wall_started = last_wall.started + timedelta(milliseconds=1) + + w = wall_manager.create_dummy( + session_id=s.id, + source=wall_source, + user_id=s.user_id, + started=wall_started, + req_cpi=wall_req_cpis[idx] if wall_req_cpis else wall_req_cpi, + ) + s.append_wall_event(w=w) + + # If it's the last wall in the session, respect the final_status + # value for the Session + if wall_statuses: + _final_status = wall_statuses[idx] + else: + _final_status = final_status if idx == wall_count - 1 else Status.FAIL + + options = list(WALL_ALLOWED_STATUS_STATUS_CODE.get(_final_status, {})) + wall_manager.finish( + wall=w, + status=_final_status, + status_code_1=randchoice(options), + finished=w.started + timedelta(seconds=randint(a=60 * 2, b=60 * 10)), + ) + + return s + + return _create_session + + +@pytest.fixture(scope="function") +def finished_session_factory( + session_factory, session_manager, utc_hour_ago +) -> Callable: + from generalresearch.models.thl.session import Source + + def _create_finished_session( + user: "User", + # Wall details + wall_count: int = 5, + wall_req_cpi: Decimal = Decimal(".50"), + wall_req_cpis: Optional[List[Decimal]] = None, + wall_statuses: Optional[List[Status]] = None, + wall_source: Source = Source.TESTING, + # Session details + final_status: Status = Status.COMPLETE, + started: datetime = utc_hour_ago, + ) -> "Session": + s: Session = session_factory( + user=user, + wall_count=wall_count, + wall_req_cpi=wall_req_cpi, + wall_req_cpis=wall_req_cpis, + wall_statuses=wall_statuses, + wall_source=wall_source, + final_status=final_status, + started=started, + ) + status, status_code_1 = s.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s.determine_payments() + session_manager.finish_with_status( + s, + finished=s.wall_events[-1].finished, + payout=bp_pay, + user_payout=user_pay, + status=status, + status_code_1=status_code_1, + ) + return s + + return _create_finished_session + + +@pytest.fixture(scope="function") +def session(user, session_manager, wall_manager) -> "Session": + from generalresearch.models.thl.session import Wall, Session + + session: Session = session_manager.create_dummy(user=user, country_iso="us") + wall: Wall = wall_manager.create_dummy( + session_id=session.id, + user_id=session.user_id, + started=session.started, + ) + session.append_wall_event(w=wall) + + return session + + +@pytest.fixture +def product(request, product_manager) -> "Product": + from generalresearch.managers.thl.product import ProductManager + + team = getattr(request, "team", None) + business = getattr(request, "business", None) + + product_manager: ProductManager + return product_manager.create_dummy( + team_id=team.uuid if team else None, + business_id=business.uuid if business else None, + ) + + +@pytest.fixture +def product_factory(product_manager) -> Callable: + def _create_product( + team: Optional["Team"] = None, + business: Optional["Business"] = None, + commission_pct: Decimal = Decimal("0.05"), + ): + return product_manager.create_dummy( + team_id=team.uuid if team else None, + business_id=business.uuid if business else None, + commission_pct=commission_pct, + ) + + return _create_product + + +@pytest.fixture(scope="function") +def payout_config(request) -> "PayoutConfig": + from generalresearch.models.thl.product import ( + PayoutConfig, + PayoutTransformation, + PayoutTransformationPercentArgs, + ) + + return ( + request.param + if hasattr(request, "payout_config") + else PayoutConfig( + payout_format="${payout/100:.2f}", + payout_transformation=PayoutTransformation( + f="payout_transformation_percent", + kwargs=PayoutTransformationPercentArgs(pct=0.40), + ), + ) + ) + + +@pytest.fixture(scope="function") +def product_user_wallet_yes(payout_config, product_manager) -> "Product": + from generalresearch.models.thl.product import UserWalletConfig + from generalresearch.managers.thl.product import ProductManager + + product_manager: ProductManager + return product_manager.create_dummy( + payout_config=payout_config, user_wallet_config=UserWalletConfig(enabled=True) + ) + + +@pytest.fixture(scope="function") +def product_user_wallet_no(product_manager) -> "Product": + from generalresearch.models.thl.product import UserWalletConfig + from generalresearch.managers.thl.product import ProductManager + + product_manager: ProductManager + return product_manager.create_dummy( + user_wallet_config=UserWalletConfig(enabled=False) + ) + + +@pytest.fixture(scope="function") +def product_amt_true(product_manager, payout_config) -> "Product": + from generalresearch.models.thl.product import UserWalletConfig + + return product_manager.create_dummy( + user_wallet_config=UserWalletConfig(amt=True, enabled=True), + payout_config=payout_config, + ) + + +@pytest.fixture(scope="function") +def bp_payout_factory( + thl_lm, product_manager, business_payout_event_manager +) -> Callable: + def _create_bp_payout( + product: Optional["Product"] = None, + amount: Optional["USDCent"] = None, + ext_ref_id: Optional[str] = None, + created: Optional[AwareDatetime] = None, + skip_wallet_balance_check: bool = False, + skip_one_per_day_check: bool = False, + ) -> "UserPayoutEvent": + from generalresearch.currency import USDCent + + product = product or product_manager.create_dummy() + amount = amount or USDCent(randint(1, 99_99)) + + return business_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=amount, + ext_ref_id=ext_ref_id, + created=created, + skip_wallet_balance_check=skip_wallet_balance_check, + skip_one_per_day_check=skip_one_per_day_check, + ) + + return _create_bp_payout + + +# === GR === + + +@pytest.fixture(scope="function") +def business(request, business_manager) -> "Business": + from generalresearch.managers.gr.business import BusinessManager + + business_manager: BusinessManager + return business_manager.create_dummy() + + +@pytest.fixture(scope="function") +def business_address(request, business, business_address_manager) -> "BusinessAddress": + from generalresearch.managers.gr.business import BusinessAddressManager + + business_address_manager: BusinessAddressManager + return business_address_manager.create_dummy(business_id=business.id) + + +@pytest.fixture(scope="function") +def business_bank_account( + request, business, business_bank_account_manager +) -> "BusinessBankAccount": + from generalresearch.managers.gr.business import BusinessBankAccountManager + + business_bank_account_manager: BusinessBankAccountManager + return business_bank_account_manager.create_dummy(business_id=business.id) + + +@pytest.fixture(scope="function") +def team(request, team_manager) -> "Team": + from generalresearch.managers.gr.team import TeamManager + + team_manager: TeamManager + return team_manager.create_dummy() + + +@pytest.fixture(scope="function") +def gr_user(gr_um) -> "GRUser": + from generalresearch.managers.gr.authentication import GRUserManager + + gr_um: GRUserManager + return gr_um.create_dummy() + + +@pytest.fixture(scope="function") +def gr_user_cache(gr_user, gr_db, thl_web_rr, gr_redis_config): + gr_user.set_cache( + pg_config=gr_db, thl_web_rr=thl_web_rr, redis_config=gr_redis_config + ) + return gr_user + + +@pytest.fixture(scope="function") +def gr_user_factory(gr_um) -> Callable: + def _create_gr_user(): + return gr_um.create_dummy() + + return _create_gr_user + + +@pytest.fixture() +def gr_user_token(gr_user, gr_tm, gr_db) -> "GRToken": + gr_tm.create(user_id=gr_user.id) + gr_user.prefetch_token(pg_config=gr_db) + + return gr_user.token + + +@pytest.fixture() +def gr_user_token_header(gr_user_token) -> Dict: + return gr_user_token.auth_header + + +@pytest.fixture(scope="function") +def membership(request, team, gr_user, team_manager) -> "Membership": + assert team.id, "Team must be saved" + assert gr_user.id, "GRUser must be saved" + return team_manager.add_user(team=team, gr_user=gr_user) + + +@pytest.fixture(scope="function") +def membership_factory( + team: "Team", gr_user: "GRUser", membership_manager, team_manager, gr_um +) -> Callable: + from generalresearch.managers.gr.team import MembershipManager + + membership_manager: MembershipManager + + def _create_membership(**kwargs): + _team = kwargs.get("team", team_manager.create_dummy()) + _gr_user = kwargs.get("gr_user", gr_um.create_dummy()) + + return membership_manager.create(team=_team, gr_user=_gr_user) + + return _create_membership + + +@pytest.fixture(scope="function") +def audit_log(audit_log_manager, user) -> "AuditLog": + from generalresearch.managers.thl.userhealth import AuditLogManager + + audit_log_manager: AuditLogManager + return audit_log_manager.create_dummy(user_id=user.user_id) + + +@pytest.fixture(scope="function") +def audit_log_factory(audit_log_manager) -> Callable: + from generalresearch.managers.thl.userhealth import AuditLogManager + + audit_log_manager: AuditLogManager + + def _create_audit_log( + user_id: PositiveInt, + level: Optional["AuditLogLevel"] = None, + event_type: Optional[str] = None, + event_msg: Optional[str] = None, + event_value: Optional[float] = None, + ): + return audit_log_manager.create_dummy( + user_id=user_id, + level=level, + event_type=event_type, + event_msg=event_msg, + event_value=event_value, + ) + + return _create_audit_log + + +@pytest.fixture(scope="function") +def ip_geoname(ip_geoname_manager) -> "IPGeoname": + from generalresearch.managers.thl.ipinfo import IPGeonameManager + + ip_geoname_manager: IPGeonameManager + return ip_geoname_manager.create_dummy() + + +@pytest.fixture(scope="function") +def ip_information(ip_information_manager, ip_geoname) -> "IPInformation": + from generalresearch.managers.thl.ipinfo import IPInformationManager + + ip_information_manager: IPInformationManager + return ip_information_manager.create_dummy( + geoname_id=ip_geoname.geoname_id, country_iso=ip_geoname.country_iso + ) + + +@pytest.fixture(scope="function") +def ip_information_factory(ip_information_manager) -> Callable: + from generalresearch.managers.thl.ipinfo import IPInformationManager + + ip_information_manager: IPInformationManager + + def _create_ip_info(ip: str, geoname: "IPGeoname", **kwargs): + return ip_information_manager.create_dummy( + ip=ip, + geoname_id=geoname.geoname_id, + country_iso=geoname.country_iso, + **kwargs, + ) + + return _create_ip_info + + +@pytest.fixture(scope="function") +def ip_record(ip_record_manager, ip_geoname, user) -> "IPRecord": + from generalresearch.managers.thl.userhealth import IPRecordManager + + ip_record_manager: IPRecordManager + + return ip_record_manager.create_dummy(user_id=user.user_id) + + +@pytest.fixture(scope="function") +def ip_record_factory(ip_record_manager, user) -> Callable: + from generalresearch.managers.thl.userhealth import IPRecordManager + + ip_record_manager: IPRecordManager + + def _create_ip_record(user_id: PositiveInt, ip: Optional[str] = None): + return ip_record_manager.create_dummy(user_id=user_id, ip=ip) + + return _create_ip_record + + +@pytest.fixture(scope="session") +def buyer(buyer_manager) -> Buyer: + buyer_code = uuid4().hex + buyer_manager.bulk_get_or_create(source=Source.TESTING, codes=[buyer_code]) + b = Buyer( + source=Source.TESTING, code=buyer_code, label=f"test-buyer-{buyer_code[:8]}" + ) + buyer_manager.update(b) + return b + + +@pytest.fixture(scope="session") +def buyer_factory(buyer_manager) -> Callable: + + def inner(): + return buyer_manager.bulk_get_or_create( + source=Source.TESTING, codes=[uuid4().hex] + )[0] + + return inner + + +@pytest.fixture(scope="session") +def survey(survey_manager, buyer) -> Survey: + s = Survey(source=Source.TESTING, survey_id=uuid4().hex, buyer_code=buyer.code) + survey_manager.create_bulk([s]) + return s + + +@pytest.fixture(scope="session") +def survey_factory(survey_manager, buyer_factory) -> Callable: + + def inner(buyer: Optional[Buyer] = None) -> Survey: + buyer = buyer or buyer_factory() + s = Survey( + source=Source.TESTING, + survey_id=uuid4().hex, + buyer_code=buyer.code, + buyer_id=buyer.id, + ) + survey_manager.create_bulk([s]) + return s + + return inner diff --git a/test_utils/spectrum/__init__.py b/test_utils/spectrum/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test_utils/spectrum/conftest.py b/test_utils/spectrum/conftest.py new file mode 100644 index 0000000..b7887f6 --- /dev/null +++ b/test_utils/spectrum/conftest.py @@ -0,0 +1,79 @@ +import logging + +import time + +import pytest +from datetime import datetime, timezone +from generalresearch.managers.spectrum.survey import ( + SpectrumSurveyManager, + SpectrumCriteriaManager, +) +from generalresearch.models.spectrum.survey import SpectrumSurvey +from generalresearch.sql_helper import SqlHelper + +from .surveys_json import SURVEYS_JSON, CONDITIONS + + +@pytest.fixture(scope="session") +def spectrum_rw(settings) -> SqlHelper: + print(f"{settings.spectrum_rw_db=}") + logging.info(f"{settings.spectrum_rw_db=}") + assert "/unittest-" in settings.spectrum_rw_db.path + return SqlHelper( + dsn=settings.spectrum_rw_db, + read_timeout=2, + write_timeout=1, + connect_timeout=2, + ) + + +@pytest.fixture(scope="session") +def spectrum_criteria_manager(spectrum_rw) -> SpectrumCriteriaManager: + assert "/unittest-" in spectrum_rw.dsn.path + return SpectrumCriteriaManager(spectrum_rw) + + +@pytest.fixture(scope="session") +def spectrum_survey_manager(spectrum_rw) -> SpectrumSurveyManager: + assert "/unittest-" in spectrum_rw.dsn.path + return SpectrumSurveyManager(spectrum_rw) + + +@pytest.fixture(scope="session") +def setup_spectrum_surveys( + spectrum_rw, spectrum_survey_manager, spectrum_criteria_manager +): + now = datetime.now(timezone.utc) + # make sure these example surveys exist in db + surveys = [SpectrumSurvey.model_validate_json(x) for x in SURVEYS_JSON] + for s in surveys: + s.modified_api = datetime.now(tz=timezone.utc) + spectrum_survey_manager.create_or_update(surveys) + spectrum_criteria_manager.update(CONDITIONS) + + # and make sure they have allocation for 687 + spectrum_rw.execute_sql_query( + f""" + INSERT IGNORE INTO `{spectrum_rw.db}`.spectrum_supplier + (supplier_id, name, api_key, secret_key, username, password) + VALUES (%s, %s, %s, %s, %s, %s)""", + ["687", "GRL", "x", "x", "x", "x"], + commit=True, + ) + supplier687_pk = spectrum_rw.execute_sql_query( + f""" + select id from `{spectrum_rw.db}`.spectrum_supplier where supplier_id = '687'""" + )[0]["id"] + conn = spectrum_rw.make_connection() + c = conn.cursor() + c.executemany( + f""" + INSERT IGNORE INTO `{spectrum_rw.db}`.spectrum_surveysupplier + (created, surveySig, supplier_id, survey_id) + VALUES (%s, %s, %s, %s) + """, + [[now, "xxx", supplier687_pk, s.survey_id] for s in surveys], + ) + conn.commit() + # Wait a second to make sure the spectrum-grpc pulls these from the db into global-vars + time.sleep(1) diff --git a/test_utils/spectrum/surveys_json.py b/test_utils/spectrum/surveys_json.py new file mode 100644 index 0000000..eb747a5 --- /dev/null +++ b/test_utils/spectrum/surveys_json.py @@ -0,0 +1,140 @@ +from generalresearch.models import LogicalOperator +from generalresearch.models.spectrum.survey import ( + SpectrumCondition, + SpectrumSurvey, +) +from generalresearch.models.thl.survey.condition import ConditionValueType + +SURVEYS_JSON = [ + '{"cpi":"3.90","country_isos":["us"],"language_isos":["eng"],"buyer_id":"215","bid_loi":780,"source":"s",' + '"used_question_ids":["1235","212"],"survey_id":"111111","survey_name":"Exciting New Survey #14472374",' + '"status":22,"field_end_date":"2023-03-02T07:05:36.261000Z","category_code":"232","calculation_type":"COMPLETES",' + '"requires_pii":false,"survey_exclusions":"13947261,14126487,14361592,14376811,14385771,14387789,14472374",' + '"exclusion_period":30,"bid_ir":0.2,"overall_loi":null,"overall_ir":null,"last_block_loi":null,' + '"last_block_ir":null,"project_last_complete_date":null,"country_iso":"us","language_iso":"eng",' + '"include_psids":null,"exclude_psids":null' + ',"qualifications":["ee5e842","e6e0b0b"],"quotas":[{"remaining_count":100,' + '"condition_hashes":["32cbf31"]}],"conditions":null,"created_api":"2023-02-28T07:05:36.698000Z",' + '"modified_api":"2024-03-10T09:43:40.030000Z","updated":"2024-05-30T21:52:46.431612Z","is_live":true' + "}", + '{"cpi":"3.90","country_isos":["us"],"language_isos":["eng"],"buyer_id":"215","bid_loi":780,"source":"s",' + '"used_question_ids":["1235","212"],"survey_id":"14472374","survey_name":"Exciting New Survey #14472374",' + '"status":22,"field_end_date":"2023-03-02T07:05:36.261000Z","category_code":"232","calculation_type":"COMPLETES",' + '"requires_pii":false,"survey_exclusions":"13947261,14126487,14361592,14376811,14385771,14387789,14472374",' + '"exclusion_period":30,"bid_ir":0.2,"overall_loi":null,"overall_ir":null,"last_block_loi":null,' + '"last_block_ir":null,"project_last_complete_date":null,"country_iso":"us","language_iso":"eng",' + '"include_psids":null,"exclude_psids":"0408319875e9dbffdc09e86671ad5636,23c4c66ecbc465906d0b0fd798740e64,' + '861df4603df3b7f754b8d4b89cbdb313","qualifications":["ee5e842","e6e0b0b"],"quotas":[{"remaining_count":100,' + '"condition_hashes":["32cbf31"]}],"conditions":null,"created_api":"2023-02-28T07:05:36.698000Z",' + '"modified_api":"2024-03-10T09:43:40.030000Z","updated":"2024-05-30T21:52:46.431612Z","is_live":true' + "}", + '{"cpi":"3.90","country_isos":["us"],"language_isos":["eng"],"buyer_id":"215","bid_loi":780,"source":"s",' + '"used_question_ids":["1235","212"],"survey_id":"12345","survey_name":"Exciting New Survey #14472374",' + '"status":22,"field_end_date":"2023-03-02T07:05:36.261000Z","category_code":"232","calculation_type":"COMPLETES",' + '"requires_pii":false,"survey_exclusions":"13947261,14126487,14361592,14376811,14385771,14387789,14472374",' + '"exclusion_period":30,"bid_ir":0.2,"overall_loi":null,"overall_ir":null,"last_block_loi":null,' + '"last_block_ir":null,"project_last_complete_date":null,"country_iso":"us","language_iso":"eng",' + '"include_psids":"7d043991b1494dbbb57786b11c88239c","exclude_psids":null' + ',"qualifications":["ee5e842","e6e0b0b"],"quotas":[{"remaining_count":100,' + '"condition_hashes":["32cbf31"]}],"conditions":null,"created_api":"2023-02-28T07:05:36.698000Z",' + '"modified_api":"2024-03-10T09:43:40.030000Z","updated":"2024-05-30T21:52:46.431612Z","is_live":true' + "}", + '{"cpi":"1.40","country_isos":["us"],"language_isos":["eng"],"buyer_id":"233","bid_loi":null,"source":"s",' + '"used_question_ids":["245","244","212","211","225"],"survey_id":"14970164","survey_name":"Exciting New Survey ' + '#14970164","status":22,"field_end_date":"2024-05-07T16:18:33.000000Z","category_code":"232",' + '"calculation_type":"COMPLETES","requires_pii":false,"survey_exclusions":"14970164,29690277",' + '"exclusion_period":30,"bid_ir":null,"overall_loi":900,"overall_ir":0.56,"last_block_loi":600,' + '"last_block_ir":0.01,"project_last_complete_date":"2024-05-28T04:12:56.297000Z","country_iso":"us",' + '"language_iso":"eng","include_psids":null,"exclude_psids":"01c7156fd9639737effbbdebd7fd66f6,' + "0508b88f4991bac8b10e9de74ce80194,0a51c627d77cef41f802e51a00126697,15b888176ac4781c2c978a9a05c396f8," + "17bc146b4f7fb05c7058d25da70c6a44,29935289c1f86a4144aab2e12652f305,2fe9d1d451efca10eba4fa4e5e2b74c9," + "c3527b7ef570a1571ea19870f3c25600,cdf2771d57cda9f1bf334382b2b7afd8,cebf3ec50395d973310ea526457dd5a0," + "cf3877cfc15e2e6ef2a56a7a7a37f3d3,dfa691e6d060e3643d5731df30be9f69,e0cb49537182660826aa351e1187809f," + 'edb6d280113ca49561f25fdcb500fde6,fbfba66cfad602f1c26e61e6174eb1f7,fd4307b16fd15e8534a4551c9b6872fc",' + '"qualifications":["1ab337d","a01aa68","437774f","dc6065b","82b6ad6"],"quotas":[{"remaining_count":242,' + '"condition_hashes":["c23c0b9"]},{"remaining_count":0,"condition_hashes":["5b8c6cf"]},{"remaining_count":126,' + '"condition_hashes":["ac35a6e"]},{"remaining_count":110,"condition_hashes":["5e7e5aa"]},{"remaining_count":108,' + '"condition_hashes":["9a7aef3"]},{"remaining_count":127,"condition_hashes":["4f75127"]},{"remaining_count":0,' + '"condition_hashes":["95437ed"]},{"remaining_count":17,"condition_hashes":["b4b7b95"]},{"remaining_count":16,' + '"condition_hashes":["0ab0ae6"]},{"remaining_count":8,"condition_hashes":["6e86fb5"]},{"remaining_count":12,' + '"condition_hashes":["24de31e"]},{"remaining_count":69,"condition_hashes":["6bdf350"]},{"remaining_count":411,' + '"condition_hashes":["c94d422"]}],"conditions":null,"created_api":"2023-03-30T22:47:36.324000Z",' + '"modified_api":"2024-05-30T13:07:16.489000Z","updated":"2024-05-30T21:52:37.493282Z","is_live":true,' + '"all_hashes":["c94d422","b4b7b95","6bdf350","6e86fb5","82b6ad6","24de31e","1ab337d","c23c0b9","9a7aef3",' + '"ac35a6e","95437ed","5b8c6cf","437774f","a01aa68","5e7e5aa","4f75127","0ab0ae6","dc6065b"]}', + '{"cpi":"1.23","country_isos":["au"],"language_isos":["eng"],"buyer_id":"215","bid_loi":780,"source":"s",' + '"used_question_ids":[],"survey_id":"69420","survey_name":"Everyone is eligible AU",' + '"status":22,"field_end_date":"2023-03-02T07:05:36.261000Z","category_code":"232","calculation_type":"COMPLETES",' + '"requires_pii":false,"survey_exclusions":"13947261,14126487,14361592,14376811,14385771,14387789,14472374",' + '"exclusion_period":30,"bid_ir":0.2,"overall_loi":null,"overall_ir":null,"last_block_loi":null,' + '"last_block_ir":null,"project_last_complete_date":null,"country_iso":"au","language_iso":"eng",' + '"include_psids":null,"exclude_psids":null' + ',"qualifications":[],"quotas":[{"remaining_count":100,' + '"condition_hashes":[]}],"conditions":null,"created_api":"2023-02-28T07:05:36.698000Z",' + '"modified_api":"2024-03-10T09:43:40.030000Z","updated":"2024-05-30T21:52:46.431612Z","is_live":true' + "}", + '{"cpi":"1.23","country_isos":["us"],"language_isos":["eng"],"buyer_id":"215","bid_loi":780,"source":"s",' + '"used_question_ids":[],"survey_id":"69421","survey_name":"Everyone is eligible US",' + '"status":22,"field_end_date":"2023-03-02T07:05:36.261000Z","category_code":"232","calculation_type":"COMPLETES",' + '"requires_pii":false,"survey_exclusions":"13947261,14126487,14361592,14376811,14385771,14387789,14472374",' + '"exclusion_period":30,"bid_ir":0.2,"overall_loi":null,"overall_ir":null,"last_block_loi":null,' + '"last_block_ir":null,"project_last_complete_date":null,"country_iso":"us","language_iso":"eng",' + '"include_psids":null,"exclude_psids":null' + ',"qualifications":[],"quotas":[{"remaining_count":100,' + '"condition_hashes":[]}],"conditions":null,"created_api":"2023-02-28T07:05:36.698000Z",' + '"modified_api":"2024-03-10T09:43:40.030000Z","updated":"2024-05-30T21:52:46.431612Z","is_live":true' + "}", + # For partial eligibility + '{"cpi":"1.23","country_isos":["us"],"language_isos":["eng"],"buyer_id":"215","bid_loi":780,"source":"s",' + '"used_question_ids":["1031", "212"],"survey_id":"999000","survey_name":"Pet owners",' + '"status":22,"field_end_date":"2023-03-02T07:05:36.261000Z","category_code":"232","calculation_type":"COMPLETES",' + '"requires_pii":false,"survey_exclusions":"13947261",' + '"exclusion_period":30,"bid_ir":0.2,"overall_loi":null,"overall_ir":null,"last_block_loi":null,' + '"last_block_ir":null,"project_last_complete_date":null,"country_iso":"us","language_iso":"eng",' + '"include_psids":null,"exclude_psids":null' + ',"qualifications":["0039b0c", "00f60a8"],"quotas":[{"remaining_count":100,' + '"condition_hashes":[]}],"conditions":null,"created_api":"2023-02-28T07:05:36.698000Z",' + '"modified_api":"2024-03-10T09:43:40.030000Z","updated":"2024-05-30T21:52:46.431612Z","is_live":true' + "}", +] + +# make sure hashes for 111111 are in db +c1 = SpectrumCondition( + question_id="1001", + value_type=ConditionValueType.LIST, + values=["a", "b", "c"], + negate=False, + logical_operator=LogicalOperator.OR, +) +c2 = SpectrumCondition( + question_id="1001", + value_type=ConditionValueType.LIST, + values=["a"], + negate=False, + logical_operator=LogicalOperator.OR, +) +c3 = SpectrumCondition( + question_id="1002", + value_type=ConditionValueType.RANGE, + values=["18-24", "30-32"], + negate=False, + logical_operator=LogicalOperator.OR, +) +c4 = SpectrumCondition( + question_id="212", + value_type=ConditionValueType.LIST, + values=["23", "24"], + negate=False, + logical_operator=LogicalOperator.OR, +) +c5 = SpectrumCondition( + question_id="1031", + value_type=ConditionValueType.LIST, + values=["113", "114", "121"], + negate=False, + logical_operator=LogicalOperator.OR, +) +CONDITIONS = [c1, c2, c3, c4, c5] +survey = SpectrumSurvey.model_validate_json(SURVEYS_JSON[0]) +assert c1.criterion_hash in survey.qualifications +assert c3.criterion_hash in survey.qualifications diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..30ed1c7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,19 @@ +pytest_plugins = [ + "distributed.utils_test", + "test_utils.conftest", + # -- GRL IQ + "test_utils.grliq.conftest", + "test_utils.grliq.managers.conftest", + "test_utils.grliq.models.conftest", + # -- Incite + "test_utils.incite.conftest", + "test_utils.incite.collections.conftest", + "test_utils.incite.mergers.conftest", + # -- Managers + "test_utils.managers.conftest", + "test_utils.managers.contest.conftest", + "test_utils.managers.ledger.conftest", + "test_utils.managers.upk.conftest", + # -- Models + "test_utils.models.conftest", +] diff --git a/tests/grliq/__init__.py b/tests/grliq/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/grliq/managers/__init__.py b/tests/grliq/managers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/grliq/managers/test_forensic_data.py b/tests/grliq/managers/test_forensic_data.py new file mode 100644 index 0000000..ed4da80 --- /dev/null +++ b/tests/grliq/managers/test_forensic_data.py @@ -0,0 +1,212 @@ +from datetime import timedelta +from uuid import uuid4 + +import pytest + +from generalresearch.grliq.models.events import TimingData, MouseEvent +from generalresearch.grliq.models.forensic_data import GrlIqData +from generalresearch.grliq.models.forensic_result import ( + GrlIqCheckerResults, + GrlIqForensicCategoryResult, +) + +try: + from psycopg.errors import UniqueViolation +except ImportError: + pass + + +class TestGrlIqDataManager: + + def test_create_dummy(self, grliq_dm): + from generalresearch.grliq.managers.forensic_data import GrlIqDataManager + from generalresearch.grliq.models.forensic_data import GrlIqData + + grliq_dm: GrlIqDataManager + gd1: GrlIqData = grliq_dm.create_dummy(is_attempt_allowed=True) + + assert isinstance(gd1, GrlIqData) + assert isinstance(gd1.results, GrlIqCheckerResults) + assert isinstance(gd1.category_result, GrlIqForensicCategoryResult) + + def test_create(self, grliq_data, grliq_dm): + grliq_dm.create(grliq_data) + assert grliq_data.id is not None + + with pytest.raises(UniqueViolation): + grliq_dm.create(grliq_data) + + @pytest.mark.skip(reason="todo") + def test_set_result(self): + pass + + @pytest.mark.skip(reason="todo") + def test_update_fingerprint(self): + pass + + @pytest.mark.skip(reason="todo") + def test_update_fingerprint(self): + pass + + @pytest.mark.skip(reason="todo") + def test_update_data(self): + pass + + def test_get_id(self, grliq_data, grliq_dm): + grliq_dm.create(grliq_data) + + res = grliq_dm.get_data(forensic_id=grliq_data.id) + assert res == grliq_data + + def test_get_uuid(self, grliq_data, grliq_dm): + grliq_dm.create(grliq_data) + + res = grliq_dm.get_data(forensic_uuid=grliq_data.uuid) + assert res == grliq_data + + @pytest.mark.skip(reason="todo") + def test_filter_timing_data(self): + pass + + @pytest.mark.skip(reason="todo") + def test_get_unique_user_count_by_fingerprint(self): + pass + + def test_filter_data(self, grliq_data, grliq_dm): + grliq_dm.create(grliq_data) + res = grliq_dm.filter_data(uuids=[grliq_data.uuid])[0] + assert res == grliq_data + now = res.created_at + res = grliq_dm.filter_data( + created_after=now, created_before=now + timedelta(minutes=1) + ) + assert len(res) == 1 + res = grliq_dm.filter_data( + created_after=now + timedelta(seconds=1), + created_before=now + timedelta(minutes=1), + ) + assert len(res) == 0 + + @pytest.mark.skip(reason="todo") + def test_filter_results(self): + pass + + @pytest.mark.skip(reason="todo") + def test_filter_category_results(self): + pass + + @pytest.mark.skip(reason="todo") + def test_make_filter_str(self): + pass + + def test_filter_count(self, grliq_dm, product): + res = grliq_dm.filter_count(product_id=product.uuid) + + assert isinstance(res, int) + + @pytest.mark.skip(reason="todo") + def test_filter(self): + pass + + @pytest.mark.skip(reason="todo") + def test_temporary_add_missing_fields(self): + pass + + +class TestForensicDataGetAndFilter: + + def test_events(self, grliq_dm, grliq_em): + """If load_events=True, the events and mouse_events attributes should + be an array no matter what. An empty array means that the events were + loaded, but there were no events available. + + If loaded_eventsFalse, the events and mouse_events attributes should + be None + """ + # Load Events == False + forensic_uuid = uuid4().hex + grliq_dm.create_dummy(is_attempt_allowed=True, uuid=forensic_uuid) + + instance = grliq_dm.filter_data(uuids=[forensic_uuid])[0] + assert isinstance(instance, GrlIqData) + + assert instance.events is None + assert instance.mouse_events is None + + # Load Events == True + instance = grliq_dm.get_data(forensic_uuid=forensic_uuid, load_events=True) + assert isinstance(instance, GrlIqData) + # This one doesn't have any events though + assert len(instance.events) == 0 + assert len(instance.mouse_events) == 0 + + def test_timing(self, grliq_dm, grliq_em): + forensic_uuid = uuid4().hex + grliq_dm.create_dummy(is_attempt_allowed=True, uuid=forensic_uuid) + + instance = grliq_dm.filter_data(uuids=[forensic_uuid])[0] + + grliq_em.update_or_create_timing( + session_uuid=instance.mid, + timing_data=TimingData( + client_rtts=[100, 200, 150], server_rtts=[150, 120, 120] + ), + ) + instance = grliq_dm.get_data(forensic_uuid=forensic_uuid, load_events=True) + assert isinstance(instance, GrlIqData) + assert isinstance(instance.events, list) + assert isinstance(instance.mouse_events, list) + assert isinstance(instance.timing_data, TimingData) + + def test_events_events(self, grliq_dm, grliq_em): + forensic_uuid = uuid4().hex + grliq_dm.create_dummy(is_attempt_allowed=True, uuid=forensic_uuid) + + instance = grliq_dm.filter_data(uuids=[forensic_uuid])[0] + + grliq_em.update_or_create_events( + session_uuid=instance.mid, + events=[{"a": "b"}], + mouse_events=[], + event_start=instance.created_at, + event_end=instance.created_at + timedelta(minutes=1), + ) + instance = grliq_dm.get_data(forensic_uuid=forensic_uuid, load_events=True) + assert isinstance(instance, GrlIqData) + assert isinstance(instance.events, list) + assert isinstance(instance.mouse_events, list) + assert instance.timing_data is None + assert instance.events == [{"a": "b"}] + assert len(instance.mouse_events) == 0 + assert len(instance.pointer_move_events) == 0 + assert len(instance.keyboard_events) == 0 + + def test_events_click(self, grliq_dm, grliq_em): + forensic_uuid = uuid4().hex + grliq_dm.create_dummy(is_attempt_allowed=True, uuid=forensic_uuid) + instance = grliq_dm.get_data(forensic_uuid=forensic_uuid, load_events=True) + + click_event = { + "type": "click", + "pageX": 0, + "pageY": 0, + "timeStamp": 123, + "pointerType": "mouse", + } + me = MouseEvent.from_dict(click_event) + grliq_em.update_or_create_events( + session_uuid=instance.mid, + events=[click_event], + mouse_events=[], + event_start=instance.created_at, + event_end=instance.created_at + timedelta(minutes=1), + ) + instance = grliq_dm.get_data(forensic_uuid=forensic_uuid, load_events=True) + assert isinstance(instance, GrlIqData) + assert isinstance(instance.events, list) + assert isinstance(instance.mouse_events, list) + assert instance.timing_data is None + assert instance.events == [click_event] + assert instance.mouse_events == [me] + assert len(instance.pointer_move_events) == 0 + assert len(instance.keyboard_events) == 0 diff --git a/tests/grliq/managers/test_forensic_results.py b/tests/grliq/managers/test_forensic_results.py new file mode 100644 index 0000000..a837a64 --- /dev/null +++ b/tests/grliq/managers/test_forensic_results.py @@ -0,0 +1,16 @@ +class TestGrlIqCategoryResultsReader: + + def test_filter_category_results(self, grliq_dm, grliq_crr): + from generalresearch.grliq.models.forensic_result import ( + Phase, + GrlIqForensicCategoryResult, + ) + + # this is just testing that it doesn't fail + grliq_dm.create_dummy(is_attempt_allowed=True) + grliq_dm.create_dummy(is_attempt_allowed=True) + + res = grliq_crr.filter_category_results(limit=2, phase=Phase.OFFERWALL_ENTER)[0] + assert res.get("category_result") + assert isinstance(res["category_result"], GrlIqForensicCategoryResult) + assert res["user_agent"].os.family diff --git a/tests/grliq/models/__init__.py b/tests/grliq/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/grliq/models/test_forensic_data.py b/tests/grliq/models/test_forensic_data.py new file mode 100644 index 0000000..653f9a9 --- /dev/null +++ b/tests/grliq/models/test_forensic_data.py @@ -0,0 +1,49 @@ +import pytest +from pydantic import ValidationError + +from generalresearch.grliq.models.forensic_data import GrlIqData, Platform + + +class TestGrlIqData: + + def test_supported_fonts(self, grliq_data): + s = grliq_data.supported_fonts_binary + assert len(s) == 1043 + assert "Ubuntu" in grliq_data.supported_fonts + + def test_battery(self, grliq_data): + assert not grliq_data.battery_charging + assert grliq_data.battery_level == 0.41 + + def test_base(self, grliq_data): + g: GrlIqData = grliq_data + assert g.timezone == "America/Los_Angeles" + assert g.platform == Platform.LINUX_X86_64 + assert g.webgl_extensions + # ... more + + assert g.results is None + assert g.category_result is None + + s = g.model_dump_json() + g2: GrlIqData = GrlIqData.model_validate_json(s) + + assert g2.results is None + assert g2.category_result is None + + assert g == g2 + + # Testing things that will cause a validation error, should only be + # because something is "corrupt", not b/c the user is a baddie + def test_corrupt(self, grliq_data): + """Test for timestamp and timezone offset mismatch validation.""" + d = grliq_data.model_dump(mode="json") + d.update( + { + "timezone": "America/XXX", + } + ) + with pytest.raises(ValidationError) as e: + GrlIqData.model_validate(d) + + assert "Invalid timezone name" in str(e.value) diff --git a/tests/grliq/test_utils.py b/tests/grliq/test_utils.py new file mode 100644 index 0000000..d9034d5 --- /dev/null +++ b/tests/grliq/test_utils.py @@ -0,0 +1,17 @@ +from pathlib import Path +from uuid import uuid4 + + +class TestUtils: + + def test_get_screenshot_fp(self, mnt_grliq_archive_dir, utc_hour_ago): + from generalresearch.grliq.utils import get_screenshot_fp + + fp1 = get_screenshot_fp( + created_at=utc_hour_ago, + forensic_uuid=uuid4(), + grliq_archive_dir=mnt_grliq_archive_dir, + create_dir_if_not_exists=True, + ) + + assert isinstance(fp1, Path) diff --git a/tests/incite/__init__.py b/tests/incite/__init__.py new file mode 100644 index 0000000..2f736e8 --- /dev/null +++ b/tests/incite/__init__.py @@ -0,0 +1,137 @@ +# class TestParquetBehaviors(CleanTempDirectoryTestCls): +# wall_coll = WallDFCollection( +# start=GLOBAL_VARS["wall"].start, +# offset="49h", +# archive_path=f"{settings.incite_mount_dir}/raw/df-collections/{DFCollectionType.WALL.value}", +# ) +# +# def test_filters(self): +# # Using REAL data here +# start = datetime(year=2024, month=1, day=15, hour=12, tzinfo=timezone.utc) +# end = datetime(year=2024, month=1, day=15, hour=20, tzinfo=timezone.utc) +# end_max = datetime( +# year=2024, month=1, day=15, hour=20, tzinfo=timezone.utc +# ) + timedelta(hours=2) +# +# ir = pd.Interval(left=pd.Timestamp(start), right=pd.Timestamp(end)) +# wall_items = [w for w in self.wall_coll.items if w.interval.overlaps(ir)] +# ddf = self.wall_coll.ddf( +# items=wall_items, +# include_partial=True, +# force_rr_latest=False, +# columns=["started", "finished"], +# filters=[ +# ("started", ">=", start), +# ("started", "<", end), +# ], +# ) +# +# df = ddf.compute() +# self.assertIsInstance(df, pd.DataFrame) +# +# # No started=None, and they're all between the started and the end +# self.assertFalse(df.started.isna().any()) +# self.assertFalse((df.started < start).any()) +# self.assertFalse((df.started > end).any()) +# +# # Has finished=None and finished=time, so +# # the finished is all between the started and +# # the end_max +# self.assertTrue(df.finished.isna().any()) +# self.assertTrue((df.finished.dt.year == 2024).any()) +# +# self.assertFalse((df.finished > end_max).any()) +# self.assertFalse((df.finished < start).any()) +# +# # def test_user_id_list(self): +# # # Calling compute turns it into a np.ndarray +# # user_ids = self.instance.ddf( +# # columns=["user_id"] +# # ).user_id.unique().values.compute() +# # self.assertIsInstance(user_ids, np.ndarray) +# # +# # # If ddf filters work with ndarray +# # user_product_merge = +# # +# # with self.assertRaises(TypeError) as cm: +# # user_product_merge.ddf( +# # filters=[("id", "in", user_ids)]) +# # self.assertIn("Value of 'in' filter must be a list, set or tuple.", str(cm.exception)) +# # +# # # No compute == dask array +# # user_ids = self.instance.ddf( +# # columns=["user_id"] +# # ).user_id.unique().values +# # self.assertIsInstance(user_ids, da.Array) +# # +# # with self.assertRaises(TypeError) as cm: +# # user_product_merge.ddf( +# # filters=[("id", "in", user_ids)]) +# # self.assertIn("Value of 'in' filter must be a list, set or tuple.", str(cm.exception)) +# # +# # # pick a product_id (most active one) +# # self.product_id = instance.df.product_id.value_counts().index[0] +# # self.expected_columns: int = len(instance._schema.columns) +# # self.instance = instance +# +# # def test_basic(self): +# # # now try to load up the data! +# # self.instance.grouped_key = self.product_id +# # +# # # Confirm any of the items are archived +# # self.assertTrue(self.instance.progress.has_archive.eq(True).any()) +# # +# # # Confirm it returns a df +# # df = self.instance.dd().compute() +# # +# # self.assertFalse(df.empty) +# # self.assertEqual(df.shape[1], self.expected_columns) +# # self.assertGreater(df.shape[0], 1) +# # +# # # Confirm that DF only contains this product_id +# # self.assertEqual(df[df.product_id == self.product_id].shape, df.shape) +# +# # def test_god_vs_product_id(self): +# # self.instance.grouped_key = self.product_id +# # df_product_origin = self.instance.dd(columns=None, filters=None).compute() +# # +# # self.instance.grouped_key = None +# # df_god_origin = self.instance.dd(columns=None, +# # filters=[("product_id", "==", self.product_id)]).compute() +# # +# # self.assertTrue(df_god_origin.equals(df_product_origin)) +# +# # +# # instance = POPSessionMerge( +# # start=START, +# # archive_path=self.PATH, +# # group_by="product_id" +# # ) +# # instance.build(U=GLOBAL_VARS["user"], S=GLOBAL_VARS["session"], W=GLOBAL_VARS["wall"]) +# # instance.save(god_only=False) +# # +# # # pick a product_id (most active one) +# # self.product_id = instance.df.product_id.value_counts().index[0] +# # self.expected_columns: int = len(instance._schema.columns) +# # self.instance = instance +# +# +# class TestValidItem(CleanTempDirectoryTestCls): +# +# def test_interval(self): +# for k in GLOBAL_VARS.keys(): +# coll = GLOBAL_VARS[k] +# item = coll.items[0] +# ir = item.interval +# +# self.assertIsInstance(ir, pd.Interval) +# self.assertLess(a=ir.left, b=ir.right) +# +# def test_str(self): +# for k in GLOBAL_VARS.keys(): +# coll = GLOBAL_VARS[k] +# item = coll.items[0] +# +# offset = coll.offset or "–" +# +# self.assertIn(offset, str(item)) diff --git a/tests/incite/collections/__init__.py b/tests/incite/collections/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/incite/collections/test_df_collection_base.py b/tests/incite/collections/test_df_collection_base.py new file mode 100644 index 0000000..5aaa729 --- /dev/null +++ b/tests/incite/collections/test_df_collection_base.py @@ -0,0 +1,113 @@ +from datetime import datetime, timezone + +import pandas as pd +import pytest +from pandera import DataFrameSchema + +from generalresearch.incite.collections import ( + DFCollectionType, + DFCollection, +) +from test_utils.incite.conftest import mnt_filepath + +df_collection_types = [e for e in DFCollectionType if e is not DFCollectionType.TEST] + + +@pytest.mark.parametrize("df_coll_type", df_collection_types) +class TestDFCollectionBase: + """None of these tests are about the DFCollection with any specific + data_type... that will be handled in other parameterized tests + + """ + + def test_init(self, mnt_filepath, df_coll_type): + """Try to initialize the DFCollection with various invalid parameters""" + with pytest.raises(expected_exception=ValueError) as cm: + DFCollection(archive_path=mnt_filepath.data_src) + assert "Must explicitly provide a data_type" in str(cm.value) + + # with pytest.raises(expected_exception=ValueError) as cm: + # DFCollection( + # data_type=DFCollectionType.TEST, archive_path=mnt_filepath.data_src + # ) + # assert "Must provide a supported data_type" in str(cm.value) + + instance = DFCollection( + data_type=DFCollectionType.WALL, archive_path=mnt_filepath.data_src + ) + assert instance.data_type == DFCollectionType.WALL + + +@pytest.mark.parametrize("df_coll_type", df_collection_types) +class TestDFCollectionBaseProperties: + + @pytest.mark.skip + def test_df_collection_items(self, mnt_filepath, df_coll_type): + instance = DFCollection( + data_type=df_coll_type, + start=datetime(year=1800, month=1, day=1, tzinfo=timezone.utc), + finished=datetime(year=1900, month=1, day=1, tzinfo=timezone.utc), + offset="100d", + archive_path=mnt_filepath.archive_path(enum_type=df_coll_type), + ) + + assert len(instance.interval_range) == len(instance.items) + assert len(instance.items) == 366 + + def test_df_collection_progress(self, mnt_filepath, df_coll_type): + instance = DFCollection( + data_type=df_coll_type, + start=datetime(year=1800, month=1, day=1, tzinfo=timezone.utc), + finished=datetime(year=1900, month=1, day=1, tzinfo=timezone.utc), + offset="100d", + archive_path=mnt_filepath.archive_path(enum_type=df_coll_type), + ) + + # Progress returns a dataframe with a row each Item + assert isinstance(instance.progress, pd.DataFrame) + assert instance.progress.shape == (366, 6) + + def test_df_collection_schema(self, mnt_filepath, df_coll_type): + instance1 = DFCollection( + data_type=DFCollectionType.WALL, archive_path=mnt_filepath.data_src + ) + + instance2 = DFCollection( + data_type=DFCollectionType.SESSION, archive_path=mnt_filepath.data_src + ) + + assert instance1._schema != instance2._schema + assert isinstance(instance1._schema, DataFrameSchema) + assert isinstance(instance2._schema, DataFrameSchema) + + +class TestDFCollectionBaseMethods: + + @pytest.mark.skip + def test_initial_load(self, mnt_filepath, thl_web_rr): + instance = DFCollection( + pg_config=thl_web_rr, + data_type=DFCollectionType.USER, + start=datetime(year=2022, month=1, day=1, minute=0, tzinfo=timezone.utc), + finished=datetime(year=2022, month=1, day=1, minute=5, tzinfo=timezone.utc), + offset="2min", + archive_path=mnt_filepath.data_src, + ) + + # Confirm that there are no archives available yet + assert instance.progress.has_archive.eq(False).all() + + instance.initial_load() + assert 47 == len(instance.ddf().index) + assert instance.progress.should_archive.eq(True).all() + + # A few archives should have been made + assert not instance.progress.has_archive.eq(False).all() + + @pytest.mark.skip + def test_fetch_force_rr_latest(self): + pass + + @pytest.mark.skip + def test_force_rr_latest(self): + pass diff --git a/tests/incite/collections/test_df_collection_item_base.py b/tests/incite/collections/test_df_collection_item_base.py new file mode 100644 index 0000000..a0c0b0b --- /dev/null +++ b/tests/incite/collections/test_df_collection_item_base.py @@ -0,0 +1,72 @@ +from datetime import datetime, timezone + +import pytest + +from generalresearch.incite.collections import ( + DFCollectionType, + DFCollectionItem, + DFCollection, +) +from test_utils.incite.conftest import mnt_filepath + +df_collection_types = [e for e in DFCollectionType if e is not DFCollectionType.TEST] + + +@pytest.mark.parametrize("df_coll_type", df_collection_types) +class TestDFCollectionItemBase: + + def test_init(self, mnt_filepath, df_coll_type): + collection = DFCollection( + data_type=df_coll_type, + offset="100d", + start=datetime(year=1800, month=1, day=1, tzinfo=timezone.utc), + finished=datetime(year=1900, month=1, day=1, tzinfo=timezone.utc), + archive_path=mnt_filepath.archive_path(enum_type=df_coll_type), + ) + + item = DFCollectionItem() + item._collection = collection + + assert isinstance(item, DFCollectionItem) + + +@pytest.mark.parametrize("df_coll_type", df_collection_types) +class TestDFCollectionItemProperties: + + @pytest.mark.skip + def test_filename(self, df_coll_type): + pass + + +@pytest.mark.parametrize("df_coll_type", df_collection_types) +class TestDFCollectionItemMethods: + + def test_has_mysql_false(self, mnt_filepath, df_coll_type): + collection = DFCollection( + data_type=df_coll_type, + offset="100d", + start=datetime(year=1800, month=1, day=1, tzinfo=timezone.utc), + finished=datetime(year=1900, month=1, day=1, tzinfo=timezone.utc), + archive_path=mnt_filepath.archive_path(enum_type=df_coll_type), + ) + + instance1: DFCollectionItem = collection.items[0] + assert not instance1.has_mysql() + + def test_has_mysql_true(self, thl_web_rr, mnt_filepath, df_coll_type): + collection = DFCollection( + data_type=df_coll_type, + offset="100d", + start=datetime(year=1800, month=1, day=1, tzinfo=timezone.utc), + finished=datetime(year=1900, month=1, day=1, tzinfo=timezone.utc), + archive_path=mnt_filepath.archive_path(enum_type=df_coll_type), + pg_config=thl_web_rr, + ) + + # Has RR, assume unittest server is online + instance2: DFCollectionItem = collection.items[0] + assert instance2.has_mysql() + + @pytest.mark.skip + def test_update_partial_archive(self, df_coll_type): + pass diff --git a/tests/incite/collections/test_df_collection_item_thl_web.py b/tests/incite/collections/test_df_collection_item_thl_web.py new file mode 100644 index 0000000..9c3d67a --- /dev/null +++ b/tests/incite/collections/test_df_collection_item_thl_web.py @@ -0,0 +1,994 @@ +from datetime import datetime, timezone, timedelta +from itertools import product as iter_product +from os.path import join as pjoin +from pathlib import PurePath, Path +from uuid import uuid4 + +import dask.dataframe as dd +import pandas as pd +import pytest +from distributed import Client, Scheduler, Worker + +# noinspection PyUnresolvedReferences +from distributed.utils_test import ( + gen_cluster, + client_no_amm, + loop, + loop_in_thread, + cleanup, + cluster_fixture, + client, +) +from faker import Faker +from pandera import DataFrameSchema +from pydantic import FilePath + +from generalresearch.incite.base import CollectionItemBase +from generalresearch.incite.collections import ( + DFCollectionItem, + DFCollectionType, +) +from generalresearch.incite.schemas import ARCHIVE_AFTER +from generalresearch.models.thl.user import User +from generalresearch.pg_helper import PostgresConfig +from generalresearch.sql_helper import PostgresDsn +from test_utils.incite.conftest import mnt_filepath, incite_item_factory + +fake = Faker() + +df_collections = [ + DFCollectionType.WALL, + DFCollectionType.SESSION, + DFCollectionType.LEDGER, + DFCollectionType.TASK_ADJUSTMENT, +] + +unsupported_mock_types = { + DFCollectionType.IP_INFO, + DFCollectionType.IP_HISTORY, + DFCollectionType.IP_HISTORY_WS, + DFCollectionType.TASK_ADJUSTMENT, +} + + +def combo_object(): + for x in iter_product( + df_collections, + ["15min", "45min", "1H"], + ): + yield x + + +class TestDFCollectionItemBase: + def test_init(self): + instance = CollectionItemBase() + assert isinstance(instance, CollectionItemBase) + assert isinstance(instance.start, datetime) + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollectionItemProperties: + + def test_filename(self, df_collection_data_type, df_collection, offset): + for i in df_collection.items: + assert isinstance(i.filename, str) + + assert isinstance(i.path, PurePath) + assert i.path.name == i.filename + + assert i._collection.data_type.name.lower() in i.filename + assert i._collection.offset in i.filename + assert i.start.strftime("%Y-%m-%d-%H-%M-%S") in i.filename + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollectionItemPropertiesBase: + + def test_name(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.name, str) + + def test_finish(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.finish, datetime) + + def test_interval(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.interval, pd.Interval) + + def test_partial_filename(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.partial_filename, str) + + def test_empty_filename(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.empty_filename, str) + + def test_path(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.path, FilePath) + + def test_partial_path(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.partial_path, FilePath) + + def test_empty_path(self, df_collection_data_type, offset, df_collection): + for i in df_collection.items: + assert isinstance(i.empty_path, FilePath) + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset, duration", + argvalues=list( + iter_product( + df_collections, + ["12h", "10D"], + [timedelta(days=10), timedelta(days=45)], + ) + ), +) +class TestDFCollectionItemMethod: + + def test_has_mysql( + self, + df_collection, + thl_web_rr, + offset, + duration, + df_collection_data_type, + delete_df_collection, + ): + delete_df_collection(coll=df_collection) + + df_collection.pg_config = None + for i in df_collection.items: + assert not i.has_mysql() + + # Confirm that the regular connection should work as expected + df_collection.pg_config = thl_web_rr + for i in df_collection.items: + assert i.has_mysql() + + # Make a fake connection and confirm it does NOT work + df_collection.pg_config = PostgresConfig( + dsn=PostgresDsn(f"postgres://root:@127.0.0.1/{uuid4().hex}"), + connect_timeout=5, + statement_timeout=1, + ) + for i in df_collection.items: + assert not i.has_mysql() + + @pytest.mark.skip + def test_update_partial_archive( + self, + df_collection, + offset, + duration, + thl_web_rw, + df_collection_data_type, + delete_df_collection, + ): + # for i in collection.items: + # assert i.update_partial_archive() + # assert df.created.max() < _last_time_block[1] + pass + + @pytest.mark.skip + def test_create_partial_archive( + self, + df_collection, + offset, + duration, + create_main_accounts, + thl_web_rw, + thl_lm, + df_collection_data_type, + user_factory, + product, + client_no_amm, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + assert 1 + 1 == 2 + + def test_dict( + self, + df_collection_data_type, + offset, + duration, + df_collection, + delete_df_collection, + ): + delete_df_collection(coll=df_collection) + + for item in df_collection.items: + res = item.to_dict() + assert isinstance(res, dict) + assert len(res.keys()) == 6 + + assert isinstance(res["should_archive"], bool) + assert isinstance(res["has_archive"], bool) + assert isinstance(res["path"], Path) + assert isinstance(res["filename"], str) + + assert isinstance(res["start"], datetime) + assert isinstance(res["finish"], datetime) + assert res["start"] < res["finish"] + + def test_from_mysql( + self, + df_collection_data_type, + df_collection, + offset, + duration, + create_main_accounts, + thl_web_rw, + user_factory, + product, + incite_item_factory, + delete_df_collection, + ): + from generalresearch.models.thl.user import User + + if df_collection.data_type in unsupported_mock_types: + return + + delete_df_collection(coll=df_collection) + u1: User = user_factory(product=product) + + # No data has been loaded, but we can confirm the from_mysql returns + # back an empty data with the correct columns + for item in df_collection.items: + # Unlike .from_mysql_ledger(), .from_mysql_standard() will return + # back and empty df with the correct columns in place + delete_df_collection(coll=df_collection) + df = item.from_mysql() + if df_collection.data_type == DFCollectionType.LEDGER: + assert df is None + else: + assert df.empty + assert set(df.columns) == set(df_collection._schema.columns.keys()) + + incite_item_factory(user=u1, item=item) + + df = item.from_mysql() + assert not df.empty + assert set(df.columns) == set(df_collection._schema.columns.keys()) + if df_collection.data_type == DFCollectionType.LEDGER: + # The number of rows in this dataframe will change depending + # on the mocking of data. It's because if the account has + # user wallet on, then there will be more transactions for + # example. + assert df.shape[0] > 0 + + def test_from_mysql_standard( + self, + df_collection_data_type, + df_collection, + offset, + duration, + user_factory, + product, + incite_item_factory, + delete_df_collection, + ): + from generalresearch.models.thl.user import User + + if df_collection.data_type in unsupported_mock_types: + return + u1: User = user_factory(product=product) + + delete_df_collection(coll=df_collection) + + for item in df_collection.items: + item: DFCollectionItem + + if df_collection.data_type == DFCollectionType.LEDGER: + # We're using parametrize, so this If statement is just to + # confirm other Item Types will always raise an assertion + with pytest.raises(expected_exception=AssertionError) as cm: + res = item.from_mysql_standard() + assert ( + "Can't call from_mysql_standard for Ledger DFCollectionItem" + in str(cm.value) + ) + + continue + + # Unlike .from_mysql_ledger(), .from_mysql_standard() will return + # back and empty df with the correct columns in place + df = item.from_mysql_standard() + assert df.empty + assert set(df.columns) == set(df_collection._schema.columns.keys()) + + incite_item_factory(user=u1, item=item) + + df = item.from_mysql_standard() + assert not df.empty + assert set(df.columns) == set(df_collection._schema.columns.keys()) + assert df.shape[0] > 0 + + def test_from_mysql_ledger( + self, + df_collection, + user, + create_main_accounts, + offset, + duration, + thl_web_rw, + thl_lm, + df_collection_data_type, + user_factory, + product, + client_no_amm, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + from generalresearch.models.thl.user import User + + if df_collection.data_type != DFCollectionType.LEDGER: + return + u1: User = user_factory(product=product) + + delete_df_collection(coll=df_collection) + + for item in df_collection.items: + item: DFCollectionItem + delete_df_collection(coll=df_collection) + + # Okay, now continue with the actual Ledger Item tests... we need + # to ensure that this item.start - item.finish range hasn't had + # any prior transactions created within that range. + assert item.from_mysql_ledger() is None + + # Create main accounts doesn't matter because it doesn't + # add any transactions to the db + assert item.from_mysql_ledger() is None + + incite_item_factory(user=u1, item=item) + df = item.from_mysql_ledger() + assert isinstance(df, pd.DataFrame) + + # Not only is this a np.int64 to int comparison, but I also know it + # isn't actually measuring anything meaningful. However, it's useful + # as it tells us if the DF contains all the correct TX Entries. I + # figured it's better to count the amount rather than just the + # number of rows. DF == transactions * 2 because there are two + # entries per transactions + # assert df.amount.sum() == total_amt + # assert total_entries == df.shape[0] + + assert not df.tx_id.is_unique + df["net"] = df.direction * df.amount + assert df.groupby("tx_id").net.sum().sum() == 0 + + def test_to_archive( + self, + df_collection, + user, + offset, + duration, + df_collection_data_type, + user_factory, + product, + client_no_amm, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + from generalresearch.models.thl.user import User + + if df_collection.data_type in unsupported_mock_types: + return + u1: User = user_factory(product=product) + + delete_df_collection(coll=df_collection) + + for item in df_collection.items: + item: DFCollectionItem + + incite_item_factory(user=u1, item=item) + + # Load up the data that we'll be using for various to_archive + # methods. + df = item.from_mysql() + ddf = dd.from_pandas(df, npartitions=1) + + # (1) Write the basic archive, the issue is that because it's + # an empty pd.DataFrame, it never makes an actual parquet file + assert item.to_archive(ddf=ddf, is_partial=False, overwrite=False) + assert item.has_archive() + assert item.has_archive(include_empty=False) + + def test__to_archive( + self, + df_collection_data_type, + df_collection, + user_factory, + product, + offset, + duration, + client_no_amm, + user, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + """We already have a test for the "non-private" version of this, + which primarily just uses the respective Client to determine if + the ddf is_empty or not. + + Therefore, use the private test to check the manual behavior of + passing in the is_empty or overwrite. + """ + if df_collection.data_type in unsupported_mock_types: + return + + delete_df_collection(coll=df_collection) + u1: User = user_factory(product=product) + + for item in df_collection.items: + item: DFCollectionItem + + incite_item_factory(user=u1, item=item) + + # Load up the data that we'll be using for various to_archive + # methods. Will always be empty pd.DataFrames for now... + df = item.from_mysql() + ddf = dd.from_pandas(df, npartitions=1) + + # (1) Confirm a missing ddf (shouldn't bc of type hint) should + # immediately return back False + assert not item._to_archive(ddf=None, is_empty=True) + assert not item._to_archive(ddf=None, is_empty=False) + + # (2) Setting empty overrides any possible state of the ddf + for rand_val in [df, ddf, True, 1_000]: + assert not item.empty_path.exists() + item._to_archive(ddf=rand_val, is_empty=True) + assert item.empty_path.exists() + item.empty_path.unlink() + + # (3) Trigger a warning with overwrite. First write an empty, + # then write it again with override default to confirm it worked, + # then write it again with override=False to confirm it does + # not work. + assert item._to_archive(ddf=ddf, is_empty=True) + res1 = item.empty_path.stat() + + # Returns none because it knows the file (regular, empty, or + # partial) already exists + assert not item._to_archive(ddf=ddf, is_empty=True, overwrite=False) + + # Currently override=True doesn't actually work on empty files + # because it's checked again in .set_empty() and isn't + # aware of the override flag that may be passed in to + # item._to_archive() + with pytest.raises(expected_exception=AssertionError) as cm: + item._to_archive(ddf=rand_val, is_empty=True, overwrite=True) + assert "set_empty is already set; why are you doing this?" in str(cm.value) + + # We can assert the file stats are the same because we were never + # able to go ahead and rewrite or update it in anyway + res2 = item.empty_path.stat() + assert res1 == res2 + + @pytest.mark.skip + def test_to_archive_numbered_partial( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + @pytest.mark.skip + def test_initial_load( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + @pytest.mark.skip + def test_clear_corrupt_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset, duration", + argvalues=list(iter_product(df_collections, ["12h", "10D"], [timedelta(days=15)])), +) +class TestDFCollectionItemMethodBase: + + @pytest.mark.skip + def test_path_exists(self, df_collection_data_type, offset, duration): + pass + + @pytest.mark.skip + def test_next_numbered_path(self, df_collection_data_type, offset, duration): + pass + + @pytest.mark.skip + def test_search_highest_numbered_path( + self, df_collection_data_type, offset, duration + ): + pass + + @pytest.mark.skip + def test_tmp_filename(self, df_collection_data_type, offset, duration): + pass + + @pytest.mark.skip + def test_tmp_path(self, df_collection_data_type, offset, duration): + pass + + def test_is_empty(self, df_collection_data_type, df_collection, offset, duration): + """ + test_has_empty was merged into this because item.has_empty is + an alias for is_empty.. or vis-versa + """ + + for item in df_collection.items: + assert not item.is_empty() + assert not item.has_empty() + + item.empty_path.touch() + + assert item.is_empty() + assert item.has_empty() + + def test_has_partial_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + for item in df_collection.items: + assert not item.has_partial_archive() + item.partial_path.touch() + assert item.has_partial_archive() + + def test_has_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + for item in df_collection.items: + # (1) Originally, nothing exists... so let's just make a file and + # confirm that it works if just touch that path (no validation + # occurs at all). + assert not item.has_archive(include_empty=False) + assert not item.has_archive(include_empty=True) + item.path.touch() + assert item.has_archive(include_empty=False) + assert item.has_archive(include_empty=True) + + item.path.unlink() + assert not item.has_archive(include_empty=False) + assert not item.has_archive(include_empty=True) + + # (2) Same as the above, except make an empty directory + # instead of a file + assert not item.has_archive(include_empty=False) + assert not item.has_archive(include_empty=True) + item.path.mkdir() + assert item.has_archive(include_empty=False) + assert item.has_archive(include_empty=True) + + item.path.rmdir() + assert not item.has_archive(include_empty=False) + assert not item.has_archive(include_empty=True) + + # (3) Rather than make a empty file or dir at the path, let's + # touch the empty_path and confirm the include_empty option + # works + + item.empty_path.touch() + assert not item.has_archive(include_empty=False) + assert item.has_archive(include_empty=True) + + def test_delete_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + for item in df_collection.items: + item: DFCollectionItem + # (1) Confirm that it doesn't raise an error or anything if we + # try to delete files or folders that do not exist + CollectionItemBase.delete_archive(generic_path=item.path) + CollectionItemBase.delete_archive(generic_path=item.empty_path) + CollectionItemBase.delete_archive(generic_path=item.partial_path) + + item.path.touch() + item.empty_path.touch() + item.partial_path.touch() + + CollectionItemBase.delete_archive(generic_path=item.path) + CollectionItemBase.delete_archive(generic_path=item.empty_path) + CollectionItemBase.delete_archive(generic_path=item.partial_path) + + assert not item.path.exists() + assert not item.empty_path.exists() + assert not item.partial_path.exists() + + def test_should_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + schema: DataFrameSchema = df_collection._schema + aa = schema.metadata[ARCHIVE_AFTER] + + # It shouldn't be None, it can be timedelta(seconds=0) + assert isinstance(aa, timedelta) + + for item in df_collection.items: + item: DFCollectionItem + + if datetime.now(tz=timezone.utc) > item.finish + aa: + assert item.should_archive() + else: + assert not item.should_archive() + + @pytest.mark.skip + def test_set_empty(self, df_collection_data_type, df_collection, offset, duration): + pass + + def test_valid_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + # Originally, nothing has been saved or anything.. so confirm it + # always comes back as None + for item in df_collection.items: + assert not item.valid_archive(generic_path=None, sample=None) + + _path = Path(pjoin(df_collection.archive_path, uuid4().hex)) + + # (1) Fail if isfile, but doesn't exist and if we can't read + # it as valid ParquetFile + assert not item.valid_archive(generic_path=_path, sample=None) + _path.touch() + assert not item.valid_archive(generic_path=_path, sample=None) + _path.unlink() + + # (2) Fail if isdir and we can't read it as a valid ParquetFile + _path.mkdir() + assert _path.is_dir() + assert not item.valid_archive(generic_path=_path, sample=None) + _path.rmdir() + + @pytest.mark.skip + def test_validate_df( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + @pytest.mark.skip + def test_from_archive( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + def test__to_dict(self, df_collection_data_type, df_collection, offset, duration): + + for item in df_collection.items: + res = item._to_dict() + assert isinstance(res, dict) + assert len(res.keys()) == 6 + + assert isinstance(res["should_archive"], bool) + assert isinstance(res["has_archive"], bool) + assert isinstance(res["path"], Path) + assert isinstance(res["filename"], str) + + assert isinstance(res["start"], datetime) + assert isinstance(res["finish"], datetime) + assert res["start"] < res["finish"] + + @pytest.mark.skip + def test_delete_partial( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + @pytest.mark.skip + def test_cleanup_partials( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + @pytest.mark.skip + def test_delete_dangling_partials( + self, df_collection_data_type, df_collection, offset, duration + ): + pass + + +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) +async def test_client(client, s, worker): + """c,s,a are all required - the secondary Worker (b) is not required""" + + assert isinstance(client, Client) + assert isinstance(s, Scheduler) + assert isinstance(worker, Worker) + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", + argvalues=combo_object(), +) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) +@pytest.mark.anyio +async def test_client_parametrize(c, s, w, df_collection_data_type, offset): + """c,s,a are all required - the secondary Worker (b) is not required""" + + assert isinstance(c, Client), f"c is not Client, it's {type(c)}" + assert isinstance(s, Scheduler), f"s is not Scheduler, it's {type(s)}" + assert isinstance(w, Worker), f"w is not Worker, it's {type(w)}" + + assert df_collection_data_type is not None + assert isinstance(offset, str) + + +# I cannot figure out how to define the parametrize on the Test, but then have +# sync or async methods within it, with some having or not having the +# gen_cluster decorator set. + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset, duration", + argvalues=list(iter_product(df_collections, ["12h", "10D"], [timedelta(days=15)])), +) +class TestDFCollectionItemFunctionalTest: + + def test_to_archive_and_ddf( + self, + df_collection_data_type, + offset, + duration, + client_no_amm, + df_collection, + user, + user_factory, + product, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + from generalresearch.models.thl.user import User + + if df_collection.data_type in unsupported_mock_types: + return + u1: User = user_factory(product=product) + + delete_df_collection(coll=df_collection) + df_collection._client = client_no_amm + + # Assert that there are no pre-existing archives + assert df_collection.progress.has_archive.eq(False).all() + res = df_collection.ddf() + assert res is None + + delete_df_collection(coll=df_collection) + for item in df_collection.items: + item: DFCollectionItem + + incite_item_factory(user=u1, item=item) + item.initial_load() + + # I know it seems weird to delete items from the database before we + # proceed with the test. However, the content should have already + # been saved out into an parquet at this point, and I am too lazy + # to write a separate teardown for a collection (and not a + # single Item) + + # Now that we went ahead with the initial_load, Assert that all + # items have archives files saved + assert isinstance(df_collection.progress, pd.DataFrame) + assert df_collection.progress.has_archive.eq(True).all() + + ddf = df_collection.ddf() + shape = df_collection._client.compute(collections=ddf.shape, sync=True) + assert shape[0] > 5 + + def test_filesize_estimate( + self, + df_collection, + user, + offset, + duration, + client_no_amm, + user_factory, + product, + df_collection_data_type, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + """A functional test to write some Parquet files for the + DFCollection and then confirm that the files get written + correctly. + + Confirm the files are written correctly by: + (1) Validating their passing the pandera schema + (2) The file or dir has an expected size on disk + """ + import pyarrow.parquet as pq + from generalresearch.models.thl.user import User + import os + + if df_collection.data_type in unsupported_mock_types: + return + delete_df_collection(coll=df_collection) + u1: User = user_factory(product=product) + + # Pick 3 random items to sample for correct filesize + for item in df_collection.items: + item: DFCollectionItem + + incite_item_factory(user=u1, item=item) + item.initial_load(overwrite=True) + + total_bytes = 0 + for fp in pq.ParquetDataset(item.path).files: + total_bytes += os.stat(fp).st_size + + total_mb = total_bytes / 1_048_576 + + assert total_bytes > 1_000 + assert total_mb < 1 + + def test_to_archive_client( + self, + client_no_amm, + df_collection, + user_factory, + product, + offset, + duration, + df_collection_data_type, + incite_item_factory, + delete_df_collection, + mnt_filepath, + ): + from generalresearch.models.thl.user import User + + delete_df_collection(coll=df_collection) + df_collection._client = client_no_amm + u1: User = user_factory(product=product) + + for item in df_collection.items: + item: DFCollectionItem + + if df_collection.data_type in unsupported_mock_types: + continue + + incite_item_factory(user=u1, item=item) + + # Load up the data that we'll be using for various to_archive + # methods. Will always be empty pd.DataFrames for now... + df = item.from_mysql() + ddf = dd.from_pandas(df, npartitions=1) + assert isinstance(ddf, dd.DataFrame) + + # (1) Write the basic archive, the issue is that because it's + # an empty pd.DataFrame, it never makes an actual parquet file + assert not item.has_archive() + saved = item.to_archive(ddf=ddf, is_partial=False, overwrite=False) + assert saved + assert item.has_archive(include_empty=True) + + @pytest.mark.skip + def test_get_items(self, df_collection, product, offset, duration): + with pytest.warns(expected_warning=ResourceWarning) as cm: + df_collection.get_items_last365() + assert "DFCollectionItem has missing archives" in str( + [w.message for w in cm.list] + ) + + res = df_collection.get_items_last365() + assert len(res) == len(df_collection.items) + + def test_saving_protections( + self, + client_no_amm, + df_collection_data_type, + df_collection, + incite_item_factory, + delete_df_collection, + user_factory, + product, + offset, + duration, + mnt_filepath, + ): + """Don't allow creating an archive for data that will likely be + overwritten or updated + """ + from generalresearch.models.thl.user import User + + if df_collection.data_type in unsupported_mock_types: + return + u1: User = user_factory(product=product) + + schema: DataFrameSchema = df_collection._schema + aa = schema.metadata[ARCHIVE_AFTER] + assert isinstance(aa, timedelta) + + delete_df_collection(df_collection) + for item in df_collection.items: + item: DFCollectionItem + + incite_item_factory(user=u1, item=item) + + should_archive = item.should_archive() + res = item.initial_load() + + # self.assertIn("Cannot create archive for such new data", str(cm.records)) + + # .to_archive() will return back True or False depending on if it + # was successful. We want to compare that result to the + # .should_archive() method result + assert should_archive == res + + def test_empty_item( + self, + client_no_amm, + df_collection_data_type, + df_collection, + incite_item_factory, + delete_df_collection, + user, + offset, + duration, + mnt_filepath, + ): + delete_df_collection(coll=df_collection) + + for item in df_collection.items: + assert not item.has_empty() + df: pd.DataFrame = item.from_mysql() + + # We do this check b/c the Ledger returns back None and + # I don't want it to fail when we go to make a ddf + if df is None: + item.set_empty() + else: + ddf = dd.from_pandas(df, npartitions=1) + item.to_archive(ddf=ddf) + + assert item.has_empty() + + def test_file_touching( + self, + client_no_amm, + df_collection_data_type, + df_collection, + incite_item_factory, + delete_df_collection, + user_factory, + product, + offset, + duration, + mnt_filepath, + ): + from generalresearch.models.thl.user import User + + delete_df_collection(coll=df_collection) + df_collection._client = client_no_amm + u1: User = user_factory(product=product) + + for item in df_collection.items: + # Confirm none of the paths exist yet + assert not item.has_archive() + assert not item.path_exists(generic_path=item.path) + assert not item.has_empty() + assert not item.path_exists(generic_path=item.empty_path) + + if df_collection.data_type in unsupported_mock_types: + assert not item.has_archive(include_empty=False) + assert not item.has_empty() + assert not item.path_exists(generic_path=item.empty_path) + else: + incite_item_factory(user=u1, item=item) + item.initial_load() + + assert item.has_archive(include_empty=False) + assert item.path_exists(generic_path=item.path) + assert not item.has_empty() diff --git a/tests/incite/collections/test_df_collection_thl_marketplaces.py b/tests/incite/collections/test_df_collection_thl_marketplaces.py new file mode 100644 index 0000000..0a77938 --- /dev/null +++ b/tests/incite/collections/test_df_collection_thl_marketplaces.py @@ -0,0 +1,75 @@ +from datetime import datetime, timezone +from itertools import product + +import pytest +from pandera import Column, Index, DataFrameSchema + +from generalresearch.incite.collections import DFCollection +from generalresearch.incite.collections import DFCollectionType +from generalresearch.incite.collections.thl_marketplaces import ( + InnovateSurveyHistoryCollection, + MorningSurveyTimeseriesCollection, + SagoSurveyHistoryCollection, + SpectrumSurveyTimeseriesCollection, +) +from test_utils.incite.conftest import mnt_filepath + + +def combo_object(): + for x in product( + [ + InnovateSurveyHistoryCollection, + MorningSurveyTimeseriesCollection, + SagoSurveyHistoryCollection, + SpectrumSurveyTimeseriesCollection, + ], + ["5min", "6H", "30D"], + ): + yield x + + +@pytest.mark.parametrize("df_coll, offset", combo_object()) +class TestDFCollection_thl_marketplaces: + + def test_init(self, mnt_filepath, df_coll, offset, spectrum_rw): + assert issubclass(df_coll, DFCollection) + + # This is stupid, but we need to pull the default from the + # Pydantic field + data_type = df_coll.model_fields["data_type"].default + assert isinstance(data_type, DFCollectionType) + + # (1) Can't be totally empty, needs a path... + with pytest.raises(expected_exception=Exception) as cm: + instance = df_coll() + + # (2) Confirm it only needs the archive_path + instance = df_coll( + archive_path=mnt_filepath.archive_path(enum_type=data_type), + ) + assert isinstance(instance, DFCollection) + + # (3) Confirm it loads with all + instance = df_coll( + archive_path=mnt_filepath.archive_path(enum_type=data_type), + sql_helper=spectrum_rw, + offset=offset, + start=datetime(year=2023, month=6, day=1, minute=0, tzinfo=timezone.utc), + finished=datetime(year=2023, month=6, day=1, minute=5, tzinfo=timezone.utc), + ) + assert isinstance(instance, DFCollection) + + # (4) Now that we initialize the Class, we can access the _schema + assert isinstance(instance._schema, DataFrameSchema) + assert isinstance(instance._schema.index, Index) + + for c in instance._schema.columns.keys(): + assert isinstance(c, str) + col = instance._schema.columns[c] + assert isinstance(col, Column) + + assert instance._schema.coerce, "coerce on all Schemas" + assert isinstance(instance._schema.checks, list) + assert len(instance._schema.checks) == 0 + assert isinstance(instance._schema.metadata, dict) + assert len(instance._schema.metadata.keys()) == 2 diff --git a/tests/incite/collections/test_df_collection_thl_web.py b/tests/incite/collections/test_df_collection_thl_web.py new file mode 100644 index 0000000..e6f464b --- /dev/null +++ b/tests/incite/collections/test_df_collection_thl_web.py @@ -0,0 +1,160 @@ +from datetime import datetime +from itertools import product + +import dask.dataframe as dd +import pandas as pd +import pytest +from pandera import DataFrameSchema + +from generalresearch.incite.collections import DFCollection, DFCollectionType + + +def combo_object(): + for x in product( + [ + DFCollectionType.USER, + DFCollectionType.WALL, + DFCollectionType.SESSION, + DFCollectionType.TASK_ADJUSTMENT, + DFCollectionType.AUDIT_LOG, + DFCollectionType.LEDGER, + ], + ["30min", "1H"], + ): + yield x + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollection_thl_web: + + def test_init(self, df_collection_data_type, offset, df_collection): + assert isinstance(df_collection_data_type, DFCollectionType) + assert isinstance(df_collection, DFCollection) + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollection_thl_web_Properties: + + def test_items(self, df_collection_data_type, offset, df_collection): + assert isinstance(df_collection.items, list) + for i in df_collection.items: + assert i._collection == df_collection + + def test__schema(self, df_collection_data_type, offset, df_collection): + assert isinstance(df_collection._schema, DataFrameSchema) + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollection_thl_web_BaseProperties: + + @pytest.mark.skip + def test__interval_range(self, df_collection_data_type, offset, df_collection): + pass + + def test_interval_start(self, df_collection_data_type, offset, df_collection): + assert isinstance(df_collection.interval_start, datetime) + + def test_interval_range(self, df_collection_data_type, offset, df_collection): + assert isinstance(df_collection.interval_range, list) + + def test_progress(self, df_collection_data_type, offset, df_collection): + assert isinstance(df_collection.progress, pd.DataFrame) + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollection_thl_web_Methods: + + @pytest.mark.skip + def test_initial_loads(self, df_collection_data_type, df_collection, offset): + pass + + @pytest.mark.skip + def test_fetch_force_rr_latest( + self, df_collection_data_type, df_collection, offset + ): + pass + + @pytest.mark.skip + def test_force_rr_latest(self, df_collection_data_type, df_collection, offset): + pass + + +@pytest.mark.parametrize( + argnames="df_collection_data_type, offset", argvalues=combo_object() +) +class TestDFCollection_thl_web_BaseMethods: + + def test_fetch_all_paths(self, df_collection_data_type, offset, df_collection): + res = df_collection.fetch_all_paths( + items=None, force_rr_latest=False, include_partial=False + ) + assert isinstance(res, list) + + @pytest.mark.skip + def test_ddf(self, df_collection_data_type, offset, df_collection): + res = df_collection.ddf() + assert isinstance(res, dd.DataFrame) + + # -- cleanup -- + @pytest.mark.skip + def test_schedule_cleanup(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_cleanup(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_cleanup_partials(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_clear_tmp_archives(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_clear_corrupt_archives( + self, df_collection_data_type, offset, df_collection + ): + pass + + @pytest.mark.skip + def test_rebuild_symlinks(self, df_collection_data_type, offset, df_collection): + pass + + # -- Source timing -- + @pytest.mark.skip + def test_get_item(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_get_item_start(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_get_items(self, df_collection_data_type, offset, df_collection): + # If we get all the items from the start of the collection, it + # should include all the items! + res1 = df_collection.items + res2 = df_collection.get_items(since=df_collection.start) + assert len(res1) == len(res2) + + @pytest.mark.skip + def test_get_items_from_year(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_get_items_last90(self, df_collection_data_type, offset, df_collection): + pass + + @pytest.mark.skip + def test_get_items_last365(self, df_collection_data_type, offset, df_collection): + pass diff --git a/tests/incite/collections/test_df_collection_thl_web_ledger.py b/tests/incite/collections/test_df_collection_thl_web_ledger.py new file mode 100644 index 0000000..599d979 --- /dev/null +++ b/tests/incite/collections/test_df_collection_thl_web_ledger.py @@ -0,0 +1,32 @@ +# def test_loaded(self, client_no_amm, collection, new_user_fixture, pop_ledger_merge): +# collection._client = client_no_amm +# +# teardown_events(collection) +# THL_LM.create_main_accounts() +# +# for item in collection.items: +# populate_events(item, user=new_user_fixture) +# item.initial_load() +# +# ddf = collection.ddf( +# force_rr_latest=False, +# include_partial=True, +# filters=[ +# ("created", ">=", collection.start), +# ("created", "<", collection.finished), +# ], +# ) +# +# assert isinstance(ddf, dd.DataFrame) +# df = client_no_amm.compute(collections=ddf, sync=True) +# assert isinstance(df, pd.DataFrame) +# +# # Simple validation check(s) +# assert not df.tx_id.is_unique +# df["net"] = df.direction * df.amount +# assert df.groupby("tx_id").net.sum().sum() == 0 +# +# teardown_events(collection) +# +# +# diff --git a/tests/incite/mergers/__init__.py b/tests/incite/mergers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/incite/mergers/foundations/__init__.py b/tests/incite/mergers/foundations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/incite/mergers/foundations/test_enriched_session.py b/tests/incite/mergers/foundations/test_enriched_session.py new file mode 100644 index 0000000..ec15d38 --- /dev/null +++ b/tests/incite/mergers/foundations/test_enriched_session.py @@ -0,0 +1,138 @@ +from datetime import timedelta, timezone, datetime +from decimal import Decimal +from itertools import product +from typing import Optional + +from generalresearch.incite.schemas.admin_responses import ( + AdminPOPSessionSchema, +) + +import dask.dataframe as dd +import pandas as pd +import pytest + +from test_utils.incite.collections.conftest import ( + session_collection, + wall_collection, +) + + +@pytest.mark.parametrize( + argnames="offset, duration", + argvalues=list( + product( + ["12h", "3D"], + [timedelta(days=5)], + ) + ), +) +class TestEnrichedSession: + + def test_base( + self, + client_no_amm, + product, + user_factory, + wall_collection, + session_collection, + enriched_session_merge, + thl_web_rr, + delete_df_collection, + incite_item_factory, + ): + from generalresearch.models.thl.user import User + + delete_df_collection(coll=session_collection) + + u1: User = user_factory(product=product, created=session_collection.start) + + for item in session_collection.items: + incite_item_factory(item=item, user=u1) + item.initial_load() + + for item in wall_collection.items: + item.initial_load() + + enriched_session_merge.build( + client=client_no_amm, + wall_coll=wall_collection, + session_coll=session_collection, + pg_config=thl_web_rr, + ) + + # -- + + ddf = enriched_session_merge.ddf() + assert isinstance(ddf, dd.DataFrame) + + df = client_no_amm.compute(collections=ddf, sync=True) + assert isinstance(df, pd.DataFrame) + + assert not df.empty + + # -- Teardown + delete_df_collection(session_collection) + + +class TestEnrichedSessionAdmin: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2020, month=3, day=14, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "1d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return timedelta(days=5) + + def test_to_admin_response( + self, + event_report_request, + enriched_session_merge, + client_no_amm, + wall_collection, + session_collection, + thl_web_rr, + session_report_request, + user_factory, + start, + session_factory, + product_factory, + delete_df_collection, + ): + delete_df_collection(coll=wall_collection) + delete_df_collection(coll=session_collection) + + p1 = product_factory() + p2 = product_factory() + + for p in [p1, p2]: + u = user_factory(product=p) + for i in range(50): + s = session_factory( + user=u, + wall_count=1, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(minutes=i, seconds=1), + ) + wall_collection.initial_load(client=None, sync=True) + session_collection.initial_load(client=None, sync=True) + + enriched_session_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + + df = enriched_session_merge.to_admin_response( + rr=session_report_request, client=client_no_amm + ) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + assert isinstance(AdminPOPSessionSchema.validate(df), pd.DataFrame) + assert df.index.get_level_values(1).nunique() == 2 diff --git a/tests/incite/mergers/foundations/test_enriched_task_adjust.py b/tests/incite/mergers/foundations/test_enriched_task_adjust.py new file mode 100644 index 0000000..96c214f --- /dev/null +++ b/tests/incite/mergers/foundations/test_enriched_task_adjust.py @@ -0,0 +1,76 @@ +from datetime import timedelta +from itertools import product as iter_product + +import dask.dataframe as dd +import pandas as pd +import pytest + +from test_utils.incite.collections.conftest import ( + wall_collection, + task_adj_collection, + session_collection, +) +from test_utils.incite.mergers.conftest import enriched_wall_merge + + +@pytest.mark.parametrize( + argnames="offset, duration,", + argvalues=list( + iter_product( + ["12h", "3D"], + [timedelta(days=5)], + ) + ), +) +class TestEnrichedTaskAdjust: + + @pytest.mark.skip + def test_base( + self, + client_no_amm, + user_factory, + product, + task_adj_collection, + wall_collection, + session_collection, + enriched_wall_merge, + enriched_task_adjust_merge, + incite_item_factory, + delete_df_collection, + thl_web_rr, + ): + from generalresearch.models.thl.user import User + + # -- Build & Setup + delete_df_collection(coll=session_collection) + u1: User = user_factory(product=product) + + for item in session_collection.items: + incite_item_factory(user=u1, item=item) + item.initial_load() + for item in wall_collection.items: + item.initial_load() + + enriched_wall_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + + enriched_task_adjust_merge.build( + client=client_no_amm, + task_adjust_coll=task_adj_collection, + enriched_wall=enriched_wall_merge, + pg_config=thl_web_rr, + ) + + # -- + + ddf = enriched_task_adjust_merge.ddf() + assert isinstance(ddf, dd.DataFrame) + + df = client_no_amm.compute(collections=ddf, sync=True) + assert isinstance(df, pd.DataFrame) + + assert not df.empty diff --git a/tests/incite/mergers/foundations/test_enriched_wall.py b/tests/incite/mergers/foundations/test_enriched_wall.py new file mode 100644 index 0000000..8f4995b --- /dev/null +++ b/tests/incite/mergers/foundations/test_enriched_wall.py @@ -0,0 +1,236 @@ +from datetime import timedelta, timezone, datetime +from decimal import Decimal +from itertools import product as iter_product +from typing import Optional + +import dask.dataframe as dd +import pandas as pd +import pytest + +# noinspection PyUnresolvedReferences +from distributed.utils_test import ( + gen_cluster, + client_no_amm, + loop, + loop_in_thread, + cleanup, + cluster_fixture, + client, +) + +from generalresearch.incite.mergers.foundations.enriched_wall import ( + EnrichedWallMergeItem, +) +from test_utils.incite.collections.conftest import ( + session_collection, + wall_collection, +) +from test_utils.incite.conftest import incite_item_factory +from test_utils.incite.mergers.conftest import ( + enriched_wall_merge, +) + + +@pytest.mark.parametrize( + argnames="offset, duration", + argvalues=list(iter_product(["48h", "3D"], [timedelta(days=5)])), +) +class TestEnrichedWall: + + def test_base( + self, + client_no_amm, + product, + user_factory, + wall_collection, + thl_web_rr, + session_collection, + enriched_wall_merge, + delete_df_collection, + incite_item_factory, + ): + from generalresearch.models.thl.user import User + + # -- Build & Setup + delete_df_collection(coll=session_collection) + delete_df_collection(coll=wall_collection) + u1: User = user_factory(product=product, created=session_collection.start) + + for item in session_collection.items: + incite_item_factory(item=item, user=u1) + item.initial_load() + + for item in wall_collection.items: + item.initial_load() + + enriched_wall_merge.build( + client=client_no_amm, + wall_coll=wall_collection, + session_coll=session_collection, + pg_config=thl_web_rr, + ) + + # -- + + ddf = enriched_wall_merge.ddf() + assert isinstance(ddf, dd.DataFrame) + + df = client_no_amm.compute(collections=ddf, sync=True) + assert isinstance(df, pd.DataFrame) + + assert not df.empty + + def test_base_item( + self, + client_no_amm, + product, + user_factory, + wall_collection, + session_collection, + enriched_wall_merge, + delete_df_collection, + thl_web_rr, + incite_item_factory, + ): + # -- Build & Setup + delete_df_collection(coll=session_collection) + u = user_factory(product=product, created=session_collection.start) + + for item in session_collection.items: + incite_item_factory(item=item, user=u) + item.initial_load() + for item in wall_collection.items: + item.initial_load() + + enriched_wall_merge.build( + client=client_no_amm, + wall_coll=wall_collection, + session_coll=session_collection, + pg_config=thl_web_rr, + ) + + # -- + + for item in enriched_wall_merge.items: + assert isinstance(item, EnrichedWallMergeItem) + + path = item.path + + try: + modified_time1 = path.stat().st_mtime + except (Exception,): + modified_time1 = 0 + + item.build( + client=client_no_amm, + wall_coll=wall_collection, + session_coll=session_collection, + pg_config=thl_web_rr, + ) + modified_time2 = path.stat().st_mtime + + # Merger Items can't be updated unless it's a partial, confirm + # that even after attempting to rebuild, it doesn't re-touch + # the file + assert modified_time2 == modified_time1 + + # def test_admin_pop_session_device_type(ew_merge_setup): + # self.build() + # + # rr = ReportRequest( + # report_type=ReportType.POP_EVENT, + # index0="started", + # index1="device_type", + # freq="min", + # start=start, + # ) + # + # df, categories, updated = self.instance.to_admin_response( + # rr=rr, product_ids=[self.product.id], client=client + # ) + # + # assert isinstance(df, pd.DataFrame) + # device_types_str = [str(e.value) for e in DeviceType] + # device_types = df.index.get_level_values(1).values + # assert all([dt in device_types_str for dt in device_types]) + + +class TestEnrichedWallToAdmin: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2020, month=3, day=14, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "1d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return timedelta(days=5) + + def test_empty(self, enriched_wall_merge, client_no_amm, start): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest.model_validate({"interval": "5min", "start": start}) + + res = enriched_wall_merge.to_admin_response( + rr=rr, + client=client_no_amm, + ) + + assert isinstance(res, pd.DataFrame) + + assert res.empty + assert len(res.columns) > 5 + + def test_to_admin_response( + self, + event_report_request, + enriched_wall_merge, + client_no_amm, + wall_collection, + session_collection, + thl_web_rr, + user, + session_factory, + delete_df_collection, + product_factory, + user_factory, + start, + ): + delete_df_collection(coll=wall_collection) + delete_df_collection(coll=session_collection) + + p1 = product_factory() + p2 = product_factory() + + for p in [p1, p2]: + u = user_factory(product=p) + for i in range(50): + s = session_factory( + user=u, + wall_count=2, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(minutes=i, seconds=1), + ) + + wall_collection.initial_load(client=None, sync=True) + session_collection.initial_load(client=None, sync=True) + + enriched_wall_merge.build( + client=client_no_amm, + wall_coll=wall_collection, + session_coll=session_collection, + pg_config=thl_web_rr, + ) + + df = enriched_wall_merge.to_admin_response( + rr=event_report_request, client=client_no_amm + ) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + # assert len(df) == 1 + # assert user.product_id == df.reset_index().loc[0, "index1"] + assert df.index.get_level_values(1).nunique() == 2 diff --git a/tests/incite/mergers/foundations/test_user_id_product.py b/tests/incite/mergers/foundations/test_user_id_product.py new file mode 100644 index 0000000..f96bfb4 --- /dev/null +++ b/tests/incite/mergers/foundations/test_user_id_product.py @@ -0,0 +1,73 @@ +from datetime import timedelta, datetime, timezone +from itertools import product + +import pandas as pd +import pytest + +# noinspection PyUnresolvedReferences +from distributed.utils_test import ( + gen_cluster, + client_no_amm, + loop, + loop_in_thread, + cleanup, + cluster_fixture, + client, +) + +from generalresearch.incite.mergers.foundations.user_id_product import ( + UserIdProductMergeItem, +) +from test_utils.incite.mergers.conftest import user_id_product_merge + + +@pytest.mark.parametrize( + argnames="offset, duration, start", + argvalues=list( + product( + ["12h", "3D"], + [timedelta(days=5)], + [ + (datetime.now(tz=timezone.utc) - timedelta(days=35)).replace( + microsecond=0 + ) + ], + ) + ), +) +class TestUserIDProduct: + + @pytest.mark.skip + def test_base(self, client_no_amm, user_id_product_merge): + ddf = user_id_product_merge.ddf() + df = client_no_amm.compute(collections=ddf, sync=True) + assert isinstance(df, pd.DataFrame) + assert not df.empty + + @pytest.mark.skip + def test_base_item(self, client_no_amm, user_id_product_merge, user_collection): + assert len(user_id_product_merge.items) == 1 + + for item in user_id_product_merge.items: + assert isinstance(item, UserIdProductMergeItem) + + path = item.path + + try: + modified_time1 = path.stat().st_mtime + except (Exception,): + modified_time1 = 0 + + user_id_product_merge.build(client=client_no_amm, user_coll=user_collection) + modified_time2 = path.stat().st_mtime + + assert modified_time2 > modified_time1 + + @pytest.mark.skip + def test_read(self, client_no_amm, user_id_product_merge): + users_ddf = user_id_product_merge.ddf() + df = client_no_amm.compute(collections=users_ddf, sync=True) + + assert isinstance(df, pd.DataFrame) + assert len(df.columns) == 1 + assert str(df.product_id.dtype) == "category" diff --git a/tests/incite/mergers/test_merge_collection.py b/tests/incite/mergers/test_merge_collection.py new file mode 100644 index 0000000..692cac3 --- /dev/null +++ b/tests/incite/mergers/test_merge_collection.py @@ -0,0 +1,102 @@ +from datetime import datetime, timezone, timedelta +from itertools import product + +import pandas as pd +import pytest +from pandera import DataFrameSchema + +from generalresearch.incite.mergers import ( + MergeCollection, + MergeType, +) +from test_utils.incite.conftest import mnt_filepath + +merge_types = list(e for e in MergeType if e != MergeType.TEST) + + +@pytest.mark.parametrize( + argnames="merge_type, offset, duration, start", + argvalues=list( + product( + merge_types, + ["5min", "6h", "14D"], + [timedelta(days=30)], + [ + (datetime.now(tz=timezone.utc) - timedelta(days=35)).replace( + microsecond=0 + ) + ], + ) + ), +) +class TestMergeCollection: + + def test_init(self, mnt_filepath, merge_type, offset, duration, start): + with pytest.raises(expected_exception=ValueError) as cm: + MergeCollection(archive_path=mnt_filepath.data_src) + assert "Must explicitly provide a merge_type" in str(cm.value) + + instance = MergeCollection( + merge_type=merge_type, + archive_path=mnt_filepath.archive_path(enum_type=merge_type), + ) + assert instance.merge_type == merge_type + + def test_items(self, mnt_filepath, merge_type, offset, duration, start): + instance = MergeCollection( + merge_type=merge_type, + offset=offset, + start=start, + finished=start + duration, + archive_path=mnt_filepath.archive_path(enum_type=merge_type), + ) + + assert len(instance.interval_range) == len(instance.items) + + def test_progress(self, mnt_filepath, merge_type, offset, duration, start): + instance = MergeCollection( + merge_type=merge_type, + offset=offset, + start=start, + finished=start + duration, + archive_path=mnt_filepath.archive_path(enum_type=merge_type), + ) + + assert isinstance(instance.progress, pd.DataFrame) + assert instance.progress.shape[0] > 0 + assert instance.progress.shape[1] == 7 + assert instance.progress["group_by"].isnull().all() + + def test_schema(self, mnt_filepath, merge_type, offset, duration, start): + instance = MergeCollection( + merge_type=merge_type, + archive_path=mnt_filepath.archive_path(enum_type=merge_type), + ) + + assert isinstance(instance._schema, DataFrameSchema) + + def test_load(self, mnt_filepath, merge_type, offset, duration, start): + instance = MergeCollection( + merge_type=merge_type, + start=start, + finished=start + duration, + offset=offset, + archive_path=mnt_filepath.archive_path(enum_type=merge_type), + ) + + # Confirm that there are no archives available yet + assert instance.progress.has_archive.eq(False).all() + + def test_get_items(self, mnt_filepath, merge_type, offset, duration, start): + instance = MergeCollection( + start=start, + finished=start + duration, + offset=offset, + merge_type=merge_type, + archive_path=mnt_filepath.archive_path(enum_type=merge_type), + ) + + # with pytest.raises(expected_exception=ResourceWarning) as cm: + res = instance.get_items_last365() + # assert "has missing archives", str(cm.value) + assert len(res) == len(instance.items) diff --git a/tests/incite/mergers/test_merge_collection_item.py b/tests/incite/mergers/test_merge_collection_item.py new file mode 100644 index 0000000..96f8789 --- /dev/null +++ b/tests/incite/mergers/test_merge_collection_item.py @@ -0,0 +1,66 @@ +from datetime import datetime, timezone, timedelta +from itertools import product +from pathlib import PurePath + +import pytest + +from generalresearch.incite.mergers import MergeCollectionItem, MergeType +from generalresearch.incite.mergers.foundations.enriched_session import ( + EnrichedSessionMerge, +) +from generalresearch.incite.mergers.foundations.enriched_wall import ( + EnrichedWallMerge, +) +from test_utils.incite.mergers.conftest import merge_collection + + +@pytest.mark.parametrize( + argnames="merge_type, offset, duration", + argvalues=list( + product( + [MergeType.ENRICHED_SESSION, MergeType.ENRICHED_WALL], + ["1h"], + [timedelta(days=1)], + ) + ), +) +class TestMergeCollectionItem: + + def test_file_naming(self, merge_collection, offset, duration, start): + assert len(merge_collection.items) == 25 + + items: list[MergeCollectionItem] = merge_collection.items + + for i in items: + i: MergeCollectionItem + + assert isinstance(i.path, PurePath) + assert i.path.name == i.filename + + assert i._collection.merge_type.name.lower() in i.filename + assert i._collection.offset in i.filename + assert i.start.strftime("%Y-%m-%d-%H-%M-%S") in i.filename + + def test_archives(self, merge_collection, offset, duration, start): + assert len(merge_collection.items) == 25 + + for i in merge_collection.items: + assert not i.has_archive() + assert not i.has_empty() + assert not i.is_empty() + assert not i.has_partial_archive() + assert i.has_archive() == i.path_exists(generic_path=i.path) + + res = set([i.should_archive() for i in merge_collection.items]) + assert len(res) == 1 + + def test_item_to_archive(self, merge_collection, offset, duration, start): + for item in merge_collection.items: + item: MergeCollectionItem + assert not item.has_archive() + + # TODO: setup build methods + # ddf = self.build + # saved = instance.to_archive(ddf=ddf) + # self.assertTrue(saved) + # self.assertTrue(instance.has_archive()) diff --git a/tests/incite/mergers/test_pop_ledger.py b/tests/incite/mergers/test_pop_ledger.py new file mode 100644 index 0000000..6f96108 --- /dev/null +++ b/tests/incite/mergers/test_pop_ledger.py @@ -0,0 +1,307 @@ +from datetime import timedelta, datetime, timezone +from itertools import product as iter_product +from typing import Optional + +import pandas as pd +import pytest +from distributed.utils_test import client_no_amm + +from generalresearch.incite.schemas.mergers.pop_ledger import ( + numerical_col_names, +) +from test_utils.incite.collections.conftest import ledger_collection +from test_utils.incite.conftest import mnt_filepath, incite_item_factory +from test_utils.incite.mergers.conftest import pop_ledger_merge +from test_utils.managers.ledger.conftest import create_main_accounts + + +@pytest.mark.parametrize( + argnames="offset, duration", + argvalues=list( + iter_product( + ["12h", "3D"], + [timedelta(days=4)], + ) + ), +) +class TestMergePOPLedger: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2020, month=3, day=14, tzinfo=timezone.utc) + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return timedelta(days=5) + + def test_base( + self, + client_no_amm, + ledger_collection, + pop_ledger_merge, + product, + user_factory, + create_main_accounts, + thl_lm, + delete_df_collection, + incite_item_factory, + delete_ledger_db, + ): + from generalresearch.models.thl.ledger import LedgerAccount + + u = user_factory(product=product, created=ledger_collection.start) + + # -- Build & Setup + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + # assert ledger_collection.start is None + # assert ledger_collection.offset is None + + for item in ledger_collection.items: + incite_item_factory(item=item, user=u) + item.initial_load() + + # Confirm any of the items are archived + assert ledger_collection.progress.has_archive.eq(True).all() + + pop_ledger_merge.build( + client=client_no_amm, + ledger_coll=ledger_collection, + ) + # assert pop_ledger_merge.progress.has_archive.eq(True).all() + + ddf = pop_ledger_merge.ddf() + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + + # -- + + user_wallet_account: LedgerAccount = thl_lm.get_account_or_create_user_wallet( + user=u + ) + cash_account: LedgerAccount = thl_lm.get_account_cash() + rev_account: LedgerAccount = thl_lm.get_account_task_complete_revenue() + + item_finishes = [i.finish for i in ledger_collection.items] + item_finishes.sort(reverse=True) + last_item_finish = item_finishes[0] + + # Pure SQL based lookups + cash_balance: int = thl_lm.get_account_balance(account=cash_account) + rev_balance: int = thl_lm.get_account_balance(account=rev_account) + assert cash_balance > rev_balance + + # (1) Test Cash Account + ddf = pop_ledger_merge.ddf( + columns=numerical_col_names, + filters=[ + ("account_id", "==", cash_account.uuid), + ("time_idx", ">=", ledger_collection.start), + ("time_idx", "<", last_item_finish), + ], + ) + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + assert df["mp_payment.CREDIT"].sum() == 0 + assert cash_balance > 0 + assert df["mp_payment.DEBIT"].sum() == cash_balance + + # (2) Test Revenue Account + ddf = pop_ledger_merge.ddf( + columns=numerical_col_names, + filters=[ + ("account_id", "==", rev_account.uuid), + ("time_idx", ">=", ledger_collection.start), + ("time_idx", "<", last_item_finish), + ], + ) + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + assert rev_balance == 0 + assert df["bp_payment.CREDIT"].sum() == 0 + assert df["mp_payment.DEBIT"].sum() == 0 + assert df["mp_payment.CREDIT"].sum() > 0 + + # -- Cleanup + delete_ledger_db() + + def test_pydantic_init( + self, + client_no_amm, + ledger_collection, + pop_ledger_merge, + mnt_filepath, + product, + user_factory, + create_main_accounts, + offset, + duration, + start, + thl_lm, + incite_item_factory, + delete_df_collection, + delete_ledger_db, + session_collection, + ): + from generalresearch.models.thl.ledger import LedgerAccount + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.finance import ProductBalances + + u = user_factory(product=product, created=session_collection.start) + + assert ledger_collection.finished is not None + assert isinstance(u.product, Product) + delete_ledger_db() + create_main_accounts(), + delete_df_collection(coll=ledger_collection) + + bp_account: LedgerAccount = thl_lm.get_account_or_create_bp_wallet( + product=u.product + ) + cash_account: LedgerAccount = thl_lm.get_account_cash() + rev_account: LedgerAccount = thl_lm.get_account_task_complete_revenue() + + for item in ledger_collection.items: + incite_item_factory(item=item, user=u) + item.initial_load(overwrite=True) + + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + item_finishes = [i.finish for i in ledger_collection.items] + item_finishes.sort(reverse=True) + last_item_finish = item_finishes[0] + + # (1) Filter by the Product Account, this means no cash_account, or + # rev_account transactions will be present in here... + ddf = pop_ledger_merge.ddf( + columns=numerical_col_names + ["time_idx"], + filters=[ + ("account_id", "==", bp_account.uuid), + ("time_idx", ">=", ledger_collection.start), + ("time_idx", "<", last_item_finish), + ], + ) + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + df = df.set_index("time_idx") + assert not df.empty + + instance = ProductBalances.from_pandas(input_data=df.sum()) + assert instance.payout == instance.net == instance.bp_payment_credit + assert instance.available_balance < instance.net + assert instance.available_balance + instance.retainer == instance.net + assert instance.balance == thl_lm.get_account_balance(bp_account) + assert df["bp_payment.CREDIT"].sum() == thl_lm.get_account_balance(bp_account) + + # (2) Filter by the Cash Account + ddf = pop_ledger_merge.ddf( + columns=numerical_col_names + ["time_idx"], + filters=[ + ("account_id", "==", cash_account.uuid), + ("time_idx", ">=", ledger_collection.start), + ("time_idx", "<", last_item_finish), + ], + ) + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + cash_balance: int = thl_lm.get_account_balance(account=cash_account) + assert df["bp_payment.CREDIT"].sum() == 0 + assert cash_balance > 0 + assert df["mp_payment.CREDIT"].sum() == 0 + assert df["mp_payment.DEBIT"].sum() == cash_balance + + # (2) Filter by the Revenue Account + ddf = pop_ledger_merge.ddf( + columns=numerical_col_names + ["time_idx"], + filters=[ + ("account_id", "==", rev_account.uuid), + ("time_idx", ">=", ledger_collection.start), + ("time_idx", "<", last_item_finish), + ], + ) + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + rev_balance: int = thl_lm.get_account_balance(account=rev_account) + assert rev_balance == 0 + assert df["bp_payment.CREDIT"].sum() == 0 + assert df["mp_payment.DEBIT"].sum() == 0 + assert df["mp_payment.CREDIT"].sum() > 0 + + def test_resample( + self, + client_no_amm, + ledger_collection, + pop_ledger_merge, + mnt_filepath, + user_factory, + product, + create_main_accounts, + offset, + duration, + start, + thl_lm, + delete_df_collection, + incite_item_factory, + ): + from generalresearch.models.thl.user import User + + assert ledger_collection.finished is not None + delete_df_collection(coll=ledger_collection) + u1: User = user_factory(product=product) + + bp_account = thl_lm.get_account_or_create_bp_wallet(product=u1.product) + + for item in ledger_collection.items: + incite_item_factory(user=u1, item=item) + item.initial_load(overwrite=True) + + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + item_finishes = [i.finish for i in ledger_collection.items] + item_finishes.sort(reverse=True) + last_item_finish = item_finishes[0] + + ddf = pop_ledger_merge.ddf( + columns=numerical_col_names + ["time_idx"], + filters=[ + ("account_id", "==", bp_account.uuid), + ("time_idx", ">=", ledger_collection.start), + ("time_idx", "<", last_item_finish), + ], + ) + df = client_no_amm.compute(collections=ddf, sync=True) + assert isinstance(df, pd.DataFrame) + assert isinstance(df.index, pd.Index) + assert not isinstance(df.index, pd.RangeIndex) + + # Now change the index so we can easily resample it + df = df.set_index("time_idx") + assert isinstance(df.index, pd.Index) + assert isinstance(df.index, pd.DatetimeIndex) + + bp_account_balance = thl_lm.get_account_balance(account=bp_account) + + # Initial sum + initial_sum = df.sum().sum() + # assert len(df) == 48 # msg="Original df should be 48 rows" + + # Original (1min) to 5min + df_5min = df.resample("5min").sum() + # assert len(df_5min) == 12 + assert initial_sum == df_5min.sum().sum() + + # 30min + df_30min = df.resample("30min").sum() + # assert len(df_30min) == 2 + assert initial_sum == df_30min.sum().sum() + + # 1hr + df_1hr = df.resample("1h").sum() + # assert len(df_1hr) == 1 + assert initial_sum == df_1hr.sum().sum() + + # 1 day + df_1day = df.resample("1d").sum() + # assert len(df_1day) == 1 + assert initial_sum == df_1day.sum().sum() diff --git a/tests/incite/mergers/test_ym_survey_merge.py b/tests/incite/mergers/test_ym_survey_merge.py new file mode 100644 index 0000000..4c2df6b --- /dev/null +++ b/tests/incite/mergers/test_ym_survey_merge.py @@ -0,0 +1,125 @@ +from datetime import timedelta, timezone, datetime +from itertools import product + +import pandas as pd +import pytest + +# noinspection PyUnresolvedReferences +from distributed.utils_test import ( + gen_cluster, + client_no_amm, + loop, + loop_in_thread, + cleanup, + cluster_fixture, + client, +) + +from test_utils.incite.collections.conftest import wall_collection, session_collection +from test_utils.incite.mergers.conftest import ( + enriched_session_merge, + ym_survey_wall_merge, +) + + +@pytest.mark.parametrize( + argnames="offset, duration, start", + argvalues=list( + product( + ["12h", "3D"], + [timedelta(days=30)], + [ + (datetime.now(tz=timezone.utc) - timedelta(days=35)).replace( + microsecond=0 + ) + ], + ) + ), +) +class TestYMSurveyMerge: + """We override start, not because it's needed on the YMSurveyWall merge, + which operates on a rolling 10-day window, but because we don't want + to mock data in the wall collection and enriched_session_merge from + the 1800s and then wonder why there is no data available in the past + 10 days in the database. + """ + + def test_base( + self, + client_no_amm, + user_factory, + product, + ym_survey_wall_merge, + wall_collection, + session_collection, + enriched_session_merge, + delete_df_collection, + incite_item_factory, + thl_web_rr, + ): + from generalresearch.models.thl.user import User + + delete_df_collection(coll=session_collection) + user: User = user_factory(product=product, created=session_collection.start) + + # -- Build & Setup + assert ym_survey_wall_merge.start is None + assert ym_survey_wall_merge.offset == "10D" + + for item in session_collection.items: + incite_item_factory(item=item, user=user) + item.initial_load() + for item in wall_collection.items: + item.initial_load() + + # Confirm any of the items are archived + assert session_collection.progress.has_archive.eq(True).all() + assert wall_collection.progress.has_archive.eq(True).all() + + enriched_session_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + assert enriched_session_merge.progress.has_archive.eq(True).all() + + ddf = enriched_session_merge.ddf() + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + + # -- + + ym_survey_wall_merge.build( + client=client_no_amm, + wall_coll=wall_collection, + enriched_session=enriched_session_merge, + ) + assert ym_survey_wall_merge.progress.has_archive.eq(True).all() + + # -- + + ddf = ym_survey_wall_merge.ddf() + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + + # -- + assert df.product_id.nunique() == 1 + assert df.team_id.nunique() == 1 + assert df.source.nunique() > 1 + + started_min_ts = df.started.min() + started_max_ts = df.started.max() + + assert type(started_min_ts) is pd.Timestamp + assert type(started_max_ts) is pd.Timestamp + + started_min: datetime = datetime.fromisoformat(str(started_min_ts)) + started_max: datetime = datetime.fromisoformat(str(started_max_ts)) + + started_delta = started_max - started_min + assert started_delta >= timedelta(days=3) diff --git a/tests/incite/schemas/__init__.py b/tests/incite/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/incite/schemas/test_admin_responses.py b/tests/incite/schemas/test_admin_responses.py new file mode 100644 index 0000000..43aa399 --- /dev/null +++ b/tests/incite/schemas/test_admin_responses.py @@ -0,0 +1,239 @@ +from datetime import datetime, timezone, timedelta +from random import sample +from typing import List + +import numpy as np +import pandas as pd +import pytest + +from generalresearch.incite.schemas import empty_dataframe_from_schema +from generalresearch.incite.schemas.admin_responses import ( + AdminPOPSchema, + SIX_HOUR_SECONDS, +) +from generalresearch.locales import Localelator + + +class TestAdminPOPSchema: + schema_df = empty_dataframe_from_schema(AdminPOPSchema) + countries = list(Localelator().get_all_countries())[:5] + dates = [datetime(year=2024, month=1, day=i, tzinfo=None) for i in range(1, 10)] + + @classmethod + def assign_valid_vals(cls, df: pd.DataFrame) -> pd.DataFrame: + for c in df.columns: + check_attrs: dict = AdminPOPSchema.columns[c].checks[0].statistics + df[c] = np.random.randint( + check_attrs["min_value"], check_attrs["max_value"], df.shape[0] + ) + + return df + + def test_empty(self): + with pytest.raises(Exception): + AdminPOPSchema.validate(pd.DataFrame()) + + def test_new_empty_df(self): + df = empty_dataframe_from_schema(AdminPOPSchema) + + assert isinstance(df, pd.DataFrame) + assert isinstance(df.index, pd.MultiIndex) + assert df.columns.size == len(AdminPOPSchema.columns) + + def test_valid(self): + # (1) Works with raw naive datetime + dates = [ + datetime(year=2024, month=1, day=i, tzinfo=None).isoformat() + for i in range(1, 10) + ] + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[dates, self.countries], names=["index0", "index1"] + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + df = AdminPOPSchema.validate(df) + assert isinstance(df, pd.DataFrame) + + # (2) Works with isoformat naive datetime + dates = [datetime(year=2024, month=1, day=i, tzinfo=None) for i in range(1, 10)] + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[dates, self.countries], names=["index0", "index1"] + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + df = AdminPOPSchema.validate(df) + assert isinstance(df, pd.DataFrame) + + def test_index_tz_parser(self): + tz_dates = [ + datetime(year=2024, month=1, day=i, tzinfo=timezone.utc) + for i in range(1, 10) + ] + + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[tz_dates, self.countries], names=["index0", "index1"] + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + # Initially, they're all set with a timezone + timestmaps: List[pd.Timestamp] = [i for i in df.index.get_level_values(0)] + assert all([ts.tz == timezone.utc for ts in timestmaps]) + + # After validation, the timezone is removed + df = AdminPOPSchema.validate(df) + timestmaps: List[pd.Timestamp] = [i for i in df.index.get_level_values(0)] + assert all([ts.tz is None for ts in timestmaps]) + + def test_index_tz_no_future_beyond_one_year(self): + now = datetime.now(tz=timezone.utc) + tz_dates = [now + timedelta(days=i * 365) for i in range(1, 10)] + + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[tz_dates, self.countries], names=["index0", "index1"] + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + with pytest.raises(Exception) as cm: + AdminPOPSchema.validate(df) + + assert ( + "Index 'index0' failed element-wise validator " + "number 0: less_than(" in str(cm.value) + ) + + def test_index_only_str(self): + # --- float64 to str! --- + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[self.dates, np.random.rand(1, 10)[0]], + names=["index0", "index1"], + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + vals = [i for i in df.index.get_level_values(1)] + assert all([isinstance(v, float) for v in vals]) + + df = AdminPOPSchema.validate(df, lazy=True) + + vals = [i for i in df.index.get_level_values(1)] + assert all([isinstance(v, str) for v in vals]) + + # --- int to str --- + + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[self.dates, sample(range(100), 20)], + names=["index0", "index1"], + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + vals = [i for i in df.index.get_level_values(1)] + assert all([isinstance(v, int) for v in vals]) + + df = AdminPOPSchema.validate(df, lazy=True) + + vals = [i for i in df.index.get_level_values(1)] + assert all([isinstance(v, str) for v in vals]) + + # a = 1 + assert isinstance(df, pd.DataFrame) + + def test_invalid_parsing(self): + # (1) Timezones AND as strings will still parse correctly + tz_str_dates = [ + datetime( + year=2024, month=1, day=1, minute=i, tzinfo=timezone.utc + ).isoformat() + for i in range(1, 10) + ] + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[tz_str_dates, self.countries], + names=["index0", "index1"], + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + df = AdminPOPSchema.validate(df, lazy=True) + + assert isinstance(df, pd.DataFrame) + timestmaps: List[pd.Timestamp] = [i for i in df.index.get_level_values(0)] + assert all([ts.tz is None for ts in timestmaps]) + + # (2) Timezones are removed + dates = [ + datetime(year=2024, month=1, day=1, minute=i, tzinfo=timezone.utc) + for i in range(1, 10) + ] + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[dates, self.countries], names=["index0", "index1"] + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + # Has tz before validation, and none after + timestmaps: List[pd.Timestamp] = [i for i in df.index.get_level_values(0)] + assert all([ts.tz is timezone.utc for ts in timestmaps]) + + df = AdminPOPSchema.validate(df, lazy=True) + + timestmaps: List[pd.Timestamp] = [i for i in df.index.get_level_values(0)] + assert all([ts.tz is None for ts in timestmaps]) + + def test_clipping(self): + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[self.dates, self.countries], + names=["index0", "index1"], + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + df = AdminPOPSchema.validate(df) + assert df.elapsed_avg.max() < SIX_HOUR_SECONDS + + # Now that we know it's valid, break the elapsed avg + df["elapsed_avg"] = np.random.randint( + SIX_HOUR_SECONDS, SIX_HOUR_SECONDS + 10_000, df.shape[0] + ) + assert df.elapsed_avg.max() > SIX_HOUR_SECONDS + + # Confirm it doesn't fail if the values are greater, and that + # all the values are clipped to the max + df = AdminPOPSchema.validate(df) + assert df.elapsed_avg.eq(SIX_HOUR_SECONDS).all() + + def test_rounding(self): + df = pd.DataFrame( + index=pd.MultiIndex.from_product( + iterables=[self.dates, self.countries], + names=["index0", "index1"], + ), + columns=self.schema_df.columns, + ) + df = self.assign_valid_vals(df) + + df["payout_avg"] = 2.123456789900002 + + assert df.payout_avg.sum() == 95.5555555455001 + + df = AdminPOPSchema.validate(df) + assert df.payout_avg.sum() == 95.40000000000003 diff --git a/tests/incite/schemas/test_thl_web.py b/tests/incite/schemas/test_thl_web.py new file mode 100644 index 0000000..7f4434b --- /dev/null +++ b/tests/incite/schemas/test_thl_web.py @@ -0,0 +1,70 @@ +import pandas as pd +import pytest +from pandera.errors import SchemaError + + +class TestWallSchema: + + def test_empty(self): + from generalresearch.incite.schemas.thl_web import THLWallSchema + + with pytest.raises(SchemaError): + THLWallSchema.validate(pd.DataFrame()) + + def test_index_missing(self): + from generalresearch.incite.schemas.thl_web import THLWallSchema + + df = pd.DataFrame(columns=THLWallSchema.columns.keys()) + + with pytest.raises(SchemaError) as cm: + THLWallSchema.validate(df) + + def test_no_rows(self): + from generalresearch.incite.schemas.thl_web import THLWallSchema + + df = pd.DataFrame(index=["uuid"], columns=THLWallSchema.columns.keys()) + + with pytest.raises(SchemaError) as cm: + THLWallSchema.validate(df) + + def test_new_empty_df(self): + from generalresearch.incite.schemas import empty_dataframe_from_schema + from generalresearch.incite.schemas.thl_web import THLWallSchema + + df = empty_dataframe_from_schema(THLWallSchema) + assert isinstance(df, pd.DataFrame) + assert df.columns.size == 20 + + +class TestSessionSchema: + + def test_empty(self): + from generalresearch.incite.schemas.thl_web import THLSessionSchema + + with pytest.raises(SchemaError): + THLSessionSchema.validate(pd.DataFrame()) + + def test_index_missing(self): + from generalresearch.incite.schemas.thl_web import THLSessionSchema + + df = pd.DataFrame(columns=THLSessionSchema.columns.keys()) + df.set_index("uuid", inplace=True) + + with pytest.raises(SchemaError) as cm: + THLSessionSchema.validate(df) + + def test_no_rows(self): + from generalresearch.incite.schemas.thl_web import THLSessionSchema + + df = pd.DataFrame(index=["id"], columns=THLSessionSchema.columns.keys()) + + with pytest.raises(SchemaError) as cm: + THLSessionSchema.validate(df) + + def test_new_empty_df(self): + from generalresearch.incite.schemas import empty_dataframe_from_schema + from generalresearch.incite.schemas.thl_web import THLSessionSchema + + df = empty_dataframe_from_schema(THLSessionSchema) + assert isinstance(df, pd.DataFrame) + assert df.columns.size == 21 diff --git a/tests/incite/test_collection_base.py b/tests/incite/test_collection_base.py new file mode 100644 index 0000000..497e5ab --- /dev/null +++ b/tests/incite/test_collection_base.py @@ -0,0 +1,318 @@ +from datetime import datetime, timezone, timedelta +from os.path import exists as pexists, join as pjoin +from pathlib import Path +from uuid import uuid4 + +import numpy as np +import pandas as pd +import pytest +from _pytest._code.code import ExceptionInfo + +from generalresearch.incite.base import CollectionBase +from test_utils.incite.conftest import mnt_filepath + +AGO_15min = (datetime.now(tz=timezone.utc) - timedelta(minutes=15)).replace( + microsecond=0 +) +AGO_1HR = (datetime.now(tz=timezone.utc) - timedelta(hours=1)).replace(microsecond=0) +AGO_2HR = (datetime.now(tz=timezone.utc) - timedelta(hours=2)).replace(microsecond=0) + + +class TestCollectionBase: + def test_init(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.df.empty is True + + def test_init_df(self, mnt_filepath): + # Only an empty pd.DataFrame can ever be provided + instance = CollectionBase( + df=pd.DataFrame({}), archive_path=mnt_filepath.data_src + ) + assert isinstance(instance.df, pd.DataFrame) + + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase( + df=pd.DataFrame(columns=[0, 1, 2]), archive_path=mnt_filepath.data_src + ) + assert "Do not provide a pd.DataFrame" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase( + df=pd.DataFrame(np.random.randint(100, size=(1000, 1)), columns=["A"]), + archive_path=mnt_filepath.data_src, + ) + assert "Do not provide a pd.DataFrame" in str(cm.value) + + def test_init_start(self, mnt_filepath): + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase( + start=datetime.now(tz=timezone.utc) - timedelta(days=10), + archive_path=mnt_filepath.data_src, + ) + assert "Collection.start must not have microseconds" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + tz = timezone(timedelta(hours=-5), "EST") + + CollectionBase( + start=datetime(year=2000, month=1, day=1, tzinfo=tz), + archive_path=mnt_filepath.data_src, + ) + assert "Timezone is not UTC" in str(cm.value) + + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.start == datetime( + year=2018, month=1, day=1, tzinfo=timezone.utc + ) + + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase( + start=AGO_2HR, offset="3h", archive_path=mnt_filepath.data_src + ) + assert "Offset must be equal to, or smaller the start timestamp" in str( + cm.value + ) + + def test_init_archive_path(self, mnt_filepath): + """DirectoryPath is apparently smart enough to confirm that the + directory path exists. + """ + + # (1) Basic, confirm an existing path works + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.archive_path == mnt_filepath.data_src + + # (2) It can't point to a file + file_path = Path(pjoin(mnt_filepath.data_src, f"{uuid4().hex}.zip")) + assert not pexists(file_path) + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase(archive_path=file_path) + assert "Path does not point to a directory" in str(cm.value) + + # (3) It doesn't create the directory if it doesn't exist + new_path = Path(pjoin(mnt_filepath.data_src, f"{uuid4().hex}/")) + assert not pexists(new_path) + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase(archive_path=new_path) + assert "Path does not point to a directory" in str(cm.value) + + def test_init_offset(self, mnt_filepath): + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase(offset="1:X", archive_path=mnt_filepath.data_src) + assert "Invalid offset alias provided. Please review:" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase(offset=f"59sec", archive_path=mnt_filepath.data_src) + assert "Must be equal to, or longer than 1 min" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + cm: ExceptionInfo + CollectionBase(offset=f"{365 * 101}d", archive_path=mnt_filepath.data_src) + assert "String should have at most 5 characters" in str(cm.value) + + +class TestCollectionBaseProperties: + + def test_items(self, mnt_filepath): + with pytest.raises(expected_exception=NotImplementedError) as cm: + cm: ExceptionInfo + instance = CollectionBase(archive_path=mnt_filepath.data_src) + x = instance.items + assert "Must override" in str(cm.value) + + def test_interval_range(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + # Private method requires the end parameter + with pytest.raises(expected_exception=AssertionError) as cm: + cm: ExceptionInfo + instance._interval_range(end=None) + assert "an end value must be provided" in str(cm.value) + + # End param must be same as started (which forces utc) + tz = timezone(timedelta(hours=-5), "EST") + with pytest.raises(expected_exception=AssertionError) as cm: + cm: ExceptionInfo + instance._interval_range(end=datetime.now(tz=tz)) + assert "Timezones must match" in str(cm.value) + + res = instance._interval_range(end=datetime.now(tz=timezone.utc)) + assert isinstance(res, pd.IntervalIndex) + assert res.closed_left + assert res.is_non_overlapping_monotonic + assert res.is_monotonic_increasing + assert res.is_unique + + def test_interval_range2(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert isinstance(instance.interval_range, list) + + # 1 hrs ago has 2 x 30min + the future 30min + OFFSET = "30min" + instance = CollectionBase( + start=AGO_1HR, offset=OFFSET, archive_path=mnt_filepath.data_src + ) + assert len(instance.interval_range) == 3 + assert instance.interval_range[0][0] == AGO_1HR + + # 1 hrs ago has 1 x 60min + the future 60min + OFFSET = "60min" + instance = CollectionBase( + start=AGO_1HR, offset=OFFSET, archive_path=mnt_filepath.data_src + ) + assert len(instance.interval_range) == 2 + + def test_progress(self, mnt_filepath): + with pytest.raises(expected_exception=NotImplementedError) as cm: + cm: ExceptionInfo + instance = CollectionBase( + start=AGO_15min, offset="3min", archive_path=mnt_filepath.data_src + ) + x = instance.progress + assert "Must override" in str(cm.value) + + def test_progress2(self, mnt_filepath): + instance = CollectionBase( + start=AGO_2HR, + offset="15min", + archive_path=mnt_filepath.data_src, + ) + assert instance.df.empty + + with pytest.raises(expected_exception=NotImplementedError) as cm: + df = instance.progress + assert "Must override" in str(cm.value) + + def test_items2(self, mnt_filepath): + """There can't be a test for this because the Items need a path whic + isn't possible in the generic form + """ + instance = CollectionBase( + start=AGO_1HR, offset="5min", archive_path=mnt_filepath.data_src + ) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + cm: ExceptionInfo + items = instance.items + assert "Must override" in str(cm.value) + + # item = items[-3] + # ddf = instance.ddf(items=[item], include_partial=True, force_rr_latest=False) + # df = item.validate_ddf(ddf=ddf) + # assert isinstance(df, pd.DataFrame) + # assert len(df.columns) == 16 + # assert str(df.product_id.dtype) == "object" + # assert str(ddf.product_id.dtype) == "string" + + def test_items3(self, mnt_filepath): + instance = CollectionBase( + start=AGO_2HR, + offset="15min", + archive_path=mnt_filepath.data_src, + ) + with pytest.raises(expected_exception=NotImplementedError) as cm: + item = instance.items[0] + assert "Must override" in str(cm.value) + + +class TestCollectionBaseMethodsCleanup: + def test_fetch_force_rr_latest(self, mnt_filepath): + coll = CollectionBase(archive_path=mnt_filepath.data_src) + + with pytest.raises(expected_exception=Exception) as cm: + cm: ExceptionInfo + coll.fetch_force_rr_latest(sources=[]) + assert "Must override" in str(cm.value) + + def test_fetch_all_paths(self, mnt_filepath): + coll = CollectionBase(archive_path=mnt_filepath.data_src) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + cm: ExceptionInfo + coll.fetch_all_paths( + items=None, force_rr_latest=False, include_partial=False + ) + assert "Must override" in str(cm.value) + + +class TestCollectionBaseMethodsCleanup: + @pytest.mark.skip + def test_cleanup_partials(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.cleanup_partials() is None # it doesn't return anything + + def test_clear_tmp_archives(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.clear_tmp_archives() is None # it doesn't return anything + + @pytest.mark.skip + def test_clear_corrupt_archives(self, mnt_filepath): + """TODO: expand this so it actually has corrupt archives that we + check to see if they're removed + """ + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.clear_corrupt_archives() is None # it doesn't return anything + + @pytest.mark.skip + def test_rebuild_symlinks(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + assert instance.rebuild_symlinks() is None + + +class TestCollectionBaseMethodsSourceTiming: + + def test_get_item(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + i = pd.Interval(left=1, right=2, closed="left") + + with pytest.raises(expected_exception=NotImplementedError) as cm: + instance.get_item(interval=i) + assert "Must override" in str(cm.value) + + def test_get_item_start(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + + dt = datetime.now(tz=timezone.utc) + start = pd.Timestamp(dt) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + instance.get_item_start(start=start) + assert "Must override" in str(cm.value) + + def test_get_items(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + + dt = datetime.now(tz=timezone.utc) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + instance.get_items(since=dt) + assert "Must override" in str(cm.value) + + def test_get_items_from_year(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + instance.get_items_from_year(year=2020) + assert "Must override" in str(cm.value) + + def test_get_items_last90(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + instance.get_items_last90() + assert "Must override" in str(cm.value) + + def test_get_items_last365(self, mnt_filepath): + instance = CollectionBase(archive_path=mnt_filepath.data_src) + + with pytest.raises(expected_exception=NotImplementedError) as cm: + instance.get_items_last365() + assert "Must override" in str(cm.value) diff --git a/tests/incite/test_collection_base_item.py b/tests/incite/test_collection_base_item.py new file mode 100644 index 0000000..e5d1d02 --- /dev/null +++ b/tests/incite/test_collection_base_item.py @@ -0,0 +1,223 @@ +from datetime import datetime, timezone +from os.path import join as pjoin +from pathlib import Path +from uuid import uuid4 + +import dask.dataframe as dd +import pandas as pd +import pytest +from pydantic import ValidationError + +from generalresearch.incite.base import CollectionItemBase + + +class TestCollectionItemBase: + def test_init(self): + dt = datetime.now(tz=timezone.utc).replace(microsecond=0) + + instance = CollectionItemBase() + instance2 = CollectionItemBase(start=dt) + + assert isinstance(instance, CollectionItemBase) + assert isinstance(instance2, CollectionItemBase) + + assert instance.start.second == instance2.start.second + assert 0 == instance.start.microsecond == instance2.start.microsecond + + def test_init_start(self): + dt = datetime.now(tz=timezone.utc) + + with pytest.raises(expected_exception=ValidationError) as cm: + CollectionItemBase(start=dt) + + assert "CollectionItem.start must not have microsecond precision" in str( + cm.value + ) + + +class TestCollectionItemBaseProperties: + + def test_finish(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.finish + + def test_interval(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.interval + + def test_filename(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=NotImplementedError) as cm: + res = instance.filename + + assert "Do not use CollectionItemBase directly" in str(cm.value) + + def test_partial_filename(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=NotImplementedError) as cm: + res = instance.filename + + assert "Do not use CollectionItemBase directly" in str(cm.value) + + def test_empty_filename(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=NotImplementedError) as cm: + res = instance.filename + + assert "Do not use CollectionItemBase directly" in str(cm.value) + + def test_path(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.path + + def test_partial_path(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.partial_path + + def test_empty_path(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.empty_path + + +class TestCollectionItemBaseMethods: + + @pytest.mark.skip + def test_next_numbered_path(self): + pass + + @pytest.mark.skip + def test_search_highest_numbered_path(self): + pass + + def test_tmp_filename(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=NotImplementedError) as cm: + res = instance.tmp_filename() + assert "Do not use CollectionItemBase directly" in str(cm.value) + + def test_tmp_path(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.tmp_path() + + def test_is_empty(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.is_empty() + + def test_has_empty(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.has_empty() + + def test_has_partial_archive(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.has_partial_archive() + + @pytest.mark.parametrize("include_empty", [True, False]) + def test_has_archive(self, include_empty): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.has_archive(include_empty=include_empty) + + def test_delete_archive_file(self, mnt_filepath): + path1 = Path(pjoin(mnt_filepath.data_src, f"{uuid4().hex}.zip")) + + # Confirm it doesn't exist, and that delete_archive() doesn't throw + # an error when trying to delete a non-existent file or folder + assert not path1.exists() + CollectionItemBase.delete_archive(generic_path=path1) + # TODO: LOG.warning(f"tried removing non-existent file: {generic_path}") + + # Create it, confirm it exists, delete it, and confirm it doesn't exist + path1.touch() + assert path1.exists() + CollectionItemBase.delete_archive(generic_path=path1) + assert not path1.exists() + + def test_delete_archive_dir(self, mnt_filepath): + path1 = Path(pjoin(mnt_filepath.data_src, f"{uuid4().hex}")) + + # Confirm it doesn't exist, and that delete_archive() doesn't throw + # an error when trying to delete a non-existent file or folder + assert not path1.exists() + CollectionItemBase.delete_archive(generic_path=path1) + # TODO: LOG.warning(f"tried removing non-existent file: {generic_path}") + + # Create it, confirm it exists, delete it, and confirm it doesn't exist + path1.mkdir() + assert path1.exists() + assert path1.is_dir() + CollectionItemBase.delete_archive(generic_path=path1) + assert not path1.exists() + + def test_should_archive(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.should_archive() + + def test_set_empty(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.set_empty() + + def test_valid_archive(self): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=AttributeError) as cm: + res = instance.valid_archive(generic_path=None, sample=None) + + +class TestCollectionItemBaseMethodsORM: + + @pytest.mark.skip + def test_from_archive(self): + pass + + @pytest.mark.parametrize("is_partial", [True, False]) + def test_to_archive(self, is_partial): + instance = CollectionItemBase() + + with pytest.raises(expected_exception=NotImplementedError) as cm: + res = instance.to_archive( + ddf=dd.from_pandas(data=pd.DataFrame()), is_partial=is_partial + ) + assert "Must override" in str(cm.value) + + @pytest.mark.skip + def test__to_dict(self): + pass + + @pytest.mark.skip + def test_delete_partial(self): + pass + + @pytest.mark.skip + def test_cleanup_partials(self): + pass + + @pytest.mark.skip + def test_delete_dangling_partials(self): + pass diff --git a/tests/incite/test_grl_flow.py b/tests/incite/test_grl_flow.py new file mode 100644 index 0000000..c632f9a --- /dev/null +++ b/tests/incite/test_grl_flow.py @@ -0,0 +1,23 @@ +class TestGRLFlow: + + def test_init(self, mnt_filepath, thl_web_rr): + from generalresearch.incite.defaults import ( + ledger_df_collection, + task_df_collection, + pop_ledger as plm, + ) + + from generalresearch.incite.collections.thl_web import ( + LedgerDFCollection, + TaskAdjustmentDFCollection, + ) + from generalresearch.incite.mergers.pop_ledger import PopLedgerMerge + + ledger_df = ledger_df_collection(ds=mnt_filepath, pg_config=thl_web_rr) + assert isinstance(ledger_df, LedgerDFCollection) + + task_df = task_df_collection(ds=mnt_filepath, pg_config=thl_web_rr) + assert isinstance(task_df, TaskAdjustmentDFCollection) + + pop_ledger = plm(ds=mnt_filepath) + assert isinstance(pop_ledger, PopLedgerMerge) diff --git a/tests/incite/test_interval_idx.py b/tests/incite/test_interval_idx.py new file mode 100644 index 0000000..ea2bced --- /dev/null +++ b/tests/incite/test_interval_idx.py @@ -0,0 +1,23 @@ +import pandas as pd +from datetime import datetime, timezone, timedelta + + +class TestIntervalIndex: + + def test_init(self): + start = datetime(year=2000, month=1, day=1) + end = datetime(year=2000, month=1, day=10) + + iv_r: pd.IntervalIndex = pd.interval_range( + start=start, end=end, freq="1d", closed="left" + ) + assert isinstance(iv_r, pd.IntervalIndex) + assert len(iv_r.to_list()) == 9 + + # If the offset is longer than the end - start it will not + # error. It will simply have 0 rows. + iv_r: pd.IntervalIndex = pd.interval_range( + start=start, end=end, freq="30d", closed="left" + ) + assert isinstance(iv_r, pd.IntervalIndex) + assert len(iv_r.to_list()) == 0 diff --git a/tests/managers/__init__.py b/tests/managers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/managers/gr/__init__.py b/tests/managers/gr/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/managers/gr/test_authentication.py b/tests/managers/gr/test_authentication.py new file mode 100644 index 0000000..53b6931 --- /dev/null +++ b/tests/managers/gr/test_authentication.py @@ -0,0 +1,125 @@ +import logging +from random import randint +from uuid import uuid4 + +import pytest + +from generalresearch.models.gr.authentication import GRUser +from test_utils.models.conftest import gr_user + +SSO_ISSUER = "" + + +class TestGRUserManager: + + def test_create(self, gr_um): + from generalresearch.models.gr.authentication import GRUser + + user: GRUser = gr_um.create_dummy() + instance = gr_um.get_by_id(user.id) + assert user.id == instance.id + + instance2 = gr_um.get_by_id(user.id) + assert user.model_dump_json() == instance2.model_dump_json() + + def test_get_by_id(self, gr_user, gr_um): + with pytest.raises(expected_exception=ValueError) as cm: + gr_um.get_by_id(gr_user_id=999_999_999) + assert "GRUser not found" in str(cm.value) + + instance = gr_um.get_by_id(gr_user_id=gr_user.id) + assert instance.sub == gr_user.sub + + def test_get_by_sub(self, gr_user, gr_um): + with pytest.raises(expected_exception=ValueError) as cm: + gr_um.get_by_sub(sub=uuid4().hex) + assert "GRUser not found" in str(cm.value) + + instance = gr_um.get_by_sub(sub=gr_user.sub) + assert instance.id == gr_user.id + + def test_get_by_sub_or_create(self, gr_user, gr_um): + sub = f"{uuid4().hex}-{uuid4().hex}" + + with pytest.raises(expected_exception=ValueError) as cm: + gr_um.get_by_sub(sub=sub) + assert "GRUser not found" in str(cm.value) + + instance = gr_um.get_by_sub_or_create(sub=sub) + assert isinstance(instance, GRUser) + assert instance.sub == sub + + def test_get_all(self, gr_um): + res1 = gr_um.get_all() + assert isinstance(res1, list) + + gr_um.create_dummy() + res2 = gr_um.get_all() + assert len(res1) == len(res2) - 1 + + def test_get_by_team(self, gr_um): + res = gr_um.get_by_team(team_id=999_999_999) + assert isinstance(res, list) + assert res == [] + + def test_list_product_uuids(self, caplog, gr_user, gr_um, thl_web_rr): + with caplog.at_level(logging.WARNING): + gr_um.list_product_uuids(user=gr_user, thl_pg_config=thl_web_rr) + assert "prefetch not run" in caplog.text + + +class TestGRTokenManager: + + def test_create(self, gr_user, gr_tm): + assert gr_tm.create(user_id=gr_user.id) is None + + token = gr_tm.get_by_user_id(user_id=gr_user.id) + assert gr_user.id == token.user_id + + def test_get_by_user_id(self, gr_user, gr_tm): + assert gr_tm.create(user_id=gr_user.id) is None + + token = gr_tm.get_by_user_id(user_id=gr_user.id) + assert gr_user.id == token.user_id + + def test_prefetch_user(self, gr_user, gr_tm, gr_db, gr_redis_config): + from generalresearch.models.gr.authentication import GRToken + + gr_tm.create(user_id=gr_user.id) + + token: GRToken = gr_tm.get_by_user_id(user_id=gr_user.id) + assert token.user is None + + token.prefetch_user(pg_config=gr_db, redis_config=gr_redis_config) + assert token.user.id == gr_user.id + + def test_get_by_key(self, gr_user, gr_um, gr_tm): + gr_tm.create(user_id=gr_user.id) + token = gr_tm.get_by_user_id(user_id=gr_user.id) + + instance = gr_tm.get_by_key(api_key=token.key) + assert token.created == instance.created + + # Search for non-existent key + with pytest.raises(expected_exception=Exception) as cm: + gr_tm.get_by_key(api_key=uuid4().hex) + assert "No GRUser with token of " in str(cm.value) + + @pytest.mark.skip(reason="no idea how to actually test this...") + def test_get_by_sso_key(self, gr_user, gr_um, gr_tm, gr_redis_config): + from generalresearch.models.gr.authentication import GRToken + + api_key = "..." + jwks = { + # ... + } + + instance = gr_tm.get_by_key( + api_key=api_key, + jwks=jwks, + audience="...", + issuer=SSO_ISSUER, + gr_redis_config=gr_redis_config, + ) + + assert isinstance(instance, GRToken) diff --git a/tests/managers/gr/test_business.py b/tests/managers/gr/test_business.py new file mode 100644 index 0000000..7eb77f8 --- /dev/null +++ b/tests/managers/gr/test_business.py @@ -0,0 +1,150 @@ +from uuid import uuid4 + +import pytest + +from test_utils.models.conftest import business + + +class TestBusinessBankAccountManager: + + def test_init(self, business_bank_account_manager, gr_db): + assert business_bank_account_manager.pg_config == gr_db + + def test_create(self, business, business_bank_account_manager): + from generalresearch.models.gr.business import ( + TransferMethod, + BusinessBankAccount, + ) + + instance = business_bank_account_manager.create( + business_id=business.id, + uuid=uuid4().hex, + transfer_method=TransferMethod.ACH, + ) + assert isinstance(instance, BusinessBankAccount) + assert isinstance(instance.id, int) + + res = business_bank_account_manager.get_by_business_id( + business_id=instance.business_id + ) + assert isinstance(res, list) + assert len(res) == 1 + assert isinstance(res[0], BusinessBankAccount) + assert res[0].business_id == instance.business_id + + +class TestBusinessAddressManager: + + def test_create(self, business, business_address_manager): + from generalresearch.models.gr.business import BusinessAddress + + res = business_address_manager.create(uuid=uuid4().hex, business_id=business.id) + assert isinstance(res, BusinessAddress) + assert isinstance(res.id, int) + + +class TestBusinessManager: + + def test_create(self, business_manager): + from generalresearch.models.gr.business import Business + + instance = business_manager.create_dummy() + assert isinstance(instance, Business) + assert isinstance(instance.id, int) + + def test_get_or_create(self, business_manager): + uuid_key = uuid4().hex + + assert business_manager.get_by_uuid(business_uuid=uuid_key) is None + + instance = business_manager.get_or_create( + uuid=uuid_key, + name=f"name-{uuid4().hex[:6]}", + ) + + res = business_manager.get_by_uuid(business_uuid=uuid_key) + assert res.id == instance.id + + def test_get_all(self, business_manager): + res1 = business_manager.get_all() + assert isinstance(res1, list) + + business_manager.create_dummy() + res2 = business_manager.get_all() + assert len(res1) == len(res2) - 1 + + @pytest.mark.skip(reason="TODO") + def test_get_by_team(self): + pass + + def test_get_by_user_id( + self, business_manager, gr_user, team_manager, membership_manager + ): + res = business_manager.get_by_user_id(user_id=gr_user.id) + assert len(res) == 0 + + # Create a Business, but don't add it to anything + b1 = business_manager.create_dummy() + res = business_manager.get_by_user_id(user_id=gr_user.id) + assert len(res) == 0 + + # Create a Team, but don't create any Memberships + t1 = team_manager.create_dummy() + res = business_manager.get_by_user_id(user_id=gr_user.id) + assert len(res) == 0 + + # Create a Membership for the gr_user to the Team... but it doesn't + # matter because the Team doesn't have any Business yet + m1 = membership_manager.create(team=t1, gr_user=gr_user) + res = business_manager.get_by_user_id(user_id=gr_user.id) + assert len(res) == 0 + + # Add the Business to the Team... now the Business should be available + # to the gr_user + team_manager.add_business(team=t1, business=b1) + res = business_manager.get_by_user_id(user_id=gr_user.id) + assert len(res) == 1 + + # Add another Business to the Team! + b2 = business_manager.create_dummy() + team_manager.add_business(team=t1, business=b2) + res = business_manager.get_by_user_id(user_id=gr_user.id) + assert len(res) == 2 + + @pytest.mark.skip(reason="TODO") + def test_get_uuids_by_user_id(self): + pass + + def test_get_by_uuid(self, business, business_manager): + instance = business_manager.get_by_uuid(business_uuid=business.uuid) + assert business.id == instance.id + + def test_get_by_id(self, business, business_manager): + instance = business_manager.get_by_id(business_id=business.id) + assert business.uuid == instance.uuid + + def test_cache_key(self, business): + assert "business:" in business.cache_key + + # def test_create_raise_on_duplicate(self): + # b_uuid = uuid4().hex + # + # # Make the first one + # business = BusinessManager.create( + # uuid=b_uuid, + # name=f"test-{b_uuid[:6]}") + # assert isinstance(business, Business) + # + # # Try to make it again + # with pytest.raises(expected_exception=psycopg.errors.UniqueViolation): + # business = BusinessManager.create( + # uuid=b_uuid, + # name=f"test-{b_uuid[:6]}") + # + # def test_get_by_team(self, team): + # for idx in range(5): + # BusinessManager.create(name=f"Business Name #{uuid4().hex[:6]}", team=team) + # + # res = BusinessManager.get_by_team(team_id=team.id) + # assert isinstance(res, list) + # assert 5 == len(res) diff --git a/tests/managers/gr/test_team.py b/tests/managers/gr/test_team.py new file mode 100644 index 0000000..9215da4 --- /dev/null +++ b/tests/managers/gr/test_team.py @@ -0,0 +1,125 @@ +from uuid import uuid4 + +from test_utils.models.conftest import team + + +class TestMembershipManager: + + def test_init(self, membership_manager, gr_db): + assert membership_manager.pg_config == gr_db + + +class TestTeamManager: + + def test_init(self, team_manager, gr_db): + assert team_manager.pg_config == gr_db + + def test_get_or_create(self, team_manager): + from generalresearch.models.gr.team import Team + + new_uuid = uuid4().hex + + team: Team = team_manager.get_or_create(uuid=new_uuid) + + assert isinstance(team, Team) + assert isinstance(team.id, int) + assert team.uuid == new_uuid + assert team.name == "< Unknown >" + + def test_get_all(self, team_manager): + res1 = team_manager.get_all() + assert isinstance(res1, list) + + team_manager.create_dummy() + res2 = team_manager.get_all() + assert len(res1) == len(res2) - 1 + + def test_create(self, team_manager): + from generalresearch.models.gr.team import Team + + team: Team = team_manager.create_dummy() + assert isinstance(team, Team) + assert isinstance(team.id, int) + + def test_add_user(self, team, team_manager, gr_um, gr_db, gr_redis_config): + from generalresearch.models.gr.authentication import GRUser + from generalresearch.models.gr.team import Membership + + user: GRUser = gr_um.create_dummy() + + instance = team_manager.add_user(team=team, gr_user=user) + assert isinstance(instance, Membership) + + # assert team.gr_users is None + team.prefetch_gr_users(pg_config=gr_db, redis_config=gr_redis_config) + assert isinstance(team.gr_users, list) + assert len(team.gr_users) + assert team.gr_users == [user] + + def test_get_by_uuid(self, team_manager): + from generalresearch.models.gr.team import Team + + team: Team = team_manager.create_dummy() + + instance = team_manager.get_by_uuid(team_uuid=team.uuid) + assert team.id == instance.id + + def test_get_by_id(self, team_manager): + from generalresearch.models.gr.team import Team + + team: Team = team_manager.create_dummy() + + instance = team_manager.get_by_id(team_id=team.id) + assert team.uuid == instance.uuid + + def test_get_by_user(self, team, team_manager, gr_um): + from generalresearch.models.gr.authentication import GRUser + from generalresearch.models.gr.team import Team + + user: GRUser = gr_um.create_dummy() + team_manager.add_user(team=team, gr_user=user) + + res = team_manager.get_by_user(gr_user=user) + assert isinstance(res, list) + assert len(res) == 1 + instance = res[0] + assert isinstance(instance, Team) + assert instance.uuid == team.uuid + + def test_get_by_user_duplicates( + self, + gr_user_token, + gr_user, + membership, + product_factory, + membership_factory, + team, + thl_web_rr, + gr_redis_config, + gr_db, + ): + product_factory(team=team) + membership_factory(team=team, gr_user=gr_user) + + gr_user.prefetch_teams( + pg_config=gr_db, + redis_config=gr_redis_config, + ) + + assert len(gr_user.teams) == 1 + + # def test_create_raise_on_duplicate(self): + # t_uuid = uuid4().hex + # + # # Make the first one + # team = TeamManager.create( + # uuid=t_uuid, + # name=f"test-{t_uuid[:6]}") + # assert isinstance(team, Team) + # + # # Try to make it again + # with pytest.raises(expected_exception=psycopg.errors.UniqueViolation): + # TeamManager.create( + # uuid=t_uuid, + # name=f"test-{t_uuid[:6]}") + # diff --git a/tests/managers/leaderboard.py b/tests/managers/leaderboard.py new file mode 100644 index 0000000..4d32dd0 --- /dev/null +++ b/tests/managers/leaderboard.py @@ -0,0 +1,274 @@ +import os +import time +import zoneinfo +from datetime import datetime, timezone +from decimal import Decimal +from uuid import uuid4 + +import pytest + +from generalresearch.managers.leaderboard.manager import LeaderboardManager +from generalresearch.managers.leaderboard.tasks import hit_leaderboards +from generalresearch.models.thl.definitions import Status +from generalresearch.models.thl.user import User +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.session import Session +from generalresearch.models.thl.leaderboard import ( + LeaderboardCode, + LeaderboardFrequency, + LeaderboardRow, +) +from generalresearch.models.thl.product import ( + PayoutConfig, + PayoutTransformation, + PayoutTransformationPercentArgs, +) + +# random uuid for leaderboard tests +product_id = uuid4().hex + + +@pytest.fixture(autouse=True) +def set_timezone(): + os.environ["TZ"] = "UTC" + time.tzset() + yield + # Optionally reset to default + os.environ.pop("TZ", None) + time.tzset() + + +@pytest.fixture +def session_factory(): + return _create_session + + +def _create_session( + product_user_id="aaa", country_iso="us", user_payout=Decimal("1.00") +): + user = User( + product_id=product_id, + product_user_id=product_user_id, + ) + user.product = Product( + id=product_id, + name="test", + redirect_url="https://www.example.com", + payout_config=PayoutConfig( + payout_transformation=PayoutTransformation( + f="payout_transformation_percent", + kwargs=PayoutTransformationPercentArgs(pct=0.5), + ) + ), + ) + session = Session( + user=user, + started=datetime(2025, 2, 5, 6, tzinfo=timezone.utc), + id=1, + country_iso=country_iso, + status=Status.COMPLETE, + payout=Decimal("2.00"), + user_payout=user_payout, + ) + return session + + +@pytest.fixture(scope="function") +def setup_leaderboards(thl_redis): + complete_count = { + "aaa": 10, + "bbb": 6, + "ccc": 6, + "ddd": 6, + "eee": 2, + "fff": 1, + "ggg": 1, + } + sum_payout = {"aaa": 345, "bbb": 100, "ccc": 100} + max_payout = sum_payout + country_iso = "us" + for freq in [ + LeaderboardFrequency.DAILY, + LeaderboardFrequency.WEEKLY, + LeaderboardFrequency.MONTHLY, + ]: + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.COMPLETE_COUNT, + freq=freq, + product_id=product_id, + country_iso=country_iso, + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + thl_redis.delete(m.key) + thl_redis.zadd(m.key, complete_count) + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.SUM_PAYOUTS, + freq=freq, + product_id=product_id, + country_iso=country_iso, + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + thl_redis.delete(m.key) + thl_redis.zadd(m.key, sum_payout) + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.LARGEST_PAYOUT, + freq=freq, + product_id=product_id, + country_iso=country_iso, + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + thl_redis.delete(m.key) + thl_redis.zadd(m.key, max_payout) + + +class TestLeaderboards: + + def test_leaderboard_manager(self, setup_leaderboards, thl_redis): + country_iso = "us" + board_code = LeaderboardCode.COMPLETE_COUNT + freq = LeaderboardFrequency.DAILY + m = LeaderboardManager( + redis_client=thl_redis, + board_code=board_code, + freq=freq, + product_id=product_id, + country_iso=country_iso, + within_time=datetime(2025, 2, 5, 0, 0, 0), + ) + lb = m.get_leaderboard() + assert lb.period_start_local == datetime( + 2025, 2, 5, 0, 0, 0, tzinfo=zoneinfo.ZoneInfo(key="America/New_York") + ) + assert lb.period_end_local == datetime( + 2025, + 2, + 5, + 23, + 59, + 59, + 999999, + tzinfo=zoneinfo.ZoneInfo(key="America/New_York"), + ) + assert lb.period_start_utc == datetime(2025, 2, 5, 5, tzinfo=timezone.utc) + assert lb.row_count == 7 + assert lb.rows == [ + LeaderboardRow(bpuid="aaa", rank=1, value=10), + LeaderboardRow(bpuid="bbb", rank=2, value=6), + LeaderboardRow(bpuid="ccc", rank=2, value=6), + LeaderboardRow(bpuid="ddd", rank=2, value=6), + LeaderboardRow(bpuid="eee", rank=5, value=2), + LeaderboardRow(bpuid="fff", rank=6, value=1), + LeaderboardRow(bpuid="ggg", rank=6, value=1), + ] + + def test_leaderboard_manager_bpuid(self, setup_leaderboards, thl_redis): + country_iso = "us" + board_code = LeaderboardCode.COMPLETE_COUNT + freq = LeaderboardFrequency.DAILY + m = LeaderboardManager( + redis_client=thl_redis, + board_code=board_code, + freq=freq, + product_id=product_id, + country_iso=country_iso, + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + lb = m.get_leaderboard(bp_user_id="fff", limit=1) + + # TODO: this won't work correctly if I request bpuid 'ggg', because it + # is ordered at the end even though it is tied, so it won't get a + # row after ('fff') + + assert lb.rows == [ + LeaderboardRow(bpuid="eee", rank=5, value=2), + LeaderboardRow(bpuid="fff", rank=6, value=1), + LeaderboardRow(bpuid="ggg", rank=6, value=1), + ] + + lb.censor() + assert lb.rows[0].bpuid == "ee*" + + def test_leaderboard_hit(self, setup_leaderboards, session_factory, thl_redis): + hit_leaderboards(redis_client=thl_redis, session=session_factory()) + + for freq in [ + LeaderboardFrequency.DAILY, + LeaderboardFrequency.WEEKLY, + LeaderboardFrequency.MONTHLY, + ]: + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.COMPLETE_COUNT, + freq=freq, + product_id=product_id, + country_iso="us", + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + lb = m.get_leaderboard(limit=1) + assert lb.row_count == 7 + assert lb.rows == [LeaderboardRow(bpuid="aaa", rank=1, value=11)] + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.LARGEST_PAYOUT, + freq=freq, + product_id=product_id, + country_iso="us", + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + lb = m.get_leaderboard(limit=1) + assert lb.row_count == 3 + assert lb.rows == [LeaderboardRow(bpuid="aaa", rank=1, value=345)] + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.SUM_PAYOUTS, + freq=freq, + product_id=product_id, + country_iso="us", + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + lb = m.get_leaderboard(limit=1) + assert lb.row_count == 3 + assert lb.rows == [LeaderboardRow(bpuid="aaa", rank=1, value=345 + 100)] + + def test_leaderboard_hit_new_row( + self, setup_leaderboards, session_factory, thl_redis + ): + session = session_factory(product_user_id="zzz") + hit_leaderboards(redis_client=thl_redis, session=session) + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.COMPLETE_COUNT, + freq=LeaderboardFrequency.DAILY, + product_id=product_id, + country_iso="us", + within_time=datetime(2025, 2, 5, 12, 12, 12), + ) + lb = m.get_leaderboard() + assert lb.row_count == 8 + assert LeaderboardRow(bpuid="zzz", value=1, rank=6) in lb.rows + + def test_leaderboard_country(self, thl_redis): + m = LeaderboardManager( + redis_client=thl_redis, + board_code=LeaderboardCode.COMPLETE_COUNT, + freq=LeaderboardFrequency.DAILY, + product_id=product_id, + country_iso="jp", + within_time=datetime( + 2025, + 2, + 1, + ), + ) + lb = m.get_leaderboard() + assert lb.row_count == 0 + assert lb.period_start_local == datetime( + 2025, 2, 1, 0, 0, 0, tzinfo=zoneinfo.ZoneInfo(key="Asia/Tokyo") + ) + assert lb.local_start_time == "2025-02-01T00:00:00+09:00" + assert lb.local_end_time == "2025-02-01T23:59:59.999999+09:00" + assert lb.period_start_utc == datetime(2025, 1, 31, 15, tzinfo=timezone.utc) + print(lb.model_dump(mode="json")) diff --git a/tests/managers/test_events.py b/tests/managers/test_events.py new file mode 100644 index 0000000..a0fab38 --- /dev/null +++ b/tests/managers/test_events.py @@ -0,0 +1,530 @@ +import random +import time +from datetime import timedelta, datetime, timezone +from decimal import Decimal +from functools import partial +from typing import Optional +from uuid import uuid4 + +import math +import pytest +from math import floor + +from generalresearch.managers.events import EventSubscriber +from generalresearch.models import Source +from generalresearch.models.events import ( + MessageKind, + EventType, + AggregateBySource, + MaxGaugeBySource, +) +from generalresearch.models.legacy.bucket import Bucket +from generalresearch.models.thl.definitions import Status, StatusCode1 +from generalresearch.models.thl.session import Session, Wall +from generalresearch.models.thl.user import User + + +# We don't need anything in the db, so not using the db fixtures +@pytest.fixture(scope="function") +def product_id(product_manager): + return uuid4().hex + + +@pytest.fixture(scope="function") +def user_factory(product_id): + return partial(create_dummy, product_id=product_id) + + +@pytest.fixture(scope="function") +def event_subscriber(thl_redis_config, product_id): + return EventSubscriber(redis_config=thl_redis_config, product_id=product_id) + + +def create_dummy( + product_id: Optional[str] = None, product_user_id: Optional[str] = None +) -> User: + return User( + product_id=product_id, + product_user_id=product_user_id or uuid4().hex, + uuid=uuid4().hex, + created=datetime.now(tz=timezone.utc), + user_id=random.randint(0, floor(2**32 / 2)), + ) + + +class TestActiveUsers: + + def test_run_empty(self, event_manager, product_id): + res = event_manager.get_user_stats(product_id) + assert res == { + "active_users_last_1h": 0, + "active_users_last_24h": 0, + "signups_last_24h": 0, + "in_progress_users": 0, + } + + def test_run(self, event_manager, product_id, user_factory): + event_manager.clear_global_user_stats() + user1: User = user_factory() + + # No matter how many times we do this, they're only active once + event_manager.handle_user(user1) + event_manager.handle_user(user1) + event_manager.handle_user(user1) + + res = event_manager.get_user_stats(product_id) + assert res == { + "active_users_last_1h": 1, + "active_users_last_24h": 1, + "signups_last_24h": 1, + "in_progress_users": 0, + } + + assert event_manager.get_global_user_stats() == { + "active_users_last_1h": 1, + "active_users_last_24h": 1, + "signups_last_24h": 1, + "in_progress_users": 0, + } + + # Create a 2nd user in another product + product_id2 = uuid4().hex + user2: User = user_factory(product_id=product_id2) + # Change to say user was created >24 hrs ago + user2.created = user2.created - timedelta(hours=25) + event_manager.handle_user(user2) + + # And now each have 1 active user + assert event_manager.get_user_stats(product_id) == { + "active_users_last_1h": 1, + "active_users_last_24h": 1, + "signups_last_24h": 1, + "in_progress_users": 0, + } + # user2 was created older than 24hrs ago + assert event_manager.get_user_stats(product_id2) == { + "active_users_last_1h": 1, + "active_users_last_24h": 1, + "signups_last_24h": 0, + "in_progress_users": 0, + } + # 2 globally active + assert event_manager.get_global_user_stats() == { + "active_users_last_1h": 2, + "active_users_last_24h": 2, + "signups_last_24h": 1, + "in_progress_users": 0, + } + + def test_inprogress(self, event_manager, product_id, user_factory): + event_manager.clear_global_user_stats() + user1: User = user_factory() + user2: User = user_factory() + + # No matter how many times we do this, they're only active once + event_manager.mark_user_inprogress(user1) + event_manager.mark_user_inprogress(user1) + event_manager.mark_user_inprogress(user1) + event_manager.mark_user_inprogress(user2) + + res = event_manager.get_user_stats(product_id) + assert res["in_progress_users"] == 2 + + event_manager.unmark_user_inprogress(user1) + res = event_manager.get_user_stats(product_id) + assert res["in_progress_users"] == 1 + + # Shouldn't do anything + event_manager.unmark_user_inprogress(user1) + res = event_manager.get_user_stats(product_id) + assert res["in_progress_users"] == 1 + + def test_expiry(self, event_manager, product_id, user_factory): + event_manager.clear_global_user_stats() + user1: User = user_factory() + event_manager.handle_user(user1) + event_manager.mark_user_inprogress(user1) + sec_24hr = timedelta(hours=24).total_seconds() + + # We don't want to wait an hour to test this, so we're going to + # just confirm that the keys will expire + time.sleep(1.1) + ttl = event_manager.redis_client.httl( + f"active_users_last_1h:{product_id}", user1.product_user_id + )[0] + assert 3600 - 60 <= ttl <= 3600 + ttl = event_manager.redis_client.httl( + f"active_users_last_24h:{product_id}", user1.product_user_id + )[0] + assert sec_24hr - 60 <= ttl <= sec_24hr + + ttl = event_manager.redis_client.httl("signups_last_24h", user1.uuid)[0] + assert sec_24hr - 60 <= ttl <= sec_24hr + + ttl = event_manager.redis_client.httl("in_progress_users", user1.uuid)[0] + assert 3600 - 60 <= ttl <= 3600 + + +class TestSessionStats: + + def test_run_empty(self, event_manager, product_id): + res = event_manager.get_session_stats(product_id) + assert res == { + "session_enters_last_1h": 0, + "session_enters_last_24h": 0, + "session_fails_last_1h": 0, + "session_fails_last_24h": 0, + "session_completes_last_1h": 0, + "session_completes_last_24h": 0, + "sum_payouts_last_1h": 0, + "sum_payouts_last_24h": 0, + "sum_user_payouts_last_1h": 0, + "sum_user_payouts_last_24h": 0, + "session_avg_payout_last_24h": None, + "session_avg_user_payout_last_24h": None, + "session_complete_avg_loi_last_24h": None, + "session_fail_avg_loi_last_24h": None, + } + + def test_run(self, event_manager, product_id, user_factory, utc_now, utc_hour_ago): + event_manager.clear_global_session_stats() + + user: User = user_factory() + session = Session( + country_iso="us", + started=utc_hour_ago + timedelta(minutes=10), + user=user, + ) + event_manager.session_on_enter(session=session, user=user) + session.update( + finished=utc_now, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("1.00"), + user_payout=Decimal("0.95"), + ) + event_manager.session_on_finish(session=session, user=user) + assert event_manager.get_session_stats(product_id) == { + "session_enters_last_1h": 1, + "session_enters_last_24h": 1, + "session_fails_last_1h": 0, + "session_fails_last_24h": 0, + "session_completes_last_1h": 1, + "session_completes_last_24h": 1, + "sum_payouts_last_1h": 100, + "sum_payouts_last_24h": 100, + "sum_user_payouts_last_1h": 95, + "sum_user_payouts_last_24h": 95, + "session_avg_payout_last_24h": 100, + "session_avg_user_payout_last_24h": 95, + "session_complete_avg_loi_last_24h": round(session.elapsed.total_seconds()), + "session_fail_avg_loi_last_24h": None, + } + + # The session gets inserted into redis using the session.finished + # timestamp, no matter what time it is right now. So we can + # kind of test the expiry by setting it to 61 min ago + session2 = Session( + country_iso="us", + started=utc_hour_ago - timedelta(minutes=10), + user=user, + ) + event_manager.session_on_enter(session=session2, user=user) + session2.update( + finished=utc_hour_ago - timedelta(minutes=1), + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("2.00"), + user_payout=Decimal("1.50"), + ) + event_manager.session_on_finish(session=session2, user=user) + avg_loi = ( + round(session.elapsed.total_seconds()) + + round(session2.elapsed.total_seconds()) + ) / 2 + assert event_manager.get_session_stats(product_id) == { + "session_enters_last_1h": 1, + "session_enters_last_24h": 2, + "session_fails_last_1h": 0, + "session_fails_last_24h": 0, + "session_completes_last_1h": 1, + "session_completes_last_24h": 2, + "sum_payouts_last_1h": 100, + "sum_payouts_last_24h": 300, + "sum_user_payouts_last_1h": 95, + "sum_user_payouts_last_24h": 95 + 150, + "session_avg_payout_last_24h": math.ceil((100 + 200) / 2), + "session_avg_user_payout_last_24h": math.ceil((95 + 150) / 2), + "session_complete_avg_loi_last_24h": avg_loi, + "session_fail_avg_loi_last_24h": None, + } + + # Don't want to wait an hour, so confirm the keys will expire + name = "session_completes_last_1h:" + product_id + res = event_manager.redis_client.hgetall(name) + field = (int(utc_now.timestamp()) // 60) * 60 + field_name = str(field) + assert res == {field_name: "1"} + assert ( + 3600 - 60 < event_manager.redis_client.httl(name, field_name)[0] < 3600 + 60 + ) + + # Second BP, fail + product_id2 = uuid4().hex + user2: User = user_factory(product_id=product_id2) + session3 = Session( + country_iso="us", + started=utc_now - timedelta(minutes=1), + user=user2, + ) + event_manager.session_on_enter(session=session3, user=user) + session3.update( + finished=utc_now, + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + ) + event_manager.session_on_finish(session=session3, user=user) + avg_loi_complete = ( + round(session.elapsed.total_seconds()) + + round(session2.elapsed.total_seconds()) + ) / 2 + assert event_manager.get_session_stats(product_id) == { + "session_enters_last_1h": 2, + "session_enters_last_24h": 3, + "session_fails_last_1h": 1, + "session_fails_last_24h": 1, + "session_completes_last_1h": 1, + "session_completes_last_24h": 2, + "sum_payouts_last_1h": 100, + "sum_payouts_last_24h": 300, + "sum_user_payouts_last_1h": 95, + "sum_user_payouts_last_24h": 95 + 150, + "session_avg_payout_last_24h": math.ceil((100 + 200) / 2), + "session_avg_user_payout_last_24h": math.ceil((95 + 150) / 2), + "session_complete_avg_loi_last_24h": avg_loi_complete, + "session_fail_avg_loi_last_24h": round(session3.elapsed.total_seconds()), + } + + +class TestTaskStatsManager: + def test_empty(self, event_manager): + event_manager.clear_task_stats() + assert event_manager.get_task_stats_raw() == { + "live_task_count": AggregateBySource(total=0), + "live_tasks_max_payout": MaxGaugeBySource(value=None), + "task_created_count_last_1h": AggregateBySource(total=0), + "task_created_count_last_24h": AggregateBySource(total=0), + } + assert event_manager.get_latest_task_stats() is None + + sm = event_manager.get_stats_message(product_id=uuid4().hex) + assert sm.data.task_created_count_last_24h.total == 0 + assert sm.data.live_tasks_max_payout.value is None + + def test(self, event_manager): + event_manager.clear_task_stats() + event_manager.set_source_task_stats( + source=Source.TESTING, + live_task_count=100, + live_tasks_max_payout=Decimal("1.00"), + ) + assert event_manager.get_task_stats_raw() == { + "live_task_count": AggregateBySource( + total=100, by_source={Source.TESTING: 100} + ), + "live_tasks_max_payout": MaxGaugeBySource( + value=100, by_source={Source.TESTING: 100} + ), + "task_created_count_last_1h": AggregateBySource(total=0), + "task_created_count_last_24h": AggregateBySource(total=0), + } + event_manager.set_source_task_stats( + source=Source.TESTING2, + live_task_count=50, + live_tasks_max_payout=Decimal("2.00"), + ) + assert event_manager.get_task_stats_raw() == { + "live_task_count": AggregateBySource( + total=150, by_source={Source.TESTING: 100, Source.TESTING2: 50} + ), + "live_tasks_max_payout": MaxGaugeBySource( + value=200, by_source={Source.TESTING: 100, Source.TESTING2: 200} + ), + "task_created_count_last_1h": AggregateBySource(total=0), + "task_created_count_last_24h": AggregateBySource(total=0), + } + event_manager.set_source_task_stats( + source=Source.TESTING, + live_task_count=101, + live_tasks_max_payout=Decimal("1.50"), + ) + assert event_manager.get_task_stats_raw() == { + "live_task_count": AggregateBySource( + total=151, by_source={Source.TESTING: 101, Source.TESTING2: 50} + ), + "live_tasks_max_payout": MaxGaugeBySource( + value=200, by_source={Source.TESTING: 150, Source.TESTING2: 200} + ), + "task_created_count_last_1h": AggregateBySource(total=0), + "task_created_count_last_24h": AggregateBySource(total=0), + } + event_manager.set_source_task_stats( + source=Source.TESTING, + live_task_count=99, + live_tasks_max_payout=Decimal("2.50"), + ) + assert event_manager.get_task_stats_raw() == { + "live_task_count": AggregateBySource( + total=149, by_source={Source.TESTING: 99, Source.TESTING2: 50} + ), + "live_tasks_max_payout": MaxGaugeBySource( + value=250, by_source={Source.TESTING: 250, Source.TESTING2: 200} + ), + "task_created_count_last_1h": AggregateBySource(total=0), + "task_created_count_last_24h": AggregateBySource(total=0), + } + event_manager.set_source_task_stats( + source=Source.TESTING, live_task_count=0, live_tasks_max_payout=Decimal("0") + ) + assert event_manager.get_task_stats_raw() == { + "live_task_count": AggregateBySource( + total=50, by_source={Source.TESTING2: 50} + ), + "live_tasks_max_payout": MaxGaugeBySource( + value=200, by_source={Source.TESTING2: 200} + ), + "task_created_count_last_1h": AggregateBySource(total=0), + "task_created_count_last_24h": AggregateBySource(total=0), + } + + event_manager.set_source_task_stats( + source=Source.TESTING, + live_task_count=0, + live_tasks_max_payout=Decimal("0"), + created_count=10, + ) + res = event_manager.get_task_stats_raw() + assert res["task_created_count_last_1h"] == AggregateBySource( + total=10, by_source={Source.TESTING: 10} + ) + assert res["task_created_count_last_24h"] == AggregateBySource( + total=10, by_source={Source.TESTING: 10} + ) + + event_manager.set_source_task_stats( + source=Source.TESTING, + live_task_count=0, + live_tasks_max_payout=Decimal("0"), + created_count=10, + ) + res = event_manager.get_task_stats_raw() + assert res["task_created_count_last_1h"] == AggregateBySource( + total=20, by_source={Source.TESTING: 20} + ) + assert res["task_created_count_last_24h"] == AggregateBySource( + total=20, by_source={Source.TESTING: 20} + ) + + event_manager.set_source_task_stats( + source=Source.TESTING2, + live_task_count=0, + live_tasks_max_payout=Decimal("0"), + created_count=1, + ) + res = event_manager.get_task_stats_raw() + assert res["task_created_count_last_1h"] == AggregateBySource( + total=21, by_source={Source.TESTING: 20, Source.TESTING2: 1} + ) + assert res["task_created_count_last_24h"] == AggregateBySource( + total=21, by_source={Source.TESTING: 20, Source.TESTING2: 1} + ) + + sm = event_manager.get_stats_message(product_id=uuid4().hex) + assert sm.data.task_created_count_last_24h.total == 21 + + +class TestChannelsSubscriptions: + def test_stats_worker( + self, + event_manager, + event_subscriber, + product_id, + user_factory, + utc_hour_ago, + utc_now, + ): + event_manager.clear_stats() + assert event_subscriber.pubsub + # We subscribed, so manually trigger the stats worker + # to get all product_ids subscribed and publish a + # stats message into that channel + event_manager.stats_worker_task() + assert product_id in event_manager.get_active_subscribers() + msg = event_subscriber.get_next_message() + print(msg) + assert msg.kind == MessageKind.STATS + + user = user_factory() + session = Session( + country_iso="us", + started=utc_hour_ago, + user=user, + clicked_bucket=Bucket(user_payout_min=Decimal("0.50")), + id=1, + ) + + event_manager.handle_session_enter(session, user) + msg = event_subscriber.get_next_message() + print(msg) + assert msg.kind == MessageKind.EVENT + assert msg.data.event_type == EventType.SESSION_ENTER + + wall = Wall( + req_survey_id="a", + req_cpi=Decimal("1"), + source=Source.TESTING, + session_id=session.id, + user_id=user.user_id, + ) + + event_manager.handle_task_enter(wall, session, user) + msg = event_subscriber.get_next_message() + print(msg) + assert msg.kind == MessageKind.EVENT + assert msg.data.event_type == EventType.TASK_ENTER + + wall.update( + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + finished=datetime.now(tz=timezone.utc), + cpi=Decimal("1"), + ) + event_manager.handle_task_finish(wall, session, user) + msg = event_subscriber.get_next_message() + print(msg) + assert msg.kind == MessageKind.EVENT + assert msg.data.event_type == EventType.TASK_FINISH + + session.update( + finished=utc_now, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("1.00"), + user_payout=Decimal("1.00"), + ) + + event_manager.handle_session_finish(session, user) + msg = event_subscriber.get_next_message() + print(msg) + assert msg.kind == MessageKind.EVENT + assert msg.data.event_type == EventType.SESSION_FINISH + + event_manager.stats_worker_task() + assert product_id in event_manager.get_active_subscribers() + msg = event_subscriber.get_next_message() + print(msg) + assert msg.kind == MessageKind.STATS + assert msg.data.active_users_last_1h == 1 + assert msg.data.session_enters_last_24h == 1 + assert msg.data.session_completes_last_1h == 1 + assert msg.data.signups_last_24h == 1 diff --git a/tests/managers/test_lucid.py b/tests/managers/test_lucid.py new file mode 100644 index 0000000..1a1bae7 --- /dev/null +++ b/tests/managers/test_lucid.py @@ -0,0 +1,23 @@ +import pytest + +from generalresearch.managers.lucid.profiling import get_profiling_library + +qids = ["42", "43", "45", "97", "120", "639", "15297"] + + +class TestLucidProfiling: + + @pytest.mark.skip + def test_get_library(self, thl_web_rr): + pks = [(qid, "us", "eng") for qid in qids] + qs = get_profiling_library(thl_web_rr, pks=pks) + assert len(qids) == len(qs) + + # just making sure this doesn't raise errors + for q in qs: + q.to_upk_question() + + # a lot will fail parsing because they have no options or the options are blank + # just asserting that we get some back + qs = get_profiling_library(thl_web_rr, country_iso="mx", language_iso="spa") + assert len(qs) > 100 diff --git a/tests/managers/test_userpid.py b/tests/managers/test_userpid.py new file mode 100644 index 0000000..4a3f699 --- /dev/null +++ b/tests/managers/test_userpid.py @@ -0,0 +1,68 @@ +import pytest +from pydantic import MySQLDsn + +from generalresearch.managers.marketplace.user_pid import UserPidMultiManager +from generalresearch.sql_helper import SqlHelper +from generalresearch.managers.cint.user_pid import CintUserPidManager +from generalresearch.managers.dynata.user_pid import DynataUserPidManager +from generalresearch.managers.innovate.user_pid import InnovateUserPidManager +from generalresearch.managers.morning.user_pid import MorningUserPidManager + +# from generalresearch.managers.precision import PrecisionUserPidManager +from generalresearch.managers.prodege.user_pid import ProdegeUserPidManager +from generalresearch.managers.repdata.user_pid import RepdataUserPidManager +from generalresearch.managers.sago.user_pid import SagoUserPidManager +from generalresearch.managers.spectrum.user_pid import SpectrumUserPidManager + +dsn = "" + + +class TestCintUserPidManager: + + def test_filter(self): + m = CintUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-cint"))) + + with pytest.raises(expected_exception=AssertionError) as excinfo: + m.filter() + assert str(excinfo.value) == "Must pass ONE of user_ids, pids" + + with pytest.raises(expected_exception=AssertionError) as excinfo: + m.filter(user_ids=[1, 2, 3], pids=["ed5b47c8551d453d985501391f190d3f"]) + assert str(excinfo.value) == "Must pass ONE of user_ids, pids" + + # pids get .hex before and after + assert m.filter(pids=["ed5b47c8551d453d985501391f190d3f"]) == m.filter( + pids=["ed5b47c8-551d-453d-9855-01391f190d3f"] + ) + + # user_ids and pids are for the same 3 users + res1 = m.filter(user_ids=[61586871, 61458915, 61390116]) + res2 = m.filter( + pids=[ + "ed5b47c8551d453d985501391f190d3f", + "7e640732c59f43d1b7c00137ab66600c", + "5160aeec9c3b4dbb85420128e6da6b5a", + ] + ) + assert res1 == res2 + + +class TestUserPidMultiManager: + + def test_filter(self): + managers = [ + CintUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-cint"))), + DynataUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-dynata"))), + InnovateUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-innovate"))), + MorningUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-morning"))), + ProdegeUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-prodege"))), + RepdataUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-repdata"))), + SagoUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-sago"))), + SpectrumUserPidManager(SqlHelper(MySQLDsn(dsn + "thl-spectrum"))), + ] + m = UserPidMultiManager(sql_helper=SqlHelper(MySQLDsn(dsn)), managers=managers) + res = m.filter(user_ids=[1]) + assert len(res) == len(managers) + + res = m.filter(user_ids=[1, 2, 3]) + assert len(res) > len(managers) diff --git a/tests/managers/thl/__init__.py b/tests/managers/thl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/managers/thl/test_buyer.py b/tests/managers/thl/test_buyer.py new file mode 100644 index 0000000..69ea105 --- /dev/null +++ b/tests/managers/thl/test_buyer.py @@ -0,0 +1,25 @@ +from generalresearch.models import Source + + +class TestBuyer: + + def test( + self, + delete_buyers_surveys, + buyer_manager, + ): + + bs = buyer_manager.bulk_get_or_create(source=Source.TESTING, codes=["a", "b"]) + assert len(bs) == 2 + buyer_a = bs[0] + assert buyer_a.id is not None + bs2 = buyer_manager.bulk_get_or_create(source=Source.TESTING, codes=["a", "c"]) + assert len(bs2) == 2 + buyer_a2 = bs2[0] + buyer_c = bs2[1] + # a isn't created again + assert buyer_a == buyer_a2 + assert bs2[0].id is not None + + # and its cached + assert buyer_c.id == buyer_manager.source_code_pk[f"{Source.TESTING.value}:c"] diff --git a/tests/managers/thl/test_cashout_method.py b/tests/managers/thl/test_cashout_method.py new file mode 100644 index 0000000..fe561d3 --- /dev/null +++ b/tests/managers/thl/test_cashout_method.py @@ -0,0 +1,139 @@ +import pytest + +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.thl.wallet.cashout_method import ( + CashMailCashoutMethodData, + USDeliveryAddress, + PaypalCashoutMethodData, +) +from test_utils.managers.cashout_methods import ( + EXAMPLE_TANGO_CASHOUT_METHODS, + AMT_ASSIGNMENT_CASHOUT_METHOD, + AMT_BONUS_CASHOUT_METHOD, +) + + +class TestTangoCashoutMethods: + + def test_create_and_get(self, cashout_method_manager, setup_cashoutmethod_db): + res = cashout_method_manager.filter(payout_types=[PayoutType.TANGO]) + assert len(res) == 2 + cm = [x for x in res if x.ext_id == "U025035"][0] + assert EXAMPLE_TANGO_CASHOUT_METHODS[0] == cm + + def test_user( + self, cashout_method_manager, user_with_wallet, setup_cashoutmethod_db + ): + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + # This user ONLY has the two tango cashout methods, no AMT + assert len(res) == 2 + + +class TestAMTCashoutMethods: + + def test_create_and_get(self, cashout_method_manager, setup_cashoutmethod_db): + res = cashout_method_manager.filter(payout_types=[PayoutType.AMT]) + assert len(res) == 2 + cm = [x for x in res if x.name == "AMT Assignment"][0] + assert AMT_ASSIGNMENT_CASHOUT_METHOD == cm + cm = [x for x in res if x.name == "AMT Bonus"][0] + assert AMT_BONUS_CASHOUT_METHOD == cm + + def test_user( + self, cashout_method_manager, user_with_wallet_amt, setup_cashoutmethod_db + ): + res = cashout_method_manager.get_cashout_methods(user_with_wallet_amt) + # This user has the 2 tango, plus amt bonus & assignment + assert len(res) == 4 + + +class TestUserCashoutMethods: + + def test(self, cashout_method_manager, user_with_wallet, delete_cashoutmethod_db): + delete_cashoutmethod_db() + + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 0 + + def test_cash_in_mail( + self, cashout_method_manager, user_with_wallet, delete_cashoutmethod_db + ): + delete_cashoutmethod_db() + + data = CashMailCashoutMethodData( + delivery_address=USDeliveryAddress.model_validate( + { + "name_or_attn": "Josh Ackerman", + "address": "123 Fake St", + "city": "San Francisco", + "state": "CA", + "postal_code": "12345", + } + ) + ) + cashout_method_manager.create_cash_in_mail_cashout_method( + data=data, user=user_with_wallet + ) + + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 1 + assert res[0].data.delivery_address.postal_code == "12345" + + # try to create the same one again. should just do nothing + cashout_method_manager.create_cash_in_mail_cashout_method( + data=data, user=user_with_wallet + ) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 1 + + # Create with a new address, will create a new one + data.delivery_address.postal_code = "99999" + cashout_method_manager.create_cash_in_mail_cashout_method( + data=data, user=user_with_wallet + ) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 2 + + def test_paypal( + self, cashout_method_manager, user_with_wallet, delete_cashoutmethod_db + ): + delete_cashoutmethod_db() + + data = PaypalCashoutMethodData(email="test@example.com") + cashout_method_manager.create_paypal_cashout_method( + data=data, user=user_with_wallet + ) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 1 + assert res[0].data.email == "test@example.com" + + # try to create the same one again. should just do nothing + cashout_method_manager.create_paypal_cashout_method( + data=data, user=user_with_wallet + ) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 1 + + # Create with a new email, will error! must delete the old one first. + # We can only have one paypal active + data.email = "test2@example.com" + with pytest.raises( + ValueError, + match="User already has a cashout method of this type. Delete the existing one and try again.", + ): + cashout_method_manager.create_paypal_cashout_method( + data=data, user=user_with_wallet + ) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 1 + + cashout_method_manager.delete_cashout_method(res[0].id) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 0 + + cashout_method_manager.create_paypal_cashout_method( + data=data, user=user_with_wallet + ) + res = cashout_method_manager.get_cashout_methods(user_with_wallet) + assert len(res) == 1 + assert res[0].data.email == "test2@example.com" diff --git a/tests/managers/thl/test_category.py b/tests/managers/thl/test_category.py new file mode 100644 index 0000000..ad0f07b --- /dev/null +++ b/tests/managers/thl/test_category.py @@ -0,0 +1,100 @@ +import pytest + +from generalresearch.models.thl.category import Category + + +class TestCategory: + + @pytest.fixture + def beauty_fitness(self, thl_web_rw): + + return Category( + uuid="12c1e96be82c4642a07a12a90ce6f59e", + adwords_vertical_id="44", + label="Beauty & Fitness", + path="/Beauty & Fitness", + ) + + @pytest.fixture + def hair_care(self, beauty_fitness): + + return Category( + uuid="dd76c4b565d34f198dad3687326503d6", + adwords_vertical_id="146", + label="Hair Care", + path="/Beauty & Fitness/Hair Care", + ) + + @pytest.fixture + def hair_loss(self, hair_care): + + return Category( + uuid="aacff523c8e246888215611ec3b823c0", + adwords_vertical_id="235", + label="Hair Loss", + path="/Beauty & Fitness/Hair Care/Hair Loss", + ) + + @pytest.fixture + def category_data( + self, category_manager, thl_web_rw, beauty_fitness, hair_care, hair_loss + ): + cats = [beauty_fitness, hair_care, hair_loss] + data = [x.model_dump(mode="json") for x in cats] + # We need the parent pk's to set the parent_id. So insert all without a parent, + # then pull back all pks and map to the parents as parsed by the parent_path + query = """ + INSERT INTO marketplace_category + (uuid, adwords_vertical_id, label, path) + VALUES + (%(uuid)s, %(adwords_vertical_id)s, %(label)s, %(path)s) + ON CONFLICT (uuid) DO NOTHING; + """ + with thl_web_rw.make_connection() as conn: + with conn.cursor() as c: + c.executemany(query=query, params_seq=data) + conn.commit() + + res = thl_web_rw.execute_sql_query("SELECT id, path FROM marketplace_category") + path_id = {x["path"]: x["id"] for x in res} + data = [ + {"id": path_id[c.path], "parent_id": path_id[c.parent_path]} + for c in cats + if c.parent_path + ] + query = """ + UPDATE marketplace_category + SET parent_id = %(parent_id)s + WHERE id = %(id)s; + """ + with thl_web_rw.make_connection() as conn: + with conn.cursor() as c: + c.executemany(query=query, params_seq=data) + conn.commit() + + category_manager.populate_caches() + + def test( + self, + category_data, + category_manager, + beauty_fitness, + hair_care, + hair_loss, + ): + # category_manager on init caches the category info. This rarely/never changes so this is fine, + # but now that tests get run on a new db each time, the category_manager is inited before + # the fixtures run. so category_manager's cache needs to be rerun + + # path='/Beauty & Fitness/Hair Care/Hair Loss' + c: Category = category_manager.get_by_label("Hair Loss") + # Beauty & Fitness + assert beauty_fitness.uuid == category_manager.get_top_level(c).uuid + + c: Category = category_manager.categories[beauty_fitness.uuid] + # The root is itself + assert c == category_manager.get_category_root(c) + + # The root is Beauty & Fitness + c: Category = category_manager.get_by_label("Hair Loss") + assert beauty_fitness.uuid == category_manager.get_category_root(c).uuid diff --git a/tests/managers/thl/test_contest/__init__.py b/tests/managers/thl/test_contest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/managers/thl/test_contest/test_leaderboard.py b/tests/managers/thl/test_contest/test_leaderboard.py new file mode 100644 index 0000000..80a88a5 --- /dev/null +++ b/tests/managers/thl/test_contest/test_leaderboard.py @@ -0,0 +1,138 @@ +from datetime import datetime, timezone, timedelta +from zoneinfo import ZoneInfo + +from generalresearch.currency import USDCent +from generalresearch.models.thl.contest.definitions import ( + ContestStatus, + ContestEndReason, +) +from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContest, + LeaderboardContestCreate, +) +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.user import User +from test_utils.managers.contest.conftest import ( + leaderboard_contest_in_db as contest_in_db, + leaderboard_contest_create as contest_create, +) + + +class TestLeaderboardContestCRUD: + + def test_create( + self, + contest_create: LeaderboardContestCreate, + product_user_wallet_yes: Product, + thl_lm, + contest_manager, + ): + c = contest_manager.create( + product_id=product_user_wallet_yes.uuid, contest_create=contest_create + ) + c_out = contest_manager.get(c.uuid) + assert c == c_out + + assert isinstance(c, LeaderboardContest) + assert c.prize_count == 2 + assert c.status == ContestStatus.ACTIVE + # We have it set in the fixture as the daily contest for 2025-01-01 + assert c.end_condition.ends_at == datetime( + 2025, 1, 1, 23, 59, 59, 999999, tzinfo=ZoneInfo("America/New_York") + ).astimezone(tz=timezone.utc) + timedelta(minutes=90) + + def test_enter( + self, + user_with_wallet: User, + contest_in_db: LeaderboardContest, + thl_lm, + contest_manager, + user_manager, + thl_redis, + ): + contest = contest_in_db + user = user_with_wallet + + c: LeaderboardContest = contest_manager.get(contest_uuid=contest.uuid) + + c = contest_manager.get_leaderboard_user_view( + contest_uuid=contest.uuid, + user=user, + redis_client=thl_redis, + user_manager=user_manager, + ) + assert c.user_rank is None + + lbm = c.get_leaderboard_manager() + lbm.hit_complete_count(user.product_user_id) + + c = contest_manager.get_leaderboard_user_view( + contest_uuid=contest.uuid, + user=user, + redis_client=thl_redis, + user_manager=user_manager, + ) + assert c.user_rank == 1 + + def test_contest_ends( + self, + user_with_wallet: User, + contest_in_db: LeaderboardContest, + thl_lm, + contest_manager, + user_manager, + thl_redis, + ): + # The contest should be over. We need to trigger it. + contest = contest_in_db + contest._redis_client = thl_redis + contest._user_manager = user_manager + user = user_with_wallet + + lbm = contest.get_leaderboard_manager() + lbm.hit_complete_count(user.product_user_id) + + c = contest_manager.get_leaderboard_user_view( + contest_uuid=contest.uuid, + user=user, + redis_client=thl_redis, + user_manager=user_manager, + ) + assert c.user_rank == 1 + + bp_wallet = thl_lm.get_account_or_create_bp_wallet_by_uuid(user.product_id) + bp_wallet_balance = thl_lm.get_account_balance(account=bp_wallet) + assert bp_wallet_balance == 0 + user_wallet = thl_lm.get_account_or_create_user_wallet(user=user) + user_balance = thl_lm.get_account_balance(user_wallet) + assert user_balance == 0 + + decision, reason = contest.should_end() + assert decision + assert reason == ContestEndReason.ENDS_AT + + contest_manager.end_contest_if_over(contest=contest, ledger_manager=thl_lm) + + c: LeaderboardContest = contest_manager.get(contest_uuid=contest.uuid) + assert c.status == ContestStatus.COMPLETED + print(c) + + user_contest = contest_manager.get_leaderboard_user_view( + contest_uuid=contest.uuid, + user=user, + redis_client=thl_redis, + user_manager=user_manager, + ) + assert len(user_contest.user_winnings) == 1 + w = user_contest.user_winnings[0] + assert w.product_user_id == user.product_user_id + assert w.prize.cash_amount == USDCent(15_00) + + # The prize is $15.00, so the user should get $15, paid by the bp + assert thl_lm.get_account_balance(account=user_wallet) == 15_00 + # contest wallet is 0, and the BP gets 20c + contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid( + contest_uuid=c.uuid + ) + assert thl_lm.get_account_balance(account=contest_wallet) == 0 + assert thl_lm.get_account_balance(account=bp_wallet) == -15_00 diff --git a/tests/managers/thl/test_contest/test_milestone.py b/tests/managers/thl/test_contest/test_milestone.py new file mode 100644 index 0000000..7312a64 --- /dev/null +++ b/tests/managers/thl/test_contest/test_milestone.py @@ -0,0 +1,296 @@ +from datetime import datetime, timezone + +from generalresearch.models.thl.contest.definitions import ( + ContestStatus, + ContestEndReason, +) +from generalresearch.models.thl.contest.milestone import ( + MilestoneContest, + MilestoneContestCreate, + MilestoneUserView, + ContestEntryTrigger, +) +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.user import User +from test_utils.managers.contest.conftest import ( + milestone_contest as contest, + milestone_contest_in_db as contest_in_db, + milestone_contest_create as contest_create, + milestone_contest_factory as contest_factory, +) + + +class TestMilestoneContest: + + def test_should_end(self, contest: MilestoneContest, thl_lm, contest_manager): + # contest is active and has no entries + should, msg = contest.should_end() + assert not should, msg + + # Change so that the contest ends now + contest.end_condition.ends_at = datetime.now(tz=timezone.utc) + should, msg = contest.should_end() + assert should + assert msg == ContestEndReason.ENDS_AT + + # Change the win amount it thinks it past over the target + contest.end_condition.ends_at = None + contest.end_condition.max_winners = 10 + contest.win_count = 10 + should, msg = contest.should_end() + assert should + assert msg == ContestEndReason.MAX_WINNERS + + +class TestMilestoneContestCRUD: + + def test_create( + self, + contest_create: MilestoneContestCreate, + product_user_wallet_yes: Product, + thl_lm, + contest_manager, + ): + c = contest_manager.create( + product_id=product_user_wallet_yes.uuid, contest_create=contest_create + ) + c_out = contest_manager.get(c.uuid) + assert c == c_out + + assert isinstance(c, MilestoneContest) + assert c.prize_count == 2 + assert c.status == ContestStatus.ACTIVE + assert c.end_condition.max_winners == 5 + assert c.entry_trigger == ContestEntryTrigger.TASK_COMPLETE + assert c.target_amount == 3 + assert c.win_count == 0 + + def test_enter( + self, + user_with_wallet: User, + contest_in_db: MilestoneContest, + thl_lm, + contest_manager, + ): + # Users CANNOT directly enter a milestone contest through the api, + # but we'll call this manager method when a trigger is hit. + contest = contest_in_db + user = user_with_wallet + + contest_manager.enter_milestone_contest( + contest_uuid=contest.uuid, + user=user, + country_iso="us", + ledger_manager=thl_lm, + incr=1, + ) + + c: MilestoneContest = contest_manager.get(contest_uuid=contest.uuid) + assert c.status == ContestStatus.ACTIVE + assert not hasattr(c, "current_amount") + assert not hasattr(c, "current_participants") + + c: MilestoneUserView = contest_manager.get_milestone_user_view( + contest_uuid=contest.uuid, user=user_with_wallet + ) + assert c.user_amount == 1 + + # Contest wallet should have 0 bc there is no ledger + contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid( + contest_uuid=contest.uuid + ) + assert thl_lm.get_account_balance(contest_wallet) == 0 + + # Enter again! + contest_manager.enter_milestone_contest( + contest_uuid=contest.uuid, + user=user, + country_iso="us", + ledger_manager=thl_lm, + incr=1, + ) + c: MilestoneUserView = contest_manager.get_milestone_user_view( + contest_uuid=contest.uuid, user=user_with_wallet + ) + assert c.user_amount == 2 + + # We should have ONE entry with a value of 2 + e = contest_manager.get_entries_by_contest_id(c.id) + assert len(e) == 1 + assert e[0].amount == 2 + + def test_enter_win( + self, + user_with_wallet: User, + contest_in_db: MilestoneContest, + thl_lm, + contest_manager, + ): + # User enters contest, which brings the USER'S total amount above the limit, + # and the user reaches the milestone + contest = contest_in_db + user = user_with_wallet + + user_wallet = thl_lm.get_account_or_create_user_wallet(user=user) + user_balance = thl_lm.get_account_balance(account=user_wallet) + bp_wallet = thl_lm.get_account_or_create_bp_wallet_by_uuid( + product_uuid=user.product_id + ) + bp_wallet_balance = thl_lm.get_account_balance(account=bp_wallet) + + c: MilestoneUserView = contest_manager.get_milestone_user_view( + contest_uuid=contest.uuid, user=user_with_wallet + ) + assert c.user_amount == 0 + res, msg = c.is_user_eligible(country_iso="us") + assert res, msg + + # User reaches the milestone after 3 completes/whatevers. + for _ in range(3): + contest_manager.enter_milestone_contest( + contest_uuid=contest.uuid, + user=user, + country_iso="us", + ledger_manager=thl_lm, + incr=1, + ) + + # to be clear, the contest itself doesn't end! + c: MilestoneContest = contest_manager.get(contest_uuid=contest.uuid) + assert c.status == ContestStatus.ACTIVE + + c: MilestoneUserView = contest_manager.get_milestone_user_view( + contest_uuid=contest.uuid, user=user_with_wallet + ) + assert c.user_amount == 3 + res, msg = c.is_user_eligible(country_iso="us") + assert not res + assert msg == "User should have won already" + + assert len(c.user_winnings) == 2 + assert c.win_count == 1 + + # The prize was awarded! User should have won $1.00 + assert thl_lm.get_account_balance(user_wallet) - user_balance == 100 + # Which was paid from the BP's balance + assert thl_lm.get_account_balance(bp_wallet) - bp_wallet_balance == -100 + + # winnings = cm.get_winnings_by_user(user=user) + # assert len(winnings) == 1 + # win = winnings[0] + # assert win.product_user_id == user.product_user_id + + def test_enter_ends( + self, + user_factory, + product_user_wallet_yes: Product, + contest_in_db: MilestoneContest, + thl_lm, + contest_manager, + ): + # Multiple users reach the milestone. Contest ends after 5 wins. + users = [user_factory(product=product_user_wallet_yes) for _ in range(5)] + contest = contest_in_db + + for u in users: + contest_manager.enter_milestone_contest( + contest_uuid=contest.uuid, + user=u, + country_iso="us", + ledger_manager=thl_lm, + incr=3, + ) + + c: MilestoneContest = contest_manager.get(contest_uuid=contest.uuid) + assert c.status == ContestStatus.COMPLETED + assert c.end_reason == ContestEndReason.MAX_WINNERS + + def test_trigger( + self, + user_with_wallet: User, + contest_in_db: MilestoneContest, + thl_lm, + contest_manager, + ): + # Pretend user just got a complete + cnt = contest_manager.hit_milestone_triggers( + country_iso="us", + user=user_with_wallet, + event=ContestEntryTrigger.TASK_COMPLETE, + ledger_manager=thl_lm, + ) + assert cnt == 1 + + # Assert this contest got entered + c: MilestoneUserView = contest_manager.get_milestone_user_view( + contest_uuid=contest_in_db.uuid, user=user_with_wallet + ) + assert c.user_amount == 1 + + +class TestMilestoneContestUserViews: + def test_list_user_eligible_country( + self, user_with_wallet: User, contest_factory, thl_lm, contest_manager + ): + # No contests exists + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="us" + ) + assert len(cs) == 0 + + # Create a contest. It'll be in the US/CA + contest_factory(country_isos={"us", "ca"}) + + # Not eligible in mexico + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="mx" + ) + assert len(cs) == 0 + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="us" + ) + assert len(cs) == 1 + + # Create another, any country + contest_factory(country_isos=None) + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="mx" + ) + assert len(cs) == 1 + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="us" + ) + assert len(cs) == 2 + + def test_list_user_eligible( + self, user_with_money: User, contest_factory, thl_lm, contest_manager + ): + # User reaches milestone after 1 complete + c = contest_factory(target_amount=1) + user = user_with_money + + cs = contest_manager.get_many_by_user_eligible( + user=user_with_money, country_iso="us" + ) + assert len(cs) == 1 + + contest_manager.enter_milestone_contest( + contest_uuid=c.uuid, user=user, country_iso="us", ledger_manager=thl_lm + ) + + # User isn't eligible anymore + cs = contest_manager.get_many_by_user_eligible( + user=user_with_money, country_iso="us" + ) + assert len(cs) == 0 + + # But it comes back in the list entered + cs = contest_manager.get_many_by_user_entered(user=user_with_money) + assert len(cs) == 1 + c = cs[0] + assert c.user_amount == 1 + assert isinstance(c, MilestoneUserView) + assert not hasattr(c, "current_win_probability") + + # They won one contest with 2 prizes + assert len(contest_manager.get_winnings_by_user(user_with_money)) == 2 diff --git a/tests/managers/thl/test_contest/test_raffle.py b/tests/managers/thl/test_contest/test_raffle.py new file mode 100644 index 0000000..060055a --- /dev/null +++ b/tests/managers/thl/test_contest/test_raffle.py @@ -0,0 +1,474 @@ +from datetime import datetime, timezone + +import pytest +from pydantic import ValidationError +from pytest import approx + +from generalresearch.currency import USDCent +from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerTransactionConditionFailedError, +) +from generalresearch.models.thl.contest import ( + ContestPrize, + ContestEntryRule, + ContestEndCondition, +) +from generalresearch.models.thl.contest.definitions import ( + ContestStatus, + ContestPrizeKind, + ContestEndReason, +) +from generalresearch.models.thl.contest.exceptions import ContestError +from generalresearch.models.thl.contest.raffle import ( + ContestEntry, + ContestEntryType, +) +from generalresearch.models.thl.contest.raffle import ( + RaffleContest, + RaffleContestCreate, + RaffleUserView, +) +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.user import User +from test_utils.managers.contest.conftest import ( + raffle_contest as contest, + raffle_contest_in_db as contest_in_db, + raffle_contest_create as contest_create, + raffle_contest_factory as contest_factory, +) + + +class TestRaffleContest: + + def test_should_end(self, contest: RaffleContest, thl_lm, contest_manager): + # contest is active and has no entries + should, msg = contest.should_end() + assert not should, msg + + # Change so that the contest ends now + contest.end_condition.ends_at = datetime.now(tz=timezone.utc) + should, msg = contest.should_end() + assert should + assert msg == ContestEndReason.ENDS_AT + + # Change the entry amount it thinks it has to over the target + contest.end_condition.ends_at = None + contest.current_amount = USDCent(100) + should, msg = contest.should_end() + assert should + assert msg == ContestEndReason.TARGET_ENTRY_AMOUNT + + +class TestRaffleContestCRUD: + + def test_create( + self, + contest_create: RaffleContestCreate, + product_user_wallet_yes: Product, + thl_lm, + contest_manager, + ): + c = contest_manager.create( + product_id=product_user_wallet_yes.uuid, contest_create=contest_create + ) + c_out = contest_manager.get(c.uuid) + assert c == c_out + + assert isinstance(c, RaffleContest) + assert c.prize_count == 1 + assert c.status == ContestStatus.ACTIVE + assert c.end_condition.target_entry_amount == USDCent(100) + assert c.current_amount == 0 + assert c.current_participants == 0 + + @pytest.mark.parametrize("user_with_money", [{"min_balance": 60}], indirect=True) + def test_enter( + self, + user_with_money: User, + contest_in_db: RaffleContest, + thl_lm, + contest_manager, + ): + # Raffle ends at $1.00. User enters for $0.60 + print(user_with_money.product_id) + print(contest_in_db.product_id) + print(contest_in_db.uuid) + contest = contest_in_db + + user_wallet = thl_lm.get_account_or_create_user_wallet(user=user_with_money) + user_balance = thl_lm.get_account_balance(account=user_wallet) + + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(60) + ) + entry = contest_manager.enter_contest( + contest_uuid=contest.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + c: RaffleContest = contest_manager.get(contest_uuid=contest.uuid) + assert c.current_amount == USDCent(60) + assert c.current_participants == 1 + assert c.status == ContestStatus.ACTIVE + + c: RaffleUserView = contest_manager.get_raffle_user_view( + contest_uuid=contest.uuid, user=user_with_money + ) + assert c.user_amount == USDCent(60) + assert c.user_amount_today == USDCent(60) + assert c.projected_win_probability == approx(60 / 100, rel=0.01) + + # Contest wallet should have $0.60 + contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid( + contest_uuid=contest.uuid + ) + assert thl_lm.get_account_balance(account=contest_wallet) == 60 + # User spent 60c + assert user_balance - thl_lm.get_account_balance(account=user_wallet) == 60 + + @pytest.mark.parametrize("user_with_money", [{"min_balance": 120}], indirect=True) + def test_enter_ends( + self, + user_with_money: User, + contest_in_db: RaffleContest, + thl_lm, + contest_manager, + ): + # User enters contest, which brings the total amount above the limit, + # and the contest should end, with a winner selected + contest = contest_in_db + + bp_wallet = thl_lm.get_account_or_create_bp_wallet_by_uuid( + user_with_money.product_id + ) + # I bribed the user, so the balance is not 0 + bp_wallet_balance = thl_lm.get_account_balance(account=bp_wallet) + + for _ in range(2): + entry = ContestEntry( + entry_type=ContestEntryType.CASH, + user=user_with_money, + amount=USDCent(60), + ) + contest_manager.enter_contest( + contest_uuid=contest.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + c: RaffleContest = contest_manager.get(contest_uuid=contest.uuid) + assert c.status == ContestStatus.COMPLETED + print(c) + + user_contest = contest_manager.get_raffle_user_view( + contest_uuid=contest.uuid, user=user_with_money + ) + assert user_contest.current_win_probability == 1 + assert user_contest.projected_win_probability == 1 + assert len(user_contest.user_winnings) == 1 + + # todo: make a all winning method + winnings = contest_manager.get_winnings_by_user(user=user_with_money) + assert len(winnings) == 1 + win = winnings[0] + assert win.product_user_id == user_with_money.product_user_id + + # Contest wallet should have gotten zeroed out + contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid( + contest_uuid=contest.uuid + ) + assert thl_lm.get_account_balance(contest_wallet) == 0 + # Expense wallet gets the $1.00 expense + expense_wallet = thl_lm.get_account_or_create_bp_expense_by_uuid( + product_uuid=user_with_money.product_id, expense_name="Prize" + ) + assert thl_lm.get_account_balance(expense_wallet) == -100 + # And the BP gets 20c + assert thl_lm.get_account_balance(bp_wallet) - bp_wallet_balance == 20 + + @pytest.mark.parametrize("user_with_money", [{"min_balance": 120}], indirect=True) + def test_enter_ends_cash_prize( + self, user_with_money: User, contest_factory, thl_lm, contest_manager + ): + # Same as test_enter_ends, but the prize is cash. Just + # testing the ledger methods + c = contest_factory( + prizes=[ + ContestPrize( + name="$1.00 bonus", + kind=ContestPrizeKind.CASH, + estimated_cash_value=USDCent(100), + cash_amount=USDCent(100), + ) + ] + ) + assert c.prizes[0].kind == ContestPrizeKind.CASH + + user_wallet = thl_lm.get_account_or_create_user_wallet(user=user_with_money) + user_balance = thl_lm.get_account_balance(user_wallet) + bp_wallet = thl_lm.get_account_or_create_bp_wallet_by_uuid( + user_with_money.product_id + ) + bp_wallet_balance = thl_lm.get_account_balance(bp_wallet) + + ## Enter Contest + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(120) + ) + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + + # The prize is $1.00, so the user spent $1.20 entering, won, then got $1.00 back + assert ( + thl_lm.get_account_balance(account=user_wallet) == user_balance + 100 - 120 + ) + # contest wallet is 0, and the BP gets 20c + contest_wallet = thl_lm.get_account_or_create_contest_wallet_by_uuid( + contest_uuid=c.uuid + ) + assert thl_lm.get_account_balance(account=contest_wallet) == 0 + assert thl_lm.get_account_balance(account=bp_wallet) - bp_wallet_balance == 20 + + def test_enter_failure( + self, + user_with_wallet: User, + contest_in_db: RaffleContest, + thl_lm, + contest_manager, + ): + c = contest_in_db + user = user_with_wallet + + # Tries to enter $0 + with pytest.raises(ValidationError) as e: + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user, amount=USDCent(0) + ) + assert "Input should be greater than 0" in str(e.value) + + # User has no money + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user, amount=USDCent(20) + ) + with pytest.raises(LedgerTransactionConditionFailedError) as e: + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + assert e.value.args[0] == "insufficient balance" + + # Tries to enter with the wrong entry type (count, on a cash contest) + entry = ContestEntry(entry_type=ContestEntryType.COUNT, user=user, amount=1) + with pytest.raises(AssertionError) as e: + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + assert "incompatible entry type" in str(e.value) + + @pytest.mark.parametrize("user_with_money", [{"min_balance": 100}], indirect=True) + def test_enter_not_eligible( + self, user_with_money: User, contest_factory, thl_lm, contest_manager + ): + # Max entry amount per user $0.10. Contest still ends at $1.00 + c = contest_factory( + entry_rule=ContestEntryRule( + max_entry_amount_per_user=USDCent(10), + max_daily_entries_per_user=USDCent(8), + ) + ) + c: RaffleContest = contest_manager.get(c.uuid) + assert c.entry_rule.max_entry_amount_per_user == USDCent(10) + assert c.entry_rule.max_daily_entries_per_user == USDCent(8) + + # User tries to enter $0.20 + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(20) + ) + with pytest.raises(ContestError) as e: + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + assert "Entry would exceed max amount per user." in str(e.value) + + # User tries to enter $0.10 + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(10) + ) + with pytest.raises(ContestError) as e: + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + assert "Entry would exceed max amount per user per day." in str(e.value) + + # User enters $0.08 successfully + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(8) + ) + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + + # Then can't anymore + entry = ContestEntry( + entry_type=ContestEntryType.CASH, user=user_with_money, amount=USDCent(1) + ) + with pytest.raises(ContestError) as e: + entry = contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + assert "Entry would exceed max amount per user per day." in str(e.value) + + +class TestRaffleContestUserViews: + def test_list_user_eligible_country( + self, user_with_wallet: User, contest_factory, thl_lm, contest_manager + ): + # No contests exists + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="us" + ) + assert len(cs) == 0 + + # Create a contest. It'll be in the US/CA + contest_factory(country_isos={"us", "ca"}) + + # Not eligible in mexico + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="mx" + ) + assert len(cs) == 0 + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="us" + ) + assert len(cs) == 1 + + # Create another, any country + contest_factory(country_isos=None) + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="mx" + ) + assert len(cs) == 1 + cs = contest_manager.get_many_by_user_eligible( + user=user_with_wallet, country_iso="us" + ) + assert len(cs) == 2 + + def test_list_user_eligible( + self, user_with_money: User, contest_factory, thl_lm, contest_manager + ): + c = contest_factory( + end_condition=ContestEndCondition(target_entry_amount=USDCent(10)), + entry_rule=ContestEntryRule( + max_entry_amount_per_user=USDCent(1), + ), + ) + cs = contest_manager.get_many_by_user_eligible( + user=user_with_money, country_iso="us" + ) + assert len(cs) == 1 + + entry = ContestEntry( + entry_type=ContestEntryType.CASH, + user=user_with_money, + amount=USDCent(1), + ) + contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + + # User isn't eligible anymore + cs = contest_manager.get_many_by_user_eligible( + user=user_with_money, country_iso="us" + ) + assert len(cs) == 0 + + # But it comes back in the list entered + cs = contest_manager.get_many_by_user_entered(user=user_with_money) + assert len(cs) == 1 + c = cs[0] + assert c.user_amount == USDCent(1) + assert c.user_amount_today == USDCent(1) + assert c.current_win_probability == 1 + assert c.projected_win_probability == approx(1 / 10, rel=0.01) + + # And nothing won yet #todo + # cs = cm.get_many_by_user_won(user=user_with_money) + + assert len(contest_manager.get_winnings_by_user(user_with_money)) == 0 + + def test_list_user_winnings( + self, user_with_money: User, contest_factory, thl_lm, contest_manager + ): + c = contest_factory( + end_condition=ContestEndCondition(target_entry_amount=USDCent(100)), + ) + entry = ContestEntry( + entry_type=ContestEntryType.CASH, + user=user_with_money, + amount=USDCent(100), + ) + contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) + # Contest ends after 100 entry, user enters 100 entry, user wins! + ws = contest_manager.get_winnings_by_user(user_with_money) + assert len(ws) == 1 + w = ws[0] + assert w.user.user_id == user_with_money.user_id + assert w.prize == c.prizes[0] + assert w.awarded_cash_amount is None + + cs = contest_manager.get_many_by_user_won(user_with_money) + assert len(cs) == 1 + c = cs[0] + w = c.user_winnings[0] + assert w.prize == c.prizes[0] + assert w.user.user_id == user_with_money.user_id + + +class TestRaffleContestCRUDCount: + # This is a COUNT contest. No cash moves. Not really fleshed out what we'd do with this. + @pytest.mark.skip + def test_enter( + self, user_with_wallet: User, contest_factory, thl_lm, contest_manager + ): + c = contest_factory(entry_type=ContestEntryType.COUNT) + entry = ContestEntry( + entry_type=ContestEntryType.COUNT, + user=user_with_wallet, + amount=1, + ) + contest_manager.enter_contest( + contest_uuid=c.uuid, + entry=entry, + country_iso="us", + ledger_manager=thl_lm, + ) diff --git a/tests/managers/thl/test_harmonized_uqa.py b/tests/managers/thl/test_harmonized_uqa.py new file mode 100644 index 0000000..6bbbbe1 --- /dev/null +++ b/tests/managers/thl/test_harmonized_uqa.py @@ -0,0 +1,116 @@ +from datetime import datetime, timezone + +import pytest + +from generalresearch.managers.thl.profiling.uqa import UQAManager +from generalresearch.models.thl.profiling.user_question_answer import ( + UserQuestionAnswer, + DUMMY_UQA, +) +from generalresearch.models.thl.user import User + + +@pytest.mark.usefixtures("uqa_db_index", "upk_data", "uqa_manager_clear_cache") +class TestUQAManager: + + def test_init(self, uqa_manager: UQAManager, user: User): + uqas = uqa_manager.get(user) + assert len(uqas) == 0 + + def test_create(self, uqa_manager: UQAManager, user: User): + now = datetime.now(tz=timezone.utc) + uqas = [ + UserQuestionAnswer( + user_id=user.user_id, + question_id="fd5bd491b75a491aa7251159680bf1f1", + country_iso="us", + language_iso="eng", + answer=("2",), + timestamp=now, + property_code="m:job_role", + calc_answers={"m:job_role": ("2",)}, + ) + ] + uqa_manager.create(user, uqas) + + res = uqa_manager.get(user=user) + assert len(res) == 1 + assert res[0] == uqas[0] + + # Same question, so this gets updated + now = datetime.now(tz=timezone.utc) + uqas_update = [ + UserQuestionAnswer( + user_id=user.user_id, + question_id="fd5bd491b75a491aa7251159680bf1f1", + country_iso="us", + language_iso="eng", + answer=("3",), + timestamp=now, + property_code="m:job_role", + calc_answers={"m:job_role": ("3",)}, + ) + ] + uqa_manager.create(user, uqas_update) + res = uqa_manager.get(user=user) + assert len(res) == 1 + assert res[0] == uqas_update[0] + + # Add a new answer + now = datetime.now(tz=timezone.utc) + uqas_new = [ + UserQuestionAnswer( + user_id=user.user_id, + question_id="3b65220db85f442ca16bb0f1c0e3a456", + country_iso="us", + language_iso="eng", + answer=("3",), + timestamp=now, + property_code="gr:children_age_gender", + calc_answers={"gr:children_age_gender": ("3",)}, + ) + ] + uqa_manager.create(user, uqas_new) + res = uqa_manager.get(user=user) + assert len(res) == 2 + assert res[1] == uqas_update[0] + assert res[0] == uqas_new[0] + + +@pytest.mark.usefixtures("uqa_db_index", "upk_data", "uqa_manager_clear_cache") +class TestUQAManagerCache: + + def test_get_uqa_empty(self, uqa_manager: UQAManager, user: User, caplog): + res = uqa_manager.get(user=user) + assert len(res) == 0 + + res = uqa_manager.get_from_db(user=user) + assert len(res) == 0 + + # Checking that the cache has only the dummy_uqa in it + res = uqa_manager.get_from_cache(user=user) + assert res == [DUMMY_UQA] + + with caplog.at_level("INFO"): + res = uqa_manager.get(user=user) + assert f"thl-grpc:uqa-cache-v2:{user.user_id} exists" in caplog.text + assert len(res) == 0 + + def test_get_uqa(self, uqa_manager: UQAManager, user: User, caplog): + + # Now the user sends an answer + uqas = [ + UserQuestionAnswer( + question_id="5d6d9f3c03bb40bf9d0a24f306387d7c", + answer=("1",), + timestamp=datetime.now(tz=timezone.utc), + country_iso="us", + language_iso="eng", + property_code="gr:gender", + user_id=user.user_id, + calc_answers={}, + ) + ] + uqa_manager.update_cache(user=user, uqas=uqas) + res = uqa_manager.get_from_cache(user=user) + assert res == uqas diff --git a/tests/managers/thl/test_ipinfo.py b/tests/managers/thl/test_ipinfo.py new file mode 100644 index 0000000..847b00c --- /dev/null +++ b/tests/managers/thl/test_ipinfo.py @@ -0,0 +1,117 @@ +import faker + +from generalresearch.managers.thl.ipinfo import ( + IPGeonameManager, + IPInformationManager, + GeoIpInfoManager, +) +from generalresearch.models.thl.ipinfo import IPGeoname, IPInformation + +fake = faker.Faker() + + +class TestIPGeonameManager: + + def test_init(self, thl_web_rr, ip_geoname_manager: IPGeonameManager): + + instance = IPGeonameManager(pg_config=thl_web_rr) + assert isinstance(instance, IPGeonameManager) + assert isinstance(ip_geoname_manager, IPGeonameManager) + + def test_create(self, ip_geoname_manager: IPGeonameManager): + + instance = ip_geoname_manager.create_dummy() + + assert isinstance(instance, IPGeoname) + + res = ip_geoname_manager.fetch_geoname_ids(filter_ids=[instance.geoname_id]) + + assert res[0].model_dump_json() == instance.model_dump_json() + + +class TestIPInformationManager: + + def test_init(self, thl_web_rr, ip_information_manager: IPInformationManager): + instance = IPInformationManager(pg_config=thl_web_rr) + assert isinstance(instance, IPInformationManager) + assert isinstance(ip_information_manager, IPInformationManager) + + def test_create(self, ip_information_manager: IPInformationManager): + instance = ip_information_manager.create_dummy() + + assert isinstance(instance, IPInformation) + + res = ip_information_manager.fetch_ip_information(filter_ips=[instance.ip]) + + assert res[0].model_dump_json() == instance.model_dump_json() + + def test_prefetch_geoname(self, ip_information, ip_geoname, thl_web_rr): + assert isinstance(ip_information, IPInformation) + + assert ip_information.geoname_id == ip_geoname.geoname_id + assert ip_information.geoname is None + + ip_information.prefetch_geoname(pg_config=thl_web_rr) + assert isinstance(ip_information.geoname, IPGeoname) + + +class TestGeoIpInfoManager: + def test_init( + self, thl_web_rr, thl_redis_config, geoipinfo_manager: GeoIpInfoManager + ): + instance = GeoIpInfoManager(pg_config=thl_web_rr, redis_config=thl_redis_config) + assert isinstance(instance, GeoIpInfoManager) + assert isinstance(geoipinfo_manager, GeoIpInfoManager) + + def test_multi(self, ip_information_factory, ip_geoname, geoipinfo_manager): + ip = fake.ipv4_public() + ip_information_factory(ip=ip, geoname=ip_geoname) + ips = [ip] + + # This only looks up in redis. They don't exist yet + res = geoipinfo_manager.get_cache_multi(ip_addresses=ips) + assert res == {ip: None} + + # Looks up in redis, if not exists, looks in mysql, then sets + # the caches that didn't exist. + res = geoipinfo_manager.get_multi(ip_addresses=ips) + assert res[ip] is not None + + ip2 = fake.ipv4_public() + ip_information_factory(ip=ip2, geoname=ip_geoname) + ips = [ip, ip2] + res = geoipinfo_manager.get_cache_multi(ip_addresses=ips) + assert res[ip] is not None + assert res[ip2] is None + res = geoipinfo_manager.get_multi(ip_addresses=ips) + assert res[ip] is not None + assert res[ip2] is not None + res = geoipinfo_manager.get_cache_multi(ip_addresses=ips) + assert res[ip] is not None + assert res[ip2] is not None + + def test_multi_ipv6(self, ip_information_factory, ip_geoname, geoipinfo_manager): + ip = fake.ipv6() + # Make another IP that will be in the same /64 block. + ip2 = ip[:-1] + "a" if ip[-1] != "a" else ip[:-1] + "b" + ip_information_factory(ip=ip, geoname=ip_geoname) + ips = [ip, ip2] + print(f"{ips=}") + + # This only looks up in redis. They don't exist yet + res = geoipinfo_manager.get_cache_multi(ip_addresses=ips) + assert res == {ip: None, ip2: None} + + # Looks up in redis, if not exists, looks in mysql, then sets + # the caches that didn't exist. + res = geoipinfo_manager.get_multi(ip_addresses=ips) + assert res[ip].ip == ip + assert res[ip].lookup_prefix == "/64" + assert res[ip2].ip == ip2 + assert res[ip2].lookup_prefix == "/64" + # they should be the same basically, except for the ip + + def test_doesnt_exist(self, geoipinfo_manager): + ip = fake.ipv4_public() + res = geoipinfo_manager.get_multi(ip_addresses=[ip]) + assert res == {ip: None} diff --git a/tests/managers/thl/test_ledger/__init__.py b/tests/managers/thl/test_ledger/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/managers/thl/test_ledger/test_lm_accounts.py b/tests/managers/thl/test_ledger/test_lm_accounts.py new file mode 100644 index 0000000..e0d9b0b --- /dev/null +++ b/tests/managers/thl/test_ledger/test_lm_accounts.py @@ -0,0 +1,268 @@ +from itertools import product as iproduct +from random import randint +from uuid import uuid4 + +import pytest + +from generalresearch.currency import LedgerCurrency +from generalresearch.managers.base import Permission +from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, +) +from generalresearch.managers.thl.ledger_manager.ledger import LedgerManager +from generalresearch.models.thl.ledger import LedgerAccount, AccountType, Direction +from generalresearch.models.thl.ledger import ( + LedgerEntry, +) +from test_utils.managers.ledger.conftest import ledger_account + + +@pytest.mark.parametrize( + argnames="currency, kind, acct_id", + argvalues=list( + iproduct( + ["USD", "test", "EUR"], + ["expense", "wallet", "revenue", "cash"], + [uuid4().hex for i in range(3)], + ) + ), +) +class TestLedgerAccountManagerNoResults: + + def test_get_account_no_results(self, currency, kind, acct_id, lm): + """Try to query for accounts that we know don't exist and confirm that + we either get the expected None result or it raises the correct + exception + """ + qn = ":".join([currency, kind, acct_id]) + + # (1) .get_account is just a wrapper for .get_account_many_ but + # call it either way + assert lm.get_account(qualified_name=qn, raise_on_error=False) is None + + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + lm.get_account(qualified_name=qn, raise_on_error=True) + + # (2) .get_account_if_exists is another wrapper + assert lm.get_account(qualified_name=qn, raise_on_error=False) is None + + def test_get_account_no_results_many(self, currency, kind, acct_id, lm): + qn = ":".join([currency, kind, acct_id]) + + # (1) .get_many_ + assert lm.get_account_many_(qualified_names=[qn], raise_on_error=False) == [] + + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + lm.get_account_many_(qualified_names=[qn], raise_on_error=True) + + # (2) .get_many + assert lm.get_account_many(qualified_names=[qn], raise_on_error=False) == [] + + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + lm.get_account_many(qualified_names=[qn], raise_on_error=True) + + # (3) .get_accounts(..) + assert lm.get_accounts_if_exists(qualified_names=[qn]) == [] + + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + lm.get_accounts(qualified_names=[qn]) + + +@pytest.mark.parametrize( + argnames="currency, account_type, direction", + argvalues=list( + iproduct( + list(LedgerCurrency), + list(AccountType), + list(Direction), + ) + ), +) +class TestLedgerAccountManagerCreate: + + def test_create_account_error_permission( + self, currency, account_type, direction, lm + ): + """Confirm that the Permission values that are set on the Ledger Manger + allow the Creation action to occur. + """ + acct_uuid = uuid4().hex + + account = LedgerAccount( + display_name=f"test-{uuid4().hex}", + currency=currency, + qualified_name=f"{currency.value}:{account_type.value}:{acct_uuid}", + account_type=account_type, + normal_balance=direction, + ) + + # (1) With no Permissions defined + test_lm = LedgerManager( + pg_config=lm.pg_config, + permissions=[], + redis_config=lm.redis_config, + cache_prefix=lm.cache_prefix, + testing=lm.testing, + ) + + with pytest.raises(expected_exception=AssertionError) as excinfo: + test_lm.create_account(account=account) + assert ( + str(excinfo.value) == "LedgerManager does not have sufficient permissions" + ) + + # (2) With Permissions defined, but not CREATE + test_lm = LedgerManager( + pg_config=lm.pg_config, + permissions=[Permission.READ, Permission.UPDATE, Permission.DELETE], + redis_config=lm.redis_config, + cache_prefix=lm.cache_prefix, + testing=lm.testing, + ) + + with pytest.raises(expected_exception=AssertionError) as excinfo: + test_lm.create_account(account=account) + assert ( + str(excinfo.value) == "LedgerManager does not have sufficient permissions" + ) + + def test_create(self, currency, account_type, direction, lm): + """Confirm that the Permission values that are set on the Ledger Manger + allow the Creation action to occur. + """ + + acct_uuid = uuid4().hex + qn = f"{currency.value}:{account_type.value}:{acct_uuid}" + + acct_model = LedgerAccount( + uuid=acct_uuid, + display_name=f"test-{uuid4().hex}", + currency=currency, + qualified_name=qn, + account_type=account_type, + normal_balance=direction, + ) + account = lm.create_account(account=acct_model) + assert isinstance(account, LedgerAccount) + + # Query for, and make sure the Account was saved in the DB + res = lm.get_account(qualified_name=qn, raise_on_error=True) + assert account.uuid == res.uuid + + def test_get_or_create(self, currency, account_type, direction, lm): + """Confirm that the Permission values that are set on the Ledger Manger + allow the Creation action to occur. + """ + + acct_uuid = uuid4().hex + qn = f"{currency.value}:{account_type.value}:{acct_uuid}" + + acct_model = LedgerAccount( + uuid=acct_uuid, + display_name=f"test-{uuid4().hex}", + currency=currency, + qualified_name=qn, + account_type=account_type, + normal_balance=direction, + ) + account = lm.get_account_or_create(account=acct_model) + assert isinstance(account, LedgerAccount) + + # Query for, and make sure the Account was saved in the DB + res = lm.get_account(qualified_name=qn, raise_on_error=True) + assert account.uuid == res.uuid + + +class TestLedgerAccountManagerGet: + + def test_get(self, ledger_account, lm): + res = lm.get_account(qualified_name=ledger_account.qualified_name) + assert res.uuid == ledger_account.uuid + + res = lm.get_account_many(qualified_names=[ledger_account.qualified_name]) + assert len(res) == 1 + assert res[0].uuid == ledger_account.uuid + + res = lm.get_accounts(qualified_names=[ledger_account.qualified_name]) + assert len(res) == 1 + assert res[0].uuid == ledger_account.uuid + + # TODO: I can't test the get_balance without first having Transaction + # creation working + + def test_get_balance_empty( + self, ledger_account, ledger_account_credit, ledger_account_debit, ledger_tx, lm + ): + res = lm.get_account_balance(account=ledger_account) + assert res == 0 + + res = lm.get_account_balance(account=ledger_account_credit) + assert res == 100 + + res = lm.get_account_balance(account=ledger_account_debit) + assert res == 100 + + @pytest.mark.parametrize("n_times", range(5)) + def test_get_account_filtered_balance( + self, + ledger_account, + ledger_account_credit, + ledger_account_debit, + ledger_tx, + n_times, + lm, + ): + """Try searching for random metadata and confirm it's always 0 because + Tx can be found. + """ + rand_key = f"key-{uuid4().hex[:10]}" + rand_value = uuid4().hex + + assert ( + lm.get_account_filtered_balance( + account=ledger_account, metadata_key=rand_key, metadata_value=rand_value + ) + == 0 + ) + + # Let's create a transaction with this metadata to confirm it saves + # and that we can filter it back + rand_amount = randint(10, 1_000) + + lm.create_tx( + entries=[ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=ledger_account_credit.uuid, + amount=rand_amount, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=ledger_account_debit.uuid, + amount=rand_amount, + ), + ], + metadata={rand_key: rand_value}, + ) + + assert ( + lm.get_account_filtered_balance( + account=ledger_account_credit, + metadata_key=rand_key, + metadata_value=rand_value, + ) + == rand_amount + ) + + assert ( + lm.get_account_filtered_balance( + account=ledger_account_debit, + metadata_key=rand_key, + metadata_value=rand_value, + ) + == rand_amount + ) + + def test_get_balance_timerange_empty(self, ledger_account, lm): + res = lm.get_account_balance_timerange(account=ledger_account) + assert res == 0 diff --git a/tests/managers/thl/test_ledger/test_lm_tx.py b/tests/managers/thl/test_ledger/test_lm_tx.py new file mode 100644 index 0000000..37b7ba3 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_lm_tx.py @@ -0,0 +1,235 @@ +from decimal import Decimal +from random import randint +from uuid import uuid4 + +import pytest + +from generalresearch.currency import LedgerCurrency +from generalresearch.managers.thl.ledger_manager.ledger import LedgerManager +from generalresearch.models.thl.ledger import ( + Direction, + LedgerEntry, + LedgerTransaction, +) + + +class TestLedgerManagerCreateTx: + + def test_create_account_error_permission(self, lm): + """Confirm that the Permission values that are set on the Ledger Manger + allow the Creation action to occur. + """ + acct_uuid = uuid4().hex + + # (1) With no Permissions defined + test_lm = LedgerManager( + pg_config=lm.pg_config, + permissions=[], + redis_config=lm.redis_config, + cache_prefix=lm.cache_prefix, + testing=lm.testing, + ) + + with pytest.raises(expected_exception=AssertionError) as excinfo: + test_lm.create_tx(entries=[]) + assert ( + str(excinfo.value) + == "LedgerTransactionManager has insufficient Permissions" + ) + + def test_create_assertions(self, ledger_account_debit, ledger_account_credit, lm): + with pytest.raises(expected_exception=ValueError) as excinfo: + lm.create_tx( + entries=[ + { + "direction": Direction.CREDIT, + "account_uuid": uuid4().hex, + "amount": randint(a=1, b=100), + } + ] + ) + assert ( + "Assertion failed, ledger transaction must have 2 or more entries" + in str(excinfo.value) + ) + + def test_create(self, ledger_account_credit, ledger_account_debit, lm): + amount = int(Decimal("1.00") * 100) + + entries = [ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=ledger_account_credit.uuid, + amount=amount, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=ledger_account_debit.uuid, + amount=amount, + ), + ] + + # Create a Transaction and validate the operation was successful + tx = lm.create_tx(entries=entries) + assert isinstance(tx, LedgerTransaction) + + res = lm.get_tx_by_id(transaction_id=tx.id) + assert isinstance(res, LedgerTransaction) + assert len(res.entries) == 2 + assert tx.id == res.id + + def test_create_and_reverse(self, ledger_account_credit, ledger_account_debit, lm): + amount = int(Decimal("1.00") * 100) + + entries = [ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=ledger_account_credit.uuid, + amount=amount, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=ledger_account_debit.uuid, + amount=amount, + ), + ] + + tx = lm.create_tx(entries=entries) + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.id == tx.id + + assert lm.get_account_balance(account=ledger_account_credit) == 100 + assert lm.get_account_balance(account=ledger_account_debit) == 100 + assert lm.check_ledger_balanced() is True + + # Reverse it + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=ledger_account_credit.uuid, + amount=amount, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=ledger_account_debit.uuid, + amount=amount, + ), + ] + + tx = lm.create_tx(entries=entries) + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.id == tx.id + + assert lm.get_account_balance(ledger_account_credit) == 0 + assert lm.get_account_balance(ledger_account_debit) == 0 + assert lm.check_ledger_balanced() + + # subtract again + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=ledger_account_credit.uuid, + amount=amount, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=ledger_account_debit.uuid, + amount=amount, + ), + ] + tx = lm.create_tx(entries=entries) + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.id == tx.id + + assert lm.get_account_balance(ledger_account_credit) == -100 + assert lm.get_account_balance(ledger_account_debit) == -100 + assert lm.check_ledger_balanced() + + +class TestLedgerManagerGetTx: + + # @pytest.mark.parametrize("currency", [LedgerCurrency.TEST], indirect=True) + def test_get_tx_by_id(self, ledger_tx, lm): + with pytest.raises(expected_exception=AssertionError): + lm.get_tx_by_id(transaction_id=ledger_tx) + + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + assert res.id == ledger_tx.id + + # @pytest.mark.parametrize("currency", [LedgerCurrency.TEST], indirect=True) + def test_get_tx_by_ids(self, ledger_tx, lm): + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + assert res.id == ledger_tx.id + + @pytest.mark.parametrize( + "tag", [f"{LedgerCurrency.TEST}:{uuid4().hex}"], indirect=True + ) + def test_get_tx_ids_by_tag(self, ledger_tx, tag, lm): + # (1) search for a random tag + res = lm.get_tx_ids_by_tag(tag="aaa:bbb") + assert isinstance(res, set) + assert len(res) == 0 + + # (2) search for the tag that was used during ledger_transaction creation + res = lm.get_tx_ids_by_tag(tag=tag) + assert isinstance(res, set) + assert len(res) == 1 + + def test_get_tx_by_tag(self, ledger_tx, tag, lm): + # (1) search for a random tag + res = lm.get_tx_by_tag(tag="aaa:bbb") + assert isinstance(res, list) + assert len(res) == 0 + + # (2) search for the tag that was used during ledger_transaction creation + res = lm.get_tx_by_tag(tag=tag) + assert isinstance(res, list) + assert len(res) == 1 + + assert isinstance(res[0], LedgerTransaction) + assert ledger_tx.id == res[0].id + + def test_get_tx_filtered_by_account( + self, ledger_tx, ledger_account, ledger_account_debit, ledger_account_credit, lm + ): + # (1) Do basic assertion checks first + with pytest.raises(expected_exception=AssertionError) as excinfo: + lm.get_tx_filtered_by_account(account_uuid=ledger_account) + assert str(excinfo.value) == "account_uuid must be a str" + + # (2) This search doesn't return anything because this ledger account + # wasn't actually used in the entries for the ledger_transaction + res = lm.get_tx_filtered_by_account(account_uuid=ledger_account.uuid) + assert len(res) == 0 + + # (3) Either the credit or the debit example ledger_accounts wll work + # to find this transaction because they're both used in the entries + res = lm.get_tx_filtered_by_account(account_uuid=ledger_account_debit.uuid) + assert len(res) == 1 + assert res[0].id == ledger_tx.id + + res = lm.get_tx_filtered_by_account(account_uuid=ledger_account_credit.uuid) + assert len(res) == 1 + assert ledger_tx.id == res[0].id + + res2 = lm.get_tx_by_id(transaction_id=ledger_tx.id) + assert res2.model_dump_json() == res[0].model_dump_json() + + def test_filter_metadata(self, ledger_tx, tx_metadata, lm): + key, value = next(iter(tx_metadata.items())) + + # (1) Confirm a random key,value pair returns nothing + res = lm.get_tx_filtered_by_metadata( + metadata_key=f"key-{uuid4().hex[:10]}", metadata_value=uuid4().hex[:12] + ) + assert len(res) == 0 + + # (2) confirm a key,value pair return the correct results + res = lm.get_tx_filtered_by_metadata(metadata_key=key, metadata_value=value) + assert len(res) == 1 + + # assert 0 == THL_lm.get_filtered_account_balance(account2, "thl_wall", "ccc") + # assert 300 == THL_lm.get_filtered_account_balance(account1, "thl_wall", "aaa") + # assert 300 == THL_lm.get_filtered_account_balance(account2, "thl_wall", "aaa") + # assert 0 == THL_lm.get_filtered_account_balance(account3, "thl_wall", "ccc") + # assert THL_lm.check_ledger_balanced() diff --git a/tests/managers/thl/test_ledger/test_lm_tx_entries.py b/tests/managers/thl/test_ledger/test_lm_tx_entries.py new file mode 100644 index 0000000..5bf1c48 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_lm_tx_entries.py @@ -0,0 +1,26 @@ +from generalresearch.models.thl.ledger import LedgerEntry + + +class TestLedgerEntryManager: + + def test_get_tx_entries_by_tx(self, ledger_tx, lm): + # First confirm the Ledger TX exists with 2 Entries + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + assert len(res.entries) == 2 + + tx_entries = lm.get_tx_entries_by_tx(transaction=ledger_tx) + assert len(tx_entries) == 2 + + assert res.entries == tx_entries + assert isinstance(tx_entries[0], LedgerEntry) + + def test_get_tx_entries_by_txs(self, ledger_tx, lm): + # First confirm the Ledger TX exists with 2 Entries + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + assert len(res.entries) == 2 + + tx_entries = lm.get_tx_entries_by_txs(transactions=[ledger_tx]) + assert len(tx_entries) == 2 + + assert res.entries == tx_entries + assert isinstance(tx_entries[0], LedgerEntry) diff --git a/tests/managers/thl/test_ledger/test_lm_tx_locks.py b/tests/managers/thl/test_ledger/test_lm_tx_locks.py new file mode 100644 index 0000000..df2611b --- /dev/null +++ b/tests/managers/thl/test_ledger/test_lm_tx_locks.py @@ -0,0 +1,371 @@ +import logging +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from typing import Callable + +import pytest + +from generalresearch.managers.thl.ledger_manager.conditions import ( + generate_condition_mp_payment, +) +from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerTransactionCreateLockError, + LedgerTransactionFlagAlreadyExistsError, + LedgerTransactionCreateError, +) +from generalresearch.models import Source +from generalresearch.models.thl.ledger import LedgerTransaction +from generalresearch.models.thl.session import ( + Wall, + Status, + StatusCode1, + Session, + WallAdjustedStatus, +) +from generalresearch.models.thl.user import User +from test_utils.models.conftest import user_factory, session, product_user_wallet_no + +logger = logging.getLogger("LedgerManager") + + +class TestLedgerLocks: + + def test_a( + self, + user_factory, + session_factory, + product_user_wallet_no, + create_main_accounts, + caplog, + thl_lm, + lm, + utc_hour_ago, + currency, + wall_factory, + delete_ledger_db, + ): + """ + TODO: This whole test is confusing a I don't really understand. + It needs to be better documented and explained what we want + it to do and evaluate... + """ + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_user_wallet_no) + s1 = session_factory( + user=user, + wall_count=3, + wall_req_cpis=[Decimal("1.23"), Decimal("3.21"), Decimal("4")], + wall_statuses=[Status.COMPLETE, Status.COMPLETE, Status.COMPLETE], + ) + + # A User does a Wall Completion in Session=1 + w1 = s1.wall_events[0] + tx = thl_lm.create_tx_task_complete(wall=w1, user=user, created=w1.started) + assert isinstance(tx, LedgerTransaction) + + # A User does another Wall Completion in Session=1 + w2 = s1.wall_events[1] + tx = thl_lm.create_tx_task_complete(wall=w2, user=user, created=w2.started) + assert isinstance(tx, LedgerTransaction) + + # That first Wall Complete was "adjusted" to instead be marked + # as a Failure + w1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + tx = thl_lm.create_tx_task_adjustment(wall=w1, user=user) + assert isinstance(tx, LedgerTransaction) + + # A User does another! Wall Completion in Session=1; however, we + # don't create a transaction for it + w3 = s1.wall_events[2] + + # Make sure we clear any flags/locks first + lock_key = f"{currency.value}:thl_wall:{w3.uuid}" + lock_name = f"{lm.cache_prefix}:transaction_lock:{lock_key}" + flag_name = f"{lm.cache_prefix}:transaction_flag:{lock_key}" + lm.redis_client.delete(lock_name) + lm.redis_client.delete(flag_name) + + # Despite the + f1 = generate_condition_mp_payment(wall=w1) + f2 = generate_condition_mp_payment(wall=w2) + f3 = generate_condition_mp_payment(wall=w3) + assert f1(lm=lm) is False + assert f2(lm=lm) is False + assert f3(lm=lm) is True + + condition = f3 + create_tx_func = lambda: thl_lm.create_tx_task_complete_(wall=w3, user=user) + assert isinstance(create_tx_func, Callable) + assert f3(lm) is True + + lm.redis_client.delete(flag_name) + lm.redis_client.delete(lock_name) + + tx = thl_lm.create_tx_protected( + lock_key=lock_key, condition=condition, create_tx_func=create_tx_func + ) + assert f3(lm) is False + + # purposely hold the lock open + tx = None + lm.redis_client.set(lock_name, "1") + with caplog.at_level(logging.ERROR): + with pytest.raises(expected_exception=LedgerTransactionCreateLockError): + tx = thl_lm.create_tx_protected( + lock_key=lock_key, + condition=condition, + create_tx_func=create_tx_func, + ) + assert tx is None + assert "Unable to acquire lock within the time specified" in caplog.text + lm.redis_client.delete(lock_name) + + def test_locking( + self, + user_factory, + product_user_wallet_no, + create_main_accounts, + delete_ledger_db, + caplog, + thl_lm, + lm, + ): + delete_ledger_db() + create_main_accounts() + + now = datetime.now(timezone.utc) - timedelta(hours=1) + user: User = user_factory(product=product_user_wallet_no) + + # A User does a Wall complete on Session.id=1 and the transaction is + # logged to the ledger + wall1 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.23"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=now, + finished=now + timedelta(seconds=1), + ) + thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started) + + # A User does a Wall complete on Session.id=1 and the transaction is + # logged to the ledger + wall2 = Wall( + user_id=user.user_id, + source=Source.FULL_CIRCLE, + req_survey_id="yyy", + req_cpi=Decimal("3.21"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=now, + finished=now + timedelta(seconds=1), + ) + thl_lm.create_tx_task_complete(wall=wall2, user=user, created=wall2.started) + + # An hour later, the first wall complete is adjusted to a Failure and + # it's tracked in the ledger + wall1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=now + timedelta(hours=1), + ) + thl_lm.create_tx_task_adjustment(wall=wall1, user=user) + + # A User does a Wall complete on Session.id=1 and the transaction + # IS NOT logged to the ledger + wall3 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("4"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=now, + finished=now + timedelta(seconds=1), + uuid="867a282d8b4d40d2a2093d75b802b629", + ) + + revenue_account = thl_lm.get_account_task_complete_revenue() + assert 0 == thl_lm.get_account_filtered_balance( + account=revenue_account, + metadata_key="thl_wall", + metadata_value=wall3.uuid, + ) + # Make sure we clear any flags/locks first + lock_key = f"test:thl_wall:{wall3.uuid}" + lock_name = f"{lm.cache_prefix}:transaction_lock:{lock_key}" + flag_name = f"{lm.cache_prefix}:transaction_flag:{lock_key}" + lm.redis_client.delete(lock_name) + lm.redis_client.delete(flag_name) + + # Purposely hold the lock open + lm.redis_client.set(name=lock_name, value="1") + with caplog.at_level(logging.DEBUG): + with pytest.raises(expected_exception=LedgerTransactionCreateLockError): + tx = thl_lm.create_tx_task_complete( + wall=wall3, user=user, created=wall3.started + ) + assert isinstance(tx, LedgerTransaction) + assert "Unable to acquire lock within the time specified" in caplog.text + + # Release the lock + lm.redis_client.delete(lock_name) + + # Set the redis flag to indicate it has been run + lm.redis_client.set(flag_name, "1") + # with self.assertLogs(logger=logger, level=logging.DEBUG) as cm2: + with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError): + tx = thl_lm.create_tx_task_complete( + wall=wall3, user=user, created=wall3.started + ) + # self.assertIn("entered_lock: True, flag_set: True", cm2.output[0]) + + # Unset the flag + lm.redis_client.delete(flag_name) + + assert 0 == lm.get_account_filtered_balance( + account=revenue_account, + metadata_key="thl_wall", + metadata_value=wall3.uuid, + ) + + # Now actually run it + tx = thl_lm.create_tx_task_complete( + wall=wall3, user=user, created=wall3.started + ) + assert tx is not None + + # Run it again, should return None + # Confirm the Exception inheritance works + tx = None + with pytest.raises(expected_exception=LedgerTransactionCreateError): + tx = thl_lm.create_tx_task_complete( + wall=wall3, user=user, created=wall3.started + ) + assert tx is None + + # clear the redis flag, it should query the db + assert lm.redis_client.get(flag_name) is not None + lm.redis_client.delete(flag_name) + assert lm.redis_client.get(flag_name) is None + + with pytest.raises(expected_exception=LedgerTransactionCreateError): + tx = thl_lm.create_tx_task_complete( + wall=wall3, user=user, created=wall3.started + ) + + assert 400 == thl_lm.get_account_filtered_balance( + account=revenue_account, + metadata_key="thl_wall", + metadata_value=wall3.uuid, + ) + + def test_bp_payment_without_locks( + self, user_factory, product_user_wallet_no, create_main_accounts, thl_lm, lm + ): + user: User = user_factory(product=product_user_wallet_no) + wall1 = Wall( + user_id=user.user_id, + source=Source.SAGO, + req_survey_id="xxx", + req_cpi=Decimal("0.50"), + session_id=3, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + + thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started) + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": session.started + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + print(thl_net, commission_amount, bp_pay, user_pay) + + # Run it 3 times without any checks, and it gets made three times! + thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + thl_lm.create_tx_bp_payment_(session=session, created=wall1.started) + thl_lm.create_tx_bp_payment_(session=session, created=wall1.started) + + bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=user.product) + assert 48 * 3 == lm.get_account_balance(account=bp_wallet) + assert 48 * 3 == thl_lm.get_account_filtered_balance( + account=bp_wallet, metadata_key="thl_session", metadata_value=session.uuid + ) + assert lm.check_ledger_balanced() + + def test_bp_payment_with_locks( + self, user_factory, product_user_wallet_no, create_main_accounts, thl_lm, lm + ): + user: User = user_factory(product=product_user_wallet_no) + + wall1 = Wall( + user_id=user.user_id, + source=Source.SAGO, + req_survey_id="xxx", + req_cpi=Decimal("0.50"), + session_id=3, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + + thl_lm.create_tx_task_complete(wall1, user, created=wall1.started) + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": session.started + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + print(thl_net, commission_amount, bp_pay, user_pay) + + # Make sure we clear any flags/locks first + lock_key = f"test:thl_wall:{wall1.uuid}" + lock_name = f"{lm.cache_prefix}:transaction_lock:{lock_key}" + flag_name = f"{lm.cache_prefix}:transaction_flag:{lock_key}" + lm.redis_client.delete(lock_name) + lm.redis_client.delete(flag_name) + + # Run it 3 times with check, and it gets made once! + thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + with pytest.raises(expected_exception=LedgerTransactionCreateError): + thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + + with pytest.raises(expected_exception=LedgerTransactionCreateError): + thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + + bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=user.product) + assert 48 == thl_lm.get_account_balance(bp_wallet) + assert 48 == thl_lm.get_account_filtered_balance( + account=bp_wallet, + metadata_key="thl_session", + metadata_value=session.uuid, + ) + assert lm.check_ledger_balanced() diff --git a/tests/managers/thl/test_ledger/test_lm_tx_metadata.py b/tests/managers/thl/test_ledger/test_lm_tx_metadata.py new file mode 100644 index 0000000..5d12633 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_lm_tx_metadata.py @@ -0,0 +1,34 @@ +class TestLedgerMetadataManager: + + def test_get_tx_metadata_by_txs(self, ledger_tx, lm): + # First confirm the Ledger TX exists with 2 Entries + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + assert isinstance(res.metadata, dict) + + tx_metadatas = lm.get_tx_metadata_by_txs(transactions=[ledger_tx]) + assert isinstance(tx_metadatas, dict) + assert isinstance(tx_metadatas[ledger_tx.id], dict) + + assert res.metadata == tx_metadatas[ledger_tx.id] + + def test_get_tx_metadata_ids_by_tx(self, ledger_tx, lm): + # First confirm the Ledger TX exists with 2 Entries + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + tx_metadata_cnt = len(res.metadata.keys()) + + tx_metadata_ids = lm.get_tx_metadata_ids_by_tx(transaction=ledger_tx) + assert isinstance(tx_metadata_ids, set) + assert isinstance(list(tx_metadata_ids)[0], int) + + assert tx_metadata_cnt == len(tx_metadata_ids) + + def test_get_tx_metadata_ids_by_txs(self, ledger_tx, lm): + # First confirm the Ledger TX exists with 2 Entries + res = lm.get_tx_by_id(transaction_id=ledger_tx.id) + tx_metadata_cnt = len(res.metadata.keys()) + + tx_metadata_ids = lm.get_tx_metadata_ids_by_txs(transactions=[ledger_tx]) + assert isinstance(tx_metadata_ids, set) + assert isinstance(list(tx_metadata_ids)[0], int) + + assert tx_metadata_cnt == len(tx_metadata_ids) diff --git a/tests/managers/thl/test_ledger/test_thl_lm_accounts.py b/tests/managers/thl/test_ledger/test_thl_lm_accounts.py new file mode 100644 index 0000000..01d5fe1 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_thl_lm_accounts.py @@ -0,0 +1,411 @@ +from uuid import uuid4 + +import pytest + + +class TestThlLedgerManagerAccounts: + + def test_get_account_or_create_user_wallet(self, user, thl_lm, lm): + from generalresearch.currency import LedgerCurrency + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + account = thl_lm.get_account_or_create_user_wallet(user=user) + assert isinstance(account, LedgerAccount) + + assert user.uuid in account.qualified_name + assert account.display_name == f"User Wallet {user.uuid}" + assert account.account_type == AccountType.USER_WALLET + assert account.normal_balance == Direction.CREDIT + assert account.reference_type == "user" + assert account.reference_uuid == user.uuid + assert account.currency == LedgerCurrency.TEST + + # Actually query for it to confirm + res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True) + assert res.model_dump_json() == account.model_dump_json() + + def test_get_account_or_create_bp_wallet(self, product, thl_lm, lm): + from generalresearch.currency import LedgerCurrency + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + account = thl_lm.get_account_or_create_bp_wallet(product=product) + assert isinstance(account, LedgerAccount) + + assert product.uuid in account.qualified_name + assert account.display_name == f"BP Wallet {product.uuid}" + assert account.account_type == AccountType.BP_WALLET + assert account.normal_balance == Direction.CREDIT + assert account.reference_type == "bp" + assert account.reference_uuid == product.uuid + assert account.currency == LedgerCurrency.TEST + + # Actually query for it to confirm + res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True) + assert res.model_dump_json() == account.model_dump_json() + + def test_get_account_or_create_bp_commission(self, product, thl_lm, lm): + from generalresearch.currency import LedgerCurrency + from generalresearch.models.thl.ledger import ( + Direction, + AccountType, + ) + + account = thl_lm.get_account_or_create_bp_commission(product=product) + + assert product.uuid in account.qualified_name + assert account.display_name == f"Revenue from commission {product.uuid}" + assert account.account_type == AccountType.REVENUE + assert account.normal_balance == Direction.CREDIT + assert account.reference_type == "bp" + assert account.reference_uuid == product.uuid + assert account.currency == LedgerCurrency.TEST + + # Actually query for it to confirm + res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True) + assert res.model_dump_json() == account.model_dump_json() + + @pytest.mark.parametrize("expense", ["tango", "paypal", "gift", "tremendous"]) + def test_get_account_or_create_bp_expense(self, product, expense, thl_lm, lm): + from generalresearch.currency import LedgerCurrency + from generalresearch.models.thl.ledger import ( + Direction, + AccountType, + ) + + account = thl_lm.get_account_or_create_bp_expense( + product=product, expense_name=expense + ) + assert product.uuid in account.qualified_name + assert account.display_name == f"Expense {expense} {product.uuid}" + assert account.account_type == AccountType.EXPENSE + assert account.normal_balance == Direction.DEBIT + assert account.reference_type == "bp" + assert account.reference_uuid == product.uuid + assert account.currency == LedgerCurrency.TEST + + # Actually query for it to confirm + res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True) + assert res.model_dump_json() == account.model_dump_json() + + def test_get_or_create_bp_pending_payout_account(self, product, thl_lm, lm): + from generalresearch.currency import LedgerCurrency + from generalresearch.models.thl.ledger import ( + Direction, + AccountType, + ) + + account = thl_lm.get_or_create_bp_pending_payout_account(product=product) + + assert product.uuid in account.qualified_name + assert account.display_name == f"BP Wallet Pending {product.uuid}" + assert account.account_type == AccountType.BP_WALLET + assert account.normal_balance == Direction.CREDIT + assert account.reference_type == "bp" + assert account.reference_uuid == product.uuid + assert account.currency == LedgerCurrency.TEST + + # Actually query for it to confirm + res = lm.get_account(qualified_name=account.qualified_name, raise_on_error=True) + assert res.model_dump_json() == account.model_dump_json() + + def test_get_account_task_complete_revenue_raises( + self, delete_ledger_db, thl_lm, lm + ): + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + + delete_ledger_db() + + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + thl_lm.get_account_task_complete_revenue() + + def test_get_account_task_complete_revenue( + self, account_cash, account_revenue_task_complete, thl_lm, lm + ): + from generalresearch.models.thl.ledger import ( + LedgerAccount, + AccountType, + ) + + res = thl_lm.get_account_task_complete_revenue() + assert isinstance(res, LedgerAccount) + assert res.reference_type is None + assert res.reference_uuid is None + assert res.account_type == AccountType.REVENUE + assert res.display_name == "Cash flow task complete" + + def test_get_account_cash_raises(self, delete_ledger_db, thl_lm, lm): + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + + delete_ledger_db() + + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + thl_lm.get_account_cash() + + def test_get_account_cash(self, account_cash, thl_lm, lm): + from generalresearch.models.thl.ledger import ( + LedgerAccount, + AccountType, + ) + + res = thl_lm.get_account_cash() + assert isinstance(res, LedgerAccount) + assert res.reference_type is None + assert res.reference_uuid is None + assert res.account_type == AccountType.CASH + assert res.display_name == "Operating Cash Account" + + def test_get_accounts(self, setup_accounts, product, user_factory, thl_lm, lm, lam): + from generalresearch.models.thl.user import User + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + + user1: User = user_factory(product=product) + user2: User = user_factory(product=product) + + account1 = thl_lm.get_account_or_create_bp_wallet(product=product) + + # (1) known account and confirm it comes back + res = lm.get_account(qualified_name=account1.qualified_name) + assert account1.model_dump_json() == res.model_dump_json() + + # (2) known accounts and confirm they both come back + res = lam.get_accounts(qualified_names=[account1.qualified_name]) + assert isinstance(res, list) + assert len(res) == 1 + assert account1 in res + + # Get 2 known and 1 made up qualified names, and confirm it raises + # an error + with pytest.raises(LedgerAccountDoesntExistError): + lam.get_accounts( + qualified_names=[ + account1.qualified_name, + f"test:bp_wall:{uuid4().hex}", + ] + ) + + def test_get_accounts_if_exists(self, product_factory, currency, thl_lm, lm): + from generalresearch.models.thl.product import Product + + p1: Product = product_factory() + p2: Product = product_factory() + + account1 = thl_lm.get_account_or_create_bp_wallet(product=p1) + account2 = thl_lm.get_account_or_create_bp_wallet(product=p2) + + # (1) known account and confirm it comes back + res = lm.get_account(qualified_name=account1.qualified_name) + assert account1.model_dump_json() == res.model_dump_json() + + # (2) known accounts and confirm they both come back + res = lm.get_accounts( + qualified_names=[account1.qualified_name, account2.qualified_name] + ) + assert isinstance(res, list) + assert len(res) == 2 + assert account1 in res + assert account2 in res + + # Get 2 known and 1 made up qualified names, and confirm only 2 + # come back + lm.get_accounts_if_exists( + qualified_names=[ + account1.qualified_name, + account2.qualified_name, + f"{currency.value}:bp_wall:{uuid4().hex}", + ] + ) + + assert isinstance(res, list) + assert len(res) == 2 + + # Confirm an empty array comes back for all unknown qualified names + res = lm.get_accounts_if_exists( + qualified_names=[ + f"{lm.currency.value}:bp_wall:{uuid4().hex}" for i in range(5) + ] + ) + assert isinstance(res, list) + assert len(res) == 0 + + def test_get_accounts_for_products(self, product_factory, thl_lm, lm): + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + from generalresearch.models.thl.ledger import ( + LedgerAccount, + ) + + # Create 5 Products + product_uuids = [] + for i in range(5): + _p = product_factory() + product_uuids.append(_p.uuid) + + # Confirm that this fails.. because none of those accounts have been + # created yet + with pytest.raises(expected_exception=LedgerAccountDoesntExistError): + thl_lm.get_accounts_bp_wallet_for_products(product_uuids=product_uuids) + + # Create the bp_wallet accounts and then try again + for p_uuid in product_uuids: + thl_lm.get_account_or_create_bp_wallet_by_uuid(product_uuid=p_uuid) + + res = thl_lm.get_accounts_bp_wallet_for_products(product_uuids=product_uuids) + assert len(res) == len(product_uuids) + assert all([isinstance(i, LedgerAccount) for i in res]) + + +class TestLedgerAccountManager: + + def test_get_or_create(self, thl_lm, lm, lam): + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + u = uuid4().hex + name = f"test-{u[:8]}" + + account = LedgerAccount( + display_name=name, + qualified_name=f"test:bp_wallet:{u}", + normal_balance=Direction.DEBIT, + account_type=AccountType.BP_WALLET, + currency="test", + reference_type="bp", + reference_uuid=u, + ) + + # First we want to validate that using the get_account method raises + # an error for a random LedgerAccount which we know does not exist. + with pytest.raises(LedgerAccountDoesntExistError): + lam.get_account(qualified_name=account.qualified_name) + + # Now that we know it doesn't exist, get_or_create for it + instance = lam.get_account_or_create(account=account) + + # It should always return + assert isinstance(instance, LedgerAccount) + assert instance.reference_uuid == u + + def test_get(self, user, thl_lm, lm, lam): + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + from generalresearch.models.thl.ledger import ( + LedgerAccount, + AccountType, + ) + + with pytest.raises(LedgerAccountDoesntExistError): + lam.get_account(qualified_name=f"test:bp_wallet:{user.product.id}") + + thl_lm.get_account_or_create_bp_wallet(product=user.product) + account = lam.get_account(qualified_name=f"test:bp_wallet:{user.product.id}") + + assert isinstance(account, LedgerAccount) + assert AccountType.BP_WALLET == account.account_type + assert user.product.uuid == account.reference_uuid + + def test_get_many(self, product_factory, thl_lm, lm, lam, currency): + from generalresearch.models.thl.product import Product + from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerAccountDoesntExistError, + ) + + p1: Product = product_factory() + p2: Product = product_factory() + + account1 = thl_lm.get_account_or_create_bp_wallet(product=p1) + account2 = thl_lm.get_account_or_create_bp_wallet(product=p2) + + # Get 1 known account and confirm it comes back + res = lam.get_account_many( + qualified_names=[account1.qualified_name, account2.qualified_name] + ) + assert isinstance(res, list) + assert len(res) == 2 + assert account1 in res + + # Get 2 known accounts and confirm they both come back + res = lam.get_account_many( + qualified_names=[account1.qualified_name, account2.qualified_name] + ) + assert isinstance(res, list) + assert len(res) == 2 + assert account1 in res + assert account2 in res + + # Get 2 known and 1 made up qualified names, and confirm only 2 come + # back. Don't raise on error, so we can confirm the array is "short" + res = lam.get_account_many( + qualified_names=[ + account1.qualified_name, + account2.qualified_name, + f"test:bp_wall:{uuid4().hex}", + ], + raise_on_error=False, + ) + assert isinstance(res, list) + assert len(res) == 2 + + # Same as above, but confirm the raise works on checking res length + with pytest.raises(LedgerAccountDoesntExistError): + lam.get_account_many( + qualified_names=[ + account1.qualified_name, + account2.qualified_name, + f"test:bp_wall:{uuid4().hex}", + ], + raise_on_error=True, + ) + + # Confirm an empty array comes back for all unknown qualified names + res = lam.get_account_many( + qualified_names=[f"test:bp_wall:{uuid4().hex}" for i in range(5)], + raise_on_error=False, + ) + assert isinstance(res, list) + assert len(res) == 0 + + def test_create_account(self, thl_lm, lm, lam): + from generalresearch.models.thl.ledger import ( + LedgerAccount, + Direction, + AccountType, + ) + + u = uuid4().hex + name = f"test-{u[:8]}" + + account = LedgerAccount( + display_name=name, + qualified_name=f"test:bp_wallet:{u}", + normal_balance=Direction.DEBIT, + account_type=AccountType.BP_WALLET, + currency="test", + reference_type="bp", + reference_uuid=u, + ) + + lam.create_account(account=account) + assert lam.get_account(f"test:bp_wallet:{u}") == account + assert lam.get_account_or_create(account) == account diff --git a/tests/managers/thl/test_ledger/test_thl_lm_bp_payout.py b/tests/managers/thl/test_ledger/test_thl_lm_bp_payout.py new file mode 100644 index 0000000..294d092 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_thl_lm_bp_payout.py @@ -0,0 +1,516 @@ +import logging +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from random import randint +from uuid import uuid4 + +import pytest +import redis +from pydantic import RedisDsn +from redis.lock import Lock + +from generalresearch.currency import USDCent +from generalresearch.managers.base import Permission +from generalresearch.managers.thl.ledger_manager.thl_ledger import ThlLedgerManager +from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerTransactionFlagAlreadyExistsError, + LedgerTransactionConditionFailedError, + LedgerTransactionReleaseLockError, + LedgerTransactionCreateError, +) +from generalresearch.managers.thl.ledger_manager.ledger import LedgerTransaction +from generalresearch.models import Source +from generalresearch.models.thl.definitions import PayoutStatus +from generalresearch.models.thl.ledger import Direction, TransactionType +from generalresearch.models.thl.session import ( + Wall, + Status, + StatusCode1, + Session, +) +from generalresearch.models.thl.user import User +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.redis_helper import RedisConfig + + +def broken_acquire(self, *args, **kwargs): + raise redis.exceptions.TimeoutError("Simulated timeout during acquire") + + +def broken_release(self, *args, **kwargs): + raise redis.exceptions.TimeoutError("Simulated timeout during release") + + +class TestThlLedgerManagerBPPayout: + + def test_create_tx_with_bp_payment( + self, + user_factory, + product_user_wallet_no, + create_main_accounts, + caplog, + thl_lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + now = datetime.now(timezone.utc) - timedelta(hours=1) + user: User = user_factory(product=product_user_wallet_no) + + wall1 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("6.00"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=now, + finished=now + timedelta(seconds=1), + ) + tx = thl_lm.create_tx_task_complete( + wall=wall1, user=user, created=wall1.started + ) + assert isinstance(tx, LedgerTransaction) + + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": now + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + + lock_key = f"test:bp_payout:{user.product.id}" + flag_name = f"{thl_lm.cache_prefix}:transaction_flag:{lock_key}" + thl_lm.redis_client.delete(flag_name) + + payoutevent_uuid = uuid4().hex + thl_lm.create_tx_bp_payout( + product=user.product, + amount=USDCent(200), + created=now, + payoutevent_uuid=payoutevent_uuid, + ) + + payoutevent_uuid = uuid4().hex + thl_lm.create_tx_bp_payout( + product=user.product, + amount=USDCent(200), + created=now + timedelta(minutes=2), + skip_one_per_day_check=True, + payoutevent_uuid=payoutevent_uuid, + ) + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + assert 170 == thl_lm.get_account_balance(bp_wallet_account) + assert 200 == thl_lm.get_account_balance(cash) + + with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError): + thl_lm.create_tx_bp_payout( + user.product, + amount=USDCent(200), + created=now + timedelta(minutes=2), + skip_one_per_day_check=False, + skip_wallet_balance_check=False, + payoutevent_uuid=payoutevent_uuid, + ) + + payoutevent_uuid = uuid4().hex + with caplog.at_level(logging.INFO): + with pytest.raises(LedgerTransactionConditionFailedError): + thl_lm.create_tx_bp_payout( + user.product, + amount=USDCent(10_000), + created=now + timedelta(minutes=2), + skip_one_per_day_check=True, + skip_wallet_balance_check=False, + payoutevent_uuid=payoutevent_uuid, + ) + assert "failed condition check balance:" in caplog.text + + thl_lm.create_tx_bp_payout( + product=user.product, + amount=USDCent(10_00), + created=now + timedelta(minutes=2), + skip_one_per_day_check=True, + skip_wallet_balance_check=True, + payoutevent_uuid=payoutevent_uuid, + ) + assert 170 - 1000 == thl_lm.get_account_balance(bp_wallet_account) + + def test_create_tx(self, product, caplog, thl_lm, currency): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + payoutevent_uuid = uuid4().hex + + # Create a BP Payout for a Product without any activity. By issuing, + # the skip_* checks, we should be able to force it to work, and will + # then ultimately result in a negative balance + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + skip_flag_check=True, + ) + + # Check the basic attributes + assert isinstance(tx, LedgerTransaction) + assert tx.ext_description == "BP Payout" + assert ( + tx.tag + == f"{currency.value}:{TransactionType.BP_PAYOUT.value}:{payoutevent_uuid}" + ) + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + + # Check the Product's balance, it should be negative the amount that was + # paid out. That's because the Product earned nothing.. and then was + # sent something. + balance = thl_lm.get_account_balance( + account=thl_lm.get_account_or_create_bp_wallet(product=product) + ) + assert balance == int(rand_amount) * -1 + + # Test some basic assertions + with caplog.at_level(logging.INFO): + with pytest.raises(expected_exception=Exception): + thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=uuid4().hex, + created=datetime.now(tz=timezone.utc), + skip_wallet_balance_check=False, + skip_one_per_day_check=False, + skip_flag_check=False, + ) + assert "failed condition check >1 tx per day" in caplog.text + + def test_create_tx_redis_failure(self, product, thl_web_rw, thl_lm): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + payoutevent_uuid = uuid4().hex + now = datetime.now(tz=timezone.utc) + + thl_lm.create_tx_plug_bp_wallet( + product, rand_amount, now, direction=Direction.CREDIT + ) + + # Non routable IP address. Redis will fail + thl_lm_redis_0 = ThlLedgerManager( + pg_config=thl_web_rw, + permissions=[ + Permission.CREATE, + Permission.READ, + Permission.UPDATE, + Permission.DELETE, + ], + testing=True, + redis_config=RedisConfig( + dsn=RedisDsn("redis://10.255.255.1:6379"), + socket_connect_timeout=0.1, + ), + ) + + with pytest.raises(expected_exception=Exception) as e: + tx = thl_lm_redis_0.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + ) + assert e.type is redis.exceptions.TimeoutError + # No txs were created + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product) + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + assert len(txs) == 0 + + def test_create_tx_multiple_per_day(self, product, thl_lm): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + payoutevent_uuid = uuid4().hex + now = datetime.now(tz=timezone.utc) + + thl_lm.create_tx_plug_bp_wallet( + product, rand_amount * USDCent(2), now, direction=Direction.CREDIT + ) + + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + ) + + # Try to create another + # Will fail b/c it has the same payout event uuid + with pytest.raises(expected_exception=Exception) as e: + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + ) + assert e.type is LedgerTransactionFlagAlreadyExistsError + + # Try to create another + # Will fail due to multiple per day + payoutevent_uuid2 = uuid4().hex + with pytest.raises(expected_exception=Exception) as e: + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid2, + created=datetime.now(tz=timezone.utc), + ) + assert e.type is LedgerTransactionConditionFailedError + assert str(e.value) == ">1 tx per day" + + # Make it run by skipping one per day check + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid2, + created=datetime.now(tz=timezone.utc), + skip_one_per_day_check=True, + ) + + def test_create_tx_redis_lock_release_error(self, product, thl_lm): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + payoutevent_uuid = uuid4().hex + now = datetime.now(tz=timezone.utc) + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product) + + thl_lm.create_tx_plug_bp_wallet( + product, rand_amount * USDCent(2), now, direction=Direction.CREDIT + ) + + original_acquire = Lock.acquire + original_release = Lock.release + Lock.acquire = broken_acquire + + # Create TX will fail on lock enter, no tx will actually get created + with pytest.raises(expected_exception=Exception) as e: + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + ) + assert e.type is LedgerTransactionCreateError + assert str(e.value) == "Redis error: Simulated timeout during acquire" + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + assert len(txs) == 0 + + Lock.acquire = original_acquire + Lock.release = broken_release + + # Create TX will fail on lock exit, after the tx was created! + with pytest.raises(expected_exception=Exception) as e: + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + ) + assert e.type is LedgerTransactionReleaseLockError + assert str(e.value) == "Redis error: Simulated timeout during release" + + # Transaction was still created! + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + assert len(txs) == 1 + Lock.release = original_release + + +class TestPayoutEventManagerBPPayout: + + def test_create(self, product, thl_lm, brokerage_product_payout_event_manager): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + now = datetime.now(tz=timezone.utc) + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product) + assert thl_lm.get_account_balance(bp_wallet_account) == 0 + thl_lm.create_tx_plug_bp_wallet( + product, rand_amount, now, direction=Direction.CREDIT + ) + assert thl_lm.get_account_balance(bp_wallet_account) == rand_amount + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + created=now, + amount=rand_amount, + payout_type=PayoutType.ACH, + ) + assert brokerage_product_payout_event_manager.check_for_ledger_tx( + thl_ledger_manager=thl_lm, + product_id=product.id, + amount=rand_amount, + payout_event=pe, + ) + assert thl_lm.get_account_balance(bp_wallet_account) == 0 + + def test_create_with_redis_error( + self, product, caplog, thl_lm, brokerage_product_payout_event_manager + ): + caplog.set_level("WARNING") + original_acquire = Lock.acquire + original_release = Lock.release + + rand_amount: USDCent = USDCent(randint(100, 1_000)) + now = datetime.now(tz=timezone.utc) + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product) + assert thl_lm.get_account_balance(bp_wallet_account) == 0 + thl_lm.create_tx_plug_bp_wallet( + product, rand_amount, now, direction=Direction.CREDIT + ) + assert thl_lm.get_account_balance(bp_wallet_account) == rand_amount + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + # Will fail on lock enter, no tx will actually get created + Lock.acquire = broken_acquire + with pytest.raises(expected_exception=Exception) as e: + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + created=now, + amount=rand_amount, + payout_type=PayoutType.ACH, + ) + assert e.type is LedgerTransactionCreateError + assert str(e.value) == "Redis error: Simulated timeout during acquire" + assert any( + "Simulated timeout during acquire. No ledger tx was created" in m + for m in caplog.messages + ) + + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + # One payout event is created, status is failed, and no ledger txs exist + assert len(txs) == 0 + pes = ( + brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products( + thl_ledger_manager=thl_lm, product_uuids=[product.id] + ) + ) + assert len(pes) == 1 + assert pes[0].status == PayoutStatus.FAILED + pe = pes[0] + + # Fix the redis method + Lock.acquire = original_acquire + + # Try to fix the failed payout, by trying ledger tx again + brokerage_product_payout_event_manager.retry_create_bp_payout_event_tx( + product=product, thl_ledger_manager=thl_lm, payout_event_uuid=pe.uuid + ) + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + assert len(txs) == 1 + assert thl_lm.get_account_balance(bp_wallet_account) == 0 + + # And then try to run it again, it'll fail because a payout event with the same info exists + with pytest.raises(expected_exception=Exception) as e: + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + created=now, + amount=rand_amount, + payout_type=PayoutType.ACH, + ) + assert e.type is ValueError + assert "Payout event already exists!" in str(e.value) + + # We wouldn't do this in practice, because this is paying out the BP again, but + # we can if want to. + # Change the timestamp so it'll create a new payout event + now = datetime.now(tz=timezone.utc) + with pytest.raises(LedgerTransactionConditionFailedError) as e: + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + created=now, + amount=rand_amount, + payout_type=PayoutType.ACH, + ) + # But it will fail due to 1 per day check + assert str(e.value) == ">1 tx per day" + pe = brokerage_product_payout_event_manager.get_by_uuid(e.value.pe_uuid) + assert pe.status == PayoutStatus.FAILED + + # And if we really want to, we can make it again + now = datetime.now(tz=timezone.utc) + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + created=now, + amount=rand_amount, + payout_type=PayoutType.ACH, + skip_one_per_day_check=True, + skip_wallet_balance_check=True, + ) + + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + assert len(txs) == 2 + # since they were paid twice + assert thl_lm.get_account_balance(bp_wallet_account) == 0 - rand_amount + + Lock.release = original_release + Lock.acquire = original_acquire + + def test_create_with_redis_error_release( + self, product, caplog, thl_lm, brokerage_product_payout_event_manager + ): + caplog.set_level("WARNING") + + original_release = Lock.release + + rand_amount: USDCent = USDCent(randint(100, 1_000)) + now = datetime.now(tz=timezone.utc) + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=product) + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + assert thl_lm.get_account_balance(bp_wallet_account) == 0 + thl_lm.create_tx_plug_bp_wallet( + product, rand_amount, now, direction=Direction.CREDIT + ) + assert thl_lm.get_account_balance(bp_wallet_account) == rand_amount + + # Will fail on lock exit, after the tx was created! + # But it'll see that the tx was created and so everything will be fine + Lock.release = broken_release + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + created=now, + amount=rand_amount, + payout_type=PayoutType.ACH, + ) + assert any( + "Simulated timeout during release but ledger tx exists" in m + for m in caplog.messages + ) + + txs = thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet_account.uuid) + txs = [tx for tx in txs if tx.metadata["tx_type"] != "plug"] + assert len(txs) == 1 + pes = ( + brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products( + thl_ledger_manager=thl_lm, product_uuids=[product.uuid] + ) + ) + assert len(pes) == 1 + assert pes[0].status == PayoutStatus.COMPLETE + Lock.release = original_release diff --git a/tests/managers/thl/test_ledger/test_thl_lm_tx.py b/tests/managers/thl/test_ledger/test_thl_lm_tx.py new file mode 100644 index 0000000..31c7107 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_thl_lm_tx.py @@ -0,0 +1,1762 @@ +import logging +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from random import randint +from uuid import uuid4 + +import pytest + +from generalresearch.currency import USDCent +from generalresearch.managers.thl.ledger_manager.ledger import ( + LedgerTransaction, +) +from generalresearch.models import Source +from generalresearch.models.thl.definitions import ( + WALL_ALLOWED_STATUS_STATUS_CODE, +) +from generalresearch.models.thl.ledger import Direction +from generalresearch.models.thl.ledger import TransactionType +from generalresearch.models.thl.product import ( + PayoutConfig, + PayoutTransformation, + UserWalletConfig, +) +from generalresearch.models.thl.session import ( + Wall, + Status, + StatusCode1, + Session, + WallAdjustedStatus, +) +from generalresearch.models.thl.user import User +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.thl.payout import UserPayoutEvent + +logger = logging.getLogger("LedgerManager") + + +class TestThlLedgerTxManager: + + def test_create_tx_task_complete( + self, + wall, + user, + account_revenue_task_complete, + create_main_accounts, + thl_lm, + lm, + ): + create_main_accounts() + tx = thl_lm.create_tx_task_complete(wall=wall, user=user) + assert isinstance(tx, LedgerTransaction) + + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.created == tx.created + + def test_create_tx_task_complete_( + self, wall, user, account_revenue_task_complete, thl_lm, lm + ): + tx = thl_lm.create_tx_task_complete_(wall=wall, user=user) + assert isinstance(tx, LedgerTransaction) + + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.created == tx.created + + def test_create_tx_bp_payment( + self, + session_factory, + user, + create_main_accounts, + delete_ledger_db, + thl_lm, + lm, + session_manager, + ): + delete_ledger_db() + create_main_accounts() + s1 = session_factory(user=user) + + status, status_code_1 = s1.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments() + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=status_code_1, + finished=datetime.now(tz=timezone.utc) + timedelta(minutes=10), + payout=bp_pay, + user_payout=user_pay, + ) + + tx = thl_lm.create_tx_bp_payment(session=s1) + assert isinstance(tx, LedgerTransaction) + + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.created == tx.created + + def test_create_tx_bp_payment_amt( + self, + session_factory, + user_factory, + product_manager, + create_main_accounts, + delete_ledger_db, + thl_lm, + lm, + session_manager, + ): + delete_ledger_db() + create_main_accounts() + product = product_manager.create_dummy( + payout_config=PayoutConfig( + payout_transformation=PayoutTransformation( + f="payout_transformation_amt" + ) + ), + user_wallet_config=UserWalletConfig(amt=True, enabled=True), + ) + user = user_factory(product=product) + s1 = session_factory(user=user, wall_req_cpi=Decimal("1")) + + status, status_code_1 = s1.determine_session_status() + assert status == Status.COMPLETE + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments( + thl_ledger_manager=thl_lm + ) + print(thl_net, commission_amount, bp_pay, user_pay) + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=status_code_1, + finished=datetime.now(tz=timezone.utc) + timedelta(minutes=10), + payout=bp_pay, + user_payout=user_pay, + ) + + tx = thl_lm.create_tx_bp_payment(session=s1) + assert isinstance(tx, LedgerTransaction) + + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.created == tx.created + + def test_create_tx_bp_payment_( + self, + session_factory, + user, + create_main_accounts, + thl_lm, + lm, + session_manager, + utc_hour_ago, + ): + s1 = session_factory(user=user) + status, status_code_1 = s1.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments() + session_manager.finish_with_status( + session=s1, + status=status, + status_code_1=status_code_1, + finished=utc_hour_ago + timedelta(minutes=10), + payout=bp_pay, + user_payout=user_pay, + ) + + s1.determine_payments() + tx = thl_lm.create_tx_bp_payment_(session=s1) + assert isinstance(tx, LedgerTransaction) + + res = lm.get_tx_by_id(transaction_id=tx.id) + assert res.created == tx.created + + def test_create_tx_task_adjustment( + self, wall_factory, session, user, create_main_accounts, thl_lm, lm + ): + """Create Wall event Complete, and Create a Tx Task Adjustment + + - I don't know what this does exactly... but we can confirm + the transaction comes back with balanced amounts, and that + the name of the Source is in the Tx description + """ + + wall_status = Status.COMPLETE + wall: Wall = wall_factory(session=session, wall_status=wall_status) + + tx = thl_lm.create_tx_task_adjustment(wall=wall, user=user) + assert isinstance(tx, LedgerTransaction) + res = lm.get_tx_by_id(transaction_id=tx.id) + + assert res.entries[0].amount == int(wall.cpi * 100) + assert res.entries[1].amount == int(wall.cpi * 100) + assert wall.source.name in res.ext_description + assert res.created == tx.created + + def test_create_tx_bp_adjustment(self, session, user, caplog, thl_lm, lm): + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + + # The default session fixture is just an unfinished wall event + assert len(session.wall_events) == 1 + assert session.finished is None + assert status == Status.TIMEOUT + assert status_code_1 in list( + WALL_ALLOWED_STATUS_STATUS_CODE.get(Status.TIMEOUT, {}) + ) + assert thl_net == Decimal(0) + assert commission_amount == Decimal(0) + assert bp_pay == Decimal(0) + assert user_pay is None + + # Update the finished timestamp, but nothing else. This means that + # there is no financial changes needed + session.update( + **{ + "finished": datetime.now(tz=timezone.utc) + timedelta(minutes=10), + } + ) + assert session.finished + with caplog.at_level(logging.INFO): + tx = thl_lm.create_tx_bp_adjustment(session=session) + assert tx is None + assert "No transactions needed." in caplog.text + + def test_create_tx_bp_payout(self, product, caplog, thl_lm, currency): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + payoutevent_uuid = uuid4().hex + + # Create a BP Payout for a Product without any activity. By issuing, + # the skip_* checks, we should be able to force it to work, and will + # then ultimately result in a negative balance + tx = thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + skip_flag_check=True, + ) + + # Check the basic attributes + assert isinstance(tx, LedgerTransaction) + assert tx.ext_description == "BP Payout" + assert ( + tx.tag + == f"{thl_lm.currency.value}:{TransactionType.BP_PAYOUT.value}:{payoutevent_uuid}" + ) + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + + # Check the Product's balance, it should be negative the amount that was + # paid out. That's because the Product earned nothing.. and then was + # sent something. + balance = thl_lm.get_account_balance( + account=thl_lm.get_account_or_create_bp_wallet(product=product) + ) + assert balance == int(rand_amount) * -1 + + # Test some basic assertions + with caplog.at_level(logging.INFO): + with pytest.raises(expected_exception=Exception): + thl_lm.create_tx_bp_payout( + product=product, + amount=rand_amount, + payoutevent_uuid=uuid4().hex, + created=datetime.now(tz=timezone.utc), + skip_wallet_balance_check=False, + skip_one_per_day_check=False, + skip_flag_check=False, + ) + assert "failed condition check >1 tx per day" in caplog.text + + def test_create_tx_bp_payout_(self, product, thl_lm, lm, currency): + rand_amount: USDCent = USDCent(randint(100, 1_000)) + payoutevent_uuid = uuid4().hex + + # Create a BP Payout for a Product without any activity. + tx = thl_lm.create_tx_bp_payout_( + product=product, + amount=rand_amount, + payoutevent_uuid=payoutevent_uuid, + created=datetime.now(tz=timezone.utc), + ) + + # Check the basic attributes + assert isinstance(tx, LedgerTransaction) + assert tx.ext_description == "BP Payout" + assert ( + tx.tag + == f"{currency.value}:{TransactionType.BP_PAYOUT.value}:{payoutevent_uuid}" + ) + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + + def test_create_tx_plug_bp_wallet( + self, product, create_main_accounts, thl_lm, lm, currency + ): + """A BP Wallet "plug" is a way to makeup discrepancies and simply + add or remove money + """ + rand_amount: USDCent = USDCent(randint(100, 1_000)) + + tx = thl_lm.create_tx_plug_bp_wallet( + product=product, + amount=rand_amount, + created=datetime.now(tz=timezone.utc), + direction=Direction.DEBIT, + skip_flag_check=False, + ) + + assert isinstance(tx, LedgerTransaction) + + # We issued the BP money they didn't earn, so now they have a + # negative balance + balance = thl_lm.get_account_balance( + account=thl_lm.get_account_or_create_bp_wallet(product=product) + ) + assert balance == int(rand_amount) * -1 + + def test_create_tx_plug_bp_wallet_( + self, product, create_main_accounts, thl_lm, lm, currency + ): + """A BP Wallet "plug" is a way to fix discrepancies and simply + add or remove money. + + Similar to above, but because it's unprotected, we can immediately + issue another to see if the balance changes + """ + rand_amount: USDCent = USDCent(randint(100, 1_000)) + + tx = thl_lm.create_tx_plug_bp_wallet_( + product=product, + amount=rand_amount, + created=datetime.now(tz=timezone.utc), + direction=Direction.DEBIT, + ) + + assert isinstance(tx, LedgerTransaction) + + # We issued the BP money they didn't earn, so now they have a + # negative balance + balance = thl_lm.get_account_balance( + account=thl_lm.get_account_or_create_bp_wallet(product=product) + ) + assert balance == int(rand_amount) * -1 + + # Issue a positive one now, and confirm the balance goes positive + thl_lm.create_tx_plug_bp_wallet_( + product=product, + amount=rand_amount + rand_amount, + created=datetime.now(tz=timezone.utc), + direction=Direction.CREDIT, + ) + balance = thl_lm.get_account_balance( + account=thl_lm.get_account_or_create_bp_wallet(product=product) + ) + assert balance == int(rand_amount) + + def test_create_tx_user_payout_request( + self, + user, + product_user_wallet_yes, + user_factory, + delete_df_collection, + thl_lm, + lm, + currency, + ): + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=500, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + # The default user fixture uses a product that doesn't have wallet + # mode enabled + with pytest.raises(expected_exception=AssertionError): + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + skip_flag_check=True, + skip_wallet_balance_check=True, + ) + + # Now try it for a user on a product with wallet mode + u2 = user_factory(product=product_user_wallet_yes) + + # User's pre-balance is 0 because no activity has occurred yet + pre_balance = lm.get_account_balance( + account=thl_lm.get_account_or_create_user_wallet(user=u2) + ) + assert pre_balance == 0 + + tx = thl_lm.create_tx_user_payout_request( + user=u2, + payout_event=pe, + skip_flag_check=True, + skip_wallet_balance_check=True, + ) + assert isinstance(tx, LedgerTransaction) + assert tx.entries[0].amount == pe.amount + assert tx.entries[1].amount == pe.amount + assert tx.ext_description == "User Payout Paypal Request $5.00" + + # + # (TODO): This key ":user_payout:" is NOT part of the TransactionType + # enum and was manually set. It should be based off the + # TransactionType names. + # + + assert tx.tag == f"{currency.value}:user_payout:{pe.uuid}:request" + + # Post balance is -$5.00 because it comes out of the wallet before + # it's Approved or Completed + post_balance = lm.get_account_balance( + account=thl_lm.get_account_or_create_user_wallet(user=u2) + ) + assert post_balance == -500 + + def test_create_tx_user_payout_request_( + self, + user, + product_user_wallet_yes, + user_factory, + delete_ledger_db, + thl_lm, + lm, + ): + delete_ledger_db() + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=500, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + rand_description = uuid4().hex + tx = thl_lm.create_tx_user_payout_request_( + user=user, payout_event=pe, description=rand_description + ) + + assert tx.ext_description == rand_description + + post_balance = lm.get_account_balance( + account=thl_lm.get_account_or_create_user_wallet(user=user) + ) + assert post_balance == -500 + + def test_create_tx_user_payout_complete( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + delete_ledger_db, + thl_lm, + lm, + currency, + ): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_user_wallet_yes) + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + rand_amount = randint(100, 1_000) + + # Ensure the user starts out with nothing... + assert lm.get_account_balance(account=user_account) == 0 + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=rand_amount, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + # Confirm it's not possible unless a request occurred happen + with pytest.raises(expected_exception=ValueError): + thl_lm.create_tx_user_payout_complete( + user=user, + payout_event=pe, + fee_amount=None, + skip_flag_check=False, + ) + + # (1) Make a request first + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + skip_flag_check=True, + skip_wallet_balance_check=True, + ) + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == rand_amount * -1 + + # (2) Complete the request + tx = thl_lm.create_tx_user_payout_complete( + user=user, + payout_event=pe, + fee_amount=Decimal(0), + skip_flag_check=False, + ) + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + assert tx.tag == f"{currency.value}:user_payout:{pe.uuid}:complete" + assert isinstance(tx, LedgerTransaction) + + # The amount that comes out of the user wallet doesn't change after + # it's approved becuase it's already been withdrawn + assert lm.get_account_balance(account=user_account) == rand_amount * -1 + + def test_create_tx_user_payout_complete_( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + ): + user: User = user_factory(product=product_user_wallet_yes) + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + rand_amount = randint(100, 1_000) + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=rand_amount, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + # (1) Make a request first + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + skip_flag_check=True, + skip_wallet_balance_check=True, + ) + + # (2) Complete the request + rand_desc = uuid4().hex + + bp_expense_account = thl_lm.get_account_or_create_bp_expense( + product=user.product, expense_name="paypal" + ) + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=user.product) + + tx = thl_lm.create_tx_user_payout_complete_( + user=user, + payout_event=pe, + fee_amount=Decimal("0.00"), + fee_expense_account=bp_expense_account, + fee_payer_account=bp_wallet_account, + description=rand_desc, + ) + assert tx.ext_description == rand_desc + assert lm.get_account_balance(account=user_account) == rand_amount * -1 + + def test_create_tx_user_payout_cancelled( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + currency, + ): + user: User = user_factory(product=product_user_wallet_yes) + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + rand_amount = randint(100, 1_000) + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=rand_amount, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + # (1) Make a request first + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + skip_flag_check=True, + skip_wallet_balance_check=True, + ) + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == rand_amount * -1 + + # (2) Cancel the request + tx = thl_lm.create_tx_user_payout_cancelled( + user=user, + payout_event=pe, + skip_flag_check=False, + ) + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + assert tx.tag == f"{currency.value}:user_payout:{pe.uuid}:cancel" + assert isinstance(tx, LedgerTransaction) + + # Assert the balance comes back to 0 after it was cancelled + assert lm.get_account_balance(account=user_account) == 0 + + def test_create_tx_user_payout_cancelled_( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + currency, + ): + user: User = user_factory(product=product_user_wallet_yes) + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + rand_amount = randint(100, 1_000) + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=rand_amount, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + # (1) Make a request first + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + skip_flag_check=True, + skip_wallet_balance_check=True, + ) + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == rand_amount * -1 + + # (2) Cancel the request + rand_desc = uuid4().hex + tx = thl_lm.create_tx_user_payout_cancelled_( + user=user, payout_event=pe, description=rand_desc + ) + assert isinstance(tx, LedgerTransaction) + assert tx.ext_description == rand_desc + assert lm.get_account_balance(account=user_account) == 0 + + def test_create_tx_user_bonus( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + currency, + ): + user: User = user_factory(product=product_user_wallet_yes) + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + rand_amount = randint(100, 1_000) + rand_ref_uuid = uuid4().hex + rand_desc = uuid4().hex + + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == 0 + + tx = thl_lm.create_tx_user_bonus( + user=user, + amount=Decimal(rand_amount / 100), + ref_uuid=rand_ref_uuid, + description=rand_desc, + skip_flag_check=True, + ) + assert tx.ext_description == rand_desc + assert tx.tag == f"{thl_lm.currency.value}:user_bonus:{rand_ref_uuid}" + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == rand_amount + + def test_create_tx_user_bonus_( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + currency, + ): + user: User = user_factory(product=product_user_wallet_yes) + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + rand_amount = randint(100, 1_000) + rand_ref_uuid = uuid4().hex + rand_desc = uuid4().hex + + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == 0 + + tx = thl_lm.create_tx_user_bonus_( + user=user, + amount=Decimal(rand_amount / 100), + ref_uuid=rand_ref_uuid, + description=rand_desc, + ) + assert tx.ext_description == rand_desc + assert tx.tag == f"{thl_lm.currency.value}:user_bonus:{rand_ref_uuid}" + assert tx.entries[0].amount == rand_amount + assert tx.entries[1].amount == rand_amount + + # Assert the balance came out of their user wallet + assert lm.get_account_balance(account=user_account) == rand_amount + + +class TestThlLedgerTxManagerFlows: + """Combine the various THL_LM methods to create actual "real world" + examples + """ + + def test_create_tx_task_complete( + self, user, create_main_accounts, thl_lm, lm, currency, delete_ledger_db + ): + delete_ledger_db() + create_main_accounts() + + wall1 = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.23"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started) + + wall2 = Wall( + user_id=1, + source=Source.FULL_CIRCLE, + req_survey_id="yyy", + req_cpi=Decimal("3.21"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + thl_lm.create_tx_task_complete(wall=wall2, user=user, created=wall2.started) + + cash = thl_lm.get_account_cash() + revenue = thl_lm.get_account_task_complete_revenue() + + assert lm.get_account_balance(cash) == 123 + 321 + assert lm.get_account_balance(revenue) == 123 + 321 + assert lm.check_ledger_balanced() + + assert ( + lm.get_account_filtered_balance( + account=revenue, metadata_key="source", metadata_value="d" + ) + == 123 + ) + + assert ( + lm.get_account_filtered_balance( + account=revenue, metadata_key="source", metadata_value="f" + ) + == 321 + ) + + assert ( + lm.get_account_filtered_balance( + account=revenue, metadata_key="source", metadata_value="x" + ) + == 0 + ) + + assert ( + thl_lm.get_account_filtered_balance( + account=revenue, + metadata_key="thl_wall", + metadata_value=wall1.uuid, + ) + == 123 + ) + + def test_create_transaction_task_complete_1_cent( + self, user, create_main_accounts, thl_lm, lm, currency + ): + wall1 = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("0.007"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + tx = thl_lm.create_tx_task_complete( + wall=wall1, user=user, created=wall1.started + ) + + assert isinstance(tx, LedgerTransaction) + + def test_create_transaction_bp_payment( + self, + user, + create_main_accounts, + thl_lm, + lm, + currency, + delete_ledger_db, + session_factory, + utc_hour_ago, + ): + delete_ledger_db() + create_main_accounts() + + s1: Session = session_factory( + user=user, + wall_count=1, + started=utc_hour_ago, + wall_source=Source.TESTING, + ) + w1: Wall = s1.wall_events[0] + + tx = thl_lm.create_tx_task_complete(wall=w1, user=user, created=w1.started) + assert isinstance(tx, LedgerTransaction) + + status, status_code_1 = s1.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments() + s1.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": s1.started + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + print(thl_net, commission_amount, bp_pay, user_pay) + thl_lm.create_tx_bp_payment(session=s1, created=w1.started) + + revenue = thl_lm.get_account_task_complete_revenue() + bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=user.product) + bp_commission = thl_lm.get_account_or_create_bp_commission(product=user.product) + + assert 0 == lm.get_account_balance(account=revenue) + assert 50 == lm.get_account_filtered_balance( + account=revenue, + metadata_key="source", + metadata_value=Source.TESTING, + ) + assert 48 == lm.get_account_balance(account=bp_wallet) + assert 48 == lm.get_account_filtered_balance( + account=bp_wallet, + metadata_key="thl_session", + metadata_value=s1.uuid, + ) + assert 2 == thl_lm.get_account_balance(account=bp_commission) + assert thl_lm.check_ledger_balanced() + + def test_create_transaction_bp_payment_round( + self, + user_factory, + product_user_wallet_no, + create_main_accounts, + thl_lm, + lm, + currency, + ): + product_user_wallet_no.commission_pct = Decimal("0.085") + user: User = user_factory(product=product_user_wallet_no) + + wall1 = Wall( + user_id=user.user_id, + source=Source.SAGO, + req_survey_id="xxx", + req_cpi=Decimal("0.287"), + session_id=3, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + + tx = thl_lm.create_tx_task_complete( + wall=wall1, user=user, created=wall1.started + ) + assert isinstance(tx, LedgerTransaction) + + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": session.started + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + + print(thl_net, commission_amount, bp_pay, user_pay) + tx = thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + assert isinstance(tx, LedgerTransaction) + + def test_create_transaction_bp_payment_round2( + self, delete_ledger_db, user, create_main_accounts, thl_lm, lm, currency + ): + delete_ledger_db() + create_main_accounts() + # user must be no user wallet + # e.g. session 869b5bfa47f44b4f81cd095ed01df2ff this fails if you dont round properly + + wall1 = Wall( + user_id=user.user_id, + source=Source.SAGO, + req_survey_id="xxx", + req_cpi=Decimal("1.64500"), + session_id=3, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + + thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started) + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + # thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": session.started + timedelta(minutes=10), + "payout": Decimal("1.53"), + "user_payout": Decimal("1.53"), + } + ) + + thl_lm.create_tx_bp_payment(session=session, created=wall1.started) + + def test_create_transaction_bp_payment_round3( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + currency, + ): + # e.g. session ___ fails b/c we rounded incorrectly + # before, and now we are off by a penny... + user: User = user_factory(product=product_user_wallet_yes) + + wall1 = Wall( + user_id=user.user_id, + source=Source.SAGO, + req_survey_id="xxx", + req_cpi=Decimal("0.385"), + session_id=3, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + thl_lm.create_tx_task_complete(wall=wall1, user=user, created=wall1.started) + + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + # thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": session.started + timedelta(minutes=10), + "payout": Decimal("0.39"), + "user_payout": Decimal("0.26"), + } + ) + # with pytest.logs(logger, level=logging.WARNING) as cm: + # tx = thl_lm.create_transaction_bp_payment(session, created=wall1.started) + # assert "Capping bp_pay to thl_net" in cm.output[0] + + def test_create_transaction_bp_payment_user_wallet( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + delete_ledger_db, + thl_lm, + session_manager, + wall_manager, + lm, + session_factory, + currency, + utc_hour_ago, + ): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_user_wallet_yes) + assert user.product.user_wallet_enabled + + s1: Session = session_factory( + user=user, + wall_count=1, + started=utc_hour_ago, + wall_req_cpi=Decimal(".50"), + wall_source=Source.TESTING, + ) + w1: Wall = s1.wall_events[0] + + thl_lm.create_tx_task_complete(wall=w1, user=user, created=w1.started) + + status, status_code_1 = s1.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments() + session_manager.finish_with_status( + session=s1, + status=status, + status_code_1=status_code_1, + finished=s1.started + timedelta(minutes=10), + payout=bp_pay, + user_payout=user_pay, + ) + thl_lm.create_tx_bp_payment(session=s1, created=w1.started) + + revenue = thl_lm.get_account_task_complete_revenue() + bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=user.product) + bp_commission = thl_lm.get_account_or_create_bp_commission(product=user.product) + user_wallet = thl_lm.get_account_or_create_user_wallet(user=user) + + assert 0 == thl_lm.get_account_balance(account=revenue) + assert 50 == thl_lm.get_account_filtered_balance( + account=revenue, + metadata_key="source", + metadata_value=Source.TESTING, + ) + + assert 48 - 19 == thl_lm.get_account_balance(account=bp_wallet) + assert 48 - 19 == thl_lm.get_account_filtered_balance( + account=bp_wallet, + metadata_key="thl_session", + metadata_value=s1.uuid, + ) + assert 2 == thl_lm.get_account_balance(bp_commission) + assert 19 == thl_lm.get_account_balance(user_wallet) + assert 19 == thl_lm.get_account_filtered_balance( + account=user_wallet, + metadata_key="thl_session", + metadata_value=s1.uuid, + ) + + assert 0 == thl_lm.get_account_filtered_balance( + account=user_wallet, metadata_key="thl_session", metadata_value="x" + ) + assert thl_lm.check_ledger_balanced() + + +class TestThlLedgerManagerAdj: + + def test_create_tx_task_adjustment( + self, + user_factory, + product_user_wallet_no, + create_main_accounts, + delete_ledger_db, + thl_lm, + lm, + utc_hour_ago, + currency, + ): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_user_wallet_no) + + wall1 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.23"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=utc_hour_ago, + finished=utc_hour_ago + timedelta(seconds=1), + ) + + thl_lm.create_tx_task_complete(wall1, user, created=wall1.started) + + wall2 = Wall( + user_id=1, + source=Source.FULL_CIRCLE, + req_survey_id="yyy", + req_cpi=Decimal("3.21"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=utc_hour_ago, + finished=utc_hour_ago + timedelta(seconds=1), + ) + thl_lm.create_tx_task_complete(wall2, user, created=wall2.started) + + wall1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + print(wall1.get_cpi_after_adjustment()) + thl_lm.create_tx_task_adjustment(wall1, user) + + cash = thl_lm.get_account_cash() + revenue = thl_lm.get_account_task_complete_revenue() + + assert 123 + 321 - 123 == thl_lm.get_account_balance(account=cash) + assert 123 + 321 - 123 == thl_lm.get_account_balance(account=revenue) + assert thl_lm.check_ledger_balanced() + assert 0 == thl_lm.get_account_filtered_balance( + revenue, metadata_key="source", metadata_value="d" + ) + assert 321 == thl_lm.get_account_filtered_balance( + revenue, metadata_key="source", metadata_value="f" + ) + assert 0 == thl_lm.get_account_filtered_balance( + revenue, metadata_key="source", metadata_value="x" + ) + assert 123 - 123 == thl_lm.get_account_filtered_balance( + account=revenue, metadata_key="thl_wall", metadata_value=wall1.uuid + ) + + # un-reconcile it + wall1.update( + adjusted_status=None, + adjusted_cpi=None, + adjusted_timestamp=utc_hour_ago + timedelta(minutes=45), + ) + print(wall1.get_cpi_after_adjustment()) + thl_lm.create_tx_task_adjustment(wall1, user) + # and then run it again to make sure it does nothing + thl_lm.create_tx_task_adjustment(wall1, user) + + cash = thl_lm.get_account_cash() + revenue = thl_lm.get_account_task_complete_revenue() + + assert 123 + 321 - 123 + 123 == thl_lm.get_account_balance(cash) + assert 123 + 321 - 123 + 123 == thl_lm.get_account_balance(revenue) + assert thl_lm.check_ledger_balanced() + assert 123 == thl_lm.get_account_filtered_balance( + account=revenue, metadata_key="source", metadata_value="d" + ) + assert 321 == thl_lm.get_account_filtered_balance( + account=revenue, metadata_key="source", metadata_value="f" + ) + assert 0 == thl_lm.get_account_filtered_balance( + account=revenue, metadata_key="source", metadata_value="x" + ) + assert 123 - 123 + 123 == thl_lm.get_account_filtered_balance( + account=revenue, metadata_key="thl_wall", metadata_value=wall1.uuid + ) + + def test_create_tx_bp_adjustment( + self, + user, + product_user_wallet_no, + create_main_accounts, + caplog, + thl_lm, + lm, + currency, + session_manager, + wall_manager, + session_factory, + utc_hour_ago, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpis=[Decimal(1), Decimal(3)], + wall_statuses=[Status.COMPLETE, Status.COMPLETE], + started=utc_hour_ago, + ) + + w1: Wall = s1.wall_events[0] + w2: Wall = s1.wall_events[1] + + thl_lm.create_tx_task_complete(wall=w1, user=user, created=w1.started) + thl_lm.create_tx_task_complete(wall=w2, user=user, created=w2.started) + + status, status_code_1 = s1.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments() + session_manager.finish_with_status( + session=s1, + status=status, + status_code_1=status_code_1, + finished=utc_hour_ago + timedelta(minutes=10), + payout=bp_pay, + user_payout=user_pay, + ) + thl_lm.create_tx_bp_payment(session=s1, created=w1.started) + revenue = thl_lm.get_account_task_complete_revenue() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=user.product) + bp_commission_account = thl_lm.get_account_or_create_bp_commission( + product=user.product + ) + assert 380 == thl_lm.get_account_balance(account=bp_wallet_account) + assert 0 == thl_lm.get_account_balance(account=revenue) + assert 20 == thl_lm.get_account_balance(account=bp_commission_account) + thl_lm.check_ledger_balanced() + + # This should do nothing (since we haven't adjusted any wall events) + s1.adjust_status() + with caplog.at_level(logging.INFO): + thl_lm.create_tx_bp_adjustment(session=s1) + + assert ( + "create_transaction_bp_adjustment. No transactions needed." in caplog.text + ) + + # self.assertEqual(380, ledger_manager.get_account_balance(bp_wallet_account)) + # self.assertEqual(0, ledger_manager.get_account_balance(revenue)) + # self.assertEqual(20, ledger_manager.get_account_balance(bp_commission_account)) + # self.assertTrue(ledger_manager.check_ledger_balanced()) + + # recon $1 survey. + wall_manager.adjust_status( + wall=w1, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=Decimal(0), + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + thl_lm.create_tx_task_adjustment(wall=w1, user=user) + # -$1.00 b/c the MP took the $1 back, but we haven't yet taken the BP payment back + assert -100 == thl_lm.get_account_balance(revenue) + s1.adjust_status() + thl_lm.create_tx_bp_adjustment(session=s1) + + with caplog.at_level(logging.INFO): + thl_lm.create_tx_bp_adjustment(session=s1) + assert ( + "create_transaction_bp_adjustment. No transactions needed." in caplog.text + ) + + assert 380 - 95 == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20 - 5 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # unrecon the $1 survey + wall_manager.adjust_status( + wall=w1, + adjusted_status=None, + adjusted_cpi=None, + adjusted_timestamp=utc_hour_ago + timedelta(minutes=45), + ) + thl_lm.create_tx_task_adjustment( + wall=w1, + user=user, + created=utc_hour_ago + timedelta(minutes=45), + ) + new_status, new_payout, new_user_payout = s1.determine_new_status_and_payouts() + s1.adjust_status() + thl_lm.create_tx_bp_adjustment(session=s1) + assert 380 == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20, thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + def test_create_tx_bp_adjustment_small( + self, + user_factory, + product_user_wallet_no, + create_main_accounts, + delete_ledger_db, + thl_lm, + lm, + utc_hour_ago, + currency, + ): + delete_ledger_db() + create_main_accounts() + + # This failed when I didn't check that `change_commission` > 0 in + # create_transaction_bp_adjustment + user: User = user_factory(product=product_user_wallet_no) + + wall1 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("0.10"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=utc_hour_ago, + finished=utc_hour_ago + timedelta(seconds=1), + ) + + tx = thl_lm.create_tx_task_complete( + wall=wall1, user=user, created=wall1.started + ) + assert isinstance(tx, LedgerTransaction) + + session = Session(started=wall1.started, user=user, wall_events=[wall1]) + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": utc_hour_ago + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + thl_lm.create_tx_bp_payment(session, created=wall1.started) + + wall1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + thl_lm.create_tx_task_adjustment(wall1, user) + session.adjust_status() + thl_lm.create_tx_bp_adjustment(session) + + def test_create_tx_bp_adjustment_abandon( + self, + user_factory, + product_user_wallet_no, + delete_ledger_db, + session_factory, + create_main_accounts, + caplog, + thl_lm, + lm, + currency, + utc_hour_ago, + session_manager, + wall_manager, + ): + delete_ledger_db() + create_main_accounts() + user: User = user_factory(product=product_user_wallet_no) + s1: Session = session_factory( + user=user, final_status=Status.ABANDON, wall_req_cpi=Decimal(1) + ) + w1 = s1.wall_events[-1] + + # Adjust to complete. + wall_manager.adjust_status( + wall=w1, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w1.cpi, + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + thl_lm.create_tx_task_adjustment(wall=w1, user=user) + s1.adjust_status() + thl_lm.create_tx_bp_adjustment(session=s1) + # And then adjust it back (it was abandon before, but now it should be + # fail (?) or back to abandon?) + wall_manager.adjust_status( + wall=w1, + adjusted_status=None, + adjusted_cpi=None, + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + thl_lm.create_tx_task_adjustment(wall=w1, user=user) + s1.adjust_status() + thl_lm.create_tx_bp_adjustment(session=s1) + + revenue = thl_lm.get_account_task_complete_revenue() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=user.product) + bp_commission_account = thl_lm.get_account_or_create_bp_commission( + product=user.product + ) + assert 0 == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 0 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # This should do nothing + s1.adjust_status() + with caplog.at_level(logging.INFO): + thl_lm.create_tx_bp_adjustment(session=s1) + assert "No transactions needed" in caplog.text + + # Now back to complete again + wall_manager.adjust_status( + wall=w1, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w1.cpi, + adjusted_timestamp=utc_hour_ago + timedelta(hours=1), + ) + s1.adjust_status() + thl_lm.create_tx_bp_adjustment(session=s1) + assert 95 == thl_lm.get_account_balance(bp_wallet_account) + + def test_create_tx_bp_adjustment_user_wallet( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + delete_ledger_db, + caplog, + thl_lm, + lm, + currency, + ): + delete_ledger_db() + create_main_accounts() + + now = datetime.now(timezone.utc) - timedelta(days=1) + user: User = user_factory(product=product_user_wallet_yes) + + # Create 2 Wall completes and create the respective transaction for + # them. We then create a 3rd wall event which is a failure but we + # do NOT create a transaction for it + + wall3 = Wall( + user_id=8, + source=Source.CINT, + req_survey_id="zzz", + req_cpi=Decimal("2.00"), + session_id=1, + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + started=now, + finished=now + timedelta(minutes=1), + ) + + now_w1 = now + timedelta(minutes=1, milliseconds=1) + wall1 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.00"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=now_w1, + finished=now_w1 + timedelta(minutes=1), + ) + tx = thl_lm.create_tx_task_complete( + wall=wall1, user=user, created=wall1.started + ) + assert isinstance(tx, LedgerTransaction) + + now_w2 = now + timedelta(minutes=2, milliseconds=1) + wall2 = Wall( + user_id=user.user_id, + source=Source.FULL_CIRCLE, + req_survey_id="yyy", + req_cpi=Decimal("3.00"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=now_w2, + finished=now_w2 + timedelta(minutes=1), + ) + tx = thl_lm.create_tx_task_complete( + wall=wall2, user=user, created=wall2.started + ) + assert isinstance(tx, LedgerTransaction) + + # It doesn't matter what order these wall events go in as because + # we have a pydantic valiator that sorts them + wall_events = [wall3, wall1, wall2] + # shuffle(wall_events) + session = Session(started=wall1.started, user=user, wall_events=wall_events) + status, status_code_1 = session.determine_session_status() + assert status == Status.COMPLETE + assert status_code_1 == StatusCode1.COMPLETE + + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + assert thl_net == Decimal("4.00") + assert commission_amount == Decimal("0.20") + assert bp_pay == Decimal("3.80") + assert user_pay == Decimal("1.52") + + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": now + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + + tx = thl_lm.create_tx_bp_adjustment(session=session, created=wall1.started) + assert isinstance(tx, LedgerTransaction) + + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(product=user.product) + assert 228 == thl_lm.get_account_balance(account=bp_wallet_account) + + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + assert 152 == thl_lm.get_account_balance(account=user_account) + + revenue = thl_lm.get_account_task_complete_revenue() + assert 0 == thl_lm.get_account_balance(account=revenue) + + bp_commission_account = thl_lm.get_account_or_create_bp_commission( + product=user.product + ) + assert 20 == thl_lm.get_account_balance(account=bp_commission_account) + + # the total (4.00) = 2.28 + 1.52 + .20 + assert thl_lm.check_ledger_balanced() + + # This should do nothing (since we haven't adjusted any wall events) + session.adjust_status() + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + with caplog.at_level(logging.INFO): + thl_lm.create_tx_bp_adjustment(session) + assert ( + "create_transaction_bp_adjustment. No transactions needed." in caplog.text + ) + + # recon $1 survey. + wall1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=now + timedelta(hours=1), + ) + thl_lm.create_tx_task_adjustment(wall1, user) + # -$1.00 b/c the MP took the $1 back, but we haven't yet taken the BP payment back + assert -100 == thl_lm.get_account_balance(revenue) + session.adjust_status() + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + thl_lm.create_tx_bp_adjustment(session) + + # running this twice b/c it should do nothing the 2nd time + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + with caplog.at_level(logging.INFO): + thl_lm.create_tx_bp_adjustment(session) + assert ( + "create_transaction_bp_adjustment. No transactions needed." in caplog.text + ) + + assert 228 - 57 == thl_lm.get_account_balance(bp_wallet_account) + assert 152 - 38 == thl_lm.get_account_balance(user_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20 - 5 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # unrecon the $1 survey + wall1.update( + adjusted_status=None, + adjusted_cpi=None, + adjusted_timestamp=now + timedelta(hours=2), + ) + tx = thl_lm.create_tx_task_adjustment(wall=wall1, user=user) + assert isinstance(tx, LedgerTransaction) + + new_status, new_payout, new_user_payout = ( + session.determine_new_status_and_payouts() + ) + print(new_status, new_payout, new_user_payout, session.adjusted_payout) + session.adjust_status() + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + thl_lm.create_tx_bp_adjustment(session) + + assert 228 - 57 + 57 == thl_lm.get_account_balance(bp_wallet_account) + assert 152 - 38 + 38 == thl_lm.get_account_balance(user_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20 - 5 + 5 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # make the $2 failure into a complete also + wall3.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=wall3.cpi, + adjusted_timestamp=now + timedelta(hours=2), + ) + thl_lm.create_tx_task_adjustment(wall3, user) + new_status, new_payout, new_user_payout = ( + session.determine_new_status_and_payouts() + ) + print(new_status, new_payout, new_user_payout, session.adjusted_payout) + session.adjust_status() + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + thl_lm.create_tx_bp_adjustment(session) + assert 228 - 57 + 57 + 114 == thl_lm.get_account_balance(bp_wallet_account) + assert 152 - 38 + 38 + 76 == thl_lm.get_account_balance(user_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20 - 5 + 5 + 10 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + def test_create_transaction_bp_adjustment_cpi_adjustment( + self, + user_factory, + product_user_wallet_no, + create_main_accounts, + delete_ledger_db, + caplog, + thl_lm, + lm, + utc_hour_ago, + currency, + ): + delete_ledger_db() + create_main_accounts() + user: User = user_factory(product=product_user_wallet_no) + + wall1 = Wall( + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.00"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=utc_hour_ago, + finished=utc_hour_ago + timedelta(seconds=1), + ) + tx = thl_lm.create_tx_task_complete( + wall=wall1, user=user, created=wall1.started + ) + assert isinstance(tx, LedgerTransaction) + + wall2 = Wall( + user_id=user.user_id, + source=Source.FULL_CIRCLE, + req_survey_id="yyy", + req_cpi=Decimal("3.00"), + session_id=1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + started=utc_hour_ago, + finished=utc_hour_ago + timedelta(seconds=1), + ) + tx = thl_lm.create_tx_task_complete( + wall=wall2, user=user, created=wall2.started + ) + assert isinstance(tx, LedgerTransaction) + + session = Session(started=wall1.started, user=user, wall_events=[wall1, wall2]) + status, status_code_1 = session.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = session.determine_payments() + session.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": utc_hour_ago + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + thl_lm.create_tx_bp_payment(session, created=wall1.started) + + revenue = thl_lm.get_account_task_complete_revenue() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_commission_account = thl_lm.get_account_or_create_bp_commission(user.product) + assert 380 == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # cpi adjustment $1 -> $.60. + wall1.update( + adjusted_status=WallAdjustedStatus.CPI_ADJUSTMENT, + adjusted_cpi=Decimal("0.60"), + adjusted_timestamp=utc_hour_ago + timedelta(minutes=30), + ) + thl_lm.create_tx_task_adjustment(wall1, user) + + # -$0.40 b/c the MP took $0.40 back, but we haven't yet taken the BP payment back + assert -40 == thl_lm.get_account_balance(revenue) + session.adjust_status() + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + thl_lm.create_tx_bp_adjustment(session) + + # running this twice b/c it should do nothing the 2nd time + print( + session.get_status_after_adjustment(), + session.get_payout_after_adjustment(), + session.get_user_payout_after_adjustment(), + ) + with caplog.at_level(logging.INFO): + thl_lm.create_tx_bp_adjustment(session) + assert "create_transaction_bp_adjustment." in caplog.text + assert "No transactions needed." in caplog.text + + assert 380 - 38 == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 20 - 2 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # adjust it to failure + wall1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=utc_hour_ago + timedelta(minutes=45), + ) + thl_lm.create_tx_task_adjustment(wall1, user) + session.adjust_status() + thl_lm.create_tx_bp_adjustment(session) + assert 300 - (300 * 0.05) == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 300 * 0.05 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # and then back to cpi adj again, but this time for more than the orig amount + wall1.update( + adjusted_status=WallAdjustedStatus.CPI_ADJUSTMENT, + adjusted_cpi=Decimal("2.00"), + adjusted_timestamp=utc_hour_ago + timedelta(minutes=45), + ) + thl_lm.create_tx_task_adjustment(wall1, user) + session.adjust_status() + thl_lm.create_tx_bp_adjustment(session) + assert 500 - (500 * 0.05) == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(revenue) + assert 500 * 0.05 == thl_lm.get_account_balance(bp_commission_account) + assert thl_lm.check_ledger_balanced() + + # And adjust again + wall1.update( + adjusted_status=WallAdjustedStatus.CPI_ADJUSTMENT, + adjusted_cpi=Decimal("3.00"), + adjusted_timestamp=utc_hour_ago + timedelta(minutes=45), + ) + thl_lm.create_tx_task_adjustment(wall=wall1, user=user) + session.adjust_status() + thl_lm.create_tx_bp_adjustment(session=session) + assert 600 - (600 * 0.05) == thl_lm.get_account_balance( + account=bp_wallet_account + ) + assert 0 == thl_lm.get_account_balance(account=revenue) + assert 600 * 0.05 == thl_lm.get_account_balance(account=bp_commission_account) + assert thl_lm.check_ledger_balanced() diff --git a/tests/managers/thl/test_ledger/test_thl_lm_tx__user_payouts.py b/tests/managers/thl/test_ledger/test_thl_lm_tx__user_payouts.py new file mode 100644 index 0000000..1e7146a --- /dev/null +++ b/tests/managers/thl/test_ledger/test_thl_lm_tx__user_payouts.py @@ -0,0 +1,505 @@ +import logging +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from uuid import uuid4 + +import pytest + +from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerTransactionFlagAlreadyExistsError, + LedgerTransactionConditionFailedError, +) +from generalresearch.models.thl.user import User +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.thl.payout import UserPayoutEvent +from test_utils.managers.ledger.conftest import create_main_accounts + + +class TestLedgerManagerAMT: + + def test_create_transaction_amt_ass_request( + self, + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + user: User = user_factory(product=product_amt_true) + + # debit_account_uuid nothing checks they match the ledger ... todo? + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.AMT_HIT, + amount=5, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + flag_key = f"test:user_payout:{pe.uuid}:request" + flag_name = f"ledger-manager:transaction_flag:{flag_key}" + lm.redis_client.delete(flag_name) + + # User has $0 in their wallet. They are allowed amt_assignment payouts until -$1.00 + thl_lm.create_tx_user_payout_request(user=user, payout_event=pe) + with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError): + thl_lm.create_tx_user_payout_request( + user=user, payout_event=pe, skip_flag_check=False + ) + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + thl_lm.create_tx_user_payout_request( + user=user, payout_event=pe, skip_flag_check=True + ) + pe2 = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.AMT_HIT, + amount=96, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + flag_key = f"test:user_payout:{pe2.uuid}:request" + flag_name = f"ledger-manager:transaction_flag:{flag_key}" + lm.redis_client.delete(flag_name) + # 96 cents would put them over the -$1.00 limit + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + thl_lm.create_tx_user_payout_request(user, payout_event=pe2) + + # But they could do 0.95 cents + pe2.amount = 95 + thl_lm.create_tx_user_payout_request( + user, payout_event=pe2, skip_flag_check=True + ) + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + product=user.product + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user=user) + + assert 0 == lm.get_account_balance(account=bp_wallet_account) + assert 0 == lm.get_account_balance(account=cash) + assert 100 == lm.get_account_balance(account=bp_pending_account) + assert -100 == lm.get_account_balance(account=user_wallet_account) + assert thl_lm.check_ledger_balanced() + assert -5 == thl_lm.get_account_filtered_balance( + account=user_wallet_account, + metadata_key="payoutevent", + metadata_value=pe.uuid, + ) + + assert -95 == thl_lm.get_account_filtered_balance( + account=user_wallet_account, + metadata_key="payoutevent", + metadata_value=pe2.uuid, + ) + + def test_create_transaction_amt_ass_complete( + self, + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + user: User = user_factory(product=product_amt_true) + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.AMT_HIT, + amount=5, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + flag = f"ledger-manager:transaction_flag:test:user_payout:{pe.uuid}:request" + lm.redis_client.delete(flag) + flag = f"ledger-manager:transaction_flag:test:user_payout:{pe.uuid}:complete" + lm.redis_client.delete(flag) + + # User has $0 in their wallet. They are allowed amt_assignment payouts until -$1.00 + thl_lm.create_tx_user_payout_request(user, payout_event=pe) + thl_lm.create_tx_user_payout_complete(user, payout_event=pe) + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + user.product + ) + bp_amt_expense_account = thl_lm.get_account_or_create_bp_expense( + user.product, expense_name="amt" + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user) + + # BP wallet pays the 1cent fee + assert -1 == thl_lm.get_account_balance(bp_wallet_account) + assert -5 == thl_lm.get_account_balance(cash) + assert -1 == thl_lm.get_account_balance(bp_amt_expense_account) + assert 0 == thl_lm.get_account_balance(bp_pending_account) + assert -5 == lm.get_account_balance(user_wallet_account) + assert thl_lm.check_ledger_balanced() + + def test_create_transaction_amt_bonus( + self, + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_amt_true) + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.AMT_BONUS, + amount=34, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + flag = f"ledger-manager:transaction_flag:test:user_payout:{pe.uuid}:request" + lm.redis_client.delete(flag) + flag = f"ledger-manager:transaction_flag:test:user_payout:{pe.uuid}:complete" + lm.redis_client.delete(flag) + + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + # User has $0 in their wallet. No amt bonus allowed + thl_lm.create_tx_user_payout_request(user, payout_event=pe) + + thl_lm.create_tx_user_bonus( + user, + amount=Decimal(5), + ref_uuid="e703830dec124f17abed2d697d8d7701", + description="Bribe", + skip_flag_check=True, + ) + pe.amount = 101 + thl_lm.create_tx_user_payout_request( + user, payout_event=pe, skip_flag_check=False + ) + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe, skip_flag_check=False + ) + with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError): + # duplicate, even if amount changed + pe.amount = 200 + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe, skip_flag_check=False + ) + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + # duplicate + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe, skip_flag_check=True + ) + pe.uuid = "533364150de4451198e5774e221a2acb" + pe.amount = 9900 + with pytest.raises(expected_exception=ValueError): + # Trying to complete payout with no pending tx + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe, skip_flag_check=True + ) + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + # trying to payout $99 with only a $5 balance + thl_lm.create_tx_user_payout_request( + user, payout_event=pe, skip_flag_check=True + ) + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + user.product + ) + bp_amt_expense_account = thl_lm.get_account_or_create_bp_expense( + user.product, expense_name="amt" + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user) + assert -500 + round(-101 * 0.20) == thl_lm.get_account_balance( + bp_wallet_account + ) + assert -101 == lm.get_account_balance(cash) + assert -20 == lm.get_account_balance(bp_amt_expense_account) + assert 0 == lm.get_account_balance(bp_pending_account) + assert 500 - 101 == lm.get_account_balance(user_wallet_account) + assert lm.check_ledger_balanced() is True + + def test_create_transaction_amt_bonus_cancel( + self, + user_factory, + product_amt_true, + create_main_accounts, + caplog, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + now = datetime.now(timezone.utc) - timedelta(hours=1) + user: User = user_factory(product=product_amt_true) + + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.AMT_BONUS, + amount=101, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + + thl_lm.create_tx_user_bonus( + user, + amount=Decimal(5), + ref_uuid="c44f4da2db1d421ebc6a5e5241ca4ce6", + description="Bribe", + skip_flag_check=True, + ) + thl_lm.create_tx_user_payout_request( + user, payout_event=pe, skip_flag_check=True + ) + thl_lm.create_tx_user_payout_cancelled( + user, payout_event=pe, skip_flag_check=True + ) + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + with caplog.at_level(logging.WARNING): + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe, skip_flag_check=True + ) + assert "trying to complete payout that was already cancelled" in caplog.text + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + user.product + ) + bp_amt_expense_account = thl_lm.get_account_or_create_bp_expense( + user.product, expense_name="amt" + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user) + assert -500 == thl_lm.get_account_balance(account=bp_wallet_account) + assert 0 == thl_lm.get_account_balance(account=cash) + assert 0 == thl_lm.get_account_balance(account=bp_amt_expense_account) + assert 0 == thl_lm.get_account_balance(account=bp_pending_account) + assert 500 == thl_lm.get_account_balance(account=user_wallet_account) + assert thl_lm.check_ledger_balanced() + + pe2 = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.AMT_BONUS, + amount=200, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + thl_lm.create_tx_user_payout_request( + user, payout_event=pe2, skip_flag_check=True + ) + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe2, skip_flag_check=True + ) + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + with caplog.at_level(logging.WARNING): + thl_lm.create_tx_user_payout_cancelled( + user, payout_event=pe2, skip_flag_check=True + ) + assert "trying to cancel payout that was already completed" in caplog.text + + +class TestLedgerManagerTango: + + def test_create_transaction_tango_request( + self, + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_amt_true) + + # debit_account_uuid nothing checks they match the ledger ... todo? + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.TANGO, + amount=500, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + flag_key = f"test:user_payout:{pe.uuid}:request" + flag_name = f"ledger-manager:transaction_flag:{flag_key}" + lm.redis_client.delete(flag_name) + thl_lm.create_tx_user_bonus( + user, + amount=Decimal(6), + ref_uuid="e703830dec124f17abed2d697d8d7701", + description="Bribe", + skip_flag_check=True, + ) + thl_lm.create_tx_user_payout_request( + user, payout_event=pe, skip_flag_check=True + ) + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + user.product + ) + bp_tango_expense_account = thl_lm.get_account_or_create_bp_expense( + user.product, expense_name="tango" + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user) + assert -600 == thl_lm.get_account_balance(bp_wallet_account) + assert 0 == thl_lm.get_account_balance(cash) + assert 0 == thl_lm.get_account_balance(bp_tango_expense_account) + assert 500 == thl_lm.get_account_balance(bp_pending_account) + assert 600 - 500 == thl_lm.get_account_balance(user_wallet_account) + assert thl_lm.check_ledger_balanced() + + thl_lm.create_tx_user_payout_complete( + user, payout_event=pe, skip_flag_check=True + ) + assert -600 - round(500 * 0.035) == thl_lm.get_account_balance( + bp_wallet_account + ) + assert -500, thl_lm.get_account_balance(cash) + assert round(-500 * 0.035) == thl_lm.get_account_balance( + bp_tango_expense_account + ) + assert 0 == lm.get_account_balance(bp_pending_account) + assert 100 == lm.get_account_balance(user_wallet_account) + assert lm.check_ledger_balanced() + + +class TestLedgerManagerPaypal: + + def test_create_transaction_paypal_request( + self, + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + now = datetime.now(tz=timezone.utc) - timedelta(hours=1) + user: User = user_factory(product=product_amt_true) + + # debit_account_uuid nothing checks they match the ledger ... todo? + pe = UserPayoutEvent( + uuid=uuid4().hex, + payout_type=PayoutType.PAYPAL, + amount=500, + cashout_method_uuid=uuid4().hex, + debit_account_uuid=uuid4().hex, + ) + flag_key = f"test:user_payout:{pe.uuid}:request" + flag_name = f"ledger-manager:transaction_flag:{flag_key}" + lm.redis_client.delete(flag_name) + thl_lm.create_tx_user_bonus( + user=user, + amount=Decimal(6), + ref_uuid="e703830dec124f17abed2d697d8d7701", + description="Bribe", + skip_flag_check=True, + ) + + thl_lm.create_tx_user_payout_request( + user, payout_event=pe, skip_flag_check=True + ) + + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + product=user.product + ) + bp_paypal_expense_account = thl_lm.get_account_or_create_bp_expense( + product=user.product, expense_name="paypal" + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user=user) + assert -600 == lm.get_account_balance(account=bp_wallet_account) + assert 0 == lm.get_account_balance(account=cash) + assert 0 == lm.get_account_balance(account=bp_paypal_expense_account) + assert 500 == lm.get_account_balance(account=bp_pending_account) + assert 600 - 500 == lm.get_account_balance(account=user_wallet_account) + assert thl_lm.check_ledger_balanced() + + thl_lm.create_tx_user_payout_complete( + user=user, payout_event=pe, skip_flag_check=True, fee_amount=Decimal("0.50") + ) + assert -600 - 50 == thl_lm.get_account_balance(bp_wallet_account) + assert -500 == thl_lm.get_account_balance(cash) + assert -50 == thl_lm.get_account_balance(bp_paypal_expense_account) + assert 0 == thl_lm.get_account_balance(bp_pending_account) + assert 100 == thl_lm.get_account_balance(user_wallet_account) + assert thl_lm.check_ledger_balanced() + + +class TestLedgerManagerBonus: + + def test_create_transaction_bonus( + self, + user_factory, + product_user_wallet_yes, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + ): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_user_wallet_yes) + + thl_lm.create_tx_user_bonus( + user=user, + amount=Decimal(5), + ref_uuid="8d0aaf612462448a9ebdd57fab0fc660", + description="Bribe", + skip_flag_check=True, + ) + cash = thl_lm.get_account_cash() + bp_wallet_account = thl_lm.get_account_or_create_bp_wallet(user.product) + bp_pending_account = thl_lm.get_or_create_bp_pending_payout_account( + product=user.product + ) + bp_amt_expense_account = thl_lm.get_account_or_create_bp_expense( + user.product, expense_name="amt" + ) + user_wallet_account = thl_lm.get_account_or_create_user_wallet(user=user) + + assert -500 == lm.get_account_balance(account=bp_wallet_account) + assert 0 == lm.get_account_balance(account=cash) + assert 0 == lm.get_account_balance(account=bp_amt_expense_account) + assert 0 == lm.get_account_balance(account=bp_pending_account) + assert 500 == lm.get_account_balance(account=user_wallet_account) + assert thl_lm.check_ledger_balanced() + + with pytest.raises(expected_exception=LedgerTransactionFlagAlreadyExistsError): + thl_lm.create_tx_user_bonus( + user=user, + amount=Decimal(5), + ref_uuid="8d0aaf612462448a9ebdd57fab0fc660", + description="Bribe", + skip_flag_check=False, + ) + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + thl_lm.create_tx_user_bonus( + user=user, + amount=Decimal(5), + ref_uuid="8d0aaf612462448a9ebdd57fab0fc660", + description="Bribe", + skip_flag_check=True, + ) diff --git a/tests/managers/thl/test_ledger/test_thl_pem.py b/tests/managers/thl/test_ledger/test_thl_pem.py new file mode 100644 index 0000000..5fb9e7d --- /dev/null +++ b/tests/managers/thl/test_ledger/test_thl_pem.py @@ -0,0 +1,251 @@ +import uuid +from random import randint +from uuid import uuid4, UUID + +import pytest + +from generalresearch.currency import USDCent +from generalresearch.models.thl.definitions import PayoutStatus +from generalresearch.models.thl.payout import BrokerageProductPayoutEvent +from generalresearch.models.thl.product import Product +from generalresearch.models.thl.wallet.cashout_method import ( + CashoutRequestInfo, +) + + +class TestThlPayoutEventManager: + + def test_get_by_uuid(self, brokerage_product_payout_event_manager, thl_lm): + """This validates that the method raises an exception if it + fails. There are plenty of other tests that use this method so + it seems silly to duplicate it here again + """ + + with pytest.raises(expected_exception=AssertionError) as excinfo: + brokerage_product_payout_event_manager.get_by_uuid(pe_uuid=uuid4().hex) + assert "expected 1 result, got 0" in str(excinfo.value) + + def test_filter_by( + self, + product_factory, + usd_cent, + bp_payout_event_factory, + thl_lm, + brokerage_product_payout_event_manager, + ): + from generalresearch.models.thl.payout import UserPayoutEvent + + N_PRODUCTS = randint(3, 10) + N_PAYOUT_EVENTS = randint(3, 10) + amounts = [] + products = [] + + for x_idx in range(N_PRODUCTS): + product: Product = product_factory() + thl_lm.get_account_or_create_bp_wallet(product=product) + products.append(product) + brokerage_product_payout_event_manager.set_account_lookup_table( + thl_lm=thl_lm + ) + + for y_idx in range(N_PAYOUT_EVENTS): + pe = bp_payout_event_factory(product=product, usd_cent=usd_cent) + amounts.append(int(usd_cent)) + assert isinstance(pe, BrokerageProductPayoutEvent) + + # We just added Payout Events for Products, now go ahead and + # query for them + accounts = thl_lm.get_accounts_bp_wallet_for_products( + product_uuids=[i.uuid for i in products] + ) + res = brokerage_product_payout_event_manager.filter_by( + debit_account_uuids=[i.uuid for i in accounts] + ) + + assert len(res) == (N_PRODUCTS * N_PAYOUT_EVENTS) + assert sum([i.amount for i in res]) == sum(amounts) + + def test_get_bp_payout_events_for_product( + self, + product_factory, + usd_cent, + bp_payout_event_factory, + brokerage_product_payout_event_manager, + thl_lm, + ): + from generalresearch.models.thl.payout import UserPayoutEvent + + N_PRODUCTS = randint(3, 10) + N_PAYOUT_EVENTS = randint(3, 10) + amounts = [] + products = [] + + for x_idx in range(N_PRODUCTS): + product: Product = product_factory() + products.append(product) + thl_lm.get_account_or_create_bp_wallet(product=product) + brokerage_product_payout_event_manager.set_account_lookup_table( + thl_lm=thl_lm + ) + + for y_idx in range(N_PAYOUT_EVENTS): + pe = bp_payout_event_factory(product=product, usd_cent=usd_cent) + amounts.append(usd_cent) + assert isinstance(pe, BrokerageProductPayoutEvent) + + # We just added 5 Payouts for a specific Product, now go + # ahead and query for them + res = brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products( + thl_ledger_manager=thl_lm, product_uuids=[product.id] + ) + + assert len(res) == N_PAYOUT_EVENTS + + # Now that all the Payouts for all the Products have been added, go + # ahead and query for them + res = ( + brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products( + thl_ledger_manager=thl_lm, product_uuids=[i.uuid for i in products] + ) + ) + + assert len(res) == (N_PRODUCTS * N_PAYOUT_EVENTS) + assert sum([i.amount for i in res]) == sum(amounts) + + @pytest.mark.skip + def test_get_payout_detail(self, user_payout_event_manager): + """This fails because the description coming back is None, but then + it tries to return a PayoutEvent which validates that the + description can't be None + """ + from generalresearch.models.thl.payout import ( + UserPayoutEvent, + PayoutType, + ) + + rand_amount = randint(a=99, b=999) + + pe = user_payout_event_manager.create( + debit_account_uuid=uuid4().hex, + account_reference_type="str-type-random", + account_reference_uuid=uuid4().hex, + cashout_method_uuid=uuid4().hex, + description="Best payout !", + amount=rand_amount, + status=PayoutStatus.PENDING, + ext_ref_id="123", + payout_type=PayoutType.CASH_IN_MAIL, + request_data={"foo": 123}, + order_data={}, + ) + + res = user_payout_event_manager.get_payout_detail(pe_uuid=pe.uuid) + assert isinstance(res, CashoutRequestInfo) + + # def test_filter_by(self): + # raise NotImplementedError + + def test_create(self, user_payout_event_manager): + from generalresearch.models.thl.payout import UserPayoutEvent + + # Confirm the creation method returns back an instance. + pe = user_payout_event_manager.create_dummy() + assert isinstance(pe, UserPayoutEvent) + + # Now query the DB for that PayoutEvent to confirm it was actually + # saved. + res = user_payout_event_manager.get_by_uuid(pe_uuid=pe.uuid) + assert isinstance(res, UserPayoutEvent) + assert UUID(res.uuid) + + # Confirm they're the same + # assert pe.model_dump_json() == res2.model_dump_json() + assert res.description is None + + # def test_update(self): + # raise NotImplementedError + + def test_create_bp_payout( + self, + product, + delete_ledger_db, + create_main_accounts, + thl_lm, + brokerage_product_payout_event_manager, + lm, + ): + from generalresearch.models.thl.payout import UserPayoutEvent + + delete_ledger_db() + create_main_accounts() + + account_bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=product) + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + rand_amount = randint(a=99, b=999) + + # Save a Brokerage Product Payout, so we have something in the + # Payout Event table and the respective ledger TX and Entry rows for it + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(rand_amount), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + assert isinstance(pe, BrokerageProductPayoutEvent) + + # Now try to query for it! + res = thl_lm.get_tx_bp_payouts(account_uuids=[account_bp_wallet.uuid]) + assert len(res) == 1 + res = thl_lm.get_tx_bp_payouts(account_uuids=[uuid4().hex]) + assert len(res) == 0 + + # Confirm it added to the users balance. The amount is negative because + # money was sent to the Brokerage Product, but they didn't have + # any activity that earned them money + bal = lm.get_account_balance(account=account_bp_wallet) + assert rand_amount == bal * -1 + + +class TestBPPayoutEvent: + + def test_get_bp_bp_payout_events_for_products( + self, + product_factory, + bp_payout_event_factory, + usd_cent, + delete_ledger_db, + create_main_accounts, + brokerage_product_payout_event_manager, + thl_lm, + ): + delete_ledger_db() + create_main_accounts() + + N_PAYOUT_EVENTS = randint(3, 10) + amounts = [] + + product: Product = product_factory() + thl_lm.get_account_or_create_bp_wallet(product=product) + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + for y_idx in range(N_PAYOUT_EVENTS): + bp_payout_event_factory(product=product, usd_cent=usd_cent) + amounts.append(usd_cent) + + # Fetch using the _bp_bp_ approach, so we have an + # array of BPPayoutEvents + bp_bp_res = ( + brokerage_product_payout_event_manager.get_bp_bp_payout_events_for_products( + thl_ledger_manager=thl_lm, product_uuids=[product.uuid] + ) + ) + assert isinstance(bp_bp_res, list) + assert sum(amounts) == sum([i.amount for i in bp_bp_res]) + for i in bp_bp_res: + assert isinstance(i, BrokerageProductPayoutEvent) + assert isinstance(i.amount, int) + assert isinstance(i.amount_usd, USDCent) + assert isinstance(i.amount_usd_str, str) + assert i.amount_usd_str[0] == "$" diff --git a/tests/managers/thl/test_ledger/test_user_txs.py b/tests/managers/thl/test_ledger/test_user_txs.py new file mode 100644 index 0000000..d81b244 --- /dev/null +++ b/tests/managers/thl/test_ledger/test_user_txs.py @@ -0,0 +1,288 @@ +from datetime import timedelta, datetime, timezone +from decimal import Decimal +from uuid import uuid4 + +from generalresearch.managers.thl.payout import ( + AMT_ASSIGNMENT_CASHOUT_METHOD, + AMT_BONUS_CASHOUT_METHOD, +) +from generalresearch.managers.thl.user_compensate import user_compensate +from generalresearch.models.thl.definitions import ( + Status, + WallAdjustedStatus, +) +from generalresearch.models.thl.ledger import ( + UserLedgerTransactionTypesSummary, + UserLedgerTransactionTypeSummary, + TransactionType, +) +from generalresearch.models.thl.session import Session +from generalresearch.models.thl.user import User +from generalresearch.models.thl.wallet import PayoutType + + +def test_user_txs( + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + session_with_tx_factory, + adj_to_fail_with_tx_factory, + adj_to_complete_with_tx_factory, + session_factory, + user_payout_event_manager, + utc_now, +): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_amt_true) + account = thl_lm.get_account_or_create_user_wallet(user) + print(f"{account.uuid=}") + + s: Session = session_with_tx_factory(user=user, wall_req_cpi=Decimal("1.00")) + + bribe_uuid = user_compensate( + ledger_manager=thl_lm, + user=user, + amount_int=100, + ) + + pe = user_payout_event_manager.create( + uuid=uuid4().hex, + debit_account_uuid=account.uuid, + cashout_method_uuid=AMT_ASSIGNMENT_CASHOUT_METHOD, + amount=5, + created=utc_now, + payout_type=PayoutType.AMT_HIT, + request_data=dict(), + ) + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + ) + pe = user_payout_event_manager.create( + uuid=uuid4().hex, + debit_account_uuid=account.uuid, + cashout_method_uuid=AMT_BONUS_CASHOUT_METHOD, + amount=127, + created=utc_now, + payout_type=PayoutType.AMT_BONUS, + request_data=dict(), + ) + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + ) + + wall = s.wall_events[-1] + adj_to_fail_with_tx_factory(session=s, created=wall.finished) + + # And a fail -> complete adjustment + s_fail: Session = session_factory( + user=user, + wall_count=1, + final_status=Status.FAIL, + wall_req_cpi=Decimal("2.00"), + ) + adj_to_complete_with_tx_factory(session=s_fail, created=utc_now) + + # txs = thl_lm.get_tx_filtered_by_account(account.uuid) + # print(len(txs), txs) + txs = thl_lm.get_user_txs(user) + assert len(txs.transactions) == 6 + assert txs.total == 6 + assert txs.page == 1 + assert txs.size == 50 + + # print(len(txs.transactions), txs) + d = txs.model_dump_json() + # print(d) + + descriptions = {x.description for x in txs.transactions} + assert descriptions == { + "Compensation Bonus", + "HIT Bonus", + "HIT Reward", + "Task Adjustment", + "Task Complete", + } + amounts = {x.amount for x in txs.transactions} + assert amounts == {-127, 100, 38, -38, -5, 76} + + assert txs.summary == UserLedgerTransactionTypesSummary( + bp_adjustment=UserLedgerTransactionTypeSummary( + entry_count=2, min_amount=-38, max_amount=76, total_amount=76 - 38 + ), + bp_payment=UserLedgerTransactionTypeSummary( + entry_count=1, min_amount=38, max_amount=38, total_amount=38 + ), + user_bonus=UserLedgerTransactionTypeSummary( + entry_count=1, min_amount=100, max_amount=100, total_amount=100 + ), + user_payout_request=UserLedgerTransactionTypeSummary( + entry_count=2, min_amount=-127, max_amount=-5, total_amount=-132 + ), + ) + tx_adj_c = [ + tx for tx in txs.transactions if tx.tx_type == TransactionType.BP_ADJUSTMENT + ] + assert sorted([tx.amount for tx in tx_adj_c]) == [-38, 76] + + +def test_user_txs_pagination( + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + session_with_tx_factory, + adj_to_fail_with_tx_factory, + user_payout_event_manager, + utc_now, +): + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_amt_true) + account = thl_lm.get_account_or_create_user_wallet(user) + print(f"{account.uuid=}") + + for _ in range(12): + user_compensate( + ledger_manager=thl_lm, + user=user, + amount_int=100, + skip_flag_check=True, + ) + + txs = thl_lm.get_user_txs(user, page=1, size=5) + assert len(txs.transactions) == 5 + assert txs.total == 12 + assert txs.page == 1 + assert txs.size == 5 + assert txs.summary.user_bonus.total_amount == 1200 + assert txs.summary.user_bonus.entry_count == 12 + + # Skip to the 3rd page. We made 12, so there are 2 left + txs = thl_lm.get_user_txs(user, page=3, size=5) + assert len(txs.transactions) == 2 + assert txs.total == 12 + assert txs.page == 3 + assert txs.summary.user_bonus.total_amount == 1200 + assert txs.summary.user_bonus.entry_count == 12 + + # Should be empty, not fail + txs = thl_lm.get_user_txs(user, page=4, size=5) + assert len(txs.transactions) == 0 + assert txs.total == 12 + assert txs.page == 4 + assert txs.summary.user_bonus.total_amount == 1200 + assert txs.summary.user_bonus.entry_count == 12 + + # Test filtering. We should pull back only this one + now = datetime.now(tz=timezone.utc) + user_compensate( + ledger_manager=thl_lm, + user=user, + amount_int=100, + skip_flag_check=True, + ) + txs = thl_lm.get_user_txs(user, page=1, size=5, time_start=now) + assert len(txs.transactions) == 1 + assert txs.total == 1 + assert txs.page == 1 + # And the summary is restricted to this time range also! + assert txs.summary.user_bonus.total_amount == 100 + assert txs.summary.user_bonus.entry_count == 1 + + # And filtering with 0 results + now = datetime.now(tz=timezone.utc) + txs = thl_lm.get_user_txs(user, page=1, size=5, time_start=now) + assert len(txs.transactions) == 0 + assert txs.total == 0 + assert txs.page == 1 + assert txs.pages == 0 + # And the summary is restricted to this time range also! + assert txs.summary.user_bonus.total_amount == None + assert txs.summary.user_bonus.entry_count == 0 + + +def test_user_txs_rolling_balance( + user_factory, + product_amt_true, + create_main_accounts, + thl_lm, + lm, + delete_ledger_db, + session_with_tx_factory, + adj_to_fail_with_tx_factory, + user_payout_event_manager, +): + """ + Creates 3 $1.00 bonuses (postive), + then 1 cashout (negative), $1.50 + then 3 more $1.00 bonuses. + Note: pagination + rolling balance will BREAK if txs have + identical timestamps. In practice, they do not. + """ + delete_ledger_db() + create_main_accounts() + + user: User = user_factory(product=product_amt_true) + account = thl_lm.get_account_or_create_user_wallet(user) + + for _ in range(3): + user_compensate( + ledger_manager=thl_lm, + user=user, + amount_int=100, + skip_flag_check=True, + ) + pe = user_payout_event_manager.create( + uuid=uuid4().hex, + debit_account_uuid=account.uuid, + cashout_method_uuid=AMT_BONUS_CASHOUT_METHOD, + amount=150, + payout_type=PayoutType.AMT_BONUS, + request_data=dict(), + ) + thl_lm.create_tx_user_payout_request( + user=user, + payout_event=pe, + ) + for _ in range(3): + user_compensate( + ledger_manager=thl_lm, + user=user, + amount_int=100, + skip_flag_check=True, + ) + + txs = thl_lm.get_user_txs(user, page=1, size=10) + assert txs.transactions[0].balance_after == 100 + assert txs.transactions[1].balance_after == 200 + assert txs.transactions[2].balance_after == 300 + assert txs.transactions[3].balance_after == 150 + assert txs.transactions[4].balance_after == 250 + assert txs.transactions[5].balance_after == 350 + assert txs.transactions[6].balance_after == 450 + + # Ascending order, get 2nd page, make sure the balances include + # the previous txs. (will return last 3 txs) + txs = thl_lm.get_user_txs(user, page=2, size=4) + assert len(txs.transactions) == 3 + assert txs.transactions[0].balance_after == 250 + assert txs.transactions[1].balance_after == 350 + assert txs.transactions[2].balance_after == 450 + + # Descending order, get 1st page. Will + # return most recent 3 txs in desc order + txs = thl_lm.get_user_txs(user, page=1, size=3, order_by="-created") + assert len(txs.transactions) == 3 + assert txs.transactions[0].balance_after == 450 + assert txs.transactions[1].balance_after == 350 + assert txs.transactions[2].balance_after == 250 diff --git a/tests/managers/thl/test_ledger/test_wallet.py b/tests/managers/thl/test_ledger/test_wallet.py new file mode 100644 index 0000000..a0abd7c --- /dev/null +++ b/tests/managers/thl/test_ledger/test_wallet.py @@ -0,0 +1,78 @@ +from decimal import Decimal +from uuid import uuid4 + +import pytest + +from generalresearch.models.thl.product import ( + UserWalletConfig, + PayoutConfig, + PayoutTransformation, + PayoutTransformationPercentArgs, +) +from generalresearch.models.thl.user import User + + +@pytest.fixture() +def schrute_product(product_manager): + return product_manager.create_dummy( + user_wallet_config=UserWalletConfig(enabled=True, amt=False), + payout_config=PayoutConfig( + payout_transformation=PayoutTransformation( + f="payout_transformation_percent", + kwargs=PayoutTransformationPercentArgs(pct=0.4), + ), + payout_format="{payout:,.0f} Schrute Bucks", + ), + ) + + +class TestGetUserWalletBalance: + def test_get_user_wallet_balance_non_managed(self, user, thl_lm): + with pytest.raises( + AssertionError, + match="Can't get wallet balance on non-managed account.", + ): + thl_lm.get_user_wallet_balance(user=user) + + def test_get_user_wallet_balance_managed_0( + self, schrute_product, user_factory, thl_lm + ): + assert ( + schrute_product.payout_config.payout_format == "{payout:,.0f} Schrute Bucks" + ) + user: User = user_factory(schrute_product) + balance = thl_lm.get_user_wallet_balance(user=user) + assert balance == 0 + balance_string = user.product.format_payout_format(Decimal(balance) / 100) + assert balance_string == "0 Schrute Bucks" + redeemable_balance = thl_lm.get_user_redeemable_wallet_balance( + user=user, user_wallet_balance=balance + ) + assert redeemable_balance == 0 + redeemable_balance_string = user.product.format_payout_format( + Decimal(redeemable_balance) / 100 + ) + assert redeemable_balance_string == "0 Schrute Bucks" + + def test_get_user_wallet_balance_managed( + self, schrute_product, user_factory, thl_lm, session_with_tx_factory + ): + user: User = user_factory(schrute_product) + thl_lm.create_tx_user_bonus( + user=user, + amount=Decimal(1), + ref_uuid=uuid4().hex, + description="cheese", + ) + session_with_tx_factory(user=user, wall_req_cpi=Decimal("1.23")) + + # This product has a payout xform of 40% and commission of 5% + # 1.23 * 0.05 = 0.06 of commission + # 1.17 of payout * 0.40 = 0.47 of user pay and (1.17-0.47) 0.70 bp pay + balance = thl_lm.get_user_wallet_balance(user=user) + assert balance == 47 + 100 # plus the $1 bribe + + redeemable_balance = thl_lm.get_user_redeemable_wallet_balance( + user=user, user_wallet_balance=balance + ) + assert redeemable_balance == 20 + 100 diff --git a/tests/managers/thl/test_maxmind.py b/tests/managers/thl/test_maxmind.py new file mode 100644 index 0000000..c588c58 --- /dev/null +++ b/tests/managers/thl/test_maxmind.py @@ -0,0 +1,273 @@ +import json +import logging +from typing import Callable + +import geoip2.models +import pytest +from faker import Faker +from faker.providers.address.en_US import Provider as USAddressProvider + +from generalresearch.managers.thl.ipinfo import GeoIpInfoManager +from generalresearch.managers.thl.maxmind import MaxmindManager +from generalresearch.managers.thl.maxmind.basic import ( + MaxmindBasicManager, +) +from generalresearch.models.thl.ipinfo import ( + GeoIPInformation, + normalize_ip, +) +from generalresearch.models.thl.maxmind.definitions import UserType + +fake = Faker() + +US_STATES = {x.lower() for x in USAddressProvider.states} + +IP_v4_INDIA = "106.203.146.157" +IP_v6_INDIA = "2402:3a80:4649:de3f:0:24:74aa:b601" +IP_v4_US = "174.218.60.101" +IP_v6_US = "2600:1700:ece0:9410:55d:faf3:c15d:6e4" +IP_v6_US_SAME_64 = "2600:1700:ece0:9410:55d:faf3:c15d:aaaa" + + +@pytest.fixture(scope="session") +def delete_ipinfo(thl_web_rw) -> Callable: + def _delete_ipinfo(ip): + thl_web_rw.execute_write( + query="DELETE FROM thl_geoname WHERE geoname_id IN (SELECT geoname_id FROM thl_ipinformation WHERE ip = %s);", + params=[ip], + ) + thl_web_rw.execute_write( + query="DELETE FROM thl_ipinformation WHERE ip = %s;", + params=[ip], + ) + + return _delete_ipinfo + + +class TestMaxmindBasicManager: + + def test_init(self, maxmind_basic_manager): + + assert isinstance(maxmind_basic_manager, MaxmindBasicManager) + + def test_get_basic_ip_information(self, maxmind_basic_manager): + ip = IP_v4_INDIA + maxmind_basic_manager.run_update_geoip_db() + + res1 = maxmind_basic_manager.get_basic_ip_information(ip_address=ip) + assert isinstance(res1, geoip2.models.Country) + assert res1.country.iso_code == "IN" + assert res1.country.name == "India" + + res2 = maxmind_basic_manager.get_basic_ip_information( + ip_address=fake.ipv4_private() + ) + assert res2 is None + + def test_get_country_iso_from_ip_geoip2db(self, maxmind_basic_manager): + ip = IP_v4_INDIA + maxmind_basic_manager.run_update_geoip_db() + + res1 = maxmind_basic_manager.get_country_iso_from_ip_geoip2db(ip=ip) + assert res1 == "in" + + res2 = maxmind_basic_manager.get_country_iso_from_ip_geoip2db( + ip=fake.ipv4_private() + ) + assert res2 is None + + def test_get_basic_ip_information_ipv6(self, maxmind_basic_manager): + ip = IP_v6_INDIA + maxmind_basic_manager.run_update_geoip_db() + + res1 = maxmind_basic_manager.get_basic_ip_information(ip_address=ip) + assert isinstance(res1, geoip2.models.Country) + assert res1.country.iso_code == "IN" + assert res1.country.name == "India" + + +class TestMaxmindManager: + + def test_init(self, thl_web_rr, thl_redis_config, maxmind_manager: MaxmindManager): + instance = MaxmindManager(pg_config=thl_web_rr, redis_config=thl_redis_config) + assert isinstance(instance, MaxmindManager) + assert isinstance(maxmind_manager, MaxmindManager) + + def test_create_basic( + self, + maxmind_manager: MaxmindManager, + geoipinfo_manager: GeoIpInfoManager, + delete_ipinfo, + ): + # This is (currently) an IP in India, and so it should only do the basic lookup + ip = IP_v4_INDIA + delete_ipinfo(ip) + geoipinfo_manager.clear_cache(ip) + assert geoipinfo_manager.get_cache(ip) is None + assert geoipinfo_manager.get_mysql_if_exists(ip) is None + + maxmind_manager.run_ip_information(ip, force_insights=False) + # Check that it is in the cache and in mysql + res = geoipinfo_manager.get_cache(ip) + assert res.ip == ip + assert res.basic + res = geoipinfo_manager.get_mysql(ip) + assert res.ip == ip + assert res.basic + + def test_create_basic_ipv6( + self, + maxmind_manager: MaxmindManager, + geoipinfo_manager: GeoIpInfoManager, + delete_ipinfo, + ): + # This is (currently) an IP in India, and so it should only do the basic lookup + ip = IP_v6_INDIA + normalized_ip, lookup_prefix = normalize_ip(ip) + delete_ipinfo(ip) + geoipinfo_manager.clear_cache(ip) + delete_ipinfo(normalized_ip) + geoipinfo_manager.clear_cache(normalized_ip) + assert geoipinfo_manager.get_cache(ip) is None + assert geoipinfo_manager.get_cache(normalized_ip) is None + assert geoipinfo_manager.get_mysql_if_exists(ip) is None + assert geoipinfo_manager.get_mysql_if_exists(normalized_ip) is None + + maxmind_manager.run_ip_information(ip, force_insights=False) + + # Check that it is in the cache + res = geoipinfo_manager.get_cache(ip) + # The looked up IP (/128) is returned, + assert res.ip == ip + assert res.lookup_prefix == "/64" + assert res.basic + + # ... but the normalized version was stored (/64) + assert geoipinfo_manager.get_cache_raw(ip) is None + res = json.loads(geoipinfo_manager.get_cache_raw(normalized_ip)) + assert res["ip"] == normalized_ip + + # Check mysql + res = geoipinfo_manager.get_mysql(ip) + assert res.ip == ip + assert res.lookup_prefix == "/64" + assert res.basic + with pytest.raises(AssertionError): + geoipinfo_manager.get_mysql_raw(ip) + res = geoipinfo_manager.get_mysql_raw(normalized_ip) + assert res["ip"] == normalized_ip + + def test_create_insights( + self, + maxmind_manager: MaxmindManager, + geoipinfo_manager: GeoIpInfoManager, + delete_ipinfo, + ): + # This is (currently) an IP in the US, so it should do insights + ip = IP_v4_US + delete_ipinfo(ip) + geoipinfo_manager.clear_cache(ip) + assert geoipinfo_manager.get_cache(ip) is None + assert geoipinfo_manager.get_mysql_if_exists(ip) is None + + res1 = maxmind_manager.run_ip_information(ip, force_insights=False) + assert isinstance(res1, GeoIPInformation) + + # Check that it is in the cache and in mysql + res2 = geoipinfo_manager.get_cache(ip) + assert isinstance(res2, GeoIPInformation) + assert res2.ip == ip + assert not res2.basic + + res3 = geoipinfo_manager.get_mysql(ip) + assert isinstance(res3, GeoIPInformation) + assert res3.ip == ip + assert not res3.basic + assert res3.is_anonymous is False + assert res3.subdivision_1_name.lower() in US_STATES + # this might change ... + assert res3.user_type == UserType.CELLULAR + + assert res1 == res2 == res3, "runner, cache, mysql all return same instance" + + def test_create_insights_ipv6( + self, + maxmind_manager: MaxmindManager, + geoipinfo_manager: GeoIpInfoManager, + delete_ipinfo, + ): + # This is (currently) an IP in the US, so it should do insights + ip = IP_v6_US + normalized_ip, lookup_prefix = normalize_ip(ip) + delete_ipinfo(ip) + geoipinfo_manager.clear_cache(ip) + delete_ipinfo(normalized_ip) + geoipinfo_manager.clear_cache(normalized_ip) + assert geoipinfo_manager.get_cache(ip) is None + assert geoipinfo_manager.get_cache(normalized_ip) is None + assert geoipinfo_manager.get_mysql_if_exists(ip) is None + assert geoipinfo_manager.get_mysql_if_exists(normalized_ip) is None + + res1 = maxmind_manager.run_ip_information(ip, force_insights=False) + assert isinstance(res1, GeoIPInformation) + assert res1.lookup_prefix == "/64" + + # Check that it is in the cache and in mysql + res2 = geoipinfo_manager.get_cache(ip) + assert isinstance(res2, GeoIPInformation) + assert res2.ip == ip + assert not res2.basic + + res3 = geoipinfo_manager.get_mysql(ip) + assert isinstance(res3, GeoIPInformation) + assert res3.ip == ip + assert not res3.basic + assert res3.is_anonymous is False + assert res3.subdivision_1_name.lower() in US_STATES + # this might change ... + assert res3.user_type == UserType.RESIDENTIAL + + assert res1 == res2 == res3, "runner, cache, mysql all return same instance" + + def test_get_or_create_ip_information(self, maxmind_manager): + ip = IP_v4_US + + res1 = maxmind_manager.get_or_create_ip_information(ip_address=ip) + assert isinstance(res1, GeoIPInformation) + + res2 = maxmind_manager.get_or_create_ip_information( + ip_address=fake.ipv4_private() + ) + assert res2 is None + + def test_get_or_create_ip_information_ipv6( + self, maxmind_manager, delete_ipinfo, geoipinfo_manager, caplog + ): + ip = IP_v6_US + normalized_ip, lookup_prefix = normalize_ip(ip) + delete_ipinfo(normalized_ip) + geoipinfo_manager.clear_cache(normalized_ip) + + with caplog.at_level(logging.INFO): + res1 = maxmind_manager.get_or_create_ip_information(ip_address=ip) + assert isinstance(res1, GeoIPInformation) + assert res1.ip == ip + # It looks up in insight using the normalize IP! + assert f"get_insights_ip_information: {normalized_ip}" in caplog.text + + # And it should NOT do the lookup again with an ipv6 in the same /64 block! + ip = IP_v6_US_SAME_64 + caplog.clear() + with caplog.at_level(logging.INFO): + res2 = maxmind_manager.get_or_create_ip_information(ip_address=ip) + assert isinstance(res2, GeoIPInformation) + assert res2.ip == ip + assert "get_insights_ip_information" not in caplog.text + + def test_run_ip_information(self, maxmind_manager): + ip = IP_v4_US + + res = maxmind_manager.run_ip_information(ip_address=ip) + assert isinstance(res, GeoIPInformation) + assert res.country_name == "United States" + assert res.country_iso == "us" diff --git a/tests/managers/thl/test_payout.py b/tests/managers/thl/test_payout.py new file mode 100644 index 0000000..31087b8 --- /dev/null +++ b/tests/managers/thl/test_payout.py @@ -0,0 +1,1269 @@ +import logging +import os +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from random import choice as rand_choice, randint +from typing import Optional +from uuid import uuid4 + +import pandas as pd +import pytest + +from generalresearch.currency import USDCent +from generalresearch.managers.thl.ledger_manager.exceptions import ( + LedgerTransactionConditionFailedError, +) +from generalresearch.managers.thl.payout import UserPayoutEventManager +from generalresearch.models.thl.definitions import PayoutStatus +from generalresearch.models.thl.ledger import LedgerEntry, Direction +from generalresearch.models.thl.payout import BusinessPayoutEvent +from generalresearch.models.thl.payout import UserPayoutEvent +from generalresearch.models.thl.wallet import PayoutType +from generalresearch.models.thl.ledger import LedgerAccount + +logger = logging.getLogger() + +cashout_method_uuid = uuid4().hex + + +class TestPayout: + + def test_get_by_uuid_and_create( + self, + user, + user_payout_event_manager: UserPayoutEventManager, + thl_lm, + utc_now, + ): + + user_account: LedgerAccount = thl_lm.get_account_or_create_user_wallet( + user=user + ) + + pe1: UserPayoutEvent = user_payout_event_manager.create( + debit_account_uuid=user_account.uuid, + payout_type=PayoutType.PAYPAL, + cashout_method_uuid=cashout_method_uuid, + amount=100, + created=utc_now, + ) + + # these get added by the query + pe1.account_reference_type = "user" + pe1.account_reference_uuid = user.uuid + # pe1.description = "PayPal" + + pe2 = user_payout_event_manager.get_by_uuid(pe_uuid=pe1.uuid) + + assert pe1 == pe2 + + def test_update(self, user, user_payout_event_manager, lm, thl_lm, utc_now): + from generalresearch.models.thl.definitions import PayoutStatus + from generalresearch.models.thl.wallet import PayoutType + + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + + pe1 = user_payout_event_manager.create( + status=PayoutStatus.PENDING, + debit_account_uuid=user_account.uuid, + payout_type=PayoutType.PAYPAL, + cashout_method_uuid=cashout_method_uuid, + amount=100, + created=utc_now, + ) + user_payout_event_manager.update( + payout_event=pe1, + status=PayoutStatus.APPROVED, + order_data={"foo": "bar"}, + ext_ref_id="abc", + ) + pe = user_payout_event_manager.get_by_uuid(pe_uuid=pe1.uuid) + assert pe.status == PayoutStatus.APPROVED + assert pe.order_data == {"foo": "bar"} + + with pytest.raises(expected_exception=AssertionError) as cm: + user_payout_event_manager.update( + payout_event=pe, status=PayoutStatus.PENDING + ) + assert "status APPROVED can only be" in str(cm.value) + + def test_create_bp_payout( + self, + user, + thl_web_rr, + user_payout_event_manager, + lm, + thl_lm, + product, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + ): + delete_ledger_db() + create_main_accounts() + from generalresearch.models.thl.ledger import LedgerAccount + + thl_lm.get_account_or_create_bp_wallet(product=product) + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + with pytest.raises(expected_exception=LedgerTransactionConditionFailedError): + # wallet balance failure + brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(100), + skip_wallet_balance_check=False, + skip_one_per_day_check=False, + ) + + # (we don't have a special method for this) Put money in the BP's account + amount_cents = 100 + cash_account: LedgerAccount = thl_lm.get_account_cash() + bp_wallet: LedgerAccount = thl_lm.get_account_or_create_bp_wallet( + product=product + ) + + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid=cash_account.uuid, + amount=amount_cents, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid=bp_wallet.uuid, + amount=amount_cents, + ), + ] + + lm.create_tx(entries=entries) + assert 100 == lm.get_account_balance(account=bp_wallet) + + # Then run it again for $1.00 + brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(100), + skip_wallet_balance_check=False, + skip_one_per_day_check=False, + ) + assert 0 == lm.get_account_balance(account=bp_wallet) + + # Run again should without balance check, should still fail due to day check + with pytest.raises(LedgerTransactionConditionFailedError): + brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(100), + skip_wallet_balance_check=True, + skip_one_per_day_check=False, + ) + + # And then we can run again skip both checks + pe = brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(100), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + assert -100 == lm.get_account_balance(account=bp_wallet) + + pe = brokerage_product_payout_event_manager.get_by_uuid(pe.uuid) + txs = lm.get_tx_filtered_by_metadata( + metadata_key="event_payout", metadata_value=pe.uuid + ) + + assert 1 == len(txs) + + def test_create_bp_payout_quick_dupe( + self, + user, + product, + thl_web_rw, + brokerage_product_payout_event_manager, + thl_lm, + lm, + utc_now, + create_main_accounts, + ): + thl_lm.get_account_or_create_bp_wallet(product=product) + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(100), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + created=utc_now, + ) + + with pytest.raises(ValueError) as cm: + brokerage_product_payout_event_manager.create_bp_payout_event( + thl_ledger_manager=thl_lm, + product=product, + amount=USDCent(100), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + created=utc_now, + ) + assert "Payout event already exists!" in str(cm.value) + + def test_filter( + self, + thl_web_rw, + thl_lm, + lm, + product, + user, + user_payout_event_manager, + utc_now, + ): + from generalresearch.models.thl.definitions import PayoutStatus + from generalresearch.models.thl.wallet import PayoutType + + user_account = thl_lm.get_account_or_create_user_wallet(user=user) + bp_account = thl_lm.get_account_or_create_bp_wallet(product=product) + + user_payout_event_manager.create( + status=PayoutStatus.PENDING, + debit_account_uuid=user_account.uuid, + payout_type=PayoutType.PAYPAL, + cashout_method_uuid=cashout_method_uuid, + amount=100, + created=utc_now, + ) + + user_payout_event_manager.create( + status=PayoutStatus.PENDING, + debit_account_uuid=bp_account.uuid, + payout_type=PayoutType.PAYPAL, + cashout_method_uuid=cashout_method_uuid, + amount=200, + created=utc_now, + ) + + pes = user_payout_event_manager.filter_by( + reference_uuid=user.uuid, created=utc_now + ) + assert 1 == len(pes) + + pes = user_payout_event_manager.filter_by( + debit_account_uuids=[bp_account.uuid], created=utc_now + ) + assert 1 == len(pes) + + pes = user_payout_event_manager.filter_by( + debit_account_uuids=[bp_account.uuid], amount=123 + ) + assert 0 == len(pes) + + pes = user_payout_event_manager.filter_by( + product_ids=[user.product_id], bp_user_ids=["x"] + ) + assert 0 == len(pes) + + pes = user_payout_event_manager.filter_by(product_ids=[user.product_id]) + assert 1 == len(pes) + + pes = user_payout_event_manager.filter_by( + cashout_types=[PayoutType.PAYPAL], + bp_user_ids=[user.product_user_id], + ) + assert 1 == len(pes) + + pes = user_payout_event_manager.filter_by( + statuses=[PayoutStatus.FAILED], + cashout_types=[PayoutType.PAYPAL], + bp_user_ids=[user.product_user_id], + ) + assert 0 == len(pes) + + pes = user_payout_event_manager.filter_by( + statuses=[PayoutStatus.PENDING], + cashout_types=[PayoutType.PAYPAL], + bp_user_ids=[user.product_user_id], + ) + assert 1 == len(pes) + + +class TestPayoutEventManager: + + def test_set_account_lookup_table( + self, payout_event_manager, thl_redis_config, thl_lm, delete_ledger_db + ): + delete_ledger_db() + rc = thl_redis_config.create_redis_client() + rc.delete("pem:account_to_product") + rc.delete("pem:product_to_account") + N = 5 + + for idx in range(N): + thl_lm.get_account_or_create_bp_wallet_by_uuid(product_uuid=uuid4().hex) + + res = rc.hgetall(name="pem:account_to_product") + assert len(res.items()) == 0 + + res = rc.hgetall(name="pem:product_to_account") + assert len(res.items()) == 0 + + payout_event_manager.set_account_lookup_table( + thl_lm=thl_lm, + ) + + res = rc.hgetall(name="pem:account_to_product") + assert len(res.items()) == N + + res = rc.hgetall(name="pem:product_to_account") + assert len(res.items()) == N + + thl_lm.get_account_or_create_bp_wallet_by_uuid(product_uuid=uuid4().hex) + payout_event_manager.set_account_lookup_table( + thl_lm=thl_lm, + ) + + res = rc.hgetall(name="pem:account_to_product") + assert len(res.items()) == N + 1 + + res = rc.hgetall(name="pem:product_to_account") + assert len(res.items()) == N + 1 + + +class TestBusinessPayoutEventManager: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "5d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return timedelta(days=10) + + def test_base( + self, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + thl_lm, + thl_web_rr, + product_factory, + bp_payout_factory, + business, + ): + delete_ledger_db() + create_main_accounts() + + from generalresearch.models.thl.product import Product + + p1: Product = product_factory(business=business) + thl_lm.get_account_or_create_bp_wallet(product=p1) + business_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + ach_id1 = uuid4().hex + ach_id2 = uuid4().hex + + bp_payout_factory( + product=p1, + amount=USDCent(1), + ext_ref_id=None, + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=p1, + amount=USDCent(1), + ext_ref_id=ach_id1, + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=p1, + amount=USDCent(25), + ext_ref_id=ach_id1, + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=p1, + amount=USDCent(50), + ext_ref_id=ach_id2, + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + + assert len(business.payouts) == 3 + assert business.payouts_total == sum([pe.amount for pe in business.payouts]) + assert business.payouts[0].created > business.payouts[1].created + assert len(business.payouts[0].bp_payouts) == 1 + assert len(business.payouts[1].bp_payouts) == 2 + + assert business.payouts[0].ext_ref_id == ach_id2 + assert business.payouts[1].ext_ref_id == ach_id1 + assert business.payouts[2].ext_ref_id is None + + def test_update_ext_reference_ids( + self, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + thl_lm, + thl_web_rr, + product_factory, + bp_payout_factory, + delete_df_collection, + user_factory, + ledger_collection, + session_with_tx_factory, + pop_ledger_merge, + client_no_amm, + mnt_filepath, + lm, + product_manager, + start, + business, + ): + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p1) + business_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + # $250.00 to work with + for idx in range(1, 10): + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("25.00"), + started=start + timedelta(days=1, minutes=idx), + ) + + ach_id1 = uuid4().hex + ach_id2 = uuid4().hex + + with pytest.raises(expected_exception=Warning) as cm: + business_payout_event_manager.update_ext_reference_ids( + new_value=ach_id2, + current_value=ach_id1, + ) + assert "No event_payouts found to UPDATE" in str(cm) + + # We must build the balance to issue ACH/Wire + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + res = business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(100_01), + pm=product_manager, + thl_lm=thl_lm, + transaction_id=ach_id1, + ) + assert isinstance(res, BusinessPayoutEvent) + + # Okay, now that there is a payout_event, let's try to update the + # ext_reference_id + business_payout_event_manager.update_ext_reference_ids( + new_value=ach_id2, + current_value=ach_id1, + ) + + res = business_payout_event_manager.filter_by(ext_ref_id=ach_id1) + assert len(res) == 0 + + res = business_payout_event_manager.filter_by(ext_ref_id=ach_id2) + assert len(res) == 1 + + def test_delete_failed_business_payout( + self, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + thl_lm, + thl_web_rr, + product_factory, + bp_payout_factory, + currency, + delete_df_collection, + user_factory, + ledger_collection, + session_with_tx_factory, + pop_ledger_merge, + client_no_amm, + mnt_filepath, + lm, + product_manager, + start, + business, + ): + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p1) + business_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + # $250.00 to work with + for idx in range(1, 10): + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("25.00"), + started=start + timedelta(days=1, minutes=idx), + ) + + # We must build the balance to issue ACH/Wire + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + ach_id1 = uuid4().hex + + res = business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(100_01), + pm=product_manager, + thl_lm=thl_lm, + transaction_id=ach_id1, + ) + assert isinstance(res, BusinessPayoutEvent) + + # (1) Confirm the initial Event Payout, Tx, TxMeta, TxEntry all exist + event_payouts = business_payout_event_manager.filter_by(ext_ref_id=ach_id1) + event_payout_uuids = [i.uuid for i in event_payouts] + assert len(event_payout_uuids) == 1 + tags = [f"{currency.value}:bp_payout:{x}" for x in event_payout_uuids] + transactions = thl_lm.get_txs_by_tags(tags=tags) + assert len(transactions) == 1 + tx_metadata_ids = thl_lm.get_tx_metadata_ids_by_txs(transactions=transactions) + assert len(tx_metadata_ids) == 2 + tx_entries = thl_lm.get_tx_entries_by_txs(transactions=transactions) + assert len(tx_entries) == 2 + + # (2) Delete! + business_payout_event_manager.delete_failed_business_payout( + ext_ref_id=ach_id1, thl_lm=thl_lm + ) + + # (3) Confirm the initial Event Payout, Tx, TxMeta, TxEntry have + # all been deleted + res = business_payout_event_manager.filter_by(ext_ref_id=ach_id1) + assert len(res) == 0 + + # Note: b/c the event_payout shouldn't exist anymore, we are taking + # the tag strings and transactions from when they did.. + res = thl_lm.get_txs_by_tags(tags=tags) + assert len(res) == 0 + + tx_metadata_ids = thl_lm.get_tx_metadata_ids_by_txs(transactions=transactions) + assert len(tx_metadata_ids) == 0 + tx_entries = thl_lm.get_tx_entries_by_txs(transactions=transactions) + assert len(tx_entries) == 0 + + def test_recoup_empty(self, business_payout_event_manager): + res = {uuid4().hex: USDCent(0) for i in range(100)} + df = pd.DataFrame.from_dict(res, orient="index").reset_index() + df.columns = ["product_id", "available_balance"] + + with pytest.raises(expected_exception=ValueError) as cm: + business_payout_event_manager.recoup_proportional( + df=df, target_amount=USDCent(1) + ) + assert "Total available amount is empty, cannot recoup" in str(cm) + + def test_recoup_exceeds(self, business_payout_event_manager): + from random import randint + + res = {uuid4().hex: USDCent(randint(a=0, b=1_000_00)) for i in range(100)} + df = pd.DataFrame.from_dict(res, orient="index").reset_index() + df.columns = ["product_id", "available_balance"] + + avail_balance = USDCent(int(df.available_balance.sum())) + + with pytest.raises(expected_exception=ValueError) as cm: + business_payout_event_manager.recoup_proportional( + df=df, target_amount=avail_balance + USDCent(1) + ) + assert " exceeds total available " in str(cm) + + def test_recoup(self, business_payout_event_manager): + from random import randint, random + + res = {uuid4().hex: USDCent(randint(a=0, b=1_000_00)) for i in range(100)} + df = pd.DataFrame.from_dict(res, orient="index").reset_index() + df.columns = ["product_id", "available_balance"] + + avail_balance = USDCent(int(df.available_balance.sum())) + random_recoup_amount = USDCent(1 + int(int(avail_balance) * random() * 0.5)) + + res = business_payout_event_manager.recoup_proportional( + df=df, target_amount=random_recoup_amount + ) + + assert isinstance(res, pd.DataFrame) + assert res.weight.sum() == pytest.approx(1) + assert res.deduction.sum() == random_recoup_amount + assert res.remaining_balance.sum() == avail_balance - random_recoup_amount + + def test_recoup_loop(self, business_payout_event_manager, request): + # TODO: Generate this file at random + fp = os.path.join( + request.config.rootpath, "data/pytest_recoup_proportional.csv" + ) + df = pd.read_csv(fp, index_col=0) + res = business_payout_event_manager.recoup_proportional( + df=df, target_amount=USDCent(1416089) + ) + + res = res[res["remaining_balance"] > 0] + + assert int(res.deduction.sum()) == 1416089 + + def test_recoup_loop_single_profitable_account(self, business_payout_event_manager): + res = [{"product_id": uuid4().hex, "available_balance": 0} for i in range(1000)] + for x in range(100): + item = rand_choice(res) + item["available_balance"] = randint(8, 12) + + df = pd.DataFrame(res) + res = business_payout_event_manager.recoup_proportional( + df=df, target_amount=USDCent(500) + ) + # res = res[res["remaining_balance"] > 0] + assert int(res.deduction.sum()) == 500 + + def test_recoup_loop_assertions(self, business_payout_event_manager): + df = pd.DataFrame( + [ + { + "product_id": uuid4().hex, + "available_balance": randint(0, 999_999), + } + for i in range(10_000) + ] + ) + available_balance = int(df.available_balance.sum()) + + # Exact amount + res = business_payout_event_manager.recoup_proportional( + df=df, target_amount=available_balance + ) + assert res.remaining_balance.sum() == 0 + assert int(res.deduction.sum()) == available_balance + + # Slightly less + res = business_payout_event_manager.recoup_proportional( + df=df, target_amount=available_balance - 1 + ) + assert res.remaining_balance.sum() == 1 + assert int(res.deduction.sum()) == available_balance - 1 + + # Slightly less + with pytest.raises(expected_exception=Exception) as cm: + res = business_payout_event_manager.recoup_proportional( + df=df, target_amount=available_balance + 1 + ) + + # Don't pull anything + res = business_payout_event_manager.recoup_proportional(df=df, target_amount=0) + assert res.remaining_balance.sum() == available_balance + assert int(res.deduction.sum()) == 0 + + def test_distribute_amount(self, business_payout_event_manager): + import io + + df = pd.read_csv( + io.StringIO( + "product_id,available_balance,weight,raw_deduction,deduction,remainder,remaining_balance\n0,15faf,11768,0.0019298788807489663,222.3374860933269,223,0.33748609332690194,11545\n5,793e3,202,3.312674489388946e-05,3.8164660257352168,3,0.8164660257352168,199\n6,c703b,22257,0.0036500097084321667,420.5103184890531,421,0.510318489053077,21836\n13,72a70,1424,0.00023352715212326036,26.90419614181658,27,0.9041961418165805,1397\n14,86173,45634,0.007483692457860156,862.1812406851528,863,0.18124068515282943,44771\n17,4f230,143676,0.02356197128403199,2714.5275876907576,2715,0.5275876907576276,140961\n18,e1ee6,129,2.1155198471840298e-05,2.437248105543777,2,0.43724810554377713,127\n22,4524a,85613,0.014040000052478012,1617.5203260458868,1618,0.5203260458868044,83995\n25,4f5e2,30282,0.004966059845924558,572.1298227292765,573,0.12982272927649774,29709\n28,0f3b3,135,2.213916119146078e-05,2.5506084825458135,2,0.5506084825458135,133\n29,15c6f,1226,0.00020105638237578454,23.163303700749385,23,0.16330370074938472,1203\n31,c0c04,376,6.166166376288335e-05,7.103916958794265,7,0.10391695879426521,369\n33,2934c,37649,0.006174202071831903,711.3174722916099,712,0.31747229160987445,36937\n38,5585d,16471,0.0027011416591448184,311.1931282667562,312,0.19312826675621864,16159\n42,0a749,663,0.00010872788051806293,12.526321658724994,12,0.5263216587249939,651\n43,9e336,322,5.280599928629904e-05,6.08367356577594,6,0.08367356577593998,316\n46,043a5,11656,0.001911511576649384,220.22142572262223,221,0.2214257226222287,11435\n47,d6f4e,39,6.3957576775331136e-06,0.7368424505132349,0,0.7368424505132349,39\n48,3e123,3012,0.0004939492852494804,56.90690925502214,57,0.9069092550221427,2955\n49,76ccc,76,1.24635277818594e-05,1.4358981086924578,1,0.4358981086924578,75\n51,8acb9,7710,0.0012643920947123155,145.66808444761645,146,0.6680844476164509,7564\n56,fef3e,212,3.476668275992359e-05,4.005399987405277,4,0.005399987405277251,208\n57,6d7c2,455709,0.07473344449925481,8609.890673870148,8610,0.8906738701480208,447099\n58,f51f6,257,4.2146403157077186e-05,4.855602814920548,4,0.8556028149205481,253\n61,06acf,84310,0.013826316148533765,1592.902230840278,1593,0.90223084027798,82717\n62,6eca7,40,6.559751464136527e-06,0.755735846680241,0,0.755735846680241,40\n68,4a415,1955,0.00032060785280967275,36.93658950649678,37,0.9365895064967802,1918\n69,57a16,409,6.707345872079598e-05,7.727399032305463,7,0.7273990323054633,402\n70,c3ef6,593,9.724831545582401e-05,11.203783927034571,11,0.20378392703457138,582\n71,385da,825,0.00013529487394781586,15.587051837779969,15,0.5870518377799687,810\n72,a8435,748,0.00012266735237935304,14.132260332920506,14,0.13226033292050587,734\n75,c9374,263383,0.043193175496966774,4976.199362654548,4977,0.19936265454816748,258406\n76,7fcc7,136,2.230315497806419e-05,2.569501878712819,2,0.5695018787128192,134\n77,26aec,356,5.838178803081509e-05,6.726049035454145,6,0.7260490354541451,350\n78,76cc5,413,6.772943386720964e-05,7.802972616973489,7,0.8029726169734888,406\n82,9476a,13973,0.002291485180209492,263.99742464157515,264,0.9974246415751509,13709\n85,ee099,2397,0.0003930931064883814,45.28747061231344,46,0.2874706123134416,2351\n87,24bd5,122295,0.020055620132664414,2310.567884244002,2311,0.5678842440020162,119984\n91,1fa8f,4,6.559751464136527e-07,0.0755735846680241,0,0.0755735846680241,4\n92,53b5a,1,1.6399378660341317e-07,0.018893396167006023,0,0.018893396167006023,1\n93,f4f9e,201,3.296275110728605e-05,3.797572629568211,3,0.7975726295682111,198\n95,ff5d7,21317,0.0034958555490249587,402.75052609206745,403,0.7505260920674459,20914\n96,6290d,80,1.3119502928273054e-05,1.511471693360482,1,0.5114716933604819,79\n100,9c34b,1870,0.0003066683809483826,35.33065083230127,36,0.33065083230126646,1834\n101,d32a9,11577,0.0018985560675077143,218.72884742542874,219,0.7288474254287394,11358\n102,a6001,8981,0.0014728281974852537,169.6815909758811,170,0.6815909758811074,8811\n106,0a1ee,9,1.4759440794307186e-06,0.17004056550305424,0,0.17004056550305424,9\n108,8ad36,51,8.363683116774072e-06,0.9635632045173074,0,0.9635632045173074,51\n111,389ab,75,1.2299533995255988e-05,1.417004712525452,1,0.4170047125254519,74\n114,be86b,4831,0.000792253983081089,91.2739968828061,92,0.27399688280610235,4739\n118,99e96,271,4.444231616952497e-05,5.120110361258632,5,0.12011036125863228,266\n121,3b729,12417,0.0020363108482545815,234.59930020571383,235,0.5993002057138312,12182\n122,f16e2,2697,0.0004422912424694053,50.95548946241525,51,0.955489462415251,2646\n125,241c5,2,3.2798757320682633e-07,0.03778679233401205,0,0.03778679233401205,2\n127,523af,569,9.33124645773421e-05,10.750342419026428,10,0.7503424190264276,559\n130,dce2e,31217,0.005119394036398749,589.7951481454271,590,0.7951481454271061,30627\n133,214b8,37360,0.006126807867503516,705.8572807993451,706,0.8572807993450624,36654\n136,03c88,35,5.739782531119461e-06,0.6612688658452109,0,0.6612688658452109,35\n137,08021,44828,0.007351513465857806,846.9531633745461,847,0.9531633745460795,43981\n144,bc3d9,1174,0.00019252870547240705,22.180847100065073,22,0.18084710006507265,1152\n148,bac2f,3745,0.0006141567308297824,70.75576864543757,71,0.7557686454375698,3674\n150,d9e69,8755,0.0014357656017128823,165.41168344213776,166,0.41168344213775754,8589\n151,36b18,1,1.6399378660341317e-07,0.018893396167006023,0,0.018893396167006023,1\n152,e15d7,5259,0.0008624433237473498,99.36037044228468,100,0.3603704422846761,5159\n155,bad23,16504,0.002706553454102731,311.81661034026746,312,0.8166103402674594,16192\n159,9f77f,2220,0.00036406620625957726,41.943339490753374,42,0.9433394907533739,2178\n160,945ab,131,2.1483186045047126e-05,2.4750348978777894,2,0.4750348978777894,129\n161,3e2dc,354,5.8053800457608265e-05,6.688262243120133,6,0.6882622431201328,348\n162,62811,1608,0.0002637020088582884,30.38058103654569,31,0.38058103654568853,1577\n164,59a3d,2541,0.0004167082117592729,48.00811966036231,49,0.008119660362311265,2492\n165,871ca,8,1.3119502928273053e-06,0.1511471693360482,0,0.1511471693360482,8\n167,55d38,683,0.00011200775625013119,12.904189582065115,12,0.9041895820651149,671\n169,1e5af,6,9.83962719620479e-07,0.11336037700203613,0,0.11336037700203613,6\n170,b2901,15994,0.0026229166229349904,302.18097829509435,303,0.1809782950943486,15691\n174,dc880,2435,0.00039932487037931105,46.00541966665967,47,0.005419666659669531,2388\n176,d189e,34579,0.005670741146959424,653.3147460589013,654,0.31474605890127805,33925\n177,d35f5,41070,0.006735224815802179,775.9517805789375,776,0.9517805789374734,40294\n178,bd7a6,12563,0.0020602539410986796,237.35773604609668,238,0.35773604609667586,12325\n180,662cd,9675,0.0015866398853880224,182.79360791578327,183,0.7936079157832694,9492\n181,e995f,1011,0.00016579771825605072,19.101223524843093,19,0.1012235248430926,992\n189,0faf8,626,0.00010266011041373665,11.827266000545771,11,0.8272660005457713,615\n190,1fd84,213,3.4930676546527005e-05,4.024293383572283,4,0.02429338357228339,209\n193,c9b44,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n194,bbd8d,16686,0.0027364003232645522,315.25520844266254,316,0.2552084426625356,16370\n196,d4945,2654,0.00043523950964545856,50.14307342723399,51,0.1430734272339933,2603\n197,bbde8,3043,0.0004990330926341863,57.49260453619934,58,0.4926045361993374,2985\n200,579ad,20833,0.0034164825563089067,393.6061223472365,394,0.6061223472365214,20439\n202,2f932,15237,0.0024987733264762065,287.8786773966708,288,0.8786773966708097,14949\n208,b649d,103551,0.01698172059657004,1956.430066489641,1957,0.43006648964092165,101594\n209,0f939,26025,0.0042679382963538275,491.7006352463318,492,0.7006352463317853,25533\n211,638d8,6218,0.0010197133651000231,117.47913736644347,118,0.47913736644346727,6100\n215,d2a81,8301,0.0013613124225949327,156.834081582317,157,0.8340815823169976,8144\n216,62293,4,6.559751464136527e-07,0.0755735846680241,0,0.0755735846680241,4\n218,c8ae9,2829,0.00046393842230105583,53.44941775646004,54,0.44941775646004345,2775\n219,83a9f,6556,0.0010751432649719768,123.8651052708915,124,0.8651052708915046,6432\n221,d256a,72,1.1807552635445749e-05,1.360324524024434,1,0.36032452402443393,71\n222,6fdc2,7,1.1479565062238922e-06,0.1322537731690422,0,0.1322537731690422,7\n224,56b88,146928,0.02409527907806629,2775.9689120258613,2776,0.9689120258613002,144152\n230,f50f8,5798,0.0009508359747265895,109.54391097630092,110,0.5439109763009213,5688\n231,fa3be,2,3.2798757320682633e-07,0.03778679233401205,0,0.03778679233401205,2\n232,94934,537381,0.08812714503872877,10152.952125621865,10153,0.9521256218649796,527228\n234,4e20c,5,8.199689330170659e-07,0.09446698083503012,0,0.09446698083503012,5\n235,a5d68,31101,0.005100370757152753,587.6035141900543,588,0.603514190054284,30513\n236,e5a29,3208,0.0005260920674237494,60.61001490375533,61,0.610014903755328,3147\n237,0ce0f,294,4.821417326140347e-05,5.554658473099771,5,0.5546584730997708,289\n240,66d2b,18633,0.0030556962257813976,352.04065077982324,353,0.04065077982323828,18280\n244,1bd17,1815,0.0002976487226851949,34.29151404311594,35,0.29151404311593865,1780\n245,32aca,224,3.673460819916455e-05,4.23212074140935,4,0.23212074140935002,220\n247,cbf8e,4747,0.0007784785050064023,89.6869516047776,90,0.6869516047776045,4657\n249,8f24b,807633,0.13244679385587438,15258.930226547576,15259,0.9302265475762397,792374\n251,c97c3,20526,0.0033661364638216586,387.80584972396565,388,0.8058497239656504,20138\n252,88fee,13821,0.0022665581246457734,261.12562842419027,262,0.12562842419026765,13559\n253,a9ad3,178,2.9190894015407545e-05,3.3630245177270726,3,0.36302451772707256,175\n254,83738,104,1.705535380675497e-05,1.9649132013686266,1,0.9649132013686266,103\n255,21f6c,6288,0.001031192930162262,118.8016750981339,119,0.8016750981338987,6169\n256,97e28,6,9.83962719620479e-07,0.11336037700203613,0,0.11336037700203613,6\n257,7f689,39,6.3957576775331136e-06,0.7368424505132349,0,0.7368424505132349,39\n258,e7a50,28031,0.004596909832280275,529.6007879573459,530,0.600787957345915,27501\n259,2eb98,1,1.6399378660341317e-07,0.018893396167006023,0,0.018893396167006023,1\n260,349ba,19518,0.0032008307269254183,368.76130638762356,369,0.7613063876235628,19149\n261,a3d04,235,3.853853985180209e-05,4.439948099246416,4,0.43994809924641576,231\n262,40971,2249,0.0003688220260710762,42.49124797959655,43,0.491247979596551,2206\n264,4f588,6105,0.0010011820672138373,115.34418359957178,116,0.34418359957177813,5989\n269,f182e,1020,0.00016727366233548145,19.271264090346147,19,0.27126409034614696,1001\n270,3798e,5168,0.0008475198891664393,97.64107139108714,98,0.641071391087138,5070\n274,81dc3,274,4.493429752933521e-05,5.176790549759651,5,0.1767905497596507,269\n285,8520b,2,3.2798757320682633e-07,0.03778679233401205,0,0.03778679233401205,2\n287,37b89,3742,0.0006136647494699721,70.69908845693654,71,0.6990884569365363,3671\n291,63706,4740,0.0007773305485001784,89.55469783160856,90,0.5546978316085642,4650\n293,21c9d,241,3.9522502571422574e-05,4.553308476248452,4,0.5533084762484517,237\n296,3a42a,610,0.00010003620982808204,11.524971661873675,11,0.5249716618736748,599\n302,31a8f,148533,0.02435848910556477,2806.292812873906,2807,0.2928128739058593,145726\n303,38467,2399,0.0003934210940615882,45.32525740464745,46,0.3252574046474521,2353\n304,a49a6,26573,0.004357806891412498,502.0542163458511,503,0.05421634585110269,26070\n305,3c4e7,2286,0.0003748897961754025,43.19030363777577,44,0.19030363777577008,2242\n306,63986,34132,0.005597435924347699,644.8693979722497,645,0.8693979722496579,33487\n307,640f4,50,8.199689330170658e-06,0.9446698083503012,0,0.9446698083503012,50\n309,ba81a,3015,0.0004944412666092907,56.96358944352316,57,0.963589443523162,2958\n310,b7e8e,11409,0.0018710051113583408,215.55475686937172,216,0.5547568693717153,11193\n311,98301,5694,0.0009337806209198346,107.5789977749323,108,0.5789977749323043,5586\n312,e70bd,19,3.11588194546485e-06,0.35897452717311445,0,0.35897452717311445,19\n314,adf9c,5023,0.0008237407901089443,94.90152894687125,95,0.9015289468712524,4928\n316,ab5f4,55,9.019658263187724e-06,1.0391367891853314,1,0.039136789185331367,54\n318,d4a4c,2242,0.0003676740695648523,42.358994206427504,43,0.35899420642750357,2199\n319,b4062,222,3.640662062595773e-05,4.194333949075338,4,0.19433394907533774,218\n321,cb691,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n322,b83ed,3716,0.0006094009110182833,70.20786015659439,71,0.20786015659439272,3645\n323,8759e,5964,0.0009780589433027562,112.68021474002393,113,0.6802147400239278,5851\n325,2f217,3625,0.0005944774764373728,68.48856110539684,69,0.4885611053968404,3556\n326,d683f,28156,0.004617409055605701,531.9624624782216,532,0.962462478221596,27624\n327,97cf7,928,0.00015218623396796743,17.533071642981593,17,0.5330716429815929,911\n328,75135,2841,0.0004659063477402968,53.67613851046411,54,0.6761385104641136,2787\n329,d7c10,29913,0.0049055461386678986,565.1581595436512,566,0.15815954365120888,29347\n330,8598e,66,1.082358991582527e-05,1.2469641470223976,1,0.24696414702239755,65\n331,5e72b,10825,0.0017752327399819475,204.52101350784022,205,0.5210135078402232,10620\n332,30bff,975,0.00015989394193832783,18.42106126283087,18,0.42106126283087164,957\n333,f79d7,696,0.00011413967547597557,13.149803732236194,13,0.1498037322361938,683\n334,d0a0d,2121,0.0003478308213858393,40.07289327021978,41,0.07289327021977954,2080\n335,22ec5,3942,0.0006464635067906547,74.47776769033774,75,0.4777676903377426,3867\n336,efeac,515,8.445680010075779e-05,9.730099026008103,9,0.7300990260081033,506\n337,c3854,105882,0.017363990113142592,2000.4705729549316,2001,0.4705729549316402,103881\n338,cd2a3,42924,0.007039269296164907,810.9801370725665,811,0.9801370725665493,42113\n339,d7333,5089,0.0008345643800247697,96.14849309389366,97,0.14849309389366283,4992\n340,7d48c,33,5.411794957912635e-06,0.6234820735111988,0,0.6234820735111988,33\n341,6a148,100,1.6399378660341316e-05,1.8893396167006025,1,0.8893396167006025,99\n342,ba8ff,400,6.559751464136527e-05,7.55735846680241,7,0.5573584668024099,393\n343,38810,3194,0.0005237961544113017,60.34550735741724,61,0.34550735741724026,3133\n344,04d6e,43,7.051732823946766e-06,0.812416035181259,0,0.812416035181259,43\n345,d4f05,1394,0.00022860733852515797,26.337394256806398,27,0.3373942568063981,1367\n346,53f71,1325,0.00021729176724952245,25.033749921282983,26,0.0337499212829826,1299\n347,e4d06,60,9.83962719620479e-06,1.1336037700203614,1,0.1336037700203614,59\n348,8ab28,5,8.199689330170659e-07,0.09446698083503012,0,0.09446698083503012,5\n349,f21e0,65,1.0659596129221857e-05,1.2280707508553916,1,0.22807075085539164,64\n350,02fc2,2202,0.0003611143181007158,41.603258359747265,42,0.6032583597472652,2160\n351,310c0,18370,0.0030125658599047,347.07168758790067,348,0.07168758790066931,18022\n352,fceae,169,2.7714949935976826e-05,3.192983952224018,3,0.1929839522240182,166\n353,25c52,256,4.198240937047377e-05,4.836709418753542,4,0.836709418753542,252\n354,2e0cc,7029,0.0011527123260353911,132.80168165788535,133,0.8016816578853536,6896\n355,4baa9,405,6.641748357438233e-05,7.65182544763744,7,0.6518254476374397,398\n356,f2cb1,2236,0.00036669010684523185,42.24563382942547,43,0.24563382942547207,2193\n357,c4f9f,69,1.1315571275635508e-05,1.3036443355234157,1,0.30364433552341574,68\n358,0fe2b,33,5.411794957912635e-06,0.6234820735111988,0,0.6234820735111988,33\n359,ba2cc,911,0.0001493983395957094,17.21188390814249,17,0.21188390814248947,894\n360,22c5c,84,1.3775478074686707e-05,1.587045278028506,1,0.587045278028506,83\n361,b0369,42,6.8877390373433535e-06,0.793522639014253,0,0.793522639014253,42\n362,a7d03,12344,0.002024339301832532,233.22008228552235,234,0.2200822855223521,12110\n363,f6d7e,11703,0.0019192192846197442,221.1094153424715,222,0.1094153424714932,11481\n364,c281d,50376,0.008261350993933542,951.7737253090955,952,0.7737253090955392,49424\n365,dfbd6,3932,0.0006448235689246206,74.2888337286677,75,0.2888337286676972,3857\n366,a34ba,3909,0.0006410517118327421,73.85428561682654,74,0.8542856168265445,3835\n367,157ee,253,4.149042801066353e-05,4.780029230252524,4,0.7800292302525236,249\n369,33a29,6954,0.0011404127920401352,131.3846769453599,132,0.3846769453598995,6822\n370,2bd95,11279,0.001849685919099897,213.09861536766095,214,0.09861536766095469,11065\n371,138b7,2094,0.00034340298914754716,39.56277157371061,40,0.5627715737106129,2054\n372,eea6f,3884,0.0006369518671676568,73.3819507126514,74,0.38195071265140257,3810\n373,8f39b,1046,0.00017153750078717018,19.7624923906883,19,0.7624923906883012,1027\n374,bacfb,773,0.00012676719704443837,14.604595237095657,14,0.6045952370956567,759\n375,23403,855,0.00014021468754591825,16.15385372279015,16,0.1538537227901493,839\n376,0ad32,1630,0.00026730987216356345,30.79623575221982,31,0.7962357522198182,1599\n377,6be35,21,3.4438695186716767e-06,0.3967613195071265,0,0.3967613195071265,21\n378,12402,2994,0.000490997397090619,56.566828124016034,57,0.566828124016034,2937\n379,bae61,40980,0.006720465375007872,774.251374923907,775,0.25137492390695115,40205\n380,384f7,23545,0.003861233705577363,444.84501275215683,445,0.8450127521568334,23100\n381,03ed6,115,1.8859285459392513e-05,2.1727405592056925,2,0.17274055920569253,113\n382,cc4c8,3529,0.0005787340729234451,66.67479507336427,67,0.6747950733642654,3462\n383,263fa,4019,0.0006590910283591176,75.93255919519721,76,0.9325591951972143,3943\n384,704a1,11504,0.001886584521085665,217.3496295052373,218,0.34962950523728864,11286\n385,a0efb,7960,0.0013053905413631687,150.39143348936796,151,0.3914334893679552,7809\n386,b3197,981,0.00016087790465794832,18.53442163983291,18,0.5344216398329102,963\n387,1c91d,174,2.8534918868993892e-05,3.2874509330590485,3,0.28745093305904845,171\n388,1afff,797,0.0001307030479229203,15.058036745103802,15,0.05803674510380219,782\n389,20304,10605,0.0017391541069291968,200.3644663510989,201,0.3644663510989119,10404\n390,638fd,79,1.295550914166964e-05,1.4925782971934758,1,0.4925782971934758,78\n391,8c258,11,1.8039316526375448e-06,0.20782735783706627,0,0.20782735783706627,11\n392,c1847,20,3.2798757320682634e-06,0.3778679233401205,0,0.3778679233401205,20\n393,d72ec,2445,0.0004009648082453452,46.19435362832973,47,0.19435362832972913,2398\n394,d83e3,60,9.83962719620479e-06,1.1336037700203614,1,0.1336037700203614,59\n395,04a4f,60,9.83962719620479e-06,1.1336037700203614,1,0.1336037700203614,59\n396,94c2e,52,8.527676903377485e-06,0.9824566006843133,0,0.9824566006843133,52\n397,2fc45,8,1.3119502928273053e-06,0.1511471693360482,0,0.1511471693360482,8\n398,e7986,30,4.919813598102395e-06,0.5668018850101807,0,0.5668018850101807,30\n399,612a5,20340,0.003335633619513424,384.2916780369025,385,0.2916780369025105,19955\n400,f9e3a,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n401,75d14,667,0.00010938385566447659,12.601895243393018,12,0.6018952433930185,655\n402,e02b6,86,1.4103465647893532e-05,1.624832070362518,1,0.6248320703625181,85\n403,b4c46,479,7.855302378303491e-05,9.049936763995886,9,0.04993676399588587,470\n404,f9fb4,299,4.9034142194420536e-05,5.6491254539348015,5,0.6491254539348015,294\n405,08461,34,5.575788744516048e-06,0.6423754696782048,0,0.6423754696782048,34\n406,42032,124,2.0335229538823232e-05,2.342781124708747,2,0.3427811247087469,122\n407,b7ea6,2861,0.0004691862234723651,54.05400643380424,55,0.05400643380423986,2806\n408,cf91f,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n409,64f3d,11,1.8039316526375448e-06,0.20782735783706627,0,0.20782735783706627,11\n410,49d78,408,6.690946493419258e-05,7.708505636138459,7,0.708505636138459,401\n411,aa802,3907,0.0006407237242595352,73.81649882449253,74,0.8164988244925269,3833\n412,7feff,450,7.379720397153592e-05,8.50202827515271,8,0.5020282751527105,442\n413,08bc7,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n414,c7cad,16,2.6239005856546107e-06,0.3022943386720964,0,0.3022943386720964,16\n415,ec21a,84,1.3775478074686707e-05,1.587045278028506,1,0.587045278028506,83\n416,a1ba3,5,8.199689330170659e-07,0.09446698083503012,0,0.09446698083503012,5\n417,7f276,240,3.935850878481916e-05,4.534415080081446,4,0.5344150800814456,236\n419,e86cf,188,3.0830831881441676e-05,3.5519584793971326,3,0.5519584793971326,185\n420,4f7b0,14,2.2959130124477845e-06,0.2645075463380844,0,0.2645075463380844,14\n421,c61f3,363,5.952974453703898e-05,6.858302808623187,6,0.8583028086231872,357\n422,0c672,32,5.247801171309221e-06,0.6045886773441927,0,0.6045886773441927,32\n423,9f766,14028,0.00230050483847268,265.0365614307605,266,0.03656143076051421,13762\n424,2f990,4092,0.0006710625747811667,77.31177711538865,78,0.3117771153886508,4014\n425,660cb,1043,0.00017104551942735994,19.705812202187285,19,0.7058122021872855,1024\n426,0045c,81,1.3283496714876466e-05,1.5303650895274878,1,0.5303650895274878,80\n427,934e9,818,0.00013414691744159196,15.454798064610927,15,0.45479806461092664,803\n428,820eb,21,3.4438695186716767e-06,0.3967613195071265,0,0.3967613195071265,21\n429,8736b,14,2.2959130124477845e-06,0.2645075463380844,0,0.2645075463380844,14\n430,3de8b,18,2.9518881588614373e-06,0.3400811310061085,0,0.3400811310061085,18\n431,16140,21927,0.0035958917588530407,414.2754977539411,415,0.2754977539411243,21512\n432,44e80,17490,0.0028682513276936964,330.44549896093537,331,0.44549896093536745,17159\n434,c10c0,58575,0.009605936050294927,1106.680680482378,1107,0.6806804823779657,57468\n435,05214,129705,0.021270814091395706,2450.5679498415166,2451,0.5679498415165654,127254\n436,d70bb,130766,0.021444811498981926,2470.61384317471,2471,0.6138431747099276,128295\n437,a7eba,155,2.5419036923529042e-05,2.928476405885934,2,0.928476405885934,153\n438,2afa7,603,9.888825332185814e-05,11.392717888704633,11,0.39271788870463276,592\n439,ff2e5,1855,0.00030420847414933144,35.04724988979618,36,0.04724988979617706,1819\n440,b784f,148,2.4271080417305148e-05,2.7962226327168915,2,0.7962226327168915,146\n441,c2ce4,16469,0.0027008136715716115,311.15534147442224,312,0.15534147442224366,16157\n442,02ab5,491,8.052094922227586e-05,9.276657517999958,9,0.27665751799995775,482\n443,56013,65781,0.010787675276559121,1242.8264932618233,1243,0.8264932618233161,64538\n444,b4b2c,2429,0.0003983409076596906,45.89205928965763,46,0.8920592896576309,2383\n445,298a3,231,3.788256470538844e-05,4.364374514578392,4,0.3643745145783921,227\n446,3df28,3918,0.0006425276559121728,74.0243261823296,75,0.024326182329602375,3843\n447,99b56,594,9.741230924242742e-05,11.222677323201578,11,0.2226773232015784,583\n448,7c7e3,285,4.6738229181972755e-05,5.384617907596717,5,0.3846179075967173,280\n449,230a9,29334,0.004810593736224522,554.2188831629547,555,0.21888316295473942,28779\n450,11677,332788,0.05457516425617666,6287.4955236256,6288,0.4955236256000717,326500\n451,2e832,12,1.967925439240958e-06,0.22672075400407227,0,0.22672075400407227,12\n452,b961f,105,1.7219347593358382e-05,1.9838065975356325,1,0.9838065975356325,104\n453,caff0,2624,0.00043031969604735617,49.57627154222381,50,0.5762715422238074,2574\n454,c0a9c,76675,0.012574223587816704,1448.6511511051867,1449,0.6511511051867274,75226\n455,91d53,2145,0.00035176667226432124,40.52633477822792,41,0.5263347782279197,2104\n456,77494,1404,0.0002302472763911921,26.526328218476458,27,0.5263282184764577,1377\n457,9346a,160758,0.026363313146791495,3037.2645810155545,3038,0.2645810155545405,157720\n458,48d36,690,0.00011315571275635509,13.036443355234157,13,0.03644335523415698,677\n459,de96c,20110,0.003297915048594639,379.9461969184912,380,0.9461969184911823,19730\n460,1eca8,1577,0.00025861820147358256,29.7948857553685,30,0.7948857553685009,1547\n461,29179,1443,0.0002366430340687252,27.26317066898969,28,0.2631706689896909,1415\n462,f6416,12243,0.0020077759293855874,231.31184927265477,232,0.3118492726547686,12011\n463,eb5ab,24654,0.004043102814920548,465.79778910136656,466,0.7977891013665612,24188\n464,bb24e,275,4.5098291315938624e-05,5.195683945926657,5,0.19568394592665683,270\n465,089a1,1595,0.000261570089632444,30.13496688637461,31,0.13496688637460963,1564\n466,8f9ad,8142,0.00133523741052499,153.83003159176306,154,0.8300315917630599,7988\n467,01aa2,67287,0.011034649919183862,1271.2799478893344,1272,0.2799478893343803,66015\n468,b8427,76545,0.012552904395558262,1446.1950096034761,1447,0.19500960347613727,75098\n469,4ca48,20684,0.003392047482104998,390.7910063183526,391,0.7910063183525722,20293\n470,a6fe0,673,0.00011036781838409706,12.715255620395054,12,0.7152556203950535,661\n471,8131b,524,8.59327441801885e-05,9.900139591511158,9,0.9001395915111576,515\n472,9ed2f,2084,0.000341763051281513,39.37383761204055,40,0.3738376120405533,2044\n473,62851,47540,0.007796264615126262,898.1920537794664,899,0.19205377946639146,46641\n474,c8364,85303,0.013989161978630954,1611.663373234115,1612,0.6633732341149425,83691\n475,63dcc,5712,0.000936732509078696,107.9190789059384,108,0.9190789059384059,5604\n476,53490,53415,0.008759728111421314,1009.1907562606268,1010,0.19075626062681295,52405\n477,ea1ca,9513,0.0015600728919582694,179.7328777367283,180,0.7328777367283124,9333\n478,b98cd,86047,0.014111173355863893,1625.7200599823675,1626,0.720059982367502,84421\n479,5cc45,458,7.510915426436323e-05,8.65317544448876,8,0.6531754444887596,450\n480,935e4,1124,0.0001843290161422364,21.23617729171477,21,0.2361772917147711,1103\n481,25830,51,8.363683116774072e-06,0.9635632045173074,0,0.9635632045173074,51\n482,e1bce,401,6.576150842796868e-05,7.576251862969416,7,0.576251862969416,394\n483,522f3,25806,0.00423202365708768,487.5629814857574,488,0.5629814857574047,25318\n484,a091d,297,4.870615462121371e-05,5.611338661600789,5,0.6113386616007892,292\n485,5d15a,3101,0.0005085447322571842,58.588421513885685,59,0.5884215138856845,3042\n486,3807a,606,9.938023468166839e-05,11.449398077205652,11,0.44939807720565206,595\n487,6fc74,14406,0.00236249448980877,272.1782651818888,273,0.17826518188877571,14133\n488,c4b61,186,3.0502844308234848e-05,3.5141716870631203,3,0.5141716870631203,183\n489,1d557,935,0.0001533341904741913,17.665325416150633,17,0.6653254161506332,918\n490,4810f,267,4.378634102311132e-05,5.044536776590609,5,0.04453677659060862,262\n492,20171,1303,0.00021368390394424736,24.61809520560885,24,0.6180952056088493,1279\n493,645a8,507,8.314484980793048e-05,9.578951856672054,9,0.5789518566720542,498\n494,32855,5841,0.0009578877075505363,110.3563270114822,111,0.35632701148219326,5730\n495,9f0b8,1978,0.00032437970990155125,37.37113761833792,38,0.3711376183379187,1940\n496,bfa97,171,2.804293750918365e-05,3.23077074455803,3,0.23077074455803004,168\n497,c7ae8,39,6.3957576775331136e-06,0.7368424505132349,0,0.7368424505132349,39\n498,0443b,107,1.754733516656521e-05,2.0215933898696448,2,0.021593389869644763,105\n499,8b6d1,36,5.9037763177228745e-06,0.680162262012217,0,0.680162262012217,36\n500,0df28,803,0.00013168701064254076,15.171397122105837,15,0.17139712210583724,788\n501,b8da7,1772,0.00029059698986124815,33.479098007934674,34,0.47909800793467383,1738\n502,be111,30,4.919813598102395e-06,0.5668018850101807,0,0.5668018850101807,30\n503,fff89,160,2.6239005856546107e-05,3.022943386720964,3,0.022943386720963854,157\n504,69f64,6756,0.0011079420222926595,127.64378450429271,128,0.6437845042927108,6628\n505,cd02f,1582,0.00025943817040659966,29.889352736203534,30,0.8893527362035343,1552\n506,49604,1205,0.00019761251285711287,22.76654238124226,22,0.7665423812422603,1183\n507,f5741,2674,0.0004385193853775268,50.52094135057411,51,0.5209413505741125,2623\n508,35c93,984,0.00016136988601775856,18.59110182833393,18,0.5911018283339295,966\n509,7af3c,15579,0.0025548592014945737,294.34021888578684,295,0.34021888578683956,15284\n510,b061f,733,0.00012020744558030186,13.848859390415416,13,0.8488593904154165,720\n511,10963,255,4.181841558387036e-05,4.817816022586537,4,0.8178160225865367,251\n512,7429f,596,9.774029681563425e-05,11.26046411553559,11,0.2604641155355907,585\n513,8e06f,114,1.86952916727891e-05,2.153847163038687,2,0.15384716303868684,112\n514,adfd4,10604,0.0017389901131425933,200.3455729549319,201,0.345572954931896,10403\n515,f5fd5,4215,0.0006912338105333865,79.63566484393039,80,0.6356648439303854,4135\n516,565d2,285,4.6738229181972755e-05,5.384617907596717,5,0.3846179075967173,280\n517,8f6cf,119,1.9515260605806166e-05,2.2483141438737166,2,0.24831414387371664,117\n518,4f72e,3024,0.0004959172106887215,57.13363000902623,58,0.13363000902622701,2966\n519,d9867,1147,0.00018810087323411492,21.670725403555913,21,0.6707254035559131,1126\n520,2c9f1,82,1.344749050147988e-05,1.549258485694494,1,0.549258485694494,81\n521,dccb0,2004,0.00032864354835324,37.86236591868007,38,0.8623659186800694,1966\n522,6e283,2322,0.00038079357249312537,43.87046589978799,44,0.8704658997879875,2278\n523,43593,923,0.00015136626503495037,17.438604662146563,17,0.43860466214656313,906\n524,17b6e,903,0.0001480863893028821,17.06073673880644,17,0.06073673880644037,886\n525,72ae9,357,5.85457818174185e-05,6.74494243162115,6,0.7449424316211504,351\n526,d902d,873,0.0001431665757047797,16.49393485379626,16,0.49393485379626156,857\n527,64690,318,5.215002413988539e-05,6.008099981107916,6,0.00809998110791632,312\n528,f476f,25,4.099844665085329e-06,0.4723349041751506,0,0.4723349041751506,25\n529,82153,127,2.0827210898633473e-05,2.3994613132097653,2,0.3994613132097653,125\n530,de876,128,2.0991204685236885e-05,2.418354709376771,2,0.418354709376771,126\n531,0f108,641,0.00010512001721278785,12.110666943050862,12,0.11066694305086244,629\n532,c4fd2,1020,0.00016727366233548145,19.271264090346147,19,0.27126409034614696,1001\n533,5a3a4,147,2.4107086630701735e-05,2.7773292365498854,2,0.7773292365498854,145\n534,59b0c,157,2.5747024496735867e-05,2.966263198219946,2,0.9662631982199459,155\n535,1b5c6,24,3.935850878481916e-06,0.45344150800814453,0,0.45344150800814453,24\n536,433d0,129,2.1155198471840298e-05,2.437248105543777,2,0.43724810554377713,127\n537,69d01,1266,0.00020761613383992107,23.919039547429627,23,0.9190395474296267,1243\n538,d05da,277,4.542627888914545e-05,5.233470738260669,5,0.2334707382606691,272\n539,c552a,924,0.00015153025882155377,17.45749805831357,17,0.4574980583135684,907\n540,ede77,196,3.214278217426898e-05,3.703105648733181,3,0.7031056487331808,193\n541,5fb34,336,5.510191229874683e-05,6.348181112114024,6,0.34818111211402414,330\n542,3c8b0,303,4.969011734083419e-05,5.724699038602826,5,0.724699038602826,298\n543,69faf,48,7.871701756963832e-06,0.9068830160162891,0,0.9068830160162891,48\n544,4282e,40,6.559751464136527e-06,0.755735846680241,0,0.755735846680241,40\n545,c8c9d,21,3.4438695186716767e-06,0.3967613195071265,0,0.3967613195071265,21\n" + ) + ) + + for x in range(100_01, df.remaining_balance.sum(), 12345): + amount = USDCent(x) + res = business_payout_event_manager.distribute_amount(df=df, amount=amount) + assert isinstance(res, pd.Series) + assert res.sum() == amount + + def test_ach_payment_min_amount( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + payout_event_manager, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + adj_to_fail_with_tx_factory, + thl_web_rr, + lm, + product_manager, + ): + """Test having a Business with three products.. one that lost money + and two that gained money. Ensure that the Business balance + reflects that to compensate for the Product in the negative and only + assigns Brokerage Product payments from the 2 accounts that have + a positive balance. + """ + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p1) + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("5.00"), + started=start + timedelta(days=1), + ) + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("5.00"), + started=start + timedelta(days=6), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + bp_payout_factory( + product=u1.product, + amount=USDCent(475), # 95% of $5.00 + created=start + timedelta(days=1, minutes=1), + ) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + with pytest.raises(expected_exception=AssertionError) as cm: + business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(500), + pm=product_manager, + thl_lm=thl_lm, + ) + assert "Must issue Supplier Payouts at least $100 minimum." in str(cm) + + def test_ach_payment( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + payout_event_manager, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + adj_to_fail_with_tx_factory, + thl_web_rr, + lm, + product_manager, + rm_ledger_collection, + rm_pop_ledger_merge, + ): + """Test having a Business with three products.. one that lost money + and two that gained money. Ensure that the Business balance + reflects that to compensate for the Product in the negative and only + assigns Brokerage Product payments from the 2 accounts that have + a positive balance. + """ + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + p2: Product = product_factory(business=business) + p3: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + u2: User = user_factory(product=p2) + u3: User = user_factory(product=p3) + thl_lm.get_account_or_create_bp_wallet(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p2) + thl_lm.get_account_or_create_bp_wallet(product=p3) + + ach_id1 = uuid4().hex + ach_id2 = uuid4().hex + + # Product 1: Complete, Payout, Recon.. + s1 = session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("5.00"), + started=start + timedelta(days=1), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + bp_payout_factory( + product=u1.product, + amount=USDCent(475), # 95% of $5.00 + ext_ref_id=ach_id1, + created=start + timedelta(days=1, minutes=1), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + adj_to_fail_with_tx_factory( + session=s1, + created=start + timedelta(days=1, minutes=2), + ) + + # Product 2: Complete x10 + for idx in range(15): + session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal("7.50"), + started=start + timedelta(days=1, hours=2, minutes=1 + idx), + ) + + # Product 3: Complete x5 + for idx in range(10): + session_with_tx_factory( + user=u3, + wall_req_cpi=Decimal("7.50"), + started=start + timedelta(days=1, hours=3, minutes=1 + idx), + ) + + # Now that we paid out the business, let's confirm the updated balances + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + bb1 = business.balance + pb1 = bb1.product_balances[0] + pb2 = bb1.product_balances[1] + pb3 = bb1.product_balances[2] + assert bb1.payout == 25 * 712 + 475 # $7.50 * .95% = $7.125 = $7.12 + assert bb1.adjustment == -475 + assert bb1.net == ((25 * 7.12 + 4.75) - 4.75) * 100 + # The balance is lower because the single $4.75 payout + assert bb1.balance == bb1.net - 475 + assert ( + bb1.available_balance + == (pb2.available_balance + pb3.available_balance) - 475 + ) + + assert pb1.available_balance == 0 + assert pb2.available_balance == 8010 + assert pb3.available_balance == 5340 + + assert bb1.recoup_usd_str == "$4.75" + assert pb1.recoup_usd_str == "$4.75" + assert pb2.recoup_usd_str == "$0.00" + assert pb3.recoup_usd_str == "$0.00" + + assert business.payouts is None + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + assert len(business.payouts) == 1 + assert business.payouts[0].ext_ref_id == ach_id1 + + bp1 = business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(bb1.available_balance), + pm=product_manager, + thl_lm=thl_lm, + created=start + timedelta(days=1, hours=5), + ) + assert isinstance(bp1, BusinessPayoutEvent) + assert len(bp1.bp_payouts) == 2 + assert bp1.bp_payouts[0].status == PayoutStatus.COMPLETE + assert bp1.bp_payouts[1].status == PayoutStatus.COMPLETE + bp1_tx = brokerage_product_payout_event_manager.check_for_ledger_tx( + thl_ledger_manager=thl_lm, + payout_event=bp1.bp_payouts[0], + product_id=bp1.bp_payouts[0].product_id, + amount=bp1.bp_payouts[0].amount, + ) + assert bp1_tx + + bp2_tx = brokerage_product_payout_event_manager.check_for_ledger_tx( + thl_ledger_manager=thl_lm, + payout_event=bp1.bp_payouts[1], + product_id=bp1.bp_payouts[1].product_id, + amount=bp1.bp_payouts[1].amount, + ) + assert bp2_tx + + # Now that we paid out the business, let's confirm the updated balances + rm_ledger_collection() + rm_pop_ledger_merge() + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + assert len(business.payouts) == 2 + assert len(business.payouts[0].bp_payouts) == 2 + assert len(business.payouts[1].bp_payouts) == 1 + + bb2 = business.balance + + # Okay os we have the balance before, and after the Business Payout + # of bb1.available_balance worth.. + assert bb1.payout == bb2.payout + assert bb1.adjustment == bb2.adjustment + assert bb1.net == bb2.net + assert bb1.available_balance > bb2.available_balance + + # This is the ultimate test. Confirm that the second time we get the + # Business balance, it is equal to the first time we Business balance + # minus the amount that was just paid out across any children + # Brokerage Products. + # + # This accounts for all the net positive and net negative children + # Brokerage Products under the Business in thi + assert bb2.balance == bb1.balance - bb1.available_balance + + def test_ach_payment_partial_amount( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + payout_event_manager, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + adj_to_fail_with_tx_factory, + thl_web_rr, + lm, + product_manager, + rm_ledger_collection, + rm_pop_ledger_merge, + ): + """There are valid instances when we want issue a ACH or Wire to a + Business, but not for the full Available Balance amount in their + account. + + To test this, we'll create a Business with multiple Products, and + a cumulative Available Balance of $100 (for example), but then only + issue a payout of $75 (for example). We want to confirm the sum + of the Product payouts equals the $75 number and isn't greedy and + takes the full $100 amount that is available to the Business. + + """ + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + p2: Product = product_factory(business=business) + p3: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + u2: User = user_factory(product=p2) + u3: User = user_factory(product=p3) + thl_lm.get_account_or_create_bp_wallet(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p2) + thl_lm.get_account_or_create_bp_wallet(product=p3) + + # Product 1, 2, 3: Complete, and Payout multiple times. + for idx in range(5): + for u in [u1, u2, u3]: + session_with_tx_factory( + user=u, + wall_req_cpi=Decimal("50.00"), + started=start + timedelta(days=1, hours=2, minutes=1 + idx), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + # Now that we paid out the business, let's confirm the updated balances + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + + # Confirm the initial amounts. + assert len(business.payouts) == 0 + bb1 = business.balance + assert bb1.payout == 3 * 5 * 4750 + assert bb1.adjustment == 0 + assert bb1.payout == bb1.net + + assert bb1.balance_usd_str == "$712.50" + assert bb1.available_balance_usd_str == "$534.39" + + for x in range(2): + assert bb1.product_balances[x].balance == 5 * 4750 + assert bb1.product_balances[x].available_balance_usd_str == "$178.13" + + assert business.payouts_total_str == "$0.00" + assert business.balance.payment_usd_str == "$0.00" + assert business.balance.available_balance_usd_str == "$534.39" + + # This is the important part, even those the Business has $534.39 + # available to it, we are only trying to issue out a $250.00 ACH or + # Wire to the Business + bp1 = business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(250_00), + pm=product_manager, + thl_lm=thl_lm, + created=start + timedelta(days=1, hours=3), + ) + assert isinstance(bp1, BusinessPayoutEvent) + assert len(bp1.bp_payouts) == 3 + + # Now that we paid out the business, let's confirm the updated + # balances. Clear and rebuild the parquet files. + rm_ledger_collection() + rm_pop_ledger_merge() + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + # Now rebuild and confirm the payouts, balance.payment, and the + # balance.available_balance are reflective of having a $250 ACH/Wire + # sent to the Business + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + assert len(business.payouts) == 1 + assert len(business.payouts[0].bp_payouts) == 3 + assert business.payouts_total_str == "$250.00" + assert business.balance.payment_usd_str == "$250.00" + assert business.balance.available_balance_usd_str == "$346.88" + + def test_ach_tx_id_reference( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + payout_event_manager, + brokerage_product_payout_event_manager, + business_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + adj_to_fail_with_tx_factory, + thl_web_rr, + lm, + product_manager, + rm_ledger_collection, + rm_pop_ledger_merge, + ): + + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + p2: Product = product_factory(business=business) + p3: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + u2: User = user_factory(product=p2) + u3: User = user_factory(product=p3) + thl_lm.get_account_or_create_bp_wallet(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p2) + thl_lm.get_account_or_create_bp_wallet(product=p3) + + ach_id1 = uuid4().hex + ach_id2 = uuid4().hex + + for idx in range(15): + for iidx, u in enumerate([u1, u2, u3]): + session_with_tx_factory( + user=u, + wall_req_cpi=Decimal("7.50"), + started=start + timedelta(days=1, hours=1 + iidx, minutes=1 + idx), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + rm_ledger_collection() + rm_pop_ledger_merge() + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + bp1 = business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(100_01), + transaction_id=ach_id1, + pm=product_manager, + thl_lm=thl_lm, + created=start + timedelta(days=2, hours=1), + ) + + rm_ledger_collection() + rm_pop_ledger_merge() + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + bp2 = business_payout_event_manager.create_from_ach_or_wire( + business=business, + amount=USDCent(100_02), + transaction_id=ach_id2, + pm=product_manager, + thl_lm=thl_lm, + created=start + timedelta(days=4, hours=1), + ) + + assert isinstance(bp1, BusinessPayoutEvent) + assert isinstance(bp2, BusinessPayoutEvent) + + rm_ledger_collection() + rm_pop_ledger_merge() + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + business.prebuild_payouts( + thl_pg_config=thl_web_rr, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + ) + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + assert business.payouts[0].ext_ref_id == ach_id2 + assert business.payouts[1].ext_ref_id == ach_id1 diff --git a/tests/managers/thl/test_product.py b/tests/managers/thl/test_product.py new file mode 100644 index 0000000..78d5dde --- /dev/null +++ b/tests/managers/thl/test_product.py @@ -0,0 +1,362 @@ +from uuid import uuid4 + +import pytest + +from generalresearch.models import Source +from generalresearch.models.thl.product import ( + Product, + SourceConfig, + UserCreateConfig, + SourcesConfig, + UserHealthConfig, + ProfilingConfig, + SupplyPolicy, + SupplyConfig, +) +from test_utils.models.conftest import product_factory + + +class TestProductManagerGetMethods: + def test_get_by_uuid(self, product_manager): + product: Product = product_manager.create_dummy( + product_id=uuid4().hex, + team_id=uuid4().hex, + name=f"Test Product ID #{uuid4().hex[:6]}", + ) + + instance = product_manager.get_by_uuid(product_uuid=product.id) + assert isinstance(instance, Product) + # self.assertEqual(instance.model_dump(mode="json"), product.model_dump(mode="json")) + assert instance.id == product.id + + with pytest.raises(AssertionError) as cm: + product_manager.get_by_uuid(product_uuid="abc123") + assert "invalid uuid" in str(cm.value) + + with pytest.raises(AssertionError) as cm: + product_manager.get_by_uuid(product_uuid=uuid4().hex) + assert "product not found" in str(cm.value) + + def test_get_by_uuids(self, product_manager): + cnt = 5 + + product_uuids = [uuid4().hex for idx in range(cnt)] + for product_id in product_uuids: + product_manager.create_dummy( + product_id=product_id, + team_id=uuid4().hex, + name=f"Test Product ID #{uuid4().hex[:6]}", + ) + + res = product_manager.get_by_uuids(product_uuids=product_uuids) + assert isinstance(res, list) + assert cnt == len(res) + for instance in res: + assert isinstance(instance, Product) + + with pytest.raises(AssertionError) as cm: + product_manager.get_by_uuids(product_uuids=product_uuids + [uuid4().hex]) + assert "incomplete product response" in str(cm.value) + + with pytest.raises(AssertionError) as cm: + product_manager.get_by_uuids(product_uuids=product_uuids + ["abc123"]) + assert "invalid uuid" in str(cm.value) + + def test_get_by_uuid_if_exists(self, product_manager): + product: Product = product_manager.create_dummy( + product_id=uuid4().hex, + team_id=uuid4().hex, + name=f"Test Product ID #{uuid4().hex[:6]}", + ) + instance = product_manager.get_by_uuid_if_exists(product_uuid=product.id) + assert isinstance(instance, Product) + + instance = product_manager.get_by_uuid_if_exists(product_uuid="abc123") + assert instance == None + + def test_get_by_uuids_if_exists(self, product_manager): + product_uuids = [uuid4().hex for _ in range(2)] + for product_id in product_uuids: + product_manager.create_dummy( + product_id=product_id, + team_id=uuid4().hex, + name=f"Test Product ID #{uuid4().hex[:6]}", + ) + + res = product_manager.get_by_uuids_if_exists(product_uuids=product_uuids) + assert isinstance(res, list) + assert 2 == len(res) + for instance in res: + assert isinstance(instance, Product) + + res = product_manager.get_by_uuids_if_exists( + product_uuids=product_uuids + [uuid4().hex] + ) + assert isinstance(res, list) + assert 2 == len(res) + for instance in res: + assert isinstance(instance, Product) + + # # This will raise an error b/c abc123 isn't a uuid. + # res = product_manager.get_by_uuids_if_exists( + # product_uuids=product_uuids + ["abc123"] + # ) + # assert isinstance(res, list) + # assert 2 == len(res) + # for instance in res: + # assert isinstance(instance, Product) + + def test_get_by_business_ids(self, product_manager): + business_ids = [uuid4().hex for i in range(5)] + + product_manager.fetch_uuids(business_uuids=business_ids) + + for business_id in business_ids: + product_manager.create( + product_id=uuid4().hex, + team_id=None, + business_id=business_id, + redirect_url="https://www.example.com", + name=f"Test Product ID #{uuid4().hex[:6]}", + user_create_config=None, + ) + + +class TestProductManagerCreation: + + def test_base(self, product_manager): + instance = product_manager.create_dummy( + product_id=uuid4().hex, + team_id=uuid4().hex, + name=f"New Test Product {uuid4().hex[:6]}", + ) + + assert isinstance(instance, Product) + + +class TestProductManagerCreate: + + def test_create_simple(self, product_manager): + # Always required: product_id, team_id, name, redirect_url + # Required internally - if not passed use default: harmonizer_domain, + # commission_pct, sources + product_id = uuid4().hex + team_id = uuid4().hex + business_id = uuid4().hex + + product_manager.create( + product_id=product_id, + team_id=team_id, + business_id=business_id, + name=f"Test Product ID #{uuid4().hex[:6]}", + redirect_url="https://www.example.com", + ) + + instance = product_manager.get_by_uuid(product_uuid=product_id) + + assert team_id == instance.team_id + assert business_id == instance.business_id + + +class TestProductManager: + sources = [ + SourceConfig.model_validate(x) + for x in [ + {"name": "d", "active": True}, + { + "name": "f", + "active": False, + "banned_countries": ["ps", "in", "in"], + }, + { + "name": "s", + "active": True, + "supplier_id": "3488", + "allow_pii_only_buyers": True, + "allow_unhashed_buyers": True, + }, + {"name": "e", "active": True, "withhold_profiling": True}, + ] + ] + + def test_get_by_uuid1(self, product_manager, team, product, product_factory): + p1 = product_factory(team=team) + instance = product_manager.get_by_uuid(product_uuid=p1.uuid) + assert instance.id == p1.id + + # No Team and no user_create_config + assert instance.team_id == team.uuid + + # user_create_config can't be None, so ensure the default was set. + assert isinstance(instance.user_create_config, UserCreateConfig) + assert 0 == instance.user_create_config.min_hourly_create_limit + assert instance.user_create_config.max_hourly_create_limit is None + + def test_get_by_uuid2(self, product_manager, product_factory): + p2 = product_factory() + instance = product_manager.get_by_uuid(p2.id) + assert instance.id, p2.id + + # Team and default user_create_config + assert instance.team_id is not None + assert instance.team_id == p2.team_id + + assert 0 == instance.user_create_config.min_hourly_create_limit + assert instance.user_create_config.max_hourly_create_limit is None + + def test_get_by_uuid3(self, product_manager, product_factory): + p3 = product_factory() + instance = product_manager.get_by_uuid(p3.id) + assert instance.id == p3.id + + # Team and default user_create_config + assert instance.team_id is not None + assert instance.team_id == p3.team_id + + assert ( + p3.user_create_config.min_hourly_create_limit + == instance.user_create_config.min_hourly_create_limit + ) + assert instance.user_create_config.max_hourly_create_limit is None + assert not instance.user_wallet_config.enabled + + def test_sources(self, product_manager): + user_defined = [SourceConfig(name=Source.DYNATA, active=False)] + sources_config = SourcesConfig(user_defined=user_defined) + p = product_manager.create_dummy(sources_config=sources_config) + + p2 = product_manager.get_by_uuid(p.id) + + assert p == p2 + assert p2.sources_config.user_defined == user_defined + + # Assert d is off and everything else is on + dynata = p2.sources_dict[Source.DYNATA] + assert not dynata.active + assert all(x.active is True for x in p2.sources if x.name != Source.DYNATA) + + def test_global_sources(self, product_manager): + sources_config = SupplyConfig( + policies=[ + SupplyPolicy( + name=Source.DYNATA, + active=True, + address=["https://www.example.com"], + distribute_harmonizer_active=True, + ) + ] + ) + p1 = product_manager.create_dummy(sources_config=sources_config) + p2 = product_manager.get_by_uuid(p1.id) + assert p1 == p2 + + p1.sources_config.policies.append( + SupplyPolicy( + name=Source.CINT, + active=True, + address=["https://www.example.com"], + distribute_harmonizer_active=True, + ) + ) + product_manager.update(p1) + p2 = product_manager.get_by_uuid(p1.id) + assert p1 == p2 + + def test_user_health_config(self, product_manager): + p = product_manager.create_dummy( + user_health_config=UserHealthConfig(banned_countries=["ng", "in"]) + ) + + p2 = product_manager.get_by_uuid(p.id) + + assert p == p2 + assert p2.user_health_config.banned_countries == ["in", "ng"] + assert p2.user_health_config.allow_ban_iphist + + def test_profiling_config(self, product_manager): + p = product_manager.create_dummy( + profiling_config=ProfilingConfig(max_questions=1) + ) + p2 = product_manager.get_by_uuid(p.id) + + assert p == p2 + assert p2.profiling_config.max_questions == 1 + + # def test_user_create_config(self): + # # -- Product 1 --- + # instance1 = PM.get_by_uuid(self.product_id1) + # self.assertEqual(instance1.user_create_config.min_hourly_create_limit, 0) + # + # self.assertEqual(instance1.user_create_config.max_hourly_create_limit, None) + # self.assertEqual(60, instance1.user_create_config.clip_hourly_create_limit(60)) + # + # # -- Product 2 --- + # instance2 = PM.get_by_uuid(self.product2.id) + # self.assertEqual(instance2.user_create_config.min_hourly_create_limit, 200) + # self.assertEqual(instance2.user_create_config.max_hourly_create_limit, None) + # + # self.assertEqual(200, instance2.user_create_config.clip_hourly_create_limit(60)) + # self.assertEqual(300, instance2.user_create_config.clip_hourly_create_limit(300)) + # + # def test_create_and_cache(self): + # PM.uuid_cache.clear() + # + # # Hit it once, should hit mysql + # with self.assertLogs(level='INFO') as cm: + # p = PM.get_by_uuid(product_id3) + # self.assertEqual(cm.output, [f"INFO:root:Product.get_by_uuid:{product_id3}"]) + # self.assertIsNotNone(p) + # self.assertEqual(p.id, product_id3) + # + # # Hit it again, should be pulled from cachetools cache + # with self.assertLogs(level='INFO') as cm: + # logger.info("nothing") + # p = PM.get_by_uuid(product_id3) + # self.assertEqual(cm.output, [f"INFO:root:nothing"]) + # self.assertEqual(len(cm.output), 1) + # self.assertIsNotNone(p) + # self.assertEqual(p.id, product_id3) + + +class TestProductManagerUpdate: + + def test_update(self, product_manager): + p = product_manager.create_dummy() + p.name = "new name" + p.enabled = False + p.user_create_config = UserCreateConfig(min_hourly_create_limit=200) + p.sources_config = SourcesConfig( + user_defined=[SourceConfig(name=Source.DYNATA, active=False)] + ) + product_manager.update(new_product=p) + # We cleared the cache in the update + # PM.id_cache.clear() + p2 = product_manager.get_by_uuid(p.id) + + assert p2.name == "new name" + assert not p2.enabled + assert p2.user_create_config.min_hourly_create_limit == 200 + assert not p2.sources_dict[Source.DYNATA].active + + +class TestProductManagerCacheClear: + + def test_cache_clear(self, product_manager): + p = product_manager.create_dummy() + product_manager.get_by_uuid(product_uuid=p.id) + product_manager.get_by_uuid(product_uuid=p.id) + product_manager.pg_config.execute_write( + query=""" + UPDATE userprofile_brokerageproduct + SET name = 'test-6d9a5ddfd' + WHERE id = %s""", + params=[p.id], + ) + + product_manager.cache_clear(p.id) + + # Calling this with or without kwargs hits different internal keys in the cache! + p2 = product_manager.get_by_uuid(product_uuid=p.id) + assert p2.name == "test-6d9a5ddfd" + p2 = product_manager.get_by_uuid(product_uuid=p.id) + assert p2.name == "test-6d9a5ddfd" diff --git a/tests/managers/thl/test_product_prod.py b/tests/managers/thl/test_product_prod.py new file mode 100644 index 0000000..7b4f677 --- /dev/null +++ b/tests/managers/thl/test_product_prod.py @@ -0,0 +1,82 @@ +import logging +from uuid import uuid4 + +import pytest + +from generalresearch.models.thl.product import Product +from test_utils.models.conftest import product_factory + +logger = logging.getLogger() + + +class TestProductManagerGetMethods: + + def test_get_by_uuid(self, product_manager, product_factory): + # Just test that we load properly + for p in [product_factory(), product_factory(), product_factory()]: + instance = product_manager.get_by_uuid(product_uuid=p.id) + assert isinstance(instance, Product) + assert instance.id == p.id + assert instance.uuid == p.uuid + + with pytest.raises(expected_exception=AssertionError) as cm: + product_manager.get_by_uuid(product_uuid=uuid4().hex) + assert "product not found" in str(cm.value) + + def test_get_by_uuids(self, product_manager, product_factory): + products = [product_factory(), product_factory(), product_factory()] + cnt = len(products) + res = product_manager.get_by_uuids(product_uuids=[p.id for p in products]) + assert isinstance(res, list) + assert cnt == len(res) + for instance in res: + assert isinstance(instance, Product) + + with pytest.raises(expected_exception=AssertionError) as cm: + product_manager.get_by_uuids( + product_uuids=[p.id for p in products] + [uuid4().hex] + ) + assert "incomplete product response" in str(cm.value) + with pytest.raises(expected_exception=AssertionError) as cm: + product_manager.get_by_uuids( + product_uuids=[p.id for p in products] + ["abc123"] + ) + assert "invalid uuid passed" in str(cm.value) + + def test_get_by_uuid_if_exists(self, product_factory, product_manager): + products = [product_factory(), product_factory(), product_factory()] + + instance = product_manager.get_by_uuid_if_exists(product_uuid=products[0].id) + assert isinstance(instance, Product) + + instance = product_manager.get_by_uuid_if_exists(product_uuid="abc123") + assert instance is None + + def test_get_by_uuids_if_exists(self, product_manager, product_factory): + products = [product_factory(), product_factory(), product_factory()] + + res = product_manager.get_by_uuids_if_exists( + product_uuids=[p.id for p in products[:2]] + ) + assert isinstance(res, list) + assert 2 == len(res) + for instance in res: + assert isinstance(instance, Product) + + res = product_manager.get_by_uuids_if_exists( + product_uuids=[p.id for p in products[:2]] + [uuid4().hex] + ) + assert isinstance(res, list) + assert 2 == len(res) + for instance in res: + assert isinstance(instance, Product) + + +class TestProductManagerGetAll: + + @pytest.mark.skip(reason="TODO") + def test_get_ALL_by_ids(self, product_manager): + products = product_manager.get_all(rand_limit=50) + logger.info(f"Fetching {len(products)} product uuids") + # todo: once timebucks stops spamming broken accounts, fetch more + pass diff --git a/tests/managers/thl/test_profiling/__init__.py b/tests/managers/thl/test_profiling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/managers/thl/test_profiling/test_question.py b/tests/managers/thl/test_profiling/test_question.py new file mode 100644 index 0000000..998466e --- /dev/null +++ b/tests/managers/thl/test_profiling/test_question.py @@ -0,0 +1,49 @@ +from uuid import uuid4 + +from generalresearch.managers.thl.profiling.question import QuestionManager +from generalresearch.models import Source + + +class TestQuestionManager: + + def test_get_multi_upk(self, question_manager: QuestionManager, upk_data): + qs = question_manager.get_multi_upk( + question_ids=[ + "8a22de34f985476aac85e15547100db8", + "0565f87d4bf044298ba169de1339ff7e", + "b2b32d68403647e3a87e778a6348d34c", + uuid4().hex, + ] + ) + assert len(qs) == 3 + + def test_get_questions_ranked(self, question_manager: QuestionManager, upk_data): + qs = question_manager.get_questions_ranked(country_iso="mx", language_iso="spa") + assert len(qs) >= 40 + assert qs[0].importance.task_score > qs[40].importance.task_score + assert all(q.country_iso == "mx" and q.language_iso == "spa" for q in qs) + + def test_lookup_by_property(self, question_manager: QuestionManager, upk_data): + q = question_manager.lookup_by_property( + property_code="i:industry", country_iso="us", language_iso="eng" + ) + assert q.source == Source.INNOVATE + + q.explanation_template = "You work in the {answer} industry." + q.explanation_fragment_template = "you work in the {answer} industry" + question_manager.update_question_explanation(q) + + q = question_manager.lookup_by_property( + property_code="i:industry", country_iso="us", language_iso="eng" + ) + assert q.explanation_template + + def test_filter_by_property(self, question_manager: QuestionManager, upk_data): + lookup = [ + ("i:industry", "us", "eng"), + ("i:industry", "mx", "eng"), + ("m:age", "us", "eng"), + (f"m:{uuid4().hex}", "us", "eng"), + ] + qs = question_manager.filter_by_property(lookup) + assert len(qs) == 3 diff --git a/tests/managers/thl/test_profiling/test_schema.py b/tests/managers/thl/test_profiling/test_schema.py new file mode 100644 index 0000000..ae61527 --- /dev/null +++ b/tests/managers/thl/test_profiling/test_schema.py @@ -0,0 +1,44 @@ +from generalresearch.models.thl.profiling.upk_property import PropertyType + + +class TestUpkSchemaManager: + + def test_get_props_info(self, upk_schema_manager, upk_data): + props = upk_schema_manager.get_props_info() + assert ( + len(props) == 16955 + ) # ~ 70 properties x each country they are available in + + gender = [ + x + for x in props + if x.country_iso == "us" + and x.property_id == "73175402104741549f21de2071556cd7" + ] + assert len(gender) == 1 + gender = gender[0] + assert len(gender.allowed_items) == 3 + assert gender.allowed_items[0].label == "female" + assert gender.allowed_items[1].label == "male" + assert gender.prop_type == PropertyType.UPK_ITEM + assert gender.categories[0].label == "Demographic" + + age = [ + x + for x in props + if x.country_iso == "us" + and x.property_id == "94f7379437874076b345d76642d4ce6d" + ] + assert len(age) == 1 + age = age[0] + assert age.allowed_items is None + assert age.prop_type == PropertyType.UPK_NUMERICAL + assert age.gold_standard + + cars = [ + x + for x in props + if x.country_iso == "us" and x.property_label == "household_auto_type" + ][0] + assert not cars.gold_standard + assert cars.categories[0].label == "Autos & Vehicles" diff --git a/tests/managers/thl/test_profiling/test_uqa.py b/tests/managers/thl/test_profiling/test_uqa.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/managers/thl/test_profiling/test_uqa.py @@ -0,0 +1 @@ + diff --git a/tests/managers/thl/test_profiling/test_user_upk.py b/tests/managers/thl/test_profiling/test_user_upk.py new file mode 100644 index 0000000..53bb8fe --- /dev/null +++ b/tests/managers/thl/test_profiling/test_user_upk.py @@ -0,0 +1,59 @@ +from datetime import datetime, timezone + +from generalresearch.managers.thl.profiling.user_upk import UserUpkManager + +now = datetime.now(tz=timezone.utc) +base = { + "country_iso": "us", + "language_iso": "eng", + "timestamp": now, +} +upk_ans_dict = [ + {"pred": "gr:gender", "obj": "gr:male"}, + {"pred": "gr:age_in_years", "obj": "43"}, + {"pred": "gr:home_postal_code", "obj": "33143"}, + {"pred": "gr:ethnic_group", "obj": "gr:caucasians"}, + {"pred": "gr:ethnic_group", "obj": "gr:asian"}, +] +for a in upk_ans_dict: + a.update(base) + + +class TestUserUpkManager: + + def test_user_upk_empty(self, user_upk_manager: UserUpkManager, upk_data, user): + res = user_upk_manager.get_user_upk_mysql(user_id=user.user_id) + assert len(res) == 0 + + def test_user_upk(self, user_upk_manager: UserUpkManager, upk_data, user): + for x in upk_ans_dict: + x["user_id"] = user.user_id + user_upk = user_upk_manager.populate_user_upk_from_dict(upk_ans_dict) + user_upk_manager.set_user_upk(upk_ans=user_upk) + + d = user_upk_manager.get_user_upk_simple(user_id=user.user_id) + assert d["gender"] == "male" + assert d["age_in_years"] == 43 + assert d["home_postal_code"] == "33143" + assert d["ethnic_group"] == {"caucasians", "asian"} + + # Change my answers. age 43->44, gender male->female, + # ethnic->remove asian, add black_or_african_american + for x in upk_ans_dict: + if x["pred"] == "age_in_years": + x["obj"] = "44" + if x["pred"] == "gender": + x["obj"] = "female" + upk_ans_dict[-1]["obj"] = "black_or_african_american" + user_upk = user_upk_manager.populate_user_upk_from_dict(upk_ans_dict) + user_upk_manager.set_user_upk(upk_ans=user_upk) + + d = user_upk_manager.get_user_upk_simple(user_id=user.user_id) + assert d["gender"] == "female" + assert d["age_in_years"] == 44 + assert d["home_postal_code"] == "33143" + assert d["ethnic_group"] == {"caucasians", "black_or_african_american"} + + age, gender = user_upk_manager.get_age_gender(user_id=user.user_id) + assert age == 44 + assert gender == "female" diff --git a/tests/managers/thl/test_session_manager.py b/tests/managers/thl/test_session_manager.py new file mode 100644 index 0000000..6bedc2b --- /dev/null +++ b/tests/managers/thl/test_session_manager.py @@ -0,0 +1,137 @@ +from datetime import timedelta +from decimal import Decimal +from uuid import uuid4 + +from faker import Faker + +from generalresearch.models import DeviceType +from generalresearch.models.legacy.bucket import Bucket +from generalresearch.models.thl.definitions import ( + Status, + StatusCode1, + SessionStatusCode2, +) +from test_utils.models.conftest import user + +fake = Faker() + + +class TestSessionManager: + def test_create_session(self, session_manager, user, utc_hour_ago): + bucket = Bucket( + loi_min=timedelta(seconds=60), + loi_max=timedelta(seconds=120), + user_payout_min=Decimal("1"), + user_payout_max=Decimal("2"), + ) + + s1 = session_manager.create( + started=utc_hour_ago, + user=user, + country_iso="us", + device_type=DeviceType.MOBILE, + ip=fake.ipv4_public(), + bucket=bucket, + url_metadata={"foo": "bar"}, + uuid_id=uuid4().hex, + ) + + assert s1.id is not None + s2 = session_manager.get_from_uuid(session_uuid=s1.uuid) + assert s1 == s2 + + def test_finish_with_status(self, session_manager, user, utc_hour_ago): + uuid_1 = uuid4().hex + session = session_manager.create( + started=utc_hour_ago, user=user, uuid_id=uuid_1 + ) + session_manager.finish_with_status( + session=session, + status=Status.FAIL, + status_code_1=StatusCode1.SESSION_CONTINUE_FAIL, + status_code_2=SessionStatusCode2.USER_IS_BLOCKED, + ) + + s2 = session_manager.get_from_uuid(session_uuid=uuid_1) + assert s2.status == Status.FAIL + assert s2.status_code_1 == StatusCode1.SESSION_CONTINUE_FAIL + assert s2.status_code_2 == SessionStatusCode2.USER_IS_BLOCKED + + +class TestSessionManagerFilter: + + def test_base(self, session_manager, user, utc_now): + uuid_id = uuid4().hex + session_manager.create(started=utc_now, user=user, uuid_id=uuid_id) + res = session_manager.filter(limit=1) + assert len(res) != 0 + assert isinstance(res, list) + assert res[0].uuid == uuid_id + + def test_user(self, session_manager, user, utc_hour_ago): + session_manager.create(started=utc_hour_ago, user=user, uuid_id=uuid4().hex) + session_manager.create(started=utc_hour_ago, user=user, uuid_id=uuid4().hex) + + res = session_manager.filter(user=user) + assert len(res) == 2 + + def test_product( + self, product_factory, user_factory, session_manager, user, utc_hour_ago + ): + from generalresearch.models.thl.session import Session + from generalresearch.models.thl.user import User + + p1 = product_factory() + + for n in range(5): + u = user_factory(product=p1) + session_manager.create(started=utc_hour_ago, user=u, uuid_id=uuid4().hex) + + res = session_manager.filter( + product_uuids=[p1.uuid], started_since=utc_hour_ago + ) + assert isinstance(res[0], Session) + assert isinstance(res[0].user, User) + assert len(res) == 5 + + def test_team( + self, + product_factory, + user_factory, + team, + session_manager, + user, + utc_hour_ago, + thl_web_rr, + ): + p1 = product_factory(team=team) + + for n in range(5): + u = user_factory(product=p1) + session_manager.create(started=utc_hour_ago, user=u, uuid_id=uuid4().hex) + + team.prefetch_products(thl_pg_config=thl_web_rr) + assert len(team.product_uuids) == 1 + res = session_manager.filter(product_uuids=team.product_uuids) + assert len(res) == 5 + + def test_business( + self, + product_factory, + business, + user_factory, + session_manager, + user, + utc_hour_ago, + thl_web_rr, + ): + p1 = product_factory(business=business) + + for n in range(5): + u = user_factory(product=p1) + session_manager.create(started=utc_hour_ago, user=u, uuid_id=uuid4().hex) + + business.prefetch_products(thl_pg_config=thl_web_rr) + assert len(business.product_uuids) == 1 + res = session_manager.filter(product_uuids=business.product_uuids) + assert len(res) == 5 diff --git a/tests/managers/thl/test_survey.py b/tests/managers/thl/test_survey.py new file mode 100644 index 0000000..58c4577 --- /dev/null +++ b/tests/managers/thl/test_survey.py @@ -0,0 +1,376 @@ +import uuid +from datetime import datetime, timezone +from decimal import Decimal + +import pytest + +from generalresearch.models import Source +from generalresearch.models.legacy.bucket import ( + SurveyEligibilityCriterion, + TopNPlusBucket, + DurationSummary, + PayoutSummary, +) +from generalresearch.models.thl.profiling.user_question_answer import ( + UserQuestionAnswer, +) +from generalresearch.models.thl.survey.model import ( + Survey, + SurveyStat, + SurveyCategoryModel, + SurveyEligibilityDefinition, +) + + +@pytest.fixture(scope="session") +def surveys_fixture(): + return [ + Survey(source=Source.TESTING, survey_id="a", buyer_code="buyer1"), + Survey(source=Source.TESTING, survey_id="b", buyer_code="buyer2"), + Survey(source=Source.TESTING, survey_id="c", buyer_code="buyer2"), + ] + + +ssa = SurveyStat( + survey_source=Source.TESTING, + survey_survey_id="a", + survey_is_live=True, + quota_id="__all__", + cpi=Decimal(1), + country_iso="us", + version=1, + complete_too_fast_cutoff=300, + prescreen_conv_alpha=10, + prescreen_conv_beta=10, + conv_alpha=10, + conv_beta=10, + dropoff_alpha=10, + dropoff_beta=10, + completion_time_mu=1, + completion_time_sigma=0.4, + mobile_eligible_alpha=10, + mobile_eligible_beta=10, + desktop_eligible_alpha=10, + desktop_eligible_beta=10, + tablet_eligible_alpha=10, + tablet_eligible_beta=10, + long_fail_rate=1, + user_report_coeff=1, + recon_likelihood=0, + score_x0=0, + score_x1=1, + score=100, +) +ssb = ssa.model_dump() +ssb["survey_source"] = Source.TESTING +ssb["survey_survey_id"] = "b" +ssb["completion_time_mu"] = 2 +ssb["score"] = 90 +ssb = SurveyStat.model_validate(ssb) + + +class TestSurvey: + + def test( + self, + delete_buyers_surveys, + buyer_manager, + survey_manager, + surveys_fixture, + ): + survey_manager.create_or_update(surveys_fixture) + survey_ids = {s.survey_id for s in surveys_fixture} + res = survey_manager.filter_by_natural_key( + source=Source.TESTING, survey_ids=survey_ids + ) + assert len(res) == len(surveys_fixture) + assert res[0].id is not None + surveys2 = surveys_fixture.copy() + surveys2.append( + Survey(source=Source.TESTING, survey_id="d", buyer_code="buyer2") + ) + survey_manager.create_or_update(surveys2) + survey_ids = {s.survey_id for s in surveys2} + res2 = survey_manager.filter_by_natural_key( + source=Source.TESTING, survey_ids=survey_ids + ) + assert res2[0].id == res[0].id + assert res2[0] == res[0] + assert len(res2) == len(surveys2) + + def test_category(self, survey_manager): + survey1 = Survey(id=562289, survey_id="a", source=Source.TESTING) + survey2 = Survey(id=562290, survey_id="a", source=Source.TESTING) + categories = list(survey_manager.category_manager.categories.values()) + sc = [SurveyCategoryModel(category=c, strength=1 / 2) for c in categories[:2]] + sc2 = [SurveyCategoryModel(category=c, strength=1 / 2) for c in categories[2:4]] + survey1.categories = sc + survey2.categories = sc2 + surveys = [survey1, survey2] + survey_manager.update_surveys_categories(surveys) + + def test_survey_eligibility( + self, survey_manager, upk_data, question_manager, uqa_manager + ): + bucket = TopNPlusBucket( + id="c82cf98c578a43218334544ab376b00e", + contents=[], + duration=DurationSummary(max=1, min=1, q1=1, q2=1, q3=1), + quality_score=1, + payout=PayoutSummary(max=1, min=1, q1=1, q2=1, q3=1), + uri="https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i=2a4a897a76464af2b85703b72a125da0&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=82fe142", + ) + + survey1 = Survey( + survey_id="a", + source=Source.TESTING, + eligibility_criteria=SurveyEligibilityDefinition( + # "5d6d9f3c03bb40bf9d0a24f306387d7c", # gr:gender + # "c1309f099ab84a39b01200a56dac65cf", # d:600 + # "90e86550ddf84b08a9f7f5372dd9651b", # i:gender + # "1a8216ddb09440a8bc748cf8ca89ecec", # i:adhoc_13126 + property_codes=("i:adhoc_13126", "i:gender", "i:something"), + ), + ) + # might have a couple surveys in the bucket ... merge them all together + qualifying_questions = set(survey1.eligibility_criteria.property_codes) + + uqas = [ + UserQuestionAnswer( + question_id="5d6d9f3c03bb40bf9d0a24f306387d7c", + answer=("1",), + calc_answers={"i:gender": ("1",), "d:1": ("2",)}, + country_iso="us", + language_iso="eng", + property_code="gr:gender", + ), + UserQuestionAnswer( + question_id="c1309f099ab84a39b01200a56dac65cf", + answer=("50796", "50784"), + country_iso="us", + language_iso="eng", + property_code="d:600", + calc_answers={"d:600": ("50796", "50784")}, + ), + UserQuestionAnswer( + question_id="1a8216ddb09440a8bc748cf8ca89ecec", + answer=("3", "4"), + country_iso="us", + language_iso="eng", + property_code="i:adhoc_13126", + calc_answers={"i:adhoc_13126": ("3", "4")}, + ), + ] + uqad = dict() + for uqa in uqas: + for k, v in uqa.calc_answers.items(): + if k in qualifying_questions: + uqad[k] = uqa + uqad[uqa.property_code] = uqa + + question_ids = {uqa.question_id for uqa in uqad.values()} + qs = question_manager.get_multi_upk(question_ids) + # Sort question by LOWEST task_count (rarest) + qs = sorted(qs, key=lambda x: x.importance.task_count if x.importance else 0) + # qd = {q.id: q for q in qs} + + q = [x for x in qs if x.ext_question_id == "i:adhoc_13126"][0] + q.explanation_template = "You have been diagnosed with: {answer}." + q = [x for x in qs if x.ext_question_id == "gr:gender"][0] + q.explanation_template = "Your gender is {answer}." + + ecs = [] + for q in qs: + answer_code = uqad[q.ext_question_id].answer + answer_label = tuple([q.choices_text_lookup[ans] for ans in answer_code]) + explanation = None + if q.explanation_template: + explanation = q.explanation_template.format( + answer=", ".join(answer_label) + ) + sec = SurveyEligibilityCriterion( + question_id=q.id, + question_text=q.text, + qualifying_answer=answer_code, + qualifying_answer_label=answer_label, + explanation=explanation, + property_code=q.ext_question_id, + ) + print(sec) + ecs.append(sec) + bucket.eligibility_criteria = tuple(ecs) + print(bucket) + + +class TestSurveyStat: + def test( + self, + delete_buyers_surveys, + surveystat_manager, + survey_manager, + surveys_fixture, + ): + survey_manager.create_or_update(surveys_fixture) + ss = [ssa, ssb] + surveystat_manager.update_or_create(ss) + keys = [s.unique_key for s in ss] + res = surveystat_manager.filter_by_unique_keys(keys) + assert len(res) == 2 + assert res[0].conv_alpha == 10 + + ssa.conv_alpha = 11 + surveystat_manager.update_or_create([ssa]) + res = surveystat_manager.filter_by_unique_keys([ssa.unique_key]) + assert len(res) == 1 + assert res[0].conv_alpha == 11 + + @pytest.mark.skip() + def test_big( + self, + # delete_buyers_surveys, + surveystat_manager, + survey_manager, + surveys_fixture, + ): + survey = surveys_fixture[0].model_copy() + surveys = [] + for idx in range(20_000): + s = survey.model_copy() + s.survey_id = uuid.uuid4().hex + surveys.append(s) + + survey_manager.create_or_update(surveys) + # Realistically, we're going to have like, say 20k surveys + # and 99% of them will be updated each time + survey_stats = [] + for survey in surveys: + ss = ssa.model_copy() + ss.survey__survey_id = survey.survey_id + ss.quota_id = "__all__" + ss.score = 10 + survey_stats.append(ss) + print(len(survey_stats)) + print(survey_stats[12].natural_key, survey_stats[2000].natural_key) + print(f"----a-----: {datetime.now().isoformat()}") + res = surveystat_manager.update_or_create(survey_stats) + print(f"----b-----: {datetime.now().isoformat()}") + assert len(res) == 20_000 + return + + # 1,000 of the 20,000 are "new" + now = datetime.now(tz=timezone.utc) + for s in ss[:1000]: + s.survey__survey_id = "b" + s.updated_at = now + # 18,000 need to be updated (scores update or whatever) + for s in ss[1000:-1000]: + s.score = 20 + s.conv_alpha = 20 + s.conv_beta = 20 + s.updated_at = now + # and 1,000 don't change + print(f"----c-----: {datetime.now().isoformat()}") + res2 = surveystat_manager.update_or_create(ss) + print(f"----d-----: {datetime.now().isoformat()}") + assert len(res2) == 20_000 + + def test_ymsp( + self, + delete_buyers_surveys, + surveys_fixture, + survey_manager, + surveystat_manager, + ): + source = Source.TESTING + survey = surveys_fixture[0].model_copy() + surveys = [] + for idx in range(100): + s = survey.model_copy() + s.survey_id = uuid.uuid4().hex + surveys.append(s) + survey_stats = [] + for survey in surveys: + ss = ssa.model_copy() + ss.survey_survey_id = survey.survey_id + survey_stats.append(ss) + + surveystat_manager.update_surveystats_for_source( + source=source, surveys=surveys, survey_stats=survey_stats + ) + # UPDATE ------- + since = datetime.now(tz=timezone.utc) + print(f"{since=}") + + # 10 survey disappear + surveys = surveys[10:] + + # and 2 new ones are created + for idx in range(2): + s = survey.model_copy() + s.survey_id = uuid.uuid4().hex + surveys.append(s) + survey_stats = [] + for survey in surveys: + ss = ssa.model_copy() + ss.survey_survey_id = survey.survey_id + survey_stats.append(ss) + surveystat_manager.update_surveystats_for_source( + source=source, surveys=surveys, survey_stats=survey_stats + ) + + live_surveys = survey_manager.filter_by_source_live(source=source) + assert len(live_surveys) == 92 # 100 - 10 + 2 + + ss = surveystat_manager.filter_by_updated_since(since=since) + assert len(ss) == 102 # 92 existing + 10 not live + + ss = surveystat_manager.filter_by_live() + assert len(ss) == 92 + + def test_filter( + self, + delete_buyers_surveys, + surveys_fixture, + survey_manager, + surveystat_manager, + ): + surveys = [] + survey = surveys_fixture[0].model_copy() + survey.source = Source.TESTING + survey.survey_id = "a" + surveys.append(survey.model_copy()) + survey.source = Source.TESTING + survey.survey_id = "b" + surveys.append(survey.model_copy()) + survey.source = Source.TESTING2 + survey.survey_id = "b" + surveys.append(survey.model_copy()) + survey.source = Source.TESTING2 + survey.survey_id = "c" + surveys.append(survey.model_copy()) + # 4 surveys t:a, t:b, u:b, u:c + + survey_stats = [] + for survey in surveys: + ss = ssa.model_copy() + ss.survey_survey_id = survey.survey_id + ss.survey_source = survey.source + survey_stats.append(ss) + + surveystat_manager.update_surveystats_for_source( + source=Source.TESTING, surveys=surveys[:2], survey_stats=survey_stats[:2] + ) + surveystat_manager.update_surveystats_for_source( + source=Source.TESTING2, surveys=surveys[2:], survey_stats=survey_stats[2:] + ) + + survey_keys = [f"{s.source.value}:{s.survey_id}" for s in surveys] + res = surveystat_manager.filter(survey_keys=survey_keys, min_score=0.01) + assert len(res) == 4 + res = survey_manager.filter_by_keys(survey_keys) + assert len(res) == 4 + + res = surveystat_manager.filter(survey_keys=survey_keys[:2]) + assert len(res) == 2 + res = survey_manager.filter_by_keys(survey_keys[:2]) + assert len(res) == 2 diff --git a/tests/managers/thl/test_survey_penalty.py b/tests/managers/thl/test_survey_penalty.py new file mode 100644 index 0000000..4c7dc08 --- /dev/null +++ b/tests/managers/thl/test_survey_penalty.py @@ -0,0 +1,101 @@ +import uuid + +import pytest +from cachetools.keys import _HashedTuple + +from generalresearch.models import Source +from generalresearch.models.thl.survey.penalty import ( + BPSurveyPenalty, + TeamSurveyPenalty, +) + + +@pytest.fixture +def product_uuid() -> str: + # Nothing touches the db here, we don't need actual products or teams + return uuid.uuid4().hex + + +@pytest.fixture +def team_uuid() -> str: + # Nothing touches the db here, we don't need actual products or teams + return uuid.uuid4().hex + + +@pytest.fixture +def penalties(product_uuid, team_uuid): + return [ + BPSurveyPenalty( + source=Source.TESTING, survey_id="a", penalty=0.1, product_id=product_uuid + ), + BPSurveyPenalty( + source=Source.TESTING, survey_id="b", penalty=0.2, product_id=product_uuid + ), + TeamSurveyPenalty( + source=Source.TESTING, survey_id="a", penalty=1, team_id=team_uuid + ), + TeamSurveyPenalty( + source=Source.TESTING, survey_id="c", penalty=1, team_id=team_uuid + ), + # Source.TESTING:a is different from Source.TESTING2:a ! + BPSurveyPenalty( + source=Source.TESTING2, survey_id="a", penalty=0.5, product_id=product_uuid + ), + # For a random BP, should not do anything + BPSurveyPenalty( + source=Source.TESTING, survey_id="b", penalty=1, product_id=uuid.uuid4().hex + ), + ] + + +class TestSurveyPenalty: + def test(self, surveypenalty_manager, penalties, product_uuid, team_uuid): + surveypenalty_manager.set_penalties(penalties) + + res = surveypenalty_manager.get_penalties_for( + product_id=product_uuid, team_id=team_uuid + ) + assert res == {"t:a": 1.0, "t:b": 0.2, "t:c": 1, "u:a": 0.5} + + # We can update penalties for a marketplace and not erase them for another. + # But remember, marketplace is batched, so it'll overwrite the previous + # values for that marketplace + penalties = [ + BPSurveyPenalty( + source=Source.TESTING2, + survey_id="b", + penalty=0.1, + product_id=product_uuid, + ) + ] + surveypenalty_manager.set_penalties(penalties) + res = surveypenalty_manager.get_penalties_for( + product_id=product_uuid, team_id=team_uuid + ) + assert res == {"t:a": 1.0, "t:b": 0.2, "t:c": 1, "u:b": 0.1} + + # Team id doesn't exist, so it should return the product's penalties + team_id_random = uuid.uuid4().hex + surveypenalty_manager.cache.clear() + res = surveypenalty_manager.get_penalties_for( + product_id=product_uuid, team_id=team_id_random + ) + assert res == {"t:a": 0.1, "t:b": 0.2, "u:b": 0.1} + + # Assert it is cached (no redis lookup needed) + assert surveypenalty_manager.cache.currsize == 1 + res = surveypenalty_manager.get_penalties_for( + product_id=product_uuid, team_id=team_id_random + ) + assert res == {"t:a": 0.1, "t:b": 0.2, "u:b": 0.1} + assert surveypenalty_manager.cache.currsize == 1 + cached_key = tuple(list(list(surveypenalty_manager.cache.keys())[0])[1:]) + assert cached_key == tuple( + ["product_id", product_uuid, "team_id", team_id_random] + ) + + # Both don't exist, return nothing + res = surveypenalty_manager.get_penalties_for( + product_id=uuid.uuid4().hex, team_id=uuid.uuid4().hex + ) + assert res == {} diff --git a/tests/managers/thl/test_task_adjustment.py b/tests/managers/thl/test_task_adjustment.py new file mode 100644 index 0000000..839bbe1 --- /dev/null +++ b/tests/managers/thl/test_task_adjustment.py @@ -0,0 +1,346 @@ +import logging +from random import randint + +import pytest +from datetime import datetime, timezone, timedelta +from decimal import Decimal + +from generalresearch.models import Source +from generalresearch.models.thl.definitions import ( + Status, + StatusCode1, + WallAdjustedStatus, +) + + +@pytest.fixture() +def session_complete(session_with_tx_factory, user): + return session_with_tx_factory( + user=user, final_status=Status.COMPLETE, wall_req_cpi=Decimal("1.23") + ) + + +@pytest.fixture() +def session_complete_with_wallet(session_with_tx_factory, user_with_wallet): + return session_with_tx_factory( + user=user_with_wallet, + final_status=Status.COMPLETE, + wall_req_cpi=Decimal("1.23"), + ) + + +@pytest.fixture() +def session_fail(user, session_manager, wall_manager): + session = session_manager.create_dummy( + started=datetime.now(timezone.utc), user=user + ) + wall1 = wall_manager.create_dummy( + session_id=session.id, + user_id=user.user_id, + source=Source.DYNATA, + req_survey_id="72723", + req_cpi=Decimal("3.22"), + started=datetime.now(timezone.utc), + ) + wall_manager.finish( + wall=wall1, + status=Status.FAIL, + status_code_1=StatusCode1.PS_FAIL, + finished=wall1.started + timedelta(seconds=randint(a=60 * 2, b=60 * 10)), + ) + session.wall_events.append(wall1) + return session + + +class TestHandleRecons: + + def test_complete_to_recon( + self, + session_complete, + thl_lm, + task_adjustment_manager, + wall_manager, + session_manager, + caplog, + ): + print(wall_manager.pg_config.dsn) + mid = session_complete.uuid + wall_uuid = session_complete.wall_events[-1].uuid + s = session_complete + ledger_manager = thl_lm + + revenue_account = ledger_manager.get_account_task_complete_revenue() + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert ( + current_amount == 123 + ), "this is the amount of revenue from this task complete" + + bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet( + s.user.product + ) + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 117, "this is the amount paid to the BP" + + # Do the work here !! ----v + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + ) + assert ( + len(task_adjustment_manager.filter_by_wall_uuid(wall_uuid=wall_uuid)) == 1 + ) + + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert current_amount == 0, "after recon, it should be zeroed" + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 0, "this is the amount paid to the BP" + commission_account = ledger_manager.get_account_or_create_bp_commission( + s.user.product + ) + assert ledger_manager.get_account_balance(commission_account) == 0 + + # Now, say we get the exact same *adjust to incomplete* msg again. It should do nothing! + adjusted_timestamp = datetime.now(tz=timezone.utc) + wall = wall_manager.get_from_uuid(wall_uuid=wall_uuid) + with pytest.raises(match=" is already "): + wall_manager.adjust_status( + wall, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=adjusted_timestamp, + adjusted_cpi=Decimal(0), + ) + + session = session_manager.get_from_id(wall.session_id) + user = session.user + with caplog.at_level(logging.INFO): + ledger_manager.create_tx_task_adjustment( + wall, user=user, created=adjusted_timestamp + ) + assert "No transactions needed" in caplog.text + + session.wall_events = wall_manager.get_wall_events(session.id) + session.user.prefetch_product(wall_manager.pg_config) + + with caplog.at_level(logging.INFO, logger="Wall"): + session_manager.adjust_status(session) + assert "is already f" in caplog.text or "is already Status.FAIL" in caplog.text + + with caplog.at_level(logging.INFO, logger="LedgerManager"): + ledger_manager.create_tx_bp_adjustment(session, created=adjusted_timestamp) + assert "No transactions needed" in caplog.text + + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert current_amount == 0, "after recon, it should be zeroed" + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 0, "this is the amount paid to the BP" + + # And if we get an adj to fail, and handle it, it should do nothing at all + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + ) + assert ( + len(task_adjustment_manager.filter_by_wall_uuid(wall_uuid=wall_uuid)) == 1 + ) + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert current_amount == 0, "after recon, it should be zeroed" + + def test_fail_to_complete(self, session_fail, thl_lm, task_adjustment_manager): + s = session_fail + mid = session_fail.uuid + wall_uuid = session_fail.wall_events[-1].uuid + ledger_manager = thl_lm + + revenue_account = ledger_manager.get_account_task_complete_revenue() + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", mid + ) + assert ( + current_amount == 0 + ), "this is the amount of revenue from this task complete" + + bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet( + s.user.product + ) + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 0, "this is the amount paid to the BP" + + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + ) + + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert current_amount == 322, "after recon, we should be paid" + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 306, "this is the amount paid to the BP" + + # Now reverse it back to fail + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + ) + + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert current_amount == 0 + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 0 + + commission_account = ledger_manager.get_account_or_create_bp_commission( + s.user.product + ) + assert ledger_manager.get_account_balance(commission_account) == 0 + + def test_complete_already_complete( + self, session_complete, thl_lm, task_adjustment_manager + ): + s = session_complete + mid = session_complete.uuid + wall_uuid = session_complete.wall_events[-1].uuid + ledger_manager = thl_lm + + for _ in range(4): + # just run it 4 times to make sure nothing happens 4 times + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + ) + + revenue_account = ledger_manager.get_account_task_complete_revenue() + bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet( + s.user.product + ) + commission_account = ledger_manager.get_account_or_create_bp_commission( + s.user.product + ) + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert current_amount == 123 + assert ledger_manager.get_account_balance(commission_account) == 6 + + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 117 + + def test_incomplete_already_incomplete( + self, session_fail, thl_lm, task_adjustment_manager + ): + s = session_fail + mid = session_fail.uuid + wall_uuid = session_fail.wall_events[-1].uuid + ledger_manager = thl_lm + + for _ in range(4): + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + ) + + revenue_account = ledger_manager.get_account_task_complete_revenue() + bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet( + s.user.product + ) + commission_account = ledger_manager.get_account_or_create_bp_commission( + s.user.product + ) + current_amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", mid + ) + assert current_amount == 0 + assert ledger_manager.get_account_balance(commission_account) == 0 + + current_bp_payout = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert current_bp_payout == 0 + + def test_complete_to_recon_user_wallet( + self, + session_complete_with_wallet, + user_with_wallet, + thl_lm, + task_adjustment_manager, + ): + s = session_complete_with_wallet + mid = s.uuid + wall_uuid = s.wall_events[-1].uuid + ledger_manager = thl_lm + + revenue_account = ledger_manager.get_account_task_complete_revenue() + amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert amount == 123, "this is the amount of revenue from this task complete" + + bp_wallet_account = ledger_manager.get_account_or_create_bp_wallet( + s.user.product + ) + user_wallet_account = ledger_manager.get_account_or_create_user_wallet(s.user) + commission_account = ledger_manager.get_account_or_create_bp_commission( + s.user.product + ) + amount = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert amount == 70, "this is the amount paid to the BP" + amount = ledger_manager.get_account_filtered_balance( + user_wallet_account, "thl_session", mid + ) + assert amount == 47, "this is the amount paid to the user" + assert ( + ledger_manager.get_account_balance(commission_account) == 6 + ), "earned commission" + + task_adjustment_manager.handle_single_recon( + ledger_manager=thl_lm, + wall_uuid=wall_uuid, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + ) + + amount = ledger_manager.get_account_filtered_balance( + revenue_account, "thl_wall", wall_uuid + ) + assert amount == 0 + amount = ledger_manager.get_account_filtered_balance( + bp_wallet_account, "thl_session", mid + ) + assert amount == 0 + amount = ledger_manager.get_account_filtered_balance( + user_wallet_account, "thl_session", mid + ) + assert amount == 0 + assert ( + ledger_manager.get_account_balance(commission_account) == 0 + ), "earned commission" diff --git a/tests/managers/thl/test_task_status.py b/tests/managers/thl/test_task_status.py new file mode 100644 index 0000000..55c89c0 --- /dev/null +++ b/tests/managers/thl/test_task_status.py @@ -0,0 +1,696 @@ +import pytest +from datetime import datetime, timezone, timedelta +from decimal import Decimal + +from generalresearch.managers.thl.session import SessionManager +from generalresearch.models import Source +from generalresearch.models.thl.definitions import ( + Status, + WallAdjustedStatus, + StatusCode1, +) +from generalresearch.models.thl.product import ( + PayoutConfig, + UserWalletConfig, + PayoutTransformation, + PayoutTransformationPercentArgs, +) +from generalresearch.models.thl.session import Session, WallOut +from generalresearch.models.thl.task_status import TaskStatusResponse +from generalresearch.models.thl.user import User + + +start1 = datetime(2023, 2, 1, tzinfo=timezone.utc) +finish1 = start1 + timedelta(minutes=5) +recon1 = start1 + timedelta(days=20) +start2 = datetime(2023, 2, 2, tzinfo=timezone.utc) +finish2 = start2 + timedelta(minutes=5) +start3 = datetime(2023, 2, 3, tzinfo=timezone.utc) +finish3 = start3 + timedelta(minutes=5) + + +@pytest.fixture(scope="session") +def bp1(product_manager): + # user wallet disabled, payout xform NULL + return product_manager.create_dummy( + user_wallet_config=UserWalletConfig(enabled=False), + payout_config=PayoutConfig(), + ) + + +@pytest.fixture(scope="session") +def bp2(product_manager): + # user wallet disabled, payout xform 40% + return product_manager.create_dummy( + user_wallet_config=UserWalletConfig(enabled=False), + payout_config=PayoutConfig( + payout_transformation=PayoutTransformation( + f="payout_transformation_percent", + kwargs=PayoutTransformationPercentArgs(pct=0.4), + ) + ), + ) + + +@pytest.fixture(scope="session") +def bp3(product_manager): + # user wallet enabled, payout xform 50% + return product_manager.create_dummy( + user_wallet_config=UserWalletConfig(enabled=True), + payout_config=PayoutConfig( + payout_transformation=PayoutTransformation( + f="payout_transformation_percent", + kwargs=PayoutTransformationPercentArgs(pct=0.5), + ) + ), + ) + + +class TestTaskStatus: + + def test_task_status_complete_1( + self, + bp1, + user_factory, + finished_session_factory, + session_manager: SessionManager, + ): + # User Payout xform NULL + user1: User = user_factory(product=bp1) + s1: Session = finished_session_factory( + user=user1, started=start1, wall_req_cpi=Decimal(1), wall_count=2 + ) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s1.uuid, + "product_id": user1.product_id, + "product_user_id": user1.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s1.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "c", + "payout": 95, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "status_code_1": 14, + "payout_transformation": None, + } + ) + w1 = s1.wall_events[0] + wo1 = WallOut( + uuid=w1.uuid, + source=Source.TESTING, + buyer_id=None, + req_survey_id=w1.req_survey_id, + req_cpi=w1.req_cpi, + started=w1.started, + survey_id=w1.survey_id, + cpi=w1.cpi, + finished=w1.finished, + status=w1.status, + status_code_1=w1.status_code_1, + ) + w2 = s1.wall_events[1] + wo2 = WallOut( + uuid=w2.uuid, + source=Source.TESTING, + buyer_id=None, + req_survey_id=w2.req_survey_id, + req_cpi=w2.req_cpi, + started=w2.started, + survey_id=w2.survey_id, + cpi=w2.cpi, + finished=w2.finished, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + ) + expected_tsr.wall_events = [wo1, wo2] + tsr = session_manager.get_task_status_response(s1.uuid) + assert tsr == expected_tsr + + def test_task_status_complete_2( + self, bp2, user_factory, finished_session_factory, session_manager + ): + # User Payout xform 40% + user2: User = user_factory(product=bp2) + s2: Session = finished_session_factory( + user=user2, started=start1, wall_req_cpi=Decimal(1), wall_count=2 + ) + payout = 95 # 1.00 - 5% commission + user_payout = round(95 * 0.40) # 38 + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s2.uuid, + "product_id": user2.product_id, + "product_user_id": user2.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s2.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "c", + "payout": payout, + "user_payout": user_payout, + "payout_format": "${payout/100:.2f}", + "user_payout_string": "$0.38", + "status_code_1": 14, + "status_code_2": None, + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.4"}, + }, + } + ) + w1 = s2.wall_events[0] + wo1 = WallOut( + uuid=w1.uuid, + source=Source.TESTING, + buyer_id=None, + req_survey_id=w1.req_survey_id, + req_cpi=w1.req_cpi, + started=w1.started, + survey_id=w1.survey_id, + cpi=w1.cpi, + finished=w1.finished, + status=w1.status, + status_code_1=w1.status_code_1, + user_cpi=Decimal("0.38"), + user_cpi_string="$0.38", + ) + w2 = s2.wall_events[1] + wo2 = WallOut( + uuid=w2.uuid, + source=Source.TESTING, + buyer_id=None, + req_survey_id=w2.req_survey_id, + req_cpi=w2.req_cpi, + started=w2.started, + survey_id=w2.survey_id, + cpi=w2.cpi, + finished=w2.finished, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + user_cpi=Decimal("0.38"), + user_cpi_string="$0.38", + ) + expected_tsr.wall_events = [wo1, wo2] + + tsr = session_manager.get_task_status_response(s2.uuid) + assert tsr == expected_tsr + + def test_task_status_complete_3( + self, bp3, user_factory, finished_session_factory, session_manager + ): + # Wallet enabled User Payout xform 50% (the response is identical + # to the user wallet disabled w same xform) + user3: User = user_factory(product=bp3) + s3: Session = finished_session_factory( + user=user3, started=start1, wall_req_cpi=Decimal(1) + ) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s3.uuid, + "product_id": user3.product_id, + "product_user_id": user3.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s3.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "c", + "payout": 95, + "user_payout": 48, + "payout_format": "${payout/100:.2f}", + "user_payout_string": "$0.48", + "kwargs": {}, + "status_code_1": 14, + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.5"}, + }, + } + ) + tsr = session_manager.get_task_status_response(s3.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_fail( + self, bp1, user_factory, finished_session_factory, session_manager + ): + # User Payout xform NULL: user payout is None always + user1: User = user_factory(product=bp1) + s1: Session = finished_session_factory( + user=user1, started=start1, final_status=Status.FAIL + ) + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s1.uuid, + "product_id": user1.product_id, + "product_user_id": user1.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s1.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "f", + "payout": 0, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "kwargs": {}, + "status_code_1": s1.status_code_1.value, + "status_code_2": None, + "adjusted_status": None, + "adjusted_timestamp": None, + "adjusted_payout": None, + "adjusted_user_payout": None, + "adjusted_user_payout_string": None, + "payout_transformation": None, + } + ) + tsr = session_manager.get_task_status_response(s1.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_fail_xform( + self, bp2, user_factory, finished_session_factory, session_manager + ): + # User Payout xform 40%: user_payout is 0 (not None) + + user: User = user_factory(product=bp2) + s: Session = finished_session_factory( + user=user, started=start1, final_status=Status.FAIL + ) + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "f", + "payout": 0, + "user_payout": 0, + "payout_format": "${payout/100:.2f}", + "user_payout_string": "$0.00", + "kwargs": {}, + "status_code_1": s.status_code_1.value, + "status_code_2": None, + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.4"}, + }, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_abandon( + self, bp1, user_factory, session_factory, session_manager + ): + # User Payout xform NULL: all payout fields are None + user: User = user_factory(product=bp1) + s = session_factory(user=user, started=start1) + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": None, + "status": None, + "payout": None, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "kwargs": {}, + "status_code_1": None, + "status_code_2": None, + "adjusted_status": None, + "adjusted_timestamp": None, + "adjusted_payout": None, + "adjusted_user_payout": None, + "adjusted_user_payout_string": None, + "payout_transformation": None, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_abandon_xform( + self, bp2, user_factory, session_factory, session_manager + ): + # User Payout xform 40%: all payout fields are None (same as when payout xform is null) + user: User = user_factory(product=bp2) + s = session_factory(user=user, started=start1) + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": None, + "status": None, + "payout": None, + "user_payout": None, + "payout_format": "${payout/100:.2f}", + "user_payout_string": None, + "kwargs": {}, + "status_code_1": None, + "status_code_2": None, + "adjusted_status": None, + "adjusted_timestamp": None, + "adjusted_payout": None, + "adjusted_user_payout": None, + "adjusted_user_payout_string": None, + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.4"}, + }, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_adj_fail( + self, + bp1, + user_factory, + finished_session_factory, + wall_manager, + session_manager, + ): + # Complete -> Fail + # User Payout xform NULL: adjusted_user_* and user_* is still all None + user: User = user_factory(product=bp1) + s: Session = finished_session_factory( + user=user, started=start1, wall_req_cpi=Decimal(1) + ) + wall_manager.adjust_status( + wall=s.wall_events[-1], + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=recon1, + ) + session_manager.adjust_status(s) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "c", + "payout": 95, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "kwargs": {}, + "status_code_1": StatusCode1.COMPLETE.value, + "status_code_2": None, + "adjusted_status": WallAdjustedStatus.ADJUSTED_TO_FAIL.value, + "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "adjusted_payout": 0, + "adjusted_user_payout": None, + "adjusted_user_payout_string": None, + "payout_transformation": None, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_adj_fail_xform( + self, + bp2, + user_factory, + finished_session_factory, + wall_manager, + session_manager, + ): + # Complete -> Fail + # User Payout xform 40%: adjusted_user_payout is 0 (not null) + user: User = user_factory(product=bp2) + s: Session = finished_session_factory( + user=user, started=start1, wall_req_cpi=Decimal(1) + ) + wall_manager.adjust_status( + wall=s.wall_events[-1], + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=recon1, + ) + session_manager.adjust_status(s) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "c", + "payout": 95, + "user_payout": 38, + "payout_format": "${payout/100:.2f}", + "user_payout_string": "$0.38", + "kwargs": {}, + "status_code_1": StatusCode1.COMPLETE.value, + "status_code_2": None, + "adjusted_status": WallAdjustedStatus.ADJUSTED_TO_FAIL.value, + "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "adjusted_payout": 0, + "adjusted_user_payout": 0, + "adjusted_user_payout_string": "$0.00", + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.4"}, + }, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_adj_complete_from_abandon( + self, + bp1, + user_factory, + session_factory, + wall_manager, + session_manager, + ): + # User Payout xform NULL + user: User = user_factory(product=bp1) + s: Session = session_factory( + user=user, + started=start1, + wall_req_cpi=Decimal(1), + wall_count=2, + final_status=Status.ABANDON, + ) + w = s.wall_events[-1] + wall_manager.adjust_status( + wall=w, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w.cpi, + adjusted_timestamp=recon1, + ) + session_manager.adjust_status(s) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": None, + "status": None, + "payout": None, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "kwargs": {}, + "status_code_1": None, + "status_code_2": None, + "adjusted_status": WallAdjustedStatus.ADJUSTED_TO_COMPLETE.value, + "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "adjusted_payout": 95, + "adjusted_user_payout": None, + "adjusted_user_payout_string": None, + "payout_transformation": None, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_adj_complete_from_abandon_xform( + self, + bp2, + user_factory, + session_factory, + wall_manager, + session_manager, + ): + # User Payout xform 40% + user: User = user_factory(product=bp2) + s: Session = session_factory( + user=user, + started=start1, + wall_req_cpi=Decimal(1), + wall_count=2, + final_status=Status.ABANDON, + ) + w = s.wall_events[-1] + wall_manager.adjust_status( + wall=w, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w.cpi, + adjusted_timestamp=recon1, + ) + session_manager.adjust_status(s) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": None, + "status": None, + "payout": None, + "user_payout": None, + "payout_format": "${payout/100:.2f}", + "user_payout_string": None, + "kwargs": {}, + "status_code_1": None, + "status_code_2": None, + "adjusted_status": WallAdjustedStatus.ADJUSTED_TO_COMPLETE.value, + "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "adjusted_payout": 95, + "adjusted_user_payout": 38, + "adjusted_user_payout_string": "$0.38", + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.4"}, + }, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_adj_complete_from_fail( + self, + bp1, + user_factory, + finished_session_factory, + wall_manager, + session_manager, + ): + # User Payout xform NULL + user: User = user_factory(product=bp1) + s: Session = finished_session_factory( + user=user, + started=start1, + wall_req_cpi=Decimal(1), + wall_count=2, + final_status=Status.FAIL, + ) + w = s.wall_events[-1] + wall_manager.adjust_status( + wall=w, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w.cpi, + adjusted_timestamp=recon1, + ) + session_manager.adjust_status(s) + + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "f", + "payout": 0, + "user_payout": None, + "payout_format": None, + "user_payout_string": None, + "kwargs": {}, + "status_code_1": s.status_code_1.value, + "status_code_2": None, + "adjusted_status": "ac", + "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "adjusted_payout": 95, + "adjusted_user_payout": None, + "adjusted_user_payout_string": None, + "payout_transformation": None, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr + + def test_task_status_adj_complete_from_fail_xform( + self, + bp2, + user_factory, + finished_session_factory, + wall_manager, + session_manager, + ): + # User Payout xform 40% + user: User = user_factory(product=bp2) + s: Session = finished_session_factory( + user=user, + started=start1, + wall_req_cpi=Decimal(1), + wall_count=2, + final_status=Status.FAIL, + ) + w = s.wall_events[-1] + wall_manager.adjust_status( + wall=w, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w.cpi, + adjusted_timestamp=recon1, + ) + session_manager.adjust_status(s) + expected_tsr = TaskStatusResponse.model_validate( + { + "tsid": s.uuid, + "product_id": user.product_id, + "product_user_id": user.product_user_id, + "started": start1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "finished": s.finished.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "status": "f", + "payout": 0, + "user_payout": 0, + "payout_format": "${payout/100:.2f}", + "user_payout_string": "$0.00", + "kwargs": {}, + "status_code_1": s.status_code_1.value, + "status_code_2": None, + "adjusted_status": "ac", + "adjusted_timestamp": recon1.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "adjusted_payout": 95, + "adjusted_user_payout": 38, + "adjusted_user_payout_string": "$0.38", + "payout_transformation": { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.4"}, + }, + } + ) + tsr = session_manager.get_task_status_response(s.uuid) + # Not bothering with wall events ... + tsr.wall_events = None + assert tsr == expected_tsr diff --git a/tests/managers/thl/test_user_manager/__init__.py b/tests/managers/thl/test_user_manager/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/managers/thl/test_user_manager/test_base.py b/tests/managers/thl/test_user_manager/test_base.py new file mode 100644 index 0000000..3c7ee38 --- /dev/null +++ b/tests/managers/thl/test_user_manager/test_base.py @@ -0,0 +1,274 @@ +import logging +from datetime import datetime, timezone +from random import randint +from uuid import uuid4 + +import pytest + +from generalresearch.managers.thl.user_manager import ( + UserCreateNotAllowedError, + get_bp_user_create_limit_hourly, +) +from generalresearch.managers.thl.user_manager.rate_limit import ( + RateLimitItemPerHourConstantKey, +) +from generalresearch.models.thl.product import UserCreateConfig, Product +from generalresearch.models.thl.user import User +from test_utils.models.conftest import ( + user, + product, + user_manager, + product_manager, +) + +logger = logging.getLogger() + + +class TestUserManager: + + def test_copying_lru_cache(self, user_manager, user): + # Before adding the deepcopy_return decorator, this would fail b/c the returned user + # is mutable, and it would mutate in the cache + # user_manager = self.get_user_manager() + + user_manager.clear_user_inmemory_cache(user) + u = user_manager.get_user(user_id=user.user_id) + assert not u.blocked + + u.blocked = True + u = user_manager.get_user(user_id=user.user_id) + assert not u.blocked + + def test_get_user_no_inmemory(self, user, user_manager): + user_manager.clear_user_inmemory_cache(user) + user_manager.get_user.__wrapped__.cache_clear() + u = user_manager.get_user(user_id=user.user_id) + # this should hit mysql + assert u == user + + cache_info = user_manager.get_user.__wrapped__.cache_info() + assert cache_info.hits == 0, cache_info + assert cache_info.misses == 1, cache_info + + # this should hit the lru cache + u = user_manager.get_user(user_id=user.user_id) + assert u == user + + cache_info = user_manager.get_user.__wrapped__.cache_info() + assert cache_info.hits == 1, cache_info + assert cache_info.misses == 1, cache_info + + def test_get_user_with_inmemory(self, user_manager, user): + # user_manager = self.get_user_manager() + + user_manager.set_user_inmemory_cache(user) + user_manager.get_user.__wrapped__.cache_clear() + u = user_manager.get_user(user_id=user.user_id) + # this should hit inmemory cache + assert u == user + + cache_info = user_manager.get_user.__wrapped__.cache_info() + assert cache_info.hits == 0, cache_info + assert cache_info.misses == 1, cache_info + + # this should hit the lru cache + u = user_manager.get_user(user_id=user.user_id) + assert u == user + + cache_info = user_manager.get_user.__wrapped__.cache_info() + assert cache_info.hits == 1, cache_info + assert cache_info.misses == 1, cache_info + + +class TestBlockUserManager: + + def test_block_user(self, product, user_manager): + product_user_id = f"user-{uuid4().hex[:10]}" + + # mysql_user_manager to skip user creation limit check + user: User = user_manager.mysql_user_manager.create_user( + product_id=product.id, product_user_id=product_user_id + ) + assert not user.blocked + + # get_user to make sure caches are populated + user = user_manager.get_user( + product_id=product.id, product_user_id=product_user_id + ) + assert not user.blocked + + assert user_manager.block_user(user) is True + assert user.blocked + + user = user_manager.get_user( + product_id=product.id, product_user_id=product_user_id + ) + assert user.blocked + + user = user_manager.get_user(user_id=user.user_id) + assert user.blocked + + def test_block_user_whitelist(self, product, user_manager, thl_web_rw): + product_user_id = f"user-{uuid4().hex[:10]}" + + # mysql_user_manager to skip user creation limit check + user: User = user_manager.mysql_user_manager.create_user( + product_id=product.id, product_user_id=product_user_id + ) + assert not user.blocked + + now = datetime.now(tz=timezone.utc) + # Adds user to whitelist + thl_web_rw.execute_write( + """ + INSERT INTO userprofile_userstat + (user_id, key, value, date) + VALUES (%(user_id)s, 'USER_HEALTH.access_control', 1, %(date)s) + ON CONFLICT (user_id, key) DO UPDATE SET value=1""", + params={"user_id": user.user_id, "date": now}, + ) + assert user_manager.is_whitelisted(user) + assert user_manager.block_user(user) is False + assert not user.blocked + + +class TestCreateUserManager: + + def test_create_user(self, product_manager, thl_web_rw, user_manager): + product: Product = product_manager.create_dummy( + user_create_config=UserCreateConfig( + min_hourly_create_limit=10, max_hourly_create_limit=69 + ), + ) + + product_user_id = f"user-{uuid4().hex[:10]}" + + user: User = user_manager.mysql_user_manager.create_user( + product_id=product.id, product_user_id=product_user_id + ) + assert isinstance(user, User) + assert user.product_id == product.id + assert user.product_user_id == product_user_id + + assert user.user_id is not None + assert user.uuid is not None + + # make sure thl_user row is created + res_thl_user = thl_web_rw.execute_sql_query( + query=f""" + SELECT * + FROM thl_user AS u + WHERE u.id = %s + """, + params=[user.user_id], + ) + + assert len(res_thl_user) == 1 + + u2 = user_manager.get_user( + product_id=product.id, product_user_id=product_user_id + ) + assert u2.user_id == user.user_id + assert u2.uuid == user.uuid + + def test_create_user_integrity_error(self, product_manager, user_manager, caplog): + product: Product = product_manager.create_dummy( + product_id=uuid4().hex, + team_id=uuid4().hex, + name=f"Test Product ID #{uuid4().hex[:6]}", + user_create_config=UserCreateConfig( + min_hourly_create_limit=10, max_hourly_create_limit=69 + ), + ) + + product_user_id = f"user-{uuid4().hex[:10]}" + rand_msg = f"log-{uuid4().hex}" + + with caplog.at_level(logging.INFO): + logger.info(rand_msg) + user1 = user_manager.mysql_user_manager.create_user( + product_id=product.id, product_user_id=product_user_id + ) + + assert len(caplog.records) == 1 + assert caplog.records[0].getMessage() == rand_msg + + # Should cause a constraint error, triggering a lookup instead + with caplog.at_level(logging.INFO): + user2 = user_manager.mysql_user_manager.create_user( + product_id=product.id, product_user_id=product_user_id + ) + + assert len(caplog.records) == 3 + assert caplog.records[0].getMessage() == rand_msg + assert ( + caplog.records[1].getMessage() + == f"mysql_user_manager.create_user_new integrity error: {product.id} {product_user_id}" + ) + assert ( + caplog.records[2].getMessage() + == f"get_user_from_mysql: {product.id}, {product_user_id}, None, None" + ) + + assert user1 == user2 + + def test_raise_allow_user_create(self, product_manager, user_manager): + rand_num = randint(25, 200) + product: Product = product_manager.create_dummy( + product_id=uuid4().hex, + team_id=uuid4().hex, + name=f"Test Product ID #{uuid4().hex[:6]}", + user_create_config=UserCreateConfig( + min_hourly_create_limit=rand_num, + max_hourly_create_limit=rand_num, + ), + ) + + instance: Product = user_manager.product_manager.get_by_uuid( + product_uuid=product.id + ) + + # get_bp_user_create_limit_hourly is dynamically generated, make sure + # we use this value in our tests and not the + # UserCreateConfig.max_hourly_create_limit value + rl_value: int = get_bp_user_create_limit_hourly(product=instance) + + # This is a randomly generated product_id, which means we'll always + # use the global defaults + assert rand_num == instance.user_create_config.min_hourly_create_limit + assert rand_num == instance.user_create_config.max_hourly_create_limit + assert rand_num == rl_value + + rl = RateLimitItemPerHourConstantKey(rl_value) + assert str(rl) == f"{rand_num} per 1 hour" + + key = rl.key_for("thl-grpc", "allow_user_create", instance.id) + assert key == f"LIMITER/thl-grpc/allow_user_create/{instance.id}" + + # make sure we clear the key or subsequent tests will fail + user_manager.user_manager_limiter.storage.clear(key=key) + + n = 0 + with pytest.raises(expected_exception=UserCreateNotAllowedError) as cm: + for n, _ in enumerate(range(rl_value + 5)): + user_manager.user_manager_limiter.raise_allow_user_create( + product=product + ) + assert rl_value == n + + +class TestUserManagerMethods: + + def test_audit_log(self, user_manager, user, audit_log_manager): + from generalresearch.models.thl.userhealth import AuditLog + + res = audit_log_manager.filter_by_user_id(user_id=user.user_id) + assert len(res) == 0 + + msg = uuid4().hex + user_manager.audit_log(user=user, level=30, event_type=msg) + + res = audit_log_manager.filter_by_user_id(user_id=user.user_id) + assert len(res) == 1 + assert isinstance(res[0], AuditLog) + assert res[0].event_type == msg diff --git a/tests/managers/thl/test_user_manager/test_mysql.py b/tests/managers/thl/test_user_manager/test_mysql.py new file mode 100644 index 0000000..0313bbf --- /dev/null +++ b/tests/managers/thl/test_user_manager/test_mysql.py @@ -0,0 +1,25 @@ +from test_utils.models.conftest import user, user_manager + + +class TestUserManagerMysqlNew: + + def test_get_notset(self, user_manager): + assert ( + user_manager.mysql_user_manager.get_user_from_mysql(user_id=-3105) is None + ) + + def test_get_user_id(self, user, user_manager): + assert ( + user_manager.mysql_user_manager.get_user_from_mysql(user_id=user.user_id) + == user + ) + + def test_get_uuid(self, user, user_manager): + u = user_manager.mysql_user_manager.get_user_from_mysql(user_uuid=user.uuid) + assert u == user + + def test_get_ubp(self, user, user_manager): + u = user_manager.mysql_user_manager.get_user_from_mysql( + product_id=user.product_id, product_user_id=user.product_user_id + ) + assert u == user diff --git a/tests/managers/thl/test_user_manager/test_redis.py b/tests/managers/thl/test_user_manager/test_redis.py new file mode 100644 index 0000000..a69519e --- /dev/null +++ b/tests/managers/thl/test_user_manager/test_redis.py @@ -0,0 +1,80 @@ +import pytest + +from generalresearch.managers.base import Permission + + +class TestUserManagerRedis: + + def test_get_notset(self, user_manager, user): + user_manager.clear_user_inmemory_cache(user=user) + assert user_manager.redis_user_manager.get_user(user_id=user.user_id) is None + + def test_get_user_id(self, user_manager, user): + user_manager.redis_user_manager.set_user(user=user) + + assert user_manager.redis_user_manager.get_user(user_id=user.user_id) == user + + def test_get_uuid(self, user_manager, user): + user_manager.redis_user_manager.set_user(user=user) + + assert user_manager.redis_user_manager.get_user(user_uuid=user.uuid) == user + + def test_get_ubp(self, user_manager, user): + user_manager.redis_user_manager.set_user(user=user) + + assert ( + user_manager.redis_user_manager.get_user( + product_id=user.product_id, product_user_id=user.product_user_id + ) + == user + ) + + @pytest.mark.skip(reason="TODO") + def test_set(self): + # I mean, the sets are implicitly tested by the get tests above. no point + pass + + def test_get_with_cache_prefix(self, settings, user, thl_web_rw, thl_web_rr): + """ + Confirm the prefix functionality is working; we do this so it + is easier to migrate between any potentially breaking versions + if we don't want any broken keys; not as important after + pydantic usage... + """ + from generalresearch.managers.thl.user_manager.user_manager import ( + UserManager, + ) + + um1 = UserManager( + pg_config=thl_web_rw, + pg_config_rr=thl_web_rr, + sql_permissions=[Permission.UPDATE, Permission.CREATE], + redis=settings.redis, + redis_timeout=settings.redis_timeout, + ) + + um2 = UserManager( + pg_config=thl_web_rw, + pg_config_rr=thl_web_rr, + sql_permissions=[Permission.UPDATE, Permission.CREATE], + redis=settings.redis, + redis_timeout=settings.redis_timeout, + cache_prefix="user-lookup-v2", + ) + + um1.get_or_create_user( + product_id=user.product_id, product_user_id=user.product_user_id + ) + um2.get_or_create_user( + product_id=user.product_id, product_user_id=user.product_user_id + ) + + res1 = um1.redis_user_manager.client.get(f"user-lookup:user_id:{user.user_id}") + assert res1 is not None + + res2 = um2.redis_user_manager.client.get( + f"user-lookup-v2:user_id:{user.user_id}" + ) + assert res2 is not None + + assert res1 == res2 diff --git a/tests/managers/thl/test_user_manager/test_user_fetch.py b/tests/managers/thl/test_user_manager/test_user_fetch.py new file mode 100644 index 0000000..a4b3d57 --- /dev/null +++ b/tests/managers/thl/test_user_manager/test_user_fetch.py @@ -0,0 +1,48 @@ +from uuid import uuid4 + +import pytest + +from generalresearch.models.thl.user import User +from test_utils.models.conftest import product, user_manager, user_factory + + +class TestUserManagerFetch: + + def test_fetch(self, user_factory, product, user_manager): + user1: User = user_factory(product=product) + user2: User = user_factory(product=product) + res = user_manager.fetch_by_bpuids( + product_id=product.uuid, + product_user_ids=[user1.product_user_id, user2.product_user_id], + ) + assert len(res) == 2 + + res = user_manager.fetch(user_ids=[user1.user_id, user2.user_id]) + assert len(res) == 2 + + res = user_manager.fetch(user_uuids=[user1.uuid, user2.uuid]) + assert len(res) == 2 + + # filter including bogus values + res = user_manager.fetch(user_uuids=[user1.uuid, uuid4().hex]) + assert len(res) == 1 + + res = user_manager.fetch(user_uuids=[uuid4().hex]) + assert len(res) == 0 + + def test_fetch_invalid(self, user_manager): + with pytest.raises(AssertionError) as e: + user_manager.fetch(user_uuids=[], user_ids=None) + assert "Must pass ONE of user_ids, user_uuids" in str(e.value) + + with pytest.raises(AssertionError) as e: + user_manager.fetch(user_uuids=uuid4().hex) + assert "must pass a collection of user_uuids" in str(e.value) + + with pytest.raises(AssertionError) as e: + user_manager.fetch(user_uuids=[uuid4().hex], user_ids=[1, 2, 3]) + assert "Must pass ONE of user_ids, user_uuids" in str(e.value) + + with pytest.raises(AssertionError) as e: + user_manager.fetch(user_ids=list(range(501))) + assert "limit 500 user_ids" in str(e.value) diff --git a/tests/managers/thl/test_user_manager/test_user_metadata.py b/tests/managers/thl/test_user_manager/test_user_metadata.py new file mode 100644 index 0000000..91dc16a --- /dev/null +++ b/tests/managers/thl/test_user_manager/test_user_metadata.py @@ -0,0 +1,88 @@ +from uuid import uuid4 + +import pytest + +from generalresearch.models.thl.user_profile import UserMetadata +from test_utils.models.conftest import user, user_manager, user_factory + + +class TestUserMetadataManager: + + def test_get_notset(self, user, user_manager, user_metadata_manager): + # The row in the db won't exist. It just returns the default obj with everything None (except for the user_id) + um1 = user_metadata_manager.get(user_id=user.user_id) + assert um1 == UserMetadata(user_id=user.user_id) + + def test_create(self, user_factory, product, user_metadata_manager): + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + + email_address = f"{uuid4().hex}@example.com" + um = UserMetadata(user_id=u1.user_id, email_address=email_address) + # This happens in the model itself, nothing to do with the manager (a model_validator) + assert um.email_sha256 is not None + + user_metadata_manager.update(um) + um2 = user_metadata_manager.get(email_address=email_address) + assert um == um2 + + def test_create_no_email(self, product, user_factory, user_metadata_manager): + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + um = UserMetadata(user_id=u1.user_id) + assert um.email_sha256 is None + + user_metadata_manager.update(um) + um2 = user_metadata_manager.get(user_id=u1.user_id) + assert um == um2 + + def test_update(self, product, user_factory, user_metadata_manager): + from generalresearch.models.thl.user import User + + u: User = user_factory(product=product) + + email_address = f"{uuid4().hex}@example1.com" + um = UserMetadata(user_id=u.user_id, email_address=email_address) + user_metadata_manager.update(user_metadata=um) + + um.email_address = email_address.replace("example1", "example2") + user_metadata_manager.update(user_metadata=um) + + um2 = user_metadata_manager.get(email_address=um.email_address) + assert um2.email_address != email_address + + assert um2 == UserMetadata( + user_id=u.user_id, + email_address=email_address.replace("example1", "example2"), + ) + + def test_filter(self, user_factory, product, user_metadata_manager): + from generalresearch.models.thl.user import User + + user1: User = user_factory(product=product) + user2: User = user_factory(product=product) + + email_address = f"{uuid4().hex}@example.com" + res = user_metadata_manager.filter(email_addresses=[email_address]) + assert len(res) == 0 + + # Create 2 user metadata with the same email address + user_metadata_manager.update( + user_metadata=UserMetadata( + user_id=user1.user_id, email_address=email_address + ) + ) + user_metadata_manager.update( + user_metadata=UserMetadata( + user_id=user2.user_id, email_address=email_address + ) + ) + + res = user_metadata_manager.filter(email_addresses=[email_address]) + assert len(res) == 2 + + with pytest.raises(expected_exception=ValueError) as e: + res = user_metadata_manager.get(email_address=email_address) + assert "More than 1 result returned!" in str(e.value) diff --git a/tests/managers/thl/test_user_streak.py b/tests/managers/thl/test_user_streak.py new file mode 100644 index 0000000..7728f9f --- /dev/null +++ b/tests/managers/thl/test_user_streak.py @@ -0,0 +1,225 @@ +import copy +from datetime import datetime, timezone, timedelta, date +from decimal import Decimal +from zoneinfo import ZoneInfo + +import pytest + +from generalresearch.managers.thl.user_streak import compute_streaks_from_days +from generalresearch.models.thl.definitions import StatusCode1, Status +from generalresearch.models.thl.user_streak import ( + UserStreak, + StreakState, + StreakPeriod, + StreakFulfillment, +) + + +def test_compute_streaks_from_days(): + days = [ + date(2026, 1, 1), + date(2026, 1, 4), + date(2026, 1, 5), + date(2026, 1, 6), + date(2026, 1, 8), + date(2026, 1, 9), + date(2026, 2, 11), + date(2026, 2, 12), + ] + + # Active + today = date(2026, 2, 12) + res = compute_streaks_from_days(days, "us", period=StreakPeriod.DAY, today=today) + assert res == (2, 3, StreakState.ACTIVE, date(2026, 2, 12)) + + # At Risk + today = date(2026, 2, 13) + res = compute_streaks_from_days(days, "us", period=StreakPeriod.DAY, today=today) + assert res == (2, 3, StreakState.AT_RISK, date(2026, 2, 12)) + + # Broken + today = date(2026, 2, 14) + res = compute_streaks_from_days(days, "us", period=StreakPeriod.DAY, today=today) + assert res == (0, 3, StreakState.BROKEN, date(2026, 2, 12)) + + # Monthly, active + today = date(2026, 2, 14) + res = compute_streaks_from_days(days, "us", period=StreakPeriod.MONTH, today=today) + assert res == (2, 2, StreakState.ACTIVE, date(2026, 2, 1)) + + # monthly, at risk + today = date(2026, 3, 1) + res = compute_streaks_from_days(days, "us", period=StreakPeriod.MONTH, today=today) + assert res == (2, 2, StreakState.AT_RISK, date(2026, 2, 1)) + + # monthly, broken + today = date(2026, 4, 1) + res = compute_streaks_from_days(days, "us", period=StreakPeriod.MONTH, today=today) + assert res == (0, 2, StreakState.BROKEN, date(2026, 2, 1)) + + +@pytest.fixture +def broken_active_streak(user): + return [ + UserStreak( + period=StreakPeriod.DAY, + fulfillment=StreakFulfillment.ACTIVE, + country_iso="us", + user_id=user.user_id, + last_fulfilled_period_start=date(2025, 2, 11), + current_streak=0, + longest_streak=1, + state=StreakState.BROKEN, + ), + UserStreak( + period=StreakPeriod.WEEK, + fulfillment=StreakFulfillment.ACTIVE, + country_iso="us", + user_id=user.user_id, + last_fulfilled_period_start=date(2025, 2, 10), + current_streak=0, + longest_streak=1, + state=StreakState.BROKEN, + ), + UserStreak( + period=StreakPeriod.MONTH, + fulfillment=StreakFulfillment.ACTIVE, + country_iso="us", + user_id=user.user_id, + last_fulfilled_period_start=date(2025, 2, 1), + current_streak=0, + longest_streak=1, + state=StreakState.BROKEN, + ), + ] + + +def create_session_fail(session_manager, start, user): + session = session_manager.create_dummy(started=start, country_iso="us", user=user) + session_manager.finish_with_status( + session, + finished=start + timedelta(minutes=1), + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + ) + + +def create_session_complete(session_manager, start, user): + session = session_manager.create_dummy(started=start, country_iso="us", user=user) + session_manager.finish_with_status( + session, + finished=start + timedelta(minutes=1), + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal(1), + ) + + +def test_user_streak_empty(user_streak_manager, user): + streaks = user_streak_manager.get_user_streaks( + user_id=user.user_id, country_iso="us" + ) + assert streaks == [] + + +def test_user_streaks_active_broken( + user_streak_manager, user, session_manager, broken_active_streak +): + # Testing active streak, but broken (not today or yesterday) + start1 = datetime(2025, 2, 12, tzinfo=timezone.utc) + end1 = start1 + timedelta(minutes=1) + + # abandon counts as inactive + session = session_manager.create_dummy(started=start1, country_iso="us", user=user) + streak = user_streak_manager.get_user_streaks(user_id=user.user_id) + assert streak == [] + + # session start fail counts as inactive + session_manager.finish_with_status( + session, + finished=end1, + status=Status.FAIL, + status_code_1=StatusCode1.SESSION_START_QUALITY_FAIL, + ) + streak = user_streak_manager.get_user_streaks( + user_id=user.user_id, country_iso="us" + ) + assert streak == [] + + # Create another as a real failure + create_session_fail(session_manager, start1, user) + + # This is going to be the day before b/c midnight utc is 8pm US + last_active_day = start1.astimezone(ZoneInfo("America/New_York")).date() + assert last_active_day == date(2025, 2, 11) + expected_streaks = broken_active_streak.copy() + streaks = user_streak_manager.get_user_streaks(user_id=user.user_id) + assert streaks == expected_streaks + + # Create another the next day + start2 = start1 + timedelta(days=1) + create_session_fail(session_manager, start2, user) + + last_active_day = start2.astimezone(ZoneInfo("America/New_York")).date() + expected_streaks = copy.deepcopy(broken_active_streak) + expected_streaks[0].longest_streak = 2 + expected_streaks[0].last_fulfilled_period_start = last_active_day + + streaks = user_streak_manager.get_user_streaks( + user_id=user.user_id, country_iso="us" + ) + assert streaks == expected_streaks + + +def test_user_streak_complete_active(user_streak_manager, user, session_manager): + """Testing active streak that is today""" + + # They completed yesterday NY time. Today isn't over so streak is pending + start1 = datetime.now(tz=ZoneInfo("America/New_York")) - timedelta(days=1) + create_session_complete(session_manager, start1.astimezone(tz=timezone.utc), user) + + last_complete_day = start1.date() + expected_streak = UserStreak( + longest_streak=1, + current_streak=1, + state=StreakState.AT_RISK, + last_fulfilled_period_start=last_complete_day, + country_iso="us", + user_id=user.user_id, + fulfillment=StreakFulfillment.COMPLETE, + period=StreakPeriod.DAY, + ) + streaks = user_streak_manager.get_user_streaks( + user_id=user.user_id, country_iso="us" + ) + streak = [ + s + for s in streaks + if s.fulfillment == StreakFulfillment.COMPLETE and s.period == StreakPeriod.DAY + ][0] + assert streak == expected_streak + + # And now they complete today + start2 = datetime.now(tz=ZoneInfo("America/New_York")) + create_session_complete(session_manager, start2.astimezone(tz=timezone.utc), user) + last_complete_day = start2.date() + expected_streak = UserStreak( + longest_streak=2, + current_streak=2, + state=StreakState.ACTIVE, + last_fulfilled_period_start=last_complete_day, + country_iso="us", + user_id=user.user_id, + fulfillment=StreakFulfillment.COMPLETE, + period=StreakPeriod.DAY, + ) + + streaks = user_streak_manager.get_user_streaks( + user_id=user.user_id, country_iso="us" + ) + streak = [ + s + for s in streaks + if s.fulfillment == StreakFulfillment.COMPLETE and s.period == StreakPeriod.DAY + ][0] + assert streak == expected_streak diff --git a/tests/managers/thl/test_userhealth.py b/tests/managers/thl/test_userhealth.py new file mode 100644 index 0000000..1cda8de --- /dev/null +++ b/tests/managers/thl/test_userhealth.py @@ -0,0 +1,367 @@ +from datetime import timezone, datetime +from uuid import uuid4 + +import faker +import pytest + +from generalresearch.managers.thl.userhealth import ( + IPRecordManager, + UserIpHistoryManager, +) +from generalresearch.models.thl.ipinfo import GeoIPInformation +from generalresearch.models.thl.user_iphistory import ( + IPRecord, +) +from generalresearch.models.thl.userhealth import AuditLogLevel, AuditLog + +fake = faker.Faker() + + +class TestAuditLog: + + def test_init(self, thl_web_rr, audit_log_manager): + from generalresearch.managers.thl.userhealth import AuditLogManager + + alm = AuditLogManager(pg_config=thl_web_rr) + + assert isinstance(alm, AuditLogManager) + assert isinstance(audit_log_manager, AuditLogManager) + assert alm.pg_config.db == thl_web_rr.db + assert audit_log_manager.pg_config.db == thl_web_rr.db + + @pytest.mark.parametrize( + argnames="level", + argvalues=list(AuditLogLevel), + ) + def test_create(self, audit_log_manager, user, level): + instance = audit_log_manager.create( + user_id=user.user_id, level=level, event_type=uuid4().hex + ) + assert isinstance(instance, AuditLog) + assert instance.id != 1 + + def test_get_by_id(self, audit_log, audit_log_manager): + from generalresearch.models.thl.userhealth import AuditLog + + with pytest.raises(expected_exception=Exception) as cm: + audit_log_manager.get_by_id(auditlog_id=999_999_999_999) + assert "No AuditLog with id of " in str(cm.value) + + assert isinstance(audit_log, AuditLog) + res = audit_log_manager.get_by_id(auditlog_id=audit_log.id) + assert isinstance(res, AuditLog) + assert res.id == audit_log.id + assert res.created.tzinfo == timezone.utc + + def test_filter_by_product( + self, + user_factory, + product_factory, + audit_log_factory, + audit_log_manager, + ): + p1 = product_factory() + p2 = product_factory() + + audit_log_factory(user_id=user_factory(product=p1).user_id) + audit_log_factory(user_id=user_factory(product=p1).user_id) + audit_log_factory(user_id=user_factory(product=p1).user_id) + + res = audit_log_manager.filter_by_product(product=p2) + assert isinstance(res, list) + assert len(res) == 0 + + res = audit_log_manager.filter_by_product(product=p1) + assert isinstance(res, list) + assert len(res) == 3 + + audit_log_factory(user_id=user_factory(product=p2).user_id) + res = audit_log_manager.filter_by_product(product=p2) + assert isinstance(res, list) + assert isinstance(res[0], AuditLog) + assert len(res) == 1 + + def test_filter_by_user_id( + self, user_factory, product, audit_log_factory, audit_log_manager + ): + u1 = user_factory(product=product) + u2 = user_factory(product=product) + + audit_log_factory(user_id=u1.user_id) + audit_log_factory(user_id=u1.user_id) + audit_log_factory(user_id=u1.user_id) + + res = audit_log_manager.filter_by_user_id(user_id=u1.user_id) + assert isinstance(res, list) + assert len(res) == 3 + + res = audit_log_manager.filter_by_user_id(user_id=u2.user_id) + assert isinstance(res, list) + assert len(res) == 0 + + audit_log_factory(user_id=u2.user_id) + + res = audit_log_manager.filter_by_user_id(user_id=u2.user_id) + assert isinstance(res, list) + assert isinstance(res[0], AuditLog) + assert len(res) == 1 + + def test_filter( + self, + user_factory, + product_factory, + audit_log_factory, + audit_log_manager, + ): + p1 = product_factory() + p2 = product_factory() + p3 = product_factory() + + u1 = user_factory(product=p1) + u2 = user_factory(product=p2) + u3 = user_factory(product=p3) + + with pytest.raises(expected_exception=AssertionError) as cm: + audit_log_manager.filter(user_ids=[]) + assert "must pass at least 1 user_id" in str(cm.value) + + with pytest.raises(expected_exception=AssertionError) as cm: + audit_log_manager.filter(user_ids=[u1, u2, u3]) + assert "must pass user_id as int" in str(cm.value) + + res = audit_log_manager.filter(user_ids=[u1.user_id, u2.user_id, u3.user_id]) + assert isinstance(res, list) + assert len(res) == 0 + + audit_log_factory(user_id=u1.user_id) + + res = audit_log_manager.filter(user_ids=[u1.user_id, u2.user_id, u3.user_id]) + assert isinstance(res, list) + assert isinstance(res[0], AuditLog) + assert len(res) == 1 + + def test_filter_count( + self, + user_factory, + product_factory, + audit_log_factory, + audit_log_manager, + ): + p1 = product_factory() + p2 = product_factory() + p3 = product_factory() + + u1 = user_factory(product=p1) + u2 = user_factory(product=p2) + u3 = user_factory(product=p3) + + with pytest.raises(expected_exception=AssertionError) as cm: + audit_log_manager.filter(user_ids=[]) + assert "must pass at least 1 user_id" in str(cm.value) + + with pytest.raises(expected_exception=AssertionError) as cm: + audit_log_manager.filter(user_ids=[u1, u2, u3]) + assert "must pass user_id as int" in str(cm.value) + + res = audit_log_manager.filter_count( + user_ids=[u1.user_id, u2.user_id, u3.user_id] + ) + assert isinstance(res, int) + assert res == 0 + + audit_log_factory(user_id=u1.user_id, level=20) + + res = audit_log_manager.filter_count( + user_ids=[u1.user_id, u2.user_id, u3.user_id] + ) + assert isinstance(res, int) + assert res == 1 + + res = audit_log_manager.filter_count( + user_ids=[u1.user_id, u2.user_id, u3.user_id], + created_after=datetime.now(tz=timezone.utc), + ) + assert isinstance(res, int) + assert res == 0 + + res = audit_log_manager.filter_count( + user_ids=[u1.user_id], event_type_like="offerwall-enter.%%" + ) + assert res == 1 + + audit_log_factory(user_id=u1.user_id, level=50) + res = audit_log_manager.filter_count( + user_ids=[u1.user_id], + event_type_like="offerwall-enter.%%", + level_ge=10, + ) + assert res == 2 + + res = audit_log_manager.filter_count( + user_ids=[u1.user_id], event_type_like="poop.%", level_ge=10 + ) + assert res == 0 + + +class TestIPRecordManager: + + def test_init(self, thl_web_rr, thl_redis_config, ip_record_manager): + instance = IPRecordManager(pg_config=thl_web_rr, redis_config=thl_redis_config) + assert isinstance(instance, IPRecordManager) + assert isinstance(ip_record_manager, IPRecordManager) + + def test_create(self, ip_record_manager, user, ip_information): + instance = ip_record_manager.create_dummy( + user_id=user.user_id, ip=ip_information.ip + ) + assert isinstance(instance, IPRecord) + + assert isinstance(instance.forwarded_ips, list) + assert isinstance(instance.forwarded_ip_records[0], IPRecord) + assert isinstance(instance.forwarded_ips[0], str) + + assert instance.created == instance.forwarded_ip_records[0].created + + ipr1 = ip_record_manager.filter_ip_records(filter_ips=[instance.ip]) + assert isinstance(ipr1, list) + assert instance.model_dump_json() == ipr1[0].model_dump_json() + + def test_prefetch_info( + self, + ip_record_factory, + ip_information_factory, + ip_geoname, + user, + thl_web_rr, + thl_redis_config, + ): + + ip = fake.ipv4_public() + ip_information_factory(ip=ip, geoname=ip_geoname) + ipr: IPRecord = ip_record_factory(user_id=user.user_id, ip=ip) + + assert ipr.information is None + assert len(ipr.forwarded_ip_records) >= 1 + fipr = ipr.forwarded_ip_records[0] + assert fipr.information is None + + ipr.prefetch_ipinfo( + pg_config=thl_web_rr, + redis_config=thl_redis_config, + include_forwarded=True, + ) + assert isinstance(ipr.information, GeoIPInformation) + assert ipr.information.ip == ipr.ip == ip + assert fipr.information is None, "the ipinfo doesn't exist in the db yet" + + ip_information_factory(ip=fipr.ip, geoname=ip_geoname) + ipr.prefetch_ipinfo( + pg_config=thl_web_rr, + redis_config=thl_redis_config, + include_forwarded=True, + ) + assert fipr.information is not None + + +@pytest.mark.usefixtures("user_iphistory_manager_clear_cache") +class TestUserIpHistoryManager: + def test_init(self, thl_web_rr, thl_redis_config, user_iphistory_manager): + instance = UserIpHistoryManager( + pg_config=thl_web_rr, redis_config=thl_redis_config + ) + assert isinstance(instance, UserIpHistoryManager) + assert isinstance(user_iphistory_manager, UserIpHistoryManager) + + def test_latest_record( + self, + user_iphistory_manager, + user, + ip_record_factory, + ip_information_factory, + ip_geoname, + ): + ip = fake.ipv4_public() + ip_information_factory(ip=ip, geoname=ip_geoname, is_anonymous=True) + ipr1: IPRecord = ip_record_factory(user_id=user.user_id, ip=ip) + + ipr = user_iphistory_manager.get_user_latest_ip_record(user=user) + assert ipr.ip == ipr1.ip + assert ipr.is_anonymous + assert ipr.information.lookup_prefix == "/32" + + ip = fake.ipv6() + ip_information_factory(ip=ip, geoname=ip_geoname) + ipr2: IPRecord = ip_record_factory(user_id=user.user_id, ip=ip) + + ipr = user_iphistory_manager.get_user_latest_ip_record(user=user) + assert ipr.ip == ipr2.ip + assert ipr.information.lookup_prefix == "/64" + assert ipr.information is not None + assert not ipr.is_anonymous + + country_iso = user_iphistory_manager.get_user_latest_country(user=user) + assert country_iso == ip_geoname.country_iso + + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert iph.ips[0].information is not None + assert iph.ips[1].information is not None + assert iph.ips[0].country_iso == country_iso + assert iph.ips[0].is_anonymous + assert iph.ips[0].ip == ipr1.ip + assert iph.ips[1].ip == ipr2.ip + + def test_virgin(self, user, user_iphistory_manager, ip_record_factory): + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert len(iph.ips) == 0 + + ip_record_factory(user_id=user.user_id, ip=fake.ipv4_public()) + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert len(iph.ips) == 1 + + def test_out_of_order( + self, + ip_record_factory, + user, + user_iphistory_manager, + ip_information_factory, + ip_geoname, + ): + # Create the user-ip association BEFORE the ip even exists in the ipinfo table + ip = fake.ipv4_public() + ip_record_factory(user_id=user.user_id, ip=ip) + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert len(iph.ips) == 1 + ipr = iph.ips[0] + assert ipr.information is None + assert not ipr.is_anonymous + + ip_information_factory(ip=ip, geoname=ip_geoname, is_anonymous=True) + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert len(iph.ips) == 1 + ipr = iph.ips[0] + assert ipr.information is not None + assert ipr.is_anonymous + + def test_out_of_order_ipv6( + self, + ip_record_factory, + user, + user_iphistory_manager, + ip_information_factory, + ip_geoname, + ): + # Create the user-ip association BEFORE the ip even exists in the ipinfo table + ip = fake.ipv6() + ip_record_factory(user_id=user.user_id, ip=ip) + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert len(iph.ips) == 1 + ipr = iph.ips[0] + assert ipr.information is None + assert not ipr.is_anonymous + + ip_information_factory(ip=ip, geoname=ip_geoname, is_anonymous=True) + iph = user_iphistory_manager.get_user_ip_history(user_id=user.user_id) + assert len(iph.ips) == 1 + ipr = iph.ips[0] + assert ipr.information is not None + assert ipr.is_anonymous diff --git a/tests/managers/thl/test_wall_manager.py b/tests/managers/thl/test_wall_manager.py new file mode 100644 index 0000000..ee44e23 --- /dev/null +++ b/tests/managers/thl/test_wall_manager.py @@ -0,0 +1,283 @@ +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from uuid import uuid4 + +import pytest + +from generalresearch.models import Source +from generalresearch.models.thl.session import ( + ReportValue, + Status, + StatusCode1, +) +from test_utils.models.conftest import user, session + + +class TestWallManager: + + @pytest.mark.parametrize("wall_count", [1, 2, 5, 10, 50, 99]) + def test_get_wall_events( + self, wall_manager, session_factory, user, wall_count, utc_hour_ago + ): + from generalresearch.models.thl.session import Session + + s1: Session = session_factory( + user=user, wall_count=wall_count, started=utc_hour_ago + ) + + assert len(s1.wall_events) == wall_count + assert len(wall_manager.get_wall_events(session_id=s1.id)) == wall_count + + db_wall_events = wall_manager.get_wall_events(session_id=s1.id) + assert [w.uuid for w in s1.wall_events] == [w.uuid for w in db_wall_events] + assert [w.source for w in s1.wall_events] == [w.source for w in db_wall_events] + assert [w.buyer_id for w in s1.wall_events] == [ + w.buyer_id for w in db_wall_events + ] + assert [w.req_survey_id for w in s1.wall_events] == [ + w.req_survey_id for w in db_wall_events + ] + assert [w.started for w in s1.wall_events] == [ + w.started for w in db_wall_events + ] + + assert sum([w.req_cpi for w in s1.wall_events]) == sum( + [w.req_cpi for w in db_wall_events] + ) + assert sum([w.cpi for w in s1.wall_events]) == sum( + [w.cpi for w in db_wall_events] + ) + + assert [w.session_id for w in s1.wall_events] == [ + w.session_id for w in db_wall_events + ] + assert [w.user_id for w in s1.wall_events] == [ + w.user_id for w in db_wall_events + ] + assert [w.survey_id for w in s1.wall_events] == [ + w.survey_id for w in db_wall_events + ] + + assert [w.finished for w in s1.wall_events] == [ + w.finished for w in db_wall_events + ] + + def test_get_wall_events_list_input( + self, wall_manager, session_factory, user, utc_hour_ago + ): + from generalresearch.models.thl.session import Session + + session_ids = [] + for idx in range(10): + s: Session = session_factory(user=user, wall_count=5, started=utc_hour_ago) + session_ids.append(s.id) + + session_ids.sort() + res = wall_manager.get_wall_events(session_ids=session_ids) + + assert isinstance(res, list) + assert len(res) == 50 + + res1 = list(set([w.session_id for w in res])) + res1.sort() + + assert session_ids == res1 + + def test_create_wall(self, wall_manager, session_manager, user, session): + w = wall_manager.create( + session_id=session.id, + user_id=user.user_id, + uuid_id=uuid4().hex, + started=datetime.now(tz=timezone.utc), + source=Source.DYNATA, + buyer_id="123", + req_survey_id="456", + req_cpi=Decimal("1"), + ) + + assert w is not None + w2 = wall_manager.get_from_uuid(wall_uuid=w.uuid) + assert w == w2 + + def test_report_wall_abandon( + self, wall_manager, session_manager, user, session, utc_hour_ago + ): + w1 = wall_manager.create( + session_id=session.id, + user_id=user.user_id, + uuid_id=uuid4().hex, + started=utc_hour_ago, + source=Source.DYNATA, + buyer_id="123", + req_survey_id="456", + req_cpi=Decimal("1"), + ) + wall_manager.report( + wall=w1, + report_value=ReportValue.REASON_UNKNOWN, + report_timestamp=utc_hour_ago + timedelta(minutes=1), + ) + w2 = wall_manager.get_from_uuid(wall_uuid=w1.uuid) + + # I Reported a session with no status. It should be marked as an abandon with a finished ts + assert ReportValue.REASON_UNKNOWN == w2.report_value + assert Status.ABANDON == w2.status + assert utc_hour_ago + timedelta(minutes=1) == w2.finished + assert w2.report_notes is None + + # There is nothing stopping it from being un-abandoned... + finished = w1.started + timedelta(minutes=10) + wall_manager.finish( + wall=w1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + finished=finished, + ) + w2 = wall_manager.get_from_uuid(wall_uuid=w1.uuid) + assert ReportValue.REASON_UNKNOWN == w2.report_value + assert w2.report_notes is None + assert finished == w2.finished + assert Status.COMPLETE == w2.status + # the status and finished get updated + + def test_report_wall( + self, wall_manager, session_manager, user, session, utc_hour_ago + ): + w1 = wall_manager.create( + session_id=session.id, + user_id=user.user_id, + uuid_id=uuid4().hex, + started=utc_hour_ago, + source=Source.DYNATA, + buyer_id="123", + req_survey_id="456", + req_cpi=Decimal("1"), + ) + + finish_ts = utc_hour_ago + timedelta(minutes=10) + report_ts = utc_hour_ago + timedelta(minutes=11) + wall_manager.finish( + wall=w1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + finished=finish_ts, + ) + + # I Reported a session that was already completed. Only update the report values + wall_manager.report( + wall=w1, + report_value=ReportValue.TECHNICAL_ERROR, + report_timestamp=report_ts, + report_notes="This survey blows!", + ) + w2 = wall_manager.get_from_uuid(wall_uuid=w1.uuid) + + assert ReportValue.TECHNICAL_ERROR == w2.report_value + assert finish_ts == w2.finished + assert Status.COMPLETE == w2.status + assert "This survey blows!" == w2.report_notes + + def test_filter_wall_attempts( + self, wall_manager, session_manager, user, session, utc_hour_ago + ): + res = wall_manager.filter_wall_attempts(user_id=user.user_id) + assert len(res) == 0 + w1 = wall_manager.create( + session_id=session.id, + user_id=user.user_id, + uuid_id=uuid4().hex, + started=utc_hour_ago, + source=Source.DYNATA, + buyer_id="123", + req_survey_id="456", + req_cpi=Decimal("1"), + ) + res = wall_manager.filter_wall_attempts(user_id=user.user_id) + assert len(res) == 1 + w2 = wall_manager.create( + session_id=session.id, + user_id=user.user_id, + uuid_id=uuid4().hex, + started=utc_hour_ago + timedelta(minutes=1), + source=Source.DYNATA, + buyer_id="123", + req_survey_id="555", + req_cpi=Decimal("1"), + ) + res = wall_manager.filter_wall_attempts(user_id=user.user_id) + assert len(res) == 2 + + +class TestWallCacheManager: + + def test_get_attempts_none(self, wall_cache_manager, user): + attempts = wall_cache_manager.get_attempts(user.user_id) + assert len(attempts) == 0 + + def test_get_wall_events( + self, wall_cache_manager, wall_manager, session_manager, user + ): + start1 = datetime.now(timezone.utc) - timedelta(hours=3) + start2 = datetime.now(timezone.utc) - timedelta(hours=2) + start3 = datetime.now(timezone.utc) - timedelta(hours=1) + + session = session_manager.create_dummy(started=start1, user=user) + wall1 = wall_manager.create_dummy( + session_id=session.id, + user_id=session.user_id, + started=start1, + req_cpi=Decimal("1.23"), + req_survey_id="11111", + source=Source.DYNATA, + ) + # The flag never got set, so no results! + attempts = wall_cache_manager.get_attempts(user_id=user.user_id) + assert len(attempts) == 0 + + wall_cache_manager.set_flag(user_id=user.user_id) + attempts = wall_cache_manager.get_attempts(user_id=user.user_id) + assert len(attempts) == 1 + + wall2 = wall_manager.create_dummy( + session_id=session.id, + user_id=session.user_id, + started=start2, + req_cpi=Decimal("1.23"), + req_survey_id="22222", + source=Source.DYNATA, + ) + + # We haven't set the flag, so the cache won't update! + attempts = wall_cache_manager.get_attempts(user_id=user.user_id) + assert len(attempts) == 1 + + # Now set the flag + wall_cache_manager.set_flag(user_id=user.user_id) + attempts = wall_cache_manager.get_attempts(user_id=user.user_id) + assert len(attempts) == 2 + # It is in desc order + assert attempts[0].req_survey_id == "22222" + assert attempts[1].req_survey_id == "11111" + + # Test the trim. Fill up the cache with 6000 events, then add another, + # and it should be first in the list, with only 5k others + attempts10000 = [attempts[0]] * 6000 + wall_cache_manager.update_attempts_redis_(attempts10000, user_id=user.user_id) + + session = session_manager.create_dummy(started=start3, user=user) + wall3 = wall_manager.create_dummy( + session_id=session.id, + user_id=session.user_id, + started=start3, + req_cpi=Decimal("1.23"), + req_survey_id="33333", + source=Source.DYNATA, + ) + wall_cache_manager.set_flag(user_id=user.user_id) + attempts = wall_cache_manager.get_attempts(user_id=user.user_id) + + redis_key = wall_cache_manager.get_cache_key_(user_id=user.user_id) + assert wall_cache_manager.redis_client.llen(redis_key) == 5000 + + assert len(attempts) == 5000 + assert attempts[0].req_survey_id == "33333" diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/admin/__init__.py b/tests/models/admin/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/admin/test_report_request.py b/tests/models/admin/test_report_request.py new file mode 100644 index 0000000..cf4c405 --- /dev/null +++ b/tests/models/admin/test_report_request.py @@ -0,0 +1,163 @@ +from datetime import timezone, datetime + +import pandas as pd +import pytest +from pydantic import ValidationError + + +class TestReportRequest: + def test_base(self, utc_60days_ago): + from generalresearch.models.admin.request import ( + ReportRequest, + ReportType, + ) + + rr = ReportRequest() + + assert isinstance(rr.start, datetime), "rr.start incorrect type" + assert isinstance(rr.start_floor, datetime), "rr.start_floor incorrect type" + + assert rr.report_type == ReportType.POP_SESSION + assert rr.start != rr.start_floor, "rr.start != rr.start_floor" + assert rr.start_floor.tzinfo == timezone.utc, "rr.start_floor.tzinfo not utc" + + rr1 = ReportRequest.model_validate( + { + "start": datetime( + year=datetime.now().year, + month=1, + day=1, + hour=0, + minute=30, + second=25, + microsecond=35, + tzinfo=timezone.utc, + ), + "interval": "1h", + } + ) + + assert isinstance(rr1.start, datetime), "rr1.start incorrect type" + assert isinstance(rr1.start_floor, datetime), "rr1.start_floor incorrect type" + + rr2 = ReportRequest.model_validate( + { + "start": datetime( + year=datetime.now().year, + month=1, + day=1, + hour=6, + minute=30, + second=25, + microsecond=35, + tzinfo=timezone.utc, + ), + "interval": "1d", + } + ) + + assert isinstance(rr2.start, datetime), "rr2.start incorrect type" + assert isinstance(rr2.start_floor, datetime), "rr2.start_floor incorrect type" + + assert rr1.start != rr2.start, "rr1.start != rr2.start" + assert rr1.start_floor == rr2.start_floor, "rr1.start_floor == rr2.start_floor" + + # datetime.datetime(2025, 7, 9, 0, 0, tzinfo=datetime.timezone.utc) + # datetime.datetime(2025, 7, 9, 0, 0, tzinfo=datetime.timezone.utc) + + # datetime.datetime(2025, 7, 9, 0, 0, tzinfo=datetime.timezone.utc) = + # ReportRequest(report_type=, + # index0='started', index1='product_id', + # start=datetime.datetime(2025, 7, 9, 0, 46, 23, 145756, tzinfo=datetime.timezone.utc), + # end=datetime.datetime(2025, 9, 7, 0, 46, 23, 149195, tzinfo=datetime.timezone.utc), + # interval='1h', include_open_bucket=True, + # start_floor=datetime.datetime(2025, 7, 9, 0, 0, tzinfo=datetime.timezone.utc)).start_floor + + # datetime.datetime(2025, 7, 9, 0, 0, tzinfo=datetime.timezone.utc) = + # ReportRequest(report_type=, + # index0='started', index1='product_id', + # start=datetime.datetime(2025, 7, 9, 0, 46, 23, 145756, tzinfo=datetime.timezone.utc), + # end=datetime.datetime(2025, 9, 7, 0, 46, 23, 149267, tzinfo=datetime.timezone.utc), + # interval='1d', include_open_bucket=True, + # start_floor=datetime.datetime(2025, 7, 9, 0, 0, tzinfo=datetime.timezone.utc)).start_floor + + def test_start_end_range(self, utc_90days_ago, utc_30days_ago): + from generalresearch.models.admin.request import ReportRequest + + with pytest.raises(expected_exception=ValidationError) as cm: + ReportRequest.model_validate( + {"start": utc_30days_ago, "end": utc_90days_ago} + ) + + with pytest.raises(expected_exception=ValidationError) as cm: + ReportRequest.model_validate( + { + "start": datetime(year=1990, month=1, day=1), + "end": datetime(year=1950, month=1, day=1), + } + ) + + def test_start_end_range_tz(self): + from generalresearch.models.admin.request import ReportRequest + from zoneinfo import ZoneInfo + + pacific_tz = ZoneInfo("America/Los_Angeles") + + with pytest.raises(expected_exception=ValidationError) as cm: + ReportRequest.model_validate( + { + "start": datetime(year=2000, month=1, day=1, tzinfo=pacific_tz), + "end": datetime(year=2000, month=6, day=1, tzinfo=pacific_tz), + } + ) + + def test_start_floor_naive(self): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest() + + assert rr.start_floor_naive.tzinfo is None + + def test_end_naive(self): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest() + + assert rr.end_naive.tzinfo is None + + def test_pd_interval(self): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest() + + assert isinstance(rr.pd_interval, pd.Interval) + + def test_interval_timedelta(self): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest() + + assert isinstance(rr.interval_timedelta, pd.Timedelta) + + def test_buckets(self): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest() + + assert isinstance(rr.buckets(), pd.DatetimeIndex) + + def test_bucket_ranges(self): + from generalresearch.models.admin.request import ReportRequest + + rr = ReportRequest() + assert isinstance(rr.bucket_ranges(), list) + + rr = ReportRequest.model_validate( + { + "interval": "1d", + "start": datetime(year=2000, month=1, day=1, tzinfo=timezone.utc), + "end": datetime(year=2000, month=1, day=10, tzinfo=timezone.utc), + } + ) + + assert len(rr.bucket_ranges()) == 10 diff --git a/tests/models/custom_types/__init__.py b/tests/models/custom_types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/custom_types/test_aware_datetime.py b/tests/models/custom_types/test_aware_datetime.py new file mode 100644 index 0000000..a23413c --- /dev/null +++ b/tests/models/custom_types/test_aware_datetime.py @@ -0,0 +1,82 @@ +import logging +from datetime import datetime, timezone +from typing import Optional + +import pytest +import pytz +from pydantic import BaseModel, ValidationError, Field + +from generalresearch.models.custom_types import AwareDatetimeISO + +logger = logging.getLogger() + + +class AwareDatetimeISOModel(BaseModel): + dt_optional: Optional[AwareDatetimeISO] = Field(default=None) + dt: AwareDatetimeISO + + +class TestAwareDatetimeISO: + def test_str(self): + dt = "2023-10-10T01:01:01.0Z" + t = AwareDatetimeISOModel(dt=dt, dt_optional=dt) + AwareDatetimeISOModel.model_validate_json(t.model_dump_json()) + + t = AwareDatetimeISOModel(dt=dt, dt_optional=None) + AwareDatetimeISOModel.model_validate_json(t.model_dump_json()) + + def test_dt(self): + dt = datetime(2023, 10, 10, 1, 1, 1, tzinfo=timezone.utc) + t = AwareDatetimeISOModel(dt=dt, dt_optional=dt) + AwareDatetimeISOModel.model_validate_json(t.model_dump_json()) + + t = AwareDatetimeISOModel(dt=dt, dt_optional=None) + AwareDatetimeISOModel.model_validate_json(t.model_dump_json()) + + dt = datetime(2023, 10, 10, 1, 1, 1, microsecond=123, tzinfo=timezone.utc) + t = AwareDatetimeISOModel(dt=dt, dt_optional=dt) + AwareDatetimeISOModel.model_validate_json(t.model_dump_json()) + + t = AwareDatetimeISOModel(dt=dt, dt_optional=None) + AwareDatetimeISOModel.model_validate_json(t.model_dump_json()) + + def test_no_tz(self): + dt = datetime(2023, 10, 10, 1, 1, 1) + + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=dt, dt_optional=None) + + dt = "2023-10-10T01:01:01.0" + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=dt, dt_optional=None) + + def test_non_utc_tz(self): + dt = datetime( + year=2023, + month=10, + day=10, + hour=1, + second=1, + minute=1, + tzinfo=pytz.timezone("US/Central"), + ) + + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=dt, dt_optional=dt) + + def test_invalid_format(self): + dt = "2023-10-10T01:01:01Z" + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=dt, dt_optional=dt) + + dt = "2023-10-10T01:01:01" + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=dt, dt_optional=dt) + dt = "2023-10-10" + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=dt, dt_optional=dt) + + def test_required(self): + dt = "2023-10-10T01:01:01.0Z" + with pytest.raises(expected_exception=ValidationError): + AwareDatetimeISOModel(dt=None, dt_optional=dt) diff --git a/tests/models/custom_types/test_dsn.py b/tests/models/custom_types/test_dsn.py new file mode 100644 index 0000000..b37f2c4 --- /dev/null +++ b/tests/models/custom_types/test_dsn.py @@ -0,0 +1,112 @@ +from typing import Optional +from uuid import uuid4 + +import pytest +from pydantic import BaseModel, ValidationError, Field +from pydantic import MySQLDsn +from pydantic_core import Url + +from generalresearch.models.custom_types import DaskDsn, SentryDsn + + +# --- Test Pydantic Models --- + + +class SettingsModel(BaseModel): + dask: Optional["DaskDsn"] = Field(default=None) + sentry: Optional["SentryDsn"] = Field(default=None) + db: Optional["MySQLDsn"] = Field(default=None) + + +# --- Pytest themselves --- + + +class TestDaskDsn: + + def test_base(self): + from dask.distributed import Client + + m = SettingsModel(dask="tcp://dask-scheduler.internal") + + assert m.dask.scheme == "tcp" + assert m.dask.host == "dask-scheduler.internal" + assert m.dask.port == 8786 + + with pytest.raises(expected_exception=TypeError) as cm: + Client(m.dask) + assert "Scheduler address must be a string or a Cluster instance" in str( + cm.value + ) + + # todo: this requires vpn connection. maybe do this part with a localhost dsn + # client = Client(str(m.dask)) + # self.assertIsInstance(client, Client) + + def test_str(self): + m = SettingsModel(dask="tcp://dask-scheduler.internal") + assert isinstance(m.dask, Url) + assert "tcp://dask-scheduler.internal:8786" == str(m.dask) + + def test_auth(self): + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(dask="tcp://test:password@dask-scheduler.internal") + assert "User & Password are not supported" in str(cm.value) + + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(dask="tcp://test:@dask-scheduler.internal") + assert "User & Password are not supported" in str(cm.value) + + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(dask="tcp://:password@dask-scheduler.internal") + assert "User & Password are not supported" in str(cm.value) + + def test_invalid_schema(self): + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(dask="dask-scheduler.internal") + assert "relative URL without a base" in str(cm.value) + + # I look forward to the day we use infiniband interfaces + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(dask="ucx://dask-scheduler.internal") + assert "URL scheme should be 'tcp'" in str(cm.value) + + def test_port(self): + m = SettingsModel(dask="tcp://dask-scheduler.internal") + assert m.dask.port == 8786 + + +class TestSentryDsn: + def test_base(self): + m = SettingsModel( + sentry=f"https://{uuid4().hex}@12345.ingest.us.sentry.io/9876543" + ) + + assert m.sentry.scheme == "https" + assert m.sentry.host == "12345.ingest.us.sentry.io" + assert m.sentry.port == 443 + + def test_str(self): + test_url: str = f"https://{uuid4().hex}@12345.ingest.us.sentry.io/9876543" + m = SettingsModel(sentry=test_url) + assert isinstance(m.sentry, Url) + assert test_url == str(m.sentry) + + def test_auth(self): + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel( + sentry="https://0123456789abc:password@12345.ingest.us.sentry.io/9876543" + ) + assert "Sentry password is not supported" in str(cm.value) + + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(sentry="https://test:@12345.ingest.us.sentry.io/9876543") + assert "Sentry user key seems bad" in str(cm.value) + + with pytest.raises(expected_exception=ValidationError) as cm: + SettingsModel(sentry="https://:password@12345.ingest.us.sentry.io/9876543") + assert "Sentry URL requires a user key" in str(cm.value) + + def test_port(self): + test_url: str = f"https://{uuid4().hex}@12345.ingest.us.sentry.io/9876543" + m = SettingsModel(sentry=test_url) + assert m.sentry.port == 443 diff --git a/tests/models/custom_types/test_therest.py b/tests/models/custom_types/test_therest.py new file mode 100644 index 0000000..13e9bae --- /dev/null +++ b/tests/models/custom_types/test_therest.py @@ -0,0 +1,42 @@ +import json +from uuid import UUID + +import pytest +from pydantic import TypeAdapter, ValidationError + + +class TestAll: + + def test_comma_sep_str(self): + from generalresearch.models.custom_types import AlphaNumStrSet + + t = TypeAdapter(AlphaNumStrSet) + assert {"a", "b", "c"} == t.validate_python(["a", "b", "c"]) + assert '"a,b,c"' == t.dump_json({"c", "b", "a"}).decode() + assert '""' == t.dump_json(set()).decode() + assert {"a", "b", "c"} == t.validate_json('"c,b,a"') + assert set() == t.validate_json('""') + + with pytest.raises(ValidationError): + t.validate_python({"", "b", "a"}) + + with pytest.raises(ValidationError): + t.validate_python({""}) + + with pytest.raises(ValidationError): + t.validate_json('",b,a"') + + def test_UUIDStrCoerce(self): + from generalresearch.models.custom_types import UUIDStrCoerce + + t = TypeAdapter(UUIDStrCoerce) + uuid_str = "18e70590176e49c693b07682f3c112be" + assert uuid_str == t.validate_python("18e70590-176e-49c6-93b0-7682f3c112be") + assert uuid_str == t.validate_python( + UUID("18e70590-176e-49c6-93b0-7682f3c112be") + ) + assert ( + json.dumps(uuid_str) + == t.dump_json("18e70590176e49c693b07682f3c112be").decode() + ) + assert uuid_str == t.validate_json('"18e70590-176e-49c6-93b0-7682f3c112be"') diff --git a/tests/models/custom_types/test_uuid_str.py b/tests/models/custom_types/test_uuid_str.py new file mode 100644 index 0000000..91af9ae --- /dev/null +++ b/tests/models/custom_types/test_uuid_str.py @@ -0,0 +1,51 @@ +from typing import Optional +from uuid import uuid4 + +import pytest +from pydantic import BaseModel, ValidationError, Field + +from generalresearch.models.custom_types import UUIDStr + + +class UUIDStrModel(BaseModel): + uuid_optional: Optional[UUIDStr] = Field(default_factory=lambda: uuid4().hex) + uuid: UUIDStr + + +class TestUUIDStr: + def test_str(self): + v = "58889cd67f9f4c699b25437112dce638" + + t = UUIDStrModel(uuid=v, uuid_optional=v) + UUIDStrModel.model_validate_json(t.model_dump_json()) + + t = UUIDStrModel(uuid=v, uuid_optional=None) + t2 = UUIDStrModel.model_validate_json(t.model_dump_json()) + + assert t2.uuid_optional is None + assert t2.uuid == v + + def test_uuid(self): + v = uuid4() + + with pytest.raises(ValidationError) as cm: + UUIDStrModel(uuid=v, uuid_optional=None) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(ValidationError) as cm: + UUIDStrModel(uuid="58889cd67f9f4c699b25437112dce638", uuid_optional=v) + assert "Input should be a valid string" in str(cm.value) + + def test_invalid_format(self): + v = "x" + with pytest.raises(ValidationError): + UUIDStrModel(uuid=v, uuid_optional=None) + + with pytest.raises(ValidationError): + UUIDStrModel(uuid="58889cd67f9f4c699b25437112dce638", uuid_optional=v) + + def test_required(self): + v = "58889cd67f9f4c699b25437112dce638" + + with pytest.raises(ValidationError): + UUIDStrModel(uuid=None, uuid_optional=v) diff --git a/tests/models/dynata/__init__.py b/tests/models/dynata/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/dynata/test_eligbility.py b/tests/models/dynata/test_eligbility.py new file mode 100644 index 0000000..23437f5 --- /dev/null +++ b/tests/models/dynata/test_eligbility.py @@ -0,0 +1,324 @@ +from datetime import datetime, timezone + + +class TestEligibility: + + def test_evaluate_task_criteria(self): + from generalresearch.models.dynata.survey import ( + DynataQuotaGroup, + DynataFilterGroup, + DynataSurvey, + DynataRequirements, + ) + + filters = [[["a", "b"], ["c", "d"]], [["e"], ["f"]]] + filters = [DynataFilterGroup.model_validate(f) for f in filters] + criteria_evaluation = { + "a": True, + "b": True, + "c": True, + "d": False, + "e": True, + "f": True, + } + quotas = [ + DynataQuotaGroup.model_validate( + [{"count": 100, "condition_hashes": [], "status": "OPEN"}] + ) + ] + task = DynataSurvey.model_validate( + { + "survey_id": "1", + "filters": filters, + "quotas": quotas, + "allowed_devices": set("1"), + "calculation_type": "COMPLETES", + "client_id": "", + "country_iso": "us", + "language_iso": "eng", + "group_id": "g1", + "project_id": "p1", + "status": "OPEN", + "project_exclusions": set(), + "created": datetime.now(tz=timezone.utc), + "category_exclusions": set(), + "category_ids": set(), + "cpi": 1, + "days_in_field": 0, + "expected_count": 0, + "order_number": "", + "live_link": "", + "bid_ir": 0.5, + "bid_loi": 500, + "requirements": DynataRequirements(), + } + ) + assert task.determine_eligibility(criteria_evaluation) + + # task status + task.status = "CLOSED" + assert not task.determine_eligibility(criteria_evaluation) + task.status = "OPEN" + + # one quota with no space left (count = 0) + quotas = [ + DynataQuotaGroup.model_validate( + [{"count": 0, "condition_hashes": [], "status": "OPEN"}] + ) + ] + task.quotas = quotas + assert not task.determine_eligibility(criteria_evaluation) + + # we pass 'a' and 'b' + quotas = [ + DynataQuotaGroup.model_validate( + [{"count": 100, "condition_hashes": ["a", "b"], "status": "OPEN"}] + ) + ] + task.quotas = quotas + assert task.determine_eligibility(criteria_evaluation) + + # make 'f' false, we still pass the 2nd filtergroup b/c 'e' is True + criteria_evaluation = { + "a": True, + "b": True, + "c": True, + "d": False, + "e": True, + "f": False, + } + assert task.determine_eligibility(criteria_evaluation) + + # make 'e' false, we don't pass the 2nd filtergroup + criteria_evaluation = { + "a": True, + "b": True, + "c": True, + "d": False, + "e": False, + "f": False, + } + assert not task.determine_eligibility(criteria_evaluation) + + # We fail quota 'c','d', but we pass 'a','b', so we pass the first quota group + criteria_evaluation = { + "a": True, + "b": True, + "c": True, + "d": False, + "e": True, + "f": True, + } + quotas = [ + DynataQuotaGroup.model_validate( + [ + {"count": 100, "condition_hashes": ["a", "b"], "status": "OPEN"}, + {"count": 100, "condition_hashes": ["c", "d"], "status": "CLOSED"}, + ] + ) + ] + task.quotas = quotas + assert task.determine_eligibility(criteria_evaluation) + + # we pass the first qg, but then fall into a full 2nd qg + quotas = [ + DynataQuotaGroup.model_validate( + [ + {"count": 100, "condition_hashes": ["a", "b"], "status": "OPEN"}, + {"count": 100, "condition_hashes": ["c", "d"], "status": "CLOSED"}, + ] + ), + DynataQuotaGroup.model_validate( + [{"count": 100, "condition_hashes": ["f"], "status": "CLOSED"}] + ), + ] + task.quotas = quotas + assert not task.determine_eligibility(criteria_evaluation) + + def test_soft_pair(self): + from generalresearch.models.dynata.survey import ( + DynataQuotaGroup, + DynataFilterGroup, + DynataSurvey, + DynataRequirements, + ) + + filters = [[["a", "b"], ["c", "d"]], [["e"], ["f"]]] + filters = [DynataFilterGroup.model_validate(f) for f in filters] + criteria_evaluation = { + "a": True, + "b": True, + "c": True, + "d": False, + "e": True, + "f": True, + } + quotas = [ + DynataQuotaGroup.model_validate( + [{"count": 100, "condition_hashes": [], "status": "OPEN"}] + ) + ] + task = DynataSurvey.model_validate( + { + "survey_id": "1", + "filters": filters, + "quotas": quotas, + "allowed_devices": set("1"), + "calculation_type": "COMPLETES", + "client_id": "", + "country_iso": "us", + "language_iso": "eng", + "group_id": "g1", + "project_id": "p1", + "status": "OPEN", + "project_exclusions": set(), + "created": datetime.now(tz=timezone.utc), + "category_exclusions": set(), + "category_ids": set(), + "cpi": 1, + "days_in_field": 0, + "expected_count": 0, + "order_number": "", + "live_link": "", + "bid_ir": 0.5, + "bid_loi": 500, + "requirements": DynataRequirements(), + } + ) + assert task.passes_filters(criteria_evaluation) + passes, condition_hashes = task.passes_filters_soft(criteria_evaluation) + assert passes + + # make 'e' & 'f' None, we don't pass the 2nd filtergroup + criteria_evaluation = { + "a": True, + "b": True, + "c": True, + "d": False, + "e": None, + "f": None, + } + assert not task.passes_filters(criteria_evaluation) + passes, conditional_hashes = task.passes_filters_soft(criteria_evaluation) + assert passes is None + assert {"e", "f"} == conditional_hashes + + # 1st filtergroup unknown + criteria_evaluation = { + "a": True, + "b": None, + "c": None, + "d": None, + "e": None, + "f": None, + } + assert not task.passes_filters(criteria_evaluation) + passes, conditional_hashes = task.passes_filters_soft(criteria_evaluation) + assert passes is None + assert {"b", "c", "d", "e", "f"} == conditional_hashes + + # 1st filtergroup unknown, 2nd cell False + criteria_evaluation = { + "a": True, + "b": None, + "c": None, + "d": False, + "e": None, + "f": None, + } + assert not task.passes_filters(criteria_evaluation) + passes, conditional_hashes = task.passes_filters_soft(criteria_evaluation) + assert passes is None + assert {"b", "e", "f"} == conditional_hashes + + # we pass the first qg, unknown 2nd + criteria_evaluation = { + "a": True, + "b": True, + "c": None, + "d": False, + "e": None, + "f": None, + } + quotas = [ + DynataQuotaGroup.model_validate( + [ + {"count": 100, "condition_hashes": ["a", "b"], "status": "OPEN"}, + {"count": 100, "condition_hashes": ["c", "d"], "status": "CLOSED"}, + ] + ), + DynataQuotaGroup.model_validate( + [{"count": 100, "condition_hashes": ["f"], "status": "OPEN"}] + ), + ] + task.quotas = quotas + passes, conditional_hashes = task.passes_quotas_soft(criteria_evaluation) + assert passes is None + assert {"f"} == conditional_hashes + + # both quota groups unknown + criteria_evaluation = { + "a": True, + "b": None, + "c": None, + "d": False, + "e": None, + "g": None, + } + quotas = [ + DynataQuotaGroup.model_validate( + [ + {"count": 100, "condition_hashes": ["a", "b"], "status": "OPEN"}, + {"count": 100, "condition_hashes": ["c", "d"], "status": "CLOSED"}, + ] + ), + DynataQuotaGroup.model_validate( + [{"count": 100, "condition_hashes": ["g"], "status": "OPEN"}] + ), + ] + task.quotas = quotas + passes, conditional_hashes = task.passes_quotas_soft(criteria_evaluation) + assert passes is None + assert {"b", "g"} == conditional_hashes + + passes, conditional_hashes = task.determine_eligibility_soft( + criteria_evaluation + ) + assert passes is None + assert {"b", "e", "f", "g"} == conditional_hashes + + # def x(self): + # # ---- + # c1 = DynataCondition(question_id='gender', values=['male'], value_type=ConditionValueType.LIST) # 718f759 + # c2 = DynataCondition(question_id='age', values=['18-24'], value_type=ConditionValueType.RANGE) # 7a7b290 + # obj1 = DynataFilterObject(cells=[c1.criterion_hash, c2.criterion_hash]) + # + # c3 = DynataCondition(question_id='gender', values=['female'], value_type=ConditionValueType.LIST) # 38fa4e1 + # c4 = DynataCondition(question_id='age', values=['35-45'], value_type=ConditionValueType.RANGE) # e4f06fa + # obj2 = DynataFilterObject(cells=[c3.criterion_hash, c4.criterion_hash]) + # + # grp1 = DynataFilterGroup(objects=[obj1, obj2]) + # + # # ----- + # c5 = DynataCondition(question_id='ethnicity', values=['white'], value_type=ConditionValueType.LIST) # eb9b9a4 + # obj3 = DynataFilterObject(cells=[c5.criterion_hash]) + # + # c6 = DynataCondition(question_id='ethnicity', values=['black'], value_type=ConditionValueType.LIST) # 039fe2d + # obj4 = DynataFilterObject(cells=[c6.criterion_hash]) + # + # grp2 = DynataFilterGroup(objects=[obj3, obj4]) + # # ----- + # q1 = DynataQuota(count=5, status=DynataStatus.OPEN, + # condition_hashes=[c1.criterion_hash, c2.criterion_hash]) + # q2 = DynataQuota(count=10, status=DynataStatus.CLOSED, + # condition_hashes=[c3.criterion_hash, c4.criterion_hash]) + # qg1 = DynataQuotaGroup(cells=[q1, q2]) + # # ---- + # + # s = DynataSurvey(survey_id='123', status=DynataStatus.OPEN, country_iso='us', + # language_iso='eng', group_id='123', client_id='123', project_id='12', + # filters=[grp1, grp2], + # quotas=[qg1]) + # ce = {'718f759': True, '7a7b290': True, 'eb9b9a4': True} + # s.passes_filters(ce) + # s.passes_quotas(ce) diff --git a/tests/models/dynata/test_survey.py b/tests/models/dynata/test_survey.py new file mode 100644 index 0000000..ad953a3 --- /dev/null +++ b/tests/models/dynata/test_survey.py @@ -0,0 +1,164 @@ +class TestDynataCondition: + + def test_condition_create(self): + from generalresearch.models.dynata.survey import DynataCondition + + cell = { + "tag": "90606986-5508-461b-a821-216e9a72f1a0", + "attribute_id": 120, + "negate": False, + "kind": "VALUE", + "value": "45398", + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"120": {"45398"}}) + assert not c.evaluate_criterion({"120": {"11111"}}) + + cell = { + "tag": "aa7169c0-cb34-499a-aadd-31e0013df8fd", + "attribute_id": 231302, + "negate": False, + "operator": "OR", + "kind": "LIST", + "list": ["514802", "514804", "514808", "514810"], + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"231302": {"514804", "123445"}}) + assert not c.evaluate_criterion({"231302": {"123445"}}) + + cell = { + "tag": "aa7169c0-cb34-499a-aadd-31e0013df8fd", + "attribute_id": 231302, + "negate": False, + "operator": "AND", + "kind": "LIST", + "list": ["514802", "514804"], + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"231302": {"514802", "514804"}}) + assert not c.evaluate_criterion({"231302": {"514802"}}) + + cell = { + "tag": "75a36c67-0328-4c1b-a4dd-67d34688ff68", + "attribute_id": 80, + "negate": False, + "kind": "RANGE", + "range": {"from": 18, "to": 99}, + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"80": {"20"}}) + assert not c.evaluate_criterion({"80": {"120"}}) + + cell = { + "tag": "dd64b622-ed10-4a3b-e1h8-a4e63b59vha2", + "attribute_id": 83, + "negate": False, + "kind": "INEFFABLE", + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"83": {"20"}}) + + cell = { + "tag": "kei35kkjj-d00k-52kj-b3j4-a4jinx9832", + "attribute_id": 8, + "negate": False, + "kind": "ANSWERED", + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"8": {"20"}}) + assert not c.evaluate_criterion({"81": {"20"}}) + + def test_condition_range(self): + from generalresearch.models.dynata.survey import DynataCondition + + cell = { + "tag": "75a36c67-0328-4c1b-a4dd-67d34688ff68", + "attribute_id": 80, + "negate": False, + "kind": "RANGE", + "range": {"from": 18, "to": None}, + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"80": {"20"}}) + + def test_recontact(self): + from generalresearch.models.dynata.survey import DynataCondition + + cell = { + "tag": "d559212d-7984-4239-89c2-06c29588d79e", + "attribute_id": 238384, + "negate": False, + "operator": "OR", + "kind": "INVITE_COLLECTIONS", + "invite_collections": ["621041", "621042"], + } + c = DynataCondition.from_api(cell) + assert c.evaluate_criterion({"80": {"20"}}, user_groups={"621041", "a"}) + + +class TestDynataSurvey: + pass + + # def test_survey_eligibility(self): + # d = {'survey_id': 29333264, 'survey_name': '#29333264', 'survey_status': 22, + # 'field_end_date': datetime(2024, 5, 23, 18, 18, 31, tzinfo=timezone.utc), + # 'category': 'Exciting New', 'category_code': 232, + # 'crtd_on': datetime(2024, 5, 20, 17, 48, 13, tzinfo=timezone.utc), + # 'mod_on': datetime(2024, 5, 20, 18, 18, 31, tzinfo=timezone.utc), + # 'soft_launch': False, 'click_balancing': 0, 'price_type': 1, 'pii': False, + # 'buyer_message': '', 'buyer_id': 4726, 'incl_excl': 0, + # 'cpi': Decimal('1.20000'), 'last_complete_date': None, 'project_last_complete_date': None, + # 'quotas': [], 'qualifications': [], + # 'country_iso': 'fr', 'language_iso': 'fre', 'overall_ir': 0.4, 'overall_loi': 600, + # 'last_block_ir': None, 'last_block_loi': None, 'survey_exclusions': set(), 'exclusion_period': 0} + # s = DynataSurvey.from_api(d) + # s.qualifications = ['a', 'b', 'c'] + # s.quotas = [ + # SpectrumQuota(remaining_count=10, condition_hashes=['a', 'b']), + # SpectrumQuota(remaining_count=0, condition_hashes=['d']), + # SpectrumQuota(remaining_count=10, condition_hashes=['e']) + # ] + # + # self.assertTrue(s.passes_qualifications({'a': True, 'b': True, 'c': True})) + # self.assertFalse(s.passes_qualifications({'a': True, 'b': True, 'c': False})) + # + # # we do NOT match a full quota, so we pass + # self.assertTrue(s.passes_quotas({'a': True, 'b': True, 'd': False})) + # # We dont pass any + # self.assertFalse(s.passes_quotas({})) + # # we only pass a full quota + # self.assertFalse(s.passes_quotas({'d': True})) + # # we only dont pass a full quota, but we haven't passed any open + # self.assertFalse(s.passes_quotas({'d': False})) + # # we pass a quota, but also pass a full quota, so fail + # self.assertFalse(s.passes_quotas({'e': True, 'd': True})) + # # we pass a quota, but are unknown in a full quota, so fail + # self.assertFalse(s.passes_quotas({'e': True})) + # + # # # Soft Pair + # self.assertEqual((True, set()), s.passes_qualifications_soft({'a': True, 'b': True, 'c': True})) + # self.assertEqual((False, set()), s.passes_qualifications_soft({'a': True, 'b': True, 'c': False})) + # self.assertEqual((None, set('c')), s.passes_qualifications_soft({'a': True, 'b': True, 'c': None})) + # + # # we do NOT match a full quota, so we pass + # self.assertEqual((True, set()), s.passes_quotas_soft({'a': True, 'b': True, 'd': False})) + # # We dont pass any + # self.assertEqual((None, {'a', 'b', 'd', 'e'}), s.passes_quotas_soft({})) + # # we only pass a full quota + # self.assertEqual((False, set()), s.passes_quotas_soft({'d': True})) + # # we only dont pass a full quota, but we haven't passed any open + # self.assertEqual((None, {'a', 'b', 'e'}), s.passes_quotas_soft({'d': False})) + # # we pass a quota, but also pass a full quota, so fail + # self.assertEqual((False, set()), s.passes_quotas_soft({'e': True, 'd': True})) + # # we pass a quota, but are unknown in a full quota, so fail + # self.assertEqual((None, {'d'}), s.passes_quotas_soft({'e': True})) + # + # self.assertEqual(True, s.determine_eligibility({'a': True, 'b': True, 'c': True, 'd': False})) + # self.assertEqual(False, s.determine_eligibility({'a': True, 'b': True, 'c': False, 'd': False})) + # self.assertEqual(False, s.determine_eligibility({'a': True, 'b': True, 'c': None, 'd': False})) + # self.assertEqual((True, set()), s.determine_eligibility_soft({'a': True, 'b': True, 'c': True, 'd': False})) + # self.assertEqual((False, set()), s.determine_eligibility_soft({'a': True, 'b': True, 'c': False, 'd': False})) + # self.assertEqual((None, set('c')), s.determine_eligibility_soft({'a': True, 'b': True, 'c': None, + # 'd': False})) + # self.assertEqual((None, {'c', 'd'}), s.determine_eligibility_soft({'a': True, 'b': True, 'c': None, + # 'd': None})) diff --git a/tests/models/gr/__init__.py b/tests/models/gr/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/gr/test_authentication.py b/tests/models/gr/test_authentication.py new file mode 100644 index 0000000..e906d8c --- /dev/null +++ b/tests/models/gr/test_authentication.py @@ -0,0 +1,313 @@ +import binascii +import json +import os +from datetime import datetime, timezone +from random import randint +from uuid import uuid4 + +import pytest + +SSO_ISSUER = "" + + +class TestGRUser: + + def test_init(self, gr_user): + from generalresearch.models.gr.authentication import GRUser + + assert isinstance(gr_user, GRUser) + assert not gr_user.is_superuser + + assert gr_user.teams is None + assert gr_user.businesses is None + assert gr_user.products is None + + @pytest.mark.skip(reason="TODO") + def test_businesses(self): + pass + + def test_teams(self, gr_user, membership, gr_db, gr_redis_config): + from generalresearch.models.gr.team import Team + + assert gr_user.teams is None + + gr_user.prefetch_teams(pg_config=gr_db, redis_config=gr_redis_config) + + assert isinstance(gr_user.teams, list) + assert len(gr_user.teams) == 1 + assert isinstance(gr_user.teams[0], Team) + + def test_prefetch_team_duplicates( + self, + gr_user_token, + gr_user, + membership, + product_factory, + membership_factory, + team, + thl_web_rr, + gr_redis_config, + gr_db, + ): + product_factory(team=team) + membership_factory(team=team, gr_user=gr_user) + + gr_user.prefetch_teams( + pg_config=gr_db, + redis_config=gr_redis_config, + ) + + assert len(gr_user.teams) == 1 + + def test_products( + self, + gr_user, + product_factory, + team, + membership, + gr_db, + thl_web_rr, + gr_redis_config, + ): + from generalresearch.models.thl.product import Product + + assert gr_user.products is None + + # Create a new Team membership, and then create a Product that + # is part of that team + membership.prefetch_team(pg_config=gr_db, redis_config=gr_redis_config) + p: Product = product_factory(team=team) + assert p.id_int + assert team.uuid == membership.team.uuid + assert p.team_id == team.uuid + assert p.team_uuid == membership.team.uuid + assert gr_user.id == membership.user_id + + gr_user.prefetch_products( + pg_config=gr_db, + thl_pg_config=thl_web_rr, + redis_config=gr_redis_config, + ) + assert isinstance(gr_user.products, list) + assert len(gr_user.products) == 1 + assert isinstance(gr_user.products[0], Product) + + +class TestGRUserMethods: + + def test_cache_key(self, gr_user, gr_redis): + assert isinstance(gr_user.cache_key, str) + assert ":" in gr_user.cache_key + assert str(gr_user.id) in gr_user.cache_key + + def test_to_redis( + self, + gr_user, + gr_redis, + team, + business, + product_factory, + membership_factory, + ): + product_factory(team=team, business=business) + membership_factory(team=team, gr_user=gr_user) + + res = gr_user.to_redis() + assert isinstance(res, str) + + from generalresearch.models.gr.authentication import GRUser + + instance = GRUser.from_redis(res) + assert isinstance(instance, GRUser) + + def test_set_cache( + self, + gr_user, + gr_user_token, + gr_redis, + gr_db, + thl_web_rr, + gr_redis_config, + ): + assert gr_redis.get(name=gr_user.cache_key) is None + assert gr_redis.get(name=f"{gr_user.cache_key}:team_uuids") is None + assert gr_redis.get(name=f"{gr_user.cache_key}:business_uuids") is None + assert gr_redis.get(name=f"{gr_user.cache_key}:product_uuids") is None + + gr_user.set_cache( + pg_config=gr_db, thl_web_rr=thl_web_rr, redis_config=gr_redis_config + ) + + assert gr_redis.get(name=gr_user.cache_key) is not None + assert gr_redis.get(name=f"{gr_user.cache_key}:team_uuids") is not None + assert gr_redis.get(name=f"{gr_user.cache_key}:business_uuids") is not None + assert gr_redis.get(name=f"{gr_user.cache_key}:product_uuids") is not None + + def test_set_cache_gr_user( + self, + gr_user, + gr_user_token, + gr_redis, + gr_redis_config, + gr_db, + thl_web_rr, + product_factory, + team, + membership_factory, + thl_redis_config, + ): + from generalresearch.models.gr.authentication import GRUser + + p1 = product_factory(team=team) + membership_factory(team=team, gr_user=gr_user) + + gr_user.set_cache( + pg_config=gr_db, thl_web_rr=thl_web_rr, redis_config=gr_redis_config + ) + + res: str = gr_redis.get(name=gr_user.cache_key) + gru2 = GRUser.from_redis(res) + + assert gr_user.model_dump_json( + exclude={"businesses", "teams", "products"} + ) == gru2.model_dump_json(exclude={"businesses", "teams", "products"}) + + gru2.prefetch_products( + pg_config=gr_db, + thl_pg_config=thl_web_rr, + redis_config=thl_redis_config, + ) + assert gru2.product_uuids == [p1.uuid] + + def test_set_cache_team_uuids( + self, + gr_user, + membership, + gr_user_token, + gr_redis, + gr_db, + thl_web_rr, + product_factory, + team, + gr_redis_config, + ): + product_factory(team=team) + + gr_user.set_cache( + pg_config=gr_db, thl_web_rr=thl_web_rr, redis_config=gr_redis_config + ) + res = json.loads(gr_redis.get(name=f"{gr_user.cache_key}:team_uuids")) + assert len(res) == 1 + assert gr_user.team_uuids == res + + @pytest.mark.skip + def test_set_cache_business_uuids( + self, + gr_user, + membership, + gr_user_token, + gr_redis, + gr_db, + thl_web_rr, + product_factory, + business, + team, + gr_redis_config, + ): + product_factory(team=team, business=business) + + gr_user.set_cache( + pg_config=gr_db, thl_web_rr=thl_web_rr, redis_config=gr_redis_config + ) + res = json.loads(gr_redis.get(name=f"{gr_user.cache_key}:business_uuids")) + assert len(res) == 1 + assert gr_user.business_uuids == res + + def test_set_cache_product_uuids( + self, + gr_user, + membership, + gr_user_token, + gr_redis, + gr_db, + thl_web_rr, + product_factory, + team, + gr_redis_config, + ): + product_factory(team=team) + + gr_user.set_cache( + pg_config=gr_db, thl_web_rr=thl_web_rr, redis_config=gr_redis_config + ) + res = json.loads(gr_redis.get(name=f"{gr_user.cache_key}:product_uuids")) + assert len(res) == 1 + assert gr_user.product_uuids == res + + +class TestGRToken: + + @pytest.fixture + def gr_token(self, gr_user): + from generalresearch.models.gr.authentication import GRToken + + now = datetime.now(tz=timezone.utc) + token = binascii.hexlify(os.urandom(20)).decode() + + gr_token = GRToken(key=token, created=now, user_id=gr_user.id) + + return gr_token + + def test_init(self, gr_token): + from generalresearch.models.gr.authentication import GRToken + + assert isinstance(gr_token, GRToken) + assert gr_token.created + + def test_user(self, gr_token, gr_db, gr_redis_config): + from generalresearch.models.gr.authentication import GRUser + + assert gr_token.user is None + + gr_token.prefetch_user(pg_config=gr_db, redis_config=gr_redis_config) + + assert isinstance(gr_token.user, GRUser) + + def test_auth_header(self, gr_token): + assert isinstance(gr_token.auth_header, dict) + + +class TestClaims: + + def test_init(self): + from generalresearch.models.gr.authentication import Claims + + d = { + "iss": SSO_ISSUER, + "sub": f"{uuid4().hex}{uuid4().hex}", + "aud": uuid4().hex, + "exp": randint(a=1_500_000_000, b=2_000_000_000), + "iat": randint(a=1_500_000_000, b=2_000_000_000), + "auth_time": randint(a=1_500_000_000, b=2_000_000_000), + "acr": "goauthentik.io/providers/oauth2/default", + "amr": ["pwd", "mfa"], + "sid": f"{uuid4().hex}{uuid4().hex}", + "email": "max@g-r-l.com", + "email_verified": True, + "name": "Max Nanis", + "given_name": "Max Nanis", + "preferred_username": "nanis", + "nickname": "nanis", + "groups": [ + "authentik Admins", + "Developers", + "Systems Admin", + "Customer Support", + "admin", + ], + "azp": uuid4().hex, + "uid": uuid4().hex, + } + instance = Claims.model_validate(d) + + assert isinstance(instance, Claims) diff --git a/tests/models/gr/test_business.py b/tests/models/gr/test_business.py new file mode 100644 index 0000000..9a7718d --- /dev/null +++ b/tests/models/gr/test_business.py @@ -0,0 +1,1432 @@ +import os +from datetime import datetime, timedelta, timezone +from decimal import Decimal +from typing import Optional +from uuid import uuid4 + +import pandas as pd +import pytest + +# noinspection PyUnresolvedReferences +from distributed.utils_test import ( + gen_cluster, + client_no_amm, + loop, + loop_in_thread, + cleanup, + cluster_fixture, + client, +) +from pytest import approx + +from generalresearch.currency import USDCent +from generalresearch.models.thl.finance import ( + ProductBalances, + BusinessBalances, +) + +# from test_utils.incite.conftest import mnt_filepath +from test_utils.managers.conftest import ( + business_bank_account_manager, + lm, + thl_lm, +) + + +class TestBusinessBankAccount: + + def test_init(self, business, business_bank_account_manager): + from generalresearch.models.gr.business import ( + BusinessBankAccount, + TransferMethod, + ) + + instance = business_bank_account_manager.create( + business_id=business.id, + uuid=uuid4().hex, + transfer_method=TransferMethod.ACH, + ) + assert isinstance(instance, BusinessBankAccount) + + def test_business(self, business_bank_account, business, gr_db, gr_redis_config): + from generalresearch.models.gr.business import Business + + assert business_bank_account.business is None + + business_bank_account.prefetch_business( + pg_config=gr_db, redis_config=gr_redis_config + ) + assert isinstance(business_bank_account.business, Business) + assert business_bank_account.business.uuid == business.uuid + + +class TestBusinessAddress: + + def test_init(self, business_address): + from generalresearch.models.gr.business import BusinessAddress + + assert isinstance(business_address, BusinessAddress) + + +class TestBusinessContact: + + def test_init(self): + from generalresearch.models.gr.business import BusinessContact + + bc = BusinessContact(name="abc", email="test@abc.com") + assert isinstance(bc, BusinessContact) + + +class TestBusiness: + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "30d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return None + + def test_init(self, business): + from generalresearch.models.gr.business import Business + + assert isinstance(business, Business) + assert isinstance(business.id, int) + assert isinstance(business.uuid, str) + + def test_str_and_repr( + self, + business, + product_factory, + thl_web_rr, + lm, + thl_lm, + business_payout_event_manager, + bp_payout_factory, + start, + user_factory, + session_with_tx_factory, + pop_ledger_merge, + client_no_amm, + ledger_collection, + mnt_filepath, + create_main_accounts, + ): + create_main_accounts() + p1 = product_factory(business=business) + u1 = user_factory(product=p1) + p2 = product_factory(business=business) + thl_lm.get_account_or_create_bp_wallet(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p2) + + res1 = repr(business) + + assert business.uuid in res1 + assert " "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "30d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return None + + @pytest.mark.skip + def test_product_ordering(self): + # Assert that the order of business.balance.product_balances is always + # consistent and in the same order based off product.created ASC + pass + + def test_single_product( + self, + business, + product_factory, + user_factory, + mnt_filepath, + bp_payout_factory, + thl_lm, + lm, + duration, + offset, + start, + thl_web_rr, + payout_event_manager, + session_with_tx_factory, + delete_ledger_db, + create_main_accounts, + client_no_amm, + ledger_collection, + pop_ledger_merge, + delete_df_collection, + ): + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + u2: User = user_factory(product=p1) + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + + session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal("1.25"), + started=start + timedelta(days=2), + ) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + assert isinstance(business.balance, BusinessBalances) + assert business.balance.payout == 190 + assert business.balance.adjustment == 0 + assert business.balance.net == 190 + assert business.balance.retainer == 47 + assert business.balance.available_balance == 143 + + assert len(business.balance.product_balances) == 1 + pb = business.balance.product_balances[0] + assert isinstance(pb, ProductBalances) + assert pb.balance == business.balance.balance + assert pb.available_balance == business.balance.available_balance + assert pb.adjustment_percent == 0.0 + + def test_multi_product( + self, + business, + product_factory, + user_factory, + mnt_filepath, + bp_payout_factory, + thl_lm, + lm, + duration, + offset, + start, + thl_web_rr, + payout_event_manager, + session_with_tx_factory, + delete_ledger_db, + create_main_accounts, + client_no_amm, + ledger_collection, + pop_ledger_merge, + delete_df_collection, + ): + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product_factory(business=business)) + u2: User = user_factory(product=product_factory(business=business)) + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + + session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal("1.25"), + started=start + timedelta(days=2), + ) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + assert isinstance(business.balance, BusinessBalances) + assert business.balance.payout == 190 + assert business.balance.balance == 190 + assert business.balance.adjustment == 0 + assert business.balance.net == 190 + assert business.balance.retainer == 46 + assert business.balance.available_balance == 144 + + assert len(business.balance.product_balances) == 2 + + pb1 = business.balance.product_balances[0] + pb2 = business.balance.product_balances[1] + assert isinstance(pb1, ProductBalances) + assert pb1.product_id == u1.product_id + assert isinstance(pb2, ProductBalances) + assert pb2.product_id == u2.product_id + + for pb in [pb1, pb2]: + assert pb.balance != business.balance.balance + assert pb.available_balance != business.balance.available_balance + assert pb.adjustment_percent == 0.0 + + assert pb1.product_id in [u1.product_id, u2.product_id] + assert pb1.payout == 71 + assert pb1.adjustment == 0 + assert pb1.expense == 0 + assert pb1.net == 71 + assert pb1.retainer == 17 + assert pb1.available_balance == 54 + + assert pb2.product_id in [u1.product_id, u2.product_id] + assert pb2.payout == 119 + assert pb2.adjustment == 0 + assert pb2.expense == 0 + assert pb2.net == 119 + assert pb2.retainer == 29 + assert pb2.available_balance == 90 + + def test_multi_product_multi_payout( + self, + business, + product_factory, + user_factory, + mnt_filepath, + bp_payout_factory, + thl_lm, + lm, + duration, + offset, + start, + thl_web_rr, + payout_event_manager, + session_with_tx_factory, + delete_ledger_db, + create_main_accounts, + client_no_amm, + ledger_collection, + pop_ledger_merge, + delete_df_collection, + ): + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product_factory(business=business)) + u2: User = user_factory(product=product_factory(business=business)) + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + + session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal("1.25"), + started=start + timedelta(days=2), + ) + + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + bp_payout_factory( + product=u1.product, + amount=USDCent(5), + created=start + timedelta(days=4), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=u2.product, + amount=USDCent(50), + created=start + timedelta(days=4), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + assert business.balance.payout == 190 + assert business.balance.net == 190 + + assert business.balance.balance == 135 + + def test_multi_product_multi_payout_adjustment( + self, + business, + product_factory, + user_factory, + mnt_filepath, + bp_payout_factory, + thl_lm, + lm, + duration, + offset, + start, + thl_web_rr, + payout_event_manager, + session_with_tx_factory, + delete_ledger_db, + create_main_accounts, + client_no_amm, + ledger_collection, + task_adj_collection, + pop_ledger_merge, + wall_manager, + session_manager, + adj_to_fail_with_tx_factory, + delete_df_collection, + ): + """ + - Product 1 $2.50 Complete + - Product 2 $2.50 Complete + - $2.50 Payout on Product 1 + - $0.50 Payout on Product 2 + - Product 3 $2.50 Complete + - Complete -> Failure $2.50 Adjustment on Product 1 + ==== + - Net: $7.50 * .95 = $7.125 + - $2.50 = $2.375 = $2.38 + - $2.50 = $2.375 = $2.38 + - $2.50 = $2.375 = $2.38 + ==== + - $7.14 + - Balance: $2 + """ + + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + delete_df_collection(coll=task_adj_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product_factory(business=business)) + u2: User = user_factory(product=product_factory(business=business)) + u3: User = user_factory(product=product_factory(business=business)) + + s1 = session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("2.50"), + started=start + timedelta(days=1), + ) + + session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal("2.50"), + started=start + timedelta(days=2), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + bp_payout_factory( + product=u1.product, + amount=USDCent(250), + created=start + timedelta(days=3), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=u2.product, + amount=USDCent(50), + created=start + timedelta(days=4), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + adj_to_fail_with_tx_factory(session=s1, created=start + timedelta(days=5)) + + session_with_tx_factory( + user=u3, + wall_req_cpi=Decimal("2.50"), + started=start + timedelta(days=6), + ) + + # Build and prepare the Business with the db transactions now in place + + # This isn't needed for Business Balance... but good to also check + # task_adj_collection.initial_load(client=None, sync=True) + # These are the only two that are needed for Business Balance + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + df = client_no_amm.compute(ledger_collection.ddf(), sync=True) + assert df.shape == (24, 24) + + df = client_no_amm.compute(pop_ledger_merge.ddf(), sync=True) + assert df.shape == (20, 28) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + assert business.balance.payout == 714 + assert business.balance.adjustment == -238 + + assert business.balance.product_balances[0].adjustment == -238 + assert business.balance.product_balances[1].adjustment == 0 + assert business.balance.product_balances[2].adjustment == 0 + + assert business.balance.expense == 0 + assert business.balance.net == 714 - 238 + assert business.balance.balance == business.balance.payout - (250 + 50 + 238) + + predicted_retainer = sum( + [ + pb.balance * 0.25 + for pb in business.balance.product_balances + if pb.balance > 0 + ] + ) + assert business.balance.retainer == approx(predicted_retainer, rel=0.01) + + def test_neg_balance_cache( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + payout_event_manager, + adj_to_fail_with_tx_factory, + thl_web_rr, + lm, + ): + """Test having a Business with two products.. one that lost money + and one that gained money. Ensure that the Business balance + reflects that to compensate for the Product in the negative. + """ + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + p1: Product = product_factory(business=business) + p2: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + u2: User = user_factory(product=p2) + thl_lm.get_account_or_create_bp_wallet(product=p1) + thl_lm.get_account_or_create_bp_wallet(product=p2) + + # Product 1: Complete, Payout, Recon.. + s1 = session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + bp_payout_factory( + product=u1.product, + amount=USDCent(71), + ext_ref_id=uuid4().hex, + created=start + timedelta(days=1, minutes=1), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + adj_to_fail_with_tx_factory( + session=s1, + created=start + timedelta(days=1, minutes=2), + ) + + # Product 2: Complete, Complete. + s2 = session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1, minutes=3), + ) + s3 = session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1, minutes=4), + ) + + # Finally, process everything: + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + # Check Product 1 + pb1 = business.balance.product_balances[0] + assert pb1.product_id == p1.uuid + assert pb1.payout == 71 + assert pb1.adjustment == -71 + assert pb1.net == 0 + assert pb1.balance == 71 - (71 * 2) + assert pb1.retainer == 0 + assert pb1.available_balance == 0 + + # Check Product 2 + pb2 = business.balance.product_balances[1] + assert pb2.product_id == p2.uuid + assert pb2.payout == 71 * 2 + assert pb2.adjustment == 0 + assert pb2.net == 71 * 2 + assert pb2.balance == (71 * 2) + assert pb2.retainer == pytest.approx((71 * 2) * 0.25, rel=1) + assert pb2.available_balance == 107 + + # Check Business + bb1 = business.balance + assert bb1.payout == (71 * 3) # Raw total of completes + assert bb1.adjustment == -71 # 1 Complete >> Failure + assert bb1.expense == 0 + assert bb1.net == (71 * 3) - 71 # How much the Business actually earned + assert ( + bb1.balance == (71 * 3) - 71 - 71 + ) # 3 completes, but 1 payout and 1 recon leaves only one complete + # worth of activity on the account + assert bb1.retainer == pytest.approx((71 * 2) * 0.25, rel=1) + assert bb1.available_balance_usd_str == "$0.36" + + # Confirm that the debt from the pb1 in the red is covered when + # calculating the Business balance by the profit of pb2 + assert pb2.available_balance + pb1.balance == bb1.available_balance + + def test_multi_product_multi_payout_adjustment_at_timestamp( + self, + business, + product_factory, + user_factory, + mnt_filepath, + bp_payout_factory, + thl_lm, + lm, + duration, + offset, + start, + thl_web_rr, + payout_event_manager, + session_with_tx_factory, + delete_ledger_db, + create_main_accounts, + client_no_amm, + ledger_collection, + task_adj_collection, + pop_ledger_merge, + wall_manager, + session_manager, + adj_to_fail_with_tx_factory, + delete_df_collection, + ): + """ + This test measures a complex Business situation, but then makes + various assertions based off the query which uses an at_timestamp. + + The goal here is a feature that allows us to look back and see + what the balance was of an account at any specific point in time. + + - Day 1: Product 1 $2.50 Complete + - Total Payout: $2.38 + - Smart Retainer: $0.59 + - Available Balance: $1.79 + - Day 2: Product 2 $2.50 Complete + - Total Payout: $4.76 + - Smart Retainer: $1.18 + - Available Balance: $3.58 + - Day 3: $2.50 Payout on Product 1 + - Total Payout: $4.76 + - Smart Retainer: $0.59 + - Available Balance: $1.67 + - Day 4: $0.50 Payout on Product 2 + - Total Payout: $4.76 + - Smart Retainer: $0.47 + - Available Balance: $1.29 + - Day 5: Product 3 $2.50 Complete + - Total Payout: $7.14 + - Smart Retainer: $1.06 + - Available Balance: $3.08 + - Day 6: Complete -> Failure $2.50 Adjustment on Product 1 + - Total Payout: $7.18 + - Smart Retainer: $1.06 + - Available Balance: $0.70 + """ + + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + delete_df_collection(coll=task_adj_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product_factory(business=business)) + u2: User = user_factory(product=product_factory(business=business)) + u3: User = user_factory(product=product_factory(business=business)) + + s1 = session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("2.50"), + started=start + timedelta(days=1), + ) + + session_with_tx_factory( + user=u2, + wall_req_cpi=Decimal("2.50"), + started=start + timedelta(days=2), + ) + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + bp_payout_factory( + product=u1.product, + amount=USDCent(250), + created=start + timedelta(days=3), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + bp_payout_factory( + product=u2.product, + amount=USDCent(50), + created=start + timedelta(days=4), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + session_with_tx_factory( + user=u3, + wall_req_cpi=Decimal("2.50"), + started=start + timedelta(days=5), + ) + + adj_to_fail_with_tx_factory(session=s1, created=start + timedelta(days=6)) + + # Build and prepare the Business with the db transactions now in place + + # This isn't needed for Business Balance... but good to also check + # task_adj_collection.initial_load(client=None, sync=True) + # These are the only two that are needed for Business Balance + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + df = client_no_amm.compute(ledger_collection.ddf(), sync=True) + assert df.shape == (24, 24) + + df = client_no_amm.compute(pop_ledger_merge.ddf(), sync=True) + assert df.shape == (20, 28) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + at_timestamp=start + timedelta(days=1, hours=1), + ) + day1_bal = business.balance + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + at_timestamp=start + timedelta(days=2, hours=1), + ) + day2_bal = business.balance + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + at_timestamp=start + timedelta(days=3, hours=1), + ) + day3_bal = business.balance + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + at_timestamp=start + timedelta(days=4, hours=1), + ) + day4_bal = business.balance + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + at_timestamp=start + timedelta(days=5, hours=1), + ) + day5_bal = business.balance + + business.prebuild_balance( + thl_pg_config=thl_web_rr, + lm=lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + at_timestamp=start + timedelta(days=6, hours=1), + ) + day6_bal = business.balance + + assert day1_bal.payout == 238 + assert day1_bal.retainer == 59 + assert day1_bal.available_balance == 179 + + assert day2_bal.payout == 476 + assert day2_bal.retainer == 118 + assert day2_bal.available_balance == 358 + + assert day3_bal.payout == 476 + assert day3_bal.retainer == 59 + assert day3_bal.available_balance == 167 + + assert day4_bal.payout == 476 + assert day4_bal.retainer == 47 + assert day4_bal.available_balance == 129 + + assert day5_bal.payout == 714 + assert day5_bal.retainer == 106 + assert day5_bal.available_balance == 308 + + assert day6_bal.payout == 714 + assert day6_bal.retainer == 106 + assert day6_bal.available_balance == 70 + + +class TestBusinessMethods: + + @pytest.fixture(scope="function") + def start(self, utc_90days_ago) -> "datetime": + s = utc_90days_ago.replace(microsecond=0) + return s + + @pytest.fixture(scope="function") + def offset(self) -> str: + return "15d" + + @pytest.fixture(scope="function") + def duration( + self, + ) -> Optional["timedelta"]: + return None + + def test_cache_key(self, business, gr_redis): + assert isinstance(business.cache_key, str) + assert ":" in business.cache_key + assert str(business.uuid) in business.cache_key + + def test_set_cache( + self, + business, + gr_redis, + gr_db, + thl_web_rr, + client_no_amm, + mnt_filepath, + lm, + thl_lm, + business_payout_event_manager, + product_factory, + membership_factory, + team, + session_with_tx_factory, + user_factory, + ledger_collection, + pop_ledger_merge, + utc_60days_ago, + delete_ledger_db, + create_main_accounts, + gr_redis_config, + mnt_gr_api_dir, + ): + assert gr_redis.get(name=business.cache_key) is None + + p1 = product_factory(team=team, business=business) + u1 = user_factory(product=p1) + + # Business needs tx & incite to build balance + delete_ledger_db() + create_main_accounts() + thl_lm.get_account_or_create_bp_wallet(product=p1) + session_with_tx_factory(user=u1, started=utc_60days_ago) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.set_cache( + pg_config=gr_db, + thl_web_rr=thl_web_rr, + redis_config=gr_redis_config, + client=client_no_amm, + ds=mnt_filepath, + lm=lm, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + pop_ledger=pop_ledger_merge, + mnt_gr_api=mnt_gr_api_dir, + ) + + assert gr_redis.hgetall(name=business.cache_key) is not None + from generalresearch.models.gr.business import Business + + # We're going to pull only a specific year, but make sure that + # it's being assigned to the field regardless + year = datetime.now(tz=timezone.utc).year + res = Business.from_redis( + uuid=business.uuid, + fields=[f"pop_financial:{year}"], + gr_redis_config=gr_redis_config, + ) + assert len(res.pop_financial) > 0 + + def test_set_cache_business( + self, + gr_user, + business, + gr_user_token, + gr_redis, + gr_db, + thl_web_rr, + product_factory, + team, + membership_factory, + client_no_amm, + mnt_filepath, + lm, + thl_lm, + business_payout_event_manager, + user_factory, + delete_ledger_db, + create_main_accounts, + session_with_tx_factory, + ledger_collection, + team_manager, + pop_ledger_merge, + gr_redis_config, + utc_60days_ago, + mnt_gr_api_dir, + ): + from generalresearch.models.gr.business import Business + + p1 = product_factory(team=team, business=business) + u1 = user_factory(product=p1) + team_manager.add_business(team=team, business=business) + + # Business needs tx & incite to build balance + delete_ledger_db() + create_main_accounts() + thl_lm.get_account_or_create_bp_wallet(product=p1) + session_with_tx_factory(user=u1, started=utc_60days_ago) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + business.set_cache( + pg_config=gr_db, + thl_web_rr=thl_web_rr, + redis_config=gr_redis_config, + client=client_no_amm, + ds=mnt_filepath, + lm=lm, + thl_lm=thl_lm, + bpem=business_payout_event_manager, + pop_ledger=pop_ledger_merge, + mnt_gr_api=mnt_gr_api_dir, + ) + + # keys: List = Business.required_fields() + ["products", "bp_accounts"] + business2 = Business.from_redis( + uuid=business.uuid, + fields=[ + "id", + "tax_number", + "contact", + "addresses", + "teams", + "products", + "bank_accounts", + "balance", + "payouts_total_str", + "payouts_total", + "payouts", + "pop_financial", + "bp_accounts", + ], + gr_redis_config=gr_redis_config, + ) + + assert business.model_dump_json() == business2.model_dump_json() + assert p1.uuid in [p.uuid for p in business2.products] + assert len(business2.teams) == 1 + assert team.uuid in [t.uuid for t in business2.teams] + + assert business2.balance.payout == 48 + assert business2.balance.balance == 48 + assert business2.balance.net == 48 + assert business2.balance.retainer == 12 + assert business2.balance.available_balance == 36 + assert len(business2.balance.product_balances) == 1 + + assert len(business2.payouts) == 0 + + assert len(business2.bp_accounts) == 1 + assert len(business2.bp_accounts) == len(business2.product_uuids) + + assert len(business2.pop_financial) == 1 + assert business2.pop_financial[0].payout == business2.balance.payout + assert business2.pop_financial[0].net == business2.balance.net + + def test_prebuild_enriched_session_parquet( + self, + event_report_request, + enriched_session_merge, + client_no_amm, + wall_collection, + session_collection, + thl_web_rr, + session_report_request, + user_factory, + start, + session_factory, + product_factory, + delete_df_collection, + business, + mnt_filepath, + mnt_gr_api_dir, + ): + + delete_df_collection(coll=wall_collection) + delete_df_collection(coll=session_collection) + + p1 = product_factory(business=business) + p2 = product_factory(business=business) + + for p in [p1, p2]: + u = user_factory(product=p) + for i in range(50): + s = session_factory( + user=u, + wall_count=1, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(minutes=i, seconds=1), + ) + wall_collection.initial_load(client=None, sync=True) + session_collection.initial_load(client=None, sync=True) + + enriched_session_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + + business.prebuild_enriched_session_parquet( + thl_pg_config=thl_web_rr, + ds=mnt_filepath, + client=client_no_amm, + mnt_gr_api=mnt_gr_api_dir, + enriched_session=enriched_session_merge, + ) + + # Now try to read from path + df = pd.read_parquet( + os.path.join(mnt_gr_api_dir, "pop_session", f"{business.file_key}.parquet") + ) + assert isinstance(df, pd.DataFrame) + + def test_prebuild_enriched_wall_parquet( + self, + event_report_request, + enriched_session_merge, + enriched_wall_merge, + client_no_amm, + wall_collection, + session_collection, + thl_web_rr, + session_report_request, + user_factory, + start, + session_factory, + product_factory, + delete_df_collection, + business, + mnt_filepath, + mnt_gr_api_dir, + ): + + delete_df_collection(coll=wall_collection) + delete_df_collection(coll=session_collection) + + p1 = product_factory(business=business) + p2 = product_factory(business=business) + + for p in [p1, p2]: + u = user_factory(product=p) + for i in range(50): + s = session_factory( + user=u, + wall_count=1, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(minutes=i, seconds=1), + ) + wall_collection.initial_load(client=None, sync=True) + session_collection.initial_load(client=None, sync=True) + + enriched_wall_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + + business.prebuild_enriched_wall_parquet( + thl_pg_config=thl_web_rr, + ds=mnt_filepath, + client=client_no_amm, + mnt_gr_api=mnt_gr_api_dir, + enriched_wall=enriched_wall_merge, + ) + + # Now try to read from path + df = pd.read_parquet( + os.path.join(mnt_gr_api_dir, "pop_event", f"{business.file_key}.parquet") + ) + assert isinstance(df, pd.DataFrame) diff --git a/tests/models/gr/test_team.py b/tests/models/gr/test_team.py new file mode 100644 index 0000000..d728bbe --- /dev/null +++ b/tests/models/gr/test_team.py @@ -0,0 +1,296 @@ +import os +from datetime import timedelta +from decimal import Decimal + +import pandas as pd + + +class TestTeam: + + def test_init(self, team): + from generalresearch.models.gr.team import Team + + assert isinstance(team, Team) + assert isinstance(team.id, int) + assert isinstance(team.uuid, str) + + def test_memberships_none(self, team, gr_user_factory, gr_db): + assert team.memberships is None + + team.prefetch_memberships(pg_config=gr_db) + assert isinstance(team.memberships, list) + assert len(team.memberships) == 0 + + def test_memberships( + self, + team, + membership, + gr_user, + gr_user_factory, + membership_factory, + membership_manager, + gr_db, + ): + assert team.memberships is None + + team.prefetch_memberships(pg_config=gr_db) + assert isinstance(team.memberships, list) + assert len(team.memberships) == 1 + assert team.memberships[0].user_id == gr_user.id + + # Create another new Membership + membership_manager.create(team=team, gr_user=gr_user_factory()) + assert len(team.memberships) == 1 + team.prefetch_memberships(pg_config=gr_db) + assert len(team.memberships) == 2 + + def test_gr_users( + self, team, gr_user_factory, membership_manager, gr_db, gr_redis_config + ): + assert team.gr_users is None + + team.prefetch_gr_users(pg_config=gr_db, redis_config=gr_redis_config) + assert isinstance(team.gr_users, list) + assert len(team.gr_users) == 0 + + # Create a new Membership + membership_manager.create(team=team, gr_user=gr_user_factory()) + assert len(team.gr_users) == 0 + team.prefetch_gr_users(pg_config=gr_db, redis_config=gr_redis_config) + assert len(team.gr_users) == 1 + + # Create another Membership + membership_manager.create(team=team, gr_user=gr_user_factory()) + assert len(team.gr_users) == 1 + team.prefetch_gr_users(pg_config=gr_db, redis_config=gr_redis_config) + assert len(team.gr_users) == 2 + + def test_businesses(self, team, business, team_manager, gr_db, gr_redis_config): + from generalresearch.models.gr.business import Business + + assert team.businesses is None + + team.prefetch_businesses(pg_config=gr_db, redis_config=gr_redis_config) + assert isinstance(team.businesses, list) + assert len(team.businesses) == 0 + + team_manager.add_business(team=team, business=business) + assert len(team.businesses) == 0 + team.prefetch_businesses(pg_config=gr_db, redis_config=gr_redis_config) + assert len(team.businesses) == 1 + assert isinstance(team.businesses[0], Business) + assert team.businesses[0].uuid == business.uuid + + def test_products(self, team, product_factory, thl_web_rr): + from generalresearch.models.thl.product import Product + + assert team.products is None + + team.prefetch_products(thl_pg_config=thl_web_rr) + assert isinstance(team.products, list) + assert len(team.products) == 0 + + product_factory(team=team) + assert len(team.products) == 0 + team.prefetch_products(thl_pg_config=thl_web_rr) + assert len(team.products) == 1 + assert isinstance(team.products[0], Product) + + +class TestTeamMethods: + + def test_cache_key(self, team, gr_redis): + assert isinstance(team.cache_key, str) + assert ":" in team.cache_key + assert str(team.uuid) in team.cache_key + + def test_set_cache( + self, + team, + gr_redis, + gr_db, + thl_web_rr, + gr_redis_config, + client_no_amm, + mnt_filepath, + mnt_gr_api_dir, + enriched_wall_merge, + enriched_session_merge, + ): + assert gr_redis.get(name=team.cache_key) is None + + team.set_cache( + pg_config=gr_db, + thl_web_rr=thl_web_rr, + redis_config=gr_redis_config, + client=client_no_amm, + ds=mnt_filepath, + mnt_gr_api=mnt_gr_api_dir, + enriched_wall=enriched_wall_merge, + enriched_session=enriched_session_merge, + ) + + assert gr_redis.hgetall(name=team.cache_key) is not None + + def test_set_cache_team( + self, + gr_user, + gr_user_token, + gr_redis, + gr_db, + thl_web_rr, + product_factory, + team, + membership_factory, + gr_redis_config, + client_no_amm, + mnt_filepath, + mnt_gr_api_dir, + enriched_wall_merge, + enriched_session_merge, + ): + from generalresearch.models.gr.team import Team + + p1 = product_factory(team=team) + membership_factory(team=team, gr_user=gr_user) + + team.set_cache( + pg_config=gr_db, + thl_web_rr=thl_web_rr, + redis_config=gr_redis_config, + client=client_no_amm, + ds=mnt_filepath, + mnt_gr_api=mnt_gr_api_dir, + enriched_wall=enriched_wall_merge, + enriched_session=enriched_session_merge, + ) + + team2 = Team.from_redis( + uuid=team.uuid, + fields=["id", "memberships", "gr_users", "businesses", "products"], + gr_redis_config=gr_redis_config, + ) + + assert team.model_dump_json() == team2.model_dump_json() + assert p1.uuid in [p.uuid for p in team2.products] + assert len(team2.gr_users) == 1 + assert gr_user.id in [gru.id for gru in team2.gr_users] + + def test_prebuild_enriched_session_parquet( + self, + event_report_request, + enriched_session_merge, + client_no_amm, + wall_collection, + session_collection, + thl_web_rr, + session_report_request, + user_factory, + start, + session_factory, + product_factory, + delete_df_collection, + business, + mnt_filepath, + mnt_gr_api_dir, + team, + ): + + delete_df_collection(coll=wall_collection) + delete_df_collection(coll=session_collection) + + p1 = product_factory(team=team) + p2 = product_factory(team=team) + + for p in [p1, p2]: + u = user_factory(product=p) + for i in range(50): + s = session_factory( + user=u, + wall_count=1, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(minutes=i, seconds=1), + ) + wall_collection.initial_load(client=None, sync=True) + session_collection.initial_load(client=None, sync=True) + + enriched_session_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + + team.prebuild_enriched_session_parquet( + thl_pg_config=thl_web_rr, + ds=mnt_filepath, + client=client_no_amm, + mnt_gr_api=mnt_gr_api_dir, + enriched_session=enriched_session_merge, + ) + + # Now try to read from path + df = pd.read_parquet( + os.path.join(mnt_gr_api_dir, "pop_session", f"{team.file_key}.parquet") + ) + assert isinstance(df, pd.DataFrame) + + def test_prebuild_enriched_wall_parquet( + self, + event_report_request, + enriched_session_merge, + enriched_wall_merge, + client_no_amm, + wall_collection, + session_collection, + thl_web_rr, + session_report_request, + user_factory, + start, + session_factory, + product_factory, + delete_df_collection, + business, + mnt_filepath, + mnt_gr_api_dir, + team, + ): + + delete_df_collection(coll=wall_collection) + delete_df_collection(coll=session_collection) + + p1 = product_factory(team=team) + p2 = product_factory(team=team) + + for p in [p1, p2]: + u = user_factory(product=p) + for i in range(50): + s = session_factory( + user=u, + wall_count=1, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(minutes=i, seconds=1), + ) + wall_collection.initial_load(client=None, sync=True) + session_collection.initial_load(client=None, sync=True) + + enriched_wall_merge.build( + client=client_no_amm, + session_coll=session_collection, + wall_coll=wall_collection, + pg_config=thl_web_rr, + ) + + team.prebuild_enriched_wall_parquet( + thl_pg_config=thl_web_rr, + ds=mnt_filepath, + client=client_no_amm, + mnt_gr_api=mnt_gr_api_dir, + enriched_wall=enriched_wall_merge, + ) + + # Now try to read from path + df = pd.read_parquet( + os.path.join(mnt_gr_api_dir, "pop_event", f"{team.file_key}.parquet") + ) + assert isinstance(df, pd.DataFrame) diff --git a/tests/models/innovate/__init__.py b/tests/models/innovate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/innovate/test_question.py b/tests/models/innovate/test_question.py new file mode 100644 index 0000000..330f919 --- /dev/null +++ b/tests/models/innovate/test_question.py @@ -0,0 +1,85 @@ +from generalresearch.models import Source +from generalresearch.models.innovate.question import ( + InnovateQuestion, + InnovateQuestionType, + InnovateQuestionOption, +) +from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionSelectorTE, + UpkQuestion, + UpkQuestionSelectorMC, + UpkQuestionType, + UpkQuestionChoice, +) + + +class TestInnovateQuestion: + + def test_text_entry(self): + + q = InnovateQuestion( + question_id="3", + country_iso="us", + language_iso="eng", + question_key="ZIPCODES", + question_text="postal code", + question_type=InnovateQuestionType.TEXT_ENTRY, + tags=None, + options=None, + is_live=True, + category_id=None, + ) + assert Source.INNOVATE == q.source + assert "i:zipcodes" == q.external_id + assert "zipcodes" == q.internal_id + assert ("zipcodes", "us", "eng") == q._key + + upk = q.to_upk_question() + expected_upk = UpkQuestion( + ext_question_id="i:zipcodes", + type=UpkQuestionType.TEXT_ENTRY, + country_iso="us", + language_iso="eng", + text="postal code", + selector=UpkQuestionSelectorTE.SINGLE_LINE, + choices=None, + ) + assert expected_upk == upk + + def test_mc(self): + + text = "Have you purchased or received any of the following in past 18 months?" + q = InnovateQuestion( + question_key="dynamic_profiling-_1_14715", + country_iso="us", + language_iso="eng", + question_id="14715", + question_text=text, + question_type=InnovateQuestionType.MULTI_SELECT, + tags="Dynamic Profiling- 1", + options=[ + InnovateQuestionOption(id="1", text="aaa", order=0), + InnovateQuestionOption(id="2", text="bbb", order=1), + ], + is_live=True, + category_id=None, + ) + assert "i:dynamic_profiling-_1_14715" == q.external_id + assert "dynamic_profiling-_1_14715" == q.internal_id + assert ("dynamic_profiling-_1_14715", "us", "eng") == q._key + assert 2 == q.num_options + + upk = q.to_upk_question() + expected_upk = UpkQuestion( + ext_question_id="i:dynamic_profiling-_1_14715", + type=UpkQuestionType.MULTIPLE_CHOICE, + country_iso="us", + language_iso="eng", + text=text, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + choices=[ + UpkQuestionChoice(id="1", text="aaa", order=0), + UpkQuestionChoice(id="2", text="bbb", order=1), + ], + ) + assert expected_upk == upk diff --git a/tests/models/legacy/__init__.py b/tests/models/legacy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/legacy/data.py b/tests/models/legacy/data.py new file mode 100644 index 0000000..20f3231 --- /dev/null +++ b/tests/models/legacy/data.py @@ -0,0 +1,265 @@ +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/45b7228a7/?bpuid=379fb74f-05b2-42dc-b283 +# -47e1c8678b04&duration=1200&format=json&country_iso=us +RESPONSE_45b7228a7 = ( + '{"info": {"success": true}, "offerwall": {"availability_count": 9, "buckets": [{"category": [{"adwords_id": ' + 'null, "adwords_label": null, "id": "c82cf98c578a43218334544ab376b00e", "label": "Social Research", "p": 1.0}], ' + '"description": "", "duration": {"max": 719, "min": 72, "q1": 144, "q2": 621, "q3": 650}, ' + '"id": "5503c471d95645dd947080704d3760b3", "name": "", "payout": {"max": 132, "min": 68, "q1": 113, "q2": 124, ' + '"q3": 128}, "quality_score": 1.0, "uri": ' + '"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=5503c471d95645dd947080704d3760b3&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=16fa868", "x": 0, "y": 0}, ' + '{"category": [{"adwords_id": null, "adwords_label": null, "id": "c82cf98c578a43218334544ab376b00e", ' + '"label": "Social Research", "p": 0.6666666666666666}, {"adwords_id": "5000", "adwords_label": "World ' + 'Localities", "id": "cd3f9374ba5d4e5692ee5691320ecc8b", "label": "World Localities", "p": 0.16666666666666666}, ' + '{"adwords_id": "14", "adwords_label": "People & Society", "id": "c8642a1b86d9460cbe8f7e8ae6e56ee4", ' + '"label": "People & Society", "p": 0.16666666666666666}], "description": "", "duration": {"max": 1180, ' + '"min": 144, "q1": 457, "q2": 621, "q3": 1103}, "id": "56f437aa5da443748872390a5cbf6103", "name": "", "payout": {' + '"max": 113, "min": 24, "q1": 40, "q2": 68, "q3": 68}, "quality_score": 0.17667012, ' + '"uri": "https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=56f437aa5da443748872390a5cbf6103&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=16fa868", "x": 1, "y": 0}, ' + '{"category": [{"adwords_id": null, "adwords_label": null, "id": "c82cf98c578a43218334544ab376b00e", ' + '"label": "Social Research", "p": 0.6666666666666666}, {"adwords_id": "5000", "adwords_label": "World ' + 'Localities", "id": "cd3f9374ba5d4e5692ee5691320ecc8b", "label": "World Localities", "p": 0.16666666666666666}, ' + '{"adwords_id": "14", "adwords_label": "People & Society", "id": "c8642a1b86d9460cbe8f7e8ae6e56ee4", ' + '"label": "People & Society", "p": 0.16666666666666666}], "description": "", "duration": {"max": 1180, ' + '"min": 144, "q1": 457, "q2": 1103, "q3": 1128}, "id": "a5b8403e4a4a4ed1a21ef9b3a721ab02", "name": "", ' + '"payout": {"max": 68, "min": 14, "q1": 24, "q2": 40, "q3": 68}, "quality_score": 0.01, ' + '"uri": "https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=a5b8403e4a4a4ed1a21ef9b3a721ab02&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=16fa868", "x": 2, "y": 0}], ' + '"id": "2fba2999baf0423cad0c49eceea4eb33", "payout_format": "${payout/100:.2f}"}}' +) + +RESPONSE_b145b803 = ( + '{"info":{"success":true},"offerwall":{"availability_count":10,"buckets":[{"category":[{"adwords_id":null,' + '"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research","p":1.0}],"contents":[{' + '"id":"b9d6fdb95ae2402dbb8e8673be382f04","id_code":"m:b9d6fdb95ae2402dbb8e8673be382f04","loi":954,"payout":166,' + '"source":"m"},{"id":"x94r9bg","id_code":"o:x94r9bg","loi":71,"payout":132,"source":"o"},{"id":"ejqjbv4",' + '"id_code":"o:ejqjbv4","loi":650,"payout":128,"source":"o"},{"id":"yxqdnb9","id_code":"o:yxqdnb9","loi":624,' + '"payout":113,"source":"o"},{"id":"vyjrv0v","id_code":"o:vyjrv0v","loi":719,"payout":124,"source":"o"}],' + '"currency":"USD","description":"","duration":{"max":954,"min":72,"q1":625,"q2":650,"q3":719},' + '"id":"c190231a7d494012a2f641a89a85e6a6","name":"","payout":{"max":166,"min":113,"q1":124,"q2":128,"q3":132},' + '"quality_score":1.0,"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120' + '/?i=c190231a7d494012a2f641a89a85e6a6&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=9dbbf8c","x":0,"y":0},' + '{"category":[{"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social ' + 'Research","p":0.5},{"adwords_id":"5000","adwords_label":"World Localities",' + '"id":"cd3f9374ba5d4e5692ee5691320ecc8b","label":"World Localities","p":0.25},{"adwords_id":"14",' + '"adwords_label":"People & Society","id":"c8642a1b86d9460cbe8f7e8ae6e56ee4","label":"People & Society",' + '"p":0.25}],"contents":[{"id":"ejqa3kw","id_code":"o:ejqa3kw","loi":143,"payout":68,"source":"o"},' + '{"id":"g6xkrbm","id_code":"o:g6xkrbm","loi":536,"payout":68,"source":"o"},{"id":"yr5od0g","id_code":"o:yr5od0g",' + '"loi":457,"payout":68,"source":"o"}],"currency":"USD","description":"","duration":{"max":537,"min":144,"q1":301,' + '"q2":457,"q3":497},"id":"ee12be565f744ef2b194703f3d32f8cd","name":"","payout":{"max":68,"min":68,"q1":68,' + '"q2":68,"q3":68},"quality_score":0.01,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=ee12be565f744ef2b194703f3d32f8cd&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=9dbbf8c","x":1,"y":0},' + '{"category":[{"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social ' + 'Research","p":1.0}],"contents":[{"id":"a019660a8bc0411dba19a3e5c5df5b6c",' + '"id_code":"m:a019660a8bc0411dba19a3e5c5df5b6c","loi":1180,"payout":24,"source":"m"},{"id":"ejqa3kw",' + '"id_code":"o:ejqa3kw","loi":143,"payout":68,"source":"o"},{"id":"8c54725047cc4e0590665a034d37e7f5",' + '"id_code":"m:8c54725047cc4e0590665a034d37e7f5","loi":1128,"payout":14,"source":"m"}],"currency":"USD",' + '"description":"","duration":{"max":1180,"min":144,"q1":636,"q2":1128,"q3":1154},' + '"id":"dc2551eb48d84b329fdcb5b2bd60ed71","name":"","payout":{"max":68,"min":14,"q1":19,"q2":24,"q3":46},' + '"quality_score":0.06662835156312356,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=dc2551eb48d84b329fdcb5b2bd60ed71&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=9dbbf8c","x":2,"y":0}],' + '"id":"7cb9eb4c1a5b41a38cfebd13c9c338cb"}}' +) + +# This is a blocked user. Otherwise, the format is identical to RESPONSE_b145b803 +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/1e5f0af8/?bpuid=00051c62-a872-4832-a008 +# -c37ec51d33d3&duration=1200&format=json&country_iso=us +RESPONSE_d48cce47 = ( + '{"info":{"success":true},"offerwall":{"availability_count":0,"buckets":[],' + '"id":"168680387a7f4c8c9cc8e7ab63f502ff","payout_format":"${payout/100:.2f}"}}' +) + +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/1e5f0af8/?bpuid=379fb74f-05b2-42dc-b283 +# -47e1c8678b04&duration=1200&format=json&country_iso=us +RESPONSE_1e5f0af8 = ( + '{"info":{"success":true},"offerwall":{"availability_count":9,"buckets":[{"category":[{"adwords_id":null,' + '"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research","p":1.0}],"contents":[{' + '"id":"x94r9bg","id_code":"o:x94r9bg","loi":71,"payout":132,"source":"o"},{"id":"yxqdnb9","id_code":"o:yxqdnb9",' + '"loi":604,"payout":113,"source":"o"},{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":473,"payout":128,"source":"o"},' + '{"id":"vyjrv0v","id_code":"o:vyjrv0v","loi":719,"payout":124,"source":"o"}],"currency":"USD","description":"",' + '"duration":{"max":719,"min":72,"q1":373,"q2":539,"q3":633},"id":"2a4a897a76464af2b85703b72a125da0",' + '"is_recontact":false,"name":"","payout":{"max":132,"min":113,"q1":121,"q2":126,"q3":129},"quality_score":1.0,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=2a4a897a76464af2b85703b72a125da0&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=82fe142","x":0,"y":0},' + '{"category":[{"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social ' + 'Research","p":0.5},{"adwords_id":"5000","adwords_label":"World Localities",' + '"id":"cd3f9374ba5d4e5692ee5691320ecc8b","label":"World Localities","p":0.25},{"adwords_id":"14",' + '"adwords_label":"People & Society","id":"c8642a1b86d9460cbe8f7e8ae6e56ee4","label":"People & Society",' + '"p":0.25}],"contents":[{"id":"775ed98f65604dac91b3a60814438829","id_code":"m:775ed98f65604dac91b3a60814438829",' + '"loi":1121,"payout":32,"source":"m"},{"id":"ejqa3kw","id_code":"o:ejqa3kw","loi":143,"payout":68,"source":"o"},' + '{"id":"yr5od0g","id_code":"o:yr5od0g","loi":457,"payout":68,"source":"o"}],"currency":"USD","description":"",' + '"duration":{"max":1121,"min":144,"q1":301,"q2":457,"q3":789},"id":"0aa83eb711c042e28bb9284e604398ac",' + '"is_recontact":false,"name":"","payout":{"max":68,"min":32,"q1":50,"q2":68,"q3":68},"quality_score":0.01,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=0aa83eb711c042e28bb9284e604398ac&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=82fe142","x":1,"y":0},' + '{"category":[{"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social ' + 'Research","p":1.0}],"contents":[{"id":"775ed98f65604dac91b3a60814438829",' + '"id_code":"m:775ed98f65604dac91b3a60814438829","loi":1121,"payout":32,"source":"m"},' + '{"id":"8c54725047cc4e0590665a034d37e7f5","id_code":"m:8c54725047cc4e0590665a034d37e7f5","loi":1128,"payout":14,' + '"source":"m"},{"id":"a019660a8bc0411dba19a3e5c5df5b6c","id_code":"m:a019660a8bc0411dba19a3e5c5df5b6c",' + '"loi":1180,"payout":24,"source":"m"}],"currency":"USD","description":"","duration":{"max":1180,"min":1121,' + '"q1":1125,"q2":1128,"q3":1154},"id":"d3a87b2bb6cf4428a55916bdf65e775e","is_recontact":false,"name":"",' + '"payout":{"max":32,"min":14,"q1":19,"q2":24,"q3":28},"quality_score":0.0825688889614345,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=d3a87b2bb6cf4428a55916bdf65e775e&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=82fe142","x":2,"y":0}],' + '"id":"391a54bbe84c4dbfa50b40841201a606"}}' +) + +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/5fl8bpv5/?bpuid=379fb74f-05b2-42dc-b283 +# -47e1c8678b04&duration=1200&format=json&country_iso=us +RESPONSE_5fl8bpv5 = ( + '{"info":{"success":true},"offerwall":{"availability_count":9,"buckets":[{' + '"id":"a1097b20f2ae472a9a9ad2987ba3bf95",' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=a1097b20f2ae472a9a9ad2987ba3bf95&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e7baf5e"}],' + '"id":"fddaa544d7ff428a8ccccd0667fdc249","payout_format":"${payout/100:.2f}"}}' +) + +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/37d1da64/?bpuid=379fb74f-05b2-42dc-b283 +# -47e1c8678b04&duration=1200&format=json&country_iso=us +RESPONSE_37d1da64 = ( + '{"info":{"success":true},"offerwall":{"availability_count":18,"buckets":[{"category":[{"adwords_id":null,' + '"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research","p":1.0}],"contents":[{' + '"id":"0qvwx4z","id_code":"o:0qvwx4z","loi":700,"payout":106,"source":"o"},{"id":"x94r9bg","id_code":"o:x94r9bg",' + '"loi":72,"payout":132,"source":"o"},{"id":"yxqdnb9","id_code":"o:yxqdnb9","loi":605,"payout":113,"source":"o"},' + '{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":474,"payout":128,"source":"o"}],"eligibility":"conditional",' + '"id":"2cfc47e8d8c4417cb1f499dbf7e9afb8","loi":700,"missing_questions":["7ca8b59f4c864f80a1a7c7287adfc637"],' + '"payout":106,"uri":null},{"category":[{"adwords_id":null,"adwords_label":null,' + '"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research","p":1.0}],"contents":[{"id":"yxqdnb9",' + '"id_code":"o:yxqdnb9","loi":605,"payout":113,"source":"o"},{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,' + '"payout":132,"source":"o"},{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":474,"payout":128,"source":"o"}],' + '"eligibility":"unconditional","id":"8964a87ebbe9433cae0ddce1b34a637a","loi":605,"payout":113,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=8964a87ebbe9433cae0ddce1b34a637a&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=ec8aa9a"},{"category":[{' + '"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research",' + '"p":1.0}],"contents":[{"id":"a019660a8bc0411dba19a3e5c5df5b6c","id_code":"m:a019660a8bc0411dba19a3e5c5df5b6c",' + '"loi":1180,"payout":24,"source":"m"},{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},' + '{"id":"yxqdnb9","id_code":"o:yxqdnb9","loi":605,"payout":113,"source":"o"},' + '{"id":"775ed98f65604dac91b3a60814438829","id_code":"m:775ed98f65604dac91b3a60814438829","loi":1121,"payout":32,' + '"source":"m"},{"id":"ejqa3kw","id_code":"o:ejqa3kw","loi":144,"payout":68,"source":"o"},{"id":"ejqjbv4",' + '"id_code":"o:ejqjbv4","loi":474,"payout":128,"source":"o"},{"id":"yr5od0g","id_code":"o:yr5od0g","loi":457,' + '"payout":68,"source":"o"},{"id":"vyjrv0v","id_code":"o:vyjrv0v","loi":719,"payout":124,"source":"o"}],' + '"eligibility":"unconditional","id":"ba98df6041344e818f02873534ae09a3","loi":1180,"payout":24,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=ba98df6041344e818f02873534ae09a3&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=ec8aa9a"},{"category":[{' + '"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research",' + '"p":1.0}],"contents":[{"id":"775ed98f65604dac91b3a60814438829","id_code":"m:775ed98f65604dac91b3a60814438829",' + '"loi":1121,"payout":32,"source":"m"},{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},' + '{"id":"yxqdnb9","id_code":"o:yxqdnb9","loi":605,"payout":113,"source":"o"},{"id":"ejqa3kw",' + '"id_code":"o:ejqa3kw","loi":144,"payout":68,"source":"o"},{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":474,' + '"payout":128,"source":"o"},{"id":"yr5od0g","id_code":"o:yr5od0g","loi":457,"payout":68,"source":"o"},' + '{"id":"vyjrv0v","id_code":"o:vyjrv0v","loi":719,"payout":124,"source":"o"}],"eligibility":"unconditional",' + '"id":"6456abe0f2584be8a30d9e0e93bef496","loi":1121,"payout":32,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=6456abe0f2584be8a30d9e0e93bef496&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=ec8aa9a"},{"category":[{' + '"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research",' + '"p":1.0}],"contents":[{"id":"ejqa3kw","id_code":"o:ejqa3kw","loi":144,"payout":68,"source":"o"},{"id":"x94r9bg",' + '"id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,' + '"payout":132,"source":"o"}],"eligibility":"unconditional","id":"e0a0fb27bf174040be7971795a967ce5","loi":144,' + '"payout":68,"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=e0a0fb27bf174040be7971795a967ce5&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=ec8aa9a"},{"category":[{' + '"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research",' + '"p":1.0}],"contents":[{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":474,"payout":128,"source":"o"},' + '{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},{"id":"x94r9bg","id_code":"o:x94r9bg",' + '"loi":72,"payout":132,"source":"o"}],"eligibility":"unconditional","id":"99e3f8cfd367411b822cf11c3b54b558",' + '"loi":474,"payout":128,"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709' + "/00ff1d9b71b94bf4b20d22cd56774120/?i=99e3f8cfd367411b822cf11c3b54b558&b=379fb74f-05b2-42dc-b283-47e1c8678b04" + '&66482fb=ec8aa9a"},{"category":[{"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e",' + '"label":"Social Research","p":1.0}],"contents":[{"id":"8c54725047cc4e0590665a034d37e7f5",' + '"id_code":"m:8c54725047cc4e0590665a034d37e7f5","loi":1128,"payout":14,"source":"m"},{"id":"x94r9bg",' + '"id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},{"id":"yxqdnb9","id_code":"o:yxqdnb9","loi":605,' + '"payout":113,"source":"o"},{"id":"775ed98f65604dac91b3a60814438829",' + '"id_code":"m:775ed98f65604dac91b3a60814438829","loi":1121,"payout":32,"source":"m"},{"id":"ejqa3kw",' + '"id_code":"o:ejqa3kw","loi":144,"payout":68,"source":"o"},{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":474,' + '"payout":128,"source":"o"},{"id":"yr5od0g","id_code":"o:yr5od0g","loi":457,"payout":68,"source":"o"},' + '{"id":"vyjrv0v","id_code":"o:vyjrv0v","loi":719,"payout":124,"source":"o"}],"eligibility":"unconditional",' + '"id":"db7bc77afbb6443e8f35e9f06764fd06","loi":1128,"payout":14,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=db7bc77afbb6443e8f35e9f06764fd06&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=ec8aa9a"},{"category":[{' + '"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research",' + '"p":0.5},{"adwords_id":"5000","adwords_label":"World Localities","id":"cd3f9374ba5d4e5692ee5691320ecc8b",' + '"label":"World Localities","p":0.25},{"adwords_id":"14","adwords_label":"People & Society",' + '"id":"c8642a1b86d9460cbe8f7e8ae6e56ee4","label":"People & Society","p":0.25}],"contents":[{"id":"yr5od0g",' + '"id_code":"o:yr5od0g","loi":457,"payout":68,"source":"o"},{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,' + '"payout":132,"source":"o"},{"id":"ejqa3kw","id_code":"o:ejqa3kw","loi":144,"payout":68,"source":"o"}],' + '"eligibility":"unconditional","id":"d5af3c782a354d6a8616e47531eb15e7","loi":457,"payout":68,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=d5af3c782a354d6a8616e47531eb15e7&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=ec8aa9a"},{"category":[{' + '"adwords_id":null,"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research",' + '"p":1.0}],"contents":[{"id":"vyjrv0v","id_code":"o:vyjrv0v","loi":719,"payout":124,"source":"o"},' + '{"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},{"id":"ejqjbv4","id_code":"o:ejqjbv4",' + '"loi":474,"payout":128,"source":"o"}],"eligibility":"unconditional","id":"ceafb146bb7042889b85630a44338a75",' + '"loi":719,"payout":124,"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709' + "/00ff1d9b71b94bf4b20d22cd56774120/?i=ceafb146bb7042889b85630a44338a75&b=379fb74f-05b2-42dc-b283-47e1c8678b04" + '&66482fb=ec8aa9a"}],"id":"37ed0858e0f64329812e2070fb658eb3","question_info":{' + '"7ca8b59f4c864f80a1a7c7287adfc637":{"choices":[{"choice_id":"0","choice_text":"Single, never married",' + '"order":0},{"choice_id":"1","choice_text":"Living with a Partner","order":1},{"choice_id":"2",' + '"choice_text":"Civil Union / Domestic Partnership","order":2},{"choice_id":"3","choice_text":"Married",' + '"order":3},{"choice_id":"4","choice_text":"Separated","order":4},{"choice_id":"5","choice_text":"Divorced",' + '"order":5},{"choice_id":"6","choice_text":"Widowed","order":6}],"country_iso":"us","language_iso":"eng",' + '"question_id":"7ca8b59f4c864f80a1a7c7287adfc637","question_text":"What is your relationship status?",' + '"question_type":"MC","selector":"SA"}}}}' +) + +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/5fa23085/?bpuid=379fb74f-05b2-42dc-b283 +# -47e1c8678b04&duration=1200&format=json&country_iso=us +RESPONSE_5fa23085 = ( + '{"info":{"success":true},"offerwall":{"availability_count":7,"buckets":[{"category":[{"adwords_id":null,' + '"adwords_label":null,"id":"c82cf98c578a43218334544ab376b00e","label":"Social Research","p":0.6666666666666666},' + '{"adwords_id":"5000","adwords_label":"World Localities","id":"cd3f9374ba5d4e5692ee5691320ecc8b","label":"World ' + 'Localities","p":0.16666666666666666},{"adwords_id":"14","adwords_label":"People & Society",' + '"id":"c8642a1b86d9460cbe8f7e8ae6e56ee4","label":"People & Society","p":0.16666666666666666}],"contents":[{' + '"id":"x94r9bg","id_code":"o:x94r9bg","loi":72,"payout":132,"source":"o"},{"id":"yxqdnb9","id_code":"o:yxqdnb9",' + '"loi":605,"payout":113,"source":"o"},{"id":"ejqjbv4","id_code":"o:ejqjbv4","loi":640,"payout":128,"source":"o"},' + '{"id":"yr5od0g","id_code":"o:yr5od0g","loi":457,"payout":68,"source":"o"},{"id":"vyjrv0v","id_code":"o:vyjrv0v",' + '"loi":719,"payout":124,"source":"o"}],"duration":{"max":719,"min":72,"q1":457,"q2":605,"q3":640},' + '"id":"30049c28363b4f689c84fbacf1cc57e2","payout":{"max":132,"min":68,"q1":113,"q2":124,"q3":128},' + '"source":"pollfish","uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120' + '/?i=30049c28363b4f689c84fbacf1cc57e2&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e3c1af1"}],' + '"id":"1c33756fd7ca49fa84ee48e42145e68c"}}' +) + +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/1705e4f8/?bpuid=379fb74f-05b2-42dc-b283 +# -47e1c8678b04&duration=1200&format=json&country_iso=us +RESPONSE_1705e4f8 = ( + '{"info":{"success":true},"offerwall":{"availability_count":7,"buckets":[{"currency":"USD","duration":719,' + '"id":"d06abd6ac75b453a93d0e85e4e391c00","min_payout":124,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=d06abd6ac75b453a93d0e85e4e391c00&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":705,"id":"bcf5ca85a4044e9abed163f72039c7d1","min_payout":123,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=bcf5ca85a4044e9abed163f72039c7d1&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":679,"id":"b85034bd59e04a64ac1be7686a4c906d","min_payout":121,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=b85034bd59e04a64ac1be7686a4c906d&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":1180,"id":"324cf03b9ba24ef19284683bf9b62afb","min_payout":24,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=324cf03b9ba24ef19284683bf9b62afb&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":1062,"id":"6bbbef3157c346009e677e556ecea7e7","min_payout":17,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=6bbbef3157c346009e677e556ecea7e7&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":1128,"id":"bc12ba86b5024cf5b8ade997415190f7","min_payout":14,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=bc12ba86b5024cf5b8ade997415190f7&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":1126,"id":"e574e715dcc744dbacf84a8426a0cd37","min_payout":10,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=e574e715dcc744dbacf84a8426a0cd37&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":1302,"id":"a6f562e3371347d19d281d40b3ca317d","min_payout":10,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=a6f562e3371347d19d281d40b3ca317d&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"},{"currency":"USD",' + '"duration":953,"id":"462c8e1fbdad475792a360097c8e740f","min_payout":4,' + '"uri":"https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i' + '=462c8e1fbdad475792a360097c8e740f&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=e09baa2"}],' + '"id":"7f5158dc25174ada89f54e2b26a61b20"}}' +) + +# Blocked user +# https://fsb.generalresearch.com/00ff1d9b71b94bf4b20d22cd56774120/offerwall/0af0f7ec/?bpuid=00051c62-a872-4832-a008 +# -c37ec51d33d3&duration=1200&format=json&country_iso=us +RESPONSE_0af0f7ec = ( + '{"info":{"success":true},"offerwall":{"availability_count":0,"buckets":[],' + '"id":"34e71f4ccdec47b3b0991f5cfda60238","payout_format":"${payout/100:.2f}"}}' +) diff --git a/tests/models/legacy/test_offerwall_parse_response.py b/tests/models/legacy/test_offerwall_parse_response.py new file mode 100644 index 0000000..7fb5315 --- /dev/null +++ b/tests/models/legacy/test_offerwall_parse_response.py @@ -0,0 +1,186 @@ +import json + +from generalresearch.models import Source +from generalresearch.models.legacy.bucket import ( + TopNPlusBucket, + SurveyEligibilityCriterion, + DurationSummary, + PayoutSummary, + BucketTask, +) + + +class TestOfferwallTopNAndStarwall: + def test_45b7228a7(self): + from generalresearch.models.legacy.offerwall import ( + TopNOfferWall, + TopNOfferWallResponse, + StarwallOfferWallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_45b7228a7, + ) + + res = json.loads(RESPONSE_45b7228a7) + assert TopNOfferWallResponse.model_validate(res) + offerwall = TopNOfferWall.model_validate(res["offerwall"]) + assert offerwall + offerwall.censor() + # Format is identical to starwall + assert StarwallOfferWallResponse.model_validate(res) + + def test_b145b803(self): + from generalresearch.models.legacy.offerwall import ( + TopNPlusOfferWallResponse, + StarwallPlusOfferWallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_b145b803, + ) + + res = json.loads(RESPONSE_b145b803) + assert TopNPlusOfferWallResponse.model_validate(res) + assert StarwallPlusOfferWallResponse.model_validate(res) + + def test_d48cce47(self): + from generalresearch.models.legacy.offerwall import ( + TopNPlusBlockOfferWallResponse, + StarwallPlusBlockOfferWallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_b145b803, + RESPONSE_d48cce47, + ) + + res = json.loads(RESPONSE_d48cce47) # this is a blocked user's response + assert TopNPlusBlockOfferWallResponse.model_validate(res) + assert StarwallPlusBlockOfferWallResponse.model_validate(res) + # otherwise it is identical to the plus's response + res = json.loads(RESPONSE_b145b803) + assert TopNPlusBlockOfferWallResponse.model_validate(res) + assert StarwallPlusBlockOfferWallResponse.model_validate(res) + + def test_1e5f0af8(self): + from generalresearch.models.legacy.offerwall import ( + TopNPlusBlockRecontactOfferWallResponse, + StarwallPlusBlockRecontactOfferWallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_d48cce47, + RESPONSE_1e5f0af8, + ) + + res = json.loads(RESPONSE_1e5f0af8) + assert TopNPlusBlockRecontactOfferWallResponse.model_validate(res) + assert StarwallPlusBlockRecontactOfferWallResponse.model_validate(res) + + res = json.loads(RESPONSE_d48cce47) # this is a blocked user's response + assert TopNPlusBlockRecontactOfferWallResponse.model_validate(res) + assert StarwallPlusBlockRecontactOfferWallResponse.model_validate(res) + + def test_eligibility_criteria(self): + b = TopNPlusBucket( + id="c82cf98c578a43218334544ab376b00e", + contents=[ + BucketTask( + id="12345", + payout=10, + source=Source.TESTING, + id_code="t:12345", + loi=120, + ) + ], + duration=DurationSummary(max=1, min=1, q1=1, q2=1, q3=1), + quality_score=1, + payout=PayoutSummary(max=1, min=1, q1=1, q2=1, q3=1), + uri="https://task.generalresearch.com/api/v1/52d3f63b2709/00ff1d9b71b94bf4b20d22cd56774120/?i=2a4a897a76464af2b85703b72a125da0&b=379fb74f-05b2-42dc-b283-47e1c8678b04&66482fb=82fe142", + eligibility_criteria=( + SurveyEligibilityCriterion( + question_id="71a367fb71b243dc89f0012e0ec91749", + question_text="what is something", + qualifying_answer=("1",), + qualifying_answer_label=("abc",), + property_code="t:123", + ), + SurveyEligibilityCriterion( + question_id="81a367fb71b243dc89f0012e0ec91749", + question_text="what is something 2", + qualifying_answer=("2",), + qualifying_answer_label=("ddd",), + property_code="t:124", + ), + ), + ) + assert b.eligibility_criteria[0].rank == 0 + assert b.eligibility_criteria[1].rank == 1 + print(b.model_dump_json()) + b.censor() + print(b.model_dump_json()) + + +class TestOfferwallSingle: + def test_5fl8bpv5(self): + from generalresearch.models.legacy.offerwall import ( + SingleEntryOfferWallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_5fl8bpv5, + ) + + res = json.loads(RESPONSE_5fl8bpv5) + assert SingleEntryOfferWallResponse.model_validate(res) + + +class TestOfferwallSoftPair: + def test_37d1da64(self): + from generalresearch.models.legacy.offerwall import ( + SoftPairOfferwallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_37d1da64, + ) + + res = json.loads(RESPONSE_37d1da64) + assert SoftPairOfferwallResponse.model_validate(res) + + +class TestMarketplace: + def test_5fa23085(self): + from generalresearch.models.legacy.offerwall import ( + MarketplaceOfferwallResponse, + ) + + from tests.models.legacy.data import ( + RESPONSE_5fa23085, + ) + + res = json.loads(RESPONSE_5fa23085) + assert MarketplaceOfferwallResponse.model_validate(res) + + +class TestTimebucks: + def test_1705e4f8(self): + from generalresearch.models.legacy.offerwall import ( + TimeBucksOfferwallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_1705e4f8, + ) + + res = json.loads(RESPONSE_1705e4f8) + assert TimeBucksOfferwallResponse.model_validate(res) + + def test_0af0f7ec(self): + from generalresearch.models.legacy.offerwall import ( + TimeBucksBlockOfferwallResponse, + ) + from tests.models.legacy.data import ( + RESPONSE_1705e4f8, + RESPONSE_0af0f7ec, + ) + + res = json.loads(RESPONSE_0af0f7ec) + assert TimeBucksBlockOfferwallResponse.model_validate(res) + + res = json.loads(RESPONSE_1705e4f8) + assert TimeBucksBlockOfferwallResponse.model_validate(res) diff --git a/tests/models/legacy/test_profiling_questions.py b/tests/models/legacy/test_profiling_questions.py new file mode 100644 index 0000000..1afaa6b --- /dev/null +++ b/tests/models/legacy/test_profiling_questions.py @@ -0,0 +1,81 @@ +class TestUpkQuestionResponse: + + def test_init(self): + from generalresearch.models.legacy.questions import UpkQuestionResponse + + s = ( + '{"status": "success", "count": 7, "questions": [{"selector": "SL", "validation": {"patterns": [{' + '"message": "Must input a value between 13 and 120", "pattern": "^(1[01][0-9]|120|1[3-9]|[2-9][' + '0-9])$"}]}, "country_iso": "us", "question_id": "c5a4ef644c374f8994ecb3226b84263e", "language_iso": ' + '"eng", "configuration": {"max_length": 3, "type": "TE"}, "question_text": "What is your age (in ' + 'years)?", "question_type": "TE", "task_score": 20.28987136651298, "task_count": 21131, "p": 1.0}, ' + '{"choices": [{"order": 0, "choice_id": "0", "choice_text": "Male"}, {"order": 1, "choice_id": "1", ' + '"choice_text": "Female"}, {"order": 2, "choice_id": "2", "choice_text": "Other"}], "selector": "SA", ' + '"country_iso": "us", "question_id": "5d6d9f3c03bb40bf9d0a24f306387d7c", "language_iso": "eng", ' + '"question_text": "What is your gender?", "question_type": "MC", "task_score": 16.598347505339095, ' + '"task_count": 4842, "p": 0.8180607558081178}, {"choices": [{"order": 0, "choice_id": "ara", ' + '"choice_text": "Arabic"}, {"order": 1, "choice_id": "zho", "choice_text": "Chinese - Mandarin"}, ' + '{"order": 2, "choice_id": "dut", "choice_text": "Dutch"}, {"order": 3, "choice_id": "eng", ' + '"choice_text": "English"}, {"order": 4, "choice_id": "fre", "choice_text": "French"}, {"order": 5, ' + '"choice_id": "ger", "choice_text": "German"}, {"order": 6, "choice_id": "hat", "choice_text": "Haitian ' + 'Creole"}, {"order": 7, "choice_id": "hin", "choice_text": "Hindi"}, {"order": 8, "choice_id": "ind", ' + '"choice_text": "Indonesian"}, {"order": 9, "choice_id": "ita", "choice_text": "Italian"}, {"order": 10, ' + '"choice_id": "jpn", "choice_text": "Japanese"}, {"order": 11, "choice_id": "kor", "choice_text": ' + '"Korean"}, {"order": 12, "choice_id": "may", "choice_text": "Malay"}, {"order": 13, "choice_id": "pol", ' + '"choice_text": "Polish"}, {"order": 14, "choice_id": "por", "choice_text": "Portuguese"}, {"order": 15, ' + '"choice_id": "pan", "choice_text": "Punjabi"}, {"order": 16, "choice_id": "rus", "choice_text": ' + '"Russian"}, {"order": 17, "choice_id": "spa", "choice_text": "Spanish"}, {"order": 18, "choice_id": ' + '"tgl", "choice_text": "Tagalog/Filipino"}, {"order": 19, "choice_id": "tur", "choice_text": "Turkish"}, ' + '{"order": 20, "choice_id": "vie", "choice_text": "Vietnamese"}, {"order": 21, "choice_id": "zul", ' + '"choice_text": "Zulu"}, {"order": 22, "choice_id": "xxx", "choice_text": "Other"}], "selector": "MA", ' + '"country_iso": "us", "question_id": "f15663d012244d5fa43f5784f7bd1898", "language_iso": "eng", ' + '"question_text": "Which language(s) do you speak fluently at home? (Select all that apply)", ' + '"question_type": "MC", "task_score": 15.835933296975051, "task_count": 147, "p": 0.780484657143325}, ' + '{"selector": "SL", "validation": {"patterns": [{"message": "Must enter a valid zip code: XXXXX", ' + '"pattern": "^[0-9]{5}$"}]}, "country_iso": "us", "question_id": "543de254e9ca4d9faded4377edab82a9", ' + '"language_iso": "eng", "configuration": {"max_length": 5, "min_length": 5, "type": "TE"}, ' + '"question_text": "What is your zip code?", "question_type": "TE", "task_score": 3.9114103408096685, ' + '"task_count": 4116, "p": 0.19277649769949645}, {"selector": "SL", "validation": {"patterns": [{' + '"message": "Must input digits only (range between 1 and 999999)", "pattern": "^[0-9]{1,6}$"}]}, ' + '"country_iso": "us", "question_id": "9ffacedc92584215912062a9d75338fa", "language_iso": "eng", ' + '"configuration": {"max_length": 6, "type": "TE"}, "question_text": "What is your current annual ' + 'household income before taxes (in USD)?", "question_type": "TE", "task_score": 2.4630414657197686, ' + '"task_count": 3267, "p": 0.12139266046727369}, {"choices": [{"order": 0, "choice_id": "0", ' + '"choice_text": "Employed full-time"}, {"order": 1, "choice_id": "1", "choice_text": "Employed ' + 'part-time"}, {"order": 2, "choice_id": "2", "choice_text": "Self-employed full-time"}, {"order": 3, ' + '"choice_id": "3", "choice_text": "Self-employed part-time"}, {"order": 4, "choice_id": "4", ' + '"choice_text": "Active military"}, {"order": 5, "choice_id": "5", "choice_text": "Inactive ' + 'military/Veteran"}, {"order": 6, "choice_id": "6", "choice_text": "Temporarily unemployed"}, ' + '{"order": 7, "choice_id": "7", "choice_text": "Full-time homemaker"}, {"order": 8, "choice_id": "8", ' + '"choice_text": "Retired"}, {"order": 9, "choice_id": "9", "choice_text": "Student"}, {"order": 10, ' + '"choice_id": "10", "choice_text": "Disabled"}], "selector": "SA", "country_iso": "us", "question_id": ' + '"b546d26651f040c9a6900ffb126e7d69", "language_iso": "eng", "question_text": "What is your current ' + 'employment status?", "question_type": "MC", "task_score": 1.6940674222375414, "task_count": 1134, ' + '"p": 0.0834932559027201}, {"choices": [{"order": 0, "choice_id": "0", "choice_text": "No"}, ' + '{"order": 1, "choice_id": "1", "choice_text": "Yes, Mexican"}, {"order": 2, "choice_id": "2", ' + '"choice_text": "Yes, Puerto Rican"}, {"order": 3, "choice_id": "3", "choice_text": "Yes, Cuban"}, ' + '{"order": 4, "choice_id": "4", "choice_text": "Yes, Salvadoran"}, {"order": 5, "choice_id": "5", ' + '"choice_text": "Yes, Dominican"}, {"order": 6, "choice_id": "6", "choice_text": "Yes, Guatemalan"}, ' + '{"order": 7, "choice_id": "7", "choice_text": "Yes, Colombian"}, {"order": 8, "choice_id": "8", ' + '"choice_text": "Yes, Honduran"}, {"order": 9, "choice_id": "9", "choice_text": "Yes, Ecuadorian"}, ' + '{"order": 10, "choice_id": "10", "choice_text": "Yes, Argentinian"}, {"order": 11, "choice_id": "11", ' + '"choice_text": "Yes, Peruvian"}, {"order": 12, "choice_id": "12", "choice_text": "Yes, Nicaraguan"}, ' + '{"order": 13, "choice_id": "13", "choice_text": "Yes, Spaniard"}, {"order": 14, "choice_id": "14", ' + '"choice_text": "Yes, Venezuelan"}, {"order": 15, "choice_id": "15", "choice_text": "Yes, Panamanian"}, ' + '{"order": 16, "choice_id": "16", "choice_text": "Yes, Other"}], "selector": "SA", "country_iso": "us", ' + '"question_id": "7d452b8069c24a1aacbccbf767910345", "language_iso": "eng", "question_text": "Are you of ' + 'Hispanic, Latino, or Spanish origin?", "question_type": "MC", "task_score": 1.6342156553005878, ' + '"task_count": 2516, "p": 0.08054342118687588}], "special_questions": [{"selector": "HIDDEN", ' + '"country_iso": "xx", "question_id": "1d1e2e8380ac474b87fb4e4c569b48df", "language_iso": "xxx", ' + '"question_type": "HIDDEN"}, {"selector": "HIDDEN", "country_iso": "xx", "question_id": ' + '"2fbedb2b9f7647b09ff5e52fa119cc5e", "language_iso": "xxx", "question_type": "HIDDEN"}, {"selector": ' + '"HIDDEN", "country_iso": "xx", "question_id": "4030c52371b04e80b64e058d9c5b82e9", "language_iso": ' + '"xxx", "question_type": "HIDDEN"}, {"selector": "HIDDEN", "country_iso": "xx", "question_id": ' + '"a91cb1dea814480dba12d9b7b48696dd", "language_iso": "xxx", "question_type": "HIDDEN"}, {"selector": ' + '"HIDDEN", "task_count": 40.0, "task_score": 0.45961204566189584, "country_iso": "us", "question_id": ' + '"59f39a785f154752b6435c260cbce3c6", "language_iso": "eng", "question_text": "Core-Based Statistical ' + 'Area (2020)", "question_type": "HIDDEN"}], "consent_questions": []}' + ) + instance = UpkQuestionResponse.model_validate_json(s) + + assert isinstance(instance, UpkQuestionResponse) diff --git a/tests/models/legacy/test_user_question_answer_in.py b/tests/models/legacy/test_user_question_answer_in.py new file mode 100644 index 0000000..253c46e --- /dev/null +++ b/tests/models/legacy/test_user_question_answer_in.py @@ -0,0 +1,304 @@ +import json +from decimal import Decimal +from uuid import uuid4 + +import pytest + + +class TestUserQuestionAnswers: + """This is for the GRS POST submission that may contain multiple + Question+Answer(s) combinations for a single GRS Survey. It is + responsible for making sure the same question isn't submitted + more than once per submission, and other "list validation" + checks that aren't possible on an individual level. + """ + + def test_json_init( + self, + product_manager, + user_manager, + session_manager, + wall_manager, + user_factory, + product, + session_factory, + utc_hour_ago, + ): + from generalresearch.models.thl.session import Session, Wall + from generalresearch.models.thl.user import User + from generalresearch.models import Source + from generalresearch.models.legacy.questions import ( + UserQuestionAnswers, + ) + + u: User = user_factory(product=product) + + s1: Session = session_factory( + user=u, + wall_count=1, + started=utc_hour_ago, + wall_req_cpi=Decimal("0.00"), + wall_source=Source.GRS, + ) + assert isinstance(s1, Session) + w1 = s1.wall_events[0] + assert isinstance(w1, Wall) + + instance = UserQuestionAnswers.model_validate_json( + json.dumps( + { + "product_id": product.uuid, + "product_user_id": u.product_user_id, + "session_id": w1.uuid, + "answers": [ + {"question_id": uuid4().hex, "answer": ["a", "b"]}, + {"question_id": uuid4().hex, "answer": ["a", "b"]}, + ], + } + ) + ) + assert isinstance(instance, UserQuestionAnswers) + + def test_simple_validation_errors( + self, product_manager, user_manager, session_manager, wall_manager + ): + from generalresearch.models.legacy.questions import ( + UserQuestionAnswers, + ) + + with pytest.raises(ValueError): + UserQuestionAnswers.model_validate( + { + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "session_id": uuid4().hex, + "answers": [{"question_id": uuid4().hex, "answer": ["a", "b"]}], + } + ) + + with pytest.raises(ValueError): + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "session_id": uuid4().hex, + "answers": [{"question_id": uuid4().hex, "answer": ["a", "b"]}], + } + ) + + # user is validated only if a session_id is passed + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "answers": [{"question_id": uuid4().hex, "answer": ["a", "b"]}], + } + ) + + with pytest.raises(ValueError): + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "session_id": uuid4().hex, + } + ) + + with pytest.raises(ValueError): + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "session_id": uuid4().hex, + "answers": [], + } + ) + + with pytest.raises(ValueError): + answers = [ + {"question_id": uuid4().hex, "answer": ["a"]} for i in range(101) + ] + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "session_id": uuid4().hex, + "answers": answers, + } + ) + + with pytest.raises(ValueError): + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "session_id": uuid4().hex, + "answers": "aaa", + } + ) + + def test_no_duplicate_questions(self): + # TODO: depending on if or how many of these types of errors actually + # occur, we could get fancy and just drop one of them. I don't + # think this is worth exploring yet unless we see if it's a problem. + from generalresearch.models.legacy.questions import ( + UserQuestionAnswers, + ) + + consistent_qid = uuid4().hex + with pytest.raises(ValueError) as cm: + UserQuestionAnswers.model_validate( + { + "product_id": uuid4().hex, + "product_user_id": f"test-user-{uuid4().hex[:6]}", + "session_id": uuid4().hex, + "answers": [ + {"question_id": consistent_qid, "answer": ["aaa"]}, + {"question_id": consistent_qid, "answer": ["bbb"]}, + ], + } + ) + + assert "Don't provide answers to duplicate questions" in str(cm.value) + + def test_allow_answer_failures_silent( + self, + product_manager, + user_manager, + session_manager, + wall_manager, + product, + user_factory, + utc_hour_ago, + session_factory, + ): + """ + There are many instances where suppliers may be submitting answers + manually, and they're just totally broken. We want to silently remove + that one QuestionAnswerIn without "loosing" any of the other + QuestionAnswerIn items that they provided. + """ + from generalresearch.models.thl.session import Session, Wall + from generalresearch.models.thl.user import User + from generalresearch.models.legacy.questions import ( + UserQuestionAnswers, + ) + + u: User = user_factory(product=product) + + s1: Session = session_factory(user=u, wall_count=1, started=utc_hour_ago) + assert isinstance(s1, Session) + w1 = s1.wall_events[0] + assert isinstance(w1, Wall) + + data = { + "product_id": product.uuid, + "product_user_id": u.product_user_id, + "session_id": w1.uuid, + "answers": [ + {"question_id": uuid4().hex, "answer": ["aaa"]}, + {"question_id": f"broken-{uuid4().hex[:6]}", "answer": ["bbb"]}, + ], + } + # load via .model_validate() + instance = UserQuestionAnswers.model_validate(data) + assert isinstance(instance, UserQuestionAnswers) + + # One of the QuestionAnswerIn items was invalid, so it was dropped + assert 1 == len(instance.answers) + + # Confirm that this also works via model_validate_json + json_data = json.dumps(data) + instance = UserQuestionAnswers.model_validate_json(json_data) + assert isinstance(instance, UserQuestionAnswers) + + # One of the QuestionAnswerIn items was invalid, so it was dropped + assert 1 == len(instance.answers) + + assert instance.user is None + instance.prefetch_user(um=user_manager) + assert isinstance(instance.user, User) + + +class TestUserQuestionAnswerIn: + """This is for the individual Question+Answer(s) that may come back from + a GRS POST. + """ + + def test_simple_validation_errors(self): + from generalresearch.models.legacy.questions import ( + UserQuestionAnswerIn, + ) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate( + {"question_id": f"test-{uuid4().hex[:6]}", "answer": ["123"]} + ) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate({"answer": ["123"]}) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": [123]} + ) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": [""]} + ) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": [" "]} + ) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": ["a" * 5_001]} + ) + + with pytest.raises(ValueError): + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": []} + ) + + def test_only_single_answers(self): + from generalresearch.models.legacy.questions import ( + UserQuestionAnswerIn, + ) + + for qid in { + "2fbedb2b9f7647b09ff5e52fa119cc5e", + "4030c52371b04e80b64e058d9c5b82e9", + "a91cb1dea814480dba12d9b7b48696dd", + "1d1e2e8380ac474b87fb4e4c569b48df", + }: + # This is the UserAgent question which only allows a single answer + with pytest.raises(ValueError) as cm: + UserQuestionAnswerIn.model_validate( + {"question_id": qid, "answer": ["a", "b"]} + ) + + assert "Too many answer values provided" in str(cm.value) + + def test_answer_item_limit(self): + from generalresearch.models.legacy.questions import ( + UserQuestionAnswerIn, + ) + + answer = [uuid4().hex[:6] for i in range(11)] + with pytest.raises(ValueError) as cm: + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": answer} + ) + assert "List should have at most 10 items after validation" in str(cm.value) + + def test_disallow_duplicate_answer_values(self): + from generalresearch.models.legacy.questions import ( + UserQuestionAnswerIn, + ) + + answer = ["aaa" for i in range(5)] + with pytest.raises(ValueError) as cm: + UserQuestionAnswerIn.model_validate( + {"question_id": uuid4().hex, "answer": answer} + ) diff --git a/tests/models/morning/__init__.py b/tests/models/morning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/morning/test.py b/tests/models/morning/test.py new file mode 100644 index 0000000..cf4982d --- /dev/null +++ b/tests/models/morning/test.py @@ -0,0 +1,199 @@ +from datetime import datetime, timezone + +from generalresearch.models.morning.question import MorningQuestion + +bid = { + "buyer_account_id": "ab180f06-aa2b-4b8b-9b87-1031bfe8b16b", + "buyer_id": "5f3b4daa-6ff0-4826-a551-9d4572ea1c84", + "country_id": "us", + "end_date": "2024-07-19T09:01:13.520243Z", + "exclusions": [ + {"group_id": "66070689-5198-5782-b388-33daa74f3269", "lockout_period": 28} + ], + "id": "5324c2ac-eca8-4ed0-8b0e-042ba3aa2a85", + "language_ids": ["en"], + "name": "Ad-Hoc Survey", + "published_at": "2024-06-19T09:01:13.520243Z", + "quotas": [ + { + "cost_per_interview": 154, + "id": "b8ade883-a83d-4d8e-9ef7-953f4b692bd8", + "qualifications": [ + { + "id": "age", + "response_ids": [ + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "25", + "26", + "27", + "28", + "29", + "30", + "31", + "32", + "33", + "34", + ], + }, + {"id": "gender", "response_ids": ["1"]}, + {"id": "hispanic", "response_ids": ["1"]}, + ], + "statistics": { + "length_of_interview": 1353, + "median_length_of_interview": 1353, + "num_available": 3, + "num_completes": 7, + "num_failures": 0, + "num_in_progress": 4, + "num_over_quotas": 0, + "num_qualified": 27, + "num_quality_terminations": 14, + "num_timeouts": 1, + "qualified_conversion": 30, + }, + } + ], + "state": "active", + "statistics": { + "earnings_per_click": 26, + "estimated_length_of_interview": 1140, + "incidence_rate": 77, + "length_of_interview": 1198, + "median_length_of_interview": 1198, + "num_available": 70, + "num_completes": 360, + "num_entrants": 1467, + "num_failures": 0, + "num_in_progress": 48, + "num_over_quotas": 10, + "num_qualified": 1121, + "num_quality_terminations": 584, + "num_screenouts": 380, + "num_timeouts": 85, + "qualified_conversion": 34, + "system_conversion": 25, + }, + "supplier_exclusive": False, + "survey_type": "ad_hoc", + "timeout": 21600, + "topic_id": "general", +} + +bid = { + "_experimental_single_use_qualifications": [ + { + "id": "electric_car_test", + "name": "Electric Car Test", + "text": "What kind of vehicle do you drive?", + "language_ids": ["en"], + "responses": [{"id": "1", "text": "electric"}, {"id": "2", "text": "gas"}], + "type": "multiple_choice", + } + ], + "buyer_account_id": "0b6f207c-96e1-4dce-b032-566a815ad263", + "buyer_id": "9020f6f3-db41-470a-a5d7-c04fa2da9156", + "closed_at": "2022-01-01T00:00:00Z", + "country_id": "us", + "end_date": "2022-01-01T00:00:00Z", + "exclusions": [ + {"group_id": "0bbae805-5a80-42e3-8d5f-cb056a0f825d", "lockout_period": 7} + ], + "id": "000f09a3-bc25-4adc-a443-a9975800e7ac", + "language_ids": ["en", "es"], + "name": "My Example Survey", + "published_at": "2021-12-30T00:00:00Z", + "quotas": [ + { + "_experimental_single_use_qualifications": [ + {"id": "electric_car_test", "response_ids": ["1"]} + ], + "cost_per_interview": 100, + "id": "6a7d0190-e6ad-4a59-9945-7ba460517f2b", + "qualifications": [ + {"id": "gender", "response_ids": ["1"]}, + {"id": "age", "response_ids": ["18", "19", "20", "21"]}, + ], + "statistics": { + "length_of_interview": 600, + "median_length_of_interview": 600, + "num_available": 500, + "num_completes": 100, + "num_failures": 0, + "num_in_progress": 0, + "num_over_quotas": 0, + "num_qualified": 100, + "num_quality_terminations": 0, + "num_timeouts": 0, + "qualified_conversion": 100, + }, + } + ], + "state": "active", + "statistics": { + "earnings_per_click": 50, + "estimated_length_of_interview": 720, + "incidence_rate": 100, + "length_of_interview": 600, + "median_length_of_interview": 600, + "num_available": 500, + "num_completes": 100, + "num_entrants": 100, + "num_failures": 0, + "num_in_progress": 0, + "num_over_quotas": 0, + "num_qualified": 100, + "num_quality_terminations": 0, + "num_screenouts": 0, + "num_timeouts": 0, + "qualified_conversion": 100, + "system_conversion": 100, + }, + "supplier_exclusive": False, + "survey_type": "ad_hoc", + "timeout": 3600, + "topic_id": "general", +} + +# what gets run in MorningAPI._format_bid +bid["language_isos"] = ("eng",) +bid["country_iso"] = "us" +bid["end_date"] = datetime(2024, 7, 19, 9, 1, 13, 520243, tzinfo=timezone.utc) +bid["published_at"] = datetime(2024, 6, 19, 9, 1, 13, 520243, tzinfo=timezone.utc) +bid.update(bid["statistics"]) +bid["qualified_conversion"] /= 100 +bid["system_conversion"] /= 100 +for quota in bid["quotas"]: + quota.update(quota["statistics"]) + quota["qualified_conversion"] /= 100 + quota["cost_per_interview"] /= 100 +if "_experimental_single_use_qualifications" in bid: + bid["experimental_single_use_qualifications"] = [ + MorningQuestion.from_api(q, bid["country_iso"], "eng") + for q in bid["_experimental_single_use_qualifications"] + ] + + +class TestMorningBid: + + def test_model_validate(self): + from generalresearch.models.morning.survey import MorningBid + + s = MorningBid.model_validate(bid) + d = s.model_dump(mode="json") + d = s.to_mysql() + + def test_manager(self): + # todo: credentials n stuff + pass + # sql_helper = SqlHelper(host="localhost", user="root", password="", db="300large-morning") + # m = MorningSurveyManager(sql_helper=sql_helper) + # s = MorningBid.model_validate(bid) + # m.create(s) + # res = m.get_survey_library()[0] + # MorningBid.model_validate(res) diff --git a/tests/models/precision/__init__.py b/tests/models/precision/__init__.py new file mode 100644 index 0000000..8006fa3 --- /dev/null +++ b/tests/models/precision/__init__.py @@ -0,0 +1,115 @@ +survey_json = { + "cpi": "1.44", + "country_isos": "ca", + "language_isos": "eng", + "country_iso": "ca", + "language_iso": "eng", + "buyer_id": "7047", + "bid_loi": 1200, + "bid_ir": 0.45, + "source": "e", + "used_question_ids": ["age", "country_iso", "gender", "gender_1"], + "survey_id": "0000", + "group_id": "633473", + "status": "open", + "name": "beauty survey", + "survey_guid": "c7f375c5077d4c6c8209ff0b539d7183", + "category_id": "-1", + "global_conversion": None, + "desired_count": 96, + "achieved_count": 0, + "allowed_devices": "1,2,3", + "entry_link": "https://www.opinionetwork.com/survey/entry.aspx?mid=[%MID%]&project=633473&key=%%key%%", + "excluded_surveys": "470358,633286", + "quotas": [ + { + "name": "25-34,Male,Quebec", + "id": "2324110", + "guid": "23b5760d24994bc08de451b3e62e77c7", + "status": "open", + "desired_count": 48, + "achieved_count": 0, + "termination_count": 0, + "overquota_count": 0, + "condition_hashes": ["b41e1a3", "bc89ee8", "4124366", "9f32c61"], + }, + { + "name": "25-34,Female,Quebec", + "id": "2324111", + "guid": "0706f1a88d7e4f11ad847c03012e68d2", + "status": "open", + "desired_count": 48, + "achieved_count": 0, + "termination_count": 4, + "overquota_count": 0, + "condition_hashes": ["b41e1a3", "0cdc304", "500af2c", "9f32c61"], + }, + ], + "conditions": { + "b41e1a3": { + "logical_operator": "OR", + "value_type": 1, + "negate": False, + "question_id": "country_iso", + "values": ["ca"], + "criterion_hash": "b41e1a3", + "value_len": 1, + "sizeof": 2, + }, + "bc89ee8": { + "logical_operator": "OR", + "value_type": 1, + "negate": False, + "question_id": "gender", + "values": ["male"], + "criterion_hash": "bc89ee8", + "value_len": 1, + "sizeof": 4, + }, + "4124366": { + "logical_operator": "OR", + "value_type": 1, + "negate": False, + "question_id": "gender_1", + "values": ["male"], + "criterion_hash": "4124366", + "value_len": 1, + "sizeof": 4, + }, + "9f32c61": { + "logical_operator": "OR", + "value_type": 1, + "negate": False, + "question_id": "age", + "values": ["25", "26", "27", "28", "29", "30", "31", "32", "33", "34"], + "criterion_hash": "9f32c61", + "value_len": 10, + "sizeof": 20, + }, + "0cdc304": { + "logical_operator": "OR", + "value_type": 1, + "negate": False, + "question_id": "gender", + "values": ["female"], + "criterion_hash": "0cdc304", + "value_len": 1, + "sizeof": 6, + }, + "500af2c": { + "logical_operator": "OR", + "value_type": 1, + "negate": False, + "question_id": "gender_1", + "values": ["female"], + "criterion_hash": "500af2c", + "value_len": 1, + "sizeof": 6, + }, + }, + "expected_end_date": "2024-06-28T10:40:33.000000Z", + "created": None, + "updated": None, + "is_live": True, + "all_hashes": ["0cdc304", "b41e1a3", "9f32c61", "bc89ee8", "4124366", "500af2c"], +} diff --git a/tests/models/precision/test_survey.py b/tests/models/precision/test_survey.py new file mode 100644 index 0000000..ff2d6d1 --- /dev/null +++ b/tests/models/precision/test_survey.py @@ -0,0 +1,88 @@ +class TestPrecisionQuota: + + def test_quota_passes(self): + from generalresearch.models.precision.survey import PrecisionSurvey + from tests.models.precision import survey_json + + s = PrecisionSurvey.model_validate(survey_json) + q = s.quotas[0] + ce = {k: True for k in ["b41e1a3", "bc89ee8", "4124366", "9f32c61"]} + assert q.matches(ce) + + ce["b41e1a3"] = False + assert not q.matches(ce) + + ce.pop("b41e1a3") + assert not q.matches(ce) + assert not q.matches({}) + + def test_quota_passes_closed(self): + from generalresearch.models.precision import PrecisionStatus + from generalresearch.models.precision.survey import PrecisionSurvey + from tests.models.precision import survey_json + + s = PrecisionSurvey.model_validate(survey_json) + q = s.quotas[0] + q.status = PrecisionStatus.CLOSED + ce = {k: True for k in ["b41e1a3", "bc89ee8", "4124366", "9f32c61"]} + # We still match, but the quota is not open + assert q.matches(ce) + assert not q.is_open + + +class TestPrecisionSurvey: + + def test_passes(self): + from generalresearch.models.precision.survey import PrecisionSurvey + from tests.models.precision import survey_json + + s = PrecisionSurvey.model_validate(survey_json) + ce = {k: True for k in ["b41e1a3", "bc89ee8", "4124366", "9f32c61"]} + assert s.determine_eligibility(ce) + + def test_elig_closed_quota(self): + from generalresearch.models.precision import PrecisionStatus + from generalresearch.models.precision.survey import PrecisionSurvey + from tests.models.precision import survey_json + + s = PrecisionSurvey.model_validate(survey_json) + ce = {k: True for k in ["b41e1a3", "bc89ee8", "4124366", "9f32c61"]} + q = s.quotas[0] + q.status = PrecisionStatus.CLOSED + # We match a closed quota + assert not s.determine_eligibility(ce) + + s.quotas[0].status = PrecisionStatus.OPEN + s.quotas[1].status = PrecisionStatus.CLOSED + # Now me match an open quota and dont match the closed quota, so we should be eligible + assert s.determine_eligibility(ce) + + def test_passes_sp(self): + from generalresearch.models.precision import PrecisionStatus + from generalresearch.models.precision.survey import PrecisionSurvey + from tests.models.precision import survey_json + + s = PrecisionSurvey.model_validate(survey_json) + ce = {k: True for k in ["b41e1a3", "bc89ee8", "4124366", "9f32c61"]} + passes, hashes = s.determine_eligibility_soft(ce) + + # We don't know if we match the 2nd quota, but it is open so it doesn't matter + assert passes + assert (True, []) == s.quotas[0].matches_soft(ce) + assert (None, ["0cdc304", "500af2c"]) == s.quotas[1].matches_soft(ce) + + # Now don't know if we match either + ce.pop("9f32c61") # age + passes, hashes = s.determine_eligibility_soft(ce) + assert passes is None + assert {"500af2c", "9f32c61", "0cdc304"} == hashes + + ce["9f32c61"] = False + ce["0cdc304"] = False + # We know we don't match either + assert (False, set()) == s.determine_eligibility_soft(ce) + + # We pass 1st quota, 2nd is unknown but closed, so we don't know for sure we pass + ce = {k: True for k in ["b41e1a3", "bc89ee8", "4124366", "9f32c61"]} + s.quotas[1].status = PrecisionStatus.CLOSED + assert (None, {"0cdc304", "500af2c"}) == s.determine_eligibility_soft(ce) diff --git a/tests/models/precision/test_survey_manager.py b/tests/models/precision/test_survey_manager.py new file mode 100644 index 0000000..8532ab0 --- /dev/null +++ b/tests/models/precision/test_survey_manager.py @@ -0,0 +1,63 @@ +# from decimal import Decimal +# +# from datetime import timezone, datetime +# from pymysql import IntegrityError +# from generalresearch.models.precision.survey import PrecisionSurvey +# from tests.models.precision import survey_json + + +# def delete_survey(survey_id: str): +# db_name = sql_helper.db +# # TODO: what is the precision specific db name... +# +# sql_helper.execute_sql_query( +# query=""" +# DELETE FROM `300large-precision`.precision_survey +# WHERE survey_id = %s +# """, +# params=[survey_id], commit=True) +# sql_helper.execute_sql_query(""" +# DELETE FROM `300large-precision`.precision_survey_country WHERE survey_id = %s +# """, [survey_id], commit=True) +# sql_helper.execute_sql_query(""" +# DELETE FROM `300large-precision`.precision_survey_language WHERE survey_id = %s +# """, [survey_id], commit=True) +# +# +# class TestPrecisionSurvey: +# def test_survey_create(self): +# now = datetime.now(tz=timezone.utc) +# s = PrecisionSurvey.model_validate(survey_json) +# self.assertEqual(s.survey_id, '0000') +# delete_survey(s.survey_id) +# +# sm.create(s) +# +# surveys = sm.get_survey_library(updated_since=now) +# self.assertEqual(len(surveys), 1) +# self.assertEqual('0000', surveys[0].survey_id) +# self.assertTrue(s.is_unchanged(surveys[0])) +# +# with self.assertRaises(IntegrityError) as context: +# sm.create(s) +# +# def test_survey_update(self): +# # There's extra complexity here with the country/lang join tables +# now = datetime.now(tz=timezone.utc) +# s = PrecisionSurvey.model_validate(survey_json) +# self.assertEqual(s.survey_id, '0000') +# delete_survey(s.survey_id) +# sm.create(s) +# s.cpi = Decimal('0.50') +# # started out at only 'ca' and 'eng' +# s.country_isos = ['us'] +# s.country_iso = 'us' +# s.language_isos = ['eng', 'spa'] +# s.language_iso = 'eng' +# sm.update([s]) +# surveys = sm.get_survey_library(updated_since=now) +# self.assertEqual(len(surveys), 1) +# s2 = surveys[0] +# self.assertEqual('0000', s2.survey_id) +# self.assertEqual(Decimal('0.50'), s2.cpi) +# self.assertTrue(s.is_unchanged(s2)) diff --git a/tests/models/prodege/__init__.py b/tests/models/prodege/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/prodege/test_survey_participation.py b/tests/models/prodege/test_survey_participation.py new file mode 100644 index 0000000..b85cc91 --- /dev/null +++ b/tests/models/prodege/test_survey_participation.py @@ -0,0 +1,120 @@ +from datetime import timezone, datetime, timedelta + + +class TestProdegeParticipation: + + def test_exclude(self): + from generalresearch.models.prodege import ProdegePastParticipationType + from generalresearch.models.prodege.survey import ( + ProdegePastParticipation, + ProdegeUserPastParticipation, + ) + + now = datetime.now(tz=timezone.utc) + pp = ProdegePastParticipation.from_api( + { + "participation_project_ids": [152677146, 152803285], + "filter_type": "exclude", + "in_past_days": 7, + "participation_types": ["complete"], + } + ) + # User has no history, so is eligible + assert pp.is_eligible([]) + + # user abandoned. its a click, not complete, so he's eligible + upps = [ + ProdegeUserPastParticipation( + started=now - timedelta(hours=69), survey_id="152677146" + ) + ] + assert pp.is_eligible(upps) + + # user completes. ineligible + upps = [ + ProdegeUserPastParticipation( + started=now - timedelta(hours=69), + survey_id="152677146", + ext_status_code_1="1", + ) + ] + assert not pp.is_eligible(upps) + + # user completed. but too long ago + upps = [ + ProdegeUserPastParticipation( + started=now - timedelta(days=100), + survey_id="152677146", + ext_status_code_1="1", + ) + ] + assert pp.is_eligible(upps) + + # remove day filter, should be ineligble again + pp = ProdegePastParticipation.from_api( + { + "participation_project_ids": [152677146, 152803285], + "filter_type": "exclude", + "in_past_days": 0, + "participation_types": ["complete"], + } + ) + assert not pp.is_eligible(upps) + + # I almost forgot this.... a "complete" IS ALSO A "click"!!! + pp = ProdegePastParticipation.from_api( + { + "participation_project_ids": [152677146, 152803285], + "filter_type": "exclude", + "in_past_days": 0, + "participation_types": ["click"], + } + ) + upps = [ + ProdegeUserPastParticipation( + started=now - timedelta(hours=69), + survey_id="152677146", + ext_status_code_1="1", + ) + ] + assert { + ProdegePastParticipationType.COMPLETE, + ProdegePastParticipationType.CLICK, + } == upps[0].participation_types + assert not pp.is_eligible(upps) + + def test_include(self): + from generalresearch.models.prodege.survey import ( + ProdegePastParticipation, + ProdegeUserPastParticipation, + ) + + now = datetime.now(tz=timezone.utc) + pp = ProdegePastParticipation.from_api( + { + "participation_project_ids": [152677146, 152803285], + "filter_type": "include", + "in_past_days": 7, + "participation_types": ["complete"], + } + ) + # User has no history, so is IN-eligible + assert not pp.is_eligible([]) + + # user abandoned. its a click, not complete, so he's INeligible + upps = [ + ProdegeUserPastParticipation( + started=now - timedelta(hours=69), survey_id="152677146" + ) + ] + assert not pp.is_eligible(upps) + + # user completes, eligible + upps = [ + ProdegeUserPastParticipation( + started=now - timedelta(hours=69), + survey_id="152677146", + ext_status_code_1="1", + ) + ] + assert pp.is_eligible(upps) diff --git a/tests/models/spectrum/__init__.py b/tests/models/spectrum/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/spectrum/test_question.py b/tests/models/spectrum/test_question.py new file mode 100644 index 0000000..ba118d7 --- /dev/null +++ b/tests/models/spectrum/test_question.py @@ -0,0 +1,216 @@ +from datetime import datetime, timezone + +from generalresearch.models import Source +from generalresearch.models.spectrum.question import ( + SpectrumQuestionOption, + SpectrumQuestion, + SpectrumQuestionType, + SpectrumQuestionClass, +) +from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + UpkQuestionSelectorMC, + UpkQuestionType, + UpkQuestionChoice, +) + + +class TestSpectrumQuestion: + + def test_parse_from_api_1(self): + + example_1 = { + "qualification_code": 213, + "text": "My household earns approximately $%%213%% per year", + "cat": None, + "desc": "Income", + "type": 5, + "class": 1, + "condition_codes": [], + "format": {"min": 0, "max": 999999, "regex": "/^([0-9]{1,6})$/i"}, + "crtd_on": 1502869927688, + "mod_on": 1706557247467, + } + q = SpectrumQuestion.from_api(example_1, "us", "eng") + + expected_q = SpectrumQuestion( + question_id="213", + country_iso="us", + language_iso="eng", + question_name="Income", + question_text="My household earns approximately $___ per year", + question_type=SpectrumQuestionType.TEXT_ENTRY, + tags=None, + options=None, + class_num=SpectrumQuestionClass.CORE, + created=datetime(2017, 8, 16, 7, 52, 7, 688000, tzinfo=timezone.utc), + is_live=True, + source=Source.SPECTRUM, + category_id=None, + ) + assert "My household earns approximately $___ per year" == q.question_text + assert "213" == q.question_id + assert expected_q == q + q.to_upk_question() + assert "s:213" == q.external_id + + def test_parse_from_api_2(self): + + example_2 = { + "qualification_code": 211, + "text": "I'm a %%211%%", + "cat": None, + "desc": "Gender", + "type": 1, + "class": 1, + "condition_codes": [ + {"id": "111", "text": "Male"}, + {"id": "112", "text": "Female"}, + ], + "format": {"min": None, "max": None, "regex": ""}, + "crtd_on": 1502869927688, + "mod_on": 1706557249817, + } + q = SpectrumQuestion.from_api(example_2, "us", "eng") + expected_q = SpectrumQuestion( + question_id="211", + country_iso="us", + language_iso="eng", + question_name="Gender", + question_text="I'm a", + question_type=SpectrumQuestionType.SINGLE_SELECT, + tags=None, + options=[ + SpectrumQuestionOption(id="111", text="Male", order=0), + SpectrumQuestionOption(id="112", text="Female", order=1), + ], + class_num=SpectrumQuestionClass.CORE, + created=datetime(2017, 8, 16, 7, 52, 7, 688000, tzinfo=timezone.utc), + is_live=True, + source=Source.SPECTRUM, + category_id=None, + ) + assert expected_q == q + q.to_upk_question() + + def test_parse_from_api_3(self): + + example_3 = { + "qualification_code": 220, + "text": "My child is a %%230%% %%221%% old %%220%%", + "cat": None, + "desc": "Child Dependent", + "type": 6, + "class": 4, + "condition_codes": [ + {"id": "111", "text": "Boy"}, + {"id": "112", "text": "Girl"}, + ], + "format": {"min": None, "max": None, "regex": ""}, + "crtd_on": 1502869927688, + "mod_on": 1706556781278, + } + q = SpectrumQuestion.from_api(example_3, "us", "eng") + # This fails because the text has variables from other questions in it + assert q is None + + def test_parse_from_api_4(self): + + example_4 = { + "qualification_code": 1039, + "text": "Do you suffer from any of the following ailments or medical conditions? (Select all that apply) " + " %%1039%%", + "cat": "Ailments, Illness", + "desc": "Standard Ailments", + "type": 3, + "class": 2, + "condition_codes": [ + {"id": "111", "text": "Allergies (Food, Nut, Skin)"}, + {"id": "999", "text": "None of the above"}, + {"id": "130", "text": "Other"}, + { + "id": "129", + "text": "Women's Health Conditions (Reproductive Issues)", + }, + ], + "format": {"min": None, "max": None, "regex": ""}, + "crtd_on": 1502869927688, + "mod_on": 1706557241693, + } + q = SpectrumQuestion.from_api(example_4, "us", "eng") + expected_q = SpectrumQuestion( + question_id="1039", + country_iso="us", + language_iso="eng", + question_name="Standard Ailments", + question_text="Do you suffer from any of the following ailments or medical conditions? (Select all that " + "apply)", + question_type=SpectrumQuestionType.MULTI_SELECT, + tags="Ailments, Illness", + options=[ + SpectrumQuestionOption( + id="111", text="Allergies (Food, Nut, Skin)", order=0 + ), + SpectrumQuestionOption( + id="129", + text="Women's Health Conditions (Reproductive Issues)", + order=1, + ), + SpectrumQuestionOption(id="130", text="Other", order=2), + SpectrumQuestionOption(id="999", text="None of the above", order=3), + ], + class_num=SpectrumQuestionClass.EXTENDED, + created=datetime(2017, 8, 16, 7, 52, 7, 688000, tzinfo=timezone.utc), + is_live=True, + source=Source.SPECTRUM, + category_id=None, + ) + assert expected_q == q + + # todo: we should have something that infers that if the choice text is "None of the above", + # then the choice is exclusive + u = UpkQuestion( + id=None, + ext_question_id="s:1039", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + country_iso="us", + language_iso="eng", + text="Do you suffer from any of the following ailments or medical conditions? (Select all " + "that apply)", + choices=[ + UpkQuestionChoice( + id="111", + text="Allergies (Food, Nut, Skin)", + order=0, + group=None, + exclusive=False, + importance=None, + ), + UpkQuestionChoice( + id="129", + text="Women's Health Conditions (Reproductive Issues)", + order=1, + group=None, + exclusive=False, + importance=None, + ), + UpkQuestionChoice( + id="130", + text="Other", + order=2, + group=None, + exclusive=False, + importance=None, + ), + UpkQuestionChoice( + id="999", + text="None of the above", + order=3, + group=None, + exclusive=False, + importance=None, + ), + ], + ) + assert u == q.to_upk_question() diff --git a/tests/models/spectrum/test_survey.py b/tests/models/spectrum/test_survey.py new file mode 100644 index 0000000..65dec60 --- /dev/null +++ b/tests/models/spectrum/test_survey.py @@ -0,0 +1,413 @@ +from datetime import timezone, datetime +from decimal import Decimal + + +class TestSpectrumCondition: + + def test_condition_create(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.spectrum.survey import ( + SpectrumCondition, + ) + from generalresearch.models.thl.survey.condition import ConditionValueType + + c = SpectrumCondition.from_api( + { + "qualification_code": 212, + "range_sets": [ + {"units": 311, "to": 28, "from": 25}, + {"units": 311, "to": 42, "from": 40}, + ], + } + ) + assert ( + SpectrumCondition( + question_id="212", + values=["25-28", "40-42"], + value_type=ConditionValueType.RANGE, + negate=False, + logical_operator=LogicalOperator.OR, + ) + == c + ) + + # These equal each other b/c age ranges get automatically converted + assert ( + SpectrumCondition( + question_id="212", + values=["25", "26", "27", "28", "40", "41", "42"], + value_type=ConditionValueType.LIST, + negate=False, + logical_operator=LogicalOperator.OR, + ) + == c + ) + + c = SpectrumCondition.from_api( + { + "condition_codes": ["111", "117", "112", "113", "118"], + "qualification_code": 1202, + } + ) + assert ( + SpectrumCondition( + question_id="1202", + values=["111", "112", "113", "117", "118"], + value_type=ConditionValueType.LIST, + negate=False, + logical_operator=LogicalOperator.OR, + ) + == c + ) + + +class TestSpectrumQuota: + + def test_quota_create(self): + from generalresearch.models.spectrum.survey import ( + SpectrumCondition, + SpectrumQuota, + ) + + d = { + "quota_id": "a846b545-4449-4d76-93a2-f8ebdf6e711e", + "quantities": {"currently_open": 57, "remaining": 57, "achieved": 0}, + "criteria": [{"qualification_code": 211, "condition_codes": ["111"]}], + "crtd_on": 1716227282077, + "mod_on": 1716227284146, + "last_complete_date": None, + } + criteria = [SpectrumCondition.from_api(q) for q in d["criteria"]] + d["condition_hashes"] = [x.criterion_hash for x in criteria] + q = SpectrumQuota.from_api(d) + assert SpectrumQuota(remaining_count=57, condition_hashes=["c23c0b9"]) == q + assert q.is_open + + def test_quota_passes(self): + from generalresearch.models.spectrum.survey import ( + SpectrumQuota, + ) + + q = SpectrumQuota(remaining_count=57, condition_hashes=["a"]) + assert q.passes({"a": True}) + assert not q.passes({"a": False}) + assert not q.passes({}) + + # We have to match all + q = SpectrumQuota(remaining_count=57, condition_hashes=["a", "b", "c"]) + assert not q.passes({"a": True, "b": False}) + assert q.passes({"a": True, "b": True, "c": True}) + + # Quota must be open, even if we match + q = SpectrumQuota(remaining_count=0, condition_hashes=["a"]) + assert not q.passes({"a": True}) + + def test_quota_passes_soft(self): + from generalresearch.models.spectrum.survey import ( + SpectrumQuota, + ) + + q = SpectrumQuota(remaining_count=57, condition_hashes=["a", "b", "c"]) + # Pass if we match all + assert (True, set()) == q.matches_soft({"a": True, "b": True, "c": True}) + # Fail if we don't match any + assert (False, set()) == q.matches_soft({"a": True, "b": False, "c": None}) + # Unknown if any are unknown AND we don't fail any + assert (None, {"c", "b"}) == q.matches_soft({"a": True, "b": None, "c": None}) + assert (None, {"a", "c", "b"}) == q.matches_soft( + {"a": None, "b": None, "c": None} + ) + assert (False, set()) == q.matches_soft({"a": None, "b": False, "c": None}) + + +class TestSpectrumSurvey: + def test_survey_create(self): + from generalresearch.models import ( + LogicalOperator, + Source, + TaskCalculationType, + ) + from generalresearch.models.spectrum import SpectrumStatus + from generalresearch.models.spectrum.survey import ( + SpectrumCondition, + SpectrumQuota, + SpectrumSurvey, + ) + from generalresearch.models.thl.survey.condition import ConditionValueType + + # Note: d is the raw response after calling SpectrumAPI.preprocess_survey() on it! + d = { + "survey_id": 29333264, + "survey_name": "Exciting New Survey #29333264", + "survey_status": 22, + "field_end_date": datetime(2024, 5, 23, 18, 18, 31, tzinfo=timezone.utc), + "category": "Exciting New", + "category_code": 232, + "crtd_on": datetime(2024, 5, 20, 17, 48, 13, tzinfo=timezone.utc), + "mod_on": datetime(2024, 5, 20, 18, 18, 31, tzinfo=timezone.utc), + "soft_launch": False, + "click_balancing": 0, + "price_type": 1, + "pii": False, + "buyer_message": "", + "buyer_id": 4726, + "incl_excl": 0, + "cpi": Decimal("1.20000"), + "last_complete_date": None, + "project_last_complete_date": None, + "survey_performance": { + "overall": {"ir": 40, "loi": 10}, + "last_block": {"ir": None, "loi": None}, + }, + "supplier_completes": { + "needed": 495, + "achieved": 0, + "remaining": 495, + "guaranteed_allocation": 0, + "guaranteed_allocation_remaining": 0, + }, + "pds": {"enabled": False, "buyer_name": None}, + "quotas": [ + { + "quota_id": "c2bc961e-4f26-4223-b409-ebe9165cfdf5", + "quantities": { + "currently_open": 491, + "remaining": 495, + "achieved": 0, + }, + "criteria": [ + { + "qualification_code": 212, + "range_sets": [{"units": 311, "to": 64, "from": 18}], + } + ], + "crtd_on": 1716227293496, + "mod_on": 1716229289847, + "last_complete_date": None, + } + ], + "qualifications": [ + { + "range_sets": [{"units": 311, "to": 64, "from": 18}], + "qualification_code": 212, + } + ], + "country_iso": "fr", + "language_iso": "fre", + "bid_ir": 0.4, + "bid_loi": 600, + "last_block_ir": None, + "last_block_loi": None, + "survey_exclusions": set(), + "exclusion_period": 0, + } + s = SpectrumSurvey.from_api(d) + expected_survey = SpectrumSurvey( + cpi=Decimal("1.20000"), + country_isos=["fr"], + language_isos=["fre"], + buyer_id="4726", + source=Source.SPECTRUM, + used_question_ids={"212"}, + survey_id="29333264", + survey_name="Exciting New Survey #29333264", + status=SpectrumStatus.LIVE, + field_end_date=datetime(2024, 5, 23, 18, 18, 31, tzinfo=timezone.utc), + category_code="232", + calculation_type=TaskCalculationType.COMPLETES, + requires_pii=False, + survey_exclusions=set(), + exclusion_period=0, + bid_ir=0.40, + bid_loi=600, + last_block_loi=None, + last_block_ir=None, + overall_loi=None, + overall_ir=None, + project_last_complete_date=None, + country_iso="fr", + language_iso="fre", + include_psids=None, + exclude_psids=None, + qualifications=["77f493d"], + quotas=[SpectrumQuota(remaining_count=491, condition_hashes=["77f493d"])], + conditions={ + "77f493d": SpectrumCondition( + logical_operator=LogicalOperator.OR, + value_type=ConditionValueType.RANGE, + negate=False, + question_id="212", + values=["18-64"], + ) + }, + created_api=datetime(2024, 5, 20, 17, 48, 13, tzinfo=timezone.utc), + modified_api=datetime(2024, 5, 20, 18, 18, 31, tzinfo=timezone.utc), + updated=None, + ) + assert expected_survey.model_dump_json() == s.model_dump_json() + + def test_survey_properties(self): + from generalresearch.models.spectrum.survey import ( + SpectrumSurvey, + ) + + d = { + "survey_id": 29333264, + "survey_name": "#29333264", + "survey_status": 22, + "field_end_date": datetime(2024, 5, 23, 18, 18, 31, tzinfo=timezone.utc), + "category": "Exciting New", + "category_code": 232, + "crtd_on": datetime(2024, 5, 20, 17, 48, 13, tzinfo=timezone.utc), + "mod_on": datetime(2024, 5, 20, 18, 18, 31, tzinfo=timezone.utc), + "soft_launch": False, + "click_balancing": 0, + "price_type": 1, + "pii": False, + "buyer_message": "", + "buyer_id": 4726, + "incl_excl": 0, + "cpi": Decimal("1.20000"), + "last_complete_date": None, + "project_last_complete_date": None, + "quotas": [ + { + "quota_id": "c2bc961e-4f26-4223-b409-ebe9165cfdf5", + "quantities": { + "currently_open": 491, + "remaining": 495, + "achieved": 0, + }, + "criteria": [ + { + "qualification_code": 214, + "range_sets": [{"units": 311, "to": 64, "from": 18}], + } + ], + } + ], + "qualifications": [ + { + "range_sets": [{"units": 311, "to": 64, "from": 18}], + "qualification_code": 212, + }, + {"condition_codes": ["111", "117", "112"], "qualification_code": 1202}, + ], + "country_iso": "fr", + "language_iso": "fre", + "overall_ir": 0.4, + "overall_loi": 600, + "last_block_ir": None, + "last_block_loi": None, + "survey_exclusions": set(), + "exclusion_period": 0, + } + s = SpectrumSurvey.from_api(d) + assert {"212", "1202", "214"} == s.used_question_ids + assert s.is_live + assert s.is_open + assert {"38cea5e", "83955ef", "77f493d"} == s.all_hashes + + def test_survey_eligibility(self): + from generalresearch.models.spectrum.survey import ( + SpectrumQuota, + SpectrumSurvey, + ) + + d = { + "survey_id": 29333264, + "survey_name": "#29333264", + "survey_status": 22, + "field_end_date": datetime(2024, 5, 23, 18, 18, 31, tzinfo=timezone.utc), + "category": "Exciting New", + "category_code": 232, + "crtd_on": datetime(2024, 5, 20, 17, 48, 13, tzinfo=timezone.utc), + "mod_on": datetime(2024, 5, 20, 18, 18, 31, tzinfo=timezone.utc), + "soft_launch": False, + "click_balancing": 0, + "price_type": 1, + "pii": False, + "buyer_message": "", + "buyer_id": 4726, + "incl_excl": 0, + "cpi": Decimal("1.20000"), + "last_complete_date": None, + "project_last_complete_date": None, + "quotas": [], + "qualifications": [], + "country_iso": "fr", + "language_iso": "fre", + "overall_ir": 0.4, + "overall_loi": 600, + "last_block_ir": None, + "last_block_loi": None, + "survey_exclusions": set(), + "exclusion_period": 0, + } + s = SpectrumSurvey.from_api(d) + s.qualifications = ["a", "b", "c"] + s.quotas = [ + SpectrumQuota(remaining_count=10, condition_hashes=["a", "b"]), + SpectrumQuota(remaining_count=0, condition_hashes=["d"]), + SpectrumQuota(remaining_count=10, condition_hashes=["e"]), + ] + + assert s.passes_qualifications({"a": True, "b": True, "c": True}) + assert not s.passes_qualifications({"a": True, "b": True, "c": False}) + + # we do NOT match a full quota, so we pass + assert s.passes_quotas({"a": True, "b": True, "d": False}) + # We dont pass any + assert not s.passes_quotas({}) + # we only pass a full quota + assert not s.passes_quotas({"d": True}) + # we only dont pass a full quota, but we haven't passed any open + assert not s.passes_quotas({"d": False}) + # we pass a quota, but also pass a full quota, so fail + assert not s.passes_quotas({"e": True, "d": True}) + # we pass a quota, but are unknown in a full quota, so fail + assert not s.passes_quotas({"e": True}) + + # # Soft Pair + assert (True, set()) == s.passes_qualifications_soft( + {"a": True, "b": True, "c": True} + ) + assert (False, set()) == s.passes_qualifications_soft( + {"a": True, "b": True, "c": False} + ) + assert (None, set("c")) == s.passes_qualifications_soft( + {"a": True, "b": True, "c": None} + ) + + # we do NOT match a full quota, so we pass + assert (True, set()) == s.passes_quotas_soft({"a": True, "b": True, "d": False}) + # We dont pass any + assert (None, {"a", "b", "d", "e"}) == s.passes_quotas_soft({}) + # we only pass a full quota + assert (False, set()) == s.passes_quotas_soft({"d": True}) + # we only dont pass a full quota, but we haven't passed any open + assert (None, {"a", "b", "e"}) == s.passes_quotas_soft({"d": False}) + # we pass a quota, but also pass a full quota, so fail + assert (False, set()) == s.passes_quotas_soft({"e": True, "d": True}) + # we pass a quota, but are unknown in a full quota, so fail + assert (None, {"d"}) == s.passes_quotas_soft({"e": True}) + + assert s.determine_eligibility({"a": True, "b": True, "c": True, "d": False}) + assert not s.determine_eligibility( + {"a": True, "b": True, "c": False, "d": False} + ) + assert not s.determine_eligibility( + {"a": True, "b": True, "c": None, "d": False} + ) + assert (True, set()) == s.determine_eligibility_soft( + {"a": True, "b": True, "c": True, "d": False} + ) + assert (False, set()) == s.determine_eligibility_soft( + {"a": True, "b": True, "c": False, "d": False} + ) + assert (None, set("c")) == s.determine_eligibility_soft( + {"a": True, "b": True, "c": None, "d": False} + ) + assert (None, {"c", "d"}) == s.determine_eligibility_soft( + {"a": True, "b": True, "c": None, "d": None} + ) diff --git a/tests/models/spectrum/test_survey_manager.py b/tests/models/spectrum/test_survey_manager.py new file mode 100644 index 0000000..582093c --- /dev/null +++ b/tests/models/spectrum/test_survey_manager.py @@ -0,0 +1,130 @@ +import copy +import logging +from datetime import timezone, datetime +from decimal import Decimal + +from pymysql import IntegrityError + + +logger = logging.getLogger() + +example_survey_api_response = { + "survey_id": 29333264, + "survey_name": "#29333264", + "survey_status": 22, + "field_end_date": datetime(2024, 5, 23, 18, 18, 31, tzinfo=timezone.utc), + "category": "Exciting New", + "category_code": 232, + "crtd_on": datetime(2024, 5, 20, 17, 48, 13, tzinfo=timezone.utc), + "mod_on": datetime(2024, 5, 20, 18, 18, 31, tzinfo=timezone.utc), + "soft_launch": False, + "click_balancing": 0, + "price_type": 1, + "pii": False, + "buyer_message": "", + "buyer_id": 4726, + "incl_excl": 0, + "cpi": Decimal("1.20"), + "last_complete_date": None, + "project_last_complete_date": None, + "quotas": [ + { + "quota_id": "c2bc961e-4f26-4223-b409-ebe9165cfdf5", + "quantities": {"currently_open": 491, "remaining": 495, "achieved": 0}, + "criteria": [ + { + "qualification_code": 214, + "range_sets": [{"units": 311, "to": 64, "from": 18}], + } + ], + } + ], + "qualifications": [ + { + "range_sets": [{"units": 311, "to": 64, "from": 18}], + "qualification_code": 212, + }, + {"condition_codes": ["111", "117", "112"], "qualification_code": 1202}, + ], + "country_iso": "fr", + "language_iso": "fre", + "bid_ir": 0.4, + "bid_loi": 600, + "overall_ir": None, + "overall_loi": None, + "last_block_ir": None, + "last_block_loi": None, + "survey_exclusions": set(), + "exclusion_period": 0, +} + + +class TestSpectrumSurvey: + + def test_survey_create(self, settings, spectrum_manager, spectrum_rw): + from generalresearch.models.spectrum.survey import SpectrumSurvey + + assert settings.debug, "CRITICAL: Do not run this on production." + + now = datetime.now(tz=timezone.utc) + spectrum_rw.execute_sql_query( + query=f""" + DELETE FROM `{spectrum_rw.db}`.spectrum_survey + WHERE survey_id = '29333264'""", + commit=True, + ) + + d = example_survey_api_response.copy() + s = SpectrumSurvey.from_api(d) + spectrum_manager.create(s) + + surveys = spectrum_manager.get_survey_library(updated_since=now) + assert len(surveys) == 1 + assert "29333264" == surveys[0].survey_id + assert s.is_unchanged(surveys[0]) + + try: + spectrum_manager.create(s) + except IntegrityError as e: + print(e.args) + + def test_survey_update(self, settings, spectrum_manager, spectrum_rw): + from generalresearch.models.spectrum.survey import SpectrumSurvey + + assert settings.debug, "CRITICAL: Do not run this on production." + + now = datetime.now(tz=timezone.utc) + spectrum_rw.execute_sql_query( + query=f""" + DELETE FROM `{spectrum_rw.db}`.spectrum_survey + WHERE survey_id = '29333264' + """, + commit=True, + ) + d = copy.deepcopy(example_survey_api_response) + s = SpectrumSurvey.from_api(d) + print(s) + + spectrum_manager.create(s) + s.cpi = Decimal("0.50") + spectrum_manager.update([s]) + surveys = spectrum_manager.get_survey_library(updated_since=now) + assert len(surveys) == 1 + assert "29333264" == surveys[0].survey_id + assert Decimal("0.50") == surveys[0].cpi + assert s.is_unchanged(surveys[0]) + + # --- Updating bid/overall/last block + assert 600 == s.bid_loi + assert s.overall_loi is None + assert s.last_block_loi is None + + # now the last block is set + s.bid_loi = None + s.overall_loi = 1000 + s.last_block_loi = 1000 + spectrum_manager.update([s]) + surveys = spectrum_manager.get_survey_library(updated_since=now) + assert 600 == surveys[0].bid_loi + assert 1000 == surveys[0].overall_loi + assert 1000 == surveys[0].last_block_loi diff --git a/tests/models/test_currency.py b/tests/models/test_currency.py new file mode 100644 index 0000000..40cff88 --- /dev/null +++ b/tests/models/test_currency.py @@ -0,0 +1,410 @@ +"""These were taken from the wxet project's first use of this idea. Not all +functionality is the same, but pasting here so the tests are in the +correct spot... +""" + +from decimal import Decimal +from random import randint + +import pytest + + +class TestUSDCentModel: + + def test_construct_int(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val = randint(0, 999_999) + instance = USDCent(int_val) + assert int_val == instance + + def test_construct_float(self): + from generalresearch.currency import USDCent + + with pytest.warns(expected_warning=Warning) as record: + float_val: float = 10.6789 + instance = USDCent(float_val) + + assert len(record) == 1 + assert "USDCent init with a float. Rounding behavior may be unexpected" in str( + record[0].message + ) + assert instance == USDCent(10) + assert instance == 10 + + def test_construct_decimal(self): + from generalresearch.currency import USDCent + + with pytest.warns(expected_warning=Warning) as record: + decimal_val: Decimal = Decimal("10.0") + instance = USDCent(decimal_val) + + assert len(record) == 1 + assert ( + "USDCent init with a Decimal. Rounding behavior may be unexpected" + in str(record[0].message) + ) + + assert instance == USDCent(10) + assert instance == 10 + + # Now with rounding + with pytest.warns(Warning) as record: + decimal_val: Decimal = Decimal("10.6789") + instance = USDCent(decimal_val) + + assert len(record) == 1 + assert ( + "USDCent init with a Decimal. Rounding behavior may be unexpected" + in str(record[0].message) + ) + + assert instance == USDCent(10) + assert instance == 10 + + def test_construct_negative(self): + from generalresearch.currency import USDCent + + with pytest.raises(expected_exception=ValueError) as cm: + USDCent(-1) + assert "USDCent not be less than zero" in str(cm.value) + + def test_operation_add(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val1 = randint(0, 999_999) + int_val2 = randint(0, 999_999) + + instance1 = USDCent(int_val1) + instance2 = USDCent(int_val2) + + assert int_val1 + int_val2 == instance1 + instance2 + + def test_operation_subtract(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val1 = randint(500_000, 999_999) + int_val2 = randint(0, 499_999) + + instance1 = USDCent(int_val1) + instance2 = USDCent(int_val2) + + assert int_val1 - int_val2 == instance1 - instance2 + + def test_operation_subtract_to_neg(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val = randint(0, 999_999) + instance = USDCent(int_val) + + with pytest.raises(expected_exception=ValueError) as cm: + instance - USDCent(1_000_000) + + assert "USDCent not be less than zero" in str(cm.value) + + def test_operation_multiply(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val1 = randint(0, 999_999) + int_val2 = randint(0, 999_999) + + instance1 = USDCent(int_val1) + instance2 = USDCent(int_val2) + + assert int_val1 * int_val2 == instance1 * instance2 + + def test_operation_div(self): + from generalresearch.currency import USDCent + + with pytest.raises(ValueError) as cm: + USDCent(10) / 2 + assert "Division not allowed for USDCent" in str(cm.value) + + def test_operation_result_type(self): + from generalresearch.currency import USDCent + + int_val = randint(1, 999_999) + instance = USDCent(int_val) + + res_add = instance + USDCent(1) + assert isinstance(res_add, USDCent) + + res_sub = instance - USDCent(1) + assert isinstance(res_sub, USDCent) + + res_multipy = instance * USDCent(2) + assert isinstance(res_multipy, USDCent) + + def test_operation_partner_add(self): + from generalresearch.currency import USDCent + + int_val = randint(1, 999_999) + instance = USDCent(int_val) + + with pytest.raises(expected_exception=AssertionError): + instance + 0.10 + + with pytest.raises(expected_exception=AssertionError): + instance + Decimal(".10") + + with pytest.raises(expected_exception=AssertionError): + instance + "9.9" + + with pytest.raises(expected_exception=AssertionError): + instance + True + + def test_abs(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val = abs(randint(0, 999_999)) + instance = abs(USDCent(int_val)) + + assert int_val == instance + + def test_str(self): + from generalresearch.currency import USDCent + + for i in range(100): + int_val = randint(0, 999_999) + instance = USDCent(int_val) + + assert str(int_val) == str(instance) + + def test_operation_result_type_unsupported(self): + """There is no correct answer here, but we at least want to make sure + that a USDCent is returned + """ + from generalresearch.currency import USDCent + + res = USDCent(10) // 1.2 + assert not isinstance(res, USDCent) + assert isinstance(res, float) + + res = USDCent(10) % 1 + assert not isinstance(res, USDCent) + assert isinstance(res, int) + + res = pow(USDCent(10), 2) + assert not isinstance(res, USDCent) + assert isinstance(res, int) + + res = pow(USDCent(10), USDCent(2)) + assert not isinstance(res, USDCent) + assert isinstance(res, int) + + res = float(USDCent(10)) + assert not isinstance(res, USDCent) + assert isinstance(res, float) + + +class TestUSDMillModel: + + def test_construct_int(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val = randint(0, 999_999) + instance = USDMill(int_val) + assert int_val == instance + + def test_construct_float(self): + from generalresearch.currency import USDMill + + with pytest.warns(expected_warning=Warning) as record: + float_val: float = 10.6789 + instance = USDMill(float_val) + + assert len(record) == 1 + assert "USDMill init with a float. Rounding behavior may be unexpected" in str( + record[0].message + ) + assert instance == USDMill(10) + assert instance == 10 + + def test_construct_decimal(self): + from generalresearch.currency import USDMill + + with pytest.warns(expected_warning=Warning) as record: + decimal_val: Decimal = Decimal("10.0") + instance = USDMill(decimal_val) + + assert len(record) == 1 + assert ( + "USDMill init with a Decimal. Rounding behavior may be unexpected" + in str(record[0].message) + ) + + assert instance == USDMill(10) + assert instance == 10 + + # Now with rounding + with pytest.warns(expected_warning=Warning) as record: + decimal_val: Decimal = Decimal("10.6789") + instance = USDMill(decimal_val) + + assert len(record) == 1 + assert ( + "USDMill init with a Decimal. Rounding behavior may be unexpected" + in str(record[0].message) + ) + + assert instance == USDMill(10) + assert instance == 10 + + def test_construct_negative(self): + from generalresearch.currency import USDMill + + with pytest.raises(expected_exception=ValueError) as cm: + USDMill(-1) + assert "USDMill not be less than zero" in str(cm.value) + + def test_operation_add(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val1 = randint(0, 999_999) + int_val2 = randint(0, 999_999) + + instance1 = USDMill(int_val1) + instance2 = USDMill(int_val2) + + assert int_val1 + int_val2 == instance1 + instance2 + + def test_operation_subtract(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val1 = randint(500_000, 999_999) + int_val2 = randint(0, 499_999) + + instance1 = USDMill(int_val1) + instance2 = USDMill(int_val2) + + assert int_val1 - int_val2 == instance1 - instance2 + + def test_operation_subtract_to_neg(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val = randint(0, 999_999) + instance = USDMill(int_val) + + with pytest.raises(expected_exception=ValueError) as cm: + instance - USDMill(1_000_000) + + assert "USDMill not be less than zero" in str(cm.value) + + def test_operation_multiply(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val1 = randint(0, 999_999) + int_val2 = randint(0, 999_999) + + instance1 = USDMill(int_val1) + instance2 = USDMill(int_val2) + + assert int_val1 * int_val2 == instance1 * instance2 + + def test_operation_div(self): + from generalresearch.currency import USDMill + + with pytest.raises(ValueError) as cm: + USDMill(10) / 2 + assert "Division not allowed for USDMill" in str(cm.value) + + def test_operation_result_type(self): + from generalresearch.currency import USDMill + + int_val = randint(1, 999_999) + instance = USDMill(int_val) + + res_add = instance + USDMill(1) + assert isinstance(res_add, USDMill) + + res_sub = instance - USDMill(1) + assert isinstance(res_sub, USDMill) + + res_multipy = instance * USDMill(2) + assert isinstance(res_multipy, USDMill) + + def test_operation_partner_add(self): + from generalresearch.currency import USDMill + + int_val = randint(1, 999_999) + instance = USDMill(int_val) + + with pytest.raises(expected_exception=AssertionError): + instance + 0.10 + + with pytest.raises(expected_exception=AssertionError): + instance + Decimal(".10") + + with pytest.raises(expected_exception=AssertionError): + instance + "9.9" + + with pytest.raises(expected_exception=AssertionError): + instance + True + + def test_abs(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val = abs(randint(0, 999_999)) + instance = abs(USDMill(int_val)) + + assert int_val == instance + + def test_str(self): + from generalresearch.currency import USDMill + + for i in range(100): + int_val = randint(0, 999_999) + instance = USDMill(int_val) + + assert str(int_val) == str(instance) + + def test_operation_result_type_unsupported(self): + """There is no correct answer here, but we at least want to make sure + that a USDMill is returned + """ + from generalresearch.currency import USDCent, USDMill + + res = USDMill(10) // 1.2 + assert not isinstance(res, USDMill) + assert isinstance(res, float) + + res = USDMill(10) % 1 + assert not isinstance(res, USDMill) + assert isinstance(res, int) + + res = pow(USDMill(10), 2) + assert not isinstance(res, USDMill) + assert isinstance(res, int) + + res = pow(USDMill(10), USDMill(2)) + assert not isinstance(res, USDCent) + assert isinstance(res, int) + + res = float(USDMill(10)) + assert not isinstance(res, USDMill) + assert isinstance(res, float) + + +class TestNegativeFormatting: + + def test_pos(self): + from generalresearch.currency import format_usd_cent + + assert "-$987.65" == format_usd_cent(-98765) + + def test_neg(self): + from generalresearch.currency import format_usd_cent + + assert "-$123.45" == format_usd_cent(-12345) diff --git a/tests/models/test_device.py b/tests/models/test_device.py new file mode 100644 index 0000000..480e0c0 --- /dev/null +++ b/tests/models/test_device.py @@ -0,0 +1,27 @@ +import pytest + +iphone_ua_string = ( + "Mozilla/5.0 (iPhone; CPU iPhone OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) " + "Version/5.1 Mobile/9B179 Safari/7534.48.3" +) +ipad_ua_string = ( + "Mozilla/5.0(iPad; U; CPU iPhone OS 3_2 like Mac OS X; en-us) AppleWebKit/531.21.10 (KHTML, " + "like Gecko) Version/4.0.4 Mobile/7B314 Safari/531.21.10" +) +windows_ie_ua_string = "Mozilla/5.0 (compatible; MSIE 9.0; Windows NT 6.1; Trident/5.0)" +chromebook_ua_string = ( + "Mozilla/5.0 (X11; CrOS i686 0.12.433) AppleWebKit/534.30 (KHTML, like Gecko) " + "Chrome/12.0.742.77 Safari/534.30" +) + + +class TestDeviceUA: + def test_device_ua(self): + from generalresearch.models import DeviceType + from generalresearch.models.device import parse_device_from_useragent + + assert parse_device_from_useragent(iphone_ua_string) == DeviceType.MOBILE + assert parse_device_from_useragent(ipad_ua_string) == DeviceType.TABLET + assert parse_device_from_useragent(windows_ie_ua_string) == DeviceType.DESKTOP + assert parse_device_from_useragent(chromebook_ua_string) == DeviceType.DESKTOP + assert parse_device_from_useragent("greg bot") == DeviceType.UNKNOWN diff --git a/tests/models/test_finance.py b/tests/models/test_finance.py new file mode 100644 index 0000000..888bf49 --- /dev/null +++ b/tests/models/test_finance.py @@ -0,0 +1,929 @@ +from datetime import timezone, timedelta +from itertools import product as iter_product +from random import randint +from uuid import uuid4 + +import pandas as pd +import pytest + +# noinspection PyUnresolvedReferences +from distributed.utils_test import ( + gen_cluster, + client_no_amm, + loop, + loop_in_thread, + cleanup, + cluster_fixture, + client, +) +from faker import Faker + +from generalresearch.incite.schemas.mergers.pop_ledger import ( + numerical_col_names, +) +from generalresearch.models.thl.finance import ( + POPFinancial, + ProductBalances, + BusinessBalances, +) +from test_utils.conftest import delete_df_collection +from test_utils.incite.collections.conftest import ledger_collection +from test_utils.incite.mergers.conftest import pop_ledger_merge +from test_utils.managers.ledger.conftest import ( + create_main_accounts, + session_with_tx_factory, +) + +fake = Faker() + + +class TestProductBalanceInitialize: + + def test_unknown_fields(self): + with pytest.raises(expected_exception=ValueError): + ProductBalances.model_validate( + { + "bp_payment.DEBIT": 1, + } + ) + + def test_payout(self): + val = randint(1, 1_000) + instance = ProductBalances.model_validate({"bp_payment.CREDIT": val}) + assert instance.payout == val + + def test_adjustment(self): + instance = ProductBalances.model_validate( + {"bp_adjustment.CREDIT": 90, "bp_adjustment.DEBIT": 147} + ) + + assert -57 == instance.adjustment + + def test_plug(self): + instance = ProductBalances.model_validate( + { + "bp_adjustment.CREDIT": 1000, + "bp_adjustment.DEBIT": 200, + "plug.DEBIT": 50, + } + ) + assert 750 == instance.adjustment + + instance = ProductBalances.model_validate( + { + "bp_payment.CREDIT": 789, + "bp_adjustment.CREDIT": 23, + "bp_adjustment.DEBIT": 101, + "plug.DEBIT": 17, + } + ) + assert 694 == instance.net + assert 694 == instance.balance + + def test_expense(self): + instance = ProductBalances.model_validate( + {"user_bonus.CREDIT": 0, "user_bonus.DEBIT": 999} + ) + + assert -999 == instance.expense + + def test_payment(self): + instance = ProductBalances.model_validate( + {"bp_payout.CREDIT": 1, "bp_payout.DEBIT": 100} + ) + + assert 99 == instance.payment + + def test_balance(self): + instance = ProductBalances.model_validate( + { + # Payouts from surveys: 1000 + "bp_payment.CREDIT": 1000, + # Adjustments: -200 + "bp_adjustment.CREDIT": 100, + "bp_adjustment.DEBIT": 300, + # Expense: -50 + "user_bonus.CREDIT": 0, + "user_bonus.DEBIT": 50, + # Prior supplier Payouts = 99 + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 100, + } + ) + + # Supplier payments aren't considered in the net + assert 750 == instance.net + + # Confirm any Supplier payments are taken out of their balance + assert 651 == instance.balance + + def test_retainer(self): + instance = ProductBalances.model_validate( + { + "bp_payment.CREDIT": 1000, + } + ) + + assert 1000 == instance.balance + assert 250 == instance.retainer + + instance = ProductBalances.model_validate( + { + "bp_payment.CREDIT": 1000, + # 1001 worth of adjustments, making it negative + "bp_adjustment.DEBIT": 1001, + } + ) + + assert -1 == instance.balance + assert 0 == instance.retainer + + def test_available_balance(self): + instance = ProductBalances.model_validate( + { + "bp_payment.CREDIT": 1000, + } + ) + + assert 750 == instance.available_balance + + instance = ProductBalances.model_validate( + { + # Payouts from surveys: $188.37 + "bp_payment.CREDIT": 18_837, + # Adjustments: -$7.53 + $.17 + "bp_adjustment.CREDIT": 17, + "bp_adjustment.DEBIT": 753, + # $.15 of those marketplace Failure >> Completes were never + # actually paid out, so plug those positive adjustments + "plug.DEBIT": 15, + # Expense: -$27.45 + "user_bonus.CREDIT": 0, + "user_bonus.DEBIT": 2_745, + # Prior supplier Payouts = $100 + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 10_001, + } + ) + + assert 18837 == instance.payout + assert -751 == instance.adjustment + assert 15341 == instance.net + + # Confirm any Supplier payments are taken out of their balance + assert 5341 == instance.balance + assert 1335 == instance.retainer + assert 4006 == instance.available_balance + + def test_json_schema(self): + instance = ProductBalances.model_validate( + { + # Payouts from surveys: 1000 + "bp_payment.CREDIT": 1000, + # Adjustments: -200 + "bp_adjustment.CREDIT": 100, + "bp_adjustment.DEBIT": 300, + # $.80 of those marketplace Failure >> Completes were never + # actually paid out, so plug those positive adjustments + "plug.DEBIT": 80, + # Expense: -50 + "user_bonus.CREDIT": 0, + "user_bonus.DEBIT": 50, + # Prior supplier Payouts = 99 + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 100, + } + ) + + assert isinstance(instance.model_json_schema(), dict) + openapi_fields = list(instance.model_json_schema()["properties"].keys()) + + # Ensure the SkipJsonSchema is working.. + assert "mp_payment_credit" not in openapi_fields + assert "mp_payment_debit" not in openapi_fields + assert "mp_adjustment_credit" not in openapi_fields + assert "mp_adjustment_debit" not in openapi_fields + assert "bp_payment_debit" not in openapi_fields + assert "plug_credit" not in openapi_fields + assert "plug_debit" not in openapi_fields + + # Confirm the @property computed fields show up in openapi. I don't + # know how to do that yet... so this is check to confirm they're + # known computed fields for now + computed_fields = list(instance.model_computed_fields.keys()) + assert "payout" in computed_fields + assert "adjustment" in computed_fields + assert "expense" in computed_fields + assert "payment" in computed_fields + assert "net" in computed_fields + assert "balance" in computed_fields + assert "retainer" in computed_fields + assert "available_balance" in computed_fields + + def test_repr(self): + instance = ProductBalances.model_validate( + { + # Payouts from surveys: 1000 + "bp_payment.CREDIT": 1000, + # Adjustments: -200 + "bp_adjustment.CREDIT": 100, + "bp_adjustment.DEBIT": 300, + # $.80 of those marketplace Failure >> Completes were never + # actually paid out, so plug those positive adjustments + "plug.DEBIT": 80, + # Expense: -50 + "user_bonus.CREDIT": 0, + "user_bonus.DEBIT": 50, + # Prior supplier Payouts = 99 + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 100, + } + ) + + assert "Total Adjustment: -$2.80" in str(instance) + + +class TestBusinessBalanceInitialize: + + def test_validate_product_ids(self): + instance1 = ProductBalances.model_validate( + {"bp_payment.CREDIT": 500, "bp_adjustment.DEBIT": 40} + ) + + instance2 = ProductBalances.model_validate( + {"bp_payment.CREDIT": 500, "bp_adjustment.DEBIT": 40} + ) + + with pytest.raises(expected_exception=ValueError) as cm: + BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert "'product_id' must be set for BusinessBalance children" in str(cm.value) + + # Confirm that once you add them, it successfully initializes + instance1.product_id = uuid4().hex + instance2.product_id = uuid4().hex + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert isinstance(instance, BusinessBalances) + + def test_payout(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.DEBIT": 40, + } + ) + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.DEBIT": 40, + } + ) + + # Confirm the base payouts are as expected. + assert instance1.payout == 500 + assert instance2.payout == 500 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.payout == 1_000 + + def test_adjustment(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + } + ) + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + } + ) + + # Confirm the base adjustment are as expected. + assert instance1.adjustment == -30 + assert instance2.adjustment == -30 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.adjustment == -60 + + def test_expense(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + } + ) + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + } + ) + + # Confirm the base adjustment are as expected. + assert instance1.expense == -4 + assert instance2.expense == -4 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.expense == -8 + + def test_net(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + } + ) + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + } + ) + + # Confirm the simple net + assert instance1.net == 466 + assert instance2.net == 466 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.net == 466 * 2 + + def test_payment(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 10_001, + } + ) + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 500, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 40, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 10_001, + } + ) + assert instance1.payment == 10_000 + assert instance2.payment == 10_000 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.payment == 20_000 + + def test_balance(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 500, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 10_001, + } + ) + assert instance1.balance == 39_506 + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 40_000, + "bp_adjustment.CREDIT": 2_000, + "bp_adjustment.DEBIT": 400, + "plug.DEBIT": 983, + "user_bonus.DEBIT": 392, + "user_bonus.CREDIT": 0, + "bp_payout.CREDIT": 0, + "bp_payout.DEBIT": 8_000, + } + ) + assert instance2.balance == 32_225 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.balance == 39_506 + 32_225 + + def test_retainer(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 500, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 10_001, + } + ) + assert instance1.balance == 39_506 + assert instance1.retainer == 9_876 + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 40_000, + "bp_adjustment.CREDIT": 2_000, + "bp_adjustment.DEBIT": 400, + "plug.DEBIT": 983, + "user_bonus.DEBIT": 392, + "user_bonus.CREDIT": 0, + "bp_payout.CREDIT": 0, + "bp_payout.DEBIT": 8_000, + } + ) + assert instance2.balance == 32_225 + assert instance2.retainer == 8_056 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.retainer == 9_876 + 8_056 + + def test_available_balance(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.CREDIT": 20, + "bp_adjustment.DEBIT": 500, + "plug.DEBIT": 10, + "user_bonus.DEBIT": 5, + "user_bonus.CREDIT": 1, + "bp_payout.CREDIT": 1, + "bp_payout.DEBIT": 10_001, + } + ) + assert instance1.balance == 39_506 + assert instance1.retainer == 9_876 + assert instance1.available_balance == 39_506 - 9_876 + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 40_000, + "bp_adjustment.CREDIT": 2_000, + "bp_adjustment.DEBIT": 400, + "plug.DEBIT": 983, + "user_bonus.DEBIT": 392, + "user_bonus.CREDIT": 0, + "bp_payout.CREDIT": 0, + "bp_payout.DEBIT": 8_000, + } + ) + assert instance2.balance == 32_225 + assert instance2.retainer == 8_056 + assert instance2.available_balance == 32_225 - 8_056 + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert instance.retainer == 9_876 + 8_056 + assert instance.available_balance == instance.balance - (9_876 + 8_056) + + def test_negative_net(self): + instance1 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.DEBIT": 50_001, + "bp_payout.DEBIT": 4_999, + } + ) + assert 50_000 == instance1.payout + assert -50_001 == instance1.adjustment + assert 4_999 == instance1.payment + + assert -1 == instance1.net + assert -5_000 == instance1.balance + assert 0 == instance1.available_balance + + instance2 = ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.DEBIT": 10_000, + "bp_payout.DEBIT": 10_000, + } + ) + assert 50_000 == instance2.payout + assert -10_000 == instance2.adjustment + assert 10_000 == instance2.payment + + assert 40_000 == instance2.net + assert 30_000 == instance2.balance + assert 22_500 == instance2.available_balance + + # Now confirm that they're correct in the BusinessBalance + instance = BusinessBalances.model_validate( + {"product_balances": [instance1, instance2]} + ) + assert 100_000 == instance.payout + assert -60_001 == instance.adjustment + assert 14_999 == instance.payment + + assert 39_999 == instance.net + assert 25_000 == instance.balance + + # Compare the retainers together. We can't just calculate the retainer + # on the Business.balance because it'll be "masked" by any Products + # that have a negative balance and actually reduce the Business's + # retainer as a whole. Therefore, we need to sum together each of the + # retainers from the child Products + assert 0 == instance1.retainer + assert 7_500 == instance2.retainer + assert 6_250 == instance.balance * 0.25 + assert 6_250 != instance.retainer + assert 7_500 == instance.retainer + assert 25_000 - 7_500 == instance.available_balance + + def test_str(self): + instance = BusinessBalances.model_validate( + { + "product_balances": [ + ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.DEBIT": 50_001, + "bp_payout.DEBIT": 4_999, + } + ), + ProductBalances.model_validate( + { + "product_id": uuid4().hex, + "bp_payment.CREDIT": 50_000, + "bp_adjustment.DEBIT": 10_000, + "bp_payout.DEBIT": 10_000, + } + ), + ] + } + ) + + assert "Products: 2" in str(instance) + assert "Total Adjustment: -$600.01" in str(instance) + assert "Available Balance: $175.00" in str(instance) + + def test_from_json(self): + s = '{"product_balances":[{"product_id":"7485124190274248bc14132755c8fc3b","bp_payment_credit":1184,"adjustment_credit":0,"adjustment_debit":0,"supplier_credit":0,"supplier_debit":0,"user_bonus_credit":0,"user_bonus_debit":0,"payout":1184,"adjustment":0,"expense":0,"net":1184,"payment":0,"balance":1184,"retainer":296,"available_balance":888,"adjustment_percent":0.0}],"payout":1184,"adjustment":0,"expense":0,"net":1184,"payment":0,"balance":1184,"retainer":296,"available_balance":888,"adjustment_percent":0.0}' + instance = BusinessBalances.model_validate_json(s) + + assert instance.payout == 1184 + assert instance.available_balance == 888 + assert instance.retainer == 296 + assert len(instance.product_balances) == 1 + assert instance.adjustment_percent == 0.0 + assert instance.expense == 0 + + p = instance.product_balances[0] + assert p.payout == 1184 + assert p.available_balance == 888 + assert p.retainer == 296 + + +@pytest.mark.parametrize( + argnames="offset, duration", + argvalues=list( + iter_product( + ["12h", "2D"], + [timedelta(days=2), timedelta(days=5)], + ) + ), +) +class TestProductFinanceData: + + def test_base( + self, + client_no_amm, + ledger_collection, + pop_ledger_merge, + mnt_filepath, + session_with_tx_factory, + product, + user_factory, + start, + duration, + delete_df_collection, + thl_lm, + create_main_accounts, + ): + from generalresearch.models.thl.user import User + + # -- Build & Setup + # assert ledger_collection.start is None + # assert ledger_collection.offset is None + u: User = user_factory(product=product, created=ledger_collection.start) + + for item in ledger_collection.items: + + for s_idx in range(3): + rand_item_time = fake.date_time_between( + start_date=item.start, + end_date=item.finish, + tzinfo=timezone.utc, + ) + session_with_tx_factory(started=rand_item_time, user=u) + + item.initial_load(overwrite=True) + + # Confirm any of the items are archived + assert ledger_collection.progress.has_archive.eq(True).all() + + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + # assert pop_ledger_merge.progress.has_archive.eq(True).all() + + item_finishes = [i.finish for i in ledger_collection.items] + item_finishes.sort(reverse=True) + last_item_finish = item_finishes[0] + + # -- + account = thl_lm.get_account_or_create_bp_wallet(product=u.product) + + ddf = pop_ledger_merge.ddf( + force_rr_latest=False, + include_partial=True, + columns=numerical_col_names + ["time_idx", "account_id"], + filters=[ + ("account_id", "==", account.uuid), + ("time_idx", ">=", start), + ("time_idx", "<", start + duration), + ], + ) + + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + assert isinstance(df, pd.DataFrame) + assert not df.empty + + # -- + + df = df.groupby([pd.Grouper(key="time_idx", freq="D"), "account_id"]).sum() + res = POPFinancial.list_from_pandas(df, accounts=[account]) + + assert isinstance(res, list) + assert isinstance(res[0], POPFinancial) + + # On this, we can assert all products are the same, and that there are + # no overlapping time intervals + assert 1 == len(set(list([i.product_id for i in res]))) + assert len(res) == len(set(list([i.time for i in res]))) + + +@pytest.mark.parametrize( + argnames="offset, duration", + argvalues=list( + iter_product( + ["12h", "2D"], + [timedelta(days=2), timedelta(days=5)], + ) + ), +) +class TestPOPFinancialData: + + def test_base( + self, + client_no_amm, + ledger_collection, + pop_ledger_merge, + mnt_filepath, + user_factory, + product, + start, + duration, + create_main_accounts, + session_with_tx_factory, + session_manager, + thl_lm, + delete_df_collection, + delete_ledger_db, + ): + # -- Build & Setup + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + # assert ledger_collection.start is None + # assert ledger_collection.offset is None + + users = [] + for idx in range(5): + u = user_factory(product=product) + + for item in ledger_collection.items: + rand_item_time = fake.date_time_between( + start_date=item.start, + end_date=item.finish, + tzinfo=timezone.utc, + ) + + session_with_tx_factory(started=rand_item_time, user=u) + item.initial_load(overwrite=True) + + users.append(u) + + # Confirm any of the items are archived + assert ledger_collection.progress.has_archive.eq(True).all() + + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + # assert pop_ledger_merge.progress.has_archive.eq(True).all() + + item_finishes = [i.finish for i in ledger_collection.items] + item_finishes.sort(reverse=True) + last_item_finish = item_finishes[0] + + accounts = [] + for user in users: + account = thl_lm.get_account_or_create_bp_wallet(product=u.product) + accounts.append(account) + account_ids = [a.uuid for a in accounts] + + # -- + + ddf = pop_ledger_merge.ddf( + force_rr_latest=False, + include_partial=True, + columns=numerical_col_names + ["time_idx", "account_id"], + filters=[ + ("account_id", "in", account_ids), + ("time_idx", ">=", start), + ("time_idx", "<", last_item_finish), + ], + ) + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + df = df.groupby([pd.Grouper(key="time_idx", freq="D"), "account_id"]).sum() + res = POPFinancial.list_from_pandas(df, accounts=accounts) + + assert isinstance(res, list) + for i in res: + assert isinstance(i, POPFinancial) + + # This does not return the AccountID, it's the Product ID + assert i.product_id in [u.product_id for u in users] + + # 1 Product, multiple Users + assert len(users) == len(accounts) + + # We group on days, and duration is a parameter to parametrize + assert isinstance(duration, timedelta) + + # -- Teardown + delete_df_collection(ledger_collection) + + +@pytest.mark.parametrize( + argnames="offset, duration", + argvalues=list( + iter_product( + ["12h", "1D"], + [timedelta(days=2), timedelta(days=3)], + ) + ), +) +class TestBusinessBalanceData: + def test_from_pandas( + self, + client_no_amm, + ledger_collection, + pop_ledger_merge, + user_factory, + product, + create_main_accounts, + session_factory, + thl_lm, + session_manager, + start, + thl_web_rr, + duration, + delete_df_collection, + delete_ledger_db, + session_with_tx_factory, + offset, + rm_ledger_collection, + ): + from generalresearch.models.thl.user import User + from generalresearch.models.thl.ledger import LedgerAccount + + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + rm_ledger_collection() + + for idx in range(5): + u: User = user_factory(product=product, created=ledger_collection.start) + + for item in ledger_collection.items: + item_time = fake.date_time_between( + start_date=item.start, + end_date=item.finish, + tzinfo=timezone.utc, + ) + session_with_tx_factory(started=item_time, user=u) + item.initial_load(overwrite=True) + + # Confirm any of the items are archived + assert ledger_collection.progress.has_archive.eq(True).all() + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + # assert pop_ledger_merge.progress.has_archive.eq(True).all() + + account: LedgerAccount = thl_lm.get_account_or_create_bp_wallet(product=product) + + ddf = pop_ledger_merge.ddf( + force_rr_latest=False, + include_partial=True, + columns=numerical_col_names + ["account_id"], + filters=[("account_id", "in", [account.uuid])], + ) + ddf = ddf.groupby("account_id").sum() + df: pd.DataFrame = client_no_amm.compute(collections=ddf, sync=True) + + assert isinstance(df, pd.DataFrame) + + instance = BusinessBalances.from_pandas( + input_data=df, accounts=[account], thl_pg_config=thl_web_rr + ) + balance: int = thl_lm.get_account_balance(account=account) + + assert instance.balance == balance + assert instance.net == balance + assert instance.payout == balance + + assert instance.payment == 0 + assert instance.adjustment == 0 + assert instance.adjustment_percent == 0.0 + + assert instance.expense == 0 + + # Cleanup + delete_ledger_db() + delete_df_collection(coll=ledger_collection) diff --git a/tests/models/thl/__init__.py b/tests/models/thl/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/models/thl/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/models/thl/question/__init__.py b/tests/models/thl/question/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/thl/question/test_question_info.py b/tests/models/thl/question/test_question_info.py new file mode 100644 index 0000000..945ee7a --- /dev/null +++ b/tests/models/thl/question/test_question_info.py @@ -0,0 +1,146 @@ +from generalresearch.models.thl.profiling.upk_property import ( + UpkProperty, + ProfilingInfo, +) + + +class TestQuestionInfo: + + def test_init(self): + + s = ( + '[{"property_label": "hispanic", "cardinality": "*", "prop_type": "i", "country_iso": "us", ' + '"property_id": "05170ae296ab49178a075cab2a2073a6", "item_id": "7911ec1468b146ee870951f8ae9cbac1", ' + '"item_label": "panamanian", "gold_standard": 1, "options": [{"id": "c358c11e72c74fa2880358f1d4be85ab", ' + '"label": "not_hispanic"}, {"id": "b1d6c475770849bc8e0200054975dc9c", "label": "yes_hispanic"}, ' + '{"id": "bd1eb44495d84b029e107c188003c2bd", "label": "other_hispanic"}, ' + '{"id": "f290ad5e75bf4f4ea94dc847f57c1bd3", "label": "mexican"}, ' + '{"id": "49f50f2801bd415ea353063bfc02d252", "label": "puerto_rican"}, ' + '{"id": "dcbe005e522f4b10928773926601f8bf", "label": "cuban"}, ' + '{"id": "467ef8ddb7ac4edb88ba9ef817cbb7e9", "label": "salvadoran"}, ' + '{"id": "3c98e7250707403cba2f4dc7b877c963", "label": "dominican"}, ' + '{"id": "981ee77f6d6742609825ef54fea824a8", "label": "guatemalan"}, ' + '{"id": "81c8057b809245a7ae1b8a867ea6c91e", "label": "colombian"}, ' + '{"id": "513656d5f9e249fa955c3b527d483b93", "label": "honduran"}, ' + '{"id": "afc8cddd0c7b4581bea24ccd64db3446", "label": "ecuadorian"}, ' + '{"id": "61f34b36e80747a89d85e1eb17536f84", "label": "argentinian"}, ' + '{"id": "5330cfa681d44aa8ade3a6d0ea198e44", "label": "peruvian"}, ' + '{"id": "e7bceaffd76e486596205d8545019448", "label": "nicaraguan"}, ' + '{"id": "b7bbb2ebf8424714962e6c4f43275985", "label": "spanish"}, ' + '{"id": "8bf539785e7a487892a2f97e52b1932d", "label": "venezuelan"}, ' + '{"id": "7911ec1468b146ee870951f8ae9cbac1", "label": "panamanian"}], "category": [{"id": ' + '"4fd8381d5a1c4409ab007ca254ced084", "label": "Demographic", "path": "/Demographic", ' + '"adwords_vertical_id": null}]}, {"property_label": "ethnic_group", "cardinality": "*", "prop_type": ' + '"i", "country_iso": "us", "property_id": "15070958225d4132b7f6674fcfc979f6", "item_id": ' + '"64b7114cf08143949e3bcc3d00a5d8a0", "item_label": "other_ethnicity", "gold_standard": 1, "options": [{' + '"id": "a72e97f4055e4014a22bee4632cbf573", "label": "caucasians"}, ' + '{"id": "4760353bc0654e46a928ba697b102735", "label": "black_or_african_american"}, ' + '{"id": "20ff0a2969fa4656bbda5c3e0874e63b", "label": "asian"}, ' + '{"id": "107e0a79e6b94b74926c44e70faf3793", "label": "native_hawaiian_or_other_pacific_islander"}, ' + '{"id": "900fa12691d5458c8665bf468f1c98c1", "label": "native_americans"}, ' + '{"id": "64b7114cf08143949e3bcc3d00a5d8a0", "label": "other_ethnicity"}], "category": [{"id": ' + '"4fd8381d5a1c4409ab007ca254ced084", "label": "Demographic", "path": "/Demographic", ' + '"adwords_vertical_id": null}]}, {"property_label": "educational_attainment", "cardinality": "?", ' + '"prop_type": "i", "country_iso": "us", "property_id": "2637783d4b2b4075b93e2a156e16e1d8", "item_id": ' + '"934e7b81d6744a1baa31bbc51f0965d5", "item_label": "other_education", "gold_standard": 1, "options": [{' + '"id": "df35ef9e474b4bf9af520aa86630202d", "label": "3rd_grade_completion"}, ' + '{"id": "83763370a1064bd5ba76d1b68c4b8a23", "label": "8th_grade_completion"}, ' + '{"id": "f0c25a0670c340bc9250099dcce50957", "label": "not_high_school_graduate"}, ' + '{"id": "02ff74c872bd458983a83847e1a9f8fd", "label": "high_school_completion"}, ' + '{"id": "ba8beb807d56441f8fea9b490ed7561c", "label": "vocational_program_completion"}, ' + '{"id": "65373a5f348a410c923e079ddbb58e9b", "label": "some_college_completion"}, ' + '{"id": "2d15d96df85d4cc7b6f58911fdc8d5e2", "label": "associate_academic_degree_completion"}, ' + '{"id": "497b1fedec464151b063cd5367643ffa", "label": "bachelors_degree_completion"}, ' + '{"id": "295133068ac84424ae75e973dc9f2a78", "label": "some_graduate_completion"}, ' + '{"id": "e64f874faeff4062a5aa72ac483b4b9f", "label": "masters_degree_completion"}, ' + '{"id": "cbaec19a636d476385fb8e7842b044f5", "label": "doctorate_degree_completion"}, ' + '{"id": "934e7b81d6744a1baa31bbc51f0965d5", "label": "other_education"}], "category": [{"id": ' + '"4fd8381d5a1c4409ab007ca254ced084", "label": "Demographic", "path": "/Demographic", ' + '"adwords_vertical_id": null}]}, {"property_label": "household_spoken_language", "cardinality": "*", ' + '"prop_type": "i", "country_iso": "us", "property_id": "5a844571073d482a96853a0594859a51", "item_id": ' + '"62b39c1de141422896ad4ab3c4318209", "item_label": "dut", "gold_standard": 1, "options": [{"id": ' + '"f65cd57b79d14f0f8460761ce41ec173", "label": "ara"}, {"id": "6d49de1f8f394216821310abd29392d9", ' + '"label": "zho"}, {"id": "be6dc23c2bf34c3f81e96ddace22800d", "label": "eng"}, ' + '{"id": "ddc81f28752d47a3b1c1f3b8b01a9b07", "label": "fre"}, {"id": "2dbb67b29bd34e0eb630b1b8385542ca", ' + '"label": "ger"}, {"id": "a747f96952fc4b9d97edeeee5120091b", "label": "hat"}, ' + '{"id": "7144b04a3219433baac86273677551fa", "label": "hin"}, {"id": "e07ff3e82c7149eaab7ea2b39ee6a6dc", ' + '"label": "ita"}, {"id": "b681eff81975432ebfb9f5cc22dedaa3", "label": "jpn"}, ' + '{"id": "5cb20440a8f64c9ca62fb49c1e80cdef", "label": "kor"}, {"id": "171c4b77d4204bc6ac0c2b81e38a10ff", ' + '"label": "pan"}, {"id": "8c3ec18e6b6c4a55a00dd6052e8e84fb", "label": "pol"}, ' + '{"id": "3ce074d81d384dd5b96f1fb48f87bf01", "label": "por"}, {"id": "6138dc951990458fa88a666f6ddd907b", ' + '"label": "rus"}, {"id": "e66e5ecc07df4ebaa546e0b436f034bd", "label": "spa"}, ' + '{"id": "5a981b3d2f0d402a96dd2d0392ec2fcb", "label": "tgl"}, {"id": "b446251bd211403487806c4d0a904981", ' + '"label": "vie"}, {"id": "92fb3ee337374e2db875fb23f52eed46", "label": "xxx"}, ' + '{"id": "8b1f590f12f24cc1924d7bdcbe82081e", "label": "ind"}, {"id": "bf3f4be556a34ff4b836420149fd2037", ' + '"label": "tur"}, {"id": "87ca815c43ba4e7f98cbca98821aa508", "label": "zul"}, ' + '{"id": "0adbf915a7a64d67a87bb3ce5d39ca54", "label": "may"}, {"id": "62b39c1de141422896ad4ab3c4318209", ' + '"label": "dut"}], "category": [{"id": "4fd8381d5a1c4409ab007ca254ced084", "label": "Demographic", ' + '"path": "/Demographic", "adwords_vertical_id": null}]}, {"property_label": "gender", "cardinality": ' + '"?", "prop_type": "i", "country_iso": "us", "property_id": "73175402104741549f21de2071556cd7", ' + '"item_id": "093593e316344cd3a0ac73669fca8048", "item_label": "other_gender", "gold_standard": 1, ' + '"options": [{"id": "b9fc5ea07f3a4252a792fd4a49e7b52b", "label": "male"}, ' + '{"id": "9fdb8e5e18474a0b84a0262c21e17b56", "label": "female"}, ' + '{"id": "093593e316344cd3a0ac73669fca8048", "label": "other_gender"}], "category": [{"id": ' + '"4fd8381d5a1c4409ab007ca254ced084", "label": "Demographic", "path": "/Demographic", ' + '"adwords_vertical_id": null}]}, {"property_label": "age_in_years", "cardinality": "?", "prop_type": ' + '"n", "country_iso": "us", "property_id": "94f7379437874076b345d76642d4ce6d", "item_id": null, ' + '"item_label": null, "gold_standard": 1, "category": [{"id": "4fd8381d5a1c4409ab007ca254ced084", ' + '"label": "Demographic", "path": "/Demographic", "adwords_vertical_id": null}]}, {"property_label": ' + '"children_age_gender", "cardinality": "*", "prop_type": "i", "country_iso": "us", "property_id": ' + '"e926142fcea94b9cbbe13dc7891e1e7f", "item_id": "b7b8074e95334b008e8958ccb0a204f1", "item_label": ' + '"female_18", "gold_standard": 1, "options": [{"id": "16a6448ec24c48d4993d78ebee33f9b4", ' + '"label": "male_under_1"}, {"id": "809c04cb2e3b4a3bbd8077ab62cdc220", "label": "female_under_1"}, ' + '{"id": "295e05bb6a0843bc998890b24c99841e", "label": "no_children"}, ' + '{"id": "142cb948d98c4ae8b0ef2ef10978e023", "label": "male_0"}, ' + '{"id": "5a5c1b0e9abc48a98b3bc5f817d6e9d0", "label": "male_1"}, ' + '{"id": "286b1a9afb884bdfb676dbb855479d1e", "label": "male_2"}, ' + '{"id": "942ca3cda699453093df8cbabb890607", "label": "male_3"}, ' + '{"id": "995818d432f643ec8dd17e0809b24b56", "label": "male_4"}, ' + '{"id": "f38f8b57f25f4cdea0f270297a1e7a5c", "label": "male_5"}, ' + '{"id": "975df709e6d140d1a470db35023c432d", "label": "male_6"}, ' + '{"id": "f60bd89bbe0f4e92b90bccbc500467c2", "label": "male_7"}, ' + '{"id": "6714ceb3ed5042c0b605f00b06814207", "label": "male_8"}, ' + '{"id": "c03c2f8271d443cf9df380e84b4dea4c", "label": "male_9"}, ' + '{"id": "11690ee0f5a54cb794f7ddd010d74fa2", "label": "male_10"}, ' + '{"id": "17bef9a9d14b4197b2c5609fa94b0642", "label": "male_11"}, ' + '{"id": "e79c8338fe28454f89ccc78daf6f409a", "label": "male_12"}, ' + '{"id": "3a4f87acb3fa41f4ae08dfe2858238c1", "label": "male_13"}, ' + '{"id": "36ffb79d8b7840a7a8cb8d63bbc8df59", "label": "male_14"}, ' + '{"id": "1401a508f9664347aee927f6ec5b0a40", "label": "male_15"}, ' + '{"id": "6e0943c5ec4a4f75869eb195e3eafa50", "label": "male_16"}, ' + '{"id": "47d4b27b7b5242758a9fff13d3d324cf", "label": "male_17"}, ' + '{"id": "9ce886459dd44c9395eb77e1386ab181", "label": "female_0"}, ' + '{"id": "6499ccbf990d4be5b686aec1c7353fd8", "label": "female_1"}, ' + '{"id": "d85ceaa39f6d492abfc8da49acfd14f2", "label": "female_2"}, ' + '{"id": "18edb45c138e451d8cb428aefbb80f9c", "label": "female_3"}, ' + '{"id": "bac6f006ed9f4ccf85f48e91e99fdfd1", "label": "female_4"}, ' + '{"id": "5a6a1a8ad00c4ce8be52dcb267b034ff", "label": "female_5"}, ' + '{"id": "6bff0acbf6364c94ad89507bcd5f4f45", "label": "female_6"}, ' + '{"id": "d0d56a0a6b6f4516a366a2ce139b4411", "label": "female_7"}, ' + '{"id": "bda6028468044b659843e2bef4db2175", "label": "female_8"}, ' + '{"id": "dbb6d50325464032b456357b1a6e5e9c", "label": "female_9"}, ' + '{"id": "b87a93d7dc1348edac5e771684d63fb8", "label": "female_10"}, ' + '{"id": "11449d0d98f14e27ba47de40b18921d7", "label": "female_11"}, ' + '{"id": "16156501e97b4263962cbbb743840292", "label": "female_12"}, ' + '{"id": "04ee971c89a345cc8141a45bce96050c", "label": "female_13"}, ' + '{"id": "e818d310bfbc4faba4355e5d2ed49d4f", "label": "female_14"}, ' + '{"id": "440d25e078924ba0973163153c417ed6", "label": "female_15"}, ' + '{"id": "78ff804cc9b441c5a524bd91e3d1f8bf", "label": "female_16"}, ' + '{"id": "4b04d804d7d84786b2b1c22e4ed440f5", "label": "female_17"}, ' + '{"id": "28bc848cd3ff44c3893c76bfc9bc0c4e", "label": "male_18"}, ' + '{"id": "b7b8074e95334b008e8958ccb0a204f1", "label": "female_18"}], "category": [{"id": ' + '"e18ba6e9d51e482cbb19acf2e6f505ce", "label": "Parenting", "path": "/People & Society/Family & ' + 'Relationships/Family/Parenting", "adwords_vertical_id": "58"}]}, {"property_label": "home_postal_code", ' + '"cardinality": "?", "prop_type": "x", "country_iso": "us", "property_id": ' + '"f3b32ebe78014fbeb1ed6ff77d6338bf", "item_id": null, "item_label": null, "gold_standard": 1, ' + '"category": [{"id": "4fd8381d5a1c4409ab007ca254ced084", "label": "Demographic", "path": "/Demographic", ' + '"adwords_vertical_id": null}]}, {"property_label": "household_income", "cardinality": "?", "prop_type": ' + '"n", "country_iso": "us", "property_id": "ff5b1d4501d5478f98de8c90ef996ac1", "item_id": null, ' + '"item_label": null, "gold_standard": 1, "category": [{"id": "4fd8381d5a1c4409ab007ca254ced084", ' + '"label": "Demographic", "path": "/Demographic", "adwords_vertical_id": null}]}]' + ) + instance_list = ProfilingInfo.validate_json(s) + + assert isinstance(instance_list, list) + for i in instance_list: + assert isinstance(i, UpkProperty) diff --git a/tests/models/thl/question/test_user_info.py b/tests/models/thl/question/test_user_info.py new file mode 100644 index 0000000..0bbbc78 --- /dev/null +++ b/tests/models/thl/question/test_user_info.py @@ -0,0 +1,32 @@ +from generalresearch.models.thl.profiling.user_info import UserInfo + + +class TestUserInfo: + + def test_init(self): + + s = ( + '{"user_profile_knowledge": [], "marketplace_profile_knowledge": [{"source": "d", "question_id": ' + '"1", "answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "pr", ' + '"question_id": "3", "answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": ' + '"h", "question_id": "60", "answer": ["58"], "created": "2023-11-07T16:41:05.234096Z"}, ' + '{"source": "c", "question_id": "43", "answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, ' + '{"source": "s", "question_id": "211", "answer": ["111"], "created": ' + '"2023-11-07T16:41:05.234096Z"}, {"source": "s", "question_id": "1843", "answer": ["111"], ' + '"created": "2023-11-07T16:41:05.234096Z"}, {"source": "h", "question_id": "13959", "answer": [' + '"244155"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "c", "question_id": "33092", ' + '"answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "c", "question_id": "gender", ' + '"answer": ["10682"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "e", "question_id": ' + '"gender", "answer": ["male"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "f", ' + '"question_id": "gender", "answer": ["male"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": ' + '"i", "question_id": "gender", "answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, ' + '{"source": "c", "question_id": "137510", "answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, ' + '{"source": "m", "question_id": "gender", "answer": ["1"], "created": ' + '"2023-11-07T16:41:05.234096Z"}, {"source": "o", "question_id": "gender", "answer": ["male"], ' + '"created": "2023-11-07T16:41:05.234096Z"}, {"source": "c", "question_id": "gender_plus", "answer": [' + '"7657644"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "i", "question_id": ' + '"gender_plus", "answer": ["1"], "created": "2023-11-07T16:41:05.234096Z"}, {"source": "c", ' + '"question_id": "income_level", "answer": ["9071"], "created": "2023-11-07T16:41:05.234096Z"}]}' + ) + instance = UserInfo.model_validate_json(s) + assert isinstance(instance, UserInfo) diff --git a/tests/models/thl/test_adjustments.py b/tests/models/thl/test_adjustments.py new file mode 100644 index 0000000..15d01d0 --- /dev/null +++ b/tests/models/thl/test_adjustments.py @@ -0,0 +1,688 @@ +from datetime import datetime, timezone, timedelta +from decimal import Decimal + +import pytest + +from generalresearch.models import Source +from generalresearch.models.thl.session import ( + Wall, + Status, + StatusCode1, + WallAdjustedStatus, + SessionAdjustedStatus, +) + +started1 = datetime(2023, 1, 1, tzinfo=timezone.utc) +started2 = datetime(2023, 1, 1, 0, 10, 0, tzinfo=timezone.utc) +finished1 = started1 + timedelta(minutes=10) +finished2 = started2 + timedelta(minutes=10) + +adj_ts = datetime(2023, 2, 2, tzinfo=timezone.utc) +adj_ts2 = datetime(2023, 2, 3, tzinfo=timezone.utc) +adj_ts3 = datetime(2023, 2, 4, tzinfo=timezone.utc) + + +class TestProductAdjustments: + + @pytest.mark.parametrize("payout", [".6", "1", "1.8", "2", "500.0000"]) + def test_determine_bp_payment_no_rounding(self, product_factory, payout): + p1 = product_factory(commission_pct=Decimal("0.05")) + res = p1.determine_bp_payment(thl_net=Decimal(payout)) + assert isinstance(res, Decimal) + assert res == Decimal(payout) * Decimal("0.95") + + @pytest.mark.parametrize("payout", [".01", ".05", ".5"]) + def test_determine_bp_payment_rounding(self, product_factory, payout): + p1 = product_factory(commission_pct=Decimal("0.05")) + res = p1.determine_bp_payment(thl_net=Decimal(payout)) + assert isinstance(res, Decimal) + assert res != Decimal(payout) * Decimal("0.95") + + +class TestSessionAdjustments: + + def test_status_complete(self, session_factory, user): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpi=Decimal(1), + final_status=Status.COMPLETE, + started=started1, + ) + + # Confirm only the last Wall Event is a complete + assert not s1.wall_events[0].status == Status.COMPLETE + assert s1.wall_events[1].status == Status.COMPLETE + + # Confirm the Session is marked as finished and the simple brokerage + # payout calculation is correct. + status, status_code_1 = s1.determine_session_status() + assert status == Status.COMPLETE + assert status_code_1 == StatusCode1.COMPLETE + + +class TestAdjustments: + + def test_finish_with_status(self, session_factory, user, session_manager): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpi=Decimal(1), + final_status=Status.COMPLETE, + started=started1, + ) + + status, status_code_1 = s1.determine_session_status() + payout = user.product.determine_bp_payment(Decimal(1)) + session_manager.finish_with_status( + session=s1, + status=status, + status_code_1=status_code_1, + payout=payout, + finished=finished2, + ) + + assert Decimal("0.95") == payout + + def test_never_adjusted(self, session_factory, user, session_manager): + s1 = session_factory( + user=user, + wall_count=5, + wall_req_cpi=Decimal(1), + final_status=Status.COMPLETE, + started=started1, + ) + + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("0.95"), + finished=finished2, + ) + + # Confirm walls and Session are never adjusted in anyway + for w in s1.wall_events: + w: Wall + assert w.adjusted_status is None + assert w.adjusted_timestamp is None + assert w.adjusted_cpi is None + + assert s1.adjusted_status is None + assert s1.adjusted_payout is None + assert s1.adjusted_timestamp is None + + def test_adjustment_wall_values( + self, session_factory, user, session_manager, wall_manager + ): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=5, + wall_req_cpi=Decimal(1), + final_status=Status.COMPLETE, + started=started1, + ) + + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("0.95"), + finished=finished2, + ) + + # Change the last wall event to a Failure + w: Wall = s1.wall_events[-1] + wall_manager.adjust_status( + wall=w, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=adj_ts, + ) + + # Original Session and Wall status is still the same, but the Adjusted + # values have changed + assert s1.status == Status.COMPLETE + assert s1.adjusted_status is None + assert s1.adjusted_timestamp is None + assert s1.adjusted_payout is None + assert s1.adjusted_user_payout is None + + assert w.status == Status.COMPLETE + assert w.status_code_1 == StatusCode1.COMPLETE + assert w.adjusted_status == WallAdjustedStatus.ADJUSTED_TO_FAIL + assert w.adjusted_cpi == Decimal(0) + assert w.adjusted_timestamp == adj_ts + + # Because the Product doesn't have the Wallet mode enabled, the + # user_payout fields should always be None + assert not user.product.user_wallet_config.enabled + assert s1.adjusted_user_payout is None + + def test_adjustment_session_values( + self, wall_manager, session_manager, session_factory, user + ): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpi=Decimal(1), + wall_source=Source.DYNATA, + final_status=Status.COMPLETE, + started=started1, + ) + + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("0.95"), + finished=finished2, + ) + + # Change the last wall event to a Failure + wall_manager.adjust_status( + wall=s1.wall_events[-1], + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=adj_ts, + ) + + # Refresh the Session with the new Wall Adjustment considerations, + session_manager.adjust_status(session=s1) + assert s1.status == Status.COMPLETE # Original status should remain + assert s1.adjusted_status == SessionAdjustedStatus.ADJUSTED_TO_FAIL + assert s1.adjusted_payout == Decimal(0) + assert s1.adjusted_timestamp == adj_ts + + # Because the Product doesn't have the Wallet mode enabled, the + # user_payout fields should always be None + assert not user.product.user_wallet_config.enabled + assert s1.adjusted_user_payout is None + + def test_double_adjustment_session_values( + self, wall_manager, session_manager, session_factory, user + ): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpi=Decimal(1), + final_status=Status.COMPLETE, + started=started1, + ) + + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("0.95"), + finished=finished2, + ) + + # Change the last wall event to a Failure + w: Wall = s1.wall_events[-1] + wall_manager.adjust_status( + wall=w, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=adj_ts, + ) + + # Refresh the Session with the new Wall Adjustment considerations, + session_manager.adjust_status(session=s1) + + # Let's take that back again! Buyers love to do this. + # So now we're going to "un-reconcile" the last Wall Event which has + # already gone from a Complete >> Failure + wall_manager.adjust_status( + wall=w, adjusted_status=None, adjusted_timestamp=adj_ts2 + ) + assert w.adjusted_status is None + assert w.adjusted_cpi is None + assert w.adjusted_timestamp == adj_ts2 + + # Once the wall was unreconciled, "refresh" the Session again + assert s1.adjusted_status is not None + session_manager.adjust_status(session=s1) + assert s1.adjusted_status is None + assert s1.adjusted_payout is None + assert s1.adjusted_timestamp == adj_ts2 + assert s1.adjusted_user_payout is None + + def test_double_adjustment_sm_vs_db_values( + self, wall_manager, session_manager, session_factory, user + ): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpi=Decimal(1), + wall_source=Source.DYNATA, + final_status=Status.COMPLETE, + started=started1, + ) + + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("0.95"), + finished=finished2, + ) + + # Change the last wall event to a Failure + wall_manager.adjust_status( + wall=s1.wall_events[-1], + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=adj_ts, + ) + + # Refresh the Session with the new Wall Adjustment considerations, + session_manager.adjust_status(session=s1) + + # Let's take that back again! Buyers love to do this. + # So now we're going to "un-reconcile" the last Wall Event which has + # already gone from a Complete >> Failure + # Once the wall was unreconciled, "refresh" the Session again + wall_manager.adjust_status( + wall=s1.wall_events[-1], adjusted_status=None, adjusted_timestamp=adj_ts2 + ) + session_manager.adjust_status(session=s1) + + # Confirm that the sessions wall attributes are still aligned with + # what comes back directly from the database + db_wall_events = wall_manager.get_wall_events(session_id=s1.id) + for idx in range(len(s1.wall_events)): + w_sm: Wall = s1.wall_events[idx] + w_db: Wall = db_wall_events[idx] + + assert w_sm.uuid == w_db.uuid + assert w_sm.session_id == w_db.session_id + assert w_sm.status == w_db.status + assert w_sm.status_code_1 == w_db.status_code_1 + assert w_sm.status_code_2 == w_db.status_code_2 + + assert w_sm.elapsed == w_db.elapsed + + # Decimal("1.000000") vs Decimal(1) - based on mysql or postgres + assert pytest.approx(w_sm.cpi) == w_db.cpi + assert pytest.approx(w_sm.req_cpi) == w_db.req_cpi + + assert w_sm.model_dump_json( + exclude={"cpi", "req_cpi"} + ) == w_db.model_dump_json(exclude={"cpi", "req_cpi"}) + + def test_double_adjustment_double_completes( + self, wall_manager, session_manager, session_factory, user + ): + # Completed Session with 2 wall events + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpi=Decimal(2), + wall_source=Source.DYNATA, + final_status=Status.COMPLETE, + started=started1, + ) + + session_manager.finish_with_status( + session=s1, + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + payout=Decimal("0.95"), + finished=finished2, + ) + + # Change the last wall event to a Failure + wall_manager.adjust_status( + wall=s1.wall_events[-1], + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_timestamp=adj_ts, + ) + + # Refresh the Session with the new Wall Adjustment considerations, + session_manager.adjust_status(session=s1) + + # Let's take that back again! Buyers love to do this. + # So now we're going to "un-reconcile" the last Wall Event which has + # already gone from a Complete >> Failure + # Once the wall was unreconciled, "refresh" the Session again + wall_manager.adjust_status( + wall=s1.wall_events[-1], adjusted_status=None, adjusted_timestamp=adj_ts2 + ) + session_manager.adjust_status(session=s1) + + # Reassign them - we already validated they're equal in previous + # tests so this is safe to do. + s1.wall_events = wall_manager.get_wall_events(session_id=s1.id) + + # The First Wall event was originally a Failure, now let's also set + # that as a complete, so now both Wall Events will b a + # complete (Fail >> Adj to Complete, Complete >> Adj to Fail >> Adj to Complete) + w1: Wall = s1.wall_events[0] + assert w1.status == Status.FAIL + assert w1.adjusted_status is None + assert w1.adjusted_cpi is None + assert w1.adjusted_timestamp is None + + wall_manager.adjust_status( + wall=w1, + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_timestamp=adj_ts3, + ) + + assert w1.status == Status.FAIL # original status doesn't change + assert w1.adjusted_status == WallAdjustedStatus.ADJUSTED_TO_COMPLETE + assert w1.adjusted_cpi == w1.cpi + assert w1.adjusted_timestamp == adj_ts3 + + session_manager.adjust_status(s1) + assert SessionAdjustedStatus.PAYOUT_ADJUSTMENT == s1.adjusted_status + assert Decimal("3.80") == s1.adjusted_payout + assert s1.adjusted_user_payout is None + assert adj_ts3 == s1.adjusted_timestamp + + def test_complete_to_fail( + self, session_factory, user, session_manager, wall_manager, utc_hour_ago + ): + s1 = session_factory( + user=user, + wall_count=1, + wall_req_cpi=Decimal("1"), + final_status=Status.COMPLETE, + started=utc_hour_ago, + ) + + status, status_code_1 = s1.determine_session_status() + assert status == Status.COMPLETE + + thl_net = Decimal(sum(w.cpi for w in s1.wall_events if w.is_visible_complete())) + payout = user.product.determine_bp_payment(thl_net=thl_net) + + session_manager.finish_with_status( + session=s1, + status=status, + status_code_1=status_code_1, + finished=utc_hour_ago + timedelta(minutes=10), + payout=payout, + user_payout=None, + ) + + w1 = s1.wall_events[0] + w1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=adj_ts, + ) + assert w1.adjusted_status == WallAdjustedStatus.ADJUSTED_TO_FAIL + assert w1.adjusted_cpi == Decimal(0) + + new_status, new_payout, new_user_payout = s1.determine_new_status_and_payouts() + assert Status.FAIL == new_status + assert Decimal(0) == new_payout + + assert not user.product.user_wallet_config.enabled + assert new_user_payout is None + + s1.adjust_status() + assert SessionAdjustedStatus.ADJUSTED_TO_FAIL == s1.adjusted_status + assert Decimal(0) == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + assert s1.adjusted_user_payout is None + + # cpi adjustment + w1.update( + adjusted_status=WallAdjustedStatus.CPI_ADJUSTMENT, + adjusted_cpi=Decimal("0.69"), + adjusted_timestamp=adj_ts, + ) + assert WallAdjustedStatus.CPI_ADJUSTMENT == w1.adjusted_status + assert Decimal("0.69") == w1.adjusted_cpi + new_status, new_payout, new_user_payout = s1.determine_new_status_and_payouts() + assert Status.COMPLETE == new_status + assert Decimal("0.66") == new_payout + + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.33") == new_user_payout + assert new_user_payout is None + + s1.adjust_status() + assert SessionAdjustedStatus.PAYOUT_ADJUSTMENT == s1.adjusted_status + assert Decimal("0.66") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.33") == s1.adjusted_user_payout + assert s1.adjusted_user_payout is None + + # adjust cpi again + wall_manager.adjust_status( + wall=w1, + adjusted_status=WallAdjustedStatus.CPI_ADJUSTMENT, + adjusted_cpi=Decimal("0.50"), + adjusted_timestamp=adj_ts, + ) + assert WallAdjustedStatus.CPI_ADJUSTMENT == w1.adjusted_status + assert Decimal("0.50") == w1.adjusted_cpi + new_status, new_payout, new_user_payout = s1.determine_new_status_and_payouts() + assert Status.COMPLETE == new_status + assert Decimal("0.48") == new_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.24") == new_user_payout + assert new_user_payout is None + + s1.adjust_status() + assert SessionAdjustedStatus.PAYOUT_ADJUSTMENT == s1.adjusted_status + assert Decimal("0.48") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.24") == s1.adjusted_user_payout + assert s1.adjusted_user_payout is None + + def test_complete_to_fail_to_complete(self, user, session_factory, utc_hour_ago): + # Setup: Complete, then adjust it to fail + s1 = session_factory( + user=user, + wall_count=1, + wall_req_cpi=Decimal("1"), + final_status=Status.COMPLETE, + started=utc_hour_ago, + ) + w1 = s1.wall_events[0] + + status, status_code_1 = s1.determine_session_status() + thl_net, commission_amount, bp_pay, user_pay = s1.determine_payments() + s1.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": utc_hour_ago + timedelta(minutes=10), + "payout": bp_pay, + "user_payout": user_pay, + } + ) + w1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=adj_ts, + ) + s1.adjust_status() + + # Test: Adjust back to complete + w1.update( + adjusted_status=None, + adjusted_cpi=None, + adjusted_timestamp=adj_ts, + ) + assert w1.adjusted_status is None + assert w1.adjusted_cpi is None + assert adj_ts == w1.adjusted_timestamp + + new_status, new_payout, new_user_payout = s1.determine_new_status_and_payouts() + assert Status.COMPLETE == new_status + assert Decimal("0.95") == new_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.48") == new_user_payout + assert new_user_payout is None + + s1.adjust_status() + assert s1.adjusted_status is None + assert s1.adjusted_payout is None + assert s1.adjusted_user_payout is None + + def test_complete_to_fail_to_complete_adj( + self, user, session_factory, utc_hour_ago + ): + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpis=[Decimal(1), Decimal(2)], + final_status=Status.COMPLETE, + started=utc_hour_ago, + ) + + w1 = s1.wall_events[0] + w2 = s1.wall_events[1] + + status, status_code_1 = s1.determine_session_status() + thl_net = Decimal(sum(w.cpi for w in s1.wall_events if w.is_visible_complete())) + payout = user.product.determine_bp_payment(thl_net=thl_net) + s1.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": utc_hour_ago + timedelta(minutes=25), + "payout": payout, + "user_payout": None, + } + ) + + # Test. Adjust first fail to complete. Now we have 2 completes. + w1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w1.cpi, + adjusted_timestamp=adj_ts, + ) + s1.adjust_status() + assert SessionAdjustedStatus.PAYOUT_ADJUSTMENT == s1.adjusted_status + assert Decimal("2.85") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("1.42") == s1.adjusted_user_payout + assert s1.adjusted_user_payout is None + + # Now we have [Fail, Complete ($2)] -> [Complete ($1), Fail] + w2.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=adj_ts2, + ) + s1.adjust_status() + assert SessionAdjustedStatus.PAYOUT_ADJUSTMENT == s1.adjusted_status + assert Decimal("0.95") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.48") == s1.adjusted_user_payout + assert s1.adjusted_user_payout is None + + def test_complete_to_fail_to_complete_adj1( + self, user, session_factory, utc_hour_ago + ): + # Same as test_complete_to_fail_to_complete_adj but in opposite order + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpis=[Decimal(1), Decimal(2)], + final_status=Status.COMPLETE, + started=utc_hour_ago, + ) + + w1 = s1.wall_events[0] + w2 = s1.wall_events[1] + + status, status_code_1 = s1.determine_session_status() + thl_net = Decimal(sum(w.cpi for w in s1.wall_events if w.is_visible_complete())) + payout = user.product.determine_bp_payment(thl_net) + s1.update( + **{ + "status": status, + "status_code_1": status_code_1, + "finished": utc_hour_ago + timedelta(minutes=25), + "payout": payout, + "user_payout": None, + } + ) + + # Test. Adjust complete to fail. Now we have 2 fails. + w2.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_FAIL, + adjusted_cpi=0, + adjusted_timestamp=adj_ts, + ) + s1.adjust_status() + assert SessionAdjustedStatus.ADJUSTED_TO_FAIL == s1.adjusted_status + assert Decimal(0) == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal(0) == s.adjusted_user_payout + assert s1.adjusted_user_payout is None + # Now we have [Fail, Complete ($2)] -> [Complete ($1), Fail] + w1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w1.cpi, + adjusted_timestamp=adj_ts2, + ) + s1.adjust_status() + assert SessionAdjustedStatus.PAYOUT_ADJUSTMENT == s1.adjusted_status + assert Decimal("0.95") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.48") == s.adjusted_user_payout + assert s1.adjusted_user_payout is None + + def test_fail_to_complete_to_fail(self, user, session_factory, utc_hour_ago): + # End with an abandon + s1 = session_factory( + user=user, + wall_count=2, + wall_req_cpis=[Decimal(1), Decimal(2)], + final_status=Status.ABANDON, + started=utc_hour_ago, + ) + + w1 = s1.wall_events[0] + w2 = s1.wall_events[1] + + # abandon adjust to complete + w2.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w2.cpi, + adjusted_timestamp=adj_ts, + ) + assert WallAdjustedStatus.ADJUSTED_TO_COMPLETE == w2.adjusted_status + s1.adjust_status() + assert SessionAdjustedStatus.ADJUSTED_TO_COMPLETE == s1.adjusted_status + assert Decimal("1.90") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.95") == s1.adjusted_user_payout + assert s1.adjusted_user_payout is None + + # back to fail + w2.update( + adjusted_status=None, + adjusted_cpi=None, + adjusted_timestamp=adj_ts, + ) + assert w2.adjusted_status is None + s1.adjust_status() + assert s1.adjusted_status is None + assert s1.adjusted_payout is None + assert s1.adjusted_user_payout is None + + # other is now complete + w1.update( + adjusted_status=WallAdjustedStatus.ADJUSTED_TO_COMPLETE, + adjusted_cpi=w1.cpi, + adjusted_timestamp=adj_ts, + ) + assert WallAdjustedStatus.ADJUSTED_TO_COMPLETE == w1.adjusted_status + s1.adjust_status() + assert SessionAdjustedStatus.ADJUSTED_TO_COMPLETE == s1.adjusted_status + assert Decimal("0.95") == s1.adjusted_payout + assert not user.product.user_wallet_config.enabled + # assert Decimal("0.48") == s1.adjusted_user_payout + assert s1.adjusted_user_payout is None diff --git a/tests/models/thl/test_bucket.py b/tests/models/thl/test_bucket.py new file mode 100644 index 0000000..0aa5843 --- /dev/null +++ b/tests/models/thl/test_bucket.py @@ -0,0 +1,201 @@ +from datetime import timedelta +from decimal import Decimal + +import pytest +from pydantic import ValidationError + + +class TestBucket: + + def test_raises_payout(self): + from generalresearch.models.legacy.bucket import Bucket + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(user_payout_min=123) + assert "Must pass a Decimal" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(user_payout_min=Decimal(1 / 3)) + assert "Must have 2 or fewer decimal places" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(user_payout_min=Decimal(10000)) + assert "should be less than 1000" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(user_payout_min=Decimal(1), user_payout_max=Decimal("0.01")) + assert "user_payout_min should be <= user_payout_max" in str(e.value) + + def test_raises_loi(self): + from generalresearch.models.legacy.bucket import Bucket + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(loi_min=123) + assert "Input should be a valid timedelta" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(loi_min=timedelta(seconds=9999)) + assert "should be less than 90 minutes" in str(e.value) + + with pytest.raises(ValidationError) as e: + Bucket(loi_min=timedelta(seconds=0)) + assert "should be greater than 0" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket(loi_min=timedelta(seconds=10), loi_max=timedelta(seconds=9)) + assert "loi_min should be <= loi_max" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + Bucket( + loi_min=timedelta(seconds=10), + loi_max=timedelta(seconds=90), + loi_q1=timedelta(seconds=20), + ) + assert "loi_q1, q2, and q3 should all be set" in str(e.value) + with pytest.raises(expected_exception=ValidationError) as e: + Bucket( + loi_min=timedelta(seconds=10), + loi_max=timedelta(seconds=90), + loi_q1=timedelta(seconds=200), + loi_q2=timedelta(seconds=20), + loi_q3=timedelta(seconds=12), + ) + assert "loi_q1 should be <= loi_q2" in str(e.value) + + def test_parse_1(self): + from generalresearch.models.legacy.bucket import Bucket + + b1 = Bucket.parse_from_offerwall({"payout": {"min": 123}}) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=None, + loi_min=None, + loi_max=None, + ) + assert b_exp == b1 + + b2 = Bucket.parse_from_offerwall({"payout": {"min": 123, "max": 230}}) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=Decimal("2.30"), + loi_min=None, + loi_max=None, + ) + assert b_exp == b2 + + b3 = Bucket.parse_from_offerwall( + {"payout": {"min": 123, "max": 230}, "duration": {"min": 600, "max": 1800}} + ) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=Decimal("2.30"), + loi_min=timedelta(seconds=600), + loi_max=timedelta(seconds=1800), + ) + assert b_exp == b3 + + b4 = Bucket.parse_from_offerwall( + { + "payout": {"max": 80, "min": 28, "q1": 43, "q2": 43, "q3": 56}, + "duration": {"max": 1172, "min": 266, "q1": 746, "q2": 918, "q3": 1002}, + } + ) + b_exp = Bucket( + user_payout_min=Decimal("0.28"), + user_payout_max=Decimal("0.80"), + user_payout_q1=Decimal("0.43"), + user_payout_q2=Decimal("0.43"), + user_payout_q3=Decimal("0.56"), + loi_min=timedelta(seconds=266), + loi_max=timedelta(seconds=1172), + loi_q1=timedelta(seconds=746), + loi_q2=timedelta(seconds=918), + loi_q3=timedelta(seconds=1002), + ) + assert b_exp == b4 + + def test_parse_2(self): + from generalresearch.models.legacy.bucket import Bucket + + b1 = Bucket.parse_from_offerwall({"min_payout": 123}) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=None, + loi_min=None, + loi_max=None, + ) + assert b_exp == b1 + + b2 = Bucket.parse_from_offerwall({"min_payout": 123, "max_payout": 230}) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=Decimal("2.30"), + loi_min=None, + loi_max=None, + ) + assert b_exp == b2 + + b3 = Bucket.parse_from_offerwall( + { + "min_payout": 123, + "max_payout": 230, + "min_duration": 600, + "max_duration": 1800, + } + ) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=Decimal("2.30"), + loi_min=timedelta(seconds=600), + loi_max=timedelta(seconds=1800), + ) + assert b_exp, b3 + + b4 = Bucket.parse_from_offerwall( + { + "min_payout": 28, + "max_payout": 99, + "min_duration": 205, + "max_duration": 1113, + "q1_payout": 43, + "q2_payout": 43, + "q3_payout": 46, + "q1_duration": 561, + "q2_duration": 891, + "q3_duration": 918, + } + ) + b_exp = Bucket( + user_payout_min=Decimal("0.28"), + user_payout_max=Decimal("0.99"), + user_payout_q1=Decimal("0.43"), + user_payout_q2=Decimal("0.43"), + user_payout_q3=Decimal("0.46"), + loi_min=timedelta(seconds=205), + loi_max=timedelta(seconds=1113), + loi_q1=timedelta(seconds=561), + loi_q2=timedelta(seconds=891), + loi_q3=timedelta(seconds=918), + ) + assert b_exp == b4 + + def test_parse_3(self): + from generalresearch.models.legacy.bucket import Bucket + + b1 = Bucket.parse_from_offerwall({"payout": 123}) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=None, + loi_min=None, + loi_max=None, + ) + assert b_exp == b1 + + b2 = Bucket.parse_from_offerwall({"payout": 123, "duration": 1800}) + b_exp = Bucket( + user_payout_min=Decimal("1.23"), + user_payout_max=None, + loi_min=None, + loi_max=timedelta(seconds=1800), + ) + assert b_exp == b2 diff --git a/tests/models/thl/test_buyer.py b/tests/models/thl/test_buyer.py new file mode 100644 index 0000000..eebb828 --- /dev/null +++ b/tests/models/thl/test_buyer.py @@ -0,0 +1,23 @@ +from generalresearch.models import Source +from generalresearch.models.thl.survey.buyer import BuyerCountryStat + + +def test_buyer_country_stat(): + bcs = BuyerCountryStat( + country_iso="us", + source=Source.TESTING, + code="123", + task_count=100, + conversion_alpha=40, + conversion_beta=190, + dropoff_alpha=20, + dropoff_beta=50, + long_fail_rate=1, + loi_excess_ratio=1, + user_report_coeff=1, + recon_likelihood=0.05, + ) + assert bcs.score + print(bcs.score) + print(bcs.conversion_p20) + print(bcs.dropoff_p60) diff --git a/tests/models/thl/test_contest/__init__.py b/tests/models/thl/test_contest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/thl/test_contest/test_contest.py b/tests/models/thl/test_contest/test_contest.py new file mode 100644 index 0000000..d53eee5 --- /dev/null +++ b/tests/models/thl/test_contest/test_contest.py @@ -0,0 +1,23 @@ +import pytest +from generalresearch.models.thl.user import User + + +class TestContest: + """In many of the Contest related tests, we often want a consistent + Product throughout, and multiple different users that may be + involved in the Contest... so redefine the product fixture along with + some users in here that are scoped="class" so they stay around for + each of the test functions + """ + + @pytest.fixture(scope="function") + def user_1(self, user_factory, product) -> User: + return user_factory(product=product) + + @pytest.fixture(scope="function") + def user_2(self, user_factory, product) -> User: + return user_factory(product=product) + + @pytest.fixture(scope="function") + def user_3(self, user_factory, product) -> User: + return user_factory(product=product) diff --git a/tests/models/thl/test_contest/test_leaderboard_contest.py b/tests/models/thl/test_contest/test_leaderboard_contest.py new file mode 100644 index 0000000..98f3215 --- /dev/null +++ b/tests/models/thl/test_contest/test_leaderboard_contest.py @@ -0,0 +1,213 @@ +from datetime import timezone +from uuid import uuid4 + +import pytest + +from generalresearch.currency import USDCent +from generalresearch.managers.leaderboard.manager import LeaderboardManager +from generalresearch.models.thl.contest import ContestPrize +from generalresearch.models.thl.contest.definitions import ( + ContestType, + ContestPrizeKind, +) +from generalresearch.models.thl.contest.leaderboard import ( + LeaderboardContest, +) +from generalresearch.models.thl.contest.utils import ( + distribute_leaderboard_prizes, +) +from generalresearch.models.thl.leaderboard import LeaderboardRow +from tests.models.thl.test_contest.test_contest import TestContest + + +class TestLeaderboardContest(TestContest): + + @pytest.fixture + def leaderboard_contest( + self, product, thl_redis, user_manager + ) -> "LeaderboardContest": + board_key = f"leaderboard:{product.uuid}:us:weekly:2025-05-26:complete_count" + + c = LeaderboardContest( + uuid=uuid4().hex, + product_id=product.uuid, + contest_type=ContestType.LEADERBOARD, + leaderboard_key=board_key, + name="$15 1st place, $10 2nd, $5 3rd place US weekly", + prizes=[ + ContestPrize( + name="$15 Cash", + estimated_cash_value=USDCent(15_00), + cash_amount=USDCent(15_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=1, + ), + ContestPrize( + name="$10 Cash", + estimated_cash_value=USDCent(10_00), + cash_amount=USDCent(10_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=2, + ), + ContestPrize( + name="$5 Cash", + estimated_cash_value=USDCent(5_00), + cash_amount=USDCent(5_00), + kind=ContestPrizeKind.CASH, + leaderboard_rank=3, + ), + ], + ) + c._redis_client = thl_redis + c._user_manager = user_manager + return c + + def test_init(self, leaderboard_contest, thl_redis, user_1, user_2): + model = leaderboard_contest.leaderboard_model + assert leaderboard_contest.end_condition.ends_at is not None + + lbm = LeaderboardManager( + redis_client=thl_redis, + board_code=model.board_code, + country_iso=model.country_iso, + freq=model.freq, + product_id=leaderboard_contest.product_id, + within_time=model.period_start_local, + ) + + lbm.hit_complete_count(product_user_id=user_1.product_user_id) + lbm.hit_complete_count(product_user_id=user_2.product_user_id) + lbm.hit_complete_count(product_user_id=user_2.product_user_id) + + lb = leaderboard_contest.get_leaderboard() + print(lb) + + def test_win(self, leaderboard_contest, thl_redis, user_1, user_2, user_3): + model = leaderboard_contest.leaderboard_model + lbm = LeaderboardManager( + redis_client=thl_redis, + board_code=model.board_code, + country_iso=model.country_iso, + freq=model.freq, + product_id=leaderboard_contest.product_id, + within_time=model.period_start_local.astimezone(tz=timezone.utc), + ) + + lbm.hit_complete_count(product_user_id=user_1.product_user_id) + lbm.hit_complete_count(product_user_id=user_1.product_user_id) + + lbm.hit_complete_count(product_user_id=user_2.product_user_id) + + lbm.hit_complete_count(product_user_id=user_3.product_user_id) + + leaderboard_contest.end_contest() + assert len(leaderboard_contest.all_winners) == 3 + + # Prizes are $15, $10, $5. user 2 and 3 ties for 2nd place, so they split (10 + 5) + assert leaderboard_contest.all_winners[0].awarded_cash_amount == USDCent(15_00) + assert ( + leaderboard_contest.all_winners[0].user.product_user_id + == user_1.product_user_id + ) + assert leaderboard_contest.all_winners[0].prize == leaderboard_contest.prizes[0] + assert leaderboard_contest.all_winners[1].awarded_cash_amount == USDCent( + 15_00 / 2 + ) + assert leaderboard_contest.all_winners[2].awarded_cash_amount == USDCent( + 15_00 / 2 + ) + + +class TestLeaderboardContestPrizes: + + def test_distribute_prizes_1(self): + prizes = [USDCent(15_00)] + leaderboard_rows = [ + LeaderboardRow(bpuid="a", value=20, rank=1), + LeaderboardRow(bpuid="b", value=10, rank=2), + ] + result = distribute_leaderboard_prizes(prizes, leaderboard_rows) + + # a gets first prize, b gets nothing. + assert result == { + "a": USDCent(15_00), + } + + def test_distribute_prizes_2(self): + prizes = [USDCent(15_00), USDCent(10_00)] + leaderboard_rows = [ + LeaderboardRow(bpuid="a", value=20, rank=1), + LeaderboardRow(bpuid="b", value=10, rank=2), + ] + result = distribute_leaderboard_prizes(prizes, leaderboard_rows) + + # a gets first prize, b gets 2nd prize + assert result == { + "a": USDCent(15_00), + "b": USDCent(10_00), + } + + def test_distribute_prizes_3(self): + prizes = [USDCent(15_00), USDCent(10_00)] + leaderboard_rows = [ + LeaderboardRow(bpuid="a", value=20, rank=1), + ] + result = distribute_leaderboard_prizes(prizes, leaderboard_rows) + + # A gets first prize, no-one gets $10 + assert result == { + "a": USDCent(15_00), + } + + def test_distribute_prizes_4(self): + prizes = [USDCent(15_00)] + leaderboard_rows = [ + LeaderboardRow(bpuid="a", value=20, rank=1), + LeaderboardRow(bpuid="b", value=20, rank=1), + LeaderboardRow(bpuid="c", value=20, rank=1), + LeaderboardRow(bpuid="d", value=20, rank=1), + ] + result = distribute_leaderboard_prizes(prizes, leaderboard_rows) + + # 4-way tie for the $15 prize; it gets split + assert result == { + "a": USDCent(3_75), + "b": USDCent(3_75), + "c": USDCent(3_75), + "d": USDCent(3_75), + } + + def test_distribute_prizes_5(self): + prizes = [USDCent(15_00), USDCent(10_00)] + leaderboard_rows = [ + LeaderboardRow(bpuid="a", value=20, rank=1), + LeaderboardRow(bpuid="b", value=20, rank=1), + LeaderboardRow(bpuid="c", value=10, rank=3), + ] + result = distribute_leaderboard_prizes(prizes, leaderboard_rows) + + # 2-way tie for the $15 prize; the top two prizes get split. Rank 3 + # and below get nothing + assert result == { + "a": USDCent(12_50), + "b": USDCent(12_50), + } + + def test_distribute_prizes_6(self): + prizes = [USDCent(15_00), USDCent(10_00), USDCent(5_00)] + leaderboard_rows = [ + LeaderboardRow(bpuid="a", value=20, rank=1), + LeaderboardRow(bpuid="b", value=10, rank=2), + LeaderboardRow(bpuid="c", value=10, rank=2), + LeaderboardRow(bpuid="d", value=10, rank=2), + ] + result = distribute_leaderboard_prizes(prizes, leaderboard_rows) + + # A gets first prize, 3 way tie for 2nd rank: they split the 2nd and + # 3rd place prizes (10 + 5)/3 + assert result == { + "a": USDCent(15_00), + "b": USDCent(5_00), + "c": USDCent(5_00), + "d": USDCent(5_00), + } diff --git a/tests/models/thl/test_contest/test_raffle_contest.py b/tests/models/thl/test_contest/test_raffle_contest.py new file mode 100644 index 0000000..e1c0a15 --- /dev/null +++ b/tests/models/thl/test_contest/test_raffle_contest.py @@ -0,0 +1,300 @@ +from collections import Counter +from uuid import uuid4 + +import pytest +from pytest import approx + +from generalresearch.currency import USDCent +from generalresearch.models.thl.contest import ( + ContestPrize, + ContestEndCondition, +) +from generalresearch.models.thl.contest.contest_entry import ContestEntry +from generalresearch.models.thl.contest.definitions import ( + ContestEntryType, + ContestPrizeKind, + ContestType, + ContestStatus, + ContestEndReason, +) +from generalresearch.models.thl.contest.raffle import RaffleContest + +from tests.models.thl.test_contest.test_contest import TestContest + + +class TestRaffleContest(TestContest): + + @pytest.fixture(scope="function") + def raffle_contest(self, product) -> RaffleContest: + return RaffleContest( + product_id=product.uuid, + name=f"Raffle Contest {uuid4().hex}", + contest_type=ContestType.RAFFLE, + entry_type=ContestEntryType.CASH, + prizes=[ + ContestPrize( + name="iPod 64GB White", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ) + ], + end_condition=ContestEndCondition(target_entry_amount=100), + ) + + @pytest.fixture(scope="function") + def ended_raffle_contest(self, raffle_contest, utc_now) -> RaffleContest: + # Fake ending the contest + raffle_contest = raffle_contest.model_copy() + raffle_contest.update( + status=ContestStatus.COMPLETED, + ended_at=utc_now, + end_reason=ContestEndReason.ENDS_AT, + ) + return raffle_contest + + +class TestRaffleContestUserView(TestRaffleContest): + + def test_user_view(self, raffle_contest, user): + from generalresearch.models.thl.contest.raffle import RaffleUserView + + data = { + "current_amount": USDCent(1_00), + "product_user_id": user.product_user_id, + "user_amount": USDCent(1), + "user_amount_today": USDCent(1), + } + r = RaffleUserView.model_validate(raffle_contest.model_dump() | data) + res = r.model_dump(mode="json") + + assert res["product_user_id"] == user.product_user_id + assert res["user_amount_today"] == 1 + assert res["current_win_probability"] == approx(0.01, rel=0.000001) + assert res["projected_win_probability"] == approx(0.01, rel=0.000001) + + # Now change the amount + r.current_amount = USDCent(1_01) + res = r.model_dump(mode="json") + assert res["current_win_probability"] == approx(0.0099, rel=0.001) + assert res["projected_win_probability"] == approx(0.0099, rel=0.001) + + def test_win_pct(self, raffle_contest, user): + from generalresearch.models.thl.contest.raffle import RaffleUserView + + data = { + "current_amount": USDCent(10), + "product_user_id": user.product_user_id, + "user_amount": USDCent(1), + "user_amount_today": USDCent(1), + } + r = RaffleUserView.model_validate(raffle_contest.model_dump() | data) + r.prizes = [ + ContestPrize( + name="iPod 64GB White", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ContestPrize( + name="iPod 64GB White", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ] + # Raffle has 10 entries, user has 1 entry. + # There are 2 prizes. + assert r.current_win_probability == approx(expected=0.2, rel=0.01) + # He can only possibly win 1 prize + assert r.current_prize_count_probability[1] == approx(expected=0.2, rel=0.01) + # He has a 0 prob of winning 2 prizes + assert r.current_prize_count_probability[2] == 0 + # Contest end when there are 100 entries, so 1/100 * 2 prizes + assert r.projected_win_probability == approx(expected=0.02, rel=0.01) + + # Change to user having 2 entries (out of 10) + # Still with 2 prizes + r.user_amount = USDCent(2) + assert r.current_win_probability == approx(expected=0.3777, rel=0.01) + # 2/10 chance of winning 1st, 8/9 change of not winning 2nd, plus the + # same in the other order + p = (2 / 10) * (8 / 9) * 2 # 0.355555 + assert r.current_prize_count_probability[1] == approx(p, rel=0.01) + p = (2 / 10) * (1 / 9) # 0.02222 + assert r.current_prize_count_probability[2] == approx(p, rel=0.01) + + +class TestRaffleContestWinners(TestRaffleContest): + + def test_winners_1_prize(self, ended_raffle_contest, user_1, user_2, user_3): + ended_raffle_contest.entries = [ + ContestEntry( + user=user_1, + amount=USDCent(1), + entry_type=ContestEntryType.CASH, + ), + ContestEntry( + user=user_2, + amount=USDCent(2), + entry_type=ContestEntryType.CASH, + ), + ContestEntry( + user=user_3, + amount=USDCent(3), + entry_type=ContestEntryType.CASH, + ), + ] + + # There is 1 prize. If we select a winner 1000 times, we'd expect user 1 + # to win ~ 1/6th of the time, user 2 ~2/6th and 3 3/6th. + winners = ended_raffle_contest.select_winners() + assert len(winners) == 1 + + c = Counter( + [ + ended_raffle_contest.select_winners()[0].user.user_id + for _ in range(10000) + ] + ) + assert c[user_1.user_id] == approx( + 10000 * 1 / 6, rel=0.1 + ) # 10% relative tolerance + assert c[user_2.user_id] == approx(10000 * 2 / 6, rel=0.1) + assert c[user_3.user_id] == approx(10000 * 3 / 6, rel=0.1) + + def test_winners_2_prizes(self, ended_raffle_contest, user_1, user_2, user_3): + ended_raffle_contest.prizes.append( + ContestPrize( + name="iPod 64GB Black", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ) + ) + ended_raffle_contest.entries = [ + ContestEntry( + user=user_3, + amount=USDCent(1), + entry_type=ContestEntryType.CASH, + ), + ContestEntry( + user=user_1, + amount=USDCent(9999999), + entry_type=ContestEntryType.CASH, + ), + ContestEntry( + user=user_2, + amount=USDCent(1), + entry_type=ContestEntryType.CASH, + ), + ] + # In this scenario, user 1 should win both prizes + winners = ended_raffle_contest.select_winners() + assert len(winners) == 2 + # Two different prizes + assert len({w.prize.name for w in winners}) == 2 + # Same user + assert all(w.user.user_id == user_1.user_id for w in winners) + + def test_winners_2_prizes_1_entry(self, ended_raffle_contest, user_3): + ended_raffle_contest.prizes = [ + ContestPrize( + name="iPod 64GB White", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ContestPrize( + name="iPod 64GB Black", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ] + ended_raffle_contest.entries = [ + ContestEntry( + user=user_3, + amount=USDCent(1), + entry_type=ContestEntryType.CASH, + ), + ] + + # One prize goes unclaimed + winners = ended_raffle_contest.select_winners() + assert len(winners) == 1 + + def test_winners_2_prizes_1_entry_2_pennies(self, ended_raffle_contest, user_3): + ended_raffle_contest.prizes = [ + ContestPrize( + name="iPod 64GB White", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ContestPrize( + name="iPod 64GB Black", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ] + ended_raffle_contest.entries = [ + ContestEntry( + user=user_3, + amount=USDCent(2), + entry_type=ContestEntryType.CASH, + ), + ] + # User wins both prizes + winners = ended_raffle_contest.select_winners() + assert len(winners) == 2 + + def test_winners_3_prizes_3_entries( + self, ended_raffle_contest, product, user_1, user_2, user_3 + ): + ended_raffle_contest.prizes = [ + ContestPrize( + name="iPod 64GB White", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ContestPrize( + name="iPod 64GB Black", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ContestPrize( + name="iPod 64GB Red", + kind=ContestPrizeKind.PHYSICAL, + estimated_cash_value=USDCent(100_00), + ), + ] + ended_raffle_contest.entries = [ + ContestEntry( + user=user_1, + amount=USDCent(1), + entry_type=ContestEntryType.CASH, + ), + ContestEntry( + user=user_2, + amount=USDCent(2), + entry_type=ContestEntryType.CASH, + ), + ContestEntry( + user=user_3, + amount=USDCent(3), + entry_type=ContestEntryType.CASH, + ), + ] + + winners = ended_raffle_contest.select_winners() + assert len(winners) == 3 + + winners = [ended_raffle_contest.select_winners() for _ in range(10000)] + + # There's 3 winners, the 1st should follow the same percentages + c = Counter([w[0].user.user_id for w in winners]) + + assert c[user_1.user_id] == approx(10000 * 1 / 6, rel=0.1) + assert c[user_2.user_id] == approx(10000 * 2 / 6, rel=0.1) + assert c[user_3.user_id] == approx(10000 * 3 / 6, rel=0.1) + + # Assume the 1st user won + ended_raffle_contest.entries.pop(0) + winners = [ended_raffle_contest.select_winners() for _ in range(10000)] + c = Counter([w[0].user.user_id for w in winners]) + assert c[user_2.user_id] == approx(10000 * 2 / 5, rel=0.1) + assert c[user_3.user_id] == approx(10000 * 3 / 5, rel=0.1) diff --git a/tests/models/thl/test_ledger.py b/tests/models/thl/test_ledger.py new file mode 100644 index 0000000..d706357 --- /dev/null +++ b/tests/models/thl/test_ledger.py @@ -0,0 +1,130 @@ +from datetime import datetime, timezone +from decimal import Decimal +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +from generalresearch.models.thl.ledger import LedgerAccount, Direction, AccountType +from generalresearch.models.thl.ledger import LedgerTransaction, LedgerEntry + + +class TestLedgerTransaction: + + def test_create(self): + # Can create with nothing ... + t = LedgerTransaction() + assert [] == t.entries + assert {} == t.metadata + t = LedgerTransaction( + created=datetime.now(tz=timezone.utc), + metadata={"a": "b", "user": "1234"}, + ext_description="foo", + ) + + def test_ledger_entry(self): + with pytest.raises(expected_exception=ValidationError) as cm: + LedgerEntry( + direction=Direction.CREDIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=0, + ) + assert "Input should be greater than 0" in str(cm.value) + + with pytest.raises(ValidationError) as cm: + LedgerEntry( + direction=Direction.CREDIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=2**65, + ) + assert "Input should be less than 9223372036854775807" in str(cm.value) + + with pytest.raises(ValidationError) as cm: + LedgerEntry( + direction=Direction.CREDIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=Decimal("1"), + ) + assert "Input should be a valid integer" in str(cm.value) + + with pytest.raises(ValidationError) as cm: + LedgerEntry( + direction=Direction.CREDIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=1.2, + ) + assert "Input should be a valid integer" in str(cm.value) + + def test_entries(self): + entries = [ + LedgerEntry( + direction=Direction.CREDIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=100, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid="5927621462814f9893be807db850a31b", + amount=100, + ), + ] + LedgerTransaction(entries=entries) + + def test_raises_entries(self): + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=100, + ), + LedgerEntry( + direction=Direction.DEBIT, + account_uuid="5927621462814f9893be807db850a31b", + amount=100, + ), + ] + with pytest.raises(ValidationError) as e: + LedgerTransaction(entries=entries) + assert "ledger entries must balance" in str(e.value) + + entries = [ + LedgerEntry( + direction=Direction.DEBIT, + account_uuid="3f3735eaed264c2a9f8a114934afa121", + amount=100, + ), + LedgerEntry( + direction=Direction.CREDIT, + account_uuid="5927621462814f9893be807db850a31b", + amount=101, + ), + ] + with pytest.raises(ValidationError) as cm: + LedgerTransaction(entries=entries) + assert "ledger entries must balance" in str(cm.value) + + +class TestLedgerAccount: + + def test_initialization(self): + u = uuid4().hex + name = f"test-{u[:8]}" + + with pytest.raises(ValidationError) as cm: + LedgerAccount( + display_name=name, + qualified_name="bad bunny", + normal_balance=Direction.DEBIT, + account_type=AccountType.BP_WALLET, + ) + assert "qualified name should start with" in str(cm.value) + + with pytest.raises(ValidationError) as cm: + LedgerAccount( + display_name=name, + qualified_name="fish sticks:bp_wallet", + normal_balance=Direction.DEBIT, + account_type=AccountType.BP_WALLET, + currency="fish sticks", + ) + assert "Invalid UUID" in str(cm.value) diff --git a/tests/models/thl/test_marketplace_condition.py b/tests/models/thl/test_marketplace_condition.py new file mode 100644 index 0000000..217616d --- /dev/null +++ b/tests/models/thl/test_marketplace_condition.py @@ -0,0 +1,382 @@ +import pytest +from pydantic import ValidationError + + +class TestMarketplaceCondition: + + def test_list_or(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"a2"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a2"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a3"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) is None + + def test_list_or_negate(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"a2"}} + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.LIST, + values=["a2"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.LIST, + values=["a1", "a3"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=True, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) is None + + def test_list_and(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"a1", "a2"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.AND, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a2"], + logical_operator=LogicalOperator.AND, + ) + assert c.evaluate_criterion(user_qas) + user_qas = {"q1": {"a1"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a2"], + logical_operator=LogicalOperator.AND, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a3"], + logical_operator=LogicalOperator.AND, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=False, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.AND, + ) + assert c.evaluate_criterion(user_qas) is None + + def test_list_and_negate(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"a1", "a2"}} + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.AND, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.LIST, + values=["a2"], + logical_operator=LogicalOperator.AND, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.LIST, + values=["a1", "a3"], + logical_operator=LogicalOperator.AND, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=True, + value_type=ConditionValueType.LIST, + values=["a1", "a2", "a3"], + logical_operator=LogicalOperator.AND, + ) + assert c.evaluate_criterion(user_qas) is None + + def test_ranges(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"2", "50"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4", "10-20"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["10-20"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4", "10-20"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) is None + # --- negate + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.RANGE, + values=["1-4", "10-20"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.RANGE, + values=["10-20"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + # --- AND + with pytest.raises(expected_exception=ValidationError): + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4", "10-20"], + logical_operator=LogicalOperator.AND, + ) + + def test_ranges_to_list(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"2", "50"}} + MarketplaceCondition._CONVERT_LIST_TO_RANGE = ["q1"] + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4", "10-12", "3-5"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + assert ConditionValueType.LIST == c.value_type + assert ["1", "10", "11", "12", "2", "3", "4", "5"] == c.values + + def test_ranges_infinity(self): + from generalresearch.models import LogicalOperator + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"2", "50"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4", "10-inf"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion(user_qas) + user_qas = {"q1": {"5", "50"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["1-4", "60-inf"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion(user_qas) + + # need to test negative infinity! + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["inf-40"], + logical_operator=LogicalOperator.OR, + ) + assert c.evaluate_criterion({"q1": {"5", "50"}}) + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.RANGE, + values=["inf-40"], + logical_operator=LogicalOperator.OR, + ) + assert not c.evaluate_criterion({"q1": {"50"}}) + + def test_answered(self): + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_qas = {"q1": {"a2"}} + c = MarketplaceCondition( + question_id="q1", + negate=False, + value_type=ConditionValueType.ANSWERED, + values=[], + ) + assert c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=False, + value_type=ConditionValueType.ANSWERED, + values=[], + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q1", + negate=True, + value_type=ConditionValueType.ANSWERED, + values=[], + ) + assert not c.evaluate_criterion(user_qas) + c = MarketplaceCondition( + question_id="q2", + negate=True, + value_type=ConditionValueType.ANSWERED, + values=[], + ) + assert c.evaluate_criterion(user_qas) + + def test_invite(self): + from generalresearch.models.thl.survey.condition import ( + MarketplaceCondition, + ConditionValueType, + ) + + user_groups = {"g1", "g2", "g3"} + c = MarketplaceCondition( + question_id=None, + negate=False, + value_type=ConditionValueType.RECONTACT, + values=["g1", "g4"], + ) + assert c.evaluate_criterion(user_qas=dict(), user_groups=user_groups) + c = MarketplaceCondition( + question_id=None, + negate=False, + value_type=ConditionValueType.RECONTACT, + values=["g4"], + ) + assert not c.evaluate_criterion(user_qas=dict(), user_groups=user_groups) + + c = MarketplaceCondition( + question_id=None, + negate=True, + value_type=ConditionValueType.RECONTACT, + values=["g1", "g4"], + ) + assert not c.evaluate_criterion(user_qas=dict(), user_groups=user_groups) + c = MarketplaceCondition( + question_id=None, + negate=True, + value_type=ConditionValueType.RECONTACT, + values=["g4"], + ) + assert c.evaluate_criterion(user_qas=dict(), user_groups=user_groups) diff --git a/tests/models/thl/test_payout.py b/tests/models/thl/test_payout.py new file mode 100644 index 0000000..3a51328 --- /dev/null +++ b/tests/models/thl/test_payout.py @@ -0,0 +1,10 @@ +class TestBusinessPayoutEvent: + + def test_validate(self): + from generalresearch.models.gr.business import Business + + instance = Business.model_validate_json( + json_data='{"id":123,"uuid":"947f6ba5250d442b9a66cde9ee33605a","name":"Example » Demo","kind":"c","tax_number":null,"contact":null,"addresses":[],"teams":[{"id":53,"uuid":"8e4197dcaefe4f1f831a02b212e6b44a","name":"Example » Demo","memberships":null,"gr_users":null,"businesses":null,"products":null}],"products":[{"id":"fc23e741b5004581b30e6478363525df","id_int":1234,"name":"Example","enabled":true,"payments_enabled":true,"created":"2025-04-14T13:25:37.279403Z","team_id":"9e4197dcaefe4f1f831a02b212e6b44a","business_id":"857f6ba6160d442b9a66cde9ee33605a","tags":[],"commission_pct":"0.050000","redirect_url":"https://pam-api-us.reppublika.com/v2/public/4970ef00-0ef7-11f0-9962-05cb6323c84c/grl/status","harmonizer_domain":"https://talk.generalresearch.com/","sources_config":{"user_defined":[{"name":"w","active":false,"banned_countries":[],"allow_mobile_ip":true,"supplier_id":null,"allow_pii_only_buyers":false,"allow_unhashed_buyers":false,"withhold_profiling":false,"pass_unconditional_eligible_unknowns":true,"address":null,"allow_vpn":null,"distribute_harmonizer_active":null}]},"session_config":{"max_session_len":600,"max_session_hard_retry":5,"min_payout":"0.14"},"payout_config":{"payout_format":null,"payout_transformation":null},"user_wallet_config":{"enabled":false,"amt":false,"supported_payout_types":["CASH_IN_MAIL","PAYPAL","TANGO"],"min_cashout":null},"user_create_config":{"min_hourly_create_limit":0,"max_hourly_create_limit":null},"offerwall_config":{},"profiling_config":{"enabled":true,"grs_enabled":true,"n_questions":null,"max_questions":10,"avg_question_count":5.0,"task_injection_freq_mult":1.0,"non_us_mult":2.0,"hidden_questions_expiration_hours":168},"user_health_config":{"banned_countries":[],"allow_ban_iphist":true},"yield_man_config":{},"balance":null,"payouts_total_str":null,"payouts_total":null,"payouts":null,"user_wallet":{"enabled":false,"amt":false,"supported_payout_types":["CASH_IN_MAIL","PAYPAL","TANGO"],"min_cashout":null}}],"bank_accounts":[],"balance":{"product_balances":[{"product_id":"fc14e741b5004581b30e6478363414df","last_event":null,"bp_payment_credit":780251,"adjustment_credit":4678,"adjustment_debit":26446,"supplier_credit":0,"supplier_debit":451513,"user_bonus_credit":0,"user_bonus_debit":0,"issued_payment":0,"payout":780251,"payout_usd_str":"$7,802.51","adjustment":-21768,"expense":0,"net":758483,"payment":451513,"payment_usd_str":"$4,515.13","balance":306970,"retainer":76742,"retainer_usd_str":"$767.42","available_balance":230228,"available_balance_usd_str":"$2,302.28","recoup":0,"recoup_usd_str":"$0.00","adjustment_percent":0.027898714644390074}],"payout":780251,"payout_usd_str":"$7,802.51","adjustment":-21768,"expense":0,"net":758483,"net_usd_str":"$7,584.83","payment":451513,"payment_usd_str":"$4,515.13","balance":306970,"balance_usd_str":"$3,069.70","retainer":76742,"retainer_usd_str":"$767.42","available_balance":230228,"available_balance_usd_str":"$2,302.28","adjustment_percent":0.027898714644390074,"recoup":0,"recoup_usd_str":"$0.00"},"payouts_total_str":"$4,515.13","payouts_total":451513,"payouts":[{"bp_payouts":[{"uuid":"40cf2c3c341e4f9d985be4bca43e6116","debit_account_uuid":"3a058056da85493f9b7cdfe375aad0e0","cashout_method_uuid":"602113e330cf43ae85c07d94b5100291","created":"2025-08-02T09:18:20.433329Z","amount":345735,"status":"COMPLETE","ext_ref_id":null,"payout_type":"ACH","request_data":{},"order_data":null,"product_id":"fc14e741b5004581b30e6478363414df","method":"ACH","amount_usd":345735,"amount_usd_str":"$3,457.35"}],"amount":345735,"amount_usd_str":"$3,457.35","created":"2025-08-02T09:18:20.433329Z","line_items":1,"ext_ref_id":null},{"bp_payouts":[{"uuid":"63ce1787087248978919015c8fcd5ab9","debit_account_uuid":"3a058056da85493f9b7cdfe375aad0e0","cashout_method_uuid":"602113e330cf43ae85c07d94b5100291","created":"2025-06-10T22:16:18.765668Z","amount":105778,"status":"COMPLETE","ext_ref_id":"11175997868","payout_type":"ACH","request_data":{},"order_data":null,"product_id":"fc14e741b5004581b30e6478363414df","method":"ACH","amount_usd":105778,"amount_usd_str":"$1,057.78"}],"amount":105778,"amount_usd_str":"$1,057.78","created":"2025-06-10T22:16:18.765668Z","line_items":1,"ext_ref_id":"11175997868"}]}' + ) + + assert isinstance(instance, Business) diff --git a/tests/models/thl/test_payout_format.py b/tests/models/thl/test_payout_format.py new file mode 100644 index 0000000..dc91f39 --- /dev/null +++ b/tests/models/thl/test_payout_format.py @@ -0,0 +1,46 @@ +import pytest +from pydantic import BaseModel + +from generalresearch.models.thl.payout_format import ( + PayoutFormatType, + PayoutFormatField, + format_payout_format, +) + + +class PayoutFormatTestClass(BaseModel): + payout_format: PayoutFormatType = PayoutFormatField + + +class TestPayoutFormat: + def test_payout_format_cls(self): + # valid + PayoutFormatTestClass(payout_format="{payout*10:,.0f} Points") + PayoutFormatTestClass(payout_format="{payout:.0f}") + PayoutFormatTestClass(payout_format="${payout/100:.2f}") + + # invalid + with pytest.raises(expected_exception=Exception) as e: + PayoutFormatTestClass(payout_format="{payout10:,.0f} Points") + + with pytest.raises(expected_exception=Exception) as e: + PayoutFormatTestClass(payout_format="payout:,.0f} Points") + + with pytest.raises(expected_exception=Exception): + PayoutFormatTestClass(payout_format="payout") + + with pytest.raises(expected_exception=Exception): + PayoutFormatTestClass(payout_format="{payout;import sys:.0f}") + + def test_payout_format(self): + assert "1,230 Points" == format_payout_format( + payout_format="{payout*10:,.0f} Points", payout_int=123 + ) + + assert "123" == format_payout_format( + payout_format="{payout:.0f}", payout_int=123 + ) + + assert "$1.23" == format_payout_format( + payout_format="${payout/100:.2f}", payout_int=123 + ) diff --git a/tests/models/thl/test_product.py b/tests/models/thl/test_product.py new file mode 100644 index 0000000..52f60c2 --- /dev/null +++ b/tests/models/thl/test_product.py @@ -0,0 +1,1130 @@ +import os +import shutil +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from typing import Optional +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +from generalresearch.currency import USDCent +from generalresearch.models import Source +from generalresearch.models.thl.product import ( + Product, + PayoutConfig, + PayoutTransformation, + ProfilingConfig, + SourcesConfig, + IntegrationMode, + SupplyConfig, + SourceConfig, + SupplyPolicy, +) + + +class TestProduct: + + def test_init(self): + # By default, just a Pydantic instance doesn't have an id_int + instance = Product.model_validate( + dict( + id="968a9acc79b74b6fb49542d82516d284", + name="test-968a9acc", + redirect_url="https://www.google.com/hey", + ) + ) + assert instance.id_int is None + + res = instance.model_dump_json() + # We're not excluding anything here, only in the "*Out" variants + assert "id_int" in res + + def test_init_db(self, product_manager): + # By default, just a Pydantic instance doesn't have an id_int + instance = product_manager.create_dummy() + assert isinstance(instance.id_int, int) + + res = instance.model_dump_json() + + # we json skip & exclude + res = instance.model_dump() + + def test_redirect_url(self): + p = Product.model_validate( + dict( + id="968a9acc79b74b6fb49542d82516d284", + created="2023-09-21T22:13:09.274672Z", + commission_pct=Decimal("0.05"), + enabled=True, + sources=[{"name": "d", "active": True}], + name="test-968a9acc", + max_session_len=600, + team_id="8b5e94afd8a246bf8556ad9986486baa", + redirect_url="https://www.google.com/hey", + ) + ) + + with pytest.raises(expected_exception=ValidationError): + p.redirect_url = "" + + with pytest.raises(expected_exception=ValidationError): + p.redirect_url = None + + with pytest.raises(expected_exception=ValidationError): + p.redirect_url = "http://www.example.com/test/?a=1&b=2" + + with pytest.raises(expected_exception=ValidationError): + p.redirect_url = "http://www.example.com/test/?a=1&b=2&tsid=" + + p.redirect_url = "https://www.example.com/test/?a=1&b=2" + c = p.generate_bp_redirect(tsid="c6ab6ba1e75b44e2bf5aab00fc68e3b7") + assert ( + c + == "https://www.example.com/test/?a=1&b=2&tsid=c6ab6ba1e75b44e2bf5aab00fc68e3b7" + ) + + def test_harmonizer_domain(self): + p = Product( + id="968a9acc79b74b6fb49542d82516d284", + created="2023-09-21T22:13:09.274672Z", + commission_pct=Decimal("0.05"), + enabled=True, + name="test-968a9acc", + team_id="8b5e94afd8a246bf8556ad9986486baa", + harmonizer_domain="profile.generalresearch.com", + redirect_url="https://www.google.com/hey", + ) + assert p.harmonizer_domain == "https://profile.generalresearch.com/" + p.harmonizer_domain = "https://profile.generalresearch.com/" + p.harmonizer_domain = "https://profile.generalresearch.com" + assert p.harmonizer_domain == "https://profile.generalresearch.com/" + with pytest.raises(expected_exception=Exception): + p.harmonizer_domain = "" + with pytest.raises(expected_exception=Exception): + p.harmonizer_domain = None + with pytest.raises(expected_exception=Exception): + # no https + p.harmonizer_domain = "http://profile.generalresearch.com" + with pytest.raises(expected_exception=Exception): + # "/a" at the end + p.harmonizer_domain = "https://profile.generalresearch.com/a" + + def test_payout_xform(self): + p = Product( + id="968a9acc79b74b6fb49542d82516d284", + created="2023-09-21T22:13:09.274672Z", + commission_pct=Decimal("0.05"), + enabled=True, + name="test-968a9acc", + team_id="8b5e94afd8a246bf8556ad9986486baa", + harmonizer_domain="profile.generalresearch.com", + redirect_url="https://www.google.com/hey", + ) + + p.payout_config.payout_transformation = PayoutTransformation.model_validate( + { + "f": "payout_transformation_percent", + "kwargs": {"pct": "0.5", "min_payout": "0.10"}, + } + ) + + assert ( + "payout_transformation_percent" == p.payout_config.payout_transformation.f + ) + assert 0.5 == p.payout_config.payout_transformation.kwargs.pct + assert ( + Decimal("0.10") == p.payout_config.payout_transformation.kwargs.min_payout + ) + assert p.payout_config.payout_transformation.kwargs.max_payout is None + + # This calls get_payout_transformation_func + # 50% of $1.00 + assert Decimal("0.50") == p.calculate_user_payment(Decimal(1)) + # with a min + assert Decimal("0.10") == p.calculate_user_payment(Decimal("0.15")) + + with pytest.raises(expected_exception=ValidationError) as cm: + p.payout_config.payout_transformation = PayoutTransformation.model_validate( + {"f": "payout_transformation_percent", "kwargs": {}} + ) + assert "1 validation error for PayoutTransformation\nkwargs.pct" in str( + cm.value + ) + + with pytest.raises(expected_exception=ValidationError) as cm: + p.payout_config.payout_transformation = PayoutTransformation.model_validate( + {"f": "payout_transformation_percent"} + ) + + assert "1 validation error for PayoutTransformation\nkwargs" in str(cm.value) + + with pytest.warns(expected_warning=Warning) as w: + p.payout_config.payout_transformation = PayoutTransformation.model_validate( + { + "f": "payout_transformation_percent", + "kwargs": {"pct": 1, "min_payout": "0.5"}, + } + ) + assert "Are you sure you want to pay respondents >95% of CPI?" in "".join( + [str(i.message) for i in w] + ) + + p.payout_config = PayoutConfig() + assert p.calculate_user_payment(Decimal("0.15")) is None + + def test_payout_xform_amt(self): + p = Product( + id="968a9acc79b74b6fb49542d82516d284", + created="2023-09-21T22:13:09.274672Z", + commission_pct=Decimal("0.05"), + enabled=True, + name="test-968a9acc", + team_id="8b5e94afd8a246bf8556ad9986486baa", + harmonizer_domain="profile.generalresearch.com", + redirect_url="https://www.google.com/hey", + ) + + p.payout_config.payout_transformation = PayoutTransformation.model_validate( + { + "f": "payout_transformation_amt", + } + ) + + assert "payout_transformation_amt" == p.payout_config.payout_transformation.f + + # This calls get_payout_transformation_func + # 95% of $1.00 + assert p.calculate_user_payment(Decimal(1)) == Decimal("0.95") + assert p.calculate_user_payment(Decimal("1.05")) == Decimal("1.00") + + assert p.calculate_user_payment( + Decimal("0.10"), user_wallet_balance=Decimal(0) + ) == Decimal("0.07") + assert p.calculate_user_payment( + Decimal("1.05"), user_wallet_balance=Decimal(0) + ) == Decimal("0.97") + assert p.calculate_user_payment( + Decimal(".05"), user_wallet_balance=Decimal(1) + ) == Decimal("0.02") + # final balance will be <0, so pay the full amount + assert p.calculate_user_payment( + Decimal(".50"), user_wallet_balance=Decimal(-1) + ) == p.calculate_user_payment(Decimal("0.50")) + # final balance will be >0, so do the 7c rounding + assert p.calculate_user_payment( + Decimal(".50"), user_wallet_balance=Decimal("-0.10") + ) == ( + p.calculate_user_payment(Decimal(".40"), user_wallet_balance=Decimal(0)) + - Decimal("-0.10") + ) + + def test_payout_xform_none(self): + p = Product( + id="968a9acc79b74b6fb49542d82516d284", + created="2023-09-21T22:13:09.274672Z", + commission_pct=Decimal("0.05"), + enabled=True, + name="test-968a9acc", + team_id="8b5e94afd8a246bf8556ad9986486baa", + harmonizer_domain="profile.generalresearch.com", + redirect_url="https://www.google.com/hey", + payout_config=PayoutConfig(payout_format=None, payout_transformation=None), + ) + assert p.format_payout_format(Decimal("1.00")) is None + + pt = PayoutTransformation.model_validate( + {"kwargs": {"pct": 0.5}, "f": "payout_transformation_percent"} + ) + p.payout_config = PayoutConfig( + payout_format="{payout*10:,.0f} Points", payout_transformation=pt + ) + assert p.format_payout_format(Decimal("1.00")) == "1,000 Points" + + def test_profiling(self): + p = Product( + id="968a9acc79b74b6fb49542d82516d284", + created="2023-09-21T22:13:09.274672Z", + commission_pct=Decimal("0.05"), + enabled=True, + name="test-968a9acc", + team_id="8b5e94afd8a246bf8556ad9986486baa", + harmonizer_domain="profile.generalresearch.com", + redirect_url="https://www.google.com/hey", + ) + assert p.profiling_config.enabled is True + + p.profiling_config = ProfilingConfig(max_questions=1) + assert p.profiling_config.max_questions == 1 + + def test_bp_account(self, product, thl_lm): + assert product.bp_account is None + + product.prefetch_bp_account(thl_lm=thl_lm) + + from generalresearch.models.thl.ledger import LedgerAccount + + assert isinstance(product.bp_account, LedgerAccount) + + +class TestGlobalProduct: + # We have one product ID that is special; we call it the Global + # Product ID and in prod the. This product stores a bunch of extra + # things in the SourcesConfig + + def test_init_and_props(self): + instance = Product( + name="Global Config", + redirect_url="https://www.example.com", + sources_config=SupplyConfig( + policies=[ + # This is the config for Dynata that any BP is allowed to use + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + integration_mode=IntegrationMode.PLATFORM, + ), + # Spectrum that is using OUR credentials, that anyone is allowed to use. + # Same as the dynata config above, just that the dynata supplier_id is + # inferred by the dynata-grpc; it's not required to be set. + SupplyPolicy( + address=["https://spectrum.internal:50051"], + active=True, + name=Source.SPECTRUM, + supplier_id="example-supplier-id", + # implicit Scope = GLOBAL + # default integration_mode=IntegrationMode.PLATFORM, + ), + # A spectrum config with a different supplier_id, but + # it is OUR supplier, and we are paid for the completes. Only a certain BP + # can use this config. + SupplyPolicy( + address=["https://spectrum.internal:50051"], + active=True, + name=Source.SPECTRUM, + supplier_id="example-supplier-id", + team_ids=["d42194c2dfe44d7c9bec98123bc4a6c0"], + # implicit Scope = TEAM + # default integration_mode=IntegrationMode.PLATFORM, + ), + # The supplier ID is associated with THEIR + # credentials, and we do not get paid for this activity. + SupplyPolicy( + address=["https://cint.internal:50051"], + active=True, + name=Source.CINT, + supplier_id="example-supplier-id", + product_ids=["db8918b3e87d4444b60241d0d3a54caa"], + integration_mode=IntegrationMode.PASS_THROUGH, + ), + # We could have another global cint integration available + # to anyone also, or we could have another like above + SupplyPolicy( + address=["https://cint.internal:50051"], + active=True, + name=Source.CINT, + supplier_id="example-supplier-id", + team_ids=["b163972a59584de881e5eab01ad10309"], + integration_mode=IntegrationMode.PASS_THROUGH, + ), + ] + ), + ) + + assert Product.model_validate_json(instance.model_dump_json()) == instance + + s = instance.sources_config + # Cint should NOT have a global config + assert set(s.global_scoped_policies_dict.keys()) == { + Source.DYNATA, + Source.SPECTRUM, + } + + # The spectrum global config is the one that isn't scoped to a + # specific supplier + assert ( + s.global_scoped_policies_dict[Source.SPECTRUM].supplier_id + == "grl-supplier-id" + ) + + assert set(s.team_scoped_policies_dict.keys()) == { + "b163972a59584de881e5eab01ad10309", + "d42194c2dfe44d7c9bec98123bc4a6c0", + } + # This team has one team-scoped config, and it's for spectrum + assert s.team_scoped_policies_dict[ + "d42194c2dfe44d7c9bec98123bc4a6c0" + ].keys() == {Source.SPECTRUM} + + # For a random product/team, it'll just have the globally-scoped config + random_product = uuid4().hex + random_team = uuid4().hex + res = instance.sources_config.get_policies_for( + product_id=random_product, team_id=random_team + ) + assert res == s.global_scoped_policies_dict + + # It'll have the global config plus cint, and it should use the PRODUCT + # scoped config, not the TEAM scoped! + res = instance.sources_config.get_policies_for( + product_id="db8918b3e87d4444b60241d0d3a54caa", + team_id="b163972a59584de881e5eab01ad10309", + ) + assert set(res.keys()) == { + Source.DYNATA, + Source.SPECTRUM, + Source.CINT, + } + assert res[Source.CINT].supplier_id == "example-supplier-id" + + def test_source_vs_supply_validate(self): + # sources_config can be a SupplyConfig or SourcesConfig. + # make sure they get model_validated correctly + gp = Product( + name="Global Config", + redirect_url="https://www.example.com", + sources_config=SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + integration_mode=IntegrationMode.PLATFORM, + ) + ] + ), + ) + bp = Product( + name="test product config", + redirect_url="https://www.example.com", + sources_config=SourcesConfig( + user_defined=[ + SourceConfig( + active=False, + name=Source.DYNATA, + ) + ] + ), + ) + assert Product.model_validate_json(gp.model_dump_json()) == gp + assert Product.model_validate_json(bp.model_dump_json()) == bp + + def test_validations(self): + with pytest.raises( + ValidationError, match="Can only have one GLOBAL policy per Source" + ): + SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + integration_mode=IntegrationMode.PLATFORM, + ), + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + integration_mode=IntegrationMode.PASS_THROUGH, + ), + ] + ) + with pytest.raises( + ValidationError, + match="Can only have one PRODUCT policy per Source per BP", + ): + SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + product_ids=["7e417dec1c8a406e8554099b46e518ca"], + integration_mode=IntegrationMode.PLATFORM, + ), + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + product_ids=["7e417dec1c8a406e8554099b46e518ca"], + integration_mode=IntegrationMode.PASS_THROUGH, + ), + ] + ) + with pytest.raises( + ValidationError, + match="Can only have one TEAM policy per Source per Team", + ): + SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + team_ids=["7e417dec1c8a406e8554099b46e518ca"], + integration_mode=IntegrationMode.PLATFORM, + ), + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + team_ids=["7e417dec1c8a406e8554099b46e518ca"], + integration_mode=IntegrationMode.PASS_THROUGH, + ), + ] + ) + + +class TestGlobalProductConfigFor: + def test_no_user_defined(self): + sc = SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + ) + ] + ) + product = Product( + name="Test Product Config", + redirect_url="https://www.example.com", + sources_config=SourcesConfig(), + ) + res = sc.get_config_for_product(product=product) + assert len(res.policies) == 1 + + def test_user_defined_merge(self): + sc = SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + banned_countries=["mx"], + active=True, + name=Source.DYNATA, + ), + SupplyPolicy( + address=["https://dynata.internal:50051"], + banned_countries=["ca"], + active=True, + name=Source.DYNATA, + team_ids=[uuid4().hex], + ), + ] + ) + product = Product( + name="Test Product Config", + redirect_url="https://www.example.com", + sources_config=SourcesConfig( + user_defined=[ + SourceConfig( + name=Source.DYNATA, + active=False, + banned_countries=["us"], + ) + ] + ), + ) + res = sc.get_config_for_product(product=product) + assert len(res.policies) == 1 + assert not res.policies[0].active + assert res.policies[0].banned_countries == ["mx", "us"] + + def test_no_eligible(self): + sc = SupplyConfig( + policies=[ + SupplyPolicy( + address=["https://dynata.internal:50051"], + active=True, + name=Source.DYNATA, + team_ids=["7e417dec1c8a406e8554099b46e518ca"], + integration_mode=IntegrationMode.PLATFORM, + ) + ] + ) + product = Product( + name="Test Product Config", + redirect_url="https://www.example.com", + sources_config=SourcesConfig(), + ) + res = sc.get_config_for_product(product=product) + assert len(res.policies) == 0 + + +class TestProductFinancials: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "30d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return None + + def test_balance( + self, + business, + product_factory, + user_factory, + mnt_filepath, + bp_payout_factory, + thl_lm, + lm, + duration, + offset, + thl_redis_config, + start, + thl_web_rr, + brokerage_product_payout_event_manager, + session_with_tx_factory, + delete_ledger_db, + create_main_accounts, + client_no_amm, + ledger_collection, + pop_ledger_merge, + delete_df_collection, + ): + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + from generalresearch.models.thl.finance import ProductBalances + from generalresearch.currency import USDCent + + p1: Product = product_factory(business=business) + u1: User = user_factory(product=p1) + bp_wallet = thl_lm.get_account_or_create_bp_wallet(product=p1) + thl_lm.get_account_or_create_user_wallet(user=u1) + brokerage_product_payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + + assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 0 + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".50"), + started=start + timedelta(days=1), + ) + assert thl_lm.get_account_balance(account=bp_wallet) == 48 + assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 1 + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal("1.00"), + started=start + timedelta(days=2), + ) + assert thl_lm.get_account_balance(account=bp_wallet) == 143 + assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 2 + + with pytest.raises(expected_exception=AssertionError) as cm: + p1.prebuild_balance( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + ) + assert "Cannot build Product Balance" in str(cm.value) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + p1.prebuild_balance( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + ) + assert isinstance(p1.balance, ProductBalances) + assert p1.balance.payout == 143 + assert p1.balance.adjustment == 0 + assert p1.balance.expense == 0 + assert p1.balance.net == 143 + assert p1.balance.balance == 143 + assert p1.balance.retainer == 35 + assert p1.balance.available_balance == 108 + + p1.prebuild_payouts( + thl_lm=thl_lm, + bp_pem=brokerage_product_payout_event_manager, + ) + assert p1.payouts is not None + assert len(p1.payouts) == 0 + assert p1.payouts_total == 0 + assert p1.payouts_total_str == "$0.00" + + # -- Now pay them out... + + bp_payout_factory( + product=p1, + amount=USDCent(50), + created=start + timedelta(days=3), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 3 + + # RM the entire directories + shutil.rmtree(ledger_collection.archive_path) + os.makedirs(ledger_collection.archive_path, exist_ok=True) + shutil.rmtree(pop_ledger_merge.archive_path) + os.makedirs(pop_ledger_merge.archive_path, exist_ok=True) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + p1.prebuild_balance( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + ) + assert isinstance(p1.balance, ProductBalances) + assert p1.balance.payout == 143 + assert p1.balance.adjustment == 0 + assert p1.balance.expense == 0 + assert p1.balance.net == 143 + assert p1.balance.balance == 93 + assert p1.balance.retainer == 23 + assert p1.balance.available_balance == 70 + + p1.prebuild_payouts( + thl_lm=thl_lm, + bp_pem=brokerage_product_payout_event_manager, + ) + assert p1.payouts is not None + assert len(p1.payouts) == 1 + assert p1.payouts_total == 50 + assert p1.payouts_total_str == "$0.50" + + # -- Now pay ou another!. + + bp_payout_factory( + product=p1, + amount=USDCent(5), + created=start + timedelta(days=4), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + assert len(thl_lm.get_tx_filtered_by_account(account_uuid=bp_wallet.uuid)) == 4 + + # RM the entire directories + shutil.rmtree(ledger_collection.archive_path) + os.makedirs(ledger_collection.archive_path, exist_ok=True) + shutil.rmtree(pop_ledger_merge.archive_path) + os.makedirs(pop_ledger_merge.archive_path, exist_ok=True) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + p1.prebuild_balance( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + ) + assert isinstance(p1.balance, ProductBalances) + assert p1.balance.payout == 143 + assert p1.balance.adjustment == 0 + assert p1.balance.expense == 0 + assert p1.balance.net == 143 + assert p1.balance.balance == 88 + assert p1.balance.retainer == 22 + assert p1.balance.available_balance == 66 + + p1.prebuild_payouts( + thl_lm=thl_lm, + bp_pem=brokerage_product_payout_event_manager, + ) + assert p1.payouts is not None + assert len(p1.payouts) == 2 + assert p1.payouts_total == 55 + assert p1.payouts_total_str == "$0.55" + + +class TestProductBalance: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "30d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return None + + def test_inconsistent( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + payout_event_manager, + ): + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + + # 1. Complete and Build Parquets 1st time + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + # 2. Payout and build Parquets 2nd time + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + bp_payout_factory( + product=product, + amount=USDCent(71), + ext_ref_id=uuid4().hex, + created=start + timedelta(days=1, minutes=1), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + with pytest.raises(expected_exception=AssertionError) as cm: + product.prebuild_balance( + thl_lm=thl_lm, ds=mnt_filepath, client=client_no_amm + ) + assert "Sql and Parquet Balance inconsistent" in str(cm) + + def test_not_inconsistent( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + payout_event_manager, + ): + # This is very similar to the test_complete_payout_pq_inconsistent + # test, however this time we're only going to assign the payout + # in real time, and not in the past. This means that even if we + # build the parquet files multiple times, they will include the + # payout. + + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + + # 1. Complete and Build Parquets 1st time + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + # 2. Payout and build Parquets 2nd time but this payout is "now" + # so it hasn't already been archived + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + bp_payout_factory( + product=product, + amount=USDCent(71), + ext_ref_id=uuid4().hex, + created=datetime.now(tz=timezone.utc), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + # We just want to call this to confirm it doesn't raise. + product.prebuild_balance(thl_lm=thl_lm, ds=mnt_filepath, client=client_no_amm) + + +class TestProductPOPFinancial: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "30d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return None + + def test_base( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + payout_event_manager, + ): + # This is very similar to the test_complete_payout_pq_inconsistent + # test, however this time we're only going to assign the payout + # in real time, and not in the past. This means that even if we + # build the parquet files multiple times, they will include the + # payout. + + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + + # 1. Complete and Build Parquets 1st time + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + # --- test --- + assert product.pop_financial is None + product.prebuild_pop_financial( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + pop_ledger=pop_ledger_merge, + ) + + from generalresearch.models.thl.finance import POPFinancial + + assert isinstance(product.pop_financial, list) + assert isinstance(product.pop_financial[0], POPFinancial) + pf1: POPFinancial = product.pop_financial[0] + assert isinstance(pf1.time, datetime) + assert pf1.payout == 71 + assert pf1.net == 71 + assert pf1.adjustment == 0 + for adj in pf1.adjustment_types: + assert adj.amount == 0 + + +class TestProductCache: + + @pytest.fixture + def start(self) -> "datetime": + return datetime(year=2018, month=3, day=14, hour=0, tzinfo=timezone.utc) + + @pytest.fixture + def offset(self) -> str: + return "30d" + + @pytest.fixture + def duration(self) -> Optional["timedelta"]: + return None + + def test_basic( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + ): + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + # Confirm the default / null behavior + rc = thl_redis_config.create_redis_client() + res: Optional[str] = rc.get(product.cache_key) + assert res is None + with pytest.raises(expected_exception=AssertionError): + product.set_cache( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + bp_pem=brokerage_product_payout_event_manager, + redis_config=thl_redis_config, + ) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + + session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + # Now try again with everything in place + product.set_cache( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + bp_pem=brokerage_product_payout_event_manager, + redis_config=thl_redis_config, + ) + + # Fetch from cache and assert the instance loaded from redis + res: Optional[str] = rc.get(product.cache_key) + assert isinstance(res, str) + from generalresearch.models.thl.ledger import LedgerAccount + + assert isinstance(product.bp_account, LedgerAccount) + + p1: Product = Product.model_validate_json(res) + assert p1.balance.product_id == product.uuid + assert p1.balance.payout_usd_str == "$0.71" + assert p1.balance.retainer_usd_str == "$0.17" + assert p1.balance.available_balance_usd_str == "$0.54" + + def test_neg_balance_cache( + self, + product, + mnt_filepath, + thl_lm, + client_no_amm, + thl_redis_config, + brokerage_product_payout_event_manager, + delete_ledger_db, + create_main_accounts, + delete_df_collection, + ledger_collection, + business, + user_factory, + product_factory, + session_with_tx_factory, + pop_ledger_merge, + start, + bp_payout_factory, + payout_event_manager, + adj_to_fail_with_tx_factory, + ): + # Now let's load it up and actually test some things + delete_ledger_db() + create_main_accounts() + delete_df_collection(coll=ledger_collection) + + from generalresearch.models.thl.product import Product + from generalresearch.models.thl.user import User + + u1: User = user_factory(product=product) + + # 1. Complete + s1 = session_with_tx_factory( + user=u1, + wall_req_cpi=Decimal(".75"), + started=start + timedelta(days=1), + ) + + # 2. Payout + payout_event_manager.set_account_lookup_table(thl_lm=thl_lm) + bp_payout_factory( + product=product, + amount=USDCent(71), + ext_ref_id=uuid4().hex, + created=start + timedelta(days=1, minutes=1), + skip_wallet_balance_check=True, + skip_one_per_day_check=True, + ) + + # 3. Recon + adj_to_fail_with_tx_factory( + session=s1, + created=start + timedelta(days=1, minutes=1), + ) + + # Finally, process everything: + ledger_collection.initial_load(client=None, sync=True) + pop_ledger_merge.build(client=client_no_amm, ledger_coll=ledger_collection) + + product.set_cache( + thl_lm=thl_lm, + ds=mnt_filepath, + client=client_no_amm, + bp_pem=brokerage_product_payout_event_manager, + redis_config=thl_redis_config, + ) + + # Fetch from cache and assert the instance loaded from redis + rc = thl_redis_config.create_redis_client() + res: Optional[str] = rc.get(product.cache_key) + assert isinstance(res, str) + + p1: Product = Product.model_validate_json(res) + assert p1.balance.product_id == product.uuid + assert p1.balance.payout_usd_str == "$0.71" + assert p1.balance.adjustment == -71 + assert p1.balance.expense == 0 + assert p1.balance.net == 0 + assert p1.balance.balance == -71 + assert p1.balance.retainer_usd_str == "$0.00" + assert p1.balance.available_balance_usd_str == "$0.00" diff --git a/tests/models/thl/test_product_userwalletconfig.py b/tests/models/thl/test_product_userwalletconfig.py new file mode 100644 index 0000000..4583c46 --- /dev/null +++ b/tests/models/thl/test_product_userwalletconfig.py @@ -0,0 +1,56 @@ +from itertools import groupby +from random import shuffle as rshuffle + +from generalresearch.models.thl.product import ( + UserWalletConfig, +) + +from generalresearch.models.thl.wallet import PayoutType + + +def all_equal(iterable): + g = groupby(iterable) + return next(g, True) and not next(g, False) + + +class TestProductUserWalletConfig: + + def test_init(self): + instance = UserWalletConfig() + + assert isinstance(instance, UserWalletConfig) + + # Check the defaults + assert not instance.enabled + assert not instance.amt + + assert isinstance(instance.supported_payout_types, set) + assert len(instance.supported_payout_types) == 3 + + assert instance.min_cashout is None + + def test_model_dump(self): + instance = UserWalletConfig() + + # If we use the defaults, the supported_payout_types are always + # in the same order because they're the same + assert isinstance(instance.model_dump_json(), str) + res = [] + for idx in range(100): + res.append(instance.model_dump_json()) + assert all_equal(res) + + def test_model_dump_payout_types(self): + res = [] + for idx in range(100): + + # Generate a random order of PayoutTypes each time + payout_types = [e for e in PayoutType] + rshuffle(payout_types) + instance = UserWalletConfig.model_validate( + {"supported_payout_types": payout_types} + ) + + res.append(instance.model_dump_json()) + + assert all_equal(res) diff --git a/tests/models/thl/test_soft_pair.py b/tests/models/thl/test_soft_pair.py new file mode 100644 index 0000000..bac0e8d --- /dev/null +++ b/tests/models/thl/test_soft_pair.py @@ -0,0 +1,24 @@ +from generalresearch.models import Source +from generalresearch.models.thl.soft_pair import SoftPairResult, SoftPairResultType + + +def test_model(): + from generalresearch.models.dynata.survey import ( + DynataCondition, + ConditionValueType, + ) + + c1 = DynataCondition( + question_id="1", value_type=ConditionValueType.LIST, values=["a", "b"] + ) + c2 = DynataCondition( + question_id="2", value_type=ConditionValueType.LIST, values=["c", "d"] + ) + sr = SoftPairResult( + pair_type=SoftPairResultType.CONDITIONAL, + source=Source.DYNATA, + survey_id="xxx", + conditions={c1, c2}, + ) + assert sr.grpc_string == "xxx:1;2" + assert sr.survey_sid == "d:xxx" diff --git a/tests/models/thl/test_upkquestion.py b/tests/models/thl/test_upkquestion.py new file mode 100644 index 0000000..e67427e --- /dev/null +++ b/tests/models/thl/test_upkquestion.py @@ -0,0 +1,414 @@ +import pytest +from pydantic import ValidationError + + +class TestUpkQuestion: + + def test_importance(self): + from generalresearch.models.thl.profiling.upk_question import ( + UPKImportance, + ) + + ui = UPKImportance(task_score=1, task_count=None) + ui = UPKImportance(task_score=0) + with pytest.raises(ValidationError) as e: + UPKImportance(task_score=-1) + assert "Input should be greater than or equal to 0" in str(e.value) + + def test_pattern(self): + from generalresearch.models.thl.profiling.upk_question import ( + PatternValidation, + ) + + s = PatternValidation(message="hi", pattern="x") + with pytest.raises(ValidationError) as e: + s.message = "sfd" + assert "Instance is frozen" in str(e.value) + + def test_mc(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionSelectorMC, + UpkQuestionType, + UpkQuestion, + UpkQuestionConfigurationMC, + ) + + q = UpkQuestion( + id="601377a0d4c74529afc6293a8e5c3b5e", + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + text="whats up", + choices=[ + UpkQuestionChoice(id="1", text="sky", order=1), + UpkQuestionChoice(id="2", text="moon", order=2), + ], + configuration=UpkQuestionConfigurationMC(max_select=2), + ) + assert q == UpkQuestion.model_validate(q.model_dump(mode="json")) + + q = UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.SINGLE_ANSWER, + text="yes or no", + choices=[ + UpkQuestionChoice(id="1", text="yes", order=1), + UpkQuestionChoice(id="2", text="no", order=2), + ], + configuration=UpkQuestionConfigurationMC(max_select=1), + ) + assert q == UpkQuestion.model_validate(q.model_dump(mode="json")) + + q = UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + text="yes or no", + choices=[ + UpkQuestionChoice(id="1", text="yes", order=1), + UpkQuestionChoice(id="2", text="no", order=2), + ], + ) + assert q == UpkQuestion.model_validate(q.model_dump(mode="json")) + + with pytest.raises(ValidationError) as e: + q = UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.SINGLE_ANSWER, + text="yes or no", + choices=[ + UpkQuestionChoice(id="1", text="yes", order=1), + UpkQuestionChoice(id="2", text="no", order=2), + ], + configuration=UpkQuestionConfigurationMC(max_select=2), + ) + assert "max_select must be 1 if the selector is SA" in str(e.value) + + with pytest.raises(ValidationError) as e: + q = UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + text="yes or no", + choices=[ + UpkQuestionChoice(id="1", text="yes", order=1), + UpkQuestionChoice(id="2", text="no", order=2), + ], + configuration=UpkQuestionConfigurationMC(max_select=4), + ) + assert "max_select must be >= len(choices)" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as e: + q = UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + text="yes or no", + choices=[ + UpkQuestionChoice(id="1", text="yes", order=1), + UpkQuestionChoice(id="2", text="no", order=2), + ], + configuration=UpkQuestionConfigurationMC(max_length=2), + ) + assert "Extra inputs are not permitted" in str(e.value) + + def test_te(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionType, + UpkQuestion, + UpkQuestionSelectorTE, + UpkQuestionValidation, + PatternValidation, + UpkQuestionConfigurationTE, + ) + + q = UpkQuestion( + id="601377a0d4c74529afc6293a8e5c3b5e", + country_iso="us", + language_iso="eng", + type=UpkQuestionType.TEXT_ENTRY, + selector=UpkQuestionSelectorTE.MULTI_LINE, + text="whats up", + choices=[], + configuration=UpkQuestionConfigurationTE(max_length=2), + validation=UpkQuestionValidation( + patterns=[PatternValidation(pattern=".", message="x")] + ), + ) + assert q == UpkQuestion.model_validate(q.model_dump(mode="json")) + assert q.choices is None + + def test_deserialization(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + + q = UpkQuestion.model_validate( + { + "id": "601377a0d4c74529afc6293a8e5c3b5e", + "ext_question_id": "m:2342", + "country_iso": "us", + "language_iso": "eng", + "text": "whats up", + "choices": [ + {"id": "1", "text": "yes", "order": 1}, + {"id": "2", "text": "no", "order": 2}, + ], + "importance": None, + "type": "MC", + "selector": "SA", + "configuration": None, + } + ) + assert q == UpkQuestion.model_validate(q.model_dump(mode="json")) + + q = UpkQuestion.model_validate( + { + "id": "601377a0d4c74529afc6293a8e5c3b5e", + "ext_question_id": "m:2342", + "country_iso": "us", + "language_iso": "eng", + "text": "whats up", + "choices": [ + {"id": "1", "text": "yes", "order": 1}, + {"id": "2", "text": "no", "order": 2}, + ], + "importance": None, + "question_type": "MC", + "selector": "MA", + "configuration": {"max_select": 2}, + } + ) + assert q == UpkQuestion.model_validate(q.model_dump(mode="json")) + + def test_from_morning(self): + from generalresearch.models.morning.question import ( + MorningQuestion, + MorningQuestionType, + ) + + q = MorningQuestion( + **{ + "id": "gender", + "country_iso": "us", + "language_iso": "eng", + "name": "Gender", + "text": "What is your gender?", + "type": "s", + "options": [ + {"id": "1", "text": "yes", "order": 1}, + {"id": "2", "text": "no", "order": 2}, + ], + } + ) + q.to_upk_question() + q = MorningQuestion( + country_iso="us", + language_iso="eng", + type=MorningQuestionType.text_entry, + text="how old r u", + id="a", + name="age", + ) + q.to_upk_question() + + def test_order(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestionChoice, + UpkQuestionSelectorMC, + UpkQuestionType, + UpkQuestion, + order_exclusive_options, + ) + + q = UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + text="yes, no, or NA?", + choices=[ + UpkQuestionChoice(id="1", text="NA", order=0), + UpkQuestionChoice(id="2", text="no", order=1), + UpkQuestionChoice(id="3", text="yes", order=2), + ], + ) + order_exclusive_options(q) + assert ( + UpkQuestion( + country_iso="us", + language_iso="eng", + type=UpkQuestionType.MULTIPLE_CHOICE, + selector=UpkQuestionSelectorMC.MULTIPLE_ANSWER, + text="yes, no, or NA?", + choices=[ + UpkQuestionChoice(id="2", text="no", order=0), + UpkQuestionChoice(id="3", text="yes", order=1), + UpkQuestionChoice(id="1", text="NA", order=2, exclusive=True), + ], + ) + == q + ) + + +class TestUpkQuestionValidateAnswer: + def test_validate_answer_SA(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + + question = UpkQuestion.model_validate( + { + "choices": [ + {"order": 0, "choice_id": "0", "choice_text": "Male"}, + {"order": 1, "choice_id": "1", "choice_text": "Female"}, + {"order": 2, "choice_id": "2", "choice_text": "Other"}, + ], + "selector": "SA", + "country_iso": "us", + "question_id": "5d6d9f3c03bb40bf9d0a24f306387d7c", + "language_iso": "eng", + "question_text": "What is your gender?", + "question_type": "MC", + } + ) + answer = ("0",) + assert question.validate_question_answer(answer)[0] is True + answer = ("3",) + assert question.validate_question_answer(answer) == ( + False, + "Invalid Options Selected", + ) + answer = ("0", "0") + assert question.validate_question_answer(answer) == ( + False, + "Multiple of the same answer submitted", + ) + answer = ("0", "1") + assert question.validate_question_answer(answer) == ( + False, + "Single Answer MC question with >1 selected " "answers", + ) + + def test_validate_answer_MA(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + + question = UpkQuestion.model_validate( + { + "choices": [ + { + "order": 0, + "choice_id": "none", + "exclusive": True, + "choice_text": "None of the above", + }, + { + "order": 1, + "choice_id": "female_under_1", + "choice_text": "Female under age 1", + }, + { + "order": 2, + "choice_id": "male_under_1", + "choice_text": "Male under age 1", + }, + { + "order": 3, + "choice_id": "female_1", + "choice_text": "Female age 1", + }, + {"order": 4, "choice_id": "male_1", "choice_text": "Male age 1"}, + { + "order": 5, + "choice_id": "female_2", + "choice_text": "Female age 2", + }, + ], + # I removed a bunch of choices fyi + "selector": "MA", + "country_iso": "us", + "question_id": "3b65220db85f442ca16bb0f1c0e3a456", + "language_iso": "eng", + "question_text": "Please indicate the age and gender of your child or children:", + "question_type": "MC", + } + ) + answer = ("none",) + assert question.validate_question_answer(answer)[0] is True + answer = ("male_1",) + assert question.validate_question_answer(answer)[0] is True + answer = ("male_1", "female_1") + assert question.validate_question_answer(answer)[0] is True + answer = ("xxx",) + assert question.validate_question_answer(answer) == ( + False, + "Invalid Options Selected", + ) + answer = ("male_1", "male_1") + assert question.validate_question_answer(answer) == ( + False, + "Multiple of the same answer submitted", + ) + answer = ("male_1", "xxx") + assert question.validate_question_answer(answer) == ( + False, + "Invalid Options Selected", + ) + answer = ("male_1", "none") + assert question.validate_question_answer(answer) == ( + False, + "Invalid exclusive selection", + ) + + def test_validate_answer_TE(self): + from generalresearch.models.thl.profiling.upk_question import ( + UpkQuestion, + ) + + question = UpkQuestion.model_validate( + { + "selector": "SL", + "validation": { + "patterns": [ + { + "message": "Must enter a valid zip code: XXXXX", + "pattern": "^[0-9]{5}$", + } + ] + }, + "country_iso": "us", + "question_id": "543de254e9ca4d9faded4377edab82a9", + "language_iso": "eng", + "configuration": {"max_length": 5, "min_length": 5}, + "question_text": "What is your zip code?", + "question_type": "TE", + } + ) + answer = ("33143",) + assert question.validate_question_answer(answer)[0] is True + answer = ("33143", "33143") + assert question.validate_question_answer(answer) == ( + False, + "Multiple of the same answer submitted", + ) + answer = ("33143", "12345") + assert question.validate_question_answer(answer) == ( + False, + "Only one answer allowed", + ) + answer = ("111",) + assert question.validate_question_answer(answer) == ( + False, + "Must enter a valid zip code: XXXXX", + ) diff --git a/tests/models/thl/test_user.py b/tests/models/thl/test_user.py new file mode 100644 index 0000000..4f10861 --- /dev/null +++ b/tests/models/thl/test_user.py @@ -0,0 +1,688 @@ +import json +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from random import randint, choice as rand_choice +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + + +class TestUserUserID: + + def test_valid(self): + from generalresearch.models.thl.user import User + + val = randint(1, 2**30) + user = User(user_id=val) + assert user.user_id == val + + def test_type(self): + from generalresearch.models.thl.user import User + + # It will cast str to int + assert User(user_id="1").user_id == 1 + + # It will cast float to int + assert User(user_id=1.0).user_id == 1 + + # It will cast Decimal to int + assert User(user_id=Decimal("1.0")).user_id == 1 + + # pydantic Validation error is a ValueError, let's check both.. + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=Decimal("1.00000001")) + assert "1 validation error for User" in str(cm.value) + assert "user_id" in str(cm.value) + assert "Input should be a valid integer," in str(cm.value) + + with pytest.raises(expected_exception=ValidationError) as cm: + User(user_id=Decimal("1.00000001")) + assert "1 validation error for User" in str(cm.value) + assert "user_id" in str(cm.value) + assert "Input should be a valid integer," in str(cm.value) + + def test_zero(self): + from generalresearch.models.thl.user import User + + with pytest.raises(expected_exception=ValidationError) as cm: + User(user_id=0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be greater than 0" in str(cm.value) + + def test_negative(self): + from generalresearch.models.thl.user import User + + with pytest.raises(expected_exception=ValidationError) as cm: + User(user_id=-1) + assert "1 validation error for User" in str(cm.value) + assert "Input should be greater than 0" in str(cm.value) + + def test_too_big(self): + from generalresearch.models.thl.user import User + + val = 2**31 + with pytest.raises(expected_exception=ValidationError) as cm: + User(user_id=val) + assert "1 validation error for User" in str(cm.value) + assert "Input should be less than 2147483648" in str(cm.value) + + def test_identifiable(self): + from generalresearch.models.thl.user import User + + val = randint(1, 2**30) + user = User(user_id=val) + assert user.is_identifiable + + +class TestUserProductID: + user_id = randint(1, 2**30) + + def test_valid(self): + from generalresearch.models.thl.user import User + + product_id = uuid4().hex + + user = User(user_id=self.user_id, product_id=product_id) + assert user.user_id == self.user_id + assert user.product_id == product_id + + def test_type(self): + from generalresearch.models.thl.user import User + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=0.0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=Decimal("0")) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + def test_empty(self): + from generalresearch.models.thl.user import User + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id="") + assert "1 validation error for User" in str(cm.value) + assert "String should have at least 32 characters" in str(cm.value) + + def test_invalid_len(self): + from generalresearch.models.thl.user import User + + # Valid uuid4s are 32 char long + product_id = uuid4().hex[:31] + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=product_id) + assert "1 validation error for User", str(cm.value) + assert "String should have at least 32 characters", str(cm.value) + + product_id = uuid4().hex * 2 + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, product_id=product_id) + assert "1 validation error for User" in str(cm.value) + assert "String should have at most 32 characters" in str(cm.value) + + product_id = uuid4().hex + product_id *= 2 + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=product_id) + assert "1 validation error for User" in str(cm.value) + assert "String should have at most 32 characters" in str(cm.value) + + def test_invalid_uuid(self): + from generalresearch.models.thl.user import User + + # Modify the UUID to break it + product_id = uuid4().hex[:31] + "x" + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=product_id) + assert "1 validation error for User" in str(cm.value) + assert "Invalid UUID" in str(cm.value) + + def test_invalid_hex_form(self): + from generalresearch.models.thl.user import User + + # Sure not in hex form, but it'll get caught for being the + # wrong length before anything else + product_id = str(uuid4()) # '1a93447e-c77b-4cfa-b58e-ed4777d57110' + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_id=product_id) + assert "1 validation error for User" in str(cm.value) + assert "String should have at most 32 characters" in str(cm.value) + + def test_identifiable(self): + """Can't create a User with only a product_id because it also + needs to the product_user_id""" + from generalresearch.models.thl.user import User + + product_id = uuid4().hex + with pytest.raises(expected_exception=ValueError) as cm: + User(product_id=product_id) + assert "1 validation error for User" in str(cm.value) + assert "Value error, User is not identifiable" in str(cm.value) + + +class TestUserProductUserID: + user_id = randint(1, 2**30) + + def randomword(self, length: int = 50): + # Raw so nothing is escaped to add additional backslashes + _bpuid_allowed = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!#$%&()*+,-.:;<=>?@[]^_{|}~" + return "".join(rand_choice(_bpuid_allowed) for i in range(length)) + + def test_valid(self): + from generalresearch.models.thl.user import User + + product_user_id = uuid4().hex[:12] + user = User(user_id=self.user_id, product_user_id=product_user_id) + + assert user.user_id == self.user_id + assert user.product_user_id == product_user_id + + def test_type(self): + from generalresearch.models.thl.user import User + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=0.0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, product_user_id=Decimal("0")) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + def test_empty(self): + from generalresearch.models.thl.user import User + + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id="") + assert "1 validation error for User" in str(cm.value) + assert "String should have at least 3 characters" in str(cm.value) + + def test_invalid_len(self): + from generalresearch.models.thl.user import User + + product_user_id = self.randomword(251) + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "String should have at most 128 characters" in str(cm.value) + + product_user_id = self.randomword(2) + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "String should have at least 3 characters" in str(cm.value) + + def test_invalid_chars_space(self): + from generalresearch.models.thl.user import User + + product_user_id = f"{self.randomword(50)} {self.randomword(50)}" + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "String cannot contain spaces" in str(cm.value) + + def test_invalid_chars_slash(self): + from generalresearch.models.thl.user import User + + product_user_id = f"{self.randomword(50)}\{self.randomword(50)}" + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "String cannot contain backslash" in str(cm.value) + + product_user_id = f"{self.randomword(50)}/{self.randomword(50)}" + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "String cannot contain slash" in str(cm.value) + + def test_invalid_chars_backtick(self): + """Yes I could keep doing these specific character checks. However, + I wanted a test that made sure the regex was hit. I do not know + how we want to provide with the level of specific String checks + we do in here for specific error messages.""" + from generalresearch.models.thl.user import User + + product_user_id = f"{self.randomword(50)}`{self.randomword(50)}" + with pytest.raises(expected_exception=ValueError) as cm: + User(user_id=self.user_id, product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "String is not valid regex" in str(cm.value) + + def test_unique_from_product_id(self): + # We removed this filter b/c these users already exist. the manager checks for this + # though and we can't create new users like this + pass + # product_id = uuid4().hex + # + # with pytest.raises(ValueError) as cm: + # User(product_id=product_id, product_user_id=product_id) + # assert "1 validation error for User", str(cm.exception)) + # assert "product_user_id must not equal the product_id", str(cm.exception)) + + def test_identifiable(self): + """Can't create a User with only a product_user_id because it also + needs to the product_id""" + from generalresearch.models.thl.user import User + + product_user_id = uuid4().hex + with pytest.raises(ValueError) as cm: + User(product_user_id=product_user_id) + assert "1 validation error for User" in str(cm.value) + assert "Value error, User is not identifiable" in str(cm.value) + + +class TestUserUUID: + user_id = randint(1, 2**30) + + def test_valid(self): + from generalresearch.models.thl.user import User + + uuid_pk = uuid4().hex + + user = User(user_id=self.user_id, uuid=uuid_pk) + assert user.user_id == self.user_id + assert user.uuid == uuid_pk + + def test_type(self): + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=0.0) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=Decimal("0")) + assert "1 validation error for User", str(cm.value) + assert "Input should be a valid string" in str(cm.value) + + def test_empty(self): + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid="") + assert "1 validation error for User", str(cm.value) + assert "String should have at least 32 characters", str(cm.value) + + def test_invalid_len(self): + from generalresearch.models.thl.user import User + + # Valid uuid4s are 32 char long + uuid_pk = uuid4().hex[:31] + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=uuid_pk) + assert "1 validation error for User" in str(cm.value) + assert "String should have at least 32 characters" in str(cm.value) + + # Valid uuid4s are 32 char long + uuid_pk = uuid4().hex + uuid_pk *= 2 + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=uuid_pk) + assert "1 validation error for User" in str(cm.value) + assert "String should have at most 32 characters" in str(cm.value) + + def test_invalid_uuid(self): + from generalresearch.models.thl.user import User + + # Modify the UUID to break it + uuid_pk = uuid4().hex[:31] + "x" + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=uuid_pk) + assert "1 validation error for User" in str(cm.value) + assert "Invalid UUID" in str(cm.value) + + def test_invalid_hex_form(self): + from generalresearch.models.thl.user import User + + # Sure not in hex form, but it'll get caught for being the + # wrong length before anything else + uuid_pk = str(uuid4()) # '1a93447e-c77b-4cfa-b58e-ed4777d57110' + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=uuid_pk) + assert "1 validation error for User" in str(cm.value) + assert "String should have at most 32 characters" in str(cm.value) + + uuid_pk = str(uuid4())[:32] # '1a93447e-c77b-4cfa-b58e-ed4777d57110' + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, uuid=uuid_pk) + assert "1 validation error for User" in str(cm.value) + assert "Invalid UUID" in str(cm.value) + + def test_identifiable(self): + from generalresearch.models.thl.user import User + + user_uuid = uuid4().hex + user = User(uuid=user_uuid) + assert user.is_identifiable + + +class TestUserCreated: + user_id = randint(1, 2**30) + + def test_valid(self): + from generalresearch.models.thl.user import User + + user = User(user_id=self.user_id) + dt = datetime.now(tz=timezone.utc) + user.created = dt + + assert user.created == dt + + def test_tz_naive_throws_init(self): + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, created=datetime.now(tz=None)) + assert "1 validation error for User" in str(cm.value) + assert "Input should have timezone info" in str(cm.value) + + def test_tz_naive_throws_setter(self): + from generalresearch.models.thl.user import User + + user = User(user_id=self.user_id) + with pytest.raises(ValueError) as cm: + user.created = datetime.now(tz=None) + assert "1 validation error for User" in str(cm.value) + assert "Input should have timezone info" in str(cm.value) + + def test_tz_utc(self): + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User( + user_id=self.user_id, + created=datetime.now(tz=timezone(-timedelta(hours=8))), + ) + assert "1 validation error for User" in str(cm.value) + assert "Timezone is not UTC" in str(cm.value) + + def test_not_in_future(self): + from generalresearch.models.thl.user import User + + the_future = datetime.now(tz=timezone.utc) + timedelta(minutes=1) + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, created=the_future) + assert "1 validation error for User" in str(cm.value) + assert "Input is in the future" in str(cm.value) + + def test_after_anno_domini(self): + from generalresearch.models.thl.user import User + + before_ad = datetime( + year=2015, month=1, day=1, tzinfo=timezone.utc + ) + timedelta(minutes=1) + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, created=before_ad) + assert "1 validation error for User" in str(cm.value) + assert "Input is before Anno Domini" in str(cm.value) + + +class TestUserLastSeen: + user_id = randint(1, 2**30) + + def test_valid(self): + from generalresearch.models.thl.user import User + + user = User(user_id=self.user_id) + dt = datetime.now(tz=timezone.utc) + user.last_seen = dt + + assert user.last_seen == dt + + def test_tz_naive_throws_init(self): + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, last_seen=datetime.now(tz=None)) + assert "1 validation error for User" in str(cm.value) + assert "Input should have timezone info" in str(cm.value) + + def test_tz_naive_throws_setter(self): + from generalresearch.models.thl.user import User + + user = User(user_id=self.user_id) + with pytest.raises(ValueError) as cm: + user.last_seen = datetime.now(tz=None) + assert "1 validation error for User" in str(cm.value) + assert "Input should have timezone info" in str(cm.value) + + def test_tz_utc(self): + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User( + user_id=self.user_id, + last_seen=datetime.now(tz=timezone(-timedelta(hours=8))), + ) + assert "1 validation error for User" in str(cm.value) + assert "Timezone is not UTC" in str(cm.value) + + def test_not_in_future(self): + from generalresearch.models.thl.user import User + + the_future = datetime.now(tz=timezone.utc) + timedelta(minutes=1) + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, last_seen=the_future) + assert "1 validation error for User" in str(cm.value) + assert "Input is in the future" in str(cm.value) + + def test_after_anno_domini(self): + from generalresearch.models.thl.user import User + + before_ad = datetime( + year=2015, month=1, day=1, tzinfo=timezone.utc + ) + timedelta(minutes=1) + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, last_seen=before_ad) + assert "1 validation error for User" in str(cm.value) + assert "Input is before Anno Domini" in str(cm.value) + + +class TestUserBlocked: + user_id = randint(1, 2**30) + + def test_valid(self): + from generalresearch.models.thl.user import User + + user = User(user_id=self.user_id, blocked=True) + assert user.blocked + + def test_str_casting(self): + """We don't want any of these to work, and that's why + we set strict=True on the column""" + from generalresearch.models.thl.user import User + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, blocked="true") + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid boolean" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, blocked="True") + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid boolean" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, blocked="1") + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid boolean" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, blocked="yes") + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid boolean" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, blocked="no") + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid boolean" in str(cm.value) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, blocked=uuid4().hex) + assert "1 validation error for User" in str(cm.value) + assert "Input should be a valid boolean" in str(cm.value) + + +class TestUserTiming: + user_id = randint(1, 2**30) + + def test_valid(self): + from generalresearch.models.thl.user import User + + created = datetime.now(tz=timezone.utc) - timedelta(minutes=60) + last_seen = datetime.now(tz=timezone.utc) - timedelta(minutes=59) + + user = User(user_id=self.user_id, created=created, last_seen=last_seen) + assert user.created == created + assert user.last_seen == last_seen + + def test_created_first(self): + from generalresearch.models.thl.user import User + + created = datetime.now(tz=timezone.utc) - timedelta(minutes=60) + last_seen = datetime.now(tz=timezone.utc) - timedelta(minutes=59) + + with pytest.raises(ValueError) as cm: + User(user_id=self.user_id, created=last_seen, last_seen=created) + assert "1 validation error for User" in str(cm.value) + assert "User created time invalid" in str(cm.value) + + +class TestUserModelVerification: + """Tests that may be dependent on more than 1 attribute""" + + def test_identifiable(self): + from generalresearch.models.thl.user import User + + product_id = uuid4().hex + product_user_id = uuid4().hex + user = User(product_id=product_id, product_user_id=product_user_id) + assert user.is_identifiable + + def test_valid_helper(self): + from generalresearch.models.thl.user import User + + user_bool = User.is_valid_ubp( + product_id=uuid4().hex, product_user_id=uuid4().hex + ) + assert user_bool + + user_bool = User.is_valid_ubp(product_id=uuid4().hex, product_user_id=" - - - ") + assert not user_bool + + +class TestUserSerialization: + + def test_basic_json(self): + from generalresearch.models.thl.user import User + + product_id = uuid4().hex + product_user_id = uuid4().hex + + user = User( + product_id=product_id, + product_user_id=product_user_id, + created=datetime.now(tz=timezone.utc), + blocked=False, + ) + + d = json.loads(user.to_json()) + assert d.get("product_id") == product_id + assert d.get("product_user_id") == product_user_id + assert not d.get("blocked") + + assert d.get("product") is None + assert d.get("created").endswith("Z") + + def test_basic_dict(self): + from generalresearch.models.thl.user import User + + product_id = uuid4().hex + product_user_id = uuid4().hex + + user = User( + product_id=product_id, + product_user_id=product_user_id, + created=datetime.now(tz=timezone.utc), + blocked=False, + ) + + d = user.to_dict() + assert d.get("product_id") == product_id + assert d.get("product_user_id") == product_user_id + assert not d.get("blocked") + + assert d.get("product") is None + assert d.get("created").tzinfo == timezone.utc + + def test_from_json(self): + from generalresearch.models.thl.user import User + + product_id = uuid4().hex + product_user_id = uuid4().hex + + user = User( + product_id=product_id, + product_user_id=product_user_id, + created=datetime.now(tz=timezone.utc), + blocked=False, + ) + + u = User.model_validate_json(user.to_json()) + assert u.product_id == product_id + assert u.product is None + assert u.created.tzinfo == timezone.utc + + +class TestUserMethods: + + def test_audit_log(self, user, audit_log_manager): + assert user.audit_log is None + user.prefetch_audit_log(audit_log_manager=audit_log_manager) + assert user.audit_log == [] + + audit_log_manager.create_dummy(user_id=user.user_id) + user.prefetch_audit_log(audit_log_manager=audit_log_manager) + assert len(user.audit_log) == 1 + + def test_transactions( + self, user_factory, thl_lm, session_with_tx_factory, product_user_wallet_yes + ): + u1 = user_factory(product=product_user_wallet_yes) + + assert u1.transactions is None + u1.prefetch_transactions(thl_lm=thl_lm) + assert u1.transactions == [] + + session_with_tx_factory(user=u1) + + u1.prefetch_transactions(thl_lm=thl_lm) + assert len(u1.transactions) == 1 + + @pytest.mark.skip(reason="TODO") + def test_location_history(self, user): + assert user.location_history is None diff --git a/tests/models/thl/test_user_iphistory.py b/tests/models/thl/test_user_iphistory.py new file mode 100644 index 0000000..46018e0 --- /dev/null +++ b/tests/models/thl/test_user_iphistory.py @@ -0,0 +1,45 @@ +from datetime import timezone, datetime, timedelta + +from generalresearch.models.thl.user_iphistory import ( + UserIPHistory, + UserIPRecord, +) + + +def test_collapse_ip_records(): + # This does not exist in a db, so we do not need fixtures/ real user ids, whatever + now = datetime.now(tz=timezone.utc) - timedelta(days=1) + # Gets stored most recent first. This is reversed, but the validator will order it + records = [ + UserIPRecord(ip="1.2.3.5", created=now + timedelta(minutes=1)), + UserIPRecord( + ip="1e5c:de49:165a:6aa0:4f89:1433:9af7:aaaa", + created=now + timedelta(minutes=2), + ), + UserIPRecord( + ip="1e5c:de49:165a:6aa0:4f89:1433:9af7:bbbb", + created=now + timedelta(minutes=3), + ), + UserIPRecord(ip="1.2.3.5", created=now + timedelta(minutes=4)), + UserIPRecord( + ip="1e5c:de49:165a:6aa0:4f89:1433:9af7:cccc", + created=now + timedelta(minutes=5), + ), + UserIPRecord( + ip="6666:de49:165a:6aa0:4f89:1433:9af7:aaaa", + created=now + timedelta(minutes=6), + ), + UserIPRecord(ip="1.2.3.6", created=now + timedelta(minutes=7)), + ] + iph = UserIPHistory(user_id=1, ips=records) + res = iph.collapse_ip_records() + + # We should be left with one of the 1.2.3.5 ipv4s, + # and only the 1e5c::cccc and the 6666 ipv6 addresses + assert len(res) == 4 + assert [x.ip for x in res] == [ + "1.2.3.6", + "6666:de49:165a:6aa0:4f89:1433:9af7:aaaa", + "1e5c:de49:165a:6aa0:4f89:1433:9af7:cccc", + "1.2.3.5", + ] diff --git a/tests/models/thl/test_user_metadata.py b/tests/models/thl/test_user_metadata.py new file mode 100644 index 0000000..3d851dc --- /dev/null +++ b/tests/models/thl/test_user_metadata.py @@ -0,0 +1,46 @@ +import pytest + +from generalresearch.models import MAX_INT32 +from generalresearch.models.thl.user_profile import UserMetadata + + +class TestUserMetadata: + + def test_default(self): + # You can initialize it with nothing + um = UserMetadata() + assert um.email_address is None + assert um.email_sha1 is None + + def test_user_id(self): + # This does NOT validate that the user_id exists. When we attempt a db operation, + # at that point it will fail b/c of the foreign key constraint. + UserMetadata(user_id=MAX_INT32 - 1) + + with pytest.raises(expected_exception=ValueError) as cm: + UserMetadata(user_id=MAX_INT32) + assert "Input should be less than 2147483648" in str(cm.value) + + def test_email(self): + um = UserMetadata(email_address="e58375d80f5f4a958138004aae44c7ca@example.com") + assert ( + um.email_sha256 + == "fd219d8b972b3d82e70dc83284027acc7b4a6de66c42261c1684e3f05b545bc0" + ) + assert um.email_sha1 == "a82578f02b0eed28addeb81317417cf239ede1c3" + assert um.email_md5 == "9073a7a3c21cfd6160d1899fb736cd1c" + + # You cannot set the hashes directly + with pytest.raises(expected_exception=AttributeError) as cm: + um.email_md5 = "x" * 32 + # assert "can't set attribute 'email_md5'" in str(cm.value) + assert "property 'email_md5' of 'UserMetadata' object has no setter" in str( + cm.value + ) + + # assert it hasn't changed anything + assert um.email_md5 == "9073a7a3c21cfd6160d1899fb736cd1c" + + # If you update the email, all the hashes change + um.email_address = "greg@example.com" + assert um.email_md5 != "9073a7a3c21cfd6160d1899fb736cd1c" diff --git a/tests/models/thl/test_user_streak.py b/tests/models/thl/test_user_streak.py new file mode 100644 index 0000000..72efd05 --- /dev/null +++ b/tests/models/thl/test_user_streak.py @@ -0,0 +1,96 @@ +from datetime import datetime, timedelta +from zoneinfo import ZoneInfo + +import pytest +from pydantic import ValidationError + +from generalresearch.models.thl.user_streak import ( + UserStreak, + StreakPeriod, + StreakFulfillment, + StreakState, +) + + +def test_user_streak_empty_fail(): + us = UserStreak( + period=StreakPeriod.DAY, + fulfillment=StreakFulfillment.COMPLETE, + country_iso="us", + user_id=1, + last_fulfilled_period_start=None, + current_streak=0, + longest_streak=0, + state=StreakState.BROKEN, + ) + assert us.time_remaining_in_period is None + + with pytest.raises( + ValidationError, match="StreakState.BROKEN but current_streak not 0" + ): + UserStreak( + period=StreakPeriod.DAY, + fulfillment=StreakFulfillment.COMPLETE, + country_iso="us", + user_id=1, + last_fulfilled_period_start=None, + current_streak=1, + longest_streak=0, + state=StreakState.BROKEN, + ) + + with pytest.raises( + ValidationError, match="Current streak can't be longer than longest streak" + ): + UserStreak( + period=StreakPeriod.DAY, + fulfillment=StreakFulfillment.COMPLETE, + country_iso="us", + user_id=1, + last_fulfilled_period_start=None, + current_streak=1, + longest_streak=0, + state=StreakState.ACTIVE, + ) + + +def test_user_streak_remaining(): + us = UserStreak( + period=StreakPeriod.DAY, + fulfillment=StreakFulfillment.COMPLETE, + country_iso="us", + user_id=1, + last_fulfilled_period_start=None, + current_streak=1, + longest_streak=1, + state=StreakState.AT_RISK, + ) + now = datetime.now(tz=ZoneInfo("America/New_York")) + end_of_today = now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta( + days=1 + ) + print(f"{now.isoformat()=}, {end_of_today.isoformat()=}") + expected = (end_of_today - now).total_seconds() + assert us.time_remaining_in_period.total_seconds() == pytest.approx(expected, abs=1) + + +def test_user_streak_remaining_month(): + us = UserStreak( + period=StreakPeriod.MONTH, + fulfillment=StreakFulfillment.COMPLETE, + country_iso="us", + user_id=1, + last_fulfilled_period_start=None, + current_streak=1, + longest_streak=1, + state=StreakState.AT_RISK, + ) + now = datetime.now(tz=ZoneInfo("America/New_York")) + end_of_month = ( + now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + timedelta(days=32) + ).replace(day=1) + print(f"{now.isoformat()=}, {end_of_month.isoformat()=}") + expected = (end_of_month - now).total_seconds() + assert us.time_remaining_in_period.total_seconds() == pytest.approx(expected, abs=1) + print(us.time_remaining_in_period) diff --git a/tests/models/thl/test_wall.py b/tests/models/thl/test_wall.py new file mode 100644 index 0000000..057aad2 --- /dev/null +++ b/tests/models/thl/test_wall.py @@ -0,0 +1,207 @@ +from datetime import datetime, timezone, timedelta +from decimal import Decimal +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +from generalresearch.models import Source +from generalresearch.models.thl.definitions import ( + Status, + StatusCode1, + WallStatusCode2, +) +from generalresearch.models.thl.session import Wall + + +class TestWall: + + def test_wall_json(self): + w = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + ext_status_code_1="1.0", + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + started=datetime(2023, 1, 1, 0, 0, 1, tzinfo=timezone.utc), + finished=datetime(2023, 1, 1, 0, 10, 1, tzinfo=timezone.utc), + ) + s = w.to_json() + w2 = Wall.from_json(s) + assert w == w2 + + def test_status_status_code_agreement(self): + # should not raise anything + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.MARKETPLACE_FAIL, + status_code_2=WallStatusCode2.COMPLETE_TOO_FAST, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + with pytest.raises(expected_exception=ValidationError) as e: + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.GRS_ABANDON, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + assert "If status is f, status_code_1 should be in" in str(e.value) + + with pytest.raises(expected_exception=ValidationError) as cm: + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.GRS_ABANDON, + status_code_2=WallStatusCode2.COMPLETE_TOO_FAST, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + assert "If status is f, status_code_1 should be in" in str(e.value) + + def test_status_code_1_2_agreement(self): + # should not raise anything + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.MARKETPLACE_FAIL, + status_code_2=WallStatusCode2.COMPLETE_TOO_FAST, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + status_code_2=None, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + status_code_2=None, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + + with pytest.raises(expected_exception=ValidationError) as e: + Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + status=Status.FAIL, + status_code_1=StatusCode1.BUYER_FAIL, + status_code_2=WallStatusCode2.COMPLETE_TOO_FAST, + started=datetime.now(timezone.utc), + finished=datetime.now(timezone.utc) + timedelta(seconds=1), + ) + assert "If status_code_1 is 1, status_code_2 should be in" in str(e.value) + + def test_annotate_status_code(self): + w = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + ) + w.annotate_status_codes("1.0") + assert Status.COMPLETE == w.status + assert StatusCode1.COMPLETE == w.status_code_1 + assert w.status_code_2 is None + assert "1.0" == w.ext_status_code_1 + assert w.ext_status_code_2 is None + + def test_buyer_too_long(self): + buyer_id = uuid4().hex + w = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + buyer_id=buyer_id, + ) + assert buyer_id == w.buyer_id + + w = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + buyer_id=None, + ) + assert w.buyer_id is None + + w = Wall( + user_id=1, + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + survey_id="yyy", + buyer_id=buyer_id + "abc123", + ) + assert buyer_id == w.buyer_id + + @pytest.mark.skip(reason="TODO") + def test_more_stuff(self): + # todo: .update, test status logic + pass diff --git a/tests/models/thl/test_wall_session.py b/tests/models/thl/test_wall_session.py new file mode 100644 index 0000000..ab140e9 --- /dev/null +++ b/tests/models/thl/test_wall_session.py @@ -0,0 +1,326 @@ +from datetime import datetime, timezone, timedelta +from decimal import Decimal + +import pytest + +from generalresearch.models import Source +from generalresearch.models.thl.definitions import Status, StatusCode1 +from generalresearch.models.thl.session import Session, Wall +from generalresearch.models.thl.user import User + + +class TestWallSession: + + def test_session_with_no_wall_events(self): + started = datetime(2023, 1, 1, tzinfo=timezone.utc) + s = Session(user=User(user_id=1), started=started) + assert s.status is None + assert s.status_code_1 is None + + # todo: this needs to be set explicitly, not this way + # # If I have no wall events, it's a fail + # s.determine_session_status() + # assert s.status == Status.FAIL + # assert s.status_code_1 == StatusCode1.SESSION_START_FAIL + + def test_session_timeout_with_only_grs(self): + started = datetime(2023, 1, 1, tzinfo=timezone.utc) + s = Session(user=User(user_id=1), started=started) + w = Wall( + user_id=1, + source=Source.GRS, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + ) + s.append_wall_event(w) + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + assert Status.TIMEOUT == s.status + assert StatusCode1.GRS_ABANDON == s.status_code_1 + + def test_session_with_only_grs_fail(self): + # todo: this needs to be set explicitly, not this way + pass + # started = datetime(2023, 1, 1, tzinfo=timezone.utc) + # s = Session(user=User(user_id=1), started=started) + # w = Wall(user_id=1, source=Source.GRS, req_survey_id='xxx', + # req_cpi=Decimal(1), session_id=1) + # s.append_wall_event(w) + # w.finish(status=Status.FAIL, status_code_1=StatusCode1.PS_FAIL) + # s.determine_session_status() + # assert s.status == Status.FAIL + # assert s.status_code_1 == StatusCode1.GRS_FAIL + + def test_session_with_only_grs_complete(self): + started = datetime(year=2023, month=1, day=1, tzinfo=timezone.utc) + + # A Session is started + s = Session(user=User(user_id=1), started=started) + + # The User goes into a GRS survey, and completes it + # @gstupp - should a GRS be allowed with a req_cpi > 0? + w = Wall( + user_id=1, + source=Source.GRS, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + ) + s.append_wall_event(w) + w.finish(status=Status.COMPLETE, status_code_1=StatusCode1.COMPLETE) + + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + + assert s.status == Status.FAIL + + # @gstupp changed this behavior on 11/2023 (51471b6ae671f21212a8b1fad60b508181cbb8ca) + # I don't know which is preferred or the consequences of each. However, + # now it's a SESSION_CONTINUE_FAIL instead of a SESSION_START_FAIL so + # change this so the test passes + # self.assertEqual(s.status_code_1, StatusCode1.SESSION_START_FAIL) + assert s.status_code_1 == StatusCode1.SESSION_CONTINUE_FAIL + + @pytest.mark.skip(reason="TODO") + def test_session_with_only_non_grs_complete(self): + # todo: this needs to be set explicitly, not this way + pass + # # This fails... until payout stuff is done + # started = datetime(2023, 1, 1, tzinfo=timezone.utc) + # s = Session(user=User(user_id=1), started=started) + # w = Wall(source=Source.DYNATA, req_survey_id='xxx', req_cpi=Decimal('1.00001'), + # session_id=1, user_id=1) + # s.append_wall_event(w) + # w.finish(status=Status.COMPLETE, status_code_1=StatusCode1.COMPLETE) + # s.determine_session_status() + # assert s.status == Status.COMPLETE + # assert s.status_code_1 is None + + def test_session_with_only_non_grs_fail(self): + started = datetime(year=2023, month=1, day=1, tzinfo=timezone.utc) + + s = Session(user=User(user_id=1), started=started) + w = Wall( + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.00001"), + session_id=1, + user_id=1, + ) + + s.append_wall_event(w) + w.finish(status=Status.FAIL, status_code_1=StatusCode1.BUYER_FAIL) + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + + assert s.status == Status.FAIL + assert s.status_code_1 == StatusCode1.BUYER_FAIL + assert s.payout is None + + def test_session_with_only_non_grs_timeout(self): + started = datetime(year=2023, month=1, day=1, tzinfo=timezone.utc) + + s = Session(user=User(user_id=1), started=started) + w = Wall( + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.00001"), + session_id=1, + user_id=1, + ) + + s.append_wall_event(w) + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + + assert s.status == Status.TIMEOUT + assert s.status_code_1 == StatusCode1.BUYER_ABANDON + assert s.payout is None + + def test_session_with_grs_and_external(self): + started = datetime(year=2023, month=1, day=1, tzinfo=timezone.utc) + + s = Session(user=User(user_id=1), started=started) + w = Wall( + source=Source.GRS, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + user_id=1, + started=started, + ) + + s.append_wall_event(w) + w.finish( + status=Status.COMPLETE, + status_code_1=StatusCode1.COMPLETE, + finished=started + timedelta(minutes=10), + ) + + w = Wall( + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.00001"), + session_id=1, + user_id=1, + ) + s.append_wall_event(w) + w.finish( + status=Status.ABANDON, + finished=datetime.now(tz=timezone.utc) + timedelta(minutes=10), + status_code_1=StatusCode1.BUYER_ABANDON, + ) + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + + assert s.status == Status.ABANDON + assert s.status_code_1 == StatusCode1.BUYER_ABANDON + assert s.payout is None + + s = Session(user=User(user_id=1), started=started) + w = Wall( + source=Source.GRS, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + user_id=1, + ) + s.append_wall_event(w) + w.finish(status=Status.COMPLETE, status_code_1=StatusCode1.COMPLETE) + w = Wall( + source=Source.DYNATA, + req_survey_id="xxx", + req_cpi=Decimal("1.00001"), + session_id=1, + user_id=1, + ) + s.append_wall_event(w) + w.finish(status=Status.FAIL, status_code_1=StatusCode1.PS_DUPLICATE) + + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + + assert s.status == Status.FAIL + assert s.status_code_1 == StatusCode1.PS_DUPLICATE + assert s.payout is None + + def test_session_marketplace_fail(self): + started = datetime(2023, 1, 1, tzinfo=timezone.utc) + + s = Session(user=User(user_id=1), started=started) + w = Wall( + source=Source.CINT, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + user_id=1, + started=started, + ) + s.append_wall_event(w) + w.finish( + status=Status.FAIL, + status_code_1=StatusCode1.MARKETPLACE_FAIL, + finished=started + timedelta(minutes=10), + ) + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + assert Status.FAIL == s.status + assert StatusCode1.SESSION_CONTINUE_QUALITY_FAIL == s.status_code_1 + + def test_session_unknown(self): + started = datetime(2023, 1, 1, tzinfo=timezone.utc) + + s = Session(user=User(user_id=1), started=started) + w = Wall( + source=Source.CINT, + req_survey_id="xxx", + req_cpi=Decimal(1), + session_id=1, + user_id=1, + started=started, + ) + s.append_wall_event(w) + w.finish( + status=Status.FAIL, + status_code_1=StatusCode1.UNKNOWN, + finished=started + timedelta(minutes=10), + ) + status, status_code_1 = s.determine_session_status() + s.update(status=status, status_code_1=status_code_1) + assert Status.FAIL == s.status + assert StatusCode1.BUYER_FAIL == s.status_code_1 + + +# class TestWallSessionPayout: +# product_id = uuid4().hex +# +# def test_session_payout_with_only_non_grs_complete(self): +# sql_helper = self.make_sql_helper() +# user = User(user_id=1, product_id=self.product_id) +# s = Session(user=user, started=datetime(2023, 1, 1, tzinfo=timezone.utc)) +# w = Wall(source=Source.DYNATA, req_survey_id='xxx', req_cpi=Decimal('1.00001')) +# s.append_wall_event(w) +# w.handle_callback(status=Status.COMPLETE) +# s.determine_session_status() +# s.determine_payout(sql_helper=sql_helper) +# assert s.status == Status.COMPLETE +# assert s.status_code_1 is None +# # we're assuming here the commission on this BP is 8.5% and doesn't get changed by someone! +# assert s.payout == Decimal('0.88') +# +# def test_session_payout(self): +# sql_helper = self.make_sql_helper() +# user = User(user_id=1, product_id=self.product_id) +# s = Session(user=user, started=datetime(2023, 1, 1, tzinfo=timezone.utc)) +# w = Wall(source=Source.GRS, req_survey_id='xxx', req_cpi=1) +# s.append_wall_event(w) +# w.handle_callback(status=Status.COMPLETE) +# w = Wall(source=Source.DYNATA, req_survey_id='xxx', req_cpi=Decimal('1.00001')) +# s.append_wall_event(w) +# w.handle_callback(status=Status.COMPLETE) +# s.determine_session_status() +# s.determine_payout(commission_pct=Decimal('0.05')) +# assert s.status == Status.COMPLETE +# assert s.status_code_1 is None +# assert s.payout == Decimal('0.93') + + +# def test_get_from_uuid_vendor_wall(self): +# sql_helper = self.make_sql_helper() +# sql_helper.get_or_create("auth_user", "id", {"id": 1}, { +# "id": 1, "password": "1", +# "last_login": None, "is_superuser": 0, +# "username": "a", "first_name": "a", +# "last_name": "a", "email": "a", +# "is_staff": 0, "is_active": 1, +# "date_joined": "2023-10-13 14:03:20.000000"}) +# sql_helper.get_or_create("vendor_wallsession", "id", {"id": 324}, {"id": 324}) +# sql_helper.create("vendor_wall", { +# "id": "7b3e380babc840b79abf0030d408bbd9", +# "status": "c", +# "started": "2023-10-10 00:51:13.415444", +# "finished": "2023-10-10 01:08:00.676947", +# "req_loi": 1200, +# "req_cpi": 0.63, +# "req_survey_id": "8070750", +# "survey_id": "8070750", +# "cpi": 0.63, +# "user_id": 1, +# "report_notes": None, +# "report_status": None, +# "status_code": "1", +# "req_survey_hashed_opp": None, +# "session_id": 324, +# "source": "i", +# "ubp_id": None +# }) +# Wall +# w = Wall.get_from_uuid_vendor_wall('7b3e380babc840b79abf0030d408bbd9', sql_helper=sql_helper, +# session_id=1) +# assert w.status == Status.COMPLETE +# assert w.source == Source.INNOVATE +# assert w.uuid == '7b3e380babc840b79abf0030d408bbd9' +# assert w.cpi == Decimal('0.63') +# assert w.survey_id == '8070750' +# assert w.user_id == 1 diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 0000000..d280de0 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto \ No newline at end of file diff --git a/tests/sql_helper.py b/tests/sql_helper.py new file mode 100644 index 0000000..247f0cd --- /dev/null +++ b/tests/sql_helper.py @@ -0,0 +1,53 @@ +from uuid import uuid4 + +import pytest +from pydantic import MySQLDsn, MariaDBDsn, ValidationError + + +class TestSqlHelper: + def test_db_property(self): + from generalresearch.sql_helper import SqlHelper + + db_name = uuid4().hex[:8] + dsn = MySQLDsn(f"mysql://root@localhost/{db_name}") + instance = SqlHelper(dsn=dsn) + + assert instance.db == db_name + assert instance.db_name == db_name + assert instance.dbname == db_name + + def test_scheme(self): + from generalresearch.sql_helper import SqlHelper + + dsn = MySQLDsn(f"mysql://root@localhost/test") + instance = SqlHelper(dsn=dsn) + assert instance.is_mysql() + + # This needs psycopg2 installed, and don't need to make this a + # requirement of the package ... todo? + # dsn = PostgresDsn(f"postgres://root@localhost/test") + # instance = SqlHelper(dsn=dsn) + # self.assertTrue(instance.is_postgresql()) + + with pytest.raises(ValidationError): + SqlHelper(dsn=MariaDBDsn(f"maria://root@localhost/test")) + + def test_row_decode(self): + from generalresearch.sql_helper import decode_uuids + + valid_uuid4_1 = "bf432839fd0d4436ab1581af5eb98f26" + valid_uuid4_2 = "e1d8683b9c014e9d80eb120c2fc95288" + invalid_uuid4_2 = "2f3b9edf5a3da6198717b77604775ec1" + + row1 = { + "b": valid_uuid4_1, + "c": valid_uuid4_2, + } + + row2 = { + "a": valid_uuid4_1, + "b": invalid_uuid4_2, + } + + assert row1 == decode_uuids(row1) + assert row1 != decode_uuids(row2) diff --git a/tests/wall_status_codes/__init__.py b/tests/wall_status_codes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/wall_status_codes/test_analyze.py b/tests/wall_status_codes/test_analyze.py new file mode 100644 index 0000000..da6efb3 --- /dev/null +++ b/tests/wall_status_codes/test_analyze.py @@ -0,0 +1,150 @@ +from generalresearch.models.thl.definitions import StatusCode1, Status +from generalresearch.wall_status_codes import innovate + + +class TestInnovate: + def test_complete(self): + status, status_code_1, status_code_2 = innovate.annotate_status_code("1", None) + assert Status.COMPLETE == status + assert StatusCode1.COMPLETE == status_code_1 + assert status_code_2 is None + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "1", "whatever" + ) + assert Status.COMPLETE == status + assert StatusCode1.COMPLETE == status_code_1 + assert status_code_2 is None + + def test_unknown(self): + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "69420", None + ) + assert Status.FAIL == status + assert StatusCode1.UNKNOWN == status_code_1 + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "69420", "Speeder" + ) + assert Status.FAIL == status + assert StatusCode1.UNKNOWN == status_code_1 + + def test_ps(self): + status, status_code_1, status_code_2 = innovate.annotate_status_code("5", None) + assert Status.FAIL == status + assert StatusCode1.PS_FAIL == status_code_1 + # The ext_status_code_2 should reclassify this as PS_FAIL + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "8", "DeviceType" + ) + assert Status.FAIL == status + assert StatusCode1.PS_FAIL == status_code_1 + # this should be reclassified from PS_FAIL to PS_OQ + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "5", "Group NA" + ) + assert Status.FAIL == status + assert StatusCode1.PS_OVERQUOTA == status_code_1 + + def test_dupe(self): + # innovate calls it a quality, should be dupe + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "8", "Duplicated to token Tq2SwRVX7PUWnFunGPAYWHk" + ) + assert Status.FAIL == status + assert StatusCode1.PS_DUPLICATE == status_code_1 + # stay as quality + status, status_code_1, status_code_2 = innovate.annotate_status_code( + "8", "Selected threat potential score at joblevel not allow the survey" + ) + assert Status.FAIL == status + assert StatusCode1.PS_QUALITY == status_code_1 + + +# todo: fix me: This got broke because I opened the csv in libreoffice and it broke it +# status codes "1.0" -> 1 (facepalm) + +# class TestAllCsv:: +# def get_wall(self) -> pd.DataFrame: +# df = pd.read_csv(os.path.join(os.path.dirname(__file__), "wall_excerpt.csv.gz")) +# df['started'] = pd.to_datetime(df['started']) +# df['finished'] = pd.to_datetime(df['finished']) +# return df +# +# def test_dynata(self): +# df = self.get_wall() +# df = df[df.source == Source.DYNATA] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(dynata.annotate_status_code(row.status_code)), axis=1) +# self.assertEqual(1419, len(df[df.t_status == Status.COMPLETE])) +# assert len(df[df.t_status == Status.FAIL]) == 2109 +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) == 0 +# assert 1000 < len(df[df.t_status_code_1 == StatusCode1.BUYER_FAIL]) < 1100 +# assert 30 < len(df[df.t_status_code_1 == StatusCode1.PS_BLOCKED]) < 40 +# +# def test_fullcircle(self): +# df = self.get_wall() +# df = df[df.source == Source.FULL_CIRCLE] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(fullcircle.annotate_status_code(row.status_code)), axis=1) +# # assert len(df[df.t_status == Status.COMPLETE]) == 1419 +# # assert len(df[df.t_status == Status.FAIL]) == 2109 +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) == 0 +# +# def test_innovate(self): +# df = self.get_wall() +# df = df[df.source == Source.INNOVATE] +# df = df[~df.status.isin({'r', 'e'})] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(innovate.annotate_status_code( +# row.status_code, row.status_code_2 if pd.notnull(row.status_code_2) else None)), +# axis=1) +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) / len(df) < 0.05 +# # assert len(df[df.t_status == Status.COMPLETE]) == 1419 +# # assert len(df[df.t_status == Status.FAIL]) == 2109 +# # assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) == 0 +# +# def test_morning(self): +# df = self.get_wall() +# df = df[df.source == Source.MORNING_CONSULT] +# df = df[df.status_code.notnull()] +# # we have to do this for old values... +# df['status_code'] = df['status_code'].apply(short_code_to_status_codes_morning.get) +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(morning.annotate_status_code(row.status, row.status_code)), axis=1) +# dff = df[df.t_status != Status.COMPLETE] +# assert len(dff[dff.t_status_code_1 == StatusCode1.UNKNOWN]) / len(dff) < 0.05 +# +# def test_pollfish(self): +# df = self.get_wall() +# df = df[df.source == Source.POLLFISH] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(pollfish.annotate_status_code(row.status_code)), axis=1) +# +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) / len(df) < 0.05 +# +# def test_pollfish(self): +# df = self.get_wall() +# df = df[df.source == Source.PRECISION] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(precision.annotate_status_code(row.status_code)), axis=1) +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) / len(df) < 0.05 +# +# def test_sago(self): +# df = self.get_wall() +# df = df[df.source == Source.SAGO] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(sago.annotate_status_code(row.status_code)), axis=1) +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) / len(df) < 0.05 +# +# def test_spectrum(self): +# df = self.get_wall() +# df = df[df.source == Source.SPECTRUM] +# df = df[df.status_code.notnull()] +# df[['t_status', 't_status_code_1', 't_status_code_2']] = df.apply( +# lambda row: pd.Series(spectrum.annotate_status_code(row.status_code)), axis=1) +# assert len(df[df.t_status_code_1 == StatusCode1.UNKNOWN]) / len(df) < 0.05 diff --git a/tests/wxet/__init__.py b/tests/wxet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/wxet/models/__init__.py b/tests/wxet/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/wxet/models/test_definitions.py b/tests/wxet/models/test_definitions.py new file mode 100644 index 0000000..543b9f1 --- /dev/null +++ b/tests/wxet/models/test_definitions.py @@ -0,0 +1,113 @@ +import pytest + + +class TestWXETStatusCode1: + + def test_is_pre_task_entry_fail_pre(self): + from generalresearch.wxet.models.definitions import ( + WXETStatusCode1, + ) + + assert WXETStatusCode1.UNKNOWN.is_pre_task_entry_fail + assert WXETStatusCode1.WXET_FAIL.is_pre_task_entry_fail + assert WXETStatusCode1.WXET_ABANDON.is_pre_task_entry_fail + + def test_is_pre_task_entry_fail_post(self): + from generalresearch.wxet.models.definitions import ( + WXETStatusCode1, + ) + + assert not WXETStatusCode1.BUYER_OVER_QUOTA.is_pre_task_entry_fail + assert not WXETStatusCode1.BUYER_DUPLICATE.is_pre_task_entry_fail + assert not WXETStatusCode1.BUYER_TASK_NOT_AVAILABLE.is_pre_task_entry_fail + + assert not WXETStatusCode1.BUYER_ABANDON.is_pre_task_entry_fail + assert not WXETStatusCode1.BUYER_FAIL.is_pre_task_entry_fail + assert not WXETStatusCode1.BUYER_QUALITY_FAIL.is_pre_task_entry_fail + assert not WXETStatusCode1.BUYER_POSTBACK_NOT_RECEIVED.is_pre_task_entry_fail + assert not WXETStatusCode1.COMPLETE.is_pre_task_entry_fail + + +class TestCheckWXETStatusConsistent: + + def test_completes(self): + + from generalresearch.wxet.models.definitions import ( + WXETStatus, + WXETStatusCode1, + check_wxet_status_consistent, + ) + + with pytest.raises(AssertionError) as cm: + check_wxet_status_consistent( + status=WXETStatus.COMPLETE, + status_code_1=WXETStatusCode1.UNKNOWN, + status_code_2=None, + ) + + assert ( + "Invalid StatusCode1 when Status=COMPLETE. Use WXETStatusCode1.COMPLETE" + == str(cm.value) + ) + + def test_abandon(self): + + from generalresearch.wxet.models.definitions import ( + WXETStatus, + WXETStatusCode1, + check_wxet_status_consistent, + ) + + with pytest.raises(AssertionError) as cm: + check_wxet_status_consistent( + status=WXETStatus.ABANDON, + status_code_1=WXETStatusCode1.COMPLETE, + status_code_2=None, + ) + assert ( + "Invalid StatusCode1 when Status=ABANDON. Use WXET_ABANDON or BUYER_ABANDON" + == str(cm.value) + ) + + def test_fail(self): + + from generalresearch.wxet.models.definitions import ( + WXETStatus, + WXETStatusCode1, + check_wxet_status_consistent, + ) + + for sc1 in [ + WXETStatusCode1.COMPLETE, + WXETStatusCode1.WXET_ABANDON, + WXETStatusCode1.WXET_ABANDON, + ]: + with pytest.raises(AssertionError) as cm: + check_wxet_status_consistent( + status=WXETStatus.FAIL, + status_code_1=sc1, + status_code_2=None, + ) + assert "Invalid StatusCode1 when Status=FAIL." == str(cm.value) + + def test_status_code_2(self): + """Any StatusCode2 should fail if the StatusCode1 isn't + StatusCode1.WXET_FAIL + """ + + from generalresearch.wxet.models.definitions import ( + WXETStatus, + WXETStatusCode1, + WXETStatusCode2, + check_wxet_status_consistent, + ) + + for sc2 in WXETStatusCode2: + with pytest.raises(AssertionError) as cm: + check_wxet_status_consistent( + status=WXETStatus.FAIL, + status_code_1=WXETStatusCode1.COMPLETE, + status_code_2=sc2, + ) + + assert "Invalid StatusCode1 when Status=FAIL." == str(cm.value) diff --git a/tests/wxet/models/test_finish_type.py b/tests/wxet/models/test_finish_type.py new file mode 100644 index 0000000..7bdeea7 --- /dev/null +++ b/tests/wxet/models/test_finish_type.py @@ -0,0 +1,136 @@ +import pytest + +from generalresearch.wxet.models.definitions import WXETStatus, WXETStatusCode1 +from generalresearch.wxet.models.finish_type import FinishType, is_a_finish + + +class TestFinishType: + + def test_init_entrance(self): + instance = FinishType.ENTRANCE + finish_statuses = instance.finish_statuses + assert isinstance(finish_statuses, set) + assert 5 == len(finish_statuses) + + def test_init_complete(self): + instance = FinishType.COMPLETE + finish_statuses = instance.finish_statuses + assert isinstance(finish_statuses, set) + assert 1 == len(finish_statuses) + + def test_init_fail_or_complete(self): + instance = FinishType.FAIL_OR_COMPLETE + finish_statuses = instance.finish_statuses + assert isinstance(finish_statuses, set) + assert 2 == len(finish_statuses) + + def test_init_fail(self): + instance = FinishType.FAIL + finish_statuses = instance.finish_statuses + assert isinstance(finish_statuses, set) + assert 1 == len(finish_statuses) + + +class TestFunctionIsAFinish: + + def test_init_ft_entrance(self): + assert is_a_finish( + status=None, + status_code_1=None, + finish_type=FinishType.ENTRANCE, + ) + + assert is_a_finish( + status=WXETStatus.ABANDON, + status_code_1=None, + finish_type=FinishType.ENTRANCE, + ) + + assert is_a_finish( + status=WXETStatus.ABANDON, + status_code_1=WXETStatusCode1.BUYER_ABANDON, + finish_type=FinishType.ENTRANCE, + ) + + # If it's a WXET Abandon, they ever entered the Task so don't + # consider it a Finish + assert not is_a_finish( + status=WXETStatus.ABANDON, + status_code_1=WXETStatusCode1.WXET_ABANDON, + finish_type=FinishType.ENTRANCE, + ) + + def test_init_ft_complete(self): + assert is_a_finish( + status=WXETStatus.COMPLETE, + status_code_1=None, + finish_type=FinishType.COMPLETE, + ) + + assert is_a_finish( + status=WXETStatus.COMPLETE, + status_code_1=WXETStatusCode1.COMPLETE, + finish_type=FinishType.COMPLETE, + ) + + def test_init_ft_fail_or_complete(self): + assert is_a_finish( + status=WXETStatus.FAIL, + status_code_1=None, + finish_type=FinishType.FAIL_OR_COMPLETE, + ) + + assert is_a_finish( + status=WXETStatus.FAIL, + status_code_1=WXETStatusCode1.BUYER_FAIL, + finish_type=FinishType.FAIL_OR_COMPLETE, + ) + + # If it's a WXET Fail, the Worker never made it into a WXET Task + # experience, so it should not be considered a Finish + assert not is_a_finish( + status=WXETStatus.FAIL, + status_code_1=WXETStatusCode1.WXET_FAIL, + finish_type=FinishType.FAIL_OR_COMPLETE, + ) + + assert is_a_finish( + status=WXETStatus.COMPLETE, + status_code_1=WXETStatusCode1.COMPLETE, + finish_type=FinishType.FAIL_OR_COMPLETE, + ) + + def test_init_ft_fail(self): + assert is_a_finish( + status=WXETStatus.FAIL, + status_code_1=None, + finish_type=FinishType.FAIL, + ) + + assert is_a_finish( + status=WXETStatus.FAIL, + status_code_1=WXETStatusCode1.BUYER_FAIL, + finish_type=FinishType.FAIL, + ) + + def test_invalid_status_code_1(self): + for ft in FinishType: + for s in WXETStatus: + with pytest.raises(expected_exception=AssertionError) as cm: + is_a_finish( + status=s, + status_code_1=WXETStatus.COMPLETE, + finish_type=ft, + ) + assert "Invalid status_code_1" == str(cm.value) + + def test_invalid_none_status(self): + for ft in FinishType: + for sc1 in WXETStatusCode1: + with pytest.raises(expected_exception=AssertionError) as cm: + is_a_finish( + status=None, + status_code_1=sc1, + finish_type=ft, + ) + assert "Cannot provide status_code_1 without a status" == str(cm.value) -- cgit v1.2.3