from __future__ import annotations
import json
import warnings
from collections.abc import Collection
from copy import deepcopy
from pathlib import Path
from typing import Any, get_args, get_origin
from pydantic import BaseModel, RootModel
from perago.errors import TaskDefinitionError
from perago.models import WorkspaceInput, WorkspaceOutput
from perago.task import TaskDefinition
TASKDEF_SCHEMA_VERSION = 1
TASKDEF_SCHEMA_TYPE = "JSON"
_MODEL_SCHEMA_STRUCTURAL_KEYS = frozenset({"$defs", "additionalProperties", "properties", "required", "type"})
_GENERATED_SCHEMA_METADATA_KEYS = frozenset({"title"})
_SCHEMA_NAME_MAPPING_KEYS = frozenset({"properties"})
CONTROL_FIELD_MAP = {
"retryCount": ("retry", "count"),
"retryLogic": ("retry", "logic"),
"retryDelaySeconds": ("retry", "delay_seconds"),
"maxRetryDelaySeconds": ("retry", "max_delay_seconds"),
"backoffJitterMs": ("retry", "jitter_ms"),
"totalTimeoutSeconds": ("timeout", "total_seconds"),
"timeoutPolicy": ("timeout", "policy"),
"timeoutSeconds": ("timeout", "seconds"),
"responseTimeoutSeconds": ("response_timeout_seconds",),
"pollTimeoutSeconds": ("timeout", "poll_seconds"),
"concurrentExecLimit": ("limits", "concurrent_exec_limit"),
"rateLimitFrequencyInSeconds": ("limits", "rate_limit_frequency_in_seconds"),
"rateLimitPerFrequency": ("limits", "rate_limit_per_frequency"),
}
[文档]
def build_taskdef(task: TaskDefinition) -> dict[str, Any]:
"""
Build the Conductor TaskDef dictionary for one Perago task.
``build_taskdef`` is the library equivalent of ``perago extract``. It
converts a validated :class:`perago.TaskDefinition` into the JSON-compatible
mapping registered with Conductor. The task function signature determines
input and output keys, Pydantic models provide JSON Schema, and
:class:`perago.TaskControls` provide retry, timeout, response timeout, and
execution limit fields.
Parameters
----------
task : TaskDefinition
Validated task definition returned by :func:`perago.load_module_task`
or attached to a decorated function as ``__perago_task__``.
Returns
-------
dict of str to Any
JSON-compatible Conductor TaskDef mapping. Workspace tasks contain
``workspace`` and ``params`` input keys and ``workspace`` and
``result`` output keys; workspace-free tasks contain only ``params``
and ``result``.
See Also
--------
write_taskdef : Write the generated TaskDef mapping to a JSON file.
Notes
-----
Workspace guardrails, workspace prefixes, LakeFS connection settings, and
publish budget internals are not serialized into the TaskDef. A publish
budget does not replace ``timeout.response_seconds``; writable workspace
tasks warn if the configured response timeout is shorter than the derived
publish budget.
Examples
--------
>>> task_def = build_taskdef(load_module_task("app.workers.features_build"))
>>> task_def["name"]
'features.build'
"""
validate_no_root_task_models(task)
input_properties: dict[str, Any] = {}
output_properties: dict[str, Any] = {}
input_required: list[str] = []
output_required: list[str] = []
if task.has_workspace:
input_properties["workspace"] = schema_for_model(WorkspaceInput)
output_properties["workspace"] = schema_for_model(WorkspaceOutput)
input_required.append("workspace")
output_required.append("workspace")
input_properties["params"] = schema_for_model(task.params_model)
output_properties["result"] = schema_for_model(task.output_model)
input_required.append("params")
output_required.append("result")
data: dict[str, Any] = {
"name": task.name,
"ownerEmail": task.owner_email,
}
if task.description is not None:
data["description"] = task.description
data.update(
{
**_control_fields(task),
"inputKeys": input_required,
"outputKeys": output_required,
"inputSchema": {
"name": f"{task.name}.input",
"version": TASKDEF_SCHEMA_VERSION,
"type": TASKDEF_SCHEMA_TYPE,
"data": _object_schema(input_properties, input_required),
},
"outputSchema": {
"name": f"{task.name}.output",
"version": TASKDEF_SCHEMA_VERSION,
"type": TASKDEF_SCHEMA_TYPE,
"data": _object_schema(output_properties, output_required),
},
}
)
return data
[文档]
def write_taskdef(task: TaskDefinition, output: Path) -> Path:
"""
Write a generated Conductor TaskDef to a JSON file.
The parent directory is created when needed, and the file is written with
stable indentation so the generated TaskDef can be reviewed before it is
registered with Conductor.
Parameters
----------
task : TaskDefinition
Validated task definition to serialize.
output : pathlib.Path
Destination JSON file path. The path must end with ``.json`` and must
not point to an existing directory.
Returns
-------
pathlib.Path
The output path after the JSON file has been written.
Raises
------
ValueError
If ``output`` does not end with ``.json`` or points to a directory.
See Also
--------
build_taskdef : Build the TaskDef mapping without writing a file.
Examples
--------
>>> task_def = load_module_task("app.workers.metadata_validate")
>>> write_taskdef(task_def, Path("generated/metadata.validate.json"))
PosixPath('generated/metadata.validate.json')
"""
if output.suffix != ".json":
raise ValueError("output must be a JSON file path, for example generated/features.build.json")
if output.exists() and output.is_dir():
raise ValueError("output must be a JSON file path, not a directory")
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(json.dumps(build_taskdef(task), indent=2, sort_keys=False) + "\n", encoding="utf-8")
return output
def schema_for_model(model: type[BaseModel]) -> dict[str, Any]:
schema = model.model_json_schema()
_strip_model_schema_metadata(schema)
inlined = _inline_refs(schema)
_strip_schema_metadata_keys(
inlined,
_GENERATED_SCHEMA_METADATA_KEYS,
preserve_mapping_keys=_SCHEMA_NAME_MAPPING_KEYS,
)
_close_object_schemas(inlined)
return inlined
def task_models_with_config(task: TaskDefinition) -> list[type[BaseModel]]:
configured: dict[type[BaseModel], None] = {}
for model in (task.params_model, task.output_model):
for schema_model in _iter_model_graph(model):
if schema_model.model_config:
configured[schema_model] = None
return list(configured)
def task_models_with_root_model(task: TaskDefinition) -> list[type[BaseModel]]:
root_models: dict[type[BaseModel], None] = {}
for model in (task.params_model, task.output_model):
for schema_model in _iter_model_graph(model):
if issubclass(schema_model, RootModel):
root_models[schema_model] = None
return list(root_models)
def validate_no_root_task_models(task: TaskDefinition) -> None:
root_models = task_models_with_root_model(task)
if not root_models:
return
names = ", ".join(model.__name__ for model in root_models)
raise TaskDefinitionError(
"Pydantic RootModel on task model(s) "
f"{names} is not supported; Perago task contracts must use ordinary BaseModel object models."
)
def _control_fields(task: TaskDefinition) -> dict[str, Any]:
fields: dict[str, Any] = {}
for conductor_name, path in CONTROL_FIELD_MAP.items():
if conductor_name == "responseTimeoutSeconds":
value: object = _response_timeout_seconds(task)
else:
value = task.controls
for segment in path:
value = getattr(value, segment)
if value is not None:
fields[conductor_name] = value
return fields
def _response_timeout_seconds(task: TaskDefinition) -> int:
if task.workspace is not None and task.workspace.read_only:
return task.controls.timeout.response_seconds
publish_budget = task.controls.publish_budget
response_seconds = task.controls.timeout.response_seconds
if publish_budget is not None and response_seconds < publish_budget.response_timeout_seconds:
warnings.warn(
f"Task {task.name!r} has TaskControls.timeout.response_seconds={response_seconds} "
"which is shorter than "
f"publish_budget.response_timeout_seconds={publish_budget.response_timeout_seconds}; "
"responseTimeoutSeconds "
"will use timeout.response_seconds",
UserWarning,
stacklevel=4,
)
return response_seconds
def _object_schema(properties: dict[str, Any], required: list[str]) -> dict[str, Any]:
return {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False,
}
def _inline_refs(schema: dict[str, Any]) -> dict[str, Any]:
copied = deepcopy(schema)
defs = copied.pop("$defs", {})
def visit(value: Any) -> Any:
if isinstance(value, dict):
ref = value.get("$ref")
if isinstance(ref, str) and ref.startswith("#/$defs/"):
name = ref.removeprefix("#/$defs/")
replacement = deepcopy(defs[name])
siblings = {key: visit(item) for key, item in value.items() if key != "$ref"}
replacement.update(siblings)
return visit(replacement)
return {key: visit(item) for key, item in value.items()}
if isinstance(value, list):
return [visit(item) for item in value]
return value
return visit(copied)
def _close_object_schemas(schema: Any) -> None:
if isinstance(schema, dict):
if schema.get("type") == "object":
schema.setdefault("additionalProperties", False)
for value in schema.values():
_close_object_schemas(value)
elif isinstance(schema, list):
for value in schema:
_close_object_schemas(value)
def _strip_schema_metadata_keys(schema: Any, keys: Collection[str], *, preserve_mapping_keys: Collection[str]) -> None:
def visit(value: Any, *, in_preserved_mapping: bool = False) -> None:
if isinstance(value, dict):
if not in_preserved_mapping:
for key in keys:
value.pop(key, None)
for key, item in value.items():
visit(item, in_preserved_mapping=(key in preserve_mapping_keys))
elif isinstance(value, list):
for item in value:
visit(item, in_preserved_mapping=in_preserved_mapping)
visit(schema)
def _strip_model_schema_metadata(schema: dict[str, Any]) -> None:
_strip_object_schema_metadata(schema)
defs = schema.get("$defs", {})
if not isinstance(defs, dict):
return
for definition in defs.values():
if isinstance(definition, dict) and definition.get("type") == "object":
_strip_object_schema_metadata(definition)
def _strip_object_schema_metadata(schema: dict[str, Any]) -> None:
for key in list(schema):
if key not in _MODEL_SCHEMA_STRUCTURAL_KEYS:
schema.pop(key, None)
def _iter_model_graph(model: type[BaseModel]) -> list[type[BaseModel]]:
seen: set[type[BaseModel]] = set()
pending = [model]
ordered: list[type[BaseModel]] = []
while pending:
current = pending.pop()
if current in seen:
continue
seen.add(current)
ordered.append(current)
for field in current.model_fields.values():
pending.extend(_iter_annotation_models(field.annotation))
return ordered
def _iter_annotation_models(annotation: Any) -> list[type[BaseModel]]:
models: list[type[BaseModel]] = []
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
models.append(annotation)
for argument in get_args(annotation):
models.extend(_iter_annotation_models(argument))
origin = get_origin(annotation)
if isinstance(origin, type) and issubclass(origin, BaseModel):
models.append(origin)
return models