diff --git a/pkgs/clan-cli/clan_cli/webui/routers/vms.py b/pkgs/clan-cli/clan_cli/webui/routers/vms.py index 179b9e1..5b59329 100644 --- a/pkgs/clan-cli/clan_cli/webui/routers/vms.py +++ b/pkgs/clan-cli/clan_cli/webui/routers/vms.py @@ -2,7 +2,7 @@ import asyncio import json import logging import shlex -from typing import Annotated +from typing import Annotated, Iterator from uuid import UUID from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Request, status @@ -34,12 +34,10 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]: class NixBuildException(HTTPException): - def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]): - self.uuid = uuid + def __init__(self, msg: str, loc: list = ["body", "flake_attr"]): detail = [ { "loc": loc, - "uuid": str(uuid), "msg": msg, "type": "value_error", } @@ -65,7 +63,7 @@ class BuildVmTask(BaseTask): vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm" self.log.debug(f"vm_path: {vm_path}") - self.run_cmd(vm_path) + self.run_cmd([vm_path]) self.finished = True except Exception as e: self.failed = True @@ -103,7 +101,6 @@ async def inspect_vm( if proc.returncode != 0: raise NixBuildException( - "" f""" Failed to evaluate vm from '{flake_url}#{flake_attr}'. command: {shlex.join(cmd)} @@ -127,7 +124,7 @@ async def get_status(uuid: UUID) -> VmStatusResponse: @router.get("/api/vms/{uuid}/logs") async def get_logs(uuid: UUID) -> StreamingResponse: # Generator function that yields log lines as they are available - def stream_logs(): + def stream_logs() -> Iterator[str]: task = get_task(uuid) for proc in task.procs: diff --git a/pkgs/clan-cli/clan_cli/webui/task_manager.py b/pkgs/clan-cli/clan_cli/webui/task_manager.py index 25890ee..21374cb 100644 --- a/pkgs/clan-cli/clan_cli/webui/task_manager.py +++ b/pkgs/clan-cli/clan_cli/webui/task_manager.py @@ -5,6 +5,7 @@ import select import shlex import subprocess import threading +from typing import Any from uuid import UUID, uuid4 @@ -105,14 +106,14 @@ def get_task(uuid: UUID) -> BaseTask: return POOL[uuid] -def register_task(task: BaseTask, *kwargs) -> UUID: +def register_task(task: type, *args: Any) -> UUID: global POOL if not issubclass(task, BaseTask): raise TypeError("task must be a subclass of BaseTask") uuid = uuid4() - inst_task = task(uuid, *kwargs) + inst_task = task(uuid, *args) POOL[uuid] = inst_task inst_task.start() return uuid diff --git a/pkgs/clan-cli/tests/test_vms_api.py b/pkgs/clan-cli/tests/test_vms_api.py index 3f62ae7..2939bc5 100644 --- a/pkgs/clan-cli/tests/test_vms_api.py +++ b/pkgs/clan-cli/tests/test_vms_api.py @@ -2,6 +2,7 @@ from pathlib import Path import pytest from api import TestClient +from httpx import SyncByteStream # @pytest.mark.impure # def test_inspect(api: TestClient, test_flake_with_core: Path) -> None: @@ -41,6 +42,7 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None: response = api.get(f"/api/vms/{uuid}/logs") print("=========FLAKE LOGS==========") + assert isinstance(response.stream, SyncByteStream) for line in response.stream: assert line != b"", "Failed to get vm logs" print(line.decode("utf-8"), end="") @@ -48,6 +50,7 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None: assert response.status_code == 200, "Failed to get vm logs" response = api.get(f"/api/vms/{uuid}/logs") + assert isinstance(response.stream, SyncByteStream) print("=========VM LOGS==========") for line in response.stream: assert line != b"", "Failed to get vm logs"