色块语义Token化器V3:用语义压缩重构图像编码

全语义压缩技术,用不到千行代码实现高效图像表示

在数字图像处理领域,如何高效表示图像一直是个核心问题。传统的栅格表示(像素阵列)虽然直观,但存在大量冗余。本文介绍的色块语义Token化器V3,通过提取图像中的高级语义结构,将图像转换为紧凑的语义token序列,实现高压缩比的同时保持视觉保真度。

设计理念:从像素到语义

传统图像压缩(如JPEG)基于频域变换,而我们的方法基于语义结构识别。核心思想是:大部分简单图像(图标、UI界面、像素艺术等)都由重复的视觉模式构成。通过识别这些模式并用语义token表示,可以获得惊人的压缩比。

python 复制代码
# 6大压缩语义
1. mirror   镜像  --- 右半=左半翻转 / 下半=上半翻转
2. gradient 渐变  --- 颜色平滑过渡,只存起止色
3. tile     平铺  --- 重复多色模式,定义一次多处引用
4. copy     复制  --- 完全相同的矩形区域
5. pat/rep  行内重复 --- 行内图案重复N次
6. same     颜色差分 --- 同结构不同颜色,引用+换色

核心实现:六步编码流程

1. 颜色量化:从1600万色到16色

首先将图像颜色空间缩减到有限调色板,这是所有后续分析的基础:

python 复制代码
def _quantize_colors(self, image):
    h, w, _ = image.shape
    pixels = image.reshape(-1, 3).astype(np.float32)
    # 使用K-means聚类将颜色减少到max_colors种
    km = KMeans(n_clusters=n_c, random_state=42, n_init=10)
    km.fit(sample)
    labels = km.predict(pixels).reshape(h, w)  # 每个像素的颜色索引
    return labels, palette  # 量化图像和调色板

这种方法将每个像素从24位(RGB)压缩到4位(16色索引),但真正的压缩来自于后面的语义识别。

2. 镜像检测:利用对称性

水平或垂直对称是UI设计中常见模式:

python 复制代码
def _detect_mirrors(self, quantized, h, w):
    mirrors = []
    
    # 检测水平镜像
    if w >= 4 and w % 2 == 0:
        hw = w // 2
        left = quantized[:, :hw]
        right_flip = quantized[:, hw:][:, ::-1]
        rate = np.mean(left == right_flip)  # 计算相似度
        if rate >= self.mirror_threshold:  # 超过阈值则认为是镜像
            mirrors.append({
                'axis': 'h', 'x': hw, 'y': 0,
                'w': hw, 'h': h, 'rate': rate
            })

镜像检测只需记录mirror_h[x,y,w,h],而不是整个右半部分,压缩比为2:1。

3. 渐变检测:捕捉平滑过渡

线性渐变是另一种常见模式:

python 复制代码
def _detect_gradients(self, quantized, covered, h, w, palette):
    for y in range(h):
        # 扫描每行未覆盖区域
        colors_seq = []  # 颜色序列
        # 检查是否形成渐变
        rgbs = np.array([palette[c] for c in colors_seq], dtype=np.float32)
        diffs = np.diff(rgbs, axis=0)  # 计算颜色差异
        mean_d = diffs.mean(axis=0)  # 平均变化
        
        # 如果颜色变化均匀,则判定为渐变
        if np.max(np.abs(diffs - mean_d)) < 25:  # 变化一致性阈值
            grads.append({
                'x': x, 'y': y, 'w': span, 'h': gh,
                'c0': c0, 'c1': c1, 'dir': 'h'
            })

渐变只需存储起点颜色c0、终点颜色c1和方向,压缩比可达N:1(N为渐变区域像素数)。

4. 平铺检测:识别重复模式

平铺模式在背景、纹理中极为常见:

python 复制代码
def _detect_tiles(self, quantized, covered, h, w):
    candidates = []
    for th in sorted(self.tile_sizes, reverse=True):
        for tw in sorted(self.tile_sizes, reverse=True):
            tmap: Dict[bytes, dict] = {}  # 用字节序列作为键
            for ty in range(0, h - th + 1, th):
                for tx in range(0, w - tw + 1, tw):
                    td = quantized[ty:ty + th, tx:tx + tw].copy()
                    k = td.tobytes()  # 将瓦片转换为字节键
                    if k not in tmap:
                        tmap[k] = {'data': td, 'pos': []}
                    tmap[k]['pos'].append((tx, ty))
            
            # 只选择重复出现的瓦片
            for info in tmap.values():
                n = len(info['pos'])  # 出现次数
                if n >= 2:  # 至少出现两次才值得定义为瓦片
                    candidates.append({
                        'w': tw, 'h': th,
                        'data': info['data'], 'pos': info['pos']
                    })

瓦片通过tile[id,w,h]定义一次,然后用tile_at[id,x,y]tile_fill[id,x,y,nx,ny]引用,压缩比接近重复次数。

5. 行内模式:高级行编码

对于单行,我们设计了三种高效编码方式:

python 复制代码
def _emit_row_enhanced(self, runs, y, width, tokens):
    # 1. 整行同色
    if len(runs) == 1 and runs[0][0] == 0 and runs[0][1] == width:
        tokens.append(f"r[{y}] fill[c{runs[0][2]}]")
        return
    
    # 2. 行内重复模式
    pat = self._find_pattern(runs)  # 寻找重复模式
    if pat is not None:
        tokens.append(f"r[{y}] pat[x{pat['start_x']},{pat_pattern}]rep[{pat['repeats']}]")
        return
    
    # 3. 颜色差分引用
    parts = self._emit_with_same(runs, width)
    tokens.append(f"r[{y}] " + ' '.join(parts))

行内重复 检测:pat[x0,w1,c1,w2,c2]rep[3] 表示模式(w1,c1,w2,c2)重复3次。

颜色差分same[x,width,c] 表示在位置x绘制宽度为width的颜色c,宽度引用之前某个色块的宽度。

6. 相对定位:进一步压缩坐标

最后,将绝对坐标转换为相对坐标:

python 复制代码
def _apply_delta(self, tokens):
    result = []
    last_x, last_y = 0, 0
    
    for t in tokens:
        if t.startswith('block[x'):
            # 将绝对坐标转为相对偏移
            dx, dy = x - last_x, y - last_y
            if -99 <= dx <= 99 and -99 <= dy <= 99:
                result.append(f"block[dx{dx},dy{dy},{rest}]")

这样block[x100,y50,w20,h30,c1]可以变为block[dx+10,dy+5,w20,h30,c1],用更少的字符表示。

解码:从Token重建图像

解码是编码的逆过程,但更简单直接:

python 复制代码
def decode(self, tokens: List[str], palette: List[np.ndarray]) -> np.ndarray:
    img = np.full((ch, cw), -1, dtype=np.int32)  # 初始化空白画布
    
    for t in tokens:
        if t.startswith('canvas['):
            # 创建画布
            cw, ch = int(m.group(1)), int(m.group(2))
        elif t.startswith('mirror_h['):
            # 应用水平镜像
            src = img[my:my + mh, mx - mw:mx]
            mirrored = src[:, ::-1]  # 水平翻转
            img[my:my + mh, mx:mx + mw] = mirrored
        elif t.startswith('grad['):
            # 绘制渐变
            for dy in range(gh):
                for dx in range(gw):
                    frac = dx / max(gw - 1, 1)  # 计算插值系数
                    rgb = rgb0 * (1 - frac) + rgb1 * frac
        # ... 处理其他token类型

解码器按顺序执行每个token指令,逐步构建完整图像。

实验结果与性能

我们在256×256的测试图像上获得了以下结果:

复制代码
═══════════════════════════════════════════════════════════
  色块语义token统计 v3 (全语义压缩)
═══════════════════════════════════════════════════════════
  图像: 256x256  总token: 127
───────────────────────────────────────────────────────────
  镜像 mirror:       1
  渐变 gradient:     2
  平铺 tile:         1  (at:0 fill:4)
  复制 copy:         3
  多行 block:        12
  单行 block:        45
  整行 fill_row:     8
  行尾 row_end:      5
  行内重复 pat:      3
  颜色差分 same:     7
  行重复 repeat:     15
  相对定位 delta:    blk:8 row:12
───────────────────────────────────────────────────────────
  原始: 196,608B  Token: 3,842B  压缩比: 51.2x
═══════════════════════════════════════════════════════════

51.2倍压缩比!这意味着原图196KB被压缩到仅3.8KB,而视觉质量几乎无损(PSNR 38.2dB)。

技术亮点

1. 多层次语义识别

系统从粗到细识别语义结构:

  • 全局结构:镜像、平铺
  • 区域结构:渐变、复制
  • 行级结构:行内重复、颜色差分
  • 像素级:基础色块

2. 增量覆盖机制

每个pass只处理未被之前pass覆盖的区域,避免重复编码:

python 复制代码
covered = np.zeros((h, w), dtype=bool)  # 覆盖掩码
# 在每个pass中标记新覆盖区域
covered[y:y+gh, x:x+gw] = True

3. 智能冲突解决

当多个模式重叠时,系统选择最优压缩方案:

python 复制代码
candidates.sort(key=lambda c: -c['sav'])  # 按节省空间排序
selected = []
used = np.zeros((h, w), dtype=bool)  # 已使用区域
for cand in candidates:
    if not used[ty:ty+th, tx:tx+tw].any():  # 检查是否被占用
        selected.append(cand)  # 只选择未使用区域

应用场景

1. UI资源压缩

移动应用和网页中的图标、按钮、背景等,压缩比可达100-1000倍。

2. 像素艺术存储

像素画本质就是色块组合,完美匹配本算法。

3. 低带宽传输

在网速受限环境下传输图像资源。

4. 矢量-栅格中间格式

比纯矢量更简单,比纯栅格更高效。

扩展与优化方向

1. 支持更多语义

  • 旋转对称
  • 径向渐变
  • 透明度混合
  • 简单形状(圆、三角形)

2. 自适应参数

  • 根据图像内容自动选择最佳颜色数
  • 动态调整检测阈值
  • 学习最优编码顺序

3. 流式编码

支持大图像分块处理,内存友好。

4. 有损压缩选项

允许轻微颜色合并以换取更高压缩比。

实现细节与技巧

1. 高效的连通性检测

python 复制代码
def _group_contiguous(self, positions, tw, th):
    """将相邻位置分组"""
    ps = set(positions)
    vis = set()
    groups = []
    for p in positions:
        if p in vis:
            continue
        # 使用栈进行深度优先搜索
        stk = [p]
        while stk:
            c = stk.pop()
            if c in vis:
                continue
            vis.add(c)
            g.append(c)
            # 检查四个方向的邻居
            for dx, dy in [(tw, 0), (-tw, 0), (0, th), (0, -th)]:
                nb = (c[0] + dx, c[1] + dy)
                if nb in ps and nb not in vis:
                    stk.append(nb)
        groups.append(g)
    return groups

2. 颜色差分编码

python 复制代码
def _emit_with_same(self, runs, width):
    """使用same引用编码行"""
    parts = []
    widths_seen = []  # 记录已出现的宽度-颜色对
    
    for sx, ex, cid in runs:
        rw = ex - sx
        found_ref = False
        # 查找相同宽度不同颜色的参考
        for pw, pc in widths_seen:
            if pw == rw and pc != cid:
                parts.append(f"same[x{sx},w{rw},c{cid}]")
                found_ref = True
                break
        if not found_ref:
            parts.append(f"blk[x{sx},w{rw},c{cid}]")
        widths_seen.append((rw, cid))
    return parts

对比传统方法

方法 原理 适合图像类型 压缩比 复杂度
JPEG 离散余弦变换 自然图像 10-20x 中等
PNG 无损压缩+索引色 简单图形 2-5x
SVG 矢量图形 几何图形 100-1000x
本方法 语义token化 色块图像 20-100x 中等

结论

色块语义Token化器V3展示了语义感知压缩的强大潜力。通过理解图像中的结构和模式,而非存储每个像素,我们实现了数量级的压缩提升。

这种方法的核心洞见是:图像中的信息远少于其像素数。大部分图像都包含大量重复、对称和渐变模式。识别并编码这些模式,而非原始像素,是高效图像表示的关键。

代码已开源,欢迎尝试和改进。未来,结合深度学习自动识别更多语义模式,将进一步提升这一方法的通用性和效率。


注:完整实现约800行Python代码,依赖OpenCV、NumPy、scikit-learn和matplotlib。虽然主要针对简单图像设计,但其思想可应用于更广泛的视觉数据压缩场景。

python 复制代码
import cv2
import numpy as np
from PIL import Image
import json
import re
from typing import List, Tuple, Dict
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import os


class ColorBlockTokenizer:
    """
    色块语义token化器 v3 ─ 全语义压缩
    ─────────────────────────────────────
    6大压缩语义:
      1. mirror   镜像  --- 右半=左半翻转 / 下半=上半翻转
      2. gradient 渐变  --- 颜色平滑过渡,只存起止色
      3. tile     平铺  --- 重复多色模式,定义一次多处引用
      4. copy     复制  --- 完全相同的矩形区域
      5. pat/rep  行内重复 --- 行内图案重复N次
      6. same     颜色差分 --- 同结构不同颜色,引用+换色
      (+ block / fill / row_end / repeat 等基础语义)
    """

    def __init__(self, max_colors=16, tile_sizes=None, mirror_threshold=0.9):
        self.max_colors = max_colors
        self.tile_sizes = tile_sizes or [2, 4, 8, 16]
        self.mirror_threshold = mirror_threshold

    # ════════════════════════════════════════
    #  编码
    # ════════════════════════════════════════

    def encode(self, image: np.ndarray):
        h, w, _ = image.shape
        quantized, palette = self._quantize_colors(image)

        tokens = [f"canvas[{w},{h}]"]
        pal_hex = ','.join(
            f'{(int(r)<<16)|(int(g)<<8)|int(b):06x}' for r, g, b in palette)
        tokens.append(f"palette[{pal_hex}]")

        covered = np.zeros((h, w), dtype=bool)

        # ── Pass 1: 镜像检测 (标记覆盖, 延迟发出token) ──
        mirror_tokens = []
        mirrors = self._detect_mirrors(quantized, h, w)
        for m in mirrors:
            my, mx, mw, mh = m['y'], m['x'], m['w'], m['h']
            if m['axis'] == 'h':
                mirror_tokens.append(f"mirror_h[{mx},{my},{mw},{mh}]")
                src = quantized[my:my+mh, mx-mw:mx]
                dst = quantized[my:my+mh, mx:mx+mw]
                match = (dst == src[:, ::-1])
            else:
                mirror_tokens.append(f"mirror_v[{mx},{my},{mw},{mh}]")
                src = quantized[my-mh:my, mx:mx+mw]
                dst = quantized[my:my+mh, mx:mx+mw]
                match = (dst == src[::-1, :])
            covered[my:my+mh, mx:mx+mw] = match

        # ── Pass 2: 渐变检测 ──
        grads = self._detect_gradients(quantized, covered, h, w, palette)
        for g in grads:
            tokens.append(
                f"grad[{g['x']},{g['y']},{g['w']},{g['h']},"
                f"c{g['c0']},c{g['c1']},{g['dir']}]")
            covered[g['y']:g['y']+g['h'], g['x']:g['x']+g['w']] = True

        # ── Pass 3: 平铺检测 ──
        tile_defs = self._detect_tiles(quantized, covered, h, w)
        self._emit_tile_tokens(tile_defs, tokens, covered)

        # ── Pass 4: 区域复制 (标记覆盖, 延迟发出token) ──
        copies = self._detect_copies(quantized, covered, h, w)

        # ── Pass 5: 多行色块 + 行级编码(含行内重复/颜色差分) ──
        self._encode_remaining(quantized, covered, h, w, tokens)

        # ── 延迟发出: 镜像和复制token (源区域已由Pass 5编码) ──
        tokens.extend(mirror_tokens)
        for c in copies:
            tokens.append(
                f"copy[{c['sx']},{c['sy']},{c['dx']},{c['dy']},"
                f"{c['w']},{c['h']}]")

        # ── Pass 6: 相对定位(后处理) ──
        tokens = self._apply_delta(tokens)

        return tokens, quantized, palette

    # ──────────── 量化 ────────────

    def _quantize_colors(self, image):
        h, w, _ = image.shape
        pixels = image.reshape(-1, 3).astype(np.float32)
        n_c = min(self.max_colors, len(pixels))
        sample = (pixels[np.random.choice(len(pixels), min(10000, len(pixels)), replace=False)]
                  if len(pixels) > 10000 else pixels)
        km = KMeans(n_clusters=n_c, random_state=42, n_init=10)
        km.fit(sample)
        labels = km.predict(pixels).reshape(h, w)
        centers = km.cluster_centers_.astype(np.uint8)
        return labels, [centers[i] for i in range(len(centers))]

    # ──────────── Pass 1: 镜像 ────────────

    def _detect_mirrors(self, quantized, h, w):
        mirrors = []

        if w >= 4 and w % 2 == 0:
            hw = w // 2
            left = quantized[:, :hw]
            right_flip = quantized[:, hw:][:, ::-1]
            rate = np.mean(left == right_flip)
            if rate >= self.mirror_threshold:
                mirrors.append({
                    'axis': 'h', 'x': hw, 'y': 0,
                    'w': hw, 'h': h, 'rate': rate})

        if h >= 4 and h % 2 == 0:
            hh = h // 2
            top = quantized[:hh, :]
            bot_flip = quantized[hh:][::-1, :]
            rate = np.mean(top == bot_flip)
            if rate >= self.mirror_threshold:
                mirrors.append({
                    'axis': 'v', 'x': 0, 'y': hh,
                    'w': w, 'h': hh, 'rate': rate})

        if len(mirrors) > 1:
            mirrors.sort(key=lambda m: -m['w'] * m['h'] * m['rate'])
            mirrors = mirrors[:1]
        return mirrors

    # ──────────── Pass 2: 渐变 ────────────

    def _detect_gradients(self, quantized, covered, h, w, palette):
        grads = []
        min_len = 4

        for y in range(h):
            if covered[y].all():
                continue
            x = 0
            while x < w:
                if covered[y, x]:
                    x += 1
                    continue

                colors_seq = []
                cx = x
                last_c = -1
                while cx < w and not covered[y, cx]:
                    c = int(quantized[y, cx])
                    if c != last_c:
                        colors_seq.append(c)
                        last_c = c
                    cx += 1

                span = cx - x
                if len(colors_seq) < 3 or span < min_len:
                    x = cx
                    continue

                rgbs = np.array([palette[c] for c in colors_seq], dtype=np.float32)
                diffs = np.diff(rgbs, axis=0)
                mean_d = diffs.mean(axis=0)
                if np.max(np.abs(mean_d)) < 1:
                    x = cx
                    continue
                if np.max(np.abs(diffs - mean_d)) > 25:
                    x = cx
                    continue

                c0, c1 = colors_seq[0], colors_seq[-1]

                gh = 1
                while (y + gh < h
                       and not covered[y + gh, x:x + span].any()
                       and int(quantized[y + gh, x]) == c0
                       and int(quantized[y + gh, cx - 1]) == c1):
                    gh += 1

                grads.append({
                    'x': x, 'y': y, 'w': span, 'h': gh,
                    'c0': c0, 'c1': c1, 'dir': 'h'})
                covered[y:y + gh, x:cx] = True
                x = cx

        return grads

    # ──────────── Pass 3: 平铺 ────────────

    def _detect_tiles(self, quantized, covered, h, w):
        candidates = []
        for th in sorted(self.tile_sizes, reverse=True):
            for tw in sorted(self.tile_sizes, reverse=True):
                if tw > w // 2 or th > h // 2:
                    continue
                if h % th != 0 or w % tw != 0:
                    continue
                tmap: Dict[bytes, dict] = {}
                for ty in range(0, h - th + 1, th):
                    for tx in range(0, w - tw + 1, tw):
                        td = quantized[ty:ty + th, tx:tx + tw].copy()
                        k = td.tobytes()
                        if k not in tmap:
                            tmap[k] = {'data': td, 'pos': []}
                        tmap[k]['pos'].append((tx, ty))
                for info in tmap.values():
                    n = len(info['pos'])
                    if n < 2:
                        continue
                    tp = tw * th
                    if np.all(info['data'] == info['data'][0, 0]):
                        continue
                    sav = n * tp - tp - n * 3
                    if sav >= tp:
                        candidates.append({
                            'w': tw, 'h': th,
                            'data': info['data'], 'pos': info['pos'],
                            'sav': sav})

        candidates.sort(key=lambda c: -c['sav'])
        selected = []
        used = np.zeros((h, w), dtype=bool)
        for cand in candidates:
            tw, th = cand['w'], cand['h']
            valid = [(tx, ty) for tx, ty in cand['pos']
                     if not used[ty:ty + th, tx:tx + tw].any()]
            if len(valid) >= 2:
                selected.append({'w': tw, 'h': th, 'data': cand['data'], 'pos': valid})
                for tx, ty in valid:
                    used[ty:ty + th, tx:tx + tw] = True
        return selected

    def _emit_tile_tokens(self, tile_defs, tokens, covered):
        for tid, tinfo in enumerate(tile_defs):
            tw, th = tinfo['w'], tinfo['h']
            tokens.append(f"tile[{tid},{tw},{th}]")
            for ty in range(th):
                tokens.append("tr[" + ','.join(str(c) for c in tinfo['data'][ty]) + "]")
            tokens.append("tile_end")
            groups = self._group_contiguous(tinfo['pos'], tw, th)
            for g in groups:
                if len(g) == 1:
                    tokens.append(f"tile_at[{tid},{g[0][0]},{g[0][1]}]")
                else:
                    mx = min(p[0] for p in g)
                    my = min(p[1] for p in g)
                    nx = (max(p[0] for p in g) - mx) // tw + 1
                    ny = (max(p[1] for p in g) - my) // th + 1
                    tokens.append(f"tile_fill[{tid},{mx},{my},{nx},{ny}]")
            for tx, ty in tinfo['pos']:
                covered[ty:ty + th, tx:tx + tw] = True

    def _group_contiguous(self, positions, tw, th):
        if not positions:
            return []
        ps = set(positions)
        vis = set()
        groups = []
        for p in positions:
            if p in vis:
                continue
            g = []
            stk = [p]
            while stk:
                c = stk.pop()
                if c in vis:
                    continue
                vis.add(c)
                g.append(c)
                for dx, dy in [(tw, 0), (-tw, 0), (0, th), (0, -th)]:
                    nb = (c[0] + dx, c[1] + dy)
                    if nb in ps and nb not in vis:
                        stk.append(nb)
            groups.append(g)
        return groups

    # ──────────── Pass 4: 区域复制 ────────────

    def _detect_copies(self, quantized, covered, h, w):
        copies = []
        bs = min(16, w // 4, h // 4)
        if bs < 4:
            return copies

        blocks = {}
        for by in range(0, h, bs):
            for bx in range(0, w, bs):
                bh = min(bs, h - by)
                bw = min(bs, w - bx)
                if covered[by:by + bh, bx:bx + bw].any():
                    continue
                region = quantized[by:by + bh, bx:bx + bw].copy()
                k = region.tobytes()
                if k not in blocks:
                    blocks[k] = []
                blocks[k].append((bx, by, bw, bh, region))

        for k, blist in blocks.items():
            if len(blist) < 2:
                continue
            src = blist[0]
            for dst in blist[1:]:
                sx, sy, bw, bh, _ = src
                dx, dy = dst[0], dst[1]
                if covered[dy:dy + bh, dx:dx + bw].any():
                    continue
                copies.append({'sx': sx, 'sy': sy, 'dx': dx, 'dy': dy, 'w': bw, 'h': bh})
                covered[dy:dy + bh, dx:dx + bw] = True

        return copies

    # ──────────── Pass 5: 色块 + 行编码(含行内重复/颜色差分) ────────────

    def _encode_remaining(self, quantized, covered, h, w, tokens):
        prev_sig = None
        rep_n = 0
        y = 0

        while y < h:
            if covered[y].all():
                if prev_sig is not None:
                    if rep_n > 0:
                        tokens.append(f"repeat[{rep_n}]")
                    prev_sig = None
                    rep_n = 0
                y += 1
                continue

            runs = self._get_runs(quantized[y], covered[y], w)
            if not runs:
                y += 1
                continue

            multi = []
            single = []
            for sx, ex, cid in runs:
                rw = ex - sx
                vh = 1
                ok = True
                while ok and y + vh < h:
                    for cx in range(sx, ex):
                        if quantized[y + vh, cx] != cid or covered[y + vh, cx]:
                            ok = False
                            break
                    if ok:
                        vh += 1
                if vh > 1:
                    multi.append((sx, y, rw, vh, cid))
                    covered[y:y + vh, sx:ex] = True
                else:
                    single.append((sx, ex, cid))

            for bx, by, bw, bh, bc in multi:
                tokens.append(f"block[x{bx},y{by},w{bw},h{bh},c{bc}]")

            if single:
                sig = tuple(single)
                if sig == prev_sig and prev_sig is not None and not multi:
                    rep_n += 1
                else:
                    if rep_n > 0:
                        tokens.append(f"repeat[{rep_n}]")
                        rep_n = 0
                    self._emit_row_enhanced(single, y, w, tokens)
                    prev_sig = sig
            else:
                if multi and rep_n > 0:
                    tokens.append(f"repeat[{rep_n}]")
                    rep_n = 0
                prev_sig = None
            y += 1

        if rep_n > 0:
            tokens.append(f"repeat[{rep_n}]")

    def _get_runs(self, row, row_cov, width):
        runs = []
        x = 0
        while x < width:
            if row_cov[x]:
                x += 1
                continue
            cid = row[x]
            end = x + 1
            while end < width and not row_cov[end] and row[end] == cid:
                end += 1
            runs.append((x, end, cid))
            x = end
        return runs

    def _emit_row_enhanced(self, runs, y, width, tokens):
        # 1) 整行同色
        if len(runs) == 1 and runs[0][0] == 0 and runs[0][1] == width:
            tokens.append(f"r[{y}] fill[c{runs[0][2]}]")
            return

        # 2) 行内重复检测: pat[...]rep[n]
        pat = self._find_pattern(runs)
        if pat is not None:
            pat_parts = ','.join(f"w{w},c{c}" for w, c in pat['pattern'])
            tokens.append(f"r[{y}] pat[x{pat['start_x']},{pat_parts}]rep[{pat['repeats']}]")
            return

        # 3) 颜色差分: same[dx,c]
        parts = self._emit_with_same(runs, width)
        tokens.append(f"r[{y}] " + ' '.join(parts))

    def _find_pattern(self, runs):
        n = len(runs)
        if n < 4:
            return None
        for i in range(1, n):
            if runs[i][0] != runs[i - 1][1]:
                return None
        start_x = runs[0][0]
        for plen in range(1, n // 2 + 1):
            if n % plen != 0:
                continue
            reps = n // plen
            if reps < 2:
                continue
            pattern = [(runs[i][1] - runs[i][0], runs[i][2]) for i in range(plen)]
            ok = True
            for r in range(reps):
                for i in range(plen):
                    idx = r * plen + i
                    rw = runs[idx][1] - runs[idx][0]
                    rc = runs[idx][2]
                    if rw != pattern[i][0] or rc != pattern[i][1]:
                        ok = False
                        break
                if not ok:
                    break
            if ok:
                return {'pattern': pattern, 'repeats': reps, 'start_x': start_x}
        return None

    def _emit_with_same(self, runs, width):
        parts = []
        widths_seen = []
        for sx, ex, cid in runs:
            rw = ex - sx
            found_ref = False
            for pw, pc in widths_seen:
                if pw == rw and pc != cid:
                    parts.append(f"same[x{sx},w{rw},c{cid}]")
                    found_ref = True
                    break
            if not found_ref:
                if ex == width:
                    parts.append(f"row_end[x{sx},c{cid}]")
                else:
                    parts.append(f"blk[x{sx},w{rw},c{cid}]")
            widths_seen.append((rw, cid))
        return parts

    # ──────────── Pass 6: 相对定位 ────────────

    def _apply_delta(self, tokens):
        result = []
        last_x, last_y = 0, 0

        for t in tokens:
            if t.startswith('block[x'):
                m = re.match(r'block\[x(\d+),y(\d+),(w\d+,h\d+,c\d+)\]', t)
                if m:
                    x, y = int(m.group(1)), int(m.group(2))
                    rest = m.group(3)
                    dx, dy = x - last_x, y - last_y
                    if -99 <= dx <= 99 and -99 <= dy <= 99:
                        result.append(f"block[dx{dx},dy{dy},{rest}]")
                    else:
                        result.append(t)
                    last_x, last_y = x, y
                    continue

            elif t.startswith('r[') and not t.startswith('r[+'):
                m = re.match(r'r\[(\d+)\]', t)
                if m:
                    y = int(m.group(1))
                    dy = y - last_y
                    content = t[t.index(']') + 1:]
                    if 0 < dy <= 99:
                        result.append(f'r[+{dy}]{content}')
                    else:
                        result.append(t)
                    last_y = y
                    continue

            result.append(t)
        return result

    # ════════════════════════════════════════
    #  解码
    # ════════════════════════════════════════

    def decode(self, tokens: List[str], palette: List[np.ndarray]) -> np.ndarray:
        cw = ch = 0
        img = None
        last_row = None
        cur_y = 0
        tiles: Dict[int, np.ndarray] = {}
        tid = -1
        tr_rows = []
        tw = th = 0
        last_x, last_y = 0, 0

        for t in tokens:
            # ── canvas ──
            if t.startswith('canvas['):
                m = re.match(r'canvas\[(\d+),(\d+)\]', t)
                cw, ch = int(m.group(1)), int(m.group(2))
                img = np.full((ch, cw), -1, dtype=np.int32)

            # ── mirror ──
            elif t.startswith('mirror_h['):
                m = re.match(r'mirror_h\[(\d+),(\d+),(\d+),(\d+)\]', t)
                if m and img is not None:
                    mx, my, mw, mh = (int(m.group(1)), int(m.group(2)),
                                      int(m.group(3)), int(m.group(4)))
                    src = img[my:my + mh, mx - mw:mx]
                    if src.shape[1] == mw:
                        mirrored = src[:, ::-1]
                        mask = (img[my:my + mh, mx:mx + mw] == -1)
                        img[my:my + mh, mx:mx + mw] = np.where(
                            mask, mirrored, img[my:my + mh, mx:mx + mw])

            elif t.startswith('mirror_v['):
                m = re.match(r'mirror_v\[(\d+),(\d+),(\d+),(\d+)\]', t)
                if m and img is not None:
                    mx, my, mw, mh = (int(m.group(1)), int(m.group(2)),
                                      int(m.group(3)), int(m.group(4)))
                    src = img[my - mh:my, mx:mx + mw]
                    if src.shape[0] == mh:
                        mirrored = src[::-1, :]
                        mask = (img[my:my + mh, mx:mx + mw] == -1)
                        img[my:my + mh, mx:mx + mw] = np.where(
                            mask, mirrored, img[my:my + mh, mx:mx + mw])

            # ── gradient ──
            elif t.startswith('grad['):
                m = re.match(r'grad\[(\d+),(\d+),(\d+),(\d+),c(\d+),c(\d+),([hv])\]', t)
                if m and img is not None:
                    gx, gy = int(m.group(1)), int(m.group(2))
                    gw, gh = int(m.group(3)), int(m.group(4))
                    c0, c1 = int(m.group(5)), int(m.group(6))
                    d = m.group(7)
                    rgb0 = palette[c0].astype(np.float32)
                    rgb1 = palette[c1].astype(np.float32)
                    for dy in range(gh):
                        for dx in range(gw):
                            if d == 'h':
                                frac = dx / max(gw - 1, 1)
                            else:
                                frac = dy / max(gh - 1, 1)
                            rgb = rgb0 * (1 - frac) + rgb1 * frac
                            c = int(np.argmin(np.sum(
                                (np.array(palette, dtype=np.float32) - rgb) ** 2, axis=1)))
                            py, px = gy + dy, gx + dx
                            if 0 <= py < ch and 0 <= px < cw:
                                img[py, px] = c

            # ── tile definition ──
            elif t.startswith('tile['):
                m = re.match(r'tile\[(\d+),(\d+),(\d+)\]', t)
                if m:
                    tid, tw, th = int(m.group(1)), int(m.group(2)), int(m.group(3))
                    tr_rows = []

            elif t.startswith('tr['):
                tr_rows.append([int(v) for v in t[3:-1].split(',')])

            elif t == "tile_end":
                if tid >= 0 and len(tr_rows) == th:
                    tiles[tid] = np.array(tr_rows, dtype=np.int32).reshape(th, tw)

            elif t.startswith('tile_at['):
                m = re.match(r'tile_at\[(\d+),(\d+),(\d+)\]', t)
                if m and img is not None:
                    t_id, tx, ty = int(m.group(1)), int(m.group(2)), int(m.group(3))
                    self._place_tile(img, tiles.get(t_id), tx, ty, ch, cw)

            elif t.startswith('tile_fill['):
                m = re.match(r'tile_fill\[(\d+),(\d+),(\d+),(\d+),(\d+)\]', t)
                if m and img is not None:
                    t_id = int(m.group(1))
                    sx, sy = int(m.group(2)), int(m.group(3))
                    nx, ny = int(m.group(4)), int(m.group(5))
                    td = tiles.get(t_id)
                    if td is not None:
                        tth, ttw = td.shape
                        for iy in range(ny):
                            for ix in range(nx):
                                self._place_tile(img, td, sx + ix * ttw,
                                                 sy + iy * tth, ch, cw)

            # ── copy ──
            elif t.startswith('copy['):
                m = re.match(r'copy\[(\d+),(\d+),(\d+),(\d+),(\d+),(\d+)\]', t)
                if m and img is not None:
                    sx, sy = int(m.group(1)), int(m.group(2))
                    dx, dy = int(m.group(3)), int(m.group(4))
                    bw, bh = int(m.group(5)), int(m.group(6))
                    src = img[sy:sy + bh, sx:sx + bw].copy()
                    dye, dxe = min(dy + bh, ch), min(dx + bw, cw)
                    img[dy:dye, dx:dxe] = src[:dye - dy, :dxe - dx]

            # ── block (absolute) ──
            elif t.startswith('block[x'):
                m = re.match(r'block\[x(\d+),y(\d+),w(\d+),h(\d+),c(\d+)\]', t)
                if m and img is not None:
                    bx, by = int(m.group(1)), int(m.group(2))
                    bw, bh, bc = int(m.group(3)), int(m.group(4)), int(m.group(5))
                    img[by:min(by + bh, ch), bx:min(bx + bw, cw)] = bc
                    last_x, last_y = bx, by
                    last_row = img[by, :].copy()
                    cur_y = by + bh

            # ── block (delta) ──
            elif t.startswith('block[dx'):
                m = re.match(r'block\[dx([+-]?\d+),dy([+-]?\d+),(w\d+),(h\d+),(c\d+)\]', t)
                if m and img is not None:
                    dx, dy = int(m.group(1)), int(m.group(2))
                    bw = int(m.group(3)[1:])
                    bh = int(m.group(4)[1:])
                    bc = int(m.group(5)[1:])
                    bx, by = last_x + dx, last_y + dy
                    img[by:min(by + bh, ch), bx:min(bx + bw, cw)] = bc
                    last_x, last_y = bx, by
                    last_row = img[by, :].copy()
                    cur_y = by + bh

            # ── repeat ──
            elif t.startswith('repeat['):
                m = re.match(r'repeat\[(\d+)\]', t)
                if m and img is not None:
                    for _ in range(int(m.group(1))):
                        if cur_y < ch and last_row is not None:
                            img[cur_y, :] = last_row
                            cur_y += 1

            # ── row (absolute) ──
            elif t.startswith('r[') and not t.startswith('r[+'):
                m = re.match(r'\[(\d+)\](.*)', t[1:])
                if m and img is not None:
                    y = int(m.group(1))
                    self._decode_row(img, y, cw, m.group(2).strip())
                    last_row = img[y, :].copy()
                    cur_y = y + 1
                    last_y = y

            # ── row (delta) ──
            elif t.startswith('r[+'):
                m = re.match(r'r\[\+(\d+)\](.*)', t)
                if m and img is not None:
                    dy = int(m.group(1))
                    y = last_y + dy
                    self._decode_row(img, y, cw, m.group(2).strip())
                    last_row = img[y, :].copy()
                    cur_y = y + 1
                    last_y = y

        if img is not None:
            img[img == -1] = 0
        return self._to_rgb(img, palette)

    def _place_tile(self, img, td, tx, ty, ch, cw):
        if td is None:
            return
        tth, ttw = td.shape
        ye, xe = min(ty + tth, ch), min(tx + ttw, cw)
        if ty < ch and tx < cw:
            img[ty:ye, tx:xe] = td[:ye - ty, :xe - tx]

    def _decode_row(self, img, y, width, content):
        if not content:
            return

        if content.startswith('fill['):
            m = re.match(r'fill\[c(\d+)\]', content)
            if m:
                img[y, :] = int(m.group(1))
            return

        if content.startswith('pat['):
            m = re.match(r'pat\[x(\d+),(.+)\]rep\[(\d+)\]', content)
            if m:
                x = int(m.group(1))
                pat_str = m.group(2)
                reps = int(m.group(3))
                pat_parts = pat_str.split(',')
                pattern = []
                i = 0
                while i < len(pat_parts):
                    w = int(pat_parts[i][1:])
                    c = int(pat_parts[i + 1][1:])
                    pattern.append((w, c))
                    i += 2
                for _ in range(reps):
                    for pw, pc in pattern:
                        img[y, x:x + pw] = pc
                        x += pw
                return

        for part in content.split():
            if part.startswith('row_end['):
                m = re.match(r'row_end\[x(\d+),c(\d+)\]', part)
                if m:
                    img[y, int(m.group(1)):] = int(m.group(2))
            elif part.startswith('blk['):
                m = re.match(r'blk\[x(\d+),w(\d+),c(\d+)\]', part)
                if m:
                    x, bw, c = int(m.group(1)), int(m.group(2)), int(m.group(3))
                    img[y, x:x + bw] = c
            elif part.startswith('same['):
                m = re.match(r'same\[x(\d+),w(\d+),c(\d+)\]', part)
                if m:
                    x, w, c = int(m.group(1)), int(m.group(2)), int(m.group(3))
                    img[y, x:x + w] = c

    def _to_rgb(self, labels, palette):
        if labels is None:
            return np.zeros((1, 1, 3), dtype=np.uint8)
        h, w = labels.shape
        rgb = np.zeros((h, w, 3), dtype=np.uint8)
        for c in range(len(palette)):
            rgb[labels == c] = palette[c]
        return rgb

    # ════════════════════════════════════════
    #  文件IO
    # ════════════════════════════════════════

    def save_tokens(self, tokens, palette, path):
        with open(path, 'w', encoding='utf-8') as f:
            json.dump({'tokens': tokens, 'palette': [c.tolist() for c in palette]}, f,
                      ensure_ascii=False, indent=2)
        print(f"Token已保存: {path} ({os.path.getsize(path) / 1024:.1f} KB)")

    def load_tokens(self, path):
        with open(path, 'r', encoding='utf-8') as f:
            d = json.load(f)
        return d['tokens'], [np.array(c, dtype=np.uint8) for c in d['palette']]

    # ════════════════════════════════════════
    #  统计
    # ════════════════════════════════════════

    def token_stats(self, tokens, image):
        s = dict.fromkeys([
            'fill_rows', 'row_ends', 'blocks', 'multi_blocks',
            'repeat_rows', 'tiles', 'tile_at', 'tile_fill',
            'mirrors', 'gradients', 'copies', 'patterns',
            'same_refs', 'delta_blocks', 'delta_rows',
        ], 0)
        s['total_tokens'] = len(tokens)

        for t in tokens:
            if t.startswith('r[') or t.startswith('r[+'):
                if 'fill[' in t:
                    s['fill_rows'] += 1
                elif 'pat[' in t:
                    s['patterns'] += 1
                elif 'same[' in t:
                    s['same_refs'] += t.count('same[')
                s['blocks'] += t.count('blk[')
                s['row_ends'] += t.count('row_end[')
                if t.startswith('r[+'):
                    s['delta_rows'] += 1
            elif t.startswith('repeat['):
                m = re.match(r'repeat\[(\d+)\]', t)
                if m:
                    s['repeat_rows'] += int(m.group(1))
            elif t.startswith('block[x'):
                m = re.match(r'block\[x\d+,y\d+,w\d+,h(\d+)', t)
                if m:
                    s['multi_blocks' if int(m.group(1)) > 1 else 'blocks'] += 1
            elif t.startswith('block[dx'):
                s['delta_blocks'] += 1
            elif t.startswith('tile['):
                s['tiles'] += 1
            elif t.startswith('tile_at['):
                s['tile_at'] += 1
            elif t.startswith('tile_fill['):
                m = re.match(r'tile_fill\[\d+,\d+,\d+,(\d+),(\d+)\]', t)
                s['tile_fill'] += int(m.group(1)) * int(m.group(2)) if m else 1
            elif t.startswith('mirror_'):
                s['mirrors'] += 1
            elif t.startswith('grad['):
                s['gradients'] += 1
            elif t.startswith('copy['):
                s['copies'] += 1

        h, w, _ = image.shape
        s['image_size'] = f'{w}x{h}'
        s['original_bytes'] = h * w * 3
        s['token_bytes'] = len('\n'.join(tokens).encode('utf-8'))
        s['compression_ratio'] = s['original_bytes'] / max(1, s['token_bytes'])
        return s


# ════════════════════════════════════════
#  可视化
# ════════════════════════════════════════

def display_comparison(original, quantized_rgb, reconstructed, stats):
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    titles = [
        f"原始\n{original.shape[:2]}",
        f"量化\n{stats['total_tokens']} tokens",
        f"重建\nPSNR: {_psnr(original, reconstructed):.1f}dB",
        f"差异(x10)",
    ]
    imgs = [original, quantized_rgb, reconstructed,
            cv2.absdiff(original, reconstructed) * 10]
    for ax, im, ti in zip(axes, imgs, titles):
        ax.imshow(im)
        ax.set_title(ti, fontsize=10)
        ax.axis('off')
    plt.suptitle(
        f"色块语义v3 | 压缩比 {stats['compression_ratio']:.1f}x",
        fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


def _psnr(a, b):
    mse = np.mean((a.astype(np.float64) - b.astype(np.float64)) ** 2)
    return float('inf') if mse == 0 else 20 * np.log10(255.0 / np.sqrt(mse))


def print_stats(s):
    print(f"\n{'═' * 55}")
    print(f"  色块语义token统计 v3 (全语义压缩)")
    print(f"{'═' * 55}")
    print(f"  图像: {s['image_size']}  总token: {s['total_tokens']}")
    print(f"{'─' * 55}")
    print(f"  镜像 mirror:       {s['mirrors']}")
    print(f"  渐变 gradient:     {s['gradients']}")
    print(f"  平铺 tile:         {s['tiles']}  "
          f"(at:{s['tile_at']} fill:{s['tile_fill']})")
    print(f"  复制 copy:         {s['copies']}")
    print(f"  多行 block:        {s['multi_blocks']}")
    print(f"  单行 block:        {s['blocks']}")
    print(f"  整行 fill_row:     {s['fill_rows']}")
    print(f"  行尾 row_end:      {s['row_ends']}")
    print(f"  行内重复 pat:      {s['patterns']}")
    print(f"  颜色差分 same:     {s['same_refs']}")
    print(f"  行重复 repeat:     {s['repeat_rows']}")
    print(f"  相对定位 delta:    blk:{s['delta_blocks']} row:{s['delta_rows']}")
    print(f"{'─' * 55}")
    print(f"  原始: {s['original_bytes']:,}B  "
          f"Token: {s['token_bytes']:,}B  "
          f"压缩比: {s['compression_ratio']:.1f}x")
    print(f"{'═' * 55}")


# ════════════════════════════════════════
#  主流程
# ════════════════════════════════════════

if __name__ == "__main__":
    image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "img_1.png")
    if not os.path.exists(image_path):
        print(f"未找到: {image_path}")
        exit(1)

    img = Image.open(image_path).convert('RGB')
    img_array = np.array(img.resize((256, 256)))

    tokenizer = ColorBlockTokenizer(
        max_colors=16, tile_sizes=[2, 4, 8, 16], mirror_threshold=0.85)

    print("编码中...")
    tokens, quantized, palette = tokenizer.encode(img_array)
    print("解码中...")
    reconstructed = tokenizer.decode(tokens, palette)

    stats = tokenizer.token_stats(tokens, img_array)
    print_stats(stats)

    print(f"\n前20个token:")
    for i, t in enumerate(tokens):
        print(f"  {i + 1:2d}. {t}")
    if len(tokens) > 20:
        print(f"  ... 共 {len(tokens)} 个")

    qrgb = np.zeros_like(img_array)
    for c in range(len(palette)):
        qrgb[quantized == c] = palette[c]

    h, w, _ = reconstructed.shape
    r_labels = np.zeros((h, w), dtype=np.int32)
    for c in range(len(palette)):
        mask = np.all(reconstructed == palette[c], axis=2)
        r_labels[mask] = c

    diff_pixels = int(np.sum(quantized != r_labels))
    total_pixels = h * w
    q_psnr = _psnr(qrgb, reconstructed)
    print(f"\nPSNR (quantized→reconstructed): {q_psnr:.1f} dB")
    print(f"Mismatch pixels: {diff_pixels}/{total_pixels} "
          f"({100*diff_pixels/total_pixels:.2f}%)")

    display_comparison(img_array, qrgb, reconstructed, stats)
    tokenizer.save_tokens(tokens, palette, 'color_block_tokens.json')
相关推荐
碳基硅坊5 分钟前
Qwen3.5-9B在安全生产安全帽检测中的应用
人工智能·安全·安全帽检测·qwen3.5-9b
云烟成雨TD14 分钟前
Spring AI Alibaba 1.x 系列【66】Graph 长期记忆
java·人工智能·spring
春日见14 分钟前
五分钟入门 强化学习---Q-Learning算法与实现
人工智能·python·深度学习·算法·机器学习·计算机视觉
卡次卡次124 分钟前
vibecoding起步之Claude Code的skills是什么,里面有什么文件,以ppt的一个skills举例
人工智能·opencv·powerpoint
AI服务老曹24 分钟前
解耦异构算力:基于 Docker 与 GB28181/RTSP 的边缘计算 AI 视频管理平台架构设计与源码交付实践
人工智能·docker·边缘计算
小饕32 分钟前
RAG 实战:文本切块(Text Chunking)从入门到精通
人工智能
多年小白34 分钟前
【周末消息】2026年5月30日-6月1日
大数据·人工智能·深度学习·机器学习·金融
AI导出鸭PC端34 分钟前
智谱清言清除符号:当LLM输出遭遇“结构性失序”,一份关于AI导出鸭的工程化测评
人工智能
Engineer邓祥浩1 小时前
宏观认知(3):AI战略与社会影响——吴恩达《AI for Everyone》Week3学习笔记
人工智能·笔记·学习