Initial commit: InkFlow — EPUB vers livre audio local (MLX/Kokoro)
This commit is contained in:
123
backend/inkflow/analysis/gemma.py
Normal file
123
backend/inkflow/analysis/gemma.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Wrapper mlx-lm autour de Gemma pour l'analyse de texte.
|
||||
|
||||
Charge le modele paresseusement (une seule fois par process) et expose des
|
||||
helpers de generation, dont un `generate_json` tolerant qui extrait le premier
|
||||
objet/array JSON valide de la sortie du modele.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..settings import get_settings
|
||||
|
||||
# Bornes d'un bloc JSON dans une reponse potentiellement bavarde.
|
||||
_JSON_SPAN_RE = re.compile(r"(\{.*\}|\[.*\])", re.DOTALL)
|
||||
_FENCE_RE = re.compile(r"```(?:json)?\s*(.*?)```", re.DOTALL)
|
||||
|
||||
|
||||
@lru_cache(maxsize=2)
|
||||
def _load(model_id: str):
|
||||
# Import paresseux : evite de charger mlx tant qu'on n'analyse pas.
|
||||
from mlx_lm import load
|
||||
return load(model_id)
|
||||
|
||||
|
||||
class Gemma:
|
||||
"""Petite facade autour de mlx-lm pour piloter Gemma."""
|
||||
|
||||
def __init__(self, model_id: Optional[str] = None):
|
||||
self.model_id = model_id or get_settings().gemma_model
|
||||
self._model = None
|
||||
self._tokenizer = None
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
if self._model is None:
|
||||
self._model, self._tokenizer = _load(self.model_id)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
system: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
) -> str:
|
||||
"""Genere une reponse texte a partir d'un prompt (template de chat).
|
||||
|
||||
`max_tokens`/`temperature` non fournis -> valeurs des reglages courants.
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
settings = get_settings()
|
||||
if max_tokens is None:
|
||||
max_tokens = settings.gemma_max_tokens
|
||||
if temperature is None:
|
||||
temperature = settings.gemma_temperature
|
||||
from mlx_lm import generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
messages = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
formatted = self._tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
sampler = make_sampler(temp=temperature)
|
||||
return generate(
|
||||
self._model,
|
||||
self._tokenizer,
|
||||
prompt=formatted,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
system: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
retries: int = 1,
|
||||
) -> Any:
|
||||
"""Genere puis parse un JSON. Reessaie en cas d'echec de parsing.
|
||||
|
||||
`max_tokens`/`temperature` non fournis -> valeurs des reglages courants.
|
||||
"""
|
||||
last_err: Optional[Exception] = None
|
||||
for attempt in range(retries + 1):
|
||||
raw = self.generate(
|
||||
prompt, system=system, max_tokens=max_tokens,
|
||||
temperature=temperature if attempt == 0 else 0.0,
|
||||
)
|
||||
try:
|
||||
return _extract_json(raw)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_err = exc
|
||||
raise ValueError(f"Reponse JSON invalide apres {retries + 1} essais: {last_err}")
|
||||
|
||||
|
||||
def _extract_json(text: str) -> Any:
|
||||
"""Extrait le premier objet/array JSON d'une reponse libre du modele.
|
||||
|
||||
Tolere le texte parasite avant/apres (y compris un 2e bloc) grace a
|
||||
raw_decode, qui s'arrete au premier JSON complet.
|
||||
"""
|
||||
text = text.strip()
|
||||
fence = _FENCE_RE.search(text)
|
||||
if fence:
|
||||
text = fence.group(1).strip()
|
||||
decoder = json.JSONDecoder()
|
||||
# Cherche le 1er debut de structure JSON et decode a partir de la.
|
||||
for i, ch in enumerate(text):
|
||||
if ch in "[{":
|
||||
try:
|
||||
obj, _ = decoder.raw_decode(text[i:])
|
||||
return obj
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
raise ValueError("aucun JSON trouve dans la reponse")
|
||||
Reference in New Issue
Block a user