summaryrefslogtreecommitdiff
path: root/jb/models/custom_types.py
blob: 70bc5c12f1bc09c981b7fdb8472277ee7ff028f4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import re
from datetime import datetime, timezone
from typing import Any, Optional
from uuid import UUID

from pydantic import (
    AwareDatetime,
    StringConstraints,
    TypeAdapter,
    HttpUrl,
)
from pydantic.functional_serializers import PlainSerializer
from pydantic.functional_validators import AfterValidator, BeforeValidator
from pydantic.networks import UrlConstraints
from pydantic_core import Url
from typing_extensions import Annotated


def convert_datetime_to_iso_8601_with_z_suffix(dt: datetime) -> str:
    # By default, datetimes are serialized with the %f optional. We don't want that because
    #   then the deserialization fails if the datetime didn't have microseconds.
    return dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ")


def convert_str_dt(v: Any) -> Optional[AwareDatetime]:
    # By default, pydantic is unable to handle tz-aware isoformat str. Attempt to parse a str
    #   that was dumped using the iso8601 format with Z suffix.
    if v is not None and type(v) is str:
        assert v.endswith("Z") and "T" in v, "invalid format"
        return datetime.strptime(v, "%Y-%m-%dT%H:%M:%S.%fZ").replace(
            tzinfo=timezone.utc
        )
    return v


def assert_utc(v: AwareDatetime) -> AwareDatetime:
    if isinstance(v, datetime):
        assert v.tzinfo == timezone.utc, "Timezone is not UTC"
    return v


# Our custom AwareDatetime that correctly serializes and deserializes
#   to an ISO8601 str with timezone
AwareDatetimeISO = Annotated[
    AwareDatetime,
    BeforeValidator(convert_str_dt),
    AfterValidator(assert_utc),
    PlainSerializer(
        lambda x: x.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
        when_used="json-unless-none",
    ),
]

# ISO 3166-1 alpha-2 (two-letter codes, lowercase)
# "Like" b/c it matches the format, but we're not explicitly checking
#   it is one of our supported values. See models.thl.locales for that.
CountryISOLike = Annotated[
    str, StringConstraints(max_length=2, min_length=2, pattern=r"^[a-z]{2}$")
]
# 3-char ISO 639-2/B, lowercase
LanguageISOLike = Annotated[
    str, StringConstraints(max_length=3, min_length=3, pattern=r"^[a-z]{3}$")
]


def check_valid_uuid(v: str) -> str:
    try:
        assert UUID(v).hex == v
    except Exception:
        raise ValueError("Invalid UUID")
    return v


# Our custom field that stores a UUID4 as the .hex string representation
UUIDStr = Annotated[
    str,
    StringConstraints(min_length=32, max_length=32),
    AfterValidator(check_valid_uuid),
]
# Accepts the non-hex representation and coerces
UUIDStrCoerce = Annotated[
    str,
    StringConstraints(min_length=32, max_length=32),
    BeforeValidator(lambda value: TypeAdapter(UUID).validate_python(value).hex),
    AfterValidator(check_valid_uuid),
]

# Same thing as UUIDStr with HttpUrl field. It is confusing that this
# is not a str https://github.com/pydantic/pydantic/discussions/6395
HttpUrlStr = Annotated[
    str,
    BeforeValidator(lambda value: str(TypeAdapter(HttpUrl).validate_python(value))),
]

HttpsUrl = Annotated[Url, UrlConstraints(max_length=2083, allowed_schemes=["https"])]
HttpsUrlStr = Annotated[
    str,
    BeforeValidator(lambda value: str(TypeAdapter(HttpsUrl).validate_python(value))),
]


def check_valid_amt_boto3_id(v: str) -> str:
    # Test ids from amazon have 20 chars
    if not re.fullmatch(r"[A-Z0-9]{20}|[A-Z0-9]{30}", v):
        raise ValueError("Invalid AMT Boto3 ID")
    return v


AMTBoto3ID = Annotated[
    str,
    StringConstraints(min_length=20, max_length=30),
    AfterValidator(check_valid_amt_boto3_id),
]