aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorMax Nanis2026-02-25 16:20:18 -0500
committerMax Nanis2026-02-25 16:20:18 -0500
commit04aee0dc7e908ce020d2d2c3f8ffb4a96424b883 (patch)
treeefb99622da9a962a73921a945373c019f98e6273 /tests
parent8c1940445503fd6678d0961600f2be81622793a2 (diff)
downloadamt-jb-04aee0dc7e908ce020d2d2c3f8ffb4a96424b883.tar.gz
amt-jb-04aee0dc7e908ce020d2d2c3f8ffb4a96424b883.zip
test_notification (for sns mgmt), along with more type hinting on pytest conftest
Diffstat (limited to 'tests')
-rw-r--r--tests/conftest.py131
-rw-r--r--tests/http/conftest.py49
-rw-r--r--tests/http/test_basic.py37
-rw-r--r--tests/http/test_notifications.py163
-rw-r--r--tests/http/test_preview.py39
-rw-r--r--tests/http/test_work.py98
6 files changed, 375 insertions, 142 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index 3318f1c..a33b149 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,7 +1,7 @@
import copy
from datetime import datetime, timezone, timedelta
import os
-from typing import Optional
+from typing import Optional, TYPE_CHECKING, Callable, Dict, Any
from uuid import uuid4
from dotenv import load_dotenv
import pytest
@@ -20,6 +20,16 @@ from generalresearchutils.currency import USDCent
from jb.models.definitions import HitStatus, HitReviewStatus, AssignmentStatus
from jb.models.hit import HitType, HitQuestion, Hit
from tests import generate_amt_id
+from _pytest.config import Config
+
+if TYPE_CHECKING:
+ from jb.settings import Settings
+ from jb.managers.hit import HitQuestionManager, HitTypeManager, HitManager
+ from jb.managers.assignment import AssignmentManager
+ from jb.managers.bonus import BonusManager
+
+
+# --- IDs and Identifiers ---
@pytest.fixture
@@ -67,7 +77,7 @@ def pe_id() -> str:
@pytest.fixture(scope="session")
-def env_file_path(pytestconfig):
+def env_file_path(pytestconfig: Config) -> str:
root_path = pytestconfig.rootpath
env_path = os.path.join(root_path, ".env.test")
@@ -78,7 +88,7 @@ def env_file_path(pytestconfig):
@pytest.fixture(scope="session")
-def settings(env_file_path) -> "Settings":
+def settings(env_file_path: str) -> "Settings":
from jb.settings import Settings as JBSettings
s = JBSettings(_env_file=env_file_path)
@@ -90,7 +100,7 @@ def settings(env_file_path) -> "Settings":
@pytest.fixture(scope="session")
-def redis(settings):
+def redis(settings: "Settings"):
from generalresearchutils.redis_helper import RedisConfig
redis_config = RedisConfig(
@@ -103,7 +113,7 @@ def redis(settings):
@pytest.fixture(scope="session")
-def pg_config(settings) -> PostgresConfig:
+def pg_config(settings: "Settings") -> PostgresConfig:
return PostgresConfig(
dsn=settings.amt_jb_db,
connect_timeout=1,
@@ -115,8 +125,10 @@ def pg_config(settings) -> PostgresConfig:
@pytest.fixture(scope="session")
-def hqm(pg_config) -> "HitQuestionManager":
- assert "/unittest-" in pg_config.dsn.path
+def hqm(pg_config: PostgresConfig) -> "HitQuestionManager":
+ assert (
+ pg_config.dsn.path and "/unittest-" in pg_config.dsn.path
+ ), "pg_config must point to a unittest database (dsn path must contain '/unittest-')"
from jb.managers.hit import HitQuestionManager
@@ -126,8 +138,10 @@ def hqm(pg_config) -> "HitQuestionManager":
@pytest.fixture(scope="session")
-def htm(pg_config) -> "HitTypeManager":
- assert "/unittest-" in pg_config.dsn.path
+def htm(pg_config: PostgresConfig) -> "HitTypeManager":
+ assert (
+ pg_config.dsn.path and "/unittest-" in pg_config.dsn.path
+ ), "pg_config must point to a unittest database (dsn path must contain '/unittest-')"
from jb.managers.hit import HitTypeManager
@@ -137,8 +151,10 @@ def htm(pg_config) -> "HitTypeManager":
@pytest.fixture(scope="session")
-def hm(pg_config) -> "HitManager":
- assert "/unittest-" in pg_config.dsn.path
+def hm(pg_config: PostgresConfig) -> "HitManager":
+ assert (
+ pg_config.dsn.path and "/unittest-" in pg_config.dsn.path
+ ), "pg_config must point to a unittest database (dsn path must contain '/unittest-')"
from jb.managers.hit import HitManager
@@ -148,8 +164,10 @@ def hm(pg_config) -> "HitManager":
@pytest.fixture(scope="session")
-def am(pg_config) -> "AssignmentManager":
- assert "/unittest-" in pg_config.dsn.path
+def am(pg_config: PostgresConfig) -> "AssignmentManager":
+ assert (
+ pg_config.dsn.path and "/unittest-" in pg_config.dsn.path
+ ), "pg_config must point to a unittest database (dsn path must contain '/unittest-')"
from jb.managers.assignment import AssignmentManager
@@ -159,8 +177,10 @@ def am(pg_config) -> "AssignmentManager":
@pytest.fixture(scope="session")
-def bm(pg_config) -> "BonusManager":
- assert "/unittest-" in pg_config.dsn.path
+def bm(pg_config: PostgresConfig) -> "BonusManager":
+ assert (
+ pg_config.dsn.path and "/unittest-" in pg_config.dsn.path
+ ), "pg_config must point to a unittest database (dsn path must contain '/unittest-')"
from jb.managers.bonus import BonusManager
@@ -178,7 +198,7 @@ def question() -> HitQuestion:
@pytest.fixture
-def question_record(hqm, question) -> HitQuestion:
+def question_record(hqm: "HitQuestionManager", question: HitQuestion) -> HitQuestion:
return hqm.get_or_create(question)
@@ -197,14 +217,14 @@ def hit_type() -> HitType:
@pytest.fixture
-def hit_type_record(htm, hit_type) -> HitType:
+def hit_type_record(htm: "HitTypeManager", hit_type: HitType) -> HitType:
hit_type.amt_hit_type_id = generate_amt_id()
return htm.get_or_create(hit_type)
@pytest.fixture
-def hit_type_with_amt_id(htm, hit_type: HitType) -> HitType:
+def hit_type_with_amt_id(htm: "HitTypeManager", hit_type: HitType) -> HitType:
# This is a real hit type I've previously registered with amt (sandbox).
# It will always exist
hit_type.amt_hit_type_id = "3217B3DC4P5YW9DRV9R3X8O56V041J"
@@ -225,7 +245,9 @@ def amt_hit_id() -> str:
@pytest.fixture
-def hit(amt_hit_id, amt_hit_type_id, amt_group_id, question) -> Hit:
+def hit(
+ amt_hit_id: str, amt_hit_type_id: str, amt_group_id: str, question: HitQuestion
+) -> Hit:
now = datetime.now(tz=timezone.utc)
return Hit.model_validate(
@@ -255,11 +277,11 @@ def hit(amt_hit_id, amt_hit_type_id, amt_group_id, question) -> Hit:
@pytest.fixture
def hit_record(
- hm,
- question_record,
- hit_type_record,
- hit,
- amt_hit_id,
+ hm: "HitManager",
+ question_record: HitQuestion,
+ hit_type_record: HitType,
+ hit: Hit,
+ amt_hit_id: str,
) -> Hit:
"""
Returns a hit that exists in our db, but does not in amazon (the amt ids
@@ -276,7 +298,9 @@ def hit_record(
@pytest.fixture
-def hit_in_amt(hm, question_record, hit_type_with_amt_id: HitType) -> Hit:
+def hit_in_amt(
+ hm: "HitManager", question_record: HitQuestion, hit_type_with_amt_id: HitType
+) -> Hit:
# Actually create a new HIT in amt (sandbox)
hit = AMTManager.create_hit_with_hit_type(
hit_type=hit_type_with_amt_id, question=question_record
@@ -290,7 +314,7 @@ def hit_in_amt(hm, question_record, hit_type_with_amt_id: HitType) -> Hit:
@pytest.fixture
-def assignment_stub(hit: Hit, amt_assignment_id, amt_worker_id):
+def assignment_stub(hit: Hit, amt_assignment_id: str, amt_worker_id: str):
now = datetime.now(tz=timezone.utc)
return AssignmentStub(
amt_assignment_id=amt_assignment_id,
@@ -303,11 +327,27 @@ def assignment_stub(hit: Hit, amt_assignment_id, amt_worker_id):
@pytest.fixture
+def assignment_stub_record(
+ am: "AssignmentManager", hit_record: Hit, assignment_stub: AssignmentStub
+) -> AssignmentStub:
+ """
+ Returns an AssignmentStub that exists in our db, but does not in
+ amazon (the amt ids are random). The mtwerk_hit, mtwerk_hittype, and
+ mtwerk_question records will also exist (in the db)
+ """
+ assignment_stub.hit_id = hit_record.id
+ am.create_stub(stub=assignment_stub)
+ return assignment_stub
+
+
+@pytest.fixture
def assignment_factory(hit: Hit):
- def inner(amt_worker_id: str = None):
+
+ def inner(amt_worker_id: str = None) -> Assignment:
now = datetime.now(tz=timezone.utc)
amt_assignment_id = generate_amt_id()
amt_worker_id = amt_worker_id or generate_amt_id()
+
return Assignment(
amt_assignment_id=amt_assignment_id,
amt_hit_id=hit.amt_hit_id,
@@ -324,7 +364,10 @@ def assignment_factory(hit: Hit):
@pytest.fixture
-def assignment_in_db_factory(am, assignment_factory):
+def assignment_record_factory(
+ am: "AssignmentManager", assignment_factory: Callable[..., Assignment]
+):
+
def inner(hit_id: int, amt_worker_id: Optional[str] = None):
a = assignment_factory(amt_worker_id=amt_worker_id)
a.hit_id = hit_id
@@ -335,22 +378,11 @@ def assignment_in_db_factory(am, assignment_factory):
return inner
-@pytest.fixture
-def assignment_stub_in_db(am, hit_record, assignment_stub) -> AssignmentStub:
- """
- Returns an AssignmentStub that exists in our db, but does not in amazon (the amt ids are random).
- The mtwerk_hit, mtwerk_hittype, and mtwerk_question records will also exist (in the db)
- """
- assignment_stub.hit_id = hit_record.id
- am.create_stub(assignment_stub)
- return assignment_stub
-
-
-# --- HIT ---
+# --- Response ---
@pytest.fixture
-def amt_response_metadata():
+def amt_response_metadata() -> Dict[str, Any]:
req_id = str(uuid4())
return {
"RequestId": req_id,
@@ -367,7 +399,7 @@ def amt_response_metadata():
@pytest.fixture
def create_hit_type_response(
- amt_hit_type_id, amt_response_metadata
+ amt_hit_type_id: str, amt_response_metadata: Dict[str, Any]
) -> CreateHITTypeResponseTypeDef:
return {
"HITTypeId": amt_hit_type_id,
@@ -377,7 +409,7 @@ def create_hit_type_response(
@pytest.fixture
def create_hit_with_hit_type_response(
- amt_hit_type_id, amt_hit_id, amt_response_metadata
+ amt_hit_type_id: str, amt_hit_id: str, amt_response_metadata
) -> CreateHITWithHITTypeResponseTypeDef:
amt_group_id = generate_amt_id(length=30)
return {
@@ -407,7 +439,7 @@ def create_hit_with_hit_type_response(
@pytest.fixture
def get_hit_response(
- amt_hit_type_id, amt_hit_id, amt_response_metadata
+ amt_hit_type_id: str, amt_hit_id: str, amt_response_metadata
) -> GetHITResponseTypeDef:
amt_group_id = generate_amt_id(length=30)
return {
@@ -447,13 +479,12 @@ def get_hit_response_reviewing(get_hit_response):
@pytest.fixture
def get_assignment_response(
- amt_hit_type_id,
- amt_hit_id,
- amt_assignment_id,
- amt_worker_id,
+ amt_hit_id: str,
+ amt_assignment_id: str,
+ amt_worker_id: str,
get_hit_response,
amt_response_metadata,
- tsid,
+ tsid: str,
) -> GetAssignmentResponseTypeDef:
hit_response = get_hit_response["HIT"]
local_now = datetime.now(tz=tzlocal())
@@ -484,7 +515,7 @@ def get_assignment_response(
@pytest.fixture
def get_assignment_response_no_tsid(
- get_assignment_response, amt_worker_id, amt_assignment_id
+ get_assignment_response, amt_worker_id: str, amt_assignment_id: str
):
res = copy.deepcopy(get_assignment_response)
res["Assignment"]["Answer"] = (
diff --git a/tests/http/conftest.py b/tests/http/conftest.py
index 200bf1c..4f11fde 100644
--- a/tests/http/conftest.py
+++ b/tests/http/conftest.py
@@ -1,10 +1,19 @@
import httpx
+import redis
import pytest
import requests_mock
from asgi_lifespan import LifespanManager
from httpx import AsyncClient, ASGITransport
+from typing import Dict, Any
from jb.main import app
+import json
+
+from httpx import AsyncClient
+import secrets
+
+from jb.config import JB_EVENTS_STREAM, settings
+from tests import generate_amt_id
@pytest.fixture(scope="session")
@@ -48,3 +57,43 @@ def httpxclient_ip(httpxclient):
def mock_requests():
with requests_mock.Mocker() as m:
yield m
+
+
+def generate_hex_id(length: int = 40) -> str:
+ # length is number of hex chars, so we need length//2 bytes
+ return secrets.token_hex(length // 2)
+
+
+@pytest.fixture
+def mturk_event_body_record(
+ hit_record: Hit, assignment_stub_record: AssignmentStub
+) -> Dict[str, Any]:
+ return {
+ "Type": "Notification",
+ "Message": json.dumps(
+ {
+ "Events": [
+ {
+ "EventType": "AssignmentSubmitted",
+ "EventTimestamp": "2025-10-16T18:45:51.000000Z",
+ "HITId": hit_record.amt_hit_id,
+ "AssignmentId": assignment_stub_record.amt_assignment_id,
+ "HITTypeId": hit_record.amt_hit_type_id,
+ }
+ ],
+ "EventDocId": generate_hex_id(),
+ "SourceAccount": settings.aws_owner_id,
+ "CustomerId": generate_amt_id(length=14),
+ "EventDocVersion": "2006-05-05",
+ }
+ ),
+ }
+
+
+@pytest.fixture()
+def clean_mturk_events_redis_stream(redis: redis.Redis):
+ redis.xtrim(JB_EVENTS_STREAM, maxlen=0)
+ assert redis.xlen(JB_EVENTS_STREAM) == 0
+ yield
+ redis.xtrim(JB_EVENTS_STREAM, maxlen=0)
+ assert redis.xlen(JB_EVENTS_STREAM) == 0
diff --git a/tests/http/test_basic.py b/tests/http/test_basic.py
index 18359da..e806fa9 100644
--- a/tests/http/test_basic.py
+++ b/tests/http/test_basic.py
@@ -2,23 +2,24 @@ import pytest
from httpx import AsyncClient
-@pytest.mark.anyio
-async def test_base(httpxclient: AsyncClient):
- client = httpxclient
- res = await client.get("/")
- # actually returns 404. old test expects 401. idk what is should be
- print(res.text)
- # assert res.status_code == 404
- assert res.status_code == 200
+class TestBase:
+ @pytest.mark.anyio
+ async def test_base(self, httpxclient: AsyncClient):
+ client = httpxclient
+ res = await client.get("/")
+ # actually returns 404. old test expects 401. idk what is should be
+ print(res.text)
+ # assert res.status_code == 404
+ assert res.status_code == 200
-@pytest.mark.anyio
-async def test_static_file_alias(httpxclient: AsyncClient):
- client = httpxclient
- """
- These are here for site crawlers and stuff..
- """
- for p in ["/robots.txt", "/sitemap.xml", "/favicon.ico"]:
- res = await client.get(p)
- assert res.status_code == 200, p
- assert res.json() == {}
+ @pytest.mark.anyio
+ async def test_static_file_alias(self, httpxclient: AsyncClient):
+ client = httpxclient
+ """
+ These are here for site crawlers and stuff..
+ """
+ for p in ["/robots.txt", "/sitemap.xml", "/favicon.ico"]:
+ res = await client.get(p)
+ assert res.status_code == 200, p
+ assert res.json() == {}
diff --git a/tests/http/test_notifications.py b/tests/http/test_notifications.py
index 6770044..4386863 100644
--- a/tests/http/test_notifications.py
+++ b/tests/http/test_notifications.py
@@ -1,71 +1,102 @@
-import json
-
import pytest
+import json
+import redis
+from typing import Dict, Any
from httpx import AsyncClient
-import secrets
+from uuid import uuid4
from jb.config import JB_EVENTS_STREAM, settings
from jb.models.event import MTurkEvent
-from tests import generate_amt_id
-
-
-def generate_hex_id(length: int = 40) -> str:
- # length is number of hex chars, so we need length//2 bytes
- return secrets.token_hex(length // 2)
-
-
-@pytest.fixture
-def example_mturk_event_body(amt_hit_id, amt_hit_type_id, amt_assignment_id):
- return {
- "Type": "Notification",
- "Message": json.dumps(
- {
- "Events": [
- {
- "EventType": "AssignmentSubmitted",
- "EventTimestamp": "2025-10-16T18:45:51.000000Z",
- "HITId": amt_hit_id,
- "AssignmentId": amt_assignment_id,
- "HITTypeId": amt_hit_type_id,
- }
- ],
- "EventDocId": generate_hex_id(),
- "SourceAccount": settings.aws_owner_id,
- "CustomerId": generate_amt_id(length=14),
- "EventDocVersion": "2006-05-05",
- }
- ),
- }
-
-
-@pytest.fixture()
-def clean_mturk_events_redis_stream(redis):
- redis.xtrim(JB_EVENTS_STREAM, maxlen=0)
- assert redis.xlen(JB_EVENTS_STREAM) == 0
- yield
- redis.xtrim(JB_EVENTS_STREAM, maxlen=0)
- assert redis.xlen(JB_EVENTS_STREAM) == 0
-
-
-@pytest.mark.anyio
-async def test_mturk_notifications(
- redis,
- httpxclient: AsyncClient,
- no_limit,
- example_mturk_event_body,
- amt_assignment_id,
- clean_mturk_events_redis_stream,
-):
- client = httpxclient
-
- res = await client.post(url=f"/{settings.sns_path}/", json=example_mturk_event_body)
- res.raise_for_status()
-
- msg_res = redis.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100)
- msg_res = msg_res[0][1][0]
- msg_id, msg = msg_res
- redis.xdel(JB_EVENTS_STREAM, msg_id)
-
- msg_json = msg["data"]
- event = MTurkEvent.model_validate_json(msg_json)
- assert event.amt_assignment_id == amt_assignment_id
+from jb.models.hit import Hit
+
+
+class TestNotifications:
+
+ @pytest.mark.anyio
+ async def test_no_post_keys(
+ self,
+ httpxclient: AsyncClient,
+ ):
+ client = httpxclient
+
+ res = await client.post(url=f"/{settings.sns_path}/", json={})
+ assert res.status_code == 500
+ assert res.json() == {"detail": "Invalid JSON"}
+
+ @pytest.mark.anyio
+ async def test_no_post_data(
+ self,
+ httpxclient: AsyncClient,
+ ):
+ client = httpxclient
+
+ res = await client.post(url=f"/{settings.sns_path}/")
+ assert res.status_code == 500
+ assert res.json() == {"detail": "No POST data"}
+
+ @pytest.mark.anyio
+ async def test_invalid_post_data(
+ self,
+ httpxclient: AsyncClient,
+ ):
+ client = httpxclient
+
+ res = await client.post(
+ url=f"/{settings.sns_path}/", json={uuid4().hex: uuid4().hex}
+ )
+ assert res.status_code == 500
+ assert res.json() == {"detail": "Invalid JSON"}
+
+ @pytest.mark.anyio
+ async def test_mturk_notifications(
+ self,
+ redis: redis.Redis,
+ httpxclient: AsyncClient,
+ hit_record: Hit,
+ assignment_stub_record: AssignmentStub,
+ mturk_event_body_record: Dict[str, Any],
+ ):
+ client = httpxclient
+
+ json_msg = json.loads(mturk_event_body_record["Message"])
+ # Assert the mturk event is owned by the correct account
+ assert json_msg["SourceAccount"] == settings.aws_owner_id
+
+ # Assert the HIT and Assignment are "connected"
+ assert assignment_stub_record.hit_id == hit_record.id
+ assert assignment_stub_record.amt_hit_id == hit_record.amt_hit_id
+
+ # Assert the event body and the Assignment/HIT are "connected"
+ assert json_msg["Events"][0]["HITId"] == hit_record.amt_hit_id
+ assert json_msg["Events"][0]["HITId"] == assignment_stub_record.amt_hit_id
+ assert (
+ json_msg["Events"][0]["AssignmentId"]
+ == assignment_stub_record.amt_assignment_id
+ )
+
+ res = await client.post(
+ url=f"/{settings.sns_path}/", json=mturk_event_body_record
+ )
+ res.raise_for_status()
+
+ # AMT SNS needs to receive a 200 response to stop retrying the notification
+ assert res.status_code == 200
+ assert res.json() == {"status": "ok"}
+
+ # Check that the event was enqueued in Redis
+ msg_res = redis.xread(streams={JB_EVENTS_STREAM: 0}, count=1, block=100)
+ msg_res = msg_res[0][1][0]
+ msg_id, msg = msg_res
+ redis.xdel(JB_EVENTS_STREAM, msg_id)
+
+ msg_json = msg["data"]
+ event = MTurkEvent.model_validate_json(msg_json)
+
+ # Confirm that the event that we got from redis is what was POSTED by
+ # AMT SNS. We're using the fixture() so the values are the same as
+ # what was in the example_mturk_event_body
+ assert event.event_type == "AssignmentSubmitted"
+
+ assert event.amt_assignment_id == assignment_stub_record.amt_assignment_id
+ assert event.amt_hit_id == hit_record.amt_hit_id
+ assert event.amt_hit_type_id == hit_record.amt_hit_type_id
diff --git a/tests/http/test_preview.py b/tests/http/test_preview.py
index 2bdf265..467c63c 100644
--- a/tests/http/test_preview.py
+++ b/tests/http/test_preview.py
@@ -3,8 +3,8 @@
import pytest
from httpx import AsyncClient
-
from jb.models.hit import Hit
+from jb.models.assignment import AssignmentStub
class TestPreview:
@@ -15,8 +15,6 @@ class TestPreview:
res = await client.get("/preview/")
assert res.status_code == 200
- # the response is an html page
-
assert res.headers["content-type"] == "text/html; charset=utf-8"
assert res.num_bytes_downloaded == 507
@@ -25,11 +23,17 @@ class TestPreview:
assert "https://cdn.jamesbillings67.com/james-billings.js" in res.text
@pytest.mark.anyio
- async def test_preview_redirect_from_work(
- self, httpxclient: AsyncClient, amt_hit_id, amt_assignment_id
+ async def test_preview_redirect_from_work_random_str(
+ self, httpxclient: AsyncClient, amt_hit_id: str, amt_assignment_id: str
):
client = httpxclient
+ """
+ The redirect occurs regardless of any parameter validation. This is
+ because the query params will be used for record lookup on the work
+ page itself.
+ """
+
params = {
"workerId": None,
"assignmentId": amt_assignment_id,
@@ -39,3 +43,28 @@ class TestPreview:
assert res.status_code == 302
assert "/preview/" in res.headers["location"]
+
+ @pytest.mark.anyio
+ async def test_preview_redirect_from_work_records(
+ self,
+ httpxclient: AsyncClient,
+ hit_record: Hit,
+ assignment_stub_record: AssignmentStub,
+ ):
+ client = httpxclient
+
+ """
+ The redirect occurs regardless of any parameter validation. This is
+ because the query params will be used for record lookup on the work
+ page itself.
+ """
+
+ params = {
+ "workerId": None,
+ "assignmentId": assignment_stub_record.amt_assignment_id,
+ "hitId": hit_record.amt_hit_id,
+ }
+ res = await client.get("/work/", params=params)
+
+ assert res.status_code == 302
+ assert "/preview/" in res.headers["location"]
diff --git a/tests/http/test_work.py b/tests/http/test_work.py
index c69118b..66251f6 100644
--- a/tests/http/test_work.py
+++ b/tests/http/test_work.py
@@ -1,5 +1,9 @@
import pytest
from httpx import AsyncClient
+from jb.models.hit import Hit
+from jb.models.assignment import AssignmentStub
+
+from jb.managers.assignment import AssignmentManager
class TestWork:
@@ -8,17 +12,105 @@ class TestWork:
async def test_work(
self,
httpxclient: AsyncClient,
- hit_record,
- amt_assignment_id,
- amt_worker_id,
+ hit_record: Hit,
+ amt_assignment_id: str,
+ amt_worker_id: str,
):
client = httpxclient
+ assert isinstance(hit_record.id, int)
+
params = {
"workerId": amt_worker_id,
"assignmentId": amt_assignment_id,
"hitId": hit_record.amt_hit_id,
}
res = await client.get("/work/", params=params)
+ assert res.status_code == 200
+
+ @pytest.mark.anyio
+ async def test_work_no_hit_record(
+ self,
+ httpxclient: AsyncClient,
+ hit: Hit,
+ amt_assignment_id: str,
+ amt_worker_id: str,
+ ):
+ client = httpxclient
+
+ # Because no AssignmentStub record is created, and we're just using
+ # random strings as IDs, we should also confirm that the Hit record
+ # is not a saved record.
+ assert hit.id is None
+
+ params = {
+ "workerId": amt_worker_id,
+ "assignmentId": amt_assignment_id,
+ "hitId": hit.amt_hit_id,
+ }
+ res = await client.get("/work/", params=params)
+ assert res.status_code == 500
+
+ @pytest.mark.anyio
+ async def test_work_assignment_stub_existing(
+ self,
+ httpxclient: AsyncClient,
+ am: AssignmentManager,
+ hit: Hit,
+ assignment_stub_record: AssignmentStub,
+ amt_assignment_id: str,
+ amt_worker_id: str,
+ ):
+ client = httpxclient
+
+ # Because the AssignmentStub is created with a reference to the Hit,
+ # the Hit is actually a "Hit Record" (with a primary key), so it's
+ # saved in the database
+ assert isinstance(hit.id, int)
+
+ # Confirm that it exists in the database before the call
+ res = am.get_stub_if_exists(amt_assignment_id=amt_assignment_id)
+ assert isinstance(res, AssignmentStub)
+ assert isinstance(res.id, int)
+
+ params = {
+ "workerId": amt_worker_id,
+ "assignmentId": assignment_stub_record.amt_assignment_id,
+ "hitId": hit.amt_hit_id,
+ }
+ res = await client.get("/work/", params=params)
+ assert res.status_code == 200
+
+ # Confirm that it exists in the database
+ res = am.get_stub_if_exists(amt_assignment_id=amt_assignment_id)
+ assert isinstance(res, AssignmentStub)
+ assert isinstance(res.id, int)
+
+ @pytest.mark.anyio
+ async def test_work_assignment_stub_created(
+ self,
+ httpxclient: AsyncClient,
+ am: AssignmentManager,
+ hit_record: Hit,
+ assignment_stub: AssignmentStub,
+ amt_assignment_id: str,
+ amt_worker_id: str,
+ ):
+ client = httpxclient
+
+ # Confirm that it exists in the database before the call
+ res = am.get_stub_if_exists(amt_assignment_id=amt_assignment_id)
+ assert res is None
+ params = {
+ "workerId": amt_worker_id,
+ "assignmentId": assignment_stub.amt_assignment_id,
+ "hitId": hit_record.amt_hit_id,
+ }
+ res = await client.get("/work/", params=params)
assert res.status_code == 200
+
+ # Confirm that it exists in the database
+ res = am.get_stub_if_exists(amt_assignment_id=amt_assignment_id)
+ assert isinstance(res, AssignmentStub)
+ assert isinstance(res.id, int)