aboutsummaryrefslogtreecommitdiff
path: root/test_utils/managers/upk/conftest.py
blob: e28d085b4d2ff3ac55fc189a966f6e0faba33736 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
import time
from typing import TYPE_CHECKING, Optional
from uuid import UUID

import pandas as pd
import pytest

from generalresearch.pg_helper import PostgresConfig

if TYPE_CHECKING:
    from generalresearch.managers.thl.category import CategoryManager


def insert_data_from_csv(
    thl_web_rw: PostgresConfig,
    table_name: str,
    fp: Optional[str] = None,
    disable_fk_checks: bool = False,
    df: Optional[pd.DataFrame] = None,
):
    assert fp is not None or df is not None and not (fp is not None and df is not None)
    if fp:
        df = pd.read_csv(fp, dtype=str)
    df = df.where(pd.notnull(df), None)
    cols = list(df.columns)
    col_str = ", ".join(cols)
    values_str = ", ".join(["%s"] * len(cols))
    if "id" in df.columns and len(df["id"].iloc[0]) == 36:
        df["id"] = df["id"].map(lambda x: UUID(x).hex)
    args = df.to_dict("tight")["data"]

    with thl_web_rw.make_connection() as conn:
        with conn.cursor() as c:
            if disable_fk_checks:
                c.execute("SET CONSTRAINTS ALL DEFERRED")
            c.executemany(
                f"INSERT INTO {table_name} ({col_str}) VALUES ({values_str})",
                params_seq=args,
            )
        conn.commit()


@pytest.fixture(scope="session")
def category_data(
    thl_web_rw: PostgresConfig, category_manager: "CategoryManager"
) -> None:
    fp = os.path.join(os.path.dirname(__file__), "marketplace_category.csv.gz")
    insert_data_from_csv(
        thl_web_rw,
        fp=fp,
        table_name="marketplace_category",
        disable_fk_checks=True,
    )
    # Don't strictly need to do this, but probably we should
    category_manager.populate_caches()
    cats = category_manager.categories.values()
    path_id = {c.path: c.id for c in cats}
    data = [
        {"id": c.id, "parent_id": path_id[c.parent_path]} for c in cats if c.parent_path
    ]
    query = """
    UPDATE marketplace_category
    SET parent_id = %(parent_id)s
    WHERE id = %(id)s;
    """
    with thl_web_rw.make_connection() as conn:
        with conn.cursor() as c:
            c.executemany(query=query, params_seq=data)
        conn.commit()


@pytest.fixture(scope="session")
def property_data(thl_web_rw: PostgresConfig) -> None:
    fp = os.path.join(os.path.dirname(__file__), "marketplace_property.csv.gz")
    insert_data_from_csv(thl_web_rw, fp=fp, table_name="marketplace_property")


@pytest.fixture(scope="session")
def item_data(thl_web_rw: PostgresConfig) -> None:
    fp = os.path.join(os.path.dirname(__file__), "marketplace_item.csv.gz")
    insert_data_from_csv(thl_web_rw, fp=fp, table_name="marketplace_item")


@pytest.fixture(scope="session")
def propertycategoryassociation_data(
    thl_web_rw: PostgresConfig,
    category_data,
    property_data,
    category_manager: "CategoryManager",
) -> None:
    table_name = "marketplace_propertycategoryassociation"
    fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz")
    # Need to lookup category pk from uuid
    category_manager.populate_caches()
    df = pd.read_csv(fp, dtype=str)
    df["category_id"] = df["category_id"].map(
        lambda x: category_manager.categories[x].id
    )
    insert_data_from_csv(thl_web_rw, df=df, table_name=table_name)


@pytest.fixture(scope="session")
def propertycountry_data(thl_web_rw: PostgresConfig, property_data) -> None:
    fp = os.path.join(os.path.dirname(__file__), "marketplace_propertycountry.csv.gz")
    insert_data_from_csv(thl_web_rw, fp=fp, table_name="marketplace_propertycountry")


@pytest.fixture(scope="session")
def propertymarketplaceassociation_data(
    thl_web_rw: PostgresConfig, property_data
) -> None:
    table_name = "marketplace_propertymarketplaceassociation"
    fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz")
    insert_data_from_csv(thl_web_rw, fp=fp, table_name=table_name)


@pytest.fixture(scope="session")
def propertyitemrange_data(
    thl_web_rw: PostgresConfig, property_data, item_data
) -> None:
    table_name = "marketplace_propertyitemrange"
    fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz")
    insert_data_from_csv(thl_web_rw, fp=fp, table_name=table_name)


@pytest.fixture(scope="session")
def question_data(thl_web_rw: PostgresConfig) -> None:
    table_name = "marketplace_question"
    fp = os.path.join(os.path.dirname(__file__), f"{table_name}.csv.gz")
    insert_data_from_csv(
        thl_web_rw, fp=fp, table_name=table_name, disable_fk_checks=True
    )


@pytest.fixture(scope="session")
def clear_upk_tables(thl_web_rw: PostgresConfig):
    tables = [
        "marketplace_propertyitemrange",
        "marketplace_propertymarketplaceassociation",
        "marketplace_propertycategoryassociation",
        "marketplace_category",
        "marketplace_item",
        "marketplace_property",
        "marketplace_propertycountry",
        "marketplace_question",
    ]
    table_str = ", ".join(tables)

    with thl_web_rw.make_connection() as conn:
        with conn.cursor() as c:
            c.execute(f"TRUNCATE {table_str} RESTART IDENTITY CASCADE;")
        conn.commit()


@pytest.fixture(scope="session")
def upk_data(
    clear_upk_tables,
    category_data,
    property_data,
    item_data,
    propertycategoryassociation_data,
    propertycountry_data,
    propertymarketplaceassociation_data,
    propertyitemrange_data,
    question_data,
) -> None:
    # Wait a second to make sure the HarmonizerCache refresh loop pulls these in
    time.sleep(2)


def test_fixtures(upk_data):
    pass