来源论文: https://arxiv.org/abs/2606.03199v1 生成时间: Jun 03, 2026 07:29
变革性晶体结构预测:基于单晶胞流匹配与无三角形注意力机制的 CLARI 深度解析
0. 执行摘要
有机晶体结构预测(Crystal Structure Prediction, CSP)是计算化学、制药工程、有机光电及材料科学领域的核心瓶颈之一。传统基于物理化学的第一性原理搜索方法(如随机搜索、演化算法)需要耗费数千个 CPU-核年(CPU-years)来寻找单一分子的亚稳态多晶型(Polymorphism)。近年来,以 OXtal 为代表的端到端生成式深度学习模型虽然极大地缩减了晶体空间采样的时间,但由于其沿袭了 AlphaFold3 的重型设计——即在空间上对晶体进行大范围的批量扩增(Bulk Crops),并依赖具有三次方计算复杂度的三角形更新层(Triangle-Update Layers),导致其预测单个分子仍需数分钟的计算时间,难以满足数百万量级虚拟化学库的高通量筛选(High-Throughput Screening)需求。
为了彻底打破这一算力枷锁,多伦多大学的 Alán Aspuru-Guzik 教授团队联手 MIT CSAIL 及 NVIDIA 等顶尖机构,提出了全新的无冗余单晶胞生成模型 CLARI。CLARI 摒弃了复杂的体相堆叠表示,直接在无冗余的单晶胞(Unit Cell)上建立坐标与晶格矢量的联合流匹配(Flow Matching)生成轨迹。同时,CLARI 用纯双偏置注意力机制(Pair-Bias Attention)替代了极其昂贵的三角形更新层。这一颠覆性的设计不仅使得 CLARI 在预测精度(Solve Rate)上超越了 OXtal,更在预测速度上实现了惊人的 15–30倍 提升,将单个分子的 CSP 预测时间缩短至秒级。此外,CLARI 直接建模了包含显式氢原子(Explicit Hydrogens)在内的全原子系统,使得生成结构可以直接通过通用的机器学习力场(如 UMA)进行零弛豫、零修饰的能量重排(Inference-time Scaling),在维持 5–8倍 端到端加速的同时,显著提升了结构预测的成功率。本博文将从核心科学问题、理论基础、技术实现、实验评测以及未来局限性等五个维度,对 CLARI 进行全面而深入的技术剖析。
1. 核心科学问题,理论基础,技术难点,方法细节
1.1 核心科学问题与技术难点
有机晶体由分子的周期性三维排列构成。由于分子间相互作用力(如范德华力、氢键、$\\pi-\\pi$ 堆积)极其微弱且高度非线性,其能量势能面(Potential Energy Surface, PES)呈现出极度崎岖和多局域极小的特征。这种物理特性导致了**多晶型现象(Polymorphism)**的普遍存在——相同的分子在不同的温度、压力或溶剂条件下会结晶出完全不同的晶体结构,而这些不同的多晶型往往对应着迥异的溶解度、稳定性和光电性能。因此,CSP 的科学任务不仅是寻找能量最低的点,更是要高效地采样出所有可能在实验中观测到的物理亚稳态结构。
传统的生成式 CSP 模型(如 OXtal)面临三大技术难点:
- 周期性边界条件的表征冗余:为了避免处理晶格参数的奇点,先前的模型倾向于在空间中对分子进行对称性复制,生成包含多个晶胞的巨型「体相裁剪」(Bulk Crops)。这种表示方法导致网络需要处理成百上千个原子的几何图,计算开销极大。
- 高阶空间关联特征计算极其昂贵:AlphaFold3 的成功让很多生物分子生成模型继承了其“三角形边更新”(Triangle Updates)的设计。然而,三角形更新在节点数为 $N$ 的图中需要进行 $O(N^3)$ 的张量运算,成为了制约模型向大分子、大晶胞体系扩展的主要瓶颈。
- 缺乏鲁棒的晶格先验:如果仅使用标准各向同性的高斯分布作为流匹配的初始分布(Prior),模型在去噪早期极易生成密度过高(导致原子重叠爆炸)或密度过低(生成无意义的稀疏晶格)的无效结构,使优化过程极难收敛。
1.2 联合晶胞表征(Joint Unit Cell Representation)
CLARI 的核心突破在于其直接在**单晶胞(Unit Cell)**上定义生成任务。为了实现这一点,CLARI 将晶格参数与原子坐标融合成一个统一的联合表示对象。
令一个晶胞结构表示为元组 $(\\mathbf{L}, \\mathbf{C}, \\mathbf{F}, \\mathbf{E})$,其中:
- $\\mathbf{L} \\in \\mathbb{R}^{3 \\times 3}$ 是晶格矩阵(Lattice Matrix),其每一行代表一个原初基矢(Primitive Lattice Vectors);
- $\\mathbf{C} \\in \\mathbb{R}^{N \\times 3}$ 是 $N$ 个原子的笛卡尔坐标矩阵,且其重心已被平移至原点(Zeroed Centroid);
- $\\mathbf{F} \\in \\mathbb{R}^{N \\times d}$ 是原子的化学特征(如原子序数等);
- $\\mathbf{E} \\in \\mathbb{R}^{N \\times N}$ 是分子内共价键的邻接矩阵。
为了让晶格与坐标共同参与流匹配轨迹的演化,CLARI 将晶格的 3 个基矢视为 3 个特殊的**“虚拟原子”(Virtual Points)**。具体而言,它将它们沿行方向拼接到原子坐标矩阵中,构造了一个统一的联合状态矩阵 $\\mathbf{x}$:
$$\\mathbf{x} = \\frac{1}{\\sigma} \\begin{pmatrix} \\frac{1}{2}\\mathbf{L} \\ \\mathbf{C} \\end{pmatrix} \\in \\mathbb{R}^{(3+N) \\times 3}$$其中 $\\sigma$ 是用于将整个数据集的方差归一化为单位方差的缩放因子。在这一表示下,前 3 行代表缩放后的晶格基矢,后 $N$ 行代表原子的空间笛卡尔坐标。该表示完美地保证了晶格与原子在网络中的对称性,不需要为晶格和原子分别设计复杂的耦合机制。
这一表示对以下物理对称性保持不变性或等变性:
- 晶格行的带符号排列:交换 $\\mathbf{L}$ 的基矢顺序或对其取反,仅代表晶胞在空间中的不同基底选择;
- 原子行的排列:满足分子图的自同构变换;
- 整体旋转与平移:整体状态 $\\mathbf{x}$ 对三维欧氏群 $SE(3)$ 保持等变性;
- 独立周期性平移:晶胞内任意独立的物理单体(Body/Component)可平移任意个晶格矢量,其生成的物理体相结构完全一致。
1.3 连续时间流匹配理论(Flow Matching)
CLARI 采用流匹配(Flow Matching)框架来建模从简单已知分布 $p_0$ 逐渐过渡到复杂晶体分布 $p_1$ 的连续概率矢量场。在给定初始状态 $\\mathbf{x}_0 \\sim p_0$ 和目标状态 $\\mathbf{x}_1 \\sim p_1$ 的情况下,定义最简的线性插值路径(Linear Interpolant):
$$\\mathbf{x}_t = (1 - t)\\mathbf{x}_0 + t\\mathbf{x}_1$$对应的时间依赖向量场为固定值 $\\mathbf{u}_t(\\mathbf{x}) = \\mathbf{x}_1 - \\mathbf{x}_0$。我们的生成网络 $\\mathbf{v}_\\theta(\\mathbf{x}_t, t)$ 的目标是拟合该向量场,其损失函数形式为:
$$\\mathcal{L}_{FM} = \\mathbb{E}_{t \\sim p(t), \\mathbf{x}_0 \\sim p_0, \\mathbf{x}_1 \\sim p_1} \\left[ \\| \\mathbf{v}_\\theta(\\mathbf{x}_t, t) - (\\mathbf{x}_1 - \\mathbf{x}_0) \\|^2 \\right]$$在实际训练中,为了平衡晶格与坐标的收敛速度,CLARI 将流匹配损失解耦为晶格项 $\\mathcal{L}_{FM}^{\\mathbf{L}}$ 和原子坐标项 $\\mathcal{L}_{FM}^{\\mathbf{C}}$:
$$\\mathcal{L}_{FM} = \\mathcal{L}_{FM}^{\\mathbf{L}} + \\mathcal{L}_{FM}^{\\mathbf{C}}$$两项均采用均方误差(MSE),并给予相同的权重。采样时,我们从 $p_0$ 中抽取一个样本 $\\mathbf{x}_0$,通过常微分方程(ODE)求解器沿着 $\\mathrm{d}\\mathbf{x}_t = \\mathbf{v}_\\theta(\\mathbf{x}_t, t)\\mathrm{d}t$ 积分至 $t=1$,即可获得高保真的晶胞结构。
1.4 数据启发式先验(Data-Informed Prior $p_0$)
若采用标准的各向同性高斯分布 $\\mathcal{N}(0, \\mathbf{I})$ 作为先验 $p_0$,则采样的晶格矢量的行列式(即体积)有极大概率接近于 0(导致严重的原子重叠碰撞)或极大(导致原子极度稀疏)。
CLARI 设计了一种数据启发式先验 $p_0$。其核心思想是将晶格 $\\mathbf{L}$ 的基矢 $(l_1, l_2, l_3)$ 解耦为三个具有明确物理意义的独立分量,并从真实的剑桥晶体数据中心(CSD)训练集中统计并拟合高斯分布:
- 原子密度(Atom Density) $\\rho$: $$\\rho = \\frac{N}{V} \\sim \\mathcal{N}(\\mu_\\rho, \\sigma_\\rho^2)$$ 其中 $V = |\\det \\mathbf{L}|$ 是晶胞体积。这保证了初始状态的原子密度始终处于物理合理的范围内。
- 晶胞夹角(Cell Angles) $\\alpha, \\beta, \\gamma$: $$\\alpha, \\beta, \\gamma \\sim \\mathcal{N}(\\mu_\\circ, \\sigma_\\circ^2) \\quad \\text{i.i.d.}$$ 其中 $\\alpha = \\text{angle}(l_1, l_2)$,以此类推。
- 归一化晶格长度(Normalized Cell Lengths) $(a, b, c)$: 定义 $a = \\|l_1\\|_2 / V^{1/3}$($b, c$ 类似),在排序满足 $a \\le b \\le c$ 的约束下,拟合多元高斯分布: $$(a, b, c) \\sim \\mathcal{N}(\\mu_\\ell, \\mathbf{\\Sigma}_\\ell)$$ 注意,由于排序约束,协方差矩阵 $\\mathbf{\\Sigma}_\\ell$ 不是对角阵。
在采样时,我们独立地从上述三个统计高斯分布中抽取 $(\\rho, \\alpha, \\beta, \\gamma, a, b, c)$,唯一重构出具有物理合理空间比例的初始晶格 $\\mathbf{L}$。为了破除重构晶格的手性与方向偏置,最后对重构矩阵应用一个随机的旋转矩阵 $R \\in SO(3)$ 和带符号的排列矩阵。这一设计完美地消除了早期去噪过程中的“密度崩溃”现象(见图3b对比)。
1.5 纯双偏置注意力机制架构(Pair-Bias DiT)
为了消除三次方复杂度的三角形通道,CLARI 采用扩散变换器(Diffusion Transformer, DiT)架构,并在其多头自注意力(MHA)中引入了双偏置注意力(Pair-Bias Attention)。
1.5.1 特征的构建与调制
CLARI 的节点特征对应原子,而前 3 个特殊节点对应晶格特征。输入的 2D 分子图包含了原子类型、电荷、共价键等,这些一维节点特征被映射为序列特征 $\\mathbf{h} \\in \\mathbb{R}^{N \\times d}$。二维偏置特征 $\\mathbf{z} \\in \\mathbb{R}^{N \\times N \\times d_z}$ 则由分子内的拓扑距离、共价键类型、当前的 3D 笛卡尔距离和三维周期性距离通过正弦/余弦嵌入后拼接生成。
1.5.2 双偏置多头注意力机制(Pair-Bias MHA)
对于注意力头 $k$,其注意力机制的计算公式为:
$$\\mathbf{A}_{ij}^{(k)} = \\text{Softmax} \\left( \\frac{\\mathbf{Q}_i^{(k)} (\\mathbf{K}_j^{(k)})^T}{\\sqrt{d_k}} + \\text{Linear}^{(k)}(\\mathbf{z}_{ij}) \\right)$$通过将二维空间几何信息 $\\mathbf{z}_{ij}$ 线性投影,直接作为偏置项(Pair Bias)加到自注意力矩阵(Attention Logits)中,CLARI 实现了空间几何信息的高效双向传递,而无需任何昂贵的三角形张量乘法。这种设计将网络层级的单步时间复杂度成功压降到了 $O(N^2)$。
1.5.3 调制与正则化(AdaLN-Zero & SwiGLU)
模型引入了来自 DiT 的 AdaLN-Zero(Adaptive Layer Normalization with Zero Initialization)模块,利用全局条件信息 $\\mathbf{c}$(包含当前时间步 $t$、晶格常数、分子化学式等嵌入)对序列特征进行动态缩放和移位:
$$\\text{AdaLN}(\\mathbf{h}, \\mathbf{c}) = \\gamma(\\mathbf{c}) \\odot \\text{LN}(\\mathbf{h}) + \\beta(\\mathbf{c})$$同时在 MLP 部分采用了 SwiGLU 激活函数以进一步增强网络的表达能力,并引入 QKNorm 稳定大模型在大规模数据集上的多头注意力训练。
1.6 最优输运耦合与空间对称性对齐(Optimal Transport Coupling)
在流匹配中,独立的 $\\mathbf{x}_0$ 与 $\\mathbf{x}_1$ 组合会导致去噪轨迹在空间中发生交叉和扭曲。为了使插值轨迹尽可能平直(Straight Trajectories),必须在训练时对真实结构 $\\mathbf{x}_1$ 和初始噪声 $\\mathbf{x}_0$ 进行对称性对齐(Alignment),即寻找一种近似的最优输运(Optimal Transport, OT)映射。
CLARI 针对晶体对称性设计了三个分步对齐算子:
- 极速晶格对齐:寻找一个带符号的排列矩阵 $\\mathbf{\\Pi}$ 和三维旋转矩阵 $\\mathbf{R}$,使得变换后的晶格最接近初始晶格。由于带符号排列矩阵空间极小,可通过暴力枚举 $\\mathbf{\\Pi}$,并对每种情况用 Kabsch 算法求解最优的 $\\mathbf{R}$,最终选取弗罗贝尼乌斯范数最小的组合: $$\\min_{\\mathbf{\\Pi}, \\mathbf{R}} \\| \\mathbf{\\Pi} \\mathbf{L}_1 \\mathbf{R}^T - \\mathbf{L}_0 \\|_F$$
- 图自同构原子对齐:当晶格对齐后,需要对原子顺序进行重排。CLARI 预先在数据预处理阶段缓存了分子图的自同构映射(Isomorphism)。通过匈牙利算法(Hungarian Algorithm),在保持图拓扑不变的自同构限制下,快速求解原子间距离均方根偏差(RMSD)最小的行列重排矩阵。
- 最终姿态对齐:利用加权 Kabsch 算法对整体进行 $SO(3)$ 旋转对齐,其中 3 个虚拟晶格矢量节点与 $N$ 个真实原子节点具有相同的权重,确保晶格与原子的姿态在全局意义下达到高度一致。
1.7 辅助物理损失(Auxiliary Physical Losses)
纯流匹配的 MSE 损失只关注均值,往往忽略了宏观物理量的正则化,容易在生成中产生严重的“原子重叠碰撞(Clash)”或错误的体积。CLARI 在目标函数中加入了两个极具创新的物理辅助损失:
- 体积相对误差损失(Relative Volume Loss) $\\mathcal{L}_{vol}$: $$\\mathcal{L}_{vol} = \\left| \\frac{|\\det \\hat{\\mathbf{L}}_1|}{|\\det \\mathbf{L}_1|} - 1 \\right|$$ 该损失强制要求模型在去噪的第一步预测出的单步估计晶格 $\\hat{\\mathbf{L}}_1$ 的体积与真实值 $\\mathbf{L}_1$ 高度匹配,有效防止了晶胞体积的过度膨胀或收缩。
- 成对周期性距离损失(Pairwise Periodic Distance Loss) $\\mathcal{L}_{pair}$: 该损失包含两部分:对合理距离的拉近,以及对碰撞区原子的强力推开(排斥力)。定义重合集 $\\Lambda = \\{i \\neq j \\mid d_{ij} < 15 \\text{ Å or } \\hat{d}_{ij} < \\alpha_{ij}\\}$,其公式为: $$\\mathcal{L}_{pair} = \\sum_{(i,j) \\in \\Lambda} \\left[ \\left| \\hat{d}_{ij} - d_{ij} \\right| + 5 \\cdot \\max(0, \\alpha_{ij} - \\hat{d}_{ij}) \\right]$$ 其中 $\\alpha_{ij}$ 是根据原子类型指定的碰撞阈值。若两原子间距离小于其共价半径之和,损失函数将施加极其严厉的惩罚,从根本上杜绝了生成晶体中的非物理穿插现象。
2. 关键 Benchmark 体系,计算所得数据,性能数据
2.1 评测数据集与实验设置
为了全面、严苛地评估 CLARI 的生成表现,研究团队选用了以下三个极具代表性的 Benchmark 体系:
- OXtal Test Sets(50个 Rigid 体系 + 50个 Flexible 体系):这是目前端到端 CSP 模型的通用基准,涵盖了刚性分子晶体和具有多个旋转自由度的柔性分子晶体。
- CSP Blind Tests(CSP5-7):来自剑桥晶体数据中心(CCDC)举办的第5至第7届国际晶体结构预测盲测体系(共19个具有极高挑战性的分子)。这些分子是检验 CSP 算法工业界实战能力的金标准。
- CSD Teaching Subset(773个复杂分子体系):为了挑战模型的极限泛化能力,研究团队创新地引入了包含富勒烯(Fullerenes)、硼烷(Boranes)、金属过渡配合物、原子簇等超复杂体系的新测试集。这些体系化学性质多样,且大都无法被 RDKit 软件进行常规的化学键合理性清洗(Sanitization),极大挑战了传统 CSP 软件的极限。
2.2 核心性能评测结果对比
为了直观地展现 CLARI 的优越性,下面整理了论文中 CLARI(Medium-88M 及 Large-173M)在各大公开数据集上的预测成功率(Solve Rate, $\\text{Sol}@k$)以及平均能耗/耗时,并与 SOTA 基线模型 OXtal 以及经典的 DFT 盲测参与者平均水平($\\text{DFT}_{avg}$)进行全方位对比。
表 1:各大 Benchmark 上的 $\\text{Sol}@k$ 成功率对比表($\\ge 8/15$ 匹配标准)
| 算法模型 | 生成预算 $n_s$ | 保留样本数 $k$ | Rigid (50) | Flexible (50) | CSP5 (6) | CSP6 (5) | CSP7 (8) | Teaching (773) |
|---|---|---|---|---|---|---|---|---|
| OXtal | 30 | 30 | 0.300 | 0.220 | 0.167 | 0.200 | 0.125 | — |
| CLARI-M | 30 | 30 | 0.697 | 0.241 | 0.554 | 0.311 | 0.210 | 0.442 |
| CLARI-L | 30 | 30 | 0.731 | 0.287 | 0.681 | 0.355 | 0.245 | 0.461 |
| CLARI-L(能效重排) | 150 | 30 | 0.772 | 0.346 | 0.789 | 0.480 | 0.263 | 0.484 |
| CLARI-L(高预算) | 400 | 200 | 0.919 | 0.596 | 0.975 | 0.729 | 0.566 | 0.669 |
| CLARI-L(极限采样) | 1000 | 1000 | 0.940 | 0.760 | 1.000 | 0.800 | 0.875 | 0.763 |
| DFT 盲测平均 | — | — | — | — | 0.544 | 0.496 | 0.421 | — |
关键数据结论解读:
- 无能效重排下的绝对超越:在不启动任何推理端能效筛选($n_s = k = 30$)的同等条件下,CLARI-M 仅凭借 88M 的参数量,就在所有数据集上全面碾压了上代 SOTA 模型 OXtal。在 Rigid 刚性集上,CLARI-M 的成功率从 30% 跃升至 69.7%(提升超两倍);在极具挑战的 CSP5 盲测集上,CLARI-L 更是取得了 68.1% 的绝佳表现。
- 端到端超越 DFT 平均表现:当采用最佳实践配置(生成 150 个样本,通过 UMA 机器学习力场计算能效,保留能量最低的前 30 个进行匹配)时,CLARI-L 的成功率在 CSP5 盲测集上达到了 78.9%,在 CSP7 上达到了 26.3%。这一表现不仅大幅超越了 OXtal,更全面超越了传统物理化学方法中 DFT 第一性原理计算的盲测平均参与水平($\\text{DFT}_{avg}$)。
2.3 消融实验:哪些模块决定了 CLARI 的高物理品质?
为了厘清 CLARI 中每个设计细节对最终生成质量的定量贡献,团队在 CSD 验证集上进行了系统的消融实验。主要评测指标包括:原子碰撞率(Clash Rate, % ↓)、PoseBusters 合格率(% ↑)、体积相对误差(Vol. Error, % ↓)及粒子分布相似度(EMD PDD, ↓)。
表 2:CLARI 各核心设计模块在 CSD 验证集上的消融表现
| 实验代号 | 模型配置细节 | Clash Rate (%) ↓ | PoseBusters (%) ↑ | Vol. Error (%) ↓ | EMD PDD ↓ |
|---|---|---|---|---|---|
| — | CSD 真实地基线 | 0.80 | 92.42 | — | — |
| A | 基础 DiT + 均值池化晶格预测 | 25.00 | 78.14 | 2.07 | 11.01 |
| B | A + 晶格虚拟原子 Token 化 | 23.25 | 79.04 | 2.22 | 11.09 |
| C | B + $\\mathcal{L}_{vol}$ & $\\mathcal{L}_{pair}$ 辅助物理损失 | 20.33 | 80.11 | 1.79 | 10.37 |
| D | C + 伴随自调节(Self-conditioning) | 9.79 | 85.23 | 1.55 | 9.66 |
| E | D + 退化为各向同性标准高斯先验 $p_0$ | 9.32 | 84.89 | 2.88 | 10.71 |
| Clari-M | D + 数据启发式先验 $p_0$ + Beta(1.8, 1) 时间采样 | 9.56 | 87.34 | 1.59 | 9.56 |
| Clari-L | Clari-M 的大规模参数版本 (173M) | 7.69 | 85.89 | 1.50 | 9.28 |
消融实验的关键洞察:
- 虚拟原子表示的统一性:对比 A 与 B 可知,将晶格基矢直接融入原子序列进行端到端特征自注意力,能够简化网络架构而不损失精度。
- 辅助物理损失的必要性:对比 B 与 C,体积损失和碰撞惩罚使体积误差从 2.22% 显著下降至 1.79%,有效压制了非物理穿插。
- 自调节(Self-Conditioning)的威力:对比 C 与 D,引入自调节让碰撞率(Clash Rate)直接从 20.33% 暴跌至 9.79%,PoseBusters 的物理合理性大增。这证明将上一步生成的预测终点 $\\hat{\\mathbf{x}}_1$ 作为下一步的输入,极大地稳定了去噪路径,帮助网络形成了强大的局域纠偏能力。
- 数据启发式先验 $p_0$ 的无可替代性:对比 E 与 Clari-M,若退化为普通高斯噪声,虽然碰撞率略有降低(因生成了大量极度稀疏、无意义的超大晶胞),但体积相对误差暴增至 2.88%,且表征晶体拓扑几何的 EMD PDD 指标大幅恶化(10.71)。数据启发式先验确保了起点处的晶格比例与真实有机晶体高度契合。
2.4 计算效率与时间花费对比
除了精度大幅提升,CLARI 最引人瞩目的突破在于其极其低廉的算力成本。在 H100 GPU 平台上,我们对生成 150 个样本并进行 UMA 能效重排的端到端耗时进行了严苛测试。
- 单分子极速预测:平均而言,CLARI-L 生成 150 个全原子晶胞仅需 2.2 秒!
- 包含能效重排的完整流程:即使加上用 UMA 机器学习力场对这 150 个样本进行单点能计算并筛选出 Top-30 的重排时间,整体端到端耗时也仅为 6.0 秒。
- 与 SOTA 基线对比:相比之下,OXtal 生成同样数量的样本并进行重排需要数分钟的算力。在 CSP 盲测集上,CLARI 实现了 15–30倍 的单体采样加速,以及 5–8倍 的端到端能效重排流程加速(详见图4时间直方图对比)。这使得在单卡上进行百万级别分子库的晶体结构虚拟筛选成为了现实。
3. 代码实现细节,复现指南,所用的软件包及开源 repo link
3.1 开源仓库与核心依赖
CLARI 的官方实现代码已在 GitHub 开源。该仓库基于 PyTorch 生态体系构建,融合了图神经网络与最前沿的流匹配采样器。
- 官方 GitHub 开源链接:https://github.com/aspuru-guzik-group/clari
- 核心依赖软件包(Software Stack):
- PyTorch ($\\ge 2.1$):基础张量计算框架。
- PyTorch Geometric (PyG):用于处理分子 2D 图的拓扑结构与节点/边表征。
- RDKit:用于读取分子的 SMILES,提取基本的化学键性质(注意:CLARI 的核心生成阶段并不强依赖 RDKit 的 Sanitization,这也让它能够预测非 RDKit 兼容的复杂无机/金属体系)。
- CSD Python API:用于直接与剑桥晶体数据库交互,进行格式转换与数据过滤。
- UMA (Universal Model for Atoms):用于执行零温下的非物理松弛能效快速重排。
3.2 特征输入流水线(Featurization Pipeline)
如表4所示,CLARI 在每个时间步 $t$ 会注入极为丰富的多维特征。复现时,需要通过数据加载器(Dataloader)拼装以下三个流(Streams):
- 序列特征(Sequence Features):
- 实时 3D 笛卡尔坐标 $\\mathbf{x}_t$(进行线性与正弦位置编码);
- 实时三维周期性分数坐标 $\\mathbf{s}_t = \\mathbf{x}_t \\mathbf{L}_t^{-1} \\pmod 1$;
- 原子序数(Element Embedding)与 VDW/共价半径;
- 形式电荷、配位数(Atomic degrees)以及分子内的拓扑环结构。
- 成对边特征(Pair Features):
- 分子内共价键类型嵌入(单键、双键、三键、芳香键或无键);
- 2D 拓扑距离嵌入;
- 三维周期性欧氏距离嵌入 $\\mathbf{d}_{ij} = \\min_{\\mathbf{z} \\in \\mathbb{Z}^3} \\| \\mathbf{p}_i - \\mathbf{p}_j - \\mathbf{L}^T \\mathbf{z} \\|_2$(通过 RBF 径向基函数划分为 128 个 Bin 进行嵌入)。
- 全局条件特征(Global Conditioning Features):
- 时间步 $t$(通过 Transformer 常用的正弦编码);
- 当前晶格矩阵 $\\mathbf{L}$ 的行列式 $\\det \\mathbf{L}$、逆矩阵 $\\mathbf{L}^{-1}$ 和度量张量 $\\mathbf{L}^T\\mathbf{L}$ 的展平投影;
- 整个晶胞的分子式向量。
3.3 训练与采样伪代码解析
为了展示 CLARI 运行的核心逻辑,以下给出了其最核心的架构前向传播与流匹配向量预测的伪代码实现,可直接用于理解其模型文件 model.py 的工程逻辑:
import torch
import torch.nn as nn
from torch_geometric.utils import to_dense_batch
class PairBiasAttentionBlock(nn.Module):
def __init__(self, d_model, d_pair, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.pair_bias_proj = nn.Linear(d_pair, n_heads)
def forward(self, h, z):
# h: [B, N, d_model] - 序列节点特征
# z: [B, N, N, d_pair] - 二维几何边特征偏置
B, N, _ = h.shape
# 线性投影 Q, K, V
Q = self.q_proj(h).view(B, N, self.n_heads, self.d_head).transpose(1, 2)
K = self.k_proj(h).view(B, N, self.n_heads, self.d_head).transpose(1, 2)
V = self.v_proj(h).view(B, N, self.n_heads, self.d_head).transpose(1, 2)
# 计算经典多头注意力得分
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_head ** 0.5)
# 计算双偏置张量并融合
bias = self.pair_bias_proj(z).permute(0, 3, 1, 2) # [B, n_heads, N, N]
attn_logits = scores + bias
attn_weights = torch.softmax(attn_logits, dim=-1)
out = torch.matmul(attn_weights, V) # [B, n_heads, N, d_head]
out = out.transpose(1, 2).contiguous().view(B, N, -1)
return out
class ClariFlowMatchingTrunk(nn.Module):
def __init__(self, d_model, d_pair, d_cond, depth):
super().__init__()
self.blocks = nn.ModuleList([
nn.ModuleDict({
"attn": PairBiasAttentionBlock(d_model, d_pair, n_heads=8),
"adaLN_1": nn.Linear(d_cond, d_model * 6), # 预测 Scale, Shift, Gate 参数
"mlp": nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.SiLU(),
nn.Linear(d_model * 4, d_model)
),
"adaLN_2": nn.Linear(d_cond, d_model * 3)
}) for _ in range(depth)
])
self.velocity_head = nn.Linear(d_model, 3) # 预测 dx_t / dt
def forward(self, h, z, c):
# c: 全局条件特征嵌入 [B, d_cond]
for block in self.blocks:
# 块 1:带双偏置自注意力的特征提取与 AdaLN 调制
ada_params_1 = block["adaLN_1"](c).chunk(6, dim=-1)
shift_ms, scale_ms, gate_ms, shift_mlp, scale_mlp, gate_mlp = ada_params_1
norm_h = scale_ms * h + shift_ms # 模拟自适应归一化
attn_out = block["attn"](norm_h, z)
h = h + gate_ms * attn_out
# 块 2:前馈变换与调制
norm_h2 = scale_mlp * h + shift_mlp
mlp_out = block["mlp"](norm_h2)
h = h + gate_mlp * mlp_out
return self.velocity_head(h) # 最终速度预测
3.4 极速复现步骤与避坑指南
想要在本地成功复现 CLARI 并达到论文中的收敛性能,建议严格遵循以下流程:
步骤 1:环境配置与依赖安装
conda create -n clari_env python=3.10
conda activate clari_env
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
conda install pyg -c pyg
pip install rdkit-pypi scipy pandas pyyaml
# 注册并安装 CSD Python API (需学术或商业授权)
步骤 2:下载预训练模型与评估
在仓库的 checkpoints/ 目录下放置官方下载的 clari_large_csd.pt。执行以下命令,对 Rigid 测试集进行生成采样:
python sample.py --config configs/clari_large.yaml \
--checkpoint checkpoints/clari_large_csd.pt \
--test_split data/splits/ox_rigid.json \
--num_samples 150 \
--batch_size 128 \
--output_dir results/rigid_generation/
步骤 3:防 OOM 内存溢出机制(Inference Batching)
由于随着原子数 $N$ 增加,成对特征矩阵 $N \times N$ 的空间占用会呈平方级上升。CLARI 内部设计了一种非常贴心的动态自适应批处理机制(Inference Batching)。当遇到超大晶胞($N > 500$)发生 CUDA Out of Memory 错误时,DataLoader 会捕获该异常,自动将当前 Batch 拆分减半,并自动重试。这一逻辑在 utils/batching.py 中有完整实现,复现大分子预测时请务必开启该功能。
4. 关键引用文献,以及你对这项工作局限性的评论
4.1 关键引用文献
本工作建立在近年来生成式化学、流匹配算法以及生物大分子三维建模的坚实地基之上。以下是 CLARI 赖以构建的 5 篇最关键的科学文献:
- OXtal(Jin et al., 2026):OXtal: An all-atom diffusion model for organic crystal structure prediction。这是首个将扩散模型(Diffusion Models)引入全原子 CSP 的工作。CLARI 继承了其“全原子建模(含氢)”的设定,但彻底重构了其低效的体相堆叠和三角形更新机制。
- Flow Matching (Lipman et al., 2023):Flow matching for generative modeling。奠定了连续时间向量场回归的数学底座,相比传统扩散概率模型(DDPM),流匹配提供了更快的收敛速度和更平直的生成轨迹。
- AlphaFold3 (Abramson et al., 2024):Accurate structure prediction of biomolecular interactions with AlphaFold 3。展示了不具备显式旋转等变性、但通过大规模数据增强与自注意力机制训练的 Transformer 能够极其精准地捕捉复杂分子几何学。CLARI 延续了这一“去显式等变,拥抱大参数扩展”的哲学。
- UMA (Wood et al., 2025):UMA: A family of universal models for atoms。通用多元素机器学习力场。CLARI 直接生成含氢原子的完整晶胞,因此能与 UMA 无缝结合,在推理阶段进行实时的物理能量重排,极大地提升了最终筛选精度。
- COMPACK (Chisholm and Motherwell, 2005):COMPACK: A program for identifying crystal structure similarity using distances。晶体结构几何相似度金标准。CLARI 的 $\\text{Sol}@k$ 精度指标完全基于 COMPACK 对预测与真实晶体中 15 组分子堆叠簇的 RMSD15 进行严格比对而得。
4.2 局限性深度评论(Critique of Limitations)
尽管 CLARI 凭借秒级的预测速度和超越 DFT 的表现树立了生成式 CSP 的全新里程碑,但从严苛的量子化学与工业级应用视角审视,该模型仍存在以下不可忽视的局限性与改进空间:
1. 2D 分子图输入的“手性与立体化学盲区”
CLARI 的输入仅为 2D 分子化学图(SMILES),模型本身在初始阶段并不显式感知分子的手性(Chirality)和顺反异构等三维构象偏置。对于手性药物(Chiral Drugs)结晶而言,相同的 2D 图对应完全不同的三维对映异构体,而 CLARI 可能会在同一个轨迹中混淆手性。未来的迭代必须在输入的节点特征中,引入明确的手性三维构象标签或直接在等变空间中进行条件约束。
2. 必须预先指定晶胞内分子拷贝数(Formula Units $Z$)
这是一个极大的物理限制。CLARI 在生成前,必须知道目标晶胞中包含多少个分子单体(即确定的 $Z$ 值)。在实际科研或新药开发中,目标分子的最稳定结晶往往对应未知的 $Z$ 值(最常见的是 $Z=4, 2, 1$)。虽然可以通过在推理时对这些常用 $Z$ 值进行暴力网格搜索来解决,但由于这引入了离散的额外维度,仍显不够优雅。未来若能将 $Z$ 作为连续扩散/流匹配的一部分进行联合生成,将实现真正的盲测预测。
3. 缺乏显式的物理守恒约束(Force-Free Generation)
流匹配纯粹是一个数据驱动的几何生成器。尽管加入了体积损失和碰撞惩罚,CLARI 预测出的晶胞在微观上仍可能存在微小的键角偏折或范德华半径轻微侵入。这也是为什么在追求极端精细度时,仍然需要借助 UMA 力场进行简单的单点能重排。如果能在训练损失中,通过物理神经网络(PINN)的设计,直接引入解析的晶格动力学梯度约束(例如力场势能的一阶导数即受力为 0),模型输出的亚稳态物理品质将获得质的提升。
4. 复杂多柔性环状分子的构象塌陷(Conformational Collapse)
对于含有多个大柔性大环、多旋转二面角的分子,CLARI 在进行流匹配线性插值时,分子内部的旋转势垒非常高。单纯的线性空间坐标流匹配可能会导致分子在去噪中期被迫穿过极高能量的旋转过渡态,导致构象畸变或生成失败。引入基于刚体动力学(Rigid Body Dynamics)或者在李群(Lie Groups)流形上定义的扭转角流匹配(Torsional Flow Matching),是解决这一难题的终极方向。
5. 其他必要的补充(量子化学视角与行业影响)
5.1 量子化学视角下的生成式 CSP 范式革命
从经典量子化学的视角来看,晶体结构预测本质上是一个在极高维度的势能面上寻找全局和局域能效极小值(Energy Minima)的寻优问题。传统方法之所以极其昂贵,是因为它们把绝大部分算力浪费在了对那些高能、不稳定、甚至根本无法结晶的“非物理空间”的计算和评估上。
CLARI 的诞生代表着一种范式转移(Paradigm Shift):
- 主动空间过滤:CLARI 通过海量真实晶体数据(来自 CSD 百万量级的实验晶体)的学习,隐式地构建了晶体结构的“物理合理几何空间先验”。
- “生成即合理”:CLARI 的 ODE 积分轨迹实际上是在这一物理合理空间内进行“流形漫步(Manifold Walk)”。它直接绕过了能量极高的物理荒漠,只在极具亚稳态潜力的局部能谷附近输出结构。
- 极小化后道工序:这使得原本需要进行高昂 DFT 结构弛豫(Relaxation)的后道工序,被极速的机器学习力场(MLIP)单点能评估或轻量级微调所取代。这极大地解放了计算化学家的双手和服务器的电力消耗。
5.2 广泛的工业界与工程学影响
CLARI 带来的秒级高通量预测能力,将在以下高技术产业产生极其深远的变革性影响:
- 药物多晶型虚拟筛选(Pharmaceutical Polymorph Screening): 在新药研发(R&D)中,一种活性药物成分(API)若在后期生产中突然出现更稳定的全新多晶型(如经典的“利托那韦(Ritonavir)事件”),会导致原有药剂溶解度骤降、失效乃至召回。CLARI 能够让制药企业在合成新药分子之前,在几小时内对数千个候选分子进行详尽的晶型空间扫描,提前排查“多晶型陷阱”,极大地缩短药物安全性评估周期。
- 高能材料分子设计(Energetic Materials): 对于炸药、推进剂等高能密度材料,晶胞的密度与分子排布直接决定了爆速、爆压以及热感度等安全性指标。CLARI 与生俱来的高精度晶格密度预测,能够协助科研人员在电脑中极速设计并筛选出具有超高密度、且具有合理晶格堆积的高安全新型能效分子。
- 有机光电半导体筛选(Organic Optoelectronics): 有机发光二极管(OLED)和有机太阳能电池(OPV)中,有机半导体分子的电荷迁移率(Charge Mobility)高度依赖于晶体中的 $\\pi-\\pi$ 堆积距离和相对取向。CLARI 能够实现秒级的晶胞生成,使半导体材料学家得以对数百万个π共轭分子库进行大规模计算筛选,直接发掘出具有卓越电荷传输性能的明星结晶材料。
5.3 结语与未来展望
CLARI 的成功无可辩驳地证明了:精细化的单晶胞表征 + 数据启发的物理先验 + 去冗余的纯双偏置自注意力机制,足以在晶体预测领域击败那些一味模仿生物大分子、层级臃肿的重型神经网络架构。这不仅标志着生成式晶体结构预测进入了“秒级时代”,更为未来的无机晶体合成、合金固态相图预测等更广阔的材料发现任务,开辟了一条极具想象力的效率飞跃之路。随着后续手性等变机制与 PINN 物理能量梯度的深度融入,CLARI 及其演化版本,终将成为每位量子化学与材料学家不可或缺的AI超级助手。