"""Wrapper mlx-lm autour de Gemma pour l'analyse de texte. Charge le modele paresseusement (une seule fois par process) et expose des helpers de generation, dont un `generate_json` tolerant qui extrait le premier objet/array JSON valide de la sortie du modele. """ from __future__ import annotations import json import re from functools import lru_cache from typing import Any, Optional from ..settings import get_settings # Bornes d'un bloc JSON dans une reponse potentiellement bavarde. _JSON_SPAN_RE = re.compile(r"(\{.*\}|\[.*\])", re.DOTALL) _FENCE_RE = re.compile(r"```(?:json)?\s*(.*?)```", re.DOTALL) # Marqueurs de FIN de chaine de pensee : on ne garde que ce qui suit le dernier. # - balises type DeepSeek-R1 / Qwen-think # - format a canaux type Gemma 4 / Harmony (la pensee est close par ) _REASONING_END_MARKERS = ("", "", "<|channel|>") # Prefixe de canal/think non ferme reste en tete (pensee tronquee) : a retirer. _REASONING_OPEN_RE = re.compile(r"^\s*(?:<\|?channel\|?>\s*\w*|)", re.IGNORECASE) @lru_cache(maxsize=2) def _load(model_id: str): # Import paresseux : evite de charger mlx tant qu'on n'analyse pas. from mlx_lm import load return load(model_id) # Hook de streaming optionnel. Si defini, `generate()` diffuse chaque morceau de # texte AU FIL de la generation (pensee comprise, avant tout nettoyage) en # appelant ce callback. Utilise par `inkflow benchmark --stream` pour voir les # tokens en temps reel. None -> generation par lot classique (plus rapide). _TOKEN_SINK: Optional[Any] = None def set_token_sink(callback) -> None: """Definit (ou retire avec None) le callback de streaming des tokens.""" global _TOKEN_SINK _TOKEN_SINK = callback def _resolve_chat_template(model_id: str, tokenizer) -> Optional[str]: """Renvoie un template de chat a passer explicitement, ou None. Certaines conversions (Mistral recents...) logent leur template dans un fichier `chat_template.jinja` que le downloader de mlx-lm n'embarque pas toujours : `tokenizer.chat_template` est alors vide et `apply_chat_template` echoue. On recupere alors le fichier officiel du repo. None si le tokenizer possede deja un template (cas courant) ou si aucun n'est disponible. """ if getattr(tokenizer, "chat_template", None): return None from pathlib import Path from huggingface_hub import hf_hub_download # Selon les conversions : fichier Jinja brut, ou JSON {"chat_template": ...}. for fname in ("chat_template.jinja", "chat_template.json"): try: text = Path(hf_hub_download(model_id, fname)).read_text(encoding="utf-8") except Exception: # noqa: BLE001 — fichier absent, on tente le suivant continue if fname.endswith(".json"): data = json.loads(text) return data.get("chat_template") if isinstance(data, dict) else None return text return None # aucun template dispo -> apply_chat_template levera une erreur claire class Gemma: """Petite facade autour de mlx-lm pour piloter Gemma.""" def __init__(self, model_id: Optional[str] = None): self.model_id = model_id or get_settings().gemma_model self._model = None self._tokenizer = None self._chat_template = None # template recupere si absent du tokenizer def _ensure_loaded(self) -> None: if self._model is None: self._model, self._tokenizer = _load(self.model_id) self._chat_template = _resolve_chat_template( self.model_id, self._tokenizer) def generate( self, prompt: str, *, system: Optional[str] = None, max_tokens: Optional[int] = None, temperature: Optional[float] = None, ) -> str: """Genere une reponse texte a partir d'un prompt (template de chat). `max_tokens`/`temperature` non fournis -> valeurs des reglages courants. """ self._ensure_loaded() settings = get_settings() if max_tokens is None: max_tokens = settings.gemma_max_tokens # En mode raisonnement, plafond dedie (garde-fou anti-boucle) ; la # generation s'arrete de toute facon des que le JSON post-pensee est # complet (cf. boucle de streaming ci-dessous). if settings.gemma_reasoning: max_tokens = max(max_tokens, settings.gemma_reasoning_max_tokens) if temperature is None: temperature = settings.gemma_temperature # Decodage glouton (temp 0) + raisonnement = boucles de pensee sans fin. # On force un echantillonnage minimal en mode raisonnement. if settings.gemma_reasoning and temperature == 0.0: temperature = settings.gemma_reasoning_temperature from mlx_lm.sample_utils import make_sampler messages = [] if system: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": prompt}) # Modeles hybrides (Qwen3...) : hors mode raisonnement, on DESACTIVE la # pensee via enable_thinking=False -> JSON direct, bien plus rapide. Avec # --reasoning, on laisse penser puis on retire la pensee en aval. Ce # kwarg est ignore par les templates qui ne l'utilisent pas (Gemma...). template_kwargs = {} if not settings.gemma_reasoning: template_kwargs["enable_thinking"] = False formatted = self._tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False, chat_template=self._chat_template, # None -> celui du tokenizer **template_kwargs, ) sampler = make_sampler(temp=temperature) # On streame (token par token) si : un sink est branche (--stream) OU on # est en mode raisonnement (pour pouvoir s'arreter des que la reponse est # prete, sans subir les boucles de pensee sans fin). Sinon, lot rapide. if _TOKEN_SINK is not None or settings.gemma_reasoning: from mlx_lm import stream_generate parts = [] seen_end = False # marqueur de fin de pensee rencontre for resp in stream_generate( self._model, self._tokenizer, prompt=formatted, max_tokens=max_tokens, sampler=sampler, ): parts.append(resp.text) if _TOKEN_SINK is not None: _TOKEN_SINK(resp.text) # Arret anticipe : une fois la pensee close, des que le JSON # post-pensee est complet, inutile de continuer a generer. if settings.gemma_reasoning and ("}" in resp.text or "]" in resp.text): buf = "".join(parts) if not seen_end: seen_end = any(mk in buf for mk in _REASONING_END_MARKERS) if seen_end and _has_complete_json(_strip_reasoning(buf)): break if _TOKEN_SINK is not None: _TOKEN_SINK("\n") # separe les generations successives raw = "".join(parts) else: from mlx_lm import generate raw = generate( self._model, self._tokenizer, prompt=formatted, max_tokens=max_tokens, sampler=sampler, verbose=False, ) # Retire la chaine de pensee des modeles a raisonnement (sinon des # fragments de la "pensee" parasitent l'extraction JSON en aval). if settings.gemma_reasoning: return _strip_reasoning(raw) return raw def generate_json( self, prompt: str, *, system: Optional[str] = None, max_tokens: Optional[int] = None, temperature: Optional[float] = None, retries: int = 1, ) -> Any: """Genere puis parse un JSON. Reessaie en cas d'echec de parsing. `max_tokens`/`temperature` non fournis -> valeurs des reglages courants. """ last_err: Optional[Exception] = None for attempt in range(retries + 1): raw = self.generate( prompt, system=system, max_tokens=max_tokens, temperature=temperature if attempt == 0 else 0.0, ) try: return _extract_json(raw) except Exception as exc: # noqa: BLE001 last_err = exc raise ValueError(f"Reponse JSON invalide apres {retries + 1} essais: {last_err}") def _strip_reasoning(text: str) -> str: """Retire la chaine de pensee d'un modele a raisonnement. Ne conserve que ce qui suit le dernier marqueur de fin de pensee (``, ``...). Si seul un marqueur d'ouverture non ferme subsiste (pensee tronquee par le budget de tokens), on le retire en tete pour eviter de parser la pensee a la place de la reponse. """ t = text for marker in _REASONING_END_MARKERS: if marker in t: t = t.rsplit(marker, 1)[-1] t = _REASONING_OPEN_RE.sub("", t) return t.strip() def _has_complete_json(text: str) -> bool: """True si `text` contient deja un objet/array JSON complet et parsable. Sert a stopper la generation des modeles a raisonnement des que la reponse finale est ecrite (evite de consommer le budget en boucles de pensee). """ try: _extract_json(text) return True except Exception: # noqa: BLE001 return False def _extract_json(text: str) -> Any: """Extrait le premier objet/array JSON d'une reponse libre du modele. Tolere le texte parasite avant/apres (y compris un 2e bloc) grace a raw_decode, qui s'arrete au premier JSON complet. """ text = text.strip() fence = _FENCE_RE.search(text) if fence: text = fence.group(1).strip() decoder = json.JSONDecoder() # Cherche le 1er debut de structure JSON et decode a partir de la. for i, ch in enumerate(text): if ch in "[{": try: obj, _ = decoder.raw_decode(text[i:]) return obj except json.JSONDecodeError: continue raise ValueError("aucun JSON trouve dans la reponse")