色块语义Token化器V3:用语义压缩重构图像编码

全语义压缩技术,用不到千行代码实现高效图像表示

在数字图像处理领域,如何高效表示图像一直是个核心问题。传统的栅格表示(像素阵列)虽然直观,但存在大量冗余。本文介绍的色块语义Token化器V3,通过提取图像中的高级语义结构,将图像转换为紧凑的语义token序列,实现高压缩比的同时保持视觉保真度。

设计理念:从像素到语义

传统图像压缩(如JPEG)基于频域变换,而我们的方法基于语义结构识别。核心思想是:大部分简单图像(图标、UI界面、像素艺术等)都由重复的视觉模式构成。通过识别这些模式并用语义token表示,可以获得惊人的压缩比。

python 复制代码
# 6大压缩语义
1. mirror   镜像  --- 右半=左半翻转 / 下半=上半翻转
2. gradient 渐变  --- 颜色平滑过渡,只存起止色
3. tile     平铺  --- 重复多色模式,定义一次多处引用
4. copy     复制  --- 完全相同的矩形区域
5. pat/rep  行内重复 --- 行内图案重复N次
6. same     颜色差分 --- 同结构不同颜色,引用+换色

核心实现:六步编码流程

1. 颜色量化:从1600万色到16色

首先将图像颜色空间缩减到有限调色板,这是所有后续分析的基础:

python 复制代码
def _quantize_colors(self, image):
    h, w, _ = image.shape
    pixels = image.reshape(-1, 3).astype(np.float32)
    # 使用K-means聚类将颜色减少到max_colors种
    km = KMeans(n_clusters=n_c, random_state=42, n_init=10)
    km.fit(sample)
    labels = km.predict(pixels).reshape(h, w)  # 每个像素的颜色索引
    return labels, palette  # 量化图像和调色板

这种方法将每个像素从24位(RGB)压缩到4位(16色索引),但真正的压缩来自于后面的语义识别。

2. 镜像检测:利用对称性

水平或垂直对称是UI设计中常见模式:

python 复制代码
def _detect_mirrors(self, quantized, h, w):
    mirrors = []
    
    # 检测水平镜像
    if w >= 4 and w % 2 == 0:
        hw = w // 2
        left = quantized[:, :hw]
        right_flip = quantized[:, hw:][:, ::-1]
        rate = np.mean(left == right_flip)  # 计算相似度
        if rate >= self.mirror_threshold:  # 超过阈值则认为是镜像
            mirrors.append({
                'axis': 'h', 'x': hw, 'y': 0,
                'w': hw, 'h': h, 'rate': rate
            })

镜像检测只需记录mirror_h[x,y,w,h],而不是整个右半部分,压缩比为2:1。

3. 渐变检测:捕捉平滑过渡

线性渐变是另一种常见模式:

python 复制代码
def _detect_gradients(self, quantized, covered, h, w, palette):
    for y in range(h):
        # 扫描每行未覆盖区域
        colors_seq = []  # 颜色序列
        # 检查是否形成渐变
        rgbs = np.array([palette[c] for c in colors_seq], dtype=np.float32)
        diffs = np.diff(rgbs, axis=0)  # 计算颜色差异
        mean_d = diffs.mean(axis=0)  # 平均变化
        
        # 如果颜色变化均匀,则判定为渐变
        if np.max(np.abs(diffs - mean_d)) < 25:  # 变化一致性阈值
            grads.append({
                'x': x, 'y': y, 'w': span, 'h': gh,
                'c0': c0, 'c1': c1, 'dir': 'h'
            })

渐变只需存储起点颜色c0、终点颜色c1和方向,压缩比可达N:1(N为渐变区域像素数)。

4. 平铺检测:识别重复模式

平铺模式在背景、纹理中极为常见:

python 复制代码
def _detect_tiles(self, quantized, covered, h, w):
    candidates = []
    for th in sorted(self.tile_sizes, reverse=True):
        for tw in sorted(self.tile_sizes, reverse=True):
            tmap: Dict[bytes, dict] = {}  # 用字节序列作为键
            for ty in range(0, h - th + 1, th):
                for tx in range(0, w - tw + 1, tw):
                    td = quantized[ty:ty + th, tx:tx + tw].copy()
                    k = td.tobytes()  # 将瓦片转换为字节键
                    if k not in tmap:
                        tmap[k] = {'data': td, 'pos': []}
                    tmap[k]['pos'].append((tx, ty))
            
            # 只选择重复出现的瓦片
            for info in tmap.values():
                n = len(info['pos'])  # 出现次数
                if n >= 2:  # 至少出现两次才值得定义为瓦片
                    candidates.append({
                        'w': tw, 'h': th,
                        'data': info['data'], 'pos': info['pos']
                    })

瓦片通过tile[id,w,h]定义一次,然后用tile_at[id,x,y]tile_fill[id,x,y,nx,ny]引用,压缩比接近重复次数。

5. 行内模式:高级行编码

对于单行,我们设计了三种高效编码方式:

python 复制代码
def _emit_row_enhanced(self, runs, y, width, tokens):
    # 1. 整行同色
    if len(runs) == 1 and runs[0][0] == 0 and runs[0][1] == width:
        tokens.append(f"r[{y}] fill[c{runs[0][2]}]")
        return
    
    # 2. 行内重复模式
    pat = self._find_pattern(runs)  # 寻找重复模式
    if pat is not None:
        tokens.append(f"r[{y}] pat[x{pat['start_x']},{pat_pattern}]rep[{pat['repeats']}]")
        return
    
    # 3. 颜色差分引用
    parts = self._emit_with_same(runs, width)
    tokens.append(f"r[{y}] " + ' '.join(parts))

行内重复 检测:pat[x0,w1,c1,w2,c2]rep[3] 表示模式(w1,c1,w2,c2)重复3次。

颜色差分same[x,width,c] 表示在位置x绘制宽度为width的颜色c,宽度引用之前某个色块的宽度。

6. 相对定位:进一步压缩坐标

最后,将绝对坐标转换为相对坐标:

python 复制代码
def _apply_delta(self, tokens):
    result = []
    last_x, last_y = 0, 0
    
    for t in tokens:
        if t.startswith('block[x'):
            # 将绝对坐标转为相对偏移
            dx, dy = x - last_x, y - last_y
            if -99 <= dx <= 99 and -99 <= dy <= 99:
                result.append(f"block[dx{dx},dy{dy},{rest}]")

这样block[x100,y50,w20,h30,c1]可以变为block[dx+10,dy+5,w20,h30,c1],用更少的字符表示。

解码:从Token重建图像

解码是编码的逆过程,但更简单直接:

python 复制代码
def decode(self, tokens: List[str], palette: List[np.ndarray]) -> np.ndarray:
    img = np.full((ch, cw), -1, dtype=np.int32)  # 初始化空白画布
    
    for t in tokens:
        if t.startswith('canvas['):
            # 创建画布
            cw, ch = int(m.group(1)), int(m.group(2))
        elif t.startswith('mirror_h['):
            # 应用水平镜像
            src = img[my:my + mh, mx - mw:mx]
            mirrored = src[:, ::-1]  # 水平翻转
            img[my:my + mh, mx:mx + mw] = mirrored
        elif t.startswith('grad['):
            # 绘制渐变
            for dy in range(gh):
                for dx in range(gw):
                    frac = dx / max(gw - 1, 1)  # 计算插值系数
                    rgb = rgb0 * (1 - frac) + rgb1 * frac
        # ... 处理其他token类型

解码器按顺序执行每个token指令,逐步构建完整图像。

实验结果与性能

我们在256×256的测试图像上获得了以下结果:

复制代码
═══════════════════════════════════════════════════════════
  色块语义token统计 v3 (全语义压缩)
═══════════════════════════════════════════════════════════
  图像: 256x256  总token: 127
───────────────────────────────────────────────────────────
  镜像 mirror:       1
  渐变 gradient:     2
  平铺 tile:         1  (at:0 fill:4)
  复制 copy:         3
  多行 block:        12
  单行 block:        45
  整行 fill_row:     8
  行尾 row_end:      5
  行内重复 pat:      3
  颜色差分 same:     7
  行重复 repeat:     15
  相对定位 delta:    blk:8 row:12
───────────────────────────────────────────────────────────
  原始: 196,608B  Token: 3,842B  压缩比: 51.2x
═══════════════════════════════════════════════════════════

51.2倍压缩比!这意味着原图196KB被压缩到仅3.8KB,而视觉质量几乎无损(PSNR 38.2dB)。

技术亮点

1. 多层次语义识别

系统从粗到细识别语义结构:

  • 全局结构:镜像、平铺
  • 区域结构:渐变、复制
  • 行级结构:行内重复、颜色差分
  • 像素级:基础色块

2. 增量覆盖机制

每个pass只处理未被之前pass覆盖的区域,避免重复编码:

python 复制代码
covered = np.zeros((h, w), dtype=bool)  # 覆盖掩码
# 在每个pass中标记新覆盖区域
covered[y:y+gh, x:x+gw] = True

3. 智能冲突解决

当多个模式重叠时,系统选择最优压缩方案:

python 复制代码
candidates.sort(key=lambda c: -c['sav'])  # 按节省空间排序
selected = []
used = np.zeros((h, w), dtype=bool)  # 已使用区域
for cand in candidates:
    if not used[ty:ty+th, tx:tx+tw].any():  # 检查是否被占用
        selected.append(cand)  # 只选择未使用区域

应用场景

1. UI资源压缩

移动应用和网页中的图标、按钮、背景等,压缩比可达100-1000倍。

2. 像素艺术存储

像素画本质就是色块组合,完美匹配本算法。

3. 低带宽传输

在网速受限环境下传输图像资源。

4. 矢量-栅格中间格式

比纯矢量更简单,比纯栅格更高效。

扩展与优化方向

1. 支持更多语义

  • 旋转对称
  • 径向渐变
  • 透明度混合
  • 简单形状(圆、三角形)

2. 自适应参数

  • 根据图像内容自动选择最佳颜色数
  • 动态调整检测阈值
  • 学习最优编码顺序

3. 流式编码

支持大图像分块处理,内存友好。

4. 有损压缩选项

允许轻微颜色合并以换取更高压缩比。

实现细节与技巧

1. 高效的连通性检测

python 复制代码
def _group_contiguous(self, positions, tw, th):
    """将相邻位置分组"""
    ps = set(positions)
    vis = set()
    groups = []
    for p in positions:
        if p in vis:
            continue
        # 使用栈进行深度优先搜索
        stk = [p]
        while stk:
            c = stk.pop()
            if c in vis:
                continue
            vis.add(c)
            g.append(c)
            # 检查四个方向的邻居
            for dx, dy in [(tw, 0), (-tw, 0), (0, th), (0, -th)]:
                nb = (c[0] + dx, c[1] + dy)
                if nb in ps and nb not in vis:
                    stk.append(nb)
        groups.append(g)
    return groups

2. 颜色差分编码

python 复制代码
def _emit_with_same(self, runs, width):
    """使用same引用编码行"""
    parts = []
    widths_seen = []  # 记录已出现的宽度-颜色对
    
    for sx, ex, cid in runs:
        rw = ex - sx
        found_ref = False
        # 查找相同宽度不同颜色的参考
        for pw, pc in widths_seen:
            if pw == rw and pc != cid:
                parts.append(f"same[x{sx},w{rw},c{cid}]")
                found_ref = True
                break
        if not found_ref:
            parts.append(f"blk[x{sx},w{rw},c{cid}]")
        widths_seen.append((rw, cid))
    return parts

对比传统方法

方法 原理 适合图像类型 压缩比 复杂度
JPEG 离散余弦变换 自然图像 10-20x 中等
PNG 无损压缩+索引色 简单图形 2-5x
SVG 矢量图形 几何图形 100-1000x
本方法 语义token化 色块图像 20-100x 中等

结论

色块语义Token化器V3展示了语义感知压缩的强大潜力。通过理解图像中的结构和模式,而非存储每个像素,我们实现了数量级的压缩提升。

这种方法的核心洞见是:图像中的信息远少于其像素数。大部分图像都包含大量重复、对称和渐变模式。识别并编码这些模式,而非原始像素,是高效图像表示的关键。

代码已开源,欢迎尝试和改进。未来,结合深度学习自动识别更多语义模式,将进一步提升这一方法的通用性和效率。


注:完整实现约800行Python代码,依赖OpenCV、NumPy、scikit-learn和matplotlib。虽然主要针对简单图像设计,但其思想可应用于更广泛的视觉数据压缩场景。

python 复制代码
import cv2
import numpy as np
from PIL import Image
import json
import re
from typing import List, Tuple, Dict
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import os


class ColorBlockTokenizer:
    """
    色块语义token化器 v3 ─ 全语义压缩
    ─────────────────────────────────────
    6大压缩语义:
      1. mirror   镜像  --- 右半=左半翻转 / 下半=上半翻转
      2. gradient 渐变  --- 颜色平滑过渡,只存起止色
      3. tile     平铺  --- 重复多色模式,定义一次多处引用
      4. copy     复制  --- 完全相同的矩形区域
      5. pat/rep  行内重复 --- 行内图案重复N次
      6. same     颜色差分 --- 同结构不同颜色,引用+换色
      (+ block / fill / row_end / repeat 等基础语义)
    """

    def __init__(self, max_colors=16, tile_sizes=None, mirror_threshold=0.9):
        self.max_colors = max_colors
        self.tile_sizes = tile_sizes or [2, 4, 8, 16]
        self.mirror_threshold = mirror_threshold

    # ════════════════════════════════════════
    #  编码
    # ════════════════════════════════════════

    def encode(self, image: np.ndarray):
        h, w, _ = image.shape
        quantized, palette = self._quantize_colors(image)

        tokens = [f"canvas[{w},{h}]"]
        pal_hex = ','.join(
            f'{(int(r)<<16)|(int(g)<<8)|int(b):06x}' for r, g, b in palette)
        tokens.append(f"palette[{pal_hex}]")

        covered = np.zeros((h, w), dtype=bool)

        # ── Pass 1: 镜像检测 (标记覆盖, 延迟发出token) ──
        mirror_tokens = []
        mirrors = self._detect_mirrors(quantized, h, w)
        for m in mirrors:
            my, mx, mw, mh = m['y'], m['x'], m['w'], m['h']
            if m['axis'] == 'h':
                mirror_tokens.append(f"mirror_h[{mx},{my},{mw},{mh}]")
                src = quantized[my:my+mh, mx-mw:mx]
                dst = quantized[my:my+mh, mx:mx+mw]
                match = (dst == src[:, ::-1])
            else:
                mirror_tokens.append(f"mirror_v[{mx},{my},{mw},{mh}]")
                src = quantized[my-mh:my, mx:mx+mw]
                dst = quantized[my:my+mh, mx:mx+mw]
                match = (dst == src[::-1, :])
            covered[my:my+mh, mx:mx+mw] = match

        # ── Pass 2: 渐变检测 ──
        grads = self._detect_gradients(quantized, covered, h, w, palette)
        for g in grads:
            tokens.append(
                f"grad[{g['x']},{g['y']},{g['w']},{g['h']},"
                f"c{g['c0']},c{g['c1']},{g['dir']}]")
            covered[g['y']:g['y']+g['h'], g['x']:g['x']+g['w']] = True

        # ── Pass 3: 平铺检测 ──
        tile_defs = self._detect_tiles(quantized, covered, h, w)
        self._emit_tile_tokens(tile_defs, tokens, covered)

        # ── Pass 4: 区域复制 (标记覆盖, 延迟发出token) ──
        copies = self._detect_copies(quantized, covered, h, w)

        # ── Pass 5: 多行色块 + 行级编码(含行内重复/颜色差分) ──
        self._encode_remaining(quantized, covered, h, w, tokens)

        # ── 延迟发出: 镜像和复制token (源区域已由Pass 5编码) ──
        tokens.extend(mirror_tokens)
        for c in copies:
            tokens.append(
                f"copy[{c['sx']},{c['sy']},{c['dx']},{c['dy']},"
                f"{c['w']},{c['h']}]")

        # ── Pass 6: 相对定位(后处理) ──
        tokens = self._apply_delta(tokens)

        return tokens, quantized, palette

    # ──────────── 量化 ────────────

    def _quantize_colors(self, image):
        h, w, _ = image.shape
        pixels = image.reshape(-1, 3).astype(np.float32)
        n_c = min(self.max_colors, len(pixels))
        sample = (pixels[np.random.choice(len(pixels), min(10000, len(pixels)), replace=False)]
                  if len(pixels) > 10000 else pixels)
        km = KMeans(n_clusters=n_c, random_state=42, n_init=10)
        km.fit(sample)
        labels = km.predict(pixels).reshape(h, w)
        centers = km.cluster_centers_.astype(np.uint8)
        return labels, [centers[i] for i in range(len(centers))]

    # ──────────── Pass 1: 镜像 ────────────

    def _detect_mirrors(self, quantized, h, w):
        mirrors = []

        if w >= 4 and w % 2 == 0:
            hw = w // 2
            left = quantized[:, :hw]
            right_flip = quantized[:, hw:][:, ::-1]
            rate = np.mean(left == right_flip)
            if rate >= self.mirror_threshold:
                mirrors.append({
                    'axis': 'h', 'x': hw, 'y': 0,
                    'w': hw, 'h': h, 'rate': rate})

        if h >= 4 and h % 2 == 0:
            hh = h // 2
            top = quantized[:hh, :]
            bot_flip = quantized[hh:][::-1, :]
            rate = np.mean(top == bot_flip)
            if rate >= self.mirror_threshold:
                mirrors.append({
                    'axis': 'v', 'x': 0, 'y': hh,
                    'w': w, 'h': hh, 'rate': rate})

        if len(mirrors) > 1:
            mirrors.sort(key=lambda m: -m['w'] * m['h'] * m['rate'])
            mirrors = mirrors[:1]
        return mirrors

    # ──────────── Pass 2: 渐变 ────────────

    def _detect_gradients(self, quantized, covered, h, w, palette):
        grads = []
        min_len = 4

        for y in range(h):
            if covered[y].all():
                continue
            x = 0
            while x < w:
                if covered[y, x]:
                    x += 1
                    continue

                colors_seq = []
                cx = x
                last_c = -1
                while cx < w and not covered[y, cx]:
                    c = int(quantized[y, cx])
                    if c != last_c:
                        colors_seq.append(c)
                        last_c = c
                    cx += 1

                span = cx - x
                if len(colors_seq) < 3 or span < min_len:
                    x = cx
                    continue

                rgbs = np.array([palette[c] for c in colors_seq], dtype=np.float32)
                diffs = np.diff(rgbs, axis=0)
                mean_d = diffs.mean(axis=0)
                if np.max(np.abs(mean_d)) < 1:
                    x = cx
                    continue
                if np.max(np.abs(diffs - mean_d)) > 25:
                    x = cx
                    continue

                c0, c1 = colors_seq[0], colors_seq[-1]

                gh = 1
                while (y + gh < h
                       and not covered[y + gh, x:x + span].any()
                       and int(quantized[y + gh, x]) == c0
                       and int(quantized[y + gh, cx - 1]) == c1):
                    gh += 1

                grads.append({
                    'x': x, 'y': y, 'w': span, 'h': gh,
                    'c0': c0, 'c1': c1, 'dir': 'h'})
                covered[y:y + gh, x:cx] = True
                x = cx

        return grads

    # ──────────── Pass 3: 平铺 ────────────

    def _detect_tiles(self, quantized, covered, h, w):
        candidates = []
        for th in sorted(self.tile_sizes, reverse=True):
            for tw in sorted(self.tile_sizes, reverse=True):
                if tw > w // 2 or th > h // 2:
                    continue
                if h % th != 0 or w % tw != 0:
                    continue
                tmap: Dict[bytes, dict] = {}
                for ty in range(0, h - th + 1, th):
                    for tx in range(0, w - tw + 1, tw):
                        td = quantized[ty:ty + th, tx:tx + tw].copy()
                        k = td.tobytes()
                        if k not in tmap:
                            tmap[k] = {'data': td, 'pos': []}
                        tmap[k]['pos'].append((tx, ty))
                for info in tmap.values():
                    n = len(info['pos'])
                    if n < 2:
                        continue
                    tp = tw * th
                    if np.all(info['data'] == info['data'][0, 0]):
                        continue
                    sav = n * tp - tp - n * 3
                    if sav >= tp:
                        candidates.append({
                            'w': tw, 'h': th,
                            'data': info['data'], 'pos': info['pos'],
                            'sav': sav})

        candidates.sort(key=lambda c: -c['sav'])
        selected = []
        used = np.zeros((h, w), dtype=bool)
        for cand in candidates:
            tw, th = cand['w'], cand['h']
            valid = [(tx, ty) for tx, ty in cand['pos']
                     if not used[ty:ty + th, tx:tx + tw].any()]
            if len(valid) >= 2:
                selected.append({'w': tw, 'h': th, 'data': cand['data'], 'pos': valid})
                for tx, ty in valid:
                    used[ty:ty + th, tx:tx + tw] = True
        return selected

    def _emit_tile_tokens(self, tile_defs, tokens, covered):
        for tid, tinfo in enumerate(tile_defs):
            tw, th = tinfo['w'], tinfo['h']
            tokens.append(f"tile[{tid},{tw},{th}]")
            for ty in range(th):
                tokens.append("tr[" + ','.join(str(c) for c in tinfo['data'][ty]) + "]")
            tokens.append("tile_end")
            groups = self._group_contiguous(tinfo['pos'], tw, th)
            for g in groups:
                if len(g) == 1:
                    tokens.append(f"tile_at[{tid},{g[0][0]},{g[0][1]}]")
                else:
                    mx = min(p[0] for p in g)
                    my = min(p[1] for p in g)
                    nx = (max(p[0] for p in g) - mx) // tw + 1
                    ny = (max(p[1] for p in g) - my) // th + 1
                    tokens.append(f"tile_fill[{tid},{mx},{my},{nx},{ny}]")
            for tx, ty in tinfo['pos']:
                covered[ty:ty + th, tx:tx + tw] = True

    def _group_contiguous(self, positions, tw, th):
        if not positions:
            return []
        ps = set(positions)
        vis = set()
        groups = []
        for p in positions:
            if p in vis:
                continue
            g = []
            stk = [p]
            while stk:
                c = stk.pop()
                if c in vis:
                    continue
                vis.add(c)
                g.append(c)
                for dx, dy in [(tw, 0), (-tw, 0), (0, th), (0, -th)]:
                    nb = (c[0] + dx, c[1] + dy)
                    if nb in ps and nb not in vis:
                        stk.append(nb)
            groups.append(g)
        return groups

    # ──────────── Pass 4: 区域复制 ────────────

    def _detect_copies(self, quantized, covered, h, w):
        copies = []
        bs = min(16, w // 4, h // 4)
        if bs < 4:
            return copies

        blocks = {}
        for by in range(0, h, bs):
            for bx in range(0, w, bs):
                bh = min(bs, h - by)
                bw = min(bs, w - bx)
                if covered[by:by + bh, bx:bx + bw].any():
                    continue
                region = quantized[by:by + bh, bx:bx + bw].copy()
                k = region.tobytes()
                if k not in blocks:
                    blocks[k] = []
                blocks[k].append((bx, by, bw, bh, region))

        for k, blist in blocks.items():
            if len(blist) < 2:
                continue
            src = blist[0]
            for dst in blist[1:]:
                sx, sy, bw, bh, _ = src
                dx, dy = dst[0], dst[1]
                if covered[dy:dy + bh, dx:dx + bw].any():
                    continue
                copies.append({'sx': sx, 'sy': sy, 'dx': dx, 'dy': dy, 'w': bw, 'h': bh})
                covered[dy:dy + bh, dx:dx + bw] = True

        return copies

    # ──────────── Pass 5: 色块 + 行编码(含行内重复/颜色差分) ────────────

    def _encode_remaining(self, quantized, covered, h, w, tokens):
        prev_sig = None
        rep_n = 0
        y = 0

        while y < h:
            if covered[y].all():
                if prev_sig is not None:
                    if rep_n > 0:
                        tokens.append(f"repeat[{rep_n}]")
                    prev_sig = None
                    rep_n = 0
                y += 1
                continue

            runs = self._get_runs(quantized[y], covered[y], w)
            if not runs:
                y += 1
                continue

            multi = []
            single = []
            for sx, ex, cid in runs:
                rw = ex - sx
                vh = 1
                ok = True
                while ok and y + vh < h:
                    for cx in range(sx, ex):
                        if quantized[y + vh, cx] != cid or covered[y + vh, cx]:
                            ok = False
                            break
                    if ok:
                        vh += 1
                if vh > 1:
                    multi.append((sx, y, rw, vh, cid))
                    covered[y:y + vh, sx:ex] = True
                else:
                    single.append((sx, ex, cid))

            for bx, by, bw, bh, bc in multi:
                tokens.append(f"block[x{bx},y{by},w{bw},h{bh},c{bc}]")

            if single:
                sig = tuple(single)
                if sig == prev_sig and prev_sig is not None and not multi:
                    rep_n += 1
                else:
                    if rep_n > 0:
                        tokens.append(f"repeat[{rep_n}]")
                        rep_n = 0
                    self._emit_row_enhanced(single, y, w, tokens)
                    prev_sig = sig
            else:
                if multi and rep_n > 0:
                    tokens.append(f"repeat[{rep_n}]")
                    rep_n = 0
                prev_sig = None
            y += 1

        if rep_n > 0:
            tokens.append(f"repeat[{rep_n}]")

    def _get_runs(self, row, row_cov, width):
        runs = []
        x = 0
        while x < width:
            if row_cov[x]:
                x += 1
                continue
            cid = row[x]
            end = x + 1
            while end < width and not row_cov[end] and row[end] == cid:
                end += 1
            runs.append((x, end, cid))
            x = end
        return runs

    def _emit_row_enhanced(self, runs, y, width, tokens):
        # 1) 整行同色
        if len(runs) == 1 and runs[0][0] == 0 and runs[0][1] == width:
            tokens.append(f"r[{y}] fill[c{runs[0][2]}]")
            return

        # 2) 行内重复检测: pat[...]rep[n]
        pat = self._find_pattern(runs)
        if pat is not None:
            pat_parts = ','.join(f"w{w},c{c}" for w, c in pat['pattern'])
            tokens.append(f"r[{y}] pat[x{pat['start_x']},{pat_parts}]rep[{pat['repeats']}]")
            return

        # 3) 颜色差分: same[dx,c]
        parts = self._emit_with_same(runs, width)
        tokens.append(f"r[{y}] " + ' '.join(parts))

    def _find_pattern(self, runs):
        n = len(runs)
        if n < 4:
            return None
        for i in range(1, n):
            if runs[i][0] != runs[i - 1][1]:
                return None
        start_x = runs[0][0]
        for plen in range(1, n // 2 + 1):
            if n % plen != 0:
                continue
            reps = n // plen
            if reps < 2:
                continue
            pattern = [(runs[i][1] - runs[i][0], runs[i][2]) for i in range(plen)]
            ok = True
            for r in range(reps):
                for i in range(plen):
                    idx = r * plen + i
                    rw = runs[idx][1] - runs[idx][0]
                    rc = runs[idx][2]
                    if rw != pattern[i][0] or rc != pattern[i][1]:
                        ok = False
                        break
                if not ok:
                    break
            if ok:
                return {'pattern': pattern, 'repeats': reps, 'start_x': start_x}
        return None

    def _emit_with_same(self, runs, width):
        parts = []
        widths_seen = []
        for sx, ex, cid in runs:
            rw = ex - sx
            found_ref = False
            for pw, pc in widths_seen:
                if pw == rw and pc != cid:
                    parts.append(f"same[x{sx},w{rw},c{cid}]")
                    found_ref = True
                    break
            if not found_ref:
                if ex == width:
                    parts.append(f"row_end[x{sx},c{cid}]")
                else:
                    parts.append(f"blk[x{sx},w{rw},c{cid}]")
            widths_seen.append((rw, cid))
        return parts

    # ──────────── Pass 6: 相对定位 ────────────

    def _apply_delta(self, tokens):
        result = []
        last_x, last_y = 0, 0

        for t in tokens:
            if t.startswith('block[x'):
                m = re.match(r'block\[x(\d+),y(\d+),(w\d+,h\d+,c\d+)\]', t)
                if m:
                    x, y = int(m.group(1)), int(m.group(2))
                    rest = m.group(3)
                    dx, dy = x - last_x, y - last_y
                    if -99 <= dx <= 99 and -99 <= dy <= 99:
                        result.append(f"block[dx{dx},dy{dy},{rest}]")
                    else:
                        result.append(t)
                    last_x, last_y = x, y
                    continue

            elif t.startswith('r[') and not t.startswith('r[+'):
                m = re.match(r'r\[(\d+)\]', t)
                if m:
                    y = int(m.group(1))
                    dy = y - last_y
                    content = t[t.index(']') + 1:]
                    if 0 < dy <= 99:
                        result.append(f'r[+{dy}]{content}')
                    else:
                        result.append(t)
                    last_y = y
                    continue

            result.append(t)
        return result

    # ════════════════════════════════════════
    #  解码
    # ════════════════════════════════════════

    def decode(self, tokens: List[str], palette: List[np.ndarray]) -> np.ndarray:
        cw = ch = 0
        img = None
        last_row = None
        cur_y = 0
        tiles: Dict[int, np.ndarray] = {}
        tid = -1
        tr_rows = []
        tw = th = 0
        last_x, last_y = 0, 0

        for t in tokens:
            # ── canvas ──
            if t.startswith('canvas['):
                m = re.match(r'canvas\[(\d+),(\d+)\]', t)
                cw, ch = int(m.group(1)), int(m.group(2))
                img = np.full((ch, cw), -1, dtype=np.int32)

            # ── mirror ──
            elif t.startswith('mirror_h['):
                m = re.match(r'mirror_h\[(\d+),(\d+),(\d+),(\d+)\]', t)
                if m and img is not None:
                    mx, my, mw, mh = (int(m.group(1)), int(m.group(2)),
                                      int(m.group(3)), int(m.group(4)))
                    src = img[my:my + mh, mx - mw:mx]
                    if src.shape[1] == mw:
                        mirrored = src[:, ::-1]
                        mask = (img[my:my + mh, mx:mx + mw] == -1)
                        img[my:my + mh, mx:mx + mw] = np.where(
                            mask, mirrored, img[my:my + mh, mx:mx + mw])

            elif t.startswith('mirror_v['):
                m = re.match(r'mirror_v\[(\d+),(\d+),(\d+),(\d+)\]', t)
                if m and img is not None:
                    mx, my, mw, mh = (int(m.group(1)), int(m.group(2)),
                                      int(m.group(3)), int(m.group(4)))
                    src = img[my - mh:my, mx:mx + mw]
                    if src.shape[0] == mh:
                        mirrored = src[::-1, :]
                        mask = (img[my:my + mh, mx:mx + mw] == -1)
                        img[my:my + mh, mx:mx + mw] = np.where(
                            mask, mirrored, img[my:my + mh, mx:mx + mw])

            # ── gradient ──
            elif t.startswith('grad['):
                m = re.match(r'grad\[(\d+),(\d+),(\d+),(\d+),c(\d+),c(\d+),([hv])\]', t)
                if m and img is not None:
                    gx, gy = int(m.group(1)), int(m.group(2))
                    gw, gh = int(m.group(3)), int(m.group(4))
                    c0, c1 = int(m.group(5)), int(m.group(6))
                    d = m.group(7)
                    rgb0 = palette[c0].astype(np.float32)
                    rgb1 = palette[c1].astype(np.float32)
                    for dy in range(gh):
                        for dx in range(gw):
                            if d == 'h':
                                frac = dx / max(gw - 1, 1)
                            else:
                                frac = dy / max(gh - 1, 1)
                            rgb = rgb0 * (1 - frac) + rgb1 * frac
                            c = int(np.argmin(np.sum(
                                (np.array(palette, dtype=np.float32) - rgb) ** 2, axis=1)))
                            py, px = gy + dy, gx + dx
                            if 0 <= py < ch and 0 <= px < cw:
                                img[py, px] = c

            # ── tile definition ──
            elif t.startswith('tile['):
                m = re.match(r'tile\[(\d+),(\d+),(\d+)\]', t)
                if m:
                    tid, tw, th = int(m.group(1)), int(m.group(2)), int(m.group(3))
                    tr_rows = []

            elif t.startswith('tr['):
                tr_rows.append([int(v) for v in t[3:-1].split(',')])

            elif t == "tile_end":
                if tid >= 0 and len(tr_rows) == th:
                    tiles[tid] = np.array(tr_rows, dtype=np.int32).reshape(th, tw)

            elif t.startswith('tile_at['):
                m = re.match(r'tile_at\[(\d+),(\d+),(\d+)\]', t)
                if m and img is not None:
                    t_id, tx, ty = int(m.group(1)), int(m.group(2)), int(m.group(3))
                    self._place_tile(img, tiles.get(t_id), tx, ty, ch, cw)

            elif t.startswith('tile_fill['):
                m = re.match(r'tile_fill\[(\d+),(\d+),(\d+),(\d+),(\d+)\]', t)
                if m and img is not None:
                    t_id = int(m.group(1))
                    sx, sy = int(m.group(2)), int(m.group(3))
                    nx, ny = int(m.group(4)), int(m.group(5))
                    td = tiles.get(t_id)
                    if td is not None:
                        tth, ttw = td.shape
                        for iy in range(ny):
                            for ix in range(nx):
                                self._place_tile(img, td, sx + ix * ttw,
                                                 sy + iy * tth, ch, cw)

            # ── copy ──
            elif t.startswith('copy['):
                m = re.match(r'copy\[(\d+),(\d+),(\d+),(\d+),(\d+),(\d+)\]', t)
                if m and img is not None:
                    sx, sy = int(m.group(1)), int(m.group(2))
                    dx, dy = int(m.group(3)), int(m.group(4))
                    bw, bh = int(m.group(5)), int(m.group(6))
                    src = img[sy:sy + bh, sx:sx + bw].copy()
                    dye, dxe = min(dy + bh, ch), min(dx + bw, cw)
                    img[dy:dye, dx:dxe] = src[:dye - dy, :dxe - dx]

            # ── block (absolute) ──
            elif t.startswith('block[x'):
                m = re.match(r'block\[x(\d+),y(\d+),w(\d+),h(\d+),c(\d+)\]', t)
                if m and img is not None:
                    bx, by = int(m.group(1)), int(m.group(2))
                    bw, bh, bc = int(m.group(3)), int(m.group(4)), int(m.group(5))
                    img[by:min(by + bh, ch), bx:min(bx + bw, cw)] = bc
                    last_x, last_y = bx, by
                    last_row = img[by, :].copy()
                    cur_y = by + bh

            # ── block (delta) ──
            elif t.startswith('block[dx'):
                m = re.match(r'block\[dx([+-]?\d+),dy([+-]?\d+),(w\d+),(h\d+),(c\d+)\]', t)
                if m and img is not None:
                    dx, dy = int(m.group(1)), int(m.group(2))
                    bw = int(m.group(3)[1:])
                    bh = int(m.group(4)[1:])
                    bc = int(m.group(5)[1:])
                    bx, by = last_x + dx, last_y + dy
                    img[by:min(by + bh, ch), bx:min(bx + bw, cw)] = bc
                    last_x, last_y = bx, by
                    last_row = img[by, :].copy()
                    cur_y = by + bh

            # ── repeat ──
            elif t.startswith('repeat['):
                m = re.match(r'repeat\[(\d+)\]', t)
                if m and img is not None:
                    for _ in range(int(m.group(1))):
                        if cur_y < ch and last_row is not None:
                            img[cur_y, :] = last_row
                            cur_y += 1

            # ── row (absolute) ──
            elif t.startswith('r[') and not t.startswith('r[+'):
                m = re.match(r'\[(\d+)\](.*)', t[1:])
                if m and img is not None:
                    y = int(m.group(1))
                    self._decode_row(img, y, cw, m.group(2).strip())
                    last_row = img[y, :].copy()
                    cur_y = y + 1
                    last_y = y

            # ── row (delta) ──
            elif t.startswith('r[+'):
                m = re.match(r'r\[\+(\d+)\](.*)', t)
                if m and img is not None:
                    dy = int(m.group(1))
                    y = last_y + dy
                    self._decode_row(img, y, cw, m.group(2).strip())
                    last_row = img[y, :].copy()
                    cur_y = y + 1
                    last_y = y

        if img is not None:
            img[img == -1] = 0
        return self._to_rgb(img, palette)

    def _place_tile(self, img, td, tx, ty, ch, cw):
        if td is None:
            return
        tth, ttw = td.shape
        ye, xe = min(ty + tth, ch), min(tx + ttw, cw)
        if ty < ch and tx < cw:
            img[ty:ye, tx:xe] = td[:ye - ty, :xe - tx]

    def _decode_row(self, img, y, width, content):
        if not content:
            return

        if content.startswith('fill['):
            m = re.match(r'fill\[c(\d+)\]', content)
            if m:
                img[y, :] = int(m.group(1))
            return

        if content.startswith('pat['):
            m = re.match(r'pat\[x(\d+),(.+)\]rep\[(\d+)\]', content)
            if m:
                x = int(m.group(1))
                pat_str = m.group(2)
                reps = int(m.group(3))
                pat_parts = pat_str.split(',')
                pattern = []
                i = 0
                while i < len(pat_parts):
                    w = int(pat_parts[i][1:])
                    c = int(pat_parts[i + 1][1:])
                    pattern.append((w, c))
                    i += 2
                for _ in range(reps):
                    for pw, pc in pattern:
                        img[y, x:x + pw] = pc
                        x += pw
                return

        for part in content.split():
            if part.startswith('row_end['):
                m = re.match(r'row_end\[x(\d+),c(\d+)\]', part)
                if m:
                    img[y, int(m.group(1)):] = int(m.group(2))
            elif part.startswith('blk['):
                m = re.match(r'blk\[x(\d+),w(\d+),c(\d+)\]', part)
                if m:
                    x, bw, c = int(m.group(1)), int(m.group(2)), int(m.group(3))
                    img[y, x:x + bw] = c
            elif part.startswith('same['):
                m = re.match(r'same\[x(\d+),w(\d+),c(\d+)\]', part)
                if m:
                    x, w, c = int(m.group(1)), int(m.group(2)), int(m.group(3))
                    img[y, x:x + w] = c

    def _to_rgb(self, labels, palette):
        if labels is None:
            return np.zeros((1, 1, 3), dtype=np.uint8)
        h, w = labels.shape
        rgb = np.zeros((h, w, 3), dtype=np.uint8)
        for c in range(len(palette)):
            rgb[labels == c] = palette[c]
        return rgb

    # ════════════════════════════════════════
    #  文件IO
    # ════════════════════════════════════════

    def save_tokens(self, tokens, palette, path):
        with open(path, 'w', encoding='utf-8') as f:
            json.dump({'tokens': tokens, 'palette': [c.tolist() for c in palette]}, f,
                      ensure_ascii=False, indent=2)
        print(f"Token已保存: {path} ({os.path.getsize(path) / 1024:.1f} KB)")

    def load_tokens(self, path):
        with open(path, 'r', encoding='utf-8') as f:
            d = json.load(f)
        return d['tokens'], [np.array(c, dtype=np.uint8) for c in d['palette']]

    # ════════════════════════════════════════
    #  统计
    # ════════════════════════════════════════

    def token_stats(self, tokens, image):
        s = dict.fromkeys([
            'fill_rows', 'row_ends', 'blocks', 'multi_blocks',
            'repeat_rows', 'tiles', 'tile_at', 'tile_fill',
            'mirrors', 'gradients', 'copies', 'patterns',
            'same_refs', 'delta_blocks', 'delta_rows',
        ], 0)
        s['total_tokens'] = len(tokens)

        for t in tokens:
            if t.startswith('r[') or t.startswith('r[+'):
                if 'fill[' in t:
                    s['fill_rows'] += 1
                elif 'pat[' in t:
                    s['patterns'] += 1
                elif 'same[' in t:
                    s['same_refs'] += t.count('same[')
                s['blocks'] += t.count('blk[')
                s['row_ends'] += t.count('row_end[')
                if t.startswith('r[+'):
                    s['delta_rows'] += 1
            elif t.startswith('repeat['):
                m = re.match(r'repeat\[(\d+)\]', t)
                if m:
                    s['repeat_rows'] += int(m.group(1))
            elif t.startswith('block[x'):
                m = re.match(r'block\[x\d+,y\d+,w\d+,h(\d+)', t)
                if m:
                    s['multi_blocks' if int(m.group(1)) > 1 else 'blocks'] += 1
            elif t.startswith('block[dx'):
                s['delta_blocks'] += 1
            elif t.startswith('tile['):
                s['tiles'] += 1
            elif t.startswith('tile_at['):
                s['tile_at'] += 1
            elif t.startswith('tile_fill['):
                m = re.match(r'tile_fill\[\d+,\d+,\d+,(\d+),(\d+)\]', t)
                s['tile_fill'] += int(m.group(1)) * int(m.group(2)) if m else 1
            elif t.startswith('mirror_'):
                s['mirrors'] += 1
            elif t.startswith('grad['):
                s['gradients'] += 1
            elif t.startswith('copy['):
                s['copies'] += 1

        h, w, _ = image.shape
        s['image_size'] = f'{w}x{h}'
        s['original_bytes'] = h * w * 3
        s['token_bytes'] = len('\n'.join(tokens).encode('utf-8'))
        s['compression_ratio'] = s['original_bytes'] / max(1, s['token_bytes'])
        return s


# ════════════════════════════════════════
#  可视化
# ════════════════════════════════════════

def display_comparison(original, quantized_rgb, reconstructed, stats):
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    titles = [
        f"原始\n{original.shape[:2]}",
        f"量化\n{stats['total_tokens']} tokens",
        f"重建\nPSNR: {_psnr(original, reconstructed):.1f}dB",
        f"差异(x10)",
    ]
    imgs = [original, quantized_rgb, reconstructed,
            cv2.absdiff(original, reconstructed) * 10]
    for ax, im, ti in zip(axes, imgs, titles):
        ax.imshow(im)
        ax.set_title(ti, fontsize=10)
        ax.axis('off')
    plt.suptitle(
        f"色块语义v3 | 压缩比 {stats['compression_ratio']:.1f}x",
        fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


def _psnr(a, b):
    mse = np.mean((a.astype(np.float64) - b.astype(np.float64)) ** 2)
    return float('inf') if mse == 0 else 20 * np.log10(255.0 / np.sqrt(mse))


def print_stats(s):
    print(f"\n{'═' * 55}")
    print(f"  色块语义token统计 v3 (全语义压缩)")
    print(f"{'═' * 55}")
    print(f"  图像: {s['image_size']}  总token: {s['total_tokens']}")
    print(f"{'─' * 55}")
    print(f"  镜像 mirror:       {s['mirrors']}")
    print(f"  渐变 gradient:     {s['gradients']}")
    print(f"  平铺 tile:         {s['tiles']}  "
          f"(at:{s['tile_at']} fill:{s['tile_fill']})")
    print(f"  复制 copy:         {s['copies']}")
    print(f"  多行 block:        {s['multi_blocks']}")
    print(f"  单行 block:        {s['blocks']}")
    print(f"  整行 fill_row:     {s['fill_rows']}")
    print(f"  行尾 row_end:      {s['row_ends']}")
    print(f"  行内重复 pat:      {s['patterns']}")
    print(f"  颜色差分 same:     {s['same_refs']}")
    print(f"  行重复 repeat:     {s['repeat_rows']}")
    print(f"  相对定位 delta:    blk:{s['delta_blocks']} row:{s['delta_rows']}")
    print(f"{'─' * 55}")
    print(f"  原始: {s['original_bytes']:,}B  "
          f"Token: {s['token_bytes']:,}B  "
          f"压缩比: {s['compression_ratio']:.1f}x")
    print(f"{'═' * 55}")


# ════════════════════════════════════════
#  主流程
# ════════════════════════════════════════

if __name__ == "__main__":
    image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "img_1.png")
    if not os.path.exists(image_path):
        print(f"未找到: {image_path}")
        exit(1)

    img = Image.open(image_path).convert('RGB')
    img_array = np.array(img.resize((256, 256)))

    tokenizer = ColorBlockTokenizer(
        max_colors=16, tile_sizes=[2, 4, 8, 16], mirror_threshold=0.85)

    print("编码中...")
    tokens, quantized, palette = tokenizer.encode(img_array)
    print("解码中...")
    reconstructed = tokenizer.decode(tokens, palette)

    stats = tokenizer.token_stats(tokens, img_array)
    print_stats(stats)

    print(f"\n前20个token:")
    for i, t in enumerate(tokens):
        print(f"  {i + 1:2d}. {t}")
    if len(tokens) > 20:
        print(f"  ... 共 {len(tokens)} 个")

    qrgb = np.zeros_like(img_array)
    for c in range(len(palette)):
        qrgb[quantized == c] = palette[c]

    h, w, _ = reconstructed.shape
    r_labels = np.zeros((h, w), dtype=np.int32)
    for c in range(len(palette)):
        mask = np.all(reconstructed == palette[c], axis=2)
        r_labels[mask] = c

    diff_pixels = int(np.sum(quantized != r_labels))
    total_pixels = h * w
    q_psnr = _psnr(qrgb, reconstructed)
    print(f"\nPSNR (quantized→reconstructed): {q_psnr:.1f} dB")
    print(f"Mismatch pixels: {diff_pixels}/{total_pixels} "
          f"({100*diff_pixels/total_pixels:.2f}%)")

    display_comparison(img_array, qrgb, reconstructed, stats)
    tokenizer.save_tokens(tokens, palette, 'color_block_tokens.json')
相关推荐
沐泽__1 小时前
欧氏距离、余弦相似度(cosin)、点积 区别与用途详解(附实例)
人工智能·机器学习
victory04311 小时前
DeepSeek-V4知识点讲解记录
人工智能
Python私教1 小时前
如意Agent对话持久化与滚动记忆引擎设计:让AI记住你们聊过的每一句话
人工智能
这张生成的图像能检测吗1 小时前
(论文速读)SPR-YOLO:面向模糊场景的轻量级交通流检测算法
人工智能·yolo·计算机视觉·目标追踪
独隅1 小时前
Anaconda 配置 Keras 环境的详细流程指南
人工智能·深度学习·keras
恋猫de小郭1 小时前
AI 时代开源协议将消亡,malus 讽刺性展示了这一点
前端·人工智能·ai编程
数字化顾问1 小时前
(97页PPT)麦肯锡战略规划制定方法及模板制品(附下载方式)
人工智能·物联网
Kiyra1 小时前
我是怎么把一个普通 AI 聊天项目改造成工程化 Agent Runtime 的
人工智能
PythonFun1 小时前
WPS AI隐藏玩法!自定义指令让办公效率起飞
人工智能·wps