Ajout d'un outil de benchmark des modèles d'analyse + support des modèles à raisonnement
- Nouvelle commande `inkflow benchmark` : compare la sortie d'analyse aux fichiers de référence (data/<slug>/reference/), met plusieurs modèles en concurrence, table rich + rapport JSON. Métriques : attribution de locuteur, incises, type/glued. Flags --models, --temperature, --reasoning, --stream, --use-cached + suivi par chapitre. - analysis/benchmark.py : scoring pur (testable) + runner multi-modèles (un MLX à la fois). - gemma.py : support des modèles à raisonnement (retrait de la pensée, désactivation via enable_thinking hors --reasoning, arrêt anticipé sur JSON complet, plafond + température dédiés anti-boucle), récupération du chat_template manquant (fix Mistral), streaming des tokens (set_token_sink). - settings.py : gemma_reasoning, gemma_reasoning_max_tokens, gemma_reasoning_temperature. - Tests : test_benchmark.py (scoring pur), test_gemma_reasoning.py. Conclusion benchmark : Qwen3.6-27B-8bit non-raisonnant = meilleur modèle d'analyse.
This commit is contained in:
473
backend/inkflow/analysis/benchmark.py
Normal file
473
backend/inkflow/analysis/benchmark.py
Normal file
@@ -0,0 +1,473 @@
|
||||
"""Benchmark des modeles d'analyse contre les fichiers de reference.
|
||||
|
||||
Les fichiers `data/<slug>/reference/chNN.json` sont des verites terrain corrigees
|
||||
a la main (meme schema `ChapterAnalysis`). Ce module compare la sortie d'un
|
||||
modele (`hypothese`) a ces references et chiffre la qualite sur trois dimensions :
|
||||
|
||||
1. **Attribution du locuteur** (le point faible du petit modele local) ;
|
||||
2. **Incises** (bornes start/end dans la replique) ;
|
||||
3. **Type narration/dialogue** et flag `glued_to_prev` (garde-fou de regression).
|
||||
|
||||
Le scoring (`score_chapter`/`aggregate`) est **pur** : aucune dependance MLX ni
|
||||
disque, teste comme `analysis.segmenter`. Le runner (`run_benchmark`) met
|
||||
plusieurs modeles en concurrence : il relance `analyze_chapter` en memoire (sans
|
||||
ecraser les artefacts) avec chaque `model_id`, score, libere le modele, enchaine.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..models import Cast, ChapterAnalysis, Incise, Segment, SegmentType
|
||||
|
||||
# --- Normalisation -----------------------------------------------------------
|
||||
|
||||
_WS_RE = re.compile(r"\s+")
|
||||
|
||||
|
||||
def _norm_text(text: str) -> str:
|
||||
"""Texte normalise pour l'alignement (insensible aux espaces/casse)."""
|
||||
return _WS_RE.sub(" ", text).strip().casefold()
|
||||
|
||||
|
||||
def _alias_map(cast: Optional[Cast]) -> dict[str, str]:
|
||||
"""alias/nom (casefold) -> nom canonique, pour ne pas penaliser les variantes."""
|
||||
mapping: dict[str, str] = {}
|
||||
if cast is None:
|
||||
return mapping
|
||||
for c in cast.characters:
|
||||
canon = c.name.strip()
|
||||
mapping[canon.casefold()] = canon
|
||||
for alias in c.aliases:
|
||||
mapping[alias.strip().casefold()] = canon
|
||||
return mapping
|
||||
|
||||
|
||||
def _norm_speaker(name: str, alias_map: dict[str, str]) -> str:
|
||||
key = (name or "").strip().casefold()
|
||||
return alias_map.get(key, key)
|
||||
|
||||
|
||||
# --- Comptes bruts (permettent une micro-moyenne sur plusieurs chapitres) ----
|
||||
|
||||
@dataclass
|
||||
class _Counts:
|
||||
seg_total: int = 0
|
||||
seg_correct: int = 0 # locuteur correct (tous segments)
|
||||
dlg_total: int = 0
|
||||
dlg_correct: int = 0 # locuteur correct (dialogues seuls)
|
||||
type_total: int = 0
|
||||
type_correct: int = 0
|
||||
glued_total: int = 0
|
||||
glued_correct: int = 0
|
||||
inc_exact_tp: int = 0
|
||||
inc_exact_fp: int = 0
|
||||
inc_exact_fn: int = 0
|
||||
inc_ov_tp: int = 0
|
||||
inc_ov_fp: int = 0
|
||||
inc_ov_fn: int = 0
|
||||
errors: list["SpeakerError"] = field(default_factory=list)
|
||||
confusion: dict[str, dict[str, int]] = field(default_factory=dict)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
def add(self, other: "_Counts") -> None:
|
||||
self.seg_total += other.seg_total
|
||||
self.seg_correct += other.seg_correct
|
||||
self.dlg_total += other.dlg_total
|
||||
self.dlg_correct += other.dlg_correct
|
||||
self.type_total += other.type_total
|
||||
self.type_correct += other.type_correct
|
||||
self.glued_total += other.glued_total
|
||||
self.glued_correct += other.glued_correct
|
||||
self.inc_exact_tp += other.inc_exact_tp
|
||||
self.inc_exact_fp += other.inc_exact_fp
|
||||
self.inc_exact_fn += other.inc_exact_fn
|
||||
self.inc_ov_tp += other.inc_ov_tp
|
||||
self.inc_ov_fp += other.inc_ov_fp
|
||||
self.inc_ov_fn += other.inc_ov_fn
|
||||
self.errors.extend(other.errors)
|
||||
self.warnings.extend(other.warnings)
|
||||
for exp, gots in other.confusion.items():
|
||||
dst = self.confusion.setdefault(exp, {})
|
||||
for got, n in gots.items():
|
||||
dst[got] = dst.get(got, 0) + n
|
||||
|
||||
|
||||
# --- Modeles de rapport (serialisables) --------------------------------------
|
||||
|
||||
class SpeakerError(BaseModel):
|
||||
index: int # index du segment dans le chapitre
|
||||
text_excerpt: str
|
||||
expected: str
|
||||
got: str
|
||||
|
||||
|
||||
class ChapterScore(BaseModel):
|
||||
index: int # -1 pour l'agregat
|
||||
n_segments: int = 0
|
||||
n_dialogue: int = 0
|
||||
# attribution du locuteur
|
||||
speaker_acc_all: float = 1.0
|
||||
speaker_acc_dialogue: float = 1.0
|
||||
# incises
|
||||
incise_exact_p: float = 1.0
|
||||
incise_exact_r: float = 1.0
|
||||
incise_exact_f1: float = 1.0
|
||||
incise_overlap_p: float = 1.0
|
||||
incise_overlap_r: float = 1.0
|
||||
incise_overlap_f1: float = 1.0
|
||||
# type / glued
|
||||
type_acc: float = 1.0
|
||||
glued_acc: float = 1.0
|
||||
# detail
|
||||
errors: list[SpeakerError] = Field(default_factory=list)
|
||||
confusion: dict[str, dict[str, int]] = Field(default_factory=dict)
|
||||
alignment_warnings: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ModelScore(BaseModel):
|
||||
model_id: str
|
||||
elapsed_s: float = 0.0
|
||||
error: Optional[str] = None # rempli si le modele a echoue (chargement, etc.)
|
||||
per_chapter: list[ChapterScore] = Field(default_factory=list)
|
||||
aggregate: Optional[ChapterScore] = None
|
||||
|
||||
|
||||
class BenchmarkReport(BaseModel):
|
||||
slug: str
|
||||
generated_at: str # horodatage pose par la couche I/O (CLI)
|
||||
chapters: list[int] = Field(default_factory=list)
|
||||
settings_snapshot: dict = Field(default_factory=dict)
|
||||
models: list[ModelScore] = Field(default_factory=list)
|
||||
|
||||
|
||||
# --- Metriques pures ---------------------------------------------------------
|
||||
|
||||
def _prf(tp: int, fp: int, fn: int) -> tuple[float, float, float]:
|
||||
p = tp / (tp + fp) if (tp + fp) else 1.0
|
||||
r = tp / (tp + fn) if (tp + fn) else 1.0
|
||||
f1 = (2 * p * r / (p + r)) if (p + r) else 0.0
|
||||
return p, r, f1
|
||||
|
||||
|
||||
def _ratio(correct: int, total: int) -> float:
|
||||
return correct / total if total else 1.0
|
||||
|
||||
|
||||
def _iou(a: Incise, b: Incise) -> float:
|
||||
inter = max(0, min(a.end, b.end) - max(a.start, b.start))
|
||||
union = (a.end - a.start) + (b.end - b.start) - inter
|
||||
return inter / union if union else 0.0
|
||||
|
||||
|
||||
def _match_incises(ref: list[Incise], hyp: list[Incise]) -> tuple[int, int, int, int, int, int]:
|
||||
"""Compare deux listes de spans : (exact tp/fp/fn, overlap tp/fp/fn).
|
||||
|
||||
Exact = memes (start, end). Overlap = appariement glouton IoU >= 0.5.
|
||||
"""
|
||||
ref_keys = [(i.start, i.end) for i in ref]
|
||||
hyp_keys = [(i.start, i.end) for i in hyp]
|
||||
# appariement exact 1:1 (pas de double comptage si doublons improbables)
|
||||
used = [False] * len(ref_keys)
|
||||
ex_tp = 0
|
||||
for hk in hyp_keys:
|
||||
for j, rk in enumerate(ref_keys):
|
||||
if not used[j] and hk == rk:
|
||||
used[j] = True
|
||||
ex_tp += 1
|
||||
break
|
||||
ex_fp = len(hyp_keys) - ex_tp
|
||||
ex_fn = len(ref_keys) - ex_tp
|
||||
|
||||
used = [False] * len(ref)
|
||||
ov_tp = 0
|
||||
for h in hyp:
|
||||
best_j, best_iou = -1, 0.0
|
||||
for j, r in enumerate(ref):
|
||||
if used[j]:
|
||||
continue
|
||||
iou = _iou(r, h)
|
||||
if iou >= 0.5 and iou > best_iou:
|
||||
best_j, best_iou = j, iou
|
||||
if best_j >= 0:
|
||||
used[best_j] = True
|
||||
ov_tp += 1
|
||||
ov_fp = len(hyp) - ov_tp
|
||||
ov_fn = len(ref) - ov_tp
|
||||
return ex_tp, ex_fp, ex_fn, ov_tp, ov_fp, ov_fn
|
||||
|
||||
|
||||
def align(ref: ChapterAnalysis, hyp: ChapterAnalysis) -> list[tuple[Optional[Segment], Optional[Segment]]]:
|
||||
"""Aligne les segments hypothese sur la reference.
|
||||
|
||||
Cas nominal (segmentation deterministe) : meme nombre + memes textes -> 1:1.
|
||||
Sinon, alignement par `difflib.SequenceMatcher` sur les textes normalises ;
|
||||
les segments orphelins ressortent en paires avec `None`.
|
||||
"""
|
||||
rt = [_norm_text(s.text) for s in ref.segments]
|
||||
ht = [_norm_text(s.text) for s in hyp.segments]
|
||||
if rt == ht:
|
||||
return list(zip(ref.segments, hyp.segments))
|
||||
|
||||
pairs: list[tuple[Optional[Segment], Optional[Segment]]] = []
|
||||
sm = difflib.SequenceMatcher(a=rt, b=ht, autojunk=False)
|
||||
for tag, i1, i2, j1, j2 in sm.get_opcodes():
|
||||
if tag == "equal":
|
||||
for k in range(i2 - i1):
|
||||
pairs.append((ref.segments[i1 + k], hyp.segments[j1 + k]))
|
||||
elif tag == "replace":
|
||||
for k in range(max(i2 - i1, j2 - j1)):
|
||||
r = ref.segments[i1 + k] if i1 + k < i2 else None
|
||||
h = hyp.segments[j1 + k] if j1 + k < j2 else None
|
||||
pairs.append((r, h))
|
||||
elif tag == "delete":
|
||||
for k in range(i1, i2):
|
||||
pairs.append((ref.segments[k], None))
|
||||
elif tag == "insert":
|
||||
for k in range(j1, j2):
|
||||
pairs.append((None, hyp.segments[k]))
|
||||
return pairs
|
||||
|
||||
|
||||
def _score_counts(ref: ChapterAnalysis, hyp: ChapterAnalysis,
|
||||
cast: Optional[Cast]) -> _Counts:
|
||||
amap = _alias_map(cast)
|
||||
c = _Counts()
|
||||
for r, h in align(ref, hyp):
|
||||
if r is None:
|
||||
c.warnings.append(f"segment hypothese sans correspondance: {h.text[:60]!r}")
|
||||
continue
|
||||
if h is None:
|
||||
c.warnings.append(f"segment reference non couvert: {r.text[:60]!r}")
|
||||
continue
|
||||
|
||||
# type
|
||||
c.type_total += 1
|
||||
if r.type == h.type:
|
||||
c.type_correct += 1
|
||||
# glued
|
||||
c.glued_total += 1
|
||||
if r.glued_to_prev == h.glued_to_prev:
|
||||
c.glued_correct += 1
|
||||
|
||||
# locuteur
|
||||
exp = _norm_speaker(r.speaker, amap)
|
||||
got = _norm_speaker(h.speaker, amap)
|
||||
c.seg_total += 1
|
||||
ok = exp == got
|
||||
if ok:
|
||||
c.seg_correct += 1
|
||||
if r.type is SegmentType.DIALOGUE:
|
||||
c.dlg_total += 1
|
||||
if ok:
|
||||
c.dlg_correct += 1
|
||||
else:
|
||||
c.errors.append(SpeakerError(
|
||||
index=r_index(ref, r), text_excerpt=r.text[:80],
|
||||
expected=r.speaker, got=h.speaker))
|
||||
row = c.confusion.setdefault(r.speaker, {})
|
||||
row[h.speaker] = row.get(h.speaker, 0) + 1
|
||||
|
||||
# incises (sur les dialogues de la reference)
|
||||
ex_tp, ex_fp, ex_fn, ov_tp, ov_fp, ov_fn = _match_incises(r.incises, h.incises)
|
||||
c.inc_exact_tp += ex_tp
|
||||
c.inc_exact_fp += ex_fp
|
||||
c.inc_exact_fn += ex_fn
|
||||
c.inc_ov_tp += ov_tp
|
||||
c.inc_ov_fp += ov_fp
|
||||
c.inc_ov_fn += ov_fn
|
||||
return c
|
||||
|
||||
|
||||
def r_index(analysis: ChapterAnalysis, seg: Segment) -> int:
|
||||
"""Position d'un segment dans le chapitre (identite d'objet)."""
|
||||
for i, s in enumerate(analysis.segments):
|
||||
if s is seg:
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
def _counts_to_score(index: int, c: _Counts) -> ChapterScore:
|
||||
ex_p, ex_r, ex_f1 = _prf(c.inc_exact_tp, c.inc_exact_fp, c.inc_exact_fn)
|
||||
ov_p, ov_r, ov_f1 = _prf(c.inc_ov_tp, c.inc_ov_fp, c.inc_ov_fn)
|
||||
return ChapterScore(
|
||||
index=index,
|
||||
n_segments=c.seg_total,
|
||||
n_dialogue=c.dlg_total,
|
||||
speaker_acc_all=_ratio(c.seg_correct, c.seg_total),
|
||||
speaker_acc_dialogue=_ratio(c.dlg_correct, c.dlg_total),
|
||||
incise_exact_p=ex_p, incise_exact_r=ex_r, incise_exact_f1=ex_f1,
|
||||
incise_overlap_p=ov_p, incise_overlap_r=ov_r, incise_overlap_f1=ov_f1,
|
||||
type_acc=_ratio(c.type_correct, c.type_total),
|
||||
glued_acc=_ratio(c.glued_correct, c.glued_total),
|
||||
errors=c.errors,
|
||||
confusion=c.confusion,
|
||||
alignment_warnings=c.warnings,
|
||||
)
|
||||
|
||||
|
||||
def score_chapter(ref: ChapterAnalysis, hyp: ChapterAnalysis,
|
||||
cast: Optional[Cast] = None) -> ChapterScore:
|
||||
"""Score une hypothese contre une reference pour un chapitre."""
|
||||
return _counts_to_score(ref.index, _score_counts(ref, hyp, cast))
|
||||
|
||||
|
||||
def aggregate(scores: list[ChapterScore], counts: list[_Counts]) -> ChapterScore:
|
||||
"""Micro-moyenne (pooling de tous les segments) sur plusieurs chapitres."""
|
||||
total = _Counts()
|
||||
for c in counts:
|
||||
total.add(c)
|
||||
return _counts_to_score(-1, total)
|
||||
|
||||
|
||||
# --- Runner multi-modeles ----------------------------------------------------
|
||||
|
||||
def _reference_chapters(slug: str, chapters: Optional[list[int]]) -> list[int]:
|
||||
"""Index des chapitres disposant d'une reference (filtres par `chapters`)."""
|
||||
from ..config import book_data_dir
|
||||
|
||||
ref_dir = book_data_dir(slug) / "reference"
|
||||
found: list[int] = []
|
||||
if ref_dir.exists():
|
||||
for p in sorted(ref_dir.glob("ch*.json")):
|
||||
m = re.match(r"ch(\d+)\.json$", p.name)
|
||||
if m:
|
||||
found.append(int(m.group(1)))
|
||||
if chapters is not None:
|
||||
found = [i for i in found if i in chapters]
|
||||
return found
|
||||
|
||||
|
||||
def _load_reference(slug: str, index: int) -> ChapterAnalysis:
|
||||
from ..config import book_data_dir
|
||||
|
||||
path = book_data_dir(slug) / "reference" / f"ch{index:02d}.json"
|
||||
return ChapterAnalysis.model_validate_json(path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _build_model_score(model_id: str, per_chapter: list[ChapterScore],
|
||||
counts: list[_Counts], elapsed: float) -> ModelScore:
|
||||
return ModelScore(
|
||||
model_id=model_id, elapsed_s=elapsed, per_chapter=per_chapter,
|
||||
aggregate=aggregate(per_chapter, counts) if per_chapter else None,
|
||||
)
|
||||
|
||||
|
||||
def run_benchmark(slug: str, model_ids: list[str], *,
|
||||
chapters: Optional[list[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
reasoning: Optional[bool] = None,
|
||||
use_cached: bool = False,
|
||||
progress: Optional[Callable[[str], None]] = None) -> BenchmarkReport:
|
||||
"""Met plusieurs modeles en concurrence sur les chapitres de reference.
|
||||
|
||||
`use_cached=True` : compare les artefacts `analysis/chNN.json` existants (pas
|
||||
de modele charge ; `model_ids` est ignore, un seul resultat "cache").
|
||||
Sinon, pour chaque `model_id`, relance `analyze_chapter` en memoire (sans
|
||||
`save_analysis`) et score. Un seul MLX reside en RAM a la fois.
|
||||
|
||||
`progress` : callback optionnel appele a chaque etape (chargement, chapitre
|
||||
analyse, modele termine) pour suivre l'avancement d'un run long.
|
||||
"""
|
||||
from ..epub.parser import load_book, load_chapter_text
|
||||
from ..settings import get_settings
|
||||
from ..store import artifacts
|
||||
|
||||
emit = progress or (lambda _msg: None)
|
||||
|
||||
targets = _reference_chapters(slug, chapters)
|
||||
if not targets:
|
||||
raise ValueError(
|
||||
f"Aucune reference trouvee pour {slug!r} "
|
||||
f"(data/{slug}/reference/chNN.json).")
|
||||
|
||||
references = {i: _load_reference(slug, i) for i in targets}
|
||||
cast = artifacts.load_cast(slug)
|
||||
settings = get_settings()
|
||||
|
||||
snapshot = {
|
||||
"gemma_temperature": temperature if temperature is not None
|
||||
else settings.gemma_temperature,
|
||||
"gemma_max_tokens": settings.gemma_max_tokens,
|
||||
"gemma_reasoning": reasoning if reasoning is not None
|
||||
else settings.gemma_reasoning,
|
||||
"dedup_use_gemma": settings.dedup_use_gemma,
|
||||
"retro_pass_use_gemma": settings.retro_pass_use_gemma,
|
||||
"prompt_speakers_hash": hash(settings.prompt_speakers) & 0xFFFFFFFF,
|
||||
}
|
||||
report = BenchmarkReport(
|
||||
slug=slug, generated_at="", chapters=targets,
|
||||
settings_snapshot=snapshot)
|
||||
|
||||
if use_cached:
|
||||
per_chapter, counts = [], []
|
||||
for i in targets:
|
||||
hyp = artifacts.load_analysis(slug, i)
|
||||
cnt = _score_counts(references[i], hyp, cast)
|
||||
counts.append(cnt)
|
||||
per_chapter.append(_counts_to_score(i, cnt))
|
||||
report.models.append(_build_model_score("<cached>", per_chapter, counts, 0.0))
|
||||
return report
|
||||
|
||||
from .gemma import Gemma, _load
|
||||
from .segmenter import analyze_chapter
|
||||
|
||||
book = load_book(slug)
|
||||
by_index = {c.index: c for c in book.chapters}
|
||||
|
||||
# Epingle temperature/reasoning en memoire (jamais save_settings -> pas
|
||||
# d'ecriture disque), restaure en sortie.
|
||||
original_temp = settings.gemma_temperature
|
||||
original_reasoning = settings.gemma_reasoning
|
||||
if temperature is not None:
|
||||
settings.gemma_temperature = temperature
|
||||
if reasoning is not None:
|
||||
settings.gemma_reasoning = reasoning
|
||||
try:
|
||||
for mi, model_id in enumerate(model_ids, 1):
|
||||
t0 = time.perf_counter()
|
||||
per_chapter, counts = [], []
|
||||
model_err: Optional[str] = None
|
||||
emit(f"[{mi}/{len(model_ids)}] {model_id} — chargement du modele…")
|
||||
try:
|
||||
gemma = Gemma(model_id=model_id)
|
||||
for i in targets:
|
||||
ch = by_index.get(i)
|
||||
if ch is None:
|
||||
continue
|
||||
emit(f" ch{i:02d} — analyse en cours…")
|
||||
tc = time.perf_counter()
|
||||
ct = load_chapter_text(slug, ch)
|
||||
hyp, _ = analyze_chapter(
|
||||
ch, ct, gemma,
|
||||
book_chars=list(cast.characters), dedup_gemma=None)
|
||||
cnt = _score_counts(references[i], hyp, cast)
|
||||
counts.append(cnt)
|
||||
cs = _counts_to_score(i, cnt)
|
||||
per_chapter.append(cs)
|
||||
emit(f" ch{i:02d} — OK en {time.perf_counter() - tc:.0f}s "
|
||||
f"(locuteur dlg {cs.speaker_acc_dialogue:.0%}, "
|
||||
f"{len(cs.errors)} erreurs)")
|
||||
except Exception as exc: # noqa: BLE001 — un modele KO ne stoppe pas les autres
|
||||
model_err = f"{type(exc).__name__}: {exc}"
|
||||
emit(f" ! echec: {model_err[:120]}")
|
||||
finally:
|
||||
_load.cache_clear() # libere le modele avant le suivant
|
||||
ms = _build_model_score(
|
||||
model_id, per_chapter, counts, time.perf_counter() - t0)
|
||||
ms.error = model_err
|
||||
report.models.append(ms)
|
||||
if not model_err and ms.aggregate is not None:
|
||||
emit(f"[{mi}/{len(model_ids)}] {model_id} — termine en "
|
||||
f"{ms.elapsed_s:.0f}s (locuteur dlg "
|
||||
f"{ms.aggregate.speaker_acc_dialogue:.1%})")
|
||||
finally:
|
||||
settings.gemma_temperature = original_temp
|
||||
settings.gemma_reasoning = original_reasoning
|
||||
|
||||
return report
|
||||
@@ -17,6 +17,13 @@ from ..settings import get_settings
|
||||
_JSON_SPAN_RE = re.compile(r"(\{.*\}|\[.*\])", re.DOTALL)
|
||||
_FENCE_RE = re.compile(r"```(?:json)?\s*(.*?)```", re.DOTALL)
|
||||
|
||||
# Marqueurs de FIN de chaine de pensee : on ne garde que ce qui suit le dernier.
|
||||
# - balises type DeepSeek-R1 / Qwen-think
|
||||
# - format a canaux type Gemma 4 / Harmony (la pensee est close par <channel|>)
|
||||
_REASONING_END_MARKERS = ("</think>", "<channel|>", "<|channel|>")
|
||||
# Prefixe de canal/think non ferme reste en tete (pensee tronquee) : a retirer.
|
||||
_REASONING_OPEN_RE = re.compile(r"^\s*(?:<\|?channel\|?>\s*\w*|<think>)", re.IGNORECASE)
|
||||
|
||||
|
||||
@lru_cache(maxsize=2)
|
||||
def _load(model_id: str):
|
||||
@@ -25,6 +32,46 @@ def _load(model_id: str):
|
||||
return load(model_id)
|
||||
|
||||
|
||||
# Hook de streaming optionnel. Si defini, `generate()` diffuse chaque morceau de
|
||||
# texte AU FIL de la generation (pensee comprise, avant tout nettoyage) en
|
||||
# appelant ce callback. Utilise par `inkflow benchmark --stream` pour voir les
|
||||
# tokens en temps reel. None -> generation par lot classique (plus rapide).
|
||||
_TOKEN_SINK: Optional[Any] = None
|
||||
|
||||
|
||||
def set_token_sink(callback) -> None:
|
||||
"""Definit (ou retire avec None) le callback de streaming des tokens."""
|
||||
global _TOKEN_SINK
|
||||
_TOKEN_SINK = callback
|
||||
|
||||
|
||||
def _resolve_chat_template(model_id: str, tokenizer) -> Optional[str]:
|
||||
"""Renvoie un template de chat a passer explicitement, ou None.
|
||||
|
||||
Certaines conversions (Mistral recents...) logent leur template dans un
|
||||
fichier `chat_template.jinja` que le downloader de mlx-lm n'embarque pas
|
||||
toujours : `tokenizer.chat_template` est alors vide et `apply_chat_template`
|
||||
echoue. On recupere alors le fichier officiel du repo. None si le tokenizer
|
||||
possede deja un template (cas courant) ou si aucun n'est disponible.
|
||||
"""
|
||||
if getattr(tokenizer, "chat_template", None):
|
||||
return None
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
# Selon les conversions : fichier Jinja brut, ou JSON {"chat_template": ...}.
|
||||
for fname in ("chat_template.jinja", "chat_template.json"):
|
||||
try:
|
||||
text = Path(hf_hub_download(model_id, fname)).read_text(encoding="utf-8")
|
||||
except Exception: # noqa: BLE001 — fichier absent, on tente le suivant
|
||||
continue
|
||||
if fname.endswith(".json"):
|
||||
data = json.loads(text)
|
||||
return data.get("chat_template") if isinstance(data, dict) else None
|
||||
return text
|
||||
return None # aucun template dispo -> apply_chat_template levera une erreur claire
|
||||
|
||||
|
||||
class Gemma:
|
||||
"""Petite facade autour de mlx-lm pour piloter Gemma."""
|
||||
|
||||
@@ -32,10 +79,13 @@ class Gemma:
|
||||
self.model_id = model_id or get_settings().gemma_model
|
||||
self._model = None
|
||||
self._tokenizer = None
|
||||
self._chat_template = None # template recupere si absent du tokenizer
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
if self._model is None:
|
||||
self._model, self._tokenizer = _load(self.model_id)
|
||||
self._chat_template = _resolve_chat_template(
|
||||
self.model_id, self._tokenizer)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@@ -53,27 +103,76 @@ class Gemma:
|
||||
settings = get_settings()
|
||||
if max_tokens is None:
|
||||
max_tokens = settings.gemma_max_tokens
|
||||
# En mode raisonnement, plafond dedie (garde-fou anti-boucle) ; la
|
||||
# generation s'arrete de toute facon des que le JSON post-pensee est
|
||||
# complet (cf. boucle de streaming ci-dessous).
|
||||
if settings.gemma_reasoning:
|
||||
max_tokens = max(max_tokens, settings.gemma_reasoning_max_tokens)
|
||||
if temperature is None:
|
||||
temperature = settings.gemma_temperature
|
||||
from mlx_lm import generate
|
||||
# Decodage glouton (temp 0) + raisonnement = boucles de pensee sans fin.
|
||||
# On force un echantillonnage minimal en mode raisonnement.
|
||||
if settings.gemma_reasoning and temperature == 0.0:
|
||||
temperature = settings.gemma_reasoning_temperature
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
messages = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
# Modeles hybrides (Qwen3...) : hors mode raisonnement, on DESACTIVE la
|
||||
# pensee via enable_thinking=False -> JSON direct, bien plus rapide. Avec
|
||||
# --reasoning, on laisse penser puis on retire la pensee en aval. Ce
|
||||
# kwarg est ignore par les templates qui ne l'utilisent pas (Gemma...).
|
||||
template_kwargs = {}
|
||||
if not settings.gemma_reasoning:
|
||||
template_kwargs["enable_thinking"] = False
|
||||
formatted = self._tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=False
|
||||
messages, add_generation_prompt=True, tokenize=False,
|
||||
chat_template=self._chat_template, # None -> celui du tokenizer
|
||||
**template_kwargs,
|
||||
)
|
||||
sampler = make_sampler(temp=temperature)
|
||||
return generate(
|
||||
self._model,
|
||||
self._tokenizer,
|
||||
prompt=formatted,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
verbose=False,
|
||||
)
|
||||
# On streame (token par token) si : un sink est branche (--stream) OU on
|
||||
# est en mode raisonnement (pour pouvoir s'arreter des que la reponse est
|
||||
# prete, sans subir les boucles de pensee sans fin). Sinon, lot rapide.
|
||||
if _TOKEN_SINK is not None or settings.gemma_reasoning:
|
||||
from mlx_lm import stream_generate
|
||||
parts = []
|
||||
seen_end = False # marqueur de fin de pensee rencontre
|
||||
for resp in stream_generate(
|
||||
self._model, self._tokenizer, prompt=formatted,
|
||||
max_tokens=max_tokens, sampler=sampler,
|
||||
):
|
||||
parts.append(resp.text)
|
||||
if _TOKEN_SINK is not None:
|
||||
_TOKEN_SINK(resp.text)
|
||||
# Arret anticipe : une fois la pensee close, des que le JSON
|
||||
# post-pensee est complet, inutile de continuer a generer.
|
||||
if settings.gemma_reasoning and ("}" in resp.text or "]" in resp.text):
|
||||
buf = "".join(parts)
|
||||
if not seen_end:
|
||||
seen_end = any(mk in buf for mk in _REASONING_END_MARKERS)
|
||||
if seen_end and _has_complete_json(_strip_reasoning(buf)):
|
||||
break
|
||||
if _TOKEN_SINK is not None:
|
||||
_TOKEN_SINK("\n") # separe les generations successives
|
||||
raw = "".join(parts)
|
||||
else:
|
||||
from mlx_lm import generate
|
||||
raw = generate(
|
||||
self._model,
|
||||
self._tokenizer,
|
||||
prompt=formatted,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
verbose=False,
|
||||
)
|
||||
# Retire la chaine de pensee des modeles a raisonnement (sinon des
|
||||
# fragments de la "pensee" parasitent l'extraction JSON en aval).
|
||||
if settings.gemma_reasoning:
|
||||
return _strip_reasoning(raw)
|
||||
return raw
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
@@ -101,6 +200,35 @@ class Gemma:
|
||||
raise ValueError(f"Reponse JSON invalide apres {retries + 1} essais: {last_err}")
|
||||
|
||||
|
||||
def _strip_reasoning(text: str) -> str:
|
||||
"""Retire la chaine de pensee d'un modele a raisonnement.
|
||||
|
||||
Ne conserve que ce qui suit le dernier marqueur de fin de pensee
|
||||
(`</think>`, `<channel|>`...). Si seul un marqueur d'ouverture non ferme
|
||||
subsiste (pensee tronquee par le budget de tokens), on le retire en tete
|
||||
pour eviter de parser la pensee a la place de la reponse.
|
||||
"""
|
||||
t = text
|
||||
for marker in _REASONING_END_MARKERS:
|
||||
if marker in t:
|
||||
t = t.rsplit(marker, 1)[-1]
|
||||
t = _REASONING_OPEN_RE.sub("", t)
|
||||
return t.strip()
|
||||
|
||||
|
||||
def _has_complete_json(text: str) -> bool:
|
||||
"""True si `text` contient deja un objet/array JSON complet et parsable.
|
||||
|
||||
Sert a stopper la generation des modeles a raisonnement des que la reponse
|
||||
finale est ecrite (evite de consommer le budget en boucles de pensee).
|
||||
"""
|
||||
try:
|
||||
_extract_json(text)
|
||||
return True
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
def _extract_json(text: str) -> Any:
|
||||
"""Extrait le premier objet/array JSON d'une reponse libre du modele.
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import typer
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from .config import ensure_dirs
|
||||
from .config import book_data_dir, ensure_dirs
|
||||
from .epub.parser import load_book, load_chapter_text, parse_epub
|
||||
from .models import Cast
|
||||
from .store import artifacts
|
||||
@@ -95,6 +95,104 @@ def analyze(
|
||||
console.print(f"[green]Casting[/] : {len(chars)} personnages -> cast.json")
|
||||
|
||||
|
||||
@app.command()
|
||||
def benchmark(
|
||||
slug: str,
|
||||
models: Optional[str] = typer.Option(
|
||||
None, help="Modeles a comparer, separes par des virgules (def: modele courant)."),
|
||||
chapter: Optional[int] = typer.Option(
|
||||
None, help="Restreindre a un chapitre (def: tous ceux avec reference)."),
|
||||
temperature: Optional[float] = typer.Option(
|
||||
None, help="Epingle la temperature Gemma (repro). Ex: 0.0."),
|
||||
reasoning: bool = typer.Option(
|
||||
False, "--reasoning", help="Modeles a raisonnement : retire la pensee + budget tokens accru."),
|
||||
use_cached: bool = typer.Option(
|
||||
False, "--use-cached", help="Compare les analysis/chNN.json existants (pas de modele)."),
|
||||
stream: bool = typer.Option(
|
||||
False, "--stream", help="Affiche les tokens generes en temps reel (pensee + reponse)."),
|
||||
):
|
||||
"""Met des modeles en concurrence sur les chapitres de reference (vs reference/)."""
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from .analysis import gemma as _gemma
|
||||
from .analysis.benchmark import run_benchmark
|
||||
from .settings import get_settings
|
||||
|
||||
model_ids = ([m.strip() for m in models.split(",") if m.strip()]
|
||||
if models else [get_settings().gemma_model])
|
||||
chapters = [chapter] if chapter is not None else None
|
||||
|
||||
label = "artefacts en cache" if use_cached else f"{len(model_ids)} modele(s)"
|
||||
console.print(f"[blue]Benchmark[/] {slug} ({label}) — suivi par chapitre :")
|
||||
# Suivi en clair (lignes persistantes), avec horodatage pour voir l'avancement
|
||||
# d'un run long. On evite console.status (spinner) qui n'imprime rien.
|
||||
def _progress(msg: str) -> None:
|
||||
from datetime import datetime as _dt
|
||||
console.print(f"[dim]{_dt.now():%H:%M:%S}[/] {msg}")
|
||||
|
||||
# Streaming des tokens : ecriture brute sur stdout (sans markup rich) pour
|
||||
# voir defiler pensee et reponse. Necessite stdout non bufferise cote shell.
|
||||
if stream:
|
||||
def _sink(piece: str) -> None:
|
||||
sys.stdout.write(piece)
|
||||
sys.stdout.flush()
|
||||
_gemma.set_token_sink(_sink)
|
||||
try:
|
||||
report = run_benchmark(
|
||||
slug, model_ids, chapters=chapters,
|
||||
temperature=temperature,
|
||||
reasoning=reasoning if reasoning else None,
|
||||
use_cached=use_cached,
|
||||
progress=_progress)
|
||||
finally:
|
||||
if stream:
|
||||
_gemma.set_token_sink(None)
|
||||
report.generated_at = datetime.now().isoformat(timespec="seconds")
|
||||
|
||||
# Table comparative : une ligne par modele (agregat micro-moyenne).
|
||||
table = Table(title=f"Benchmark {slug} — chapitres {report.chapters}")
|
||||
table.add_column("modele")
|
||||
for col in ("speaker_dlg", "speaker_all", "incise_f1", "type", "glued", "temps(s)"):
|
||||
table.add_column(col, justify="right")
|
||||
for ms in report.models:
|
||||
if ms.error or ms.aggregate is None:
|
||||
table.add_row(ms.model_id, f"[red]{ms.error or 'aucun chapitre'}[/]")
|
||||
continue
|
||||
a = ms.aggregate
|
||||
table.add_row(
|
||||
ms.model_id,
|
||||
f"{a.speaker_acc_dialogue:.1%}",
|
||||
f"{a.speaker_acc_all:.1%}",
|
||||
f"{a.incise_overlap_f1:.2f}",
|
||||
f"{a.type_acc:.1%}",
|
||||
f"{a.glued_acc:.1%}",
|
||||
f"{ms.elapsed_s:.0f}",
|
||||
)
|
||||
console.print(table)
|
||||
|
||||
# Detail des erreurs d'attribution (les pires) par modele.
|
||||
for ms in report.models:
|
||||
errs = [e for cs in ms.per_chapter for e in cs.errors]
|
||||
if not errs:
|
||||
continue
|
||||
console.print(f"\n[bold]{ms.model_id}[/] — {len(errs)} erreur(s) de locuteur:")
|
||||
for e in errs[:15]:
|
||||
console.print(
|
||||
f" ch·seg{e.index:>3} attendu=[green]{e.expected}[/] "
|
||||
f"obtenu=[red]{e.got}[/] — {e.text_excerpt!r}")
|
||||
if len(errs) > 15:
|
||||
console.print(f" [dim]… +{len(errs) - 15} autres[/]")
|
||||
|
||||
# Rapport JSON horodate.
|
||||
out_dir = book_data_dir(slug) / "benchmark"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
stamp = report.generated_at.replace(":", "").replace("-", "")
|
||||
out_path = out_dir / f"{stamp}.json"
|
||||
out_path.write_text(report.model_dump_json(indent=2), encoding="utf-8")
|
||||
console.print(f"\n[green]Rapport[/] -> {out_path}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def pronounce(
|
||||
slug: str,
|
||||
|
||||
@@ -74,7 +74,22 @@ class Settings(BaseModel):
|
||||
|
||||
# --- Generation Gemma ---
|
||||
gemma_temperature: float = Field(0.1, ge=0.0, le=2.0)
|
||||
gemma_max_tokens: int = Field(2048, ge=64, le=8192)
|
||||
gemma_max_tokens: int = Field(2048, ge=64, le=16384)
|
||||
# Modeles a raisonnement (Gemma 4, DeepSeek-R1, Qwen-think...) : ils emettent
|
||||
# une chaine de pensee avant la reponse. Active le retrait de cette pensee
|
||||
# (canaux <|channel>thought.../<channel|>, balises <think>...</think>) AVANT
|
||||
# le parsing JSON, et releve le plafond de tokens (la pensee en consomme).
|
||||
gemma_reasoning: bool = False
|
||||
# Plafond de tokens en mode raisonnement (la pensee en consomme beaucoup).
|
||||
# La generation s'arrete de toute facon des que la reponse JSON post-pensee
|
||||
# est complete ; ce plafond est un garde-fou contre les boucles de pensee
|
||||
# sans fin (certains modeles tournent en rond a temperature 0).
|
||||
gemma_reasoning_max_tokens: int = Field(4096, ge=256, le=16384)
|
||||
# Temperature en mode raisonnement. Le decodage GLOUTON (temp 0) fait boucler
|
||||
# les modeles a raisonnement (repetitions sans fin) ; Qwen & co recommandent
|
||||
# un echantillonnage. Si la temperature effective est 0, on bascule sur
|
||||
# celle-ci. Rend le benchmark non deterministe en mode raisonnement (inevitable).
|
||||
gemma_reasoning_temperature: float = Field(0.6, ge=0.0, le=2.0)
|
||||
|
||||
# --- Prompts systeme (analyse) ---
|
||||
prompt_speakers: str = DEFAULT_PROMPT_SPEAKERS
|
||||
|
||||
Reference in New Issue
Block a user