From d1b32e07d6e682a6ee89ea480fee079b083a2573 Mon Sep 17 00:00:00 2001 From: malteos Date: Mon, 11 May 2026 12:00:21 +0200 Subject: [PATCH] feat: integrate PleIAs/CommonLingua byte-level LID model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a `commonlingua` model wired into the registry, backed by a vendored copy of upstream's `model.py` (Apache 2.0, rev 43fe88d) so we don't need `weights_only=False` to load remote pickled code. The model is exposed via a new `[commonlingua]` optional extra that pulls only `torch` (no transformers stack); device selection mirrors AfroLID's MPS > CUDA > CPU. `requires_preprocessing = False` because the byte-level architecture relies on casing as a strong language signal — the OpenLID normer's lowercasing collapses Latin-script predictions. Eval on the full CommonLID dataset (373,230 samples) gives a micro accuracy of 77.58%, matching the model card's 77.63% claim. --- Makefile | 8 +- README.md | 4 +- pyproject.toml | 9 +- src/commonlid/models/__init__.py | 1 + src/commonlid/models/commonlingua.py | 115 +++++++++++ src/commonlid/vendor/commonlingua/__init__.py | 4 + src/commonlid/vendor/commonlingua/model.py | 186 ++++++++++++++++++ tests/models/test_commonlingua.py | 95 +++++++++ tests/models/test_model_registration.py | 1 + uv.lock | 8 +- 10 files changed, 425 insertions(+), 6 deletions(-) create mode 100644 src/commonlid/models/commonlingua.py create mode 100644 src/commonlid/vendor/commonlingua/__init__.py create mode 100644 src/commonlid/vendor/commonlingua/model.py create mode 100644 tests/models/test_commonlingua.py diff --git a/Makefile b/Makefile index 42074a4..4de827e 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ PACKAGE := src/commonlid .DEFAULT_GOAL := help .PHONY: help venv \ - install install-all install-afrolid install-notebooks install-leaderboard \ + install install-all install-afrolid install-commonlingua install-notebooks install-leaderboard \ lint format format-check typecheck \ test test-slow test-all check \ build clean \ @@ -24,6 +24,7 @@ help: @echo " venv Create a uv-managed virtualenv (.venv)" @echo " install Sync runtime + dev extras (lint/type/test)" @echo " install-afrolid install + the heavy [afrolid] extra (torch + transformers)" + @echo " install-commonlingua install + the [commonlingua] extra (torch only)" @echo " install-notebooks install + the [notebooks] extra (jupyterlab + matplotlib)" @echo " install-leaderboard install + the [leaderboard] extra (gradio)" @echo " install-all install + every optional extra" @@ -55,6 +56,9 @@ install: install-afrolid: uv sync --extra dev --extra afrolid $(PYTHON_FLAG) +install-commonlingua: + uv sync --extra dev --extra commonlingua $(PYTHON_FLAG) + install-notebooks: uv sync --extra dev --extra notebooks $(PYTHON_FLAG) @@ -62,7 +66,7 @@ install-leaderboard: uv sync --extra dev --extra leaderboard $(PYTHON_FLAG) install-all: - uv sync --extra dev --extra afrolid --extra notebooks --extra leaderboard $(PYTHON_FLAG) + uv sync --extra dev --extra afrolid --extra commonlingua --extra notebooks --extra leaderboard $(PYTHON_FLAG) lint: uv run ruff check $(SRC_DIRS) diff --git a/README.md b/README.md index e4944b1..f24fd14 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ From PyPI: pip install commonlid # core deps + classical LID models pip install "commonlid[llm]" # + DSPy-based LLM evaluation pip install "commonlid[afrolid]" # + torch/transformers for AfroLID +pip install "commonlid[commonlingua]" # + torch for the CommonLingua byte-level model pip install "commonlid[notebooks]" # + jupyterlab + matplotlib for paper_tables.ipynb pip install "commonlid[all]" # everything runtime-facing ``` @@ -192,7 +193,7 @@ from commonlid import list_models, list_datasets assert list_models() == [ "AfroLID", "GlotLID", "OpenLID-v2", "cld2", "cld3", - "fasttext", "funlangid", "pyfranc", + "commonlingua", "fasttext", "funlangid", "pyfranc", ] assert list_datasets() == [ "bibles_300", "bibles_300_nano", @@ -298,6 +299,7 @@ for line in preds_path.read_text().splitlines(): | `fasttext` | [facebook/fasttext-language-identification](https://huggingface.co/facebook/fasttext-language-identification) | fasttext | | `pyfranc` | [pyfranc](https://pypi.org/project/pyfranc/) | Pure Python | | `AfroLID` | [UBC-NLP/afrolid_1.5](https://huggingface.co/UBC-NLP/afrolid_1.5) | Requires `[afrolid]` extra | +| `commonlingua` | [PleIAs/CommonLingua](https://huggingface.co/PleIAs/CommonLingua) | 2.35M-param byte-level model, 334 languages; requires `[commonlingua]` extra | | `funlangid` | Vendored in `src/commonlid/vendor/fun_langid.py` | Simple char-4gram baseline | LLM models are instantiated dynamically (`DSPyLLMModel`) and not diff --git a/pyproject.toml b/pyproject.toml index 7178729..900a71b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,11 @@ llm = [ "botocore>=1.35", ] cld3 = ["cld3-py>=3.1"] +commonlingua = [ + # CommonLingua is a 2.35M-param byte-level model; needs torch but not the + # transformers stack that [afrolid] pulls in. + "torch>=2.4", +] leaderboard = [ # gradio 4.x imports HfFolder from huggingface_hub, which was removed in # huggingface-hub 1.0; gradio 5 dropped that import. @@ -88,7 +93,7 @@ notebooks = [ "nbclient>=0.10", ] all = [ - "commonlid[afrolid,llm]", + "commonlid[afrolid,llm,commonlingua]", ] [project.scripts] @@ -208,6 +213,8 @@ omit = [ # afrolid needs the heavy `[afrolid]` extra (torch + transformers); not # installed in dev and so exercised only via mocked unit tests. "src/commonlid/models/afrolid.py", + # commonlingua needs the `[commonlingua]` extra (torch); same precedent. + "src/commonlid/models/commonlingua.py", ] [tool.coverage.report] diff --git a/src/commonlid/models/__init__.py b/src/commonlid/models/__init__.py index bd780fa..cb74c2f 100644 --- a/src/commonlid/models/__init__.py +++ b/src/commonlid/models/__init__.py @@ -11,6 +11,7 @@ from commonlid.models import afrolid as _afrolid # noqa: F401 from commonlid.models import cld2 as _cld2 # noqa: F401 from commonlid.models import cld3 as _cld3 # noqa: F401 +from commonlid.models import commonlingua as _commonlingua # noqa: F401 from commonlid.models import fasttext_ft as _fasttext_ft # noqa: F401 from commonlid.models import funlangid as _funlangid # noqa: F401 from commonlid.models import glotlid as _glotlid # noqa: F401 diff --git a/src/commonlid/models/commonlingua.py b/src/commonlid/models/commonlingua.py new file mode 100644 index 0000000..e0f07d3 --- /dev/null +++ b/src/commonlid/models/commonlingua.py @@ -0,0 +1,115 @@ +"""CommonLingua: PleIAs' byte-level LID model (PleIAs/CommonLingua). + +Requires the ``commonlid[commonlingua]`` extra (torch). The checkpoint +embeds its own ``lang2idx`` map, so no separate metadata file is fetched. +Device selection mirrors AfroLID: MPS > CUDA > CPU. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, ClassVar + +from commonlid.core.lid_model import LIDModel +from commonlid.core.registry import register_model + +if TYPE_CHECKING: + import torch + + +@register_model +class CommonLinguaModel(LIDModel): + model_id = "commonlingua" + # Byte-level model: casing carries strong language signal, so we feed + # raw UTF-8 and skip the OpenLID normer (which lowercases everything). + requires_preprocessing: ClassVar[bool] = False + + _REPO_ID: ClassVar[str] = "PleIAs/CommonLingua" + _CHECKPOINT_FILENAME: ClassVar[str] = "model.pt" + _INTERNAL_BATCH: ClassVar[int] = 256 + + def __init__(self) -> None: + super().__init__() + self._model: Any = None + self._idx2lang: dict[int, str] | None = None + self._max_len: int | None = None + self._device: str | None = None + + def load(self) -> None: + if self._loaded: + return + try: + import torch + except ImportError as exc: + msg = "CommonLingua requires torch. Install with: pip install 'commonlid[commonlingua]'" + raise ImportError(msg) from exc + + from huggingface_hub import hf_hub_download + + from commonlid.vendor.commonlingua.model import CONFIGS, ByteHybrid + + ckpt_path = hf_hub_download(repo_id=self._REPO_ID, filename=self._CHECKPOINT_FILENAME) + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + + if torch.backends.mps.is_available(): + device = "mps" + elif torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + model = ByteHybrid( # type: ignore[no-untyped-call] + num_classes=ckpt["num_classes"], + max_len=ckpt["max_len"], + **CONFIGS[ckpt["config"]], + ) + model.load_state_dict(ckpt["model_state_dict"]) + model.eval().to(device) + + self._model = model + self._idx2lang = {v: k for k, v in ckpt["lang2idx"].items()} + self._max_len = int(ckpt["max_len"]) + self._device = device + super().load() + + def _encode(self, texts: Sequence[str]) -> torch.Tensor: + import numpy as np + import torch + + assert self._max_len is not None + out = np.full((len(texts), self._max_len), 256, dtype=np.int64) + for i, t in enumerate(texts): + raw = t.encode("utf-8", errors="replace")[: self._max_len] + if raw: + out[i, : len(raw)] = np.frombuffer(raw, dtype=np.uint8) + return torch.from_numpy(out) + + def _predict_batch(self, texts: Sequence[str]) -> list[str | None]: + import torch + + if not self._loaded: + self.load() + assert self._idx2lang is not None + assert self._device is not None + + results: list[str | None] = [] + for start in range(0, len(texts), self._INTERNAL_BATCH): + chunk = list(texts[start : start + self._INTERNAL_BATCH]) + batch = self._encode(chunk).to(self._device) + with torch.no_grad(): + logits = self._model(batch) + pred_idx = logits.argmax(dim=-1).cpu().tolist() + results.extend(self._idx2lang[int(i)] for i in pred_idx) + return results + + def discover_supported_languages(self) -> frozenset[str]: + """Return every ISO 639-3 code in the model's ``lang2idx`` map.""" + if not self._loaded: + self.load() + assert self._idx2lang is not None + codes: set[str] = set() + for code in self._idx2lang.values(): + conformed = self._conform(code) + if conformed is not None: + codes.add(conformed) + return frozenset(codes) diff --git a/src/commonlid/vendor/commonlingua/__init__.py b/src/commonlid/vendor/commonlingua/__init__.py new file mode 100644 index 0000000..ff4caf2 --- /dev/null +++ b/src/commonlid/vendor/commonlingua/__init__.py @@ -0,0 +1,4 @@ +"""Vendored PleIAs/CommonLingua architecture (Apache 2.0). + +Source: https://huggingface.co/PleIAs/CommonLingua +""" diff --git a/src/commonlid/vendor/commonlingua/model.py b/src/commonlid/vendor/commonlingua/model.py new file mode 100644 index 0000000..36b662c --- /dev/null +++ b/src/commonlid/vendor/commonlingua/model.py @@ -0,0 +1,186 @@ +"""Vendored from PleIAs/CommonLingua (revision 43fe88d75e94b11283b66daccbbe4a73e7bc1361) under Apache 2.0. + +Upstream: https://huggingface.co/PleIAs/CommonLingua/blob/main/model.py + +ByteHybrid: byte-level language identification (CommonLingua v7.2.1). + +Operates directly on raw UTF-8 bytes - no tokenizer required: + + raw bytes -> byte-embed + trigram-hash-embed (summed) + -> 3 x depthwise Conv1D (k=15) + -> 1 x bidirectional attention (RoPE, 4 heads) + -> masked mean-pool + -> classification head (334 logits) + +The shipped checkpoint uses the ``base_ngram`` config: d_model=256, 4096 trigram +hash buckets x 64 dim, max_len=512 bytes. Total parameters ~ 2.35 M. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ByteNgramEmbed(nn.Module): + """Rolling polynomial hash of byte trigrams into a fixed-size table. + + Hash collisions act as regularisation; the small table (4096 x 64) + keeps parameter count bounded under arbitrary input distributions. + """ + + def __init__(self, num_buckets=4096, embed_dim=64, n=3): + super().__init__() + self.n = n + self.num_buckets = num_buckets + self.embed = nn.Embedding(num_buckets, embed_dim) + + def forward(self, byte_ids): + B, T = byte_ids.shape + clamped = byte_ids.clamp(max=255) + padded = F.pad(clamped, (0, self.n - 1), value=0) + h = torch.zeros(B, T, dtype=torch.long, device=byte_ids.device) + for i in range(self.n): + h = h * 257 + padded[:, i:i + T] + return self.embed(h % self.num_buckets) + + +class ByteConvBlock(nn.Module): + """Causal depthwise Conv1D + SwiGLU FFN, with residual + layernorm.""" + + def __init__(self, d_model, kernel_size=15, expand=2): + super().__init__() + self.norm1 = nn.LayerNorm(d_model) + self.pad = kernel_size - 1 + self.conv = nn.Conv1d(d_model, d_model, kernel_size, groups=d_model) + self.norm2 = nn.LayerNorm(d_model) + ffn = d_model * expand + self.ffn_gate = nn.Linear(d_model, ffn, bias=False) + self.ffn_up = nn.Linear(d_model, ffn, bias=False) + self.ffn_down = nn.Linear(ffn, d_model, bias=False) + + def forward(self, x): + residual = x + x = self.norm1(x).transpose(1, 2) + x = F.pad(x, (self.pad, 0)) + x = F.silu(self.conv(x)).transpose(1, 2) + x = residual + x + + residual = x + x = self.norm2(x) + x = self.ffn_down(F.silu(self.ffn_gate(x)) * self.ffn_up(x)) + return residual + x + + +def _rope(q, k): + head_dim = q.shape[-1] + seq_len = q.shape[-2] + freqs = 1.0 / (10000.0 ** (torch.arange(0, head_dim, 2, device=q.device).float() / head_dim)) + t = torch.arange(seq_len, device=q.device) + a = torch.outer(t, freqs) + cos = a.cos().to(q.dtype) + sin = a.sin().to(q.dtype) + + def rot(x): + x1, x2 = x[..., : head_dim // 2], x[..., head_dim // 2:] + return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + + return rot(q), rot(k) + + +class ByteAttnBlock(nn.Module): + """Bidirectional self-attention with RoPE + SwiGLU FFN.""" + + def __init__(self, d_model, n_heads=4, expand=2): + super().__init__() + self.n_heads = n_heads + self.head_dim = d_model // n_heads + self.norm1 = nn.LayerNorm(d_model) + self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) + self.out_proj = nn.Linear(d_model, d_model, bias=False) + self.norm2 = nn.LayerNorm(d_model) + ffn = d_model * expand + self.ffn_gate = nn.Linear(d_model, ffn, bias=False) + self.ffn_up = nn.Linear(d_model, ffn, bias=False) + self.ffn_down = nn.Linear(ffn, d_model, bias=False) + + def forward(self, x): + B, T, D = x.shape + residual = x + h = self.norm1(x) + qkv = self.qkv(h).reshape(B, T, 3, self.n_heads, self.head_dim) + q, k, v = (t.transpose(1, 2) for t in qkv.unbind(dim=2)) + q, k = _rope(q, k) + attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) + attn = attn.softmax(dim=-1) + out = (attn @ v).transpose(1, 2).contiguous().view(B, T, D) + x = residual + self.out_proj(out) + + residual = x + h = self.norm2(x) + h = self.ffn_down(F.silu(self.ffn_gate(h)) * self.ffn_up(h)) + return residual + h + + +class ByteHybrid(nn.Module): + """Byte-level classifier with optional trigram-hash augmentation.""" + + def __init__( + self, + num_classes, + d_model=256, + n_conv=3, + n_attn=1, + n_heads=4, + ffn_expand=2, + max_len=512, + conv_kernel=15, + ngram_buckets=0, + ngram_dim=64, + ): + super().__init__() + self.max_len = max_len + + # Byte values 0-255 plus index 256 = padding token + self.embed = nn.Embedding(257, d_model, padding_idx=256) + + self.ngram_embed = None + if ngram_buckets > 0: + self.ngram_embed = ByteNgramEmbed(ngram_buckets, ngram_dim, n=3) + self.ngram_proj = nn.Linear(ngram_dim, d_model, bias=False) + + self.conv_layers = nn.ModuleList( + [ByteConvBlock(d_model, conv_kernel, ffn_expand) for _ in range(n_conv)] + ) + self.attn_layers = nn.ModuleList( + [ByteAttnBlock(d_model, n_heads, ffn_expand) for _ in range(n_attn)] + ) + self.final_norm = nn.LayerNorm(d_model) + self.head = nn.Sequential( + nn.Linear(d_model, d_model), + nn.GELU(), + nn.Dropout(0.1), + nn.Linear(d_model, num_classes), + ) + + def forward(self, byte_ids): + pad_mask = byte_ids != 256 + x = self.embed(byte_ids) + if self.ngram_embed is not None: + x = x + self.ngram_proj(self.ngram_embed(byte_ids)) + for layer in self.conv_layers: + x = layer(x) + for layer in self.attn_layers: + x = layer(x) + x = self.final_norm(x) + mask = pad_mask.unsqueeze(-1).to(x.dtype) + x = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) + return self.head(x) + + +# Single shipped configuration. The checkpoint encodes which config it was +# trained with under the "config" key. +CONFIGS = { + "base_ngram": dict( + d_model=256, n_conv=3, n_attn=1, n_heads=4, conv_kernel=15, + ngram_buckets=4096, ngram_dim=64, + ), +} diff --git a/tests/models/test_commonlingua.py b/tests/models/test_commonlingua.py new file mode 100644 index 0000000..34c6d35 --- /dev/null +++ b/tests/models/test_commonlingua.py @@ -0,0 +1,95 @@ +"""Unit tests for CommonLinguaModel. + +The real model needs the ``[commonlingua]`` extra (torch), which is not in +``dev``. We mock the heavy bits and exercise the wrapper logic. +""" + +from __future__ import annotations + +import sys +from typing import Any, ClassVar + +import pytest + +from commonlid.models import commonlingua as commonlingua_mod +from commonlid.models.commonlingua import CommonLinguaModel + + +def test_load_raises_helpful_error_without_torch(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setitem(sys.modules, "torch", None) + with pytest.raises(ImportError, match=r"commonlid\[commonlingua\]"): + CommonLinguaModel().load() + + +def test_predict_returns_codes_from_idx2lang(monkeypatch: pytest.MonkeyPatch) -> None: + class _FakeTensor: + def __init__(self, values: list[int]) -> None: + self._values = values + + def argmax(self, dim: int = -1) -> _FakeTensor: + return self + + def cpu(self) -> _FakeTensor: + return self + + def tolist(self) -> list[int]: + return self._values + + def to(self, _device: str) -> _FakeTensor: + return self + + class _FakeModel: + def __init__(self) -> None: + self.calls = 0 + + def __call__(self, batch: Any) -> _FakeTensor: + self.calls += 1 + # Indices 0, 1, 2 -> eng, fra, deu via fake idx2lang. + return _FakeTensor([0, 1, 2]) + + class _NoGrad: + def __enter__(self) -> None: + return None + + def __exit__(self, *_a: Any) -> None: + return None + + fake_torch = type(sys)("torch") + fake_torch.no_grad = lambda: _NoGrad() # type: ignore[attr-defined] + fake_torch.from_numpy = lambda arr: _FakeTensor(list(arr.flatten())) # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, "torch", fake_torch) + + def fake_load(self: CommonLinguaModel) -> None: + self._model = _FakeModel() + self._idx2lang = {0: "eng", 1: "fra", 2: "deu"} + self._max_len = 512 + self._device = "cpu" + self._loaded = True + + monkeypatch.setattr(CommonLinguaModel, "load", fake_load) + preds = CommonLinguaModel().predict(["Hello", "Bonjour", "Hallo"]) + assert preds == ["eng", "fra", "deu"] + + +def test_discover_supported_languages_conforms_codes(monkeypatch: pytest.MonkeyPatch) -> None: + class _Marker: + idx2lang: ClassVar[dict[int, str]] = {0: "eng", 1: "jw", 2: "xxxxx"} + + def fake_load(self: CommonLinguaModel) -> None: + self._idx2lang = dict(_Marker.idx2lang) + self._loaded = True + + monkeypatch.setattr(CommonLinguaModel, "load", fake_load) + langs = CommonLinguaModel().discover_supported_languages() + assert "eng" in langs + assert "jav" in langs # jw -> jav via _conform + assert "xxxxx" not in langs + + +def test_model_registered() -> None: + from commonlid.core.registry import get_model + + model = get_model("commonlingua") + assert isinstance(model, CommonLinguaModel) + # Keep the module reference alive for coverage/mypy. + assert commonlingua_mod.CommonLinguaModel is CommonLinguaModel diff --git a/tests/models/test_model_registration.py b/tests/models/test_model_registration.py index 20b1c7b..c46b92e 100644 --- a/tests/models/test_model_registration.py +++ b/tests/models/test_model_registration.py @@ -20,6 +20,7 @@ def _import_models() -> None: "OpenLID-v2", "cld2", "cld3", + "commonlingua", "fasttext", "funlangid", "pyfranc", diff --git a/uv.lock b/uv.lock index d0aa5ea..3448006 100644 --- a/uv.lock +++ b/uv.lock @@ -864,6 +864,9 @@ all = [ cld3 = [ { name = "cld3-py" }, ] +commonlingua = [ + { name = "torch" }, +] dev = [ { name = "azure-identity" }, { name = "botocore" }, @@ -902,7 +905,7 @@ requires-dist = [ { name = "botocore", marker = "extra == 'llm'", specifier = ">=1.35" }, { name = "cld3-py", marker = "extra == 'cld3'", specifier = ">=3.1" }, { name = "cld3-py", marker = "extra == 'dev'", specifier = ">=3.1" }, - { name = "commonlid", extras = ["afrolid", "llm"], marker = "extra == 'all'" }, + { name = "commonlid", extras = ["afrolid", "llm", "commonlingua"], marker = "extra == 'all'" }, { name = "datasets", specifier = ">=3.1.0" }, { name = "dspy", marker = "extra == 'dev'", specifier = ">=2.5" }, { name = "dspy", marker = "extra == 'llm'", specifier = ">=2.5" }, @@ -930,13 +933,14 @@ requires-dist = [ { name = "scikit-learn", specifier = ">=1.5" }, { name = "sentencepiece", marker = "extra == 'afrolid'", specifier = ">=0.2" }, { name = "torch", marker = "extra == 'afrolid'", specifier = ">=2.4" }, + { name = "torch", marker = "extra == 'commonlingua'", specifier = ">=2.4" }, { name = "tqdm", specifier = ">=4.67" }, { name = "transformers", marker = "extra == 'afrolid'", specifier = ">=4.46,<5" }, { name = "typer", specifier = ">=0.12" }, { name = "types-regex", marker = "extra == 'dev'" }, { name = "types-tqdm", marker = "extra == 'dev'" }, ] -provides-extras = ["afrolid", "llm", "cld3", "leaderboard", "dev", "notebooks", "all"] +provides-extras = ["afrolid", "llm", "cld3", "commonlingua", "leaderboard", "dev", "notebooks", "all"] [[package]] name = "contourpy"