"""Capture portable provenance metadata for data-processing outputs."""
from __future__ import annotations
import hashlib
import os
import re
import shlex
import subprocess
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Literal
from ._json import to_jsonable
Backend = Literal["git-lfs", "dvc", "git", "filesystem", "unknown"]
[docs]
@dataclass(frozen=True)
class GitState:
"""Snapshot of a Git repository at one point in time."""
repo_root: Path
commit: str | None
branch: str | None
remote_url: str | None
dirty: bool
dirty_marker: str
status_short: str
diff_hash: str | None = None
[docs]
def canonicalize_remote_url(remote_url: str | None) -> str | None:
"""Return a portable, reader-friendly form of a Git remote URL."""
if not remote_url:
return None
remote_url = remote_url.strip()
if not remote_url:
return None
def clean_path(path: str) -> str:
path = path.strip("/")
return path[:-4] if path.endswith(".git") else path
if remote_url.startswith(("http://", "https://")):
scheme, rest = remote_url.split("://", 1)
return f"{scheme}://{clean_path(rest)}"
match = re.match(r"git@([^:]+):(.+)", remote_url)
if match:
host, path = match.groups()
return f"https://{host}/{clean_path(path)}"
match = re.match(r"ssh://(?:[^@/]+@)?([^/]+)/(.+)", remote_url)
if match:
host, path = match.groups()
if host == "github":
host = "github.com"
return f"https://{host}/{clean_path(path)}"
if remote_url.startswith("github:"):
return f"https://github.com/{clean_path(remote_url.removeprefix('github:'))}"
return remote_url
[docs]
def run_git(args: Sequence[str], cwd: Path | str) -> tuple[bool, str, str]:
"""Run Git and return ``(ok, stdout, error)`` without raising."""
try:
proc = subprocess.run(
["git", *args],
cwd=Path(cwd),
check=False,
capture_output=True,
text=True,
encoding="utf-8",
errors="replace",
)
except OSError as err:
return False, "", str(err)
if proc.returncode != 0:
error = proc.stderr.strip() or proc.stdout.strip()
return False, proc.stdout, error or f"git {' '.join(args)} failed"
return True, proc.stdout, ""
[docs]
def discover_repo_root(repo_dir: Path | str | None = None, *, max_parent_levels: int = 3) -> Path | None:
"""Find the Git repository root for a directory or one of its parents."""
start = Path.cwd() if repo_dir is None else Path(repo_dir).expanduser()
candidates = [start.resolve()]
current = candidates[0]
for _ in range(max(0, max_parent_levels)):
parent = current.parent
if parent == current:
break
candidates.append(parent)
current = parent
for candidate in candidates:
ok, output, _ = run_git(["rev-parse", "--show-toplevel"], cwd=candidate)
if ok and output.strip():
return Path(output.strip()).resolve()
return None
[docs]
def get_git_state(
repo_dir: Path | str = ".",
*,
remote: str = "origin",
include_diff_hash: bool = True,
) -> GitState:
"""Capture the current Git state for a repository."""
repo_root = discover_repo_root(repo_dir, max_parent_levels=0)
if repo_root is None:
raise RuntimeError(f"Could not resolve git repository from {repo_dir}.")
_, commit_output, _ = run_git(["rev-parse", "HEAD"], cwd=repo_root)
_, branch_output, _ = run_git(["branch", "--show-current"], cwd=repo_root)
remote_ok, remote_output, _ = run_git(["remote", "get-url", remote], cwd=repo_root)
_, status_output, _ = run_git(["status", "--porcelain"], cwd=repo_root)
dirty = bool(status_output.strip())
diff_hash = None
if include_diff_hash and dirty:
_, staged, _ = run_git(["diff", "--cached", "--binary", "--no-ext-diff", "--"], cwd=repo_root)
_, unstaged, _ = run_git(["diff", "--binary", "--no-ext-diff", "--"], cwd=repo_root)
if staged or unstaged:
diff_hash = hashlib.sha256(f"{staged}{unstaged}".encode("utf-8", errors="replace")).hexdigest()
return GitState(
repo_root=repo_root,
commit=commit_output.strip() or None,
branch=branch_output.strip() or None,
remote_url=canonicalize_remote_url(remote_output.strip() if remote_ok and remote_output.strip() else None),
dirty=dirty,
dirty_marker="+dirty" if dirty else "",
status_short=status_output.rstrip("\n"),
diff_hash=diff_hash,
)
def _repo_name_from_remote(remote_url: str | None) -> str | None:
remote_url = canonicalize_remote_url(remote_url)
if not remote_url:
return None
return remote_url.rstrip("/").removesuffix(".git").split("/")[-1] or None
def _repo_name_from_path(path: str | None) -> str | None:
if not path:
return None
return Path(path).name or None
[docs]
def public_git_state(state: GitState | Mapping[str, Any]) -> dict[str, Any]:
"""Return a portable Git state record suitable for public metadata."""
data = to_jsonable(state)
remote_url = canonicalize_remote_url(data.get("remote_url") or data.get("remote"))
dirty = bool(data.get("dirty") or data.get("git_dirty"))
name = (
data.get("name")
or data.get("repo")
or data.get("package")
or _repo_name_from_remote(remote_url)
or _repo_name_from_path(data.get("repo_root"))
or _repo_name_from_path(data.get("label"))
)
result: dict[str, Any] = {
"name": name,
"commit": data.get("commit") or data.get("git_head"),
"branch": data.get("branch") or data.get("git_branch"),
"remote_url": remote_url,
"dirty": dirty,
}
if dirty:
result["dirty_marker"] = data.get("dirty_marker") or "+dirty"
for key in ("diff_hash", "status_hash", "status_short", "untracked_files"):
if data.get(key):
result[key] = data[key]
if data.get("patch"):
result["patch"] = data["patch"]
return {key: value for key, value in result.items() if value not in (None, "")}
def format_git_state(state: GitState | Mapping[str, Any]) -> str:
data = public_git_state(state)
commit = str(data.get("commit") or "unknown")
branch = data.get("branch") or "detached"
prefix = f"{data['name']}@" if data.get("name") else ""
marker = "+dirty" if data.get("dirty") else ""
return f"{prefix}{commit[:12]}{marker} ({branch})"
def _path_kind(path: Path) -> str:
if not path.exists():
return "missing"
if path.is_dir():
return "directory"
if path.is_file():
return "file"
return "other"
def _repo_relative(path: Path, repo_root: Path) -> str | None:
try:
return path.resolve().relative_to(repo_root).as_posix()
except (OSError, ValueError):
return None
def _git_status_for_path(repo_root: Path, rel: str) -> str:
ok, output, _ = run_git(["status", "--porcelain", "--", rel], cwd=repo_root)
return output.rstrip("\n") if ok else ""
def _is_git_tracked(repo_root: Path, rel: str) -> bool:
ok, _, _ = run_git(["ls-files", "--error-unmatch", "--", rel], cwd=repo_root)
return ok
def _parse_lfs_pointer(text: str) -> dict[str, Any] | None:
lines = [line.strip() for line in text.splitlines()]
if not lines or lines[0] != "version https://git-lfs.github.com/spec/v1":
return None
pointer: dict[str, Any] = {"is_pointer_file": True}
for line in lines[1:]:
if line.startswith("oid sha256:"):
pointer["oid"] = line.removeprefix("oid sha256:")
elif line.startswith("size "):
raw = line.removeprefix("size ")
try:
pointer["size"] = int(raw)
except ValueError:
pointer["size"] = raw
return pointer
def _lfs_metadata(path: Path, repo_root: Path | None, rel: str | None) -> dict[str, Any]:
metadata: dict[str, Any] = {"is_pointer_file": False, "tracked_by_lfs": False}
if path.is_file():
try:
pointer = _parse_lfs_pointer(path.read_text(encoding="utf-8", errors="replace")[:512])
except OSError as err:
pointer = None
metadata["pointer_error"] = str(err)
if pointer is not None:
metadata.update(pointer)
if repo_root is None or rel is None:
return metadata
attr_ok, attr_output, _ = run_git(["check-attr", "filter", "--", rel], cwd=repo_root)
if attr_ok and attr_output.strip().endswith("filter: lfs"):
metadata["tracked_by_lfs"] = True
lfs_ok, lfs_output, lfs_error = run_git(["lfs", "ls-files", "--long", "--", rel], cwd=repo_root)
if lfs_ok and lfs_output.strip():
parts = lfs_output.split()
if parts:
metadata["oid"] = parts[0]
metadata["tracked_by_lfs"] = True
metadata["lfs_ls_files"] = lfs_output.strip()
elif lfs_error:
metadata["lfs_error"] = lfs_error
return metadata
def _simple_dvc_outputs(text: str) -> list[dict[str, Any]]:
outputs: list[dict[str, Any]] = []
current: dict[str, Any] | None = None
for raw_line in text.splitlines():
line = raw_line.strip()
if line.startswith("-"):
if current:
outputs.append(current)
current = {}
line = line[1:].strip()
if current is None or ":" not in line:
continue
key, value = line.split(":", 1)
key = key.strip()
value = value.strip().strip("\"'")
if key in {"path", "md5", "hash", "etag", "size", "nfiles"}:
current[key] = value
if current:
outputs.append(current)
return outputs
def _dvc_metadata(path: Path, repo_root: Path | None, rel: str | None) -> dict[str, Any]:
candidates: list[Path] = []
if repo_root is not None and rel is not None:
candidates.append(repo_root / f"{rel}.dvc")
candidates.append(repo_root / "dvc.lock")
candidates.append(path.with_name(f"{path.name}.dvc"))
metadata: dict[str, Any] = {"dvc_files": [], "outputs": []}
seen: set[Path] = set()
for candidate in candidates:
candidate = candidate.resolve()
if candidate in seen or not candidate.exists() or not candidate.is_file():
continue
seen.add(candidate)
try:
text = candidate.read_text(encoding="utf-8", errors="replace")
except OSError as err:
metadata.setdefault("errors", []).append({"path": str(candidate), "error": str(err)})
continue
if candidate.name == "dvc.lock" and rel is not None and rel not in text:
continue
dvc_info = {"path": str(candidate), "outputs": _simple_dvc_outputs(text)}
if repo_root is not None:
dvc_rel = _repo_relative(candidate, repo_root)
if dvc_rel is not None:
dvc_info["git_status"] = _git_status_for_path(repo_root, dvc_rel)
metadata["dvc_files"].append(dvc_info)
metadata["outputs"].extend(dvc_info["outputs"])
return metadata
[docs]
def summarize_directory(path: Path | str, *, max_entries: int = 20_000) -> dict[str, Any]:
"""Summarize a directory without embedding a full file listing."""
root = Path(path)
file_count = 0
total_bytes = 0
digest = hashlib.sha256()
truncated = False
for child in sorted(
(item for item in root.rglob("*") if item.is_file()),
key=lambda p: p.as_posix(),
):
file_count += 1
try:
stat = child.stat()
except OSError:
continue
total_bytes += stat.st_size
if file_count <= max_entries:
rel = child.relative_to(root).as_posix()
digest.update(f"{rel}\0{stat.st_size}\0{stat.st_mtime_ns}\n".encode())
else:
truncated = True
return {
"file_count": file_count,
"total_bytes": total_bytes,
"manifest_hash": digest.hexdigest(),
"manifest_hash_kind": "paths-size-mtime-ns",
"manifest_truncated": truncated,
"max_entries": max_entries,
}
[docs]
def public_provenance(value: Any) -> Any:
"""Return provenance metadata intended to be written into public outputs."""
if isinstance(value, InputPathState):
return public_input_path_state(value)
if isinstance(value, GitState):
return public_git_state(value)
data = to_jsonable(value)
if isinstance(data, Mapping):
result: dict[str, Any] = {}
for key, item in data.items():
if key in {"history_entry", "repo_root", "source_path"}:
continue
if key == "project_repo":
result[key] = public_git_state(item)
continue
if key == "software_repos" and isinstance(item, Iterable):
result[key] = [public_git_state(state) for state in item]
continue
if key == "configured_repos" and isinstance(item, Iterable):
result[key] = [public_git_state(state) for state in item]
continue
if key == "input_paths" and isinstance(item, Iterable):
result[key] = [public_input_path_state(state) for state in item]
continue
if key == "remote_url":
result[key] = canonicalize_remote_url(str(item)) if item else None
continue
if key == "diff_hash" and not data.get("dirty"):
continue
if key == "status_short" and not item:
continue
result[str(key)] = public_provenance(item)
return {key: item for key, item in result.items() if item not in (None, "")}
if isinstance(data, list):
return [public_provenance(item) for item in data]
return data
[docs]
def clean_command_parts(parts: Sequence[str]) -> list[str]:
"""Remove reprotrail/provenance sidecar flags from recorded commands."""
cleaned = []
skip_next = False
for part in [str(item) for item in parts]:
if skip_next:
skip_next = False
continue
if part in {"--provenance-json", "--reprotrail-provenance-json"}:
skip_next = True
continue
if part.startswith(("--provenance-json=", "--reprotrail-provenance-json=")):
continue
cleaned.append(part)
return cleaned
def _command_text(command: str | Sequence[str] | None) -> str:
if command is None:
return "unknown command"
if isinstance(command, str):
return command
return shlex.join(clean_command_parts(command))
[docs]
def build_cf_history_entry(
command: str | Sequence[str] | None = None,
*,
git_state: GitState | Mapping[str, Any] | None = None,
git_states: Sequence[GitState | Mapping[str, Any]] = (),
input_states: Sequence[InputPathState | Mapping[str, Any]] = (),
timestamp: datetime | None = None,
include_inputs: bool = False,
) -> str:
"""Build a timestamped history line suitable for CF/xarray attrs."""
when = timestamp or datetime.now(timezone.utc)
if when.tzinfo is None:
when = when.replace(tzinfo=timezone.utc)
parts = [when.astimezone(timezone.utc).isoformat(timespec="seconds")]
parts.append(_command_text(command))
states = list(git_states)
if git_state is not None:
states.insert(0, git_state)
if states:
parts.append("software=" + ", ".join(format_git_state(state) for state in states))
if include_inputs and input_states:
compact = []
for state in input_states:
data = to_jsonable(state)
compact.append(f"{Path(data.get('path', 'unknown')).name}:{data.get('backend', 'unknown')}")
parts.append("inputs=" + ", ".join(compact))
return "; ".join(parts)
[docs]
def append_cf_history(existing: str | None, entry: str) -> str:
"""Prepend a new entry to existing CF history text."""
existing = (existing or "").strip()
return entry if not existing else f"{entry}\n{existing}"
[docs]
def append_xarray_history(obj: Any, entry: str, *, copy: bool = False) -> Any:
"""Prepend a history entry to an xarray-like object's ``attrs``."""
if copy:
obj = obj.copy()
obj.attrs["history"] = append_cf_history(obj.attrs.get("history"), entry)
return obj
[docs]
def enforce_clean_repos(
repos: Iterable[Path | str],
*,
allow_dirty: bool = False,
missing_ok: bool = True,
) -> list[GitState]:
"""Validate that repositories are clean unless dirty state is allowed."""
states: list[GitState] = []
failures = []
for repo in repos:
repo_path = Path(repo).expanduser()
if not repo_path.exists():
if not missing_ok:
failures.append(f"{repo}: path does not exist")
continue
try:
state = get_git_state(repo_path)
except RuntimeError as err:
if not missing_ok:
failures.append(str(err))
continue
states.append(state)
if state.dirty and not allow_dirty:
failures.append(f"{state.repo_root} is dirty:\n{state.status_short}")
if failures:
raise RuntimeError("Dirty software repository state requires --allow-dirty.\n" + "\n\n".join(failures))
return states
[docs]
def env_allows_dirty(var: str = "REPROTRAIL_ALLOW_DIRTY") -> bool:
"""Return whether an environment variable opts into dirty repositories."""
return os.environ.get(var, "").strip().lower() in {"1", "true", "yes", "on"}