diff --git a/backend/inkflow/analysis/benchmark.py b/backend/inkflow/analysis/benchmark.py new file mode 100644 index 0000000..dd8e8f9 --- /dev/null +++ b/backend/inkflow/analysis/benchmark.py @@ -0,0 +1,473 @@ +"""Benchmark des modeles d'analyse contre les fichiers de reference. + +Les fichiers `data//reference/chNN.json` sont des verites terrain corrigees +a la main (meme schema `ChapterAnalysis`). Ce module compare la sortie d'un +modele (`hypothese`) a ces references et chiffre la qualite sur trois dimensions : + +1. **Attribution du locuteur** (le point faible du petit modele local) ; +2. **Incises** (bornes start/end dans la replique) ; +3. **Type narration/dialogue** et flag `glued_to_prev` (garde-fou de regression). + +Le scoring (`score_chapter`/`aggregate`) est **pur** : aucune dependance MLX ni +disque, teste comme `analysis.segmenter`. Le runner (`run_benchmark`) met +plusieurs modeles en concurrence : il relance `analyze_chapter` en memoire (sans +ecraser les artefacts) avec chaque `model_id`, score, libere le modele, enchaine. +""" +from __future__ import annotations + +import difflib +import re +import time +from dataclasses import dataclass, field +from typing import Callable, Optional + +from pydantic import BaseModel, Field + +from ..models import Cast, ChapterAnalysis, Incise, Segment, SegmentType + +# --- Normalisation ----------------------------------------------------------- + +_WS_RE = re.compile(r"\s+") + + +def _norm_text(text: str) -> str: + """Texte normalise pour l'alignement (insensible aux espaces/casse).""" + return _WS_RE.sub(" ", text).strip().casefold() + + +def _alias_map(cast: Optional[Cast]) -> dict[str, str]: + """alias/nom (casefold) -> nom canonique, pour ne pas penaliser les variantes.""" + mapping: dict[str, str] = {} + if cast is None: + return mapping + for c in cast.characters: + canon = c.name.strip() + mapping[canon.casefold()] = canon + for alias in c.aliases: + mapping[alias.strip().casefold()] = canon + return mapping + + +def _norm_speaker(name: str, alias_map: dict[str, str]) -> str: + key = (name or "").strip().casefold() + return alias_map.get(key, key) + + +# --- Comptes bruts (permettent une micro-moyenne sur plusieurs chapitres) ---- + +@dataclass +class _Counts: + seg_total: int = 0 + seg_correct: int = 0 # locuteur correct (tous segments) + dlg_total: int = 0 + dlg_correct: int = 0 # locuteur correct (dialogues seuls) + type_total: int = 0 + type_correct: int = 0 + glued_total: int = 0 + glued_correct: int = 0 + inc_exact_tp: int = 0 + inc_exact_fp: int = 0 + inc_exact_fn: int = 0 + inc_ov_tp: int = 0 + inc_ov_fp: int = 0 + inc_ov_fn: int = 0 + errors: list["SpeakerError"] = field(default_factory=list) + confusion: dict[str, dict[str, int]] = field(default_factory=dict) + warnings: list[str] = field(default_factory=list) + + def add(self, other: "_Counts") -> None: + self.seg_total += other.seg_total + self.seg_correct += other.seg_correct + self.dlg_total += other.dlg_total + self.dlg_correct += other.dlg_correct + self.type_total += other.type_total + self.type_correct += other.type_correct + self.glued_total += other.glued_total + self.glued_correct += other.glued_correct + self.inc_exact_tp += other.inc_exact_tp + self.inc_exact_fp += other.inc_exact_fp + self.inc_exact_fn += other.inc_exact_fn + self.inc_ov_tp += other.inc_ov_tp + self.inc_ov_fp += other.inc_ov_fp + self.inc_ov_fn += other.inc_ov_fn + self.errors.extend(other.errors) + self.warnings.extend(other.warnings) + for exp, gots in other.confusion.items(): + dst = self.confusion.setdefault(exp, {}) + for got, n in gots.items(): + dst[got] = dst.get(got, 0) + n + + +# --- Modeles de rapport (serialisables) -------------------------------------- + +class SpeakerError(BaseModel): + index: int # index du segment dans le chapitre + text_excerpt: str + expected: str + got: str + + +class ChapterScore(BaseModel): + index: int # -1 pour l'agregat + n_segments: int = 0 + n_dialogue: int = 0 + # attribution du locuteur + speaker_acc_all: float = 1.0 + speaker_acc_dialogue: float = 1.0 + # incises + incise_exact_p: float = 1.0 + incise_exact_r: float = 1.0 + incise_exact_f1: float = 1.0 + incise_overlap_p: float = 1.0 + incise_overlap_r: float = 1.0 + incise_overlap_f1: float = 1.0 + # type / glued + type_acc: float = 1.0 + glued_acc: float = 1.0 + # detail + errors: list[SpeakerError] = Field(default_factory=list) + confusion: dict[str, dict[str, int]] = Field(default_factory=dict) + alignment_warnings: list[str] = Field(default_factory=list) + + +class ModelScore(BaseModel): + model_id: str + elapsed_s: float = 0.0 + error: Optional[str] = None # rempli si le modele a echoue (chargement, etc.) + per_chapter: list[ChapterScore] = Field(default_factory=list) + aggregate: Optional[ChapterScore] = None + + +class BenchmarkReport(BaseModel): + slug: str + generated_at: str # horodatage pose par la couche I/O (CLI) + chapters: list[int] = Field(default_factory=list) + settings_snapshot: dict = Field(default_factory=dict) + models: list[ModelScore] = Field(default_factory=list) + + +# --- Metriques pures --------------------------------------------------------- + +def _prf(tp: int, fp: int, fn: int) -> tuple[float, float, float]: + p = tp / (tp + fp) if (tp + fp) else 1.0 + r = tp / (tp + fn) if (tp + fn) else 1.0 + f1 = (2 * p * r / (p + r)) if (p + r) else 0.0 + return p, r, f1 + + +def _ratio(correct: int, total: int) -> float: + return correct / total if total else 1.0 + + +def _iou(a: Incise, b: Incise) -> float: + inter = max(0, min(a.end, b.end) - max(a.start, b.start)) + union = (a.end - a.start) + (b.end - b.start) - inter + return inter / union if union else 0.0 + + +def _match_incises(ref: list[Incise], hyp: list[Incise]) -> tuple[int, int, int, int, int, int]: + """Compare deux listes de spans : (exact tp/fp/fn, overlap tp/fp/fn). + + Exact = memes (start, end). Overlap = appariement glouton IoU >= 0.5. + """ + ref_keys = [(i.start, i.end) for i in ref] + hyp_keys = [(i.start, i.end) for i in hyp] + # appariement exact 1:1 (pas de double comptage si doublons improbables) + used = [False] * len(ref_keys) + ex_tp = 0 + for hk in hyp_keys: + for j, rk in enumerate(ref_keys): + if not used[j] and hk == rk: + used[j] = True + ex_tp += 1 + break + ex_fp = len(hyp_keys) - ex_tp + ex_fn = len(ref_keys) - ex_tp + + used = [False] * len(ref) + ov_tp = 0 + for h in hyp: + best_j, best_iou = -1, 0.0 + for j, r in enumerate(ref): + if used[j]: + continue + iou = _iou(r, h) + if iou >= 0.5 and iou > best_iou: + best_j, best_iou = j, iou + if best_j >= 0: + used[best_j] = True + ov_tp += 1 + ov_fp = len(hyp) - ov_tp + ov_fn = len(ref) - ov_tp + return ex_tp, ex_fp, ex_fn, ov_tp, ov_fp, ov_fn + + +def align(ref: ChapterAnalysis, hyp: ChapterAnalysis) -> list[tuple[Optional[Segment], Optional[Segment]]]: + """Aligne les segments hypothese sur la reference. + + Cas nominal (segmentation deterministe) : meme nombre + memes textes -> 1:1. + Sinon, alignement par `difflib.SequenceMatcher` sur les textes normalises ; + les segments orphelins ressortent en paires avec `None`. + """ + rt = [_norm_text(s.text) for s in ref.segments] + ht = [_norm_text(s.text) for s in hyp.segments] + if rt == ht: + return list(zip(ref.segments, hyp.segments)) + + pairs: list[tuple[Optional[Segment], Optional[Segment]]] = [] + sm = difflib.SequenceMatcher(a=rt, b=ht, autojunk=False) + for tag, i1, i2, j1, j2 in sm.get_opcodes(): + if tag == "equal": + for k in range(i2 - i1): + pairs.append((ref.segments[i1 + k], hyp.segments[j1 + k])) + elif tag == "replace": + for k in range(max(i2 - i1, j2 - j1)): + r = ref.segments[i1 + k] if i1 + k < i2 else None + h = hyp.segments[j1 + k] if j1 + k < j2 else None + pairs.append((r, h)) + elif tag == "delete": + for k in range(i1, i2): + pairs.append((ref.segments[k], None)) + elif tag == "insert": + for k in range(j1, j2): + pairs.append((None, hyp.segments[k])) + return pairs + + +def _score_counts(ref: ChapterAnalysis, hyp: ChapterAnalysis, + cast: Optional[Cast]) -> _Counts: + amap = _alias_map(cast) + c = _Counts() + for r, h in align(ref, hyp): + if r is None: + c.warnings.append(f"segment hypothese sans correspondance: {h.text[:60]!r}") + continue + if h is None: + c.warnings.append(f"segment reference non couvert: {r.text[:60]!r}") + continue + + # type + c.type_total += 1 + if r.type == h.type: + c.type_correct += 1 + # glued + c.glued_total += 1 + if r.glued_to_prev == h.glued_to_prev: + c.glued_correct += 1 + + # locuteur + exp = _norm_speaker(r.speaker, amap) + got = _norm_speaker(h.speaker, amap) + c.seg_total += 1 + ok = exp == got + if ok: + c.seg_correct += 1 + if r.type is SegmentType.DIALOGUE: + c.dlg_total += 1 + if ok: + c.dlg_correct += 1 + else: + c.errors.append(SpeakerError( + index=r_index(ref, r), text_excerpt=r.text[:80], + expected=r.speaker, got=h.speaker)) + row = c.confusion.setdefault(r.speaker, {}) + row[h.speaker] = row.get(h.speaker, 0) + 1 + + # incises (sur les dialogues de la reference) + ex_tp, ex_fp, ex_fn, ov_tp, ov_fp, ov_fn = _match_incises(r.incises, h.incises) + c.inc_exact_tp += ex_tp + c.inc_exact_fp += ex_fp + c.inc_exact_fn += ex_fn + c.inc_ov_tp += ov_tp + c.inc_ov_fp += ov_fp + c.inc_ov_fn += ov_fn + return c + + +def r_index(analysis: ChapterAnalysis, seg: Segment) -> int: + """Position d'un segment dans le chapitre (identite d'objet).""" + for i, s in enumerate(analysis.segments): + if s is seg: + return i + return -1 + + +def _counts_to_score(index: int, c: _Counts) -> ChapterScore: + ex_p, ex_r, ex_f1 = _prf(c.inc_exact_tp, c.inc_exact_fp, c.inc_exact_fn) + ov_p, ov_r, ov_f1 = _prf(c.inc_ov_tp, c.inc_ov_fp, c.inc_ov_fn) + return ChapterScore( + index=index, + n_segments=c.seg_total, + n_dialogue=c.dlg_total, + speaker_acc_all=_ratio(c.seg_correct, c.seg_total), + speaker_acc_dialogue=_ratio(c.dlg_correct, c.dlg_total), + incise_exact_p=ex_p, incise_exact_r=ex_r, incise_exact_f1=ex_f1, + incise_overlap_p=ov_p, incise_overlap_r=ov_r, incise_overlap_f1=ov_f1, + type_acc=_ratio(c.type_correct, c.type_total), + glued_acc=_ratio(c.glued_correct, c.glued_total), + errors=c.errors, + confusion=c.confusion, + alignment_warnings=c.warnings, + ) + + +def score_chapter(ref: ChapterAnalysis, hyp: ChapterAnalysis, + cast: Optional[Cast] = None) -> ChapterScore: + """Score une hypothese contre une reference pour un chapitre.""" + return _counts_to_score(ref.index, _score_counts(ref, hyp, cast)) + + +def aggregate(scores: list[ChapterScore], counts: list[_Counts]) -> ChapterScore: + """Micro-moyenne (pooling de tous les segments) sur plusieurs chapitres.""" + total = _Counts() + for c in counts: + total.add(c) + return _counts_to_score(-1, total) + + +# --- Runner multi-modeles ---------------------------------------------------- + +def _reference_chapters(slug: str, chapters: Optional[list[int]]) -> list[int]: + """Index des chapitres disposant d'une reference (filtres par `chapters`).""" + from ..config import book_data_dir + + ref_dir = book_data_dir(slug) / "reference" + found: list[int] = [] + if ref_dir.exists(): + for p in sorted(ref_dir.glob("ch*.json")): + m = re.match(r"ch(\d+)\.json$", p.name) + if m: + found.append(int(m.group(1))) + if chapters is not None: + found = [i for i in found if i in chapters] + return found + + +def _load_reference(slug: str, index: int) -> ChapterAnalysis: + from ..config import book_data_dir + + path = book_data_dir(slug) / "reference" / f"ch{index:02d}.json" + return ChapterAnalysis.model_validate_json(path.read_text(encoding="utf-8")) + + +def _build_model_score(model_id: str, per_chapter: list[ChapterScore], + counts: list[_Counts], elapsed: float) -> ModelScore: + return ModelScore( + model_id=model_id, elapsed_s=elapsed, per_chapter=per_chapter, + aggregate=aggregate(per_chapter, counts) if per_chapter else None, + ) + + +def run_benchmark(slug: str, model_ids: list[str], *, + chapters: Optional[list[int]] = None, + temperature: Optional[float] = None, + reasoning: Optional[bool] = None, + use_cached: bool = False, + progress: Optional[Callable[[str], None]] = None) -> BenchmarkReport: + """Met plusieurs modeles en concurrence sur les chapitres de reference. + + `use_cached=True` : compare les artefacts `analysis/chNN.json` existants (pas + de modele charge ; `model_ids` est ignore, un seul resultat "cache"). + Sinon, pour chaque `model_id`, relance `analyze_chapter` en memoire (sans + `save_analysis`) et score. Un seul MLX reside en RAM a la fois. + + `progress` : callback optionnel appele a chaque etape (chargement, chapitre + analyse, modele termine) pour suivre l'avancement d'un run long. + """ + from ..epub.parser import load_book, load_chapter_text + from ..settings import get_settings + from ..store import artifacts + + emit = progress or (lambda _msg: None) + + targets = _reference_chapters(slug, chapters) + if not targets: + raise ValueError( + f"Aucune reference trouvee pour {slug!r} " + f"(data/{slug}/reference/chNN.json).") + + references = {i: _load_reference(slug, i) for i in targets} + cast = artifacts.load_cast(slug) + settings = get_settings() + + snapshot = { + "gemma_temperature": temperature if temperature is not None + else settings.gemma_temperature, + "gemma_max_tokens": settings.gemma_max_tokens, + "gemma_reasoning": reasoning if reasoning is not None + else settings.gemma_reasoning, + "dedup_use_gemma": settings.dedup_use_gemma, + "retro_pass_use_gemma": settings.retro_pass_use_gemma, + "prompt_speakers_hash": hash(settings.prompt_speakers) & 0xFFFFFFFF, + } + report = BenchmarkReport( + slug=slug, generated_at="", chapters=targets, + settings_snapshot=snapshot) + + if use_cached: + per_chapter, counts = [], [] + for i in targets: + hyp = artifacts.load_analysis(slug, i) + cnt = _score_counts(references[i], hyp, cast) + counts.append(cnt) + per_chapter.append(_counts_to_score(i, cnt)) + report.models.append(_build_model_score("", per_chapter, counts, 0.0)) + return report + + from .gemma import Gemma, _load + from .segmenter import analyze_chapter + + book = load_book(slug) + by_index = {c.index: c for c in book.chapters} + + # Epingle temperature/reasoning en memoire (jamais save_settings -> pas + # d'ecriture disque), restaure en sortie. + original_temp = settings.gemma_temperature + original_reasoning = settings.gemma_reasoning + if temperature is not None: + settings.gemma_temperature = temperature + if reasoning is not None: + settings.gemma_reasoning = reasoning + try: + for mi, model_id in enumerate(model_ids, 1): + t0 = time.perf_counter() + per_chapter, counts = [], [] + model_err: Optional[str] = None + emit(f"[{mi}/{len(model_ids)}] {model_id} — chargement du modele…") + try: + gemma = Gemma(model_id=model_id) + for i in targets: + ch = by_index.get(i) + if ch is None: + continue + emit(f" ch{i:02d} — analyse en cours…") + tc = time.perf_counter() + ct = load_chapter_text(slug, ch) + hyp, _ = analyze_chapter( + ch, ct, gemma, + book_chars=list(cast.characters), dedup_gemma=None) + cnt = _score_counts(references[i], hyp, cast) + counts.append(cnt) + cs = _counts_to_score(i, cnt) + per_chapter.append(cs) + emit(f" ch{i:02d} — OK en {time.perf_counter() - tc:.0f}s " + f"(locuteur dlg {cs.speaker_acc_dialogue:.0%}, " + f"{len(cs.errors)} erreurs)") + except Exception as exc: # noqa: BLE001 — un modele KO ne stoppe pas les autres + model_err = f"{type(exc).__name__}: {exc}" + emit(f" ! echec: {model_err[:120]}") + finally: + _load.cache_clear() # libere le modele avant le suivant + ms = _build_model_score( + model_id, per_chapter, counts, time.perf_counter() - t0) + ms.error = model_err + report.models.append(ms) + if not model_err and ms.aggregate is not None: + emit(f"[{mi}/{len(model_ids)}] {model_id} — termine en " + f"{ms.elapsed_s:.0f}s (locuteur dlg " + f"{ms.aggregate.speaker_acc_dialogue:.1%})") + finally: + settings.gemma_temperature = original_temp + settings.gemma_reasoning = original_reasoning + + return report diff --git a/backend/inkflow/analysis/gemma.py b/backend/inkflow/analysis/gemma.py index fff1954..b23cd50 100644 --- a/backend/inkflow/analysis/gemma.py +++ b/backend/inkflow/analysis/gemma.py @@ -17,6 +17,13 @@ from ..settings import get_settings _JSON_SPAN_RE = re.compile(r"(\{.*\}|\[.*\])", re.DOTALL) _FENCE_RE = re.compile(r"```(?:json)?\s*(.*?)```", re.DOTALL) +# Marqueurs de FIN de chaine de pensee : on ne garde que ce qui suit le dernier. +# - balises type DeepSeek-R1 / Qwen-think +# - format a canaux type Gemma 4 / Harmony (la pensee est close par ) +_REASONING_END_MARKERS = ("", "", "<|channel|>") +# Prefixe de canal/think non ferme reste en tete (pensee tronquee) : a retirer. +_REASONING_OPEN_RE = re.compile(r"^\s*(?:<\|?channel\|?>\s*\w*|)", re.IGNORECASE) + @lru_cache(maxsize=2) def _load(model_id: str): @@ -25,6 +32,46 @@ def _load(model_id: str): return load(model_id) +# Hook de streaming optionnel. Si defini, `generate()` diffuse chaque morceau de +# texte AU FIL de la generation (pensee comprise, avant tout nettoyage) en +# appelant ce callback. Utilise par `inkflow benchmark --stream` pour voir les +# tokens en temps reel. None -> generation par lot classique (plus rapide). +_TOKEN_SINK: Optional[Any] = None + + +def set_token_sink(callback) -> None: + """Definit (ou retire avec None) le callback de streaming des tokens.""" + global _TOKEN_SINK + _TOKEN_SINK = callback + + +def _resolve_chat_template(model_id: str, tokenizer) -> Optional[str]: + """Renvoie un template de chat a passer explicitement, ou None. + + Certaines conversions (Mistral recents...) logent leur template dans un + fichier `chat_template.jinja` que le downloader de mlx-lm n'embarque pas + toujours : `tokenizer.chat_template` est alors vide et `apply_chat_template` + echoue. On recupere alors le fichier officiel du repo. None si le tokenizer + possede deja un template (cas courant) ou si aucun n'est disponible. + """ + if getattr(tokenizer, "chat_template", None): + return None + from pathlib import Path + + from huggingface_hub import hf_hub_download + # Selon les conversions : fichier Jinja brut, ou JSON {"chat_template": ...}. + for fname in ("chat_template.jinja", "chat_template.json"): + try: + text = Path(hf_hub_download(model_id, fname)).read_text(encoding="utf-8") + except Exception: # noqa: BLE001 — fichier absent, on tente le suivant + continue + if fname.endswith(".json"): + data = json.loads(text) + return data.get("chat_template") if isinstance(data, dict) else None + return text + return None # aucun template dispo -> apply_chat_template levera une erreur claire + + class Gemma: """Petite facade autour de mlx-lm pour piloter Gemma.""" @@ -32,10 +79,13 @@ class Gemma: self.model_id = model_id or get_settings().gemma_model self._model = None self._tokenizer = None + self._chat_template = None # template recupere si absent du tokenizer def _ensure_loaded(self) -> None: if self._model is None: self._model, self._tokenizer = _load(self.model_id) + self._chat_template = _resolve_chat_template( + self.model_id, self._tokenizer) def generate( self, @@ -53,27 +103,76 @@ class Gemma: settings = get_settings() if max_tokens is None: max_tokens = settings.gemma_max_tokens + # En mode raisonnement, plafond dedie (garde-fou anti-boucle) ; la + # generation s'arrete de toute facon des que le JSON post-pensee est + # complet (cf. boucle de streaming ci-dessous). + if settings.gemma_reasoning: + max_tokens = max(max_tokens, settings.gemma_reasoning_max_tokens) if temperature is None: temperature = settings.gemma_temperature - from mlx_lm import generate + # Decodage glouton (temp 0) + raisonnement = boucles de pensee sans fin. + # On force un echantillonnage minimal en mode raisonnement. + if settings.gemma_reasoning and temperature == 0.0: + temperature = settings.gemma_reasoning_temperature from mlx_lm.sample_utils import make_sampler messages = [] if system: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": prompt}) + # Modeles hybrides (Qwen3...) : hors mode raisonnement, on DESACTIVE la + # pensee via enable_thinking=False -> JSON direct, bien plus rapide. Avec + # --reasoning, on laisse penser puis on retire la pensee en aval. Ce + # kwarg est ignore par les templates qui ne l'utilisent pas (Gemma...). + template_kwargs = {} + if not settings.gemma_reasoning: + template_kwargs["enable_thinking"] = False formatted = self._tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tokenize=False + messages, add_generation_prompt=True, tokenize=False, + chat_template=self._chat_template, # None -> celui du tokenizer + **template_kwargs, ) sampler = make_sampler(temp=temperature) - return generate( - self._model, - self._tokenizer, - prompt=formatted, - max_tokens=max_tokens, - sampler=sampler, - verbose=False, - ) + # On streame (token par token) si : un sink est branche (--stream) OU on + # est en mode raisonnement (pour pouvoir s'arreter des que la reponse est + # prete, sans subir les boucles de pensee sans fin). Sinon, lot rapide. + if _TOKEN_SINK is not None or settings.gemma_reasoning: + from mlx_lm import stream_generate + parts = [] + seen_end = False # marqueur de fin de pensee rencontre + for resp in stream_generate( + self._model, self._tokenizer, prompt=formatted, + max_tokens=max_tokens, sampler=sampler, + ): + parts.append(resp.text) + if _TOKEN_SINK is not None: + _TOKEN_SINK(resp.text) + # Arret anticipe : une fois la pensee close, des que le JSON + # post-pensee est complet, inutile de continuer a generer. + if settings.gemma_reasoning and ("}" in resp.text or "]" in resp.text): + buf = "".join(parts) + if not seen_end: + seen_end = any(mk in buf for mk in _REASONING_END_MARKERS) + if seen_end and _has_complete_json(_strip_reasoning(buf)): + break + if _TOKEN_SINK is not None: + _TOKEN_SINK("\n") # separe les generations successives + raw = "".join(parts) + else: + from mlx_lm import generate + raw = generate( + self._model, + self._tokenizer, + prompt=formatted, + max_tokens=max_tokens, + sampler=sampler, + verbose=False, + ) + # Retire la chaine de pensee des modeles a raisonnement (sinon des + # fragments de la "pensee" parasitent l'extraction JSON en aval). + if settings.gemma_reasoning: + return _strip_reasoning(raw) + return raw def generate_json( self, @@ -101,6 +200,35 @@ class Gemma: raise ValueError(f"Reponse JSON invalide apres {retries + 1} essais: {last_err}") +def _strip_reasoning(text: str) -> str: + """Retire la chaine de pensee d'un modele a raisonnement. + + Ne conserve que ce qui suit le dernier marqueur de fin de pensee + (``, ``...). Si seul un marqueur d'ouverture non ferme + subsiste (pensee tronquee par le budget de tokens), on le retire en tete + pour eviter de parser la pensee a la place de la reponse. + """ + t = text + for marker in _REASONING_END_MARKERS: + if marker in t: + t = t.rsplit(marker, 1)[-1] + t = _REASONING_OPEN_RE.sub("", t) + return t.strip() + + +def _has_complete_json(text: str) -> bool: + """True si `text` contient deja un objet/array JSON complet et parsable. + + Sert a stopper la generation des modeles a raisonnement des que la reponse + finale est ecrite (evite de consommer le budget en boucles de pensee). + """ + try: + _extract_json(text) + return True + except Exception: # noqa: BLE001 + return False + + def _extract_json(text: str) -> Any: """Extrait le premier objet/array JSON d'une reponse libre du modele. diff --git a/backend/inkflow/cli.py b/backend/inkflow/cli.py index f08d947..36c8d57 100644 --- a/backend/inkflow/cli.py +++ b/backend/inkflow/cli.py @@ -13,7 +13,7 @@ import typer from rich.console import Console from rich.table import Table -from .config import ensure_dirs +from .config import book_data_dir, ensure_dirs from .epub.parser import load_book, load_chapter_text, parse_epub from .models import Cast from .store import artifacts @@ -95,6 +95,104 @@ def analyze( console.print(f"[green]Casting[/] : {len(chars)} personnages -> cast.json") +@app.command() +def benchmark( + slug: str, + models: Optional[str] = typer.Option( + None, help="Modeles a comparer, separes par des virgules (def: modele courant)."), + chapter: Optional[int] = typer.Option( + None, help="Restreindre a un chapitre (def: tous ceux avec reference)."), + temperature: Optional[float] = typer.Option( + None, help="Epingle la temperature Gemma (repro). Ex: 0.0."), + reasoning: bool = typer.Option( + False, "--reasoning", help="Modeles a raisonnement : retire la pensee + budget tokens accru."), + use_cached: bool = typer.Option( + False, "--use-cached", help="Compare les analysis/chNN.json existants (pas de modele)."), + stream: bool = typer.Option( + False, "--stream", help="Affiche les tokens generes en temps reel (pensee + reponse)."), +): + """Met des modeles en concurrence sur les chapitres de reference (vs reference/).""" + import sys + from datetime import datetime + + from .analysis import gemma as _gemma + from .analysis.benchmark import run_benchmark + from .settings import get_settings + + model_ids = ([m.strip() for m in models.split(",") if m.strip()] + if models else [get_settings().gemma_model]) + chapters = [chapter] if chapter is not None else None + + label = "artefacts en cache" if use_cached else f"{len(model_ids)} modele(s)" + console.print(f"[blue]Benchmark[/] {slug} ({label}) — suivi par chapitre :") + # Suivi en clair (lignes persistantes), avec horodatage pour voir l'avancement + # d'un run long. On evite console.status (spinner) qui n'imprime rien. + def _progress(msg: str) -> None: + from datetime import datetime as _dt + console.print(f"[dim]{_dt.now():%H:%M:%S}[/] {msg}") + + # Streaming des tokens : ecriture brute sur stdout (sans markup rich) pour + # voir defiler pensee et reponse. Necessite stdout non bufferise cote shell. + if stream: + def _sink(piece: str) -> None: + sys.stdout.write(piece) + sys.stdout.flush() + _gemma.set_token_sink(_sink) + try: + report = run_benchmark( + slug, model_ids, chapters=chapters, + temperature=temperature, + reasoning=reasoning if reasoning else None, + use_cached=use_cached, + progress=_progress) + finally: + if stream: + _gemma.set_token_sink(None) + report.generated_at = datetime.now().isoformat(timespec="seconds") + + # Table comparative : une ligne par modele (agregat micro-moyenne). + table = Table(title=f"Benchmark {slug} — chapitres {report.chapters}") + table.add_column("modele") + for col in ("speaker_dlg", "speaker_all", "incise_f1", "type", "glued", "temps(s)"): + table.add_column(col, justify="right") + for ms in report.models: + if ms.error or ms.aggregate is None: + table.add_row(ms.model_id, f"[red]{ms.error or 'aucun chapitre'}[/]") + continue + a = ms.aggregate + table.add_row( + ms.model_id, + f"{a.speaker_acc_dialogue:.1%}", + f"{a.speaker_acc_all:.1%}", + f"{a.incise_overlap_f1:.2f}", + f"{a.type_acc:.1%}", + f"{a.glued_acc:.1%}", + f"{ms.elapsed_s:.0f}", + ) + console.print(table) + + # Detail des erreurs d'attribution (les pires) par modele. + for ms in report.models: + errs = [e for cs in ms.per_chapter for e in cs.errors] + if not errs: + continue + console.print(f"\n[bold]{ms.model_id}[/] — {len(errs)} erreur(s) de locuteur:") + for e in errs[:15]: + console.print( + f" ch·seg{e.index:>3} attendu=[green]{e.expected}[/] " + f"obtenu=[red]{e.got}[/] — {e.text_excerpt!r}") + if len(errs) > 15: + console.print(f" [dim]… +{len(errs) - 15} autres[/]") + + # Rapport JSON horodate. + out_dir = book_data_dir(slug) / "benchmark" + out_dir.mkdir(parents=True, exist_ok=True) + stamp = report.generated_at.replace(":", "").replace("-", "") + out_path = out_dir / f"{stamp}.json" + out_path.write_text(report.model_dump_json(indent=2), encoding="utf-8") + console.print(f"\n[green]Rapport[/] -> {out_path}") + + @app.command() def pronounce( slug: str, diff --git a/backend/inkflow/settings.py b/backend/inkflow/settings.py index 434a65e..3d6cf6a 100644 --- a/backend/inkflow/settings.py +++ b/backend/inkflow/settings.py @@ -74,7 +74,22 @@ class Settings(BaseModel): # --- Generation Gemma --- gemma_temperature: float = Field(0.1, ge=0.0, le=2.0) - gemma_max_tokens: int = Field(2048, ge=64, le=8192) + gemma_max_tokens: int = Field(2048, ge=64, le=16384) + # Modeles a raisonnement (Gemma 4, DeepSeek-R1, Qwen-think...) : ils emettent + # une chaine de pensee avant la reponse. Active le retrait de cette pensee + # (canaux <|channel>thought.../, balises ...) AVANT + # le parsing JSON, et releve le plafond de tokens (la pensee en consomme). + gemma_reasoning: bool = False + # Plafond de tokens en mode raisonnement (la pensee en consomme beaucoup). + # La generation s'arrete de toute facon des que la reponse JSON post-pensee + # est complete ; ce plafond est un garde-fou contre les boucles de pensee + # sans fin (certains modeles tournent en rond a temperature 0). + gemma_reasoning_max_tokens: int = Field(4096, ge=256, le=16384) + # Temperature en mode raisonnement. Le decodage GLOUTON (temp 0) fait boucler + # les modeles a raisonnement (repetitions sans fin) ; Qwen & co recommandent + # un echantillonnage. Si la temperature effective est 0, on bascule sur + # celle-ci. Rend le benchmark non deterministe en mode raisonnement (inevitable). + gemma_reasoning_temperature: float = Field(0.6, ge=0.0, le=2.0) # --- Prompts systeme (analyse) --- prompt_speakers: str = DEFAULT_PROMPT_SPEAKERS diff --git a/backend/tests/test_benchmark.py b/backend/tests/test_benchmark.py new file mode 100644 index 0000000..7a4565c --- /dev/null +++ b/backend/tests/test_benchmark.py @@ -0,0 +1,159 @@ +"""Tests purs du scoring de benchmark (sans Gemma ni disque). + +Monte des `ChapterAnalysis` synthetiques et verifie les metriques : +alignement, attribution du locuteur (avec normalisation d'alias), incises +(exact vs chevauchement), type/glued, et micro-moyenne sur plusieurs chapitres. +""" +from __future__ import annotations + +from inkflow.analysis.benchmark import ( + _score_counts, + align, + aggregate, + score_chapter, +) +from inkflow.models import ( + Cast, + Character, + ChapterAnalysis, + Incise, + Segment, + SegmentType, +) + + +def _seg(text, *, type="narration", speaker="narrateur", glued=False, incises=None): + return Segment( + type=SegmentType(type), text=text, speaker=speaker, + glued_to_prev=glued, incises=incises or []) + + +def _chap(index, segments): + return ChapterAnalysis(index=index, title=f"ch{index}", segments=segments) + + +# --- Alignement -------------------------------------------------------------- + +def test_alignement_1_1_textes_identiques(): + ref = _chap(5, [_seg("Bonjour."), _seg("— Salut.", type="dialogue")]) + hyp = _chap(5, [_seg("bonjour. "), _seg("— salut.", type="dialogue")]) # espaces/casse + pairs = align(ref, hyp) + assert len(pairs) == 2 + assert all(r is not None and h is not None for r, h in pairs) + + +def test_alignement_segment_hypothese_en_trop(): + ref = _chap(5, [_seg("A"), _seg("B")]) + hyp = _chap(5, [_seg("A"), _seg("X"), _seg("B")]) + cnt = _score_counts(ref, hyp, None) + # le segment "X" non couvert par la reference -> warning + assert any("sans correspondance" in w for w in cnt.warnings) + + +# --- Attribution du locuteur ------------------------------------------------- + +def test_speaker_parfait(): + segs = [ + _seg("narration"), + _seg("— Bonjour.", type="dialogue", speaker="Holden"), + _seg("— Salut.", type="dialogue", speaker="Kajri"), + ] + ref = _chap(5, [s.model_copy(deep=True) for s in segs]) + hyp = _chap(5, [s.model_copy(deep=True) for s in segs]) + score = score_chapter(ref, hyp) + assert score.speaker_acc_dialogue == 1.0 + assert score.speaker_acc_all == 1.0 + assert score.errors == [] + + +def test_speaker_avec_erreurs(): + ref = _chap(5, [ + _seg("— A.", type="dialogue", speaker="Holden"), + _seg("— B.", type="dialogue", speaker="Kajri"), + ]) + hyp = _chap(5, [ + _seg("— A.", type="dialogue", speaker="Holden"), + _seg("— B.", type="dialogue", speaker="Drummer"), # faux + ]) + score = score_chapter(ref, hyp) + assert score.speaker_acc_dialogue == 0.5 + assert len(score.errors) == 1 + assert score.errors[0].expected == "Kajri" + assert score.errors[0].got == "Drummer" + assert score.confusion["Kajri"]["Drummer"] == 1 + + +def test_speaker_normalisation_alias(): + ref = _chap(5, [_seg("— A.", type="dialogue", speaker="Camina Drummer")]) + hyp = _chap(5, [_seg("— A.", type="dialogue", speaker="Drummer")]) # alias + cast = Cast(characters=[Character(name="Camina Drummer", aliases=["Drummer"])]) + # sans cast : compte comme une erreur (noms differents) + assert score_chapter(ref, hyp, None).speaker_acc_dialogue == 0.0 + # avec cast : l'alias est resolu -> correct + assert score_chapter(ref, hyp, cast).speaker_acc_dialogue == 1.0 + + +# --- Incises ----------------------------------------------------------------- + +def test_incises_exact_vs_overlap(): + ref = _chap(5, [_seg("— A dit-il.", type="dialogue", speaker="X", + incises=[Incise(start=4, end=11)])]) + # span decale mais largement chevauchant -> overlap ok, exact non + hyp = _chap(5, [_seg("— A dit-il.", type="dialogue", speaker="X", + incises=[Incise(start=4, end=10)])]) + score = score_chapter(ref, hyp) + assert score.incise_exact_f1 < 1.0 + assert score.incise_overlap_f1 == 1.0 + + +def test_incises_faux_positif_baisse_precision(): + ref = _chap(5, [_seg("— A.", type="dialogue", speaker="X", incises=[])]) + hyp = _chap(5, [_seg("— A.", type="dialogue", speaker="X", + incises=[Incise(start=0, end=3)])]) # invente une incise + score = score_chapter(ref, hyp) + assert score.incise_overlap_p < 1.0 + assert score.incise_overlap_r == 1.0 # rien a rappeler + + +def test_incises_manque_baisse_rappel(): + ref = _chap(5, [_seg("— A dit-il.", type="dialogue", speaker="X", + incises=[Incise(start=4, end=11)])]) + hyp = _chap(5, [_seg("— A dit-il.", type="dialogue", speaker="X", incises=[])]) + score = score_chapter(ref, hyp) + assert score.incise_overlap_r < 1.0 + assert score.incise_overlap_p == 1.0 + + +# --- Type / glued ------------------------------------------------------------ + +def test_type_et_glued(): + ref = _chap(5, [_seg("A", type="narration"), _seg("— B", type="dialogue", glued=True)]) + hyp = _chap(5, [_seg("A", type="dialogue"), _seg("— B", type="dialogue", glued=False)]) + score = score_chapter(ref, hyp) + assert score.type_acc == 0.5 + assert score.glued_acc == 0.5 + + +# --- Agregat (micro-moyenne) ------------------------------------------------- + +def test_aggregate_micro_moyenne(): + # ch1 : 1 dialogue correct ; ch2 : 3 dialogues dont 1 faux + ref1 = _chap(1, [_seg("— A.", type="dialogue", speaker="X")]) + hyp1 = _chap(1, [_seg("— A.", type="dialogue", speaker="X")]) + ref2 = _chap(2, [ + _seg("— B.", type="dialogue", speaker="X"), + _seg("— C.", type="dialogue", speaker="Y"), + _seg("— D.", type="dialogue", speaker="Z"), + ]) + hyp2 = _chap(2, [ + _seg("— B.", type="dialogue", speaker="X"), + _seg("— C.", type="dialogue", speaker="Y"), + _seg("— D.", type="dialogue", speaker="WRONG"), + ]) + c1, c2 = _score_counts(ref1, hyp1, None), _score_counts(ref2, hyp2, None) + s1, s2 = score_chapter(ref1, hyp1), score_chapter(ref2, hyp2) + agg = aggregate([s1, s2], [c1, c2]) + # micro : 3 corrects / 4 dialogues = 0.75 (et non moyenne de 1.0 et 0.666) + assert agg.n_dialogue == 4 + assert abs(agg.speaker_acc_dialogue - 0.75) < 1e-9 + assert agg.index == -1 diff --git a/backend/tests/test_gemma_reasoning.py b/backend/tests/test_gemma_reasoning.py new file mode 100644 index 0000000..94337a7 --- /dev/null +++ b/backend/tests/test_gemma_reasoning.py @@ -0,0 +1,67 @@ +"""Tests purs de `_strip_reasoning` (retrait de la chaine de pensee). + +Sans charger de modele : on verifie que la pensee est retiree et que +`_extract_json` recupere bien la reponse FINALE (et non un fragment JSON +parasite present dans la pensee). +""" +from __future__ import annotations + +from inkflow.analysis.gemma import ( + _extract_json, + _has_complete_json, + _strip_reasoning, +) + + +def test_has_complete_json_arret_anticipe(): + # JSON complet -> True (on peut stopper la generation) + assert _has_complete_json('voici: {"speaker": "Marie"}') + assert _has_complete_json('[{"a": 1}]') + # JSON tronque (reponse pas encore finie) -> False (on continue) + assert not _has_complete_json('{"speaker": "Mar') + assert not _has_complete_json('texte sans json') + # cas streaming reel : pensee close + fence json en cours mais objet complet + buf = _strip_reasoning('...```json\n{"speaker": "Marie"}') + assert _has_complete_json(buf) + + +def test_format_a_canaux_gemma4(): + raw = ( + "<|channel>thought\n" + "Thinking Process: la capitale est Paris. Exemple: {\"capitale\": \"...\"}\n" + "```json\n{\"capitale\": \"Paris\"}\n```" + ) + cleaned = _strip_reasoning(raw) + # la pensee (et son JSON d'exemple parasite) a disparu + assert "Thinking Process" not in cleaned + assert '"..."' not in cleaned + # le JSON extrait est bien la reponse finale + assert _extract_json(cleaned) == {"capitale": "Paris"} + + +def test_balises_think_deepseek(): + raw = "je reflechis, peut-etre [1,2]\n[{\"speaker\": \"Holden\"}]" + cleaned = _strip_reasoning(raw) + assert "reflechis" not in cleaned + assert _extract_json(cleaned) == [{"speaker": "Holden"}] + + +def test_sans_raisonnement_inchange(): + raw = '{"speaker": "Kajri"}' + assert _strip_reasoning(raw) == raw + assert _extract_json(_strip_reasoning(raw)) == {"speaker": "Kajri"} + + +def test_pensee_tronquee_sans_fermeture(): + # pensee non fermee (budget de tokens epuise) : le prefixe de canal saute, + # on ne renvoie pas le marqueur d'ouverture. + raw = "<|channel>thought\nje commence a reflechir mais c'est coupe" + cleaned = _strip_reasoning(raw) + assert not cleaned.startswith("<|channel") + assert "