124 lines
4.0 KiB
Python
124 lines
4.0 KiB
Python
"""Wrapper mlx-lm autour de Gemma pour l'analyse de texte.
|
|
|
|
Charge le modele paresseusement (une seule fois par process) et expose des
|
|
helpers de generation, dont un `generate_json` tolerant qui extrait le premier
|
|
objet/array JSON valide de la sortie du modele.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
from functools import lru_cache
|
|
from typing import Any, Optional
|
|
|
|
from ..settings import get_settings
|
|
|
|
# Bornes d'un bloc JSON dans une reponse potentiellement bavarde.
|
|
_JSON_SPAN_RE = re.compile(r"(\{.*\}|\[.*\])", re.DOTALL)
|
|
_FENCE_RE = re.compile(r"```(?:json)?\s*(.*?)```", re.DOTALL)
|
|
|
|
|
|
@lru_cache(maxsize=2)
|
|
def _load(model_id: str):
|
|
# Import paresseux : evite de charger mlx tant qu'on n'analyse pas.
|
|
from mlx_lm import load
|
|
return load(model_id)
|
|
|
|
|
|
class Gemma:
|
|
"""Petite facade autour de mlx-lm pour piloter Gemma."""
|
|
|
|
def __init__(self, model_id: Optional[str] = None):
|
|
self.model_id = model_id or get_settings().gemma_model
|
|
self._model = None
|
|
self._tokenizer = None
|
|
|
|
def _ensure_loaded(self) -> None:
|
|
if self._model is None:
|
|
self._model, self._tokenizer = _load(self.model_id)
|
|
|
|
def generate(
|
|
self,
|
|
prompt: str,
|
|
*,
|
|
system: Optional[str] = None,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
) -> str:
|
|
"""Genere une reponse texte a partir d'un prompt (template de chat).
|
|
|
|
`max_tokens`/`temperature` non fournis -> valeurs des reglages courants.
|
|
"""
|
|
self._ensure_loaded()
|
|
settings = get_settings()
|
|
if max_tokens is None:
|
|
max_tokens = settings.gemma_max_tokens
|
|
if temperature is None:
|
|
temperature = settings.gemma_temperature
|
|
from mlx_lm import generate
|
|
from mlx_lm.sample_utils import make_sampler
|
|
|
|
messages = []
|
|
if system:
|
|
messages.append({"role": "system", "content": system})
|
|
messages.append({"role": "user", "content": prompt})
|
|
formatted = self._tokenizer.apply_chat_template(
|
|
messages, add_generation_prompt=True, tokenize=False
|
|
)
|
|
sampler = make_sampler(temp=temperature)
|
|
return generate(
|
|
self._model,
|
|
self._tokenizer,
|
|
prompt=formatted,
|
|
max_tokens=max_tokens,
|
|
sampler=sampler,
|
|
verbose=False,
|
|
)
|
|
|
|
def generate_json(
|
|
self,
|
|
prompt: str,
|
|
*,
|
|
system: Optional[str] = None,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
retries: int = 1,
|
|
) -> Any:
|
|
"""Genere puis parse un JSON. Reessaie en cas d'echec de parsing.
|
|
|
|
`max_tokens`/`temperature` non fournis -> valeurs des reglages courants.
|
|
"""
|
|
last_err: Optional[Exception] = None
|
|
for attempt in range(retries + 1):
|
|
raw = self.generate(
|
|
prompt, system=system, max_tokens=max_tokens,
|
|
temperature=temperature if attempt == 0 else 0.0,
|
|
)
|
|
try:
|
|
return _extract_json(raw)
|
|
except Exception as exc: # noqa: BLE001
|
|
last_err = exc
|
|
raise ValueError(f"Reponse JSON invalide apres {retries + 1} essais: {last_err}")
|
|
|
|
|
|
def _extract_json(text: str) -> Any:
|
|
"""Extrait le premier objet/array JSON d'une reponse libre du modele.
|
|
|
|
Tolere le texte parasite avant/apres (y compris un 2e bloc) grace a
|
|
raw_decode, qui s'arrete au premier JSON complet.
|
|
"""
|
|
text = text.strip()
|
|
fence = _FENCE_RE.search(text)
|
|
if fence:
|
|
text = fence.group(1).strip()
|
|
decoder = json.JSONDecoder()
|
|
# Cherche le 1er debut de structure JSON et decode a partir de la.
|
|
for i, ch in enumerate(text):
|
|
if ch in "[{":
|
|
try:
|
|
obj, _ = decoder.raw_decode(text[i:])
|
|
return obj
|
|
except json.JSONDecodeError:
|
|
continue
|
|
raise ValueError("aucun JSON trouve dans la reponse")
|