#!/usr/bin/env python3
import requests
import time
import random
import string
import sys
import re
import os
import platform
import json
import argparse
import subprocess
import getpass
from datetime import datetime

# ==============================================================================
# LLAMABENCH 6.2 - LLaMA.cpp Benchmarking Tool
# ==============================================================================

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SAVE_DIR = os.path.join(SCRIPT_DIR, "llamabench_logs")

if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

# ------------------------------------------------------------------------------
# GLOBAL LOGGING
# ------------------------------------------------------------------------------

text_report = []

def log_msg(text):
    print(text)
    text_report.append(text)

# ------------------------------------------------------------------------------
# HARDWARE & SYSTEM UTILITIES
# ------------------------------------------------------------------------------

def get_cpu_name():
    if platform.system() == "Windows":
        return platform.processor()
    elif platform.system() == "Linux":
        try:
            with open('/proc/cpuinfo', 'r') as f:
                for line in f:
                    if 'model name' in line:
                        return line.split(':')[1].strip()
        except: pass
    return platform.processor()

def get_sys_ram_usage_mb():
    if platform.system() == "Linux":
        try:
            with open('/proc/meminfo', 'r') as f:
                lines = f.readlines()
            total = free = buffers = cached = 0
            for line in lines:
                if line.startswith('MemTotal:'): total = int(line.split()[1])
                if line.startswith('MemFree:'): free = int(line.split()[1])
                if line.startswith('Buffers:'): buffers = int(line.split()[1])
                if line.startswith('Cached:'): cached = int(line.split()[1])
            used = total - free - buffers - cached
            return used / 1024.0
        except: pass
    return 0

def get_gpu_drivers(container_name=None):
    drivers = {}
    prefix = f"docker exec {container_name} " if container_name else ""

    try:
        out = subprocess.check_output(prefix + "nvidia-smi --query-gpu=driver_version --format=csv,noheader", shell=True, stderr=subprocess.DEVNULL, text=True)
        if out.strip(): drivers["cuda"] = out.strip().split('\n')[0]
    except: pass

    if "cuda" not in drivers and platform.system() == "Linux":
        try:
            out = subprocess.check_output(prefix + "cat /proc/driver/nvidia/version", shell=True, stderr=subprocess.DEVNULL, text=True)
            match = re.search(r'Kernel Module\s+([0-9\.]+)', out)
            if match: drivers["cuda"] = match.group(1)
        except: pass

    try:
        out = subprocess.check_output(prefix + "dpkg-query -W -f='${Version}' rocm-core", shell=True, stderr=subprocess.DEVNULL, text=True)
        if out.strip(): drivers["rocm"] = out.strip().split('-')[0]
    except: pass

    if "rocm" not in drivers:
        try:
            out = subprocess.check_output(prefix + "cat /opt/rocm/.info/version", shell=True, stderr=subprocess.DEVNULL, text=True)
            if out.strip(): drivers["rocm"] = out.strip().split('-')[0]
        except: pass

    if "rocm" not in drivers:
        try:
            out = subprocess.check_output(prefix + "cat /sys/module/amdgpu/version", shell=True, stderr=subprocess.DEVNULL, text=True)
            if out.strip(): drivers["rocm"] = f"amdgpu-{out.strip()}"
        except: pass

    try:
        out = subprocess.check_output(prefix + "vulkaninfo | grep 'Vulkan Instance Version'", shell=True, stderr=subprocess.DEVNULL, text=True)
        match = re.search(r'Version:\s*([0-9\.]+)', out)
        if match: drivers["vulkan"] = match.group(1)
    except: pass

    return drivers

def get_userhost_info():
    info = {
        "host": platform.node(),
        "country": "Unknown",
        "(ISP)": "Unknown",
        "username": "Unknown"
    }
    try:
        info["username"] = getpass.getuser()
    except:
        pass

    try:
        r = requests.get("https://ipinfo.io/json", timeout=3)
        if r.status_code == 200:
            data = r.json()
            info["country"] = data.get("country", "Unknown")
            info["(ISP)"] = data.get("city", "Unknown")
    except:
        pass
    return info

# ------------------------------------------------------------------------------
# DOCKER & NETWORK CONNECTION PRIORITY ENGINE
# ------------------------------------------------------------------------------

def setup_network_and_docker(target_port):
    try:
        ps_out = subprocess.check_output(['docker', 'ps', '--format', '{{.Names}}\t{{.Ports}}'], text=True)
        for line in ps_out.strip().split('\n'):
            if not line: continue
            parts = line.split('\t')
            name = parts[0]
            ports = parts[1] if len(parts) > 1 else ""
            
            matches = re.findall(r'(?:[0-9\.]+|\[::\]):([0-9]+)->[0-9]+/tcp', ports)
            if str(target_port) in matches:
                log_msg(f"[*] Docker container '{name}' detected mapping external port {target_port}.")
                return name, "127.0.0.1", target_port
    except: pass

    try:
        url = f"http://127.0.0.1:{target_port}/v1/models"
        if requests.get(url, timeout=5).status_code == 200:
            log_msg(f"[*] Local API detected on 127.0.0.1:{target_port}. Assuming native host execution.")
            return None, "127.0.0.1", target_port
    except:
        pass

    return None, None, None

def fetch_log_content(container_name):
    if container_name:
        try:
            log_out = subprocess.check_output(['docker', 'logs', container_name], stderr=subprocess.STDOUT, text=True, errors='ignore')
            return log_out
        except: pass
    
    candidates = [os.path.expanduser("~/server.log"), "./server.log", "../server.log", "./build/server.log"]
    for c in candidates:
        if os.path.exists(c):
            try:
                with open(c, 'r', encoding='utf-8', errors='ignore') as f:
                    return f.read()
            except: pass
    return ""

# ------------------------------------------------------------------------------
# LOG PARSER & METRICS EXTRACTION
# ------------------------------------------------------------------------------

def parse_detailed_log(content):
    d = {
        "arch": "Unknown", "vram_model": "0", "ram_model": "0",
        "kv_cache": "0", "kv_type": "unknown", "ssm_rs": "0", "compute": "0", 
        "exp_total": 0, "exp_used": 0, "n_ctx_log": "0",
        "layers_gpu": "0", "layers_total": "0", "mmap": "unknown",
        "threads": "Unknown", "gpus": [], "n_ctx_train": "0",
        "vram_model_brk": "", "kv_cache_brk": "", "ssm_rs_brk": "", "compute_brk": "",
        "model_file": "Unknown"
    }
    if not content: return d

    def extract_multi_device_mem(log_text, buffer_name):
        pattern = r'(CUDA\d+|CUDA_Host|ROCm\d+|ROCm_Host|RPC\d+|Vulkan\d+|Vulkan_Host|CPU)(?:\[.*?\])?\s+' + buffer_name + r'\s+buffer size\s+=\s+([\d.]+)'
        matches = re.findall(pattern, log_text)
        total = 0.0
        breakdown = []
        for dev, val_str in matches:
            val = float(val_str)
            total += val
            breakdown.append(f"{dev}: {val:.1f}")
        if total > 0:
            return f"{total:.2f}", " | ".join(breakdown)
        return "0", ""

    try:
        model_match = re.search(r"load_model:\s+loading model\s+'([^']+)'", content)
        if model_match:
            d["model_file"] = model_match.group(1).split('/')[-1].split('\\')[-1]

        arch = re.findall(r'print_info: arch\s+=\s+(\w+)', content)
        if arch: d["arch"] = arch[-1]
        
        exp_t = re.findall(r'print_info: n_expert\s+=\s+(\d+)', content)
        if exp_t and exp_t[-1] != "0": d["exp_total"] = int(exp_t[-1])
        
        exp_u = re.findall(r'print_info: n_expert_used\s+=\s+(\d+)', content)
        if exp_u and exp_u[-1] != "0": d["exp_used"] = int(exp_u[-1])

        univ_mem_pattern = r'([A-Za-z0-9_]+)\s*(?:\([^)]+\))?\s*\|\s*\d+\s*=\s*\d+\s*\+\s*\(\s*\d+\s*=\s*(\d+)\s*\+\s*(\d+)\s*\+\s*(\d+)\s*\)'
        univ_matches = re.findall(univ_mem_pattern, content)
        
        if univ_matches:
            tot_model = tot_kv = tot_comp = 0.0
            brk_model, brk_kv, brk_comp = [], [], []
            for match in univ_matches:
                dev = match[0]
                m_val = float(match[1])
                k_val = float(match[2])
                c_val = float(match[3])
                
                tot_model += m_val; tot_kv += k_val; tot_comp += c_val
                if m_val > 0: brk_model.append(f"{dev}: {m_val:.1f}")
                if k_val > 0: brk_kv.append(f"{dev}: {k_val:.1f}")
                if c_val > 0: brk_comp.append(f"{dev}: {c_val:.1f}")
            
            d["vram_model"] = f"{tot_model:.2f}"; d["vram_model_brk"] = " | ".join(brk_model)
            d["kv_cache"] = f"{tot_kv:.2f}"; d["kv_cache_brk"] = " | ".join(brk_kv)
            d["compute"] = f"{tot_comp:.2f}"; d["compute_brk"] = " | ".join(brk_comp)
        else:
            d["vram_model"], d["vram_model_brk"] = extract_multi_device_mem(content, "model")
            d["kv_cache"], d["kv_cache_brk"] = extract_multi_device_mem(content, "KV")
            d["compute"], d["compute_brk"] = extract_multi_device_mem(content, "compute")

        d["ssm_rs"], d["ssm_rs_brk"] = extract_multi_device_mem(content, "RS")
        
        ram = re.findall(r'CPU_Mapped model buffer size\s+=\s+([\d.]+)', content)
        if ram: d["ram_model"] = ram[-1]

        kv_type = re.findall(r'K \((.*?)\):', content)
        if kv_type: d["kv_type"] = kv_type[-1]
        
        ctx_train = re.findall(r'print_info: n_ctx_train\s+=\s+(\d+)', content)
        if ctx_train: d["n_ctx_train"] = ctx_train[-1]

        ctx_fit = re.findall(r'context size reduced from \d+ to (\d+)', content)
        if ctx_fit: d["n_ctx_log"] = ctx_fit[-1]
        else:
            ctx_log = re.findall(r'llama_context:\s+n_ctx\s+=\s+(\d+)', content)
            if ctx_log: d["n_ctx_log"] = ctx_log[-1]

        layers = re.findall(r'offloaded (\d+)/(\d+) layers to GPU', content)
        if layers:
            d["layers_gpu"] = layers[-1][0]
            d["layers_total"] = layers[-1][1]
            
        mmap_match = re.findall(r'mmap\s*=\s*(true|false)', content)
        if mmap_match: d["mmap"] = mmap_match[-1].lower()
            
        threads_match = re.search(r'system info: n_threads = (\d+), n_threads_batch = \d+, total_threads = (\d+)', content)
        if threads_match:
            d["threads"] = f"{threads_match.group(1)} threads active (Total: {threads_match.group(2)})"
        
        gpu_matches = re.findall(r'using device (\w+)\s*\((.*?)\)\s*\(([^)]+)\)\s*-\s*([^\n]+)', content)
        if gpu_matches:
            gpus_dict = {}
            for dev_id, basic_name, pci_addr, mem_info in gpu_matches:
                gpus_dict[pci_addr] = f"[{dev_id}] {basic_name} ({mem_info.strip()})"
            d["gpus"] = list(gpus_dict.values())
        else:
            gpu_lines = re.findall(r'(Device \d+:[^\n]+)', content)
            if gpu_lines: d["gpus"] = list(dict.fromkeys([g.strip() for g in gpu_lines]))
            
    except Exception as e: print(f"Error parsing log: {e}")
    return d

# ------------------------------------------------------------------------------
# BENCHMARK ENGINE
# ------------------------------------------------------------------------------

def print_comparison_chart(metric_name, avg_val, server_val):
    log_msg(f"\n 📈 {metric_name} summary (tok/s)")
    if avg_val <= 0:
        log_msg(" No data to display.")
        return
    if server_val <= 0:
        bar_len_avg = 25
        bar_avg = "▓" * bar_len_avg
        log_msg(f" benchmark(average): {bar_avg} {avg_val:.1f}")
        log_msg(f" server data:        [no --metrics]")
        return

    max_val = max(avg_val, server_val)
    bar_len_avg = int((avg_val / max_val) * 25)
    bar_len_srv = int((server_val / max_val) * 25)
    diff_pct = ((server_val - avg_val) / avg_val * 100) if avg_val > 0 else 0
    sign = "+" if diff_pct > 0 else ""

    log_msg(f" benchmark(average): {'▓' * bar_len_avg} {avg_val:.1f}")
    log_msg(f" server data:        {'▓' * bar_len_srv} {server_val:.1f} ({sign}{diff_pct:.1f}%)")

def run_bench():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('-r', '--rounds', type=int, default=5)
    parser.add_argument('-t', '--time', type=float, default=None)
    parser.add_argument('-c', '--context', type=int, default=None)
    parser.add_argument('-p', '--port', default='8080')
    parser.add_argument('-h', '--help', action='store_true')
    args, unknown = parser.parse_known_args()

    if args.help:
        print("""
================================================================================
LLAMA BENCH v6.2 - HELP
================================================================================

Usage: llamabench6.2.py [OPTIONS]

Options:
  -r, --rounds N     Number of test rounds (default: 5). If set to 7, runs 7 test cycles.
  -t, --time N       Time limit per round in seconds (default: unlimited).
                     If set, each round ends after N seconds and next round begins.
  -c, --context N    Context token limit per round (default: 128).
                     Round ends when this many tokens are generated.
  -p, --port N       Target API port (default: 8080).
  -h, --help         Show this help message and exit.

Behavior:
  - If both -t and -c are set, round ends when FIRST condition is met.
  - Passive Mode: The script connects to an existing LLaMA.cpp server. It does not kill or start processes.
  - Priority: Tries 127.0.0.1:[PORT] natively first, then checks Docker mapped ports.
  - Dual JSON: Saves main results to a lightweight JSON and raw server logs to a separate 'serverlog_*.json'.

Server Requirements:
  By default, the log file is read from ~/server.log on Linux and Windows systems.
  Therefore, llama.cpp must always be run with the following flags:
    --metrics
    --log-file ~/server.log
    --log-colors off
    --flash-attn on

Output:
  Results are saved to the 'llamabench_logs/' directory (created alongside this script) in JSON and TXT formats.
""")
        sys.exit(0)

    tg_tokens = args.context if args.context is not None else (-1 if args.time is not None else 128)
    target_port = args.port
    
    c_name, host_ip, api_port = setup_network_and_docker(target_port)
    if not host_ip:
        print(f"❌ ERROR: Cannot reach API on port {target_port}. Ensure LLaMA.cpp is running. Passive Mode active - exiting.")
        sys.exit(1)

    base_url = f"http://{host_ip}:{api_port}"
    try:
        m_resp = requests.get(f"{base_url}/v1/models").json()
        props = requests.get(f"{base_url}/props").json()
    except Exception as e:
        print(f"❌ Connection error on port {api_port}: {e}"); sys.exit(1)

    m_id = m_resp["data"][0]["id"]
    log_content = fetch_log_content(c_name)
    log_data = parse_detailed_log(log_content)
    
    if log_data.get("model_file") and log_data["model_file"] != "Unknown":
        m_id = log_data["model_file"]

    cpu_name = get_cpu_name()
    drivers = get_gpu_drivers(c_name)
    server_build = props.get('build_info', 'Unknown')
    runin_env = "Docker" if c_name else "Host"
    userhost_info = get_userhost_info()

    log_msg("\n" + "=" * 105)
    log_msg(f"{'--- SYSTEM & HARDWARE ---':<50} | {'--- MEMORY ALLOCATION ---'}")
    log_msg("-" * 105)
    
    mmap_val = "ON" if log_data['mmap'] == "true" else ("OFF" if log_data['mmap'] == "false" else "UNKNOWN")
    tot_l = int(log_data['layers_total']) if log_data['layers_total'] != "0" else 0
    gpu_l = int(log_data['layers_gpu']) if log_data['layers_gpu'] != "0" else 0
    ram_l = tot_l - gpu_l if tot_l > 0 else 0
    
    ram_bpe = f"{log_data['ram_model']:>8} MiB" if ram_l == 0 else f"{'Inc. below':>8}"
    ram_layers = f"{'0.00':>8} MiB (0/{tot_l} layers)" if ram_l == 0 else f"{log_data['ram_model']:>8} MiB ({ram_l}/{tot_l} layers + BPE)"

    col1 = [
        f"Source:       {'Docker (' + c_name + ')' if c_name else 'Native Host'}",
        f"System:       {platform.system()} {platform.release()}",
        f"Architecture: {log_data['arch']}",
        f"Detected Port:{api_port}",
        f"MMAP Status:  {mmap_val}"
    ]
    if log_data['exp_total'] > 0:
        col1.insert(4, f"MoE Experts:  {log_data['exp_used']} active / {log_data['exp_total']} total")
    
    col2 = [
        f"VRAM Model:   {log_data['vram_model']:>8} MiB ({gpu_l}/{tot_l} layers)" + (f" [{log_data['vram_model_brk']}]" if log_data['vram_model_brk'] else ""),
        f"KV Cache:     {log_data['kv_cache']:>8} MiB [{log_data['kv_type']}]" + (f" [{log_data['kv_cache_brk']}]" if log_data['kv_cache_brk'] else ""),
        f"Compute Buf:  {log_data['compute']:>8} MiB" + (f" [{log_data['compute_brk']}]" if log_data['compute_brk'] else ""),
        f"RAM BPE/Meta: {ram_bpe}",
        f"RAM Layers:   {ram_layers}"
    ]
    if log_data['ssm_rs'] != "0": col2.insert(3, f"SSM/RS State: {log_data['ssm_rs']:>8} MiB" + (f" [{log_data['ssm_rs_brk']}]" if log_data['ssm_rs_brk'] else ""))
        
    for c1, c2 in zip(col1 + [""] * max(0, len(col2) - len(col1)), col2 + [""] * max(0, len(col1) - len(col2))):
        log_msg(f"{c1:<50} | {c2}")
        
    log_msg("-" * 105)
    log_msg(f"CPU Model:    {cpu_name}")
    log_msg(f"GPU Drivers:  {', '.join([f'{k.upper()}: {v}' for k, v in drivers.items()]) if drivers else 'Unknown'}")
    log_msg("GPUs:")
    for gpu in log_data['gpus'] or ["  Unknown Device"]: log_msg(f"  {gpu}")
    log_msg("=" * 105 + "\n")

    log_ctx, train_ctx = int(log_data.get('n_ctx_log', 0)), int(log_data.get('n_ctx_train', 0))
    final_ctx = log_ctx if log_ctx > 0 else (props.get('default_generation_settings', {}).get('n_ctx', 0) or props.get('n_ctx', 0))
    ctx_warning = f" (Reduced from {train_ctx:,} due to lack of VRAM!)" if train_ctx > 0 and final_ctx < train_ctx and final_ctx == log_ctx else ""

    log_msg("=" * 60)
    log_msg(f"{'--- System Breakdown ---':^60}")
    log_msg(f"Server Build:   {server_build}")
    log_msg(f"Context Limit:  {final_ctx:,} tokens{ctx_warning}")
    log_msg("-" * 60)
    log_msg(f"Benchmarking: {m_id} ({args.rounds} rounds)")
    log_msg("-" * 60)

    try:
        check_p = {"prompt": "<|im_start|>system\nYou are a loyal servant.<|im_end|>\n<|im_start|>user\nWrite exactly: Boss, I'm so ready I feel like I'm not ready.<|im_end|>\n<|im_start|>assistant\n", "n_predict": 100, "temperature": 0.0, "reasoning_budget": 0}
        resp_text = requests.post(f"{base_url}/completion", json=check_p).json().get('content','').strip()
        
        resp_text = resp_text.replace('```', '')
        parts = [p.strip() for p in resp_text.split('\n') if p.strip()]
        seen = set()
        final_parts = []
        for p in parts:
            norm_p = re.sub(r'[^a-zA-Z0-9]', '', p).lower()
            if norm_p not in seen:
                seen.add(norm_p)
                final_parts.append(p)
        resp_text = "\n".join(final_parts)
        
        log_msg(f"Model response: '{resp_text}'\nStability: OK.")
    except Exception as e:
        log_msg(f"Error during stability check: {e}"); sys.exit(1)

    log_msg("-" * 60)
    
    results = {"pp": [], "tg": [], "ttft": [], "tg_ms": [], "tg_tok": []}
    first_ttft = None
    cached_rounds = 0
    
    ram_before = get_sys_ram_usage_mb()

    for i in range(args.rounds):
        salt = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8))
        
        r1 = requests.post(f"{base_url}/completion", json={"prompt": f"User: Process the following sequence:\n\n{'word ' * 450} ID: {salt}\n\nAssistant: ", "n_predict": 1, "temperature": 0.0, "reasoning_budget": 0}).json()
        t1 = r1.get("timings", {})
        
        ttft_val = t1.get("prompt_ms", 0)
        pp_val = t1.get("prompt_per_second", 0)
        
        is_cached = False
        if first_ttft is None and ttft_val > 0:
            first_ttft = ttft_val
            results["pp"].append(pp_val)
        elif first_ttft is not None and ttft_val < first_ttft * 0.5:
            is_cached = True
            cached_rounds += 1
        else:
            results["pp"].append(pp_val)

        results["ttft"].append(ttft_val)
        
        if args.time is None:
            r2 = requests.post(f"{base_url}/completion", json={"prompt": "User: Explain the concept of entropy in detail.\nAssistant: ", "n_predict": tg_tokens, "temperature": 0.0, "ignore_eos": True, "reasoning_budget": 0}).json()
            t2 = r2.get("timings", {})
            tg_val, gen_ms, gen_tok = t2.get("predicted_per_second", 0), t2.get("predicted_ms", 0), t2.get("predicted_n", 0)
        else:
            r2_req = requests.post(f"{base_url}/completion", json={"prompt": "User: Explain the concept of entropy in detail.\nAssistant: ", "n_predict": tg_tokens, "temperature": 0.0, "ignore_eos": True, "stream": True, "reasoning_budget": 0}, stream=True)
            f_t, l_t, gen_tok, t_start = None, None, 0, None
            for line in r2_req.iter_lines():
                if line and line.decode('utf-8').startswith("data: "):
                    data = line.decode('utf-8')[6:]
                    if data == "[DONE]": break
                    try:
                        if f_t is None: f_t = t_start = time.time()
                        l_t = time.time()
                        gen_tok += 1
                        if t_start and (time.time() - t_start) >= args.time: break
                    except: pass
            gen_ms = (l_t - f_t) * 1000 if f_t and l_t and gen_tok > 1 else 0
            tg_val = (gen_tok - 1) / (gen_ms / 1000) if gen_ms > 0 else 0

        if tg_val > 100000:
            tg_val = 0.0

        results["tg"].append(tg_val)
        results["tg_ms"].append(gen_ms)
        results["tg_tok"].append(gen_tok)
        
        cache_tag = " [CACHED PROMPT]" if is_cached else ""
        log_msg(f"Round {i+1:02d}: PP = {pp_val:>8.2f} t/s | TG = {tg_val:>6.2f} t/s | TTFT = {ttft_val:>7.2f} ms | Gen Time = {gen_ms:>7.2f} ms ({gen_ms/1000.0:.2f} s) | Tokens = {gen_tok}{cache_tag}")

    ram_after = get_sys_ram_usage_mb()
    ram_diff = ram_after - ram_before
    gtt_warning = "VRAM Capping Detected (GTT Leak)" if ram_diff > 1024 else None

    avg_pp = sum(results["pp"]) / len(results["pp"]) if results["pp"] else 0
    avg_tg = sum([v for v in results["tg"] if v > 0]) / len([v for v in results["tg"] if v > 0]) if any(v > 0 for v in results["tg"]) else 0
    avg_ttft = sum(results["ttft"]) / len(results["ttft"]) if results["ttft"] else 0
    avg_gen_ms = sum(results["tg_ms"]) / len(results["tg_ms"]) if results["tg_ms"] else 0
    avg_tok = sum(results["tg_tok"]) / len(results["tg_tok"]) if results["tg_tok"] else 0
    
    log_msg("\n" + "="*60)
    log_msg(f"FINAL AVERAGES - {m_id}")
    log_msg("-" * 60)
    log_msg(f"Configured Token Limit (TG): {tg_tokens if tg_tokens != -1 else 'Unlimited (Time)'}")
    if cached_rounds > 0: log_msg(f"Prompt Caching Detected:     Ignored {cached_rounds} cached rounds for PP Average.")
    log_msg(f"Average Tokens Generated:    {avg_tok:.1f} tokens")
    log_msg(f"Average Latency (TTFT):      {avg_ttft:.2f} ms")
    log_msg(f"Average Gen Time (TG):       {avg_gen_ms:.2f} ms ({avg_gen_ms/1000.0:.2f} s)")

    try:
        met = requests.get(f"{base_url}/metrics").text
        srv_tg = float(re.search(r'llamacpp:predicted_tokens_seconds\s+([\d.]+)', met).group(1))
        srv_pp = float(re.search(r'llamacpp:prompt_tokens_seconds\s+([\d.]+)', met).group(1))
    except: srv_tg, srv_pp = 0, 0

    print_comparison_chart("Token Generation", avg_tg, srv_tg)
    print_comparison_chart("Prompt Processing", avg_pp, srv_pp)
    
    if gtt_warning:
        log_msg(f"\n⚠️ PERFORMANCE WARNING: {gtt_warning} (+{ram_diff:.1f} MB RAM spiked)")
        
    log_msg("=" * 60)
    
    gpu_tags = [re.search(r'\[(.*?)\]', g).group(1) for g in log_data["gpus"] if re.search(r'\[(.*?)\]', g)]
    gpu_suffix = "-".join(gpu_tags) if gpu_tags else "CPU"
    safe_name = m_id.replace('.gguf', '')
    file_base = f"{safe_name}_{gpu_suffix}_ctx{final_ctx}_{platform.system().lower()}"
    
    json_path = os.path.join(SAVE_DIR, f"result_{file_base}.json")
    serverlog_path = os.path.join(SAVE_DIR, f"serverlog_{file_base}.json")

    mem_data_dict = {
        "VRAM Model": f"{log_data['vram_model']:>8} MiB ({gpu_l}/{tot_l} layers)" + (f" [{log_data['vram_model_brk']}]" if log_data['vram_model_brk'] else ""),
        "KV Cache": f"{log_data['kv_cache']:>8} MiB [{log_data['kv_type']}]" + (f" [{log_data['kv_cache_brk']}]" if log_data['kv_cache_brk'] else ""),
        "Compute Buf": f"{log_data['compute']:>8} MiB" + (f" [{log_data['compute_brk']}]" if log_data['compute_brk'] else ""),
        "SSM/RS State": f"{log_data['ssm_rs']:>8} MiB" + (f" [{log_data['ssm_rs_brk']}]" if log_data['ssm_rs_brk'] else "") if log_data['ssm_rs'] != "0" else None,
        "RAM BPE/Meta": ram_bpe,
        "RAM Layers": ram_layers
    }

    final_json = {
        "meta": {
            "timestamp": datetime.now().isoformat(), 
            "os": platform.platform(), 
            "arch": log_data["arch"], 
            "model": m_id, 
            "mmap": log_data["mmap"], 
            "sbuild": server_build,
            "runin": runin_env,
            "drivers": drivers
        },
        "userhost": userhost_info,
        "settings": {"rounds": args.rounds, "time": args.time or "\u221E", "context": args.context if args.context is not None else (128 if args.time is None else "\u221E")},
        "hardware": {"cpu_model": cpu_name, "threads": log_data["threads"], "gpus": log_data["gpus"]},
        "memory": {
            "vram_model_mib": float(log_data["vram_model"]), "ram_model_mib": float(log_data["ram_model"]),
            "kv_cache_mib": float(log_data["kv_cache"]), "kv_type": log_data["kv_type"],
            "ssm_mib": float(log_data["ssm_rs"]), "layers_gpu": int(log_data["layers_gpu"]), "layers_total": int(log_data["layers_total"])
        },
        "moe_diagnostics": {"expert_count": log_data["exp_total"], "expert_used_count": log_data["exp_used"]} if log_data["exp_total"] > 0 else None,
        "memory_data": {k: v for k, v in mem_data_dict.items() if v is not None},
        "results": {"pp": avg_pp, "tg": avg_tg, "ttft": avg_ttft, "gen_ms": avg_gen_ms, "tokens": avg_tok}
    }
    
    if gtt_warning: final_json["performance_warning"] = gtt_warning

    serverlog_json = {
        "data_type": "detailed_server_log",
        "serverlog_model": m_id,
        "serverlog_tg": avg_tg,
        "serverlog_pp": avg_pp,
        "serverlog_ttft": avg_ttft,
        "serverlog_vram_model": float(log_data["vram_model"]),
        "serverlog_benchmark_row_output": "\n".join(text_report),
        "serverlog_raw_content": log_content
    }
    
    with open(json_path, "w", encoding="utf-8") as f: json.dump({k: v for k, v in final_json.items() if v is not None}, f, indent=4)
    with open(serverlog_path, "w", encoding="utf-8") as f: json.dump(serverlog_json, f, indent=4)
    
    print(f"\nReport generated successfully in '{SAVE_DIR}':")
    print(f"   - Main Results: {os.path.basename(json_path)}")
    print(f"   - Server Logs:  {os.path.basename(serverlog_path)}")

if __name__ == "__main__":
    try: run_bench()
    except KeyboardInterrupt: sys.exit(0)