diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 11fe169cf..5eec03532 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -192,6 +192,11 @@ def __init__( type_v: KV cache data type for V (default: f16) spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. + Note: + Recurrent and hybrid models (Mamba, RWKV, Nemotron-A3B, Jamba) cannot + rewind their state and require full reset on history edits. This is handled + automatically to maintain compatibility. Standard transformers are unaffected. + Raises: ValueError: If the model path does not exist. @@ -553,6 +558,11 @@ def free_lora_adapter(): self._sampler = None + # Cache recurrent/hybrid model detection to avoid repeated FFI calls + self._is_recurrent_model = llama_cpp.llama_model_is_recurrent( + self._model.model + ) or llama_cpp.llama_model_is_hybrid(self._model.model) + @property def ctx(self) -> llama_cpp.llama_context_p: return self._ctx.ctx @@ -580,6 +590,19 @@ def eval_logits(self) -> Deque[List[float]]: maxlen=self._n_ctx if self._logits_all else 1, ) + @property + def _is_recurrent(self) -> bool: + """Check if model is recurrent (SSM) or hybrid (SSM+Attention). + + These models (Mamba, RWKV, Nemotron, Jamba, etc.) cannot rewind their + recurrent state without snapshots. Only strict forward progression or + full reset is allowed. + + Returns: + True if model has recurrent state that cannot be rewound. + """ + return self._is_recurrent_model + def tokenize( self, text: bytes, add_bos: bool = True, special: bool = False ) -> List[int]: @@ -638,6 +661,11 @@ def reset(self): """Reset the model state.""" self.n_tokens = 0 + if self._is_recurrent: + mem = llama_cpp.llama_get_memory(self._ctx.ctx) + if mem is not None: + llama_cpp.llama_memory_clear(mem, True) + def eval(self, tokens: Sequence[int]): """Evaluate a list of tokens. @@ -888,11 +916,22 @@ def generate( # Check for kv cache prefix match if reset and self.n_tokens > 0: longest_prefix = 0 - for a, b in zip(self._input_ids, tokens[:-1]): + for a, b in zip(self._input_ids, tokens): if a == b: longest_prefix += 1 else: break + + # Recurrent models cannot rewind state; reset if needed + if self._is_recurrent and longest_prefix < self.n_tokens: + longest_prefix = 0 + reset = True + if self.verbose: + print( + "Llama.generate: recurrent model requires full state reset", + file=sys.stderr, + ) + if longest_prefix > 0: if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1): reset = False diff --git a/vendor/llama.cpp b/vendor/llama.cpp index f49e91787..d00685831 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit f49e9178767d557a522618b16ce8694f9ddac628 +Subproject commit d006858316d4650bb4da0c6923294ccd741caefd