【多模态玩具】多模态大模型的训练和推理

一、什么是多模态?

通过融合多种数据模态(例如图片、音频、视频、文本等)来训练模型,从而提高模型的感知与理解能力,实现跨模态的信息交互与融合。

文本模态的表示:文本模态的表示方法有多种,如独热表示、低维空间表示(如通过神经网络模型学习得到的转换矩阵将单词或字映射到语义空间中)、词袋表示及其衍生出的n-grams词袋表示等。目前,主流的文本表示方法是预训练文本模型,如BERT。

视觉模态的表示:视觉模态分为图像模态和视频模态。图像模态的表示主要通过卷积神经网络(CNN)实现,如LeNet-5、AlexNet、VGG、GoogLeNet、ResNet等。视频模态的表示则结合了图像的空间属性和时间属性,通常由CNN和循环神经网络(RNN)或长短时记忆网络(LSTM)等模型共同处理。

声音模态的表示:声音模态的表示通常涉及音频信号的预处理、特征提取和表示学习等步骤,常用的模型包括深度神经网络(DNN)、卷积神经网络(CNN)和循环神经网络(RNN)等。

二、多模态数据如何融合?

以图像-文本为例,常见的融合思路有两种:

1. 统一嵌入解码器架构方法

图1 统一嵌入解码器架构

统一嵌入解码器架构方法采用单一解码器模型,类似于未经修改的 LLM 架构(例如 GPT-2 或 Llama 3.2)。在这种方法中,图像被Encoder转换为与原始文本 token 具有相同嵌入大小的 token,从而允许 LLM 在连接后同时处理文本和图像输入 token。

图2 图像编码器内部架构

图像编码器架构如上图所示,就是提取传统的预训练视觉转换器(ViT),舍去其最后的分类全连接层。Image Encoder中的Linear projection(即线性投影模块),实际上是由全连接层将256维向量的图像块投影至768维向量,实际上用卷积层会更为容易,两者也是等价的。

在提取Image Encoder之后,我们还会增加一层投影模块,将图片信息最终投影至与文本词向量输出维度相匹配。就类似于文本的向量化,只是多了一步投影对齐到同一维度。

图3 类比文本编码器

最后直接将处理过后的图像信息与文本信息相拼接,输入llm中进行推理即可。

2. 跨模态注意力架构方法

图4 跨模态注意力架构

跨模态注意力架构中,则是将图片信息经过编码与投影后,加在llm中的多头注意力层中,通过交叉注意力机制进行处理。在交叉注意力机制中,图像信息提供K和V,文本信息提供Q。

图5 交叉注意力机制

三、如何训练?

图6 多模态训练部分

与传统纯文本 LLM 的开发类似,多模态 LLM 的训练也包含两个阶段:预训练和指令微调。然而,与从零开始不同,多模态 LLM 的训练通常以一个经过预训练和指令微调的纯文本 LLM 作为基础模型。

对于图像编码器,常用CLIP编码器,并且通常在整个训练过程中保持不变(冻结参数)。在预训练阶段冻结 LLM 部分也很常见,只专注于训练投影器——一个线性层或一个小型多层感知器。鉴于投影器的学习能力有限,通常仅包含一到两层,LLM 通常会在多模态指令微调(阶段 2)期间解冻,以便进行更全面的更新。然而,在基于交叉注意力机制的模型(方法 B)中,交叉注意力层在整个训练过程中都是解冻的。

统一嵌入解码器架构(方法 A)通常更容易实现,因为它不需要对 LLM 架构本身进行任何修改。

跨模态注意力架构(方法 B)通常被认为具有更高的计算效率,因为它不会用额外的图像标记使输入上下文过载,而是稍后在交叉注意力层中引入它们。此外,如果 LLM 参数在训练期间保持冻结,则此方法可以保持原始 LLM 的纯文本性能。

四、实操

参考LLaVa的训练方式,文本编码器选用Qwen2.5-0.5B-Instruct,图像编码器用Siglip-base-patch16-224,以问答形式的图像-文本对为训练数据。根据课设的要求,采用 MedTrinity-25M的Demo为数据集,仅进行预训练,不进行指令微调。

由于原数据集仅包含图像与对应的Caption,因此人工添加描述性的Question,采用与LLaVa预训练阶段同款的预设问题。整个数据集有16w张图像文本对,跑一个batch其实差不多就过拟合了。

训练代码参考:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
import os
import json
import random
import torch
import io
import pandas as pd
import pyarrow.parquet as pq
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Any
from transformers import (
PreTrainedModel,
PretrainedConfig,
AutoTokenizer,
AutoModelForCausalLM,
AutoProcessor,
AutoModel,
Trainer,
TrainingArguments
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.utils.data import Dataset
from tqdm import tqdm

# 定义模型配置类
class VLMConfig(PretrainedConfig):
model_type = "vlm_model"

def __init__(
self,
llm_model_path='model/Qwen2.5-0.5B-Instruct',
vision_model_path='model/siglip-base-patch16-224',
freeze_vision_model=True,
freeze_llm_model=True,
image_pad_num=49, # 添加image_pad_num参数
**kwargs
):
self.vision_model_path = vision_model_path
self.llm_model_path = llm_model_path
self.freeze_vision_model = freeze_vision_model
self.freeze_llm_model = freeze_llm_model
self.image_pad_num = image_pad_num
super().__init__(**kwargs)

# 定义多模态模型
class VLM(PreTrainedModel):
config_class = VLMConfig

def __init__(self, config):
super().__init__(config)
self.config = config

# 加载视觉编码器
self.vision_model = AutoModel.from_pretrained(self.config.vision_model_path)
self.processor = AutoProcessor.from_pretrained(self.config.vision_model_path)

# 加载语言模型
self.llm_model = AutoModelForCausalLM.from_pretrained(self.config.llm_model_path)
self.tokenizer = AutoTokenizer.from_pretrained(self.config.llm_model_path)

# 确保特殊token存在
if '<|image_pad|>' not in self.tokenizer.get_vocab():
self.tokenizer.add_special_tokens({'additional_special_tokens': ['<|image_pad|>']})
self.llm_model.resize_token_embeddings(len(self.tokenizer))

# 图像特征维度与编码器隐藏维度的关系
vision_hidden_size = self.vision_model.config.vision_config.hidden_size
llm_hidden_size = self.llm_model.config.hidden_size

# SiGLIP输出是(batch_size, 196, hidden_size),需要映射到更少的tokens
# 定义投影层 - 先将196个token压缩到49个token,然后投影到LLM维度
self.linear1 = nn.Linear(vision_hidden_size * 4, llm_hidden_size) # 196 -> 49 (每4个合并)
self.linear2 = nn.Linear(llm_hidden_size, llm_hidden_size)

# 冻结模型参数
if self.config.freeze_vision_model:
for param in self.vision_model.parameters():
param.requires_grad = False

if self.config.freeze_llm_model:
for param in self.llm_model.parameters():
param.requires_grad = False

# def forward(self, input_ids, attention_mask=None, labels, pixel_values):
def forward(self, input_ids, labels, pixel_values, attention_mask=None):

# 获取文本嵌入
text_embeds = self.llm_model.get_input_embeddings()(input_ids)

# 处理图像嵌入
if pixel_values is not None:
# print(f"pix is here", pixel_values.shape)
# norm = (pixel_values - 1).norm()
# print(f"x is",norm)
image_embeds = self.vision_model.vision_model(pixel_values).last_hidden_state
b, s, d = image_embeds.shape
# 压缩图片tokens: (b, 196, d) --> (b, 49, d*4)
image_embeds = image_embeds.view(b, -1, d*4)
image_features = self.linear2(F.silu(self.linear1(image_embeds)))

# 将图像特征合并到输入嵌入中
text_embeds = text_embeds.to(image_features.dtype)
inputs_embeds = self.merge_input_ids_with_image_features(image_features, text_embeds, input_ids)
else:
inputs_embeds = text_embeds

# 通过语言模型
outputs = self.llm_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
return_dict=True
)

return outputs

def merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids):
# 找到所有<|image_pad|>的位置
image_pad_id = self.tokenizer.convert_tokens_to_ids('<|image_pad|>')
batch_indices, image_indices = torch.where(input_ids == image_pad_id)

# 获取唯一的batch_indices,以处理每个样本
unique_batch_indices = batch_indices.unique()

for batch_idx in unique_batch_indices:
# 获取当前批次的所有image_pad位置
pad_positions = image_indices[batch_indices == batch_idx]

# 确保我们有足够的image_pad位置
if len(pad_positions) >= self.config.image_pad_num:
# 只使用前image_pad_num个位置
pad_positions = pad_positions[:self.config.image_pad_num]

# 将image_features的每一行放入对应的image_pad位置
for i, pos in enumerate(pad_positions):
if i < image_features.shape[1]:
inputs_embeds[batch_idx, pos] = image_features[batch_idx, i]

return inputs_embeds

# 定义数据集类
# 不同的数据集的组织形式可能不同,作者给出的组织方式也可能与实际情况不足
# 最好先试试再再写Dateset!
class MedicalImageDataset(Dataset):
def __init__(self, data_dir, tokenizer, processor, image_pad_num=49):
self.data_dir = data_dir
self.tokenizer = tokenizer
self.processor = processor
self.image_pad_num = image_pad_num
self.data = []
self.instructions = [
"Describe the following image in detail",
"Provide a detailed description of the given image",
"Give an elaborate explanation of the image you see",
"Share a comprehensive rundown of the presented image",
"Offer a thorough analysis of the image",
"Explain the various aspects of the image before you",
"Clarify the contents of the displayed image with great detail",
"Characterize the image using a well-detailed description",
"Break down the elements of the image in a detailed manner",
"Walk through the important details of the image",
"Portray the image with a rich, descriptive narrative",
"Narrate the contents of the image with precision",
"Analyze the image in a comprehensive and detailed manner",
"Illustrate the image through a descriptive explanation",
"Examine the image closely and share its details",
"Write an exhaustive depiction of the given image"
]

# 加载所有parquet文件
parquet_files = [f for f in os.listdir(data_dir) if f.endswith('.parquet')]
print(f"Found {len(parquet_files)} parquet files")

for pq_file in tqdm(parquet_files, desc="Loading data"):
file_path = os.path.join(data_dir, pq_file)
table = pq.read_table(file_path)
df = table.to_pandas()

for _, row in df.iterrows():
# 检查必要的字段是否存在
if all(field in row for field in ['image', 'caption', 'id']):
image_bytes = row['image']['bytes'] if isinstance(row['image'], dict) and 'bytes' in row['image'] else None
self.data.append({
'image': image_bytes, # 直接存储二进制图像数据
'caption': row['caption'],
'id': row['id']
})



print(f"Loaded {len(self.data)} image-caption pairs")

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
item = self.data[idx]

# 随机选择一个指令
instruction = random.choice(self.instructions)

# 构建提示,添加图像占位符
prompt = f"{instruction}\n" + "<|image_pad|>" * self.image_pad_num
answer = item['caption']

# 从二进制数据加载图像
try:
image = Image.open(io.BytesIO(item['image'])).convert('RGB')
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.squeeze(0)
# pixel_values = self.processor(text=None, images=image, return_tensors="pt")['pixel_values']
# print(f"shape is", pixel_values.shape)
if type(pixel_values) != torch.Tensor:
print("Wrong!!!!!!!!!!!!!!!!!")
except Exception as e:
print(f"Error loading image {item['id']}: {e}")
# 创建一个空白图像
image = Image.new('RGB', (224, 224), color='white')
pixel_values = self.processor(text=None, images=image, return_tensors="pt")['pixel_values']

# 简化的对话格式构建,不依赖chat_template
system = "You are a helpful medical image analysis assistant."
user_message = prompt
assistant_message = answer

# 构建输入文本,使用简单的格式
text = f"{system}\n\nUser: {user_message}\n\nAssistant: {assistant_message}"

# 分词
encoding = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
input_ids = encoding["input_ids"][0]
attention_mask = encoding["attention_mask"][0]

# 构建标签:找到Assistant部分的开始位置
assistant_start = text.find("Assistant: ")
if assistant_start != -1:
# 计算Assistant开始前的token数量
prefix_text = text[:assistant_start]
prefix_tokens = self.tokenizer(prefix_text, add_special_tokens=False)["input_ids"]
assistant_pos = len(prefix_tokens)
# 可能需要加上特殊token的偏移量
if self.tokenizer.bos_token_id is not None:
assistant_pos += 1 # 加上BOS token
else:
# 如果找不到,估算一个位置(大约是前半部分)
assistant_pos = len(input_ids) // 2

# 构建标签
labels = input_ids.clone()
labels[:assistant_pos] = -100 # 用户部分标签设为-100
# print(f"shape is", pixel_values.shape)

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"pixel_values": pixel_values
}



class MyDataCollator:
def __init__(self, tokenizer):
self.tokenizer = tokenizer

def __call__(self, features):
# 确保所有输入都是有效的
valid_features = []
for f in features:
if f["pixel_values"] is not None and isinstance(f["pixel_values"], torch.Tensor):
valid_features.append(f)
else:
print(f"Warning: Skipping invalid sample with None or non-tensor pixel_values")

if not valid_features:
raise ValueError("No valid samples in batch!")

# 正确堆叠张量
batch = {
"input_ids": torch.stack([f["input_ids"] for f in valid_features]),
"attention_mask": torch.stack([f["attention_mask"] for f in valid_features]),
"labels": torch.stack([f["labels"] for f in valid_features]),
"pixel_values": torch.stack([f["pixel_values"] for f in valid_features])
}

return batch



def main():
# 设置配置
config = VLMConfig(
llm_model_path='model/Qwen2.5-0.5B-Instruct',
vision_model_path='model/siglip-base-patch16-224',
freeze_vision_model=True,
freeze_llm_model=True,
image_pad_num=49
)

# 初始化模型
model = VLM(config)
tokenizer = AutoTokenizer.from_pretrained(config.llm_model_path)
processor = AutoProcessor.from_pretrained(config.vision_model_path)

# 确保tokenizer有需要的token
tokenizer.add_special_tokens({'additional_special_tokens': ['<|image_pad|>']})

# 确保tokenizer有pad_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# 打印可训练参数数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"总参数量: {total_params}, 可训练参数量: {trainable_params}")

# 创建数据集
dataset = MedicalImageDataset(
data_dir="data/MedTrinity",
tokenizer=tokenizer,
processor=processor,
image_pad_num=49
)

# 训练参数
training_args = TrainingArguments(
output_dir="./output",
num_train_epochs=1,
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
save_steps=500,
save_total_limit=2,
learning_rate=1e-4,
warmup_steps=500,
logging_dir="./logs",
logging_steps=50,
eval_strategy="no",
fp16=True,
dataloader_num_workers=4,
report_to="tensorboard",
)

# 初始化训练器
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=MyDataCollator(tokenizer),
)

# 开始训练
trainer.train()

# 保存模型
trainer.save_model("./final_model")
print("训练完成,模型已保存到 ./final_model")

if __name__ == "__main__":
main()

推理代码参考:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, AutoConfig
from PIL import Image
from train import VLMConfig, VLM
import torch
from torch.nn import functional as F

device = "cuda:0"
processor = AutoProcessor.from_pretrained("model/siglip-base-patch16-224")
tokenizer = AutoTokenizer.from_pretrained('model/Qwen2.5-0.5B-Instruct')
AutoConfig.register("vlm_model", VLMConfig)
AutoModelForCausalLM.register(VLMConfig, VLM)

pretrain_model = AutoModelForCausalLM.from_pretrained('output/checkpoint-5051')
pretrain_model.to(device)

# sft_model = AutoModelForCausalLM.from_pretrained('/home/user/wyf/train_multimodal_from_scratch/save/sft')
# sft_model.to(device)

pretrain_model.eval()
# sft_model.eval()
def generate(image_input, text_input, max_new_tokens = 256, temperature = 0.0, top_k = None):
q_text = tokenizer.apply_chat_template([{"role":"system", "content":'You are a helpful medical assistant.'}, {"role":"user", "content":f'{text_input}\n<image>'}], \
tokenize=False, \
add_generation_prompt=True).replace('<image>', '<|image_pad|>'*49)
input_ids = tokenizer(q_text, return_tensors='pt')['input_ids']
input_ids = input_ids.to(device)
# image = Image.open(image_input).convert("RGB")
pixel_values = processor(text=None, images=image_input).pixel_values
pixel_values = pixel_values.to(device)
eos = tokenizer.eos_token_id
s = input_ids.shape[1]
while input_ids.shape[1] < s + max_new_tokens - 1:
# if mode == 'pretrain':
model = pretrain_model
# else:
# model = sft_model
inference_res = model(input_ids, None, pixel_values)
logits = inference_res.logits
logits = logits[:, -1, :]

for token in set(input_ids.tolist()[0]):
logits[:, token] /= 1.0

if temperature == 0.0:
_, idx_next = torch.topk(logits, k=1, dim=-1)
else:
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')

probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1, generator=None)

if idx_next == eos:
break

input_ids = torch.cat((input_ids, idx_next), dim=1)
return tokenizer.decode(input_ids[:, s:][0])

with gr.Blocks() as demo:
with gr.Row():
# 上传图片
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="选择图片")
with gr.Column(scale=1):
# mode = gr.Radio(["pretrain", "sft"], label="选择模型")
mode = 'pretrain'
text_input = gr.Textbox(label="输入文本")
# text_output = gr.Textbox(label="输出文本")
text_output = gr.Textbox(
label="输出文本",
lines=10, # 显示时默认有 10 行高
max_lines=20, # 最多扩展到 20 行(如果内容太多)
interactive=False # 输出一般不可编辑
)
generate_button = gr.Button("生成")
generate_button.click(generate, inputs=[image_input, text_input], outputs=text_output)


if __name__ == "__main__":
demo.launch(share=False, server_name="0.0.0.0", server_port=7891)


实际推理肯定是过拟合的,就用了大量数据却只训练了一个projection,应付一下课设就先这样不管他了,况且本人没有医疗知识背景,无法对模型推理的结果进行人为判断,也没有找到良好的医疗背景下的指令微调数据集,就先这样了。

最后把模型输出的英文Caption用外部LLM转为中文,再用脚本转为PDF报告,就能做一个医疗CT图像+文本指令控制输出医疗诊断报告的流,综设这样就完事了。

参考资料:

理论指导

参考实操项目

LLaVa’s Papper

数据集 MedTrinity-25M


【多模态玩具】多模态大模型的训练和推理
https://blog.sheep0.top/2025/09/29/【多模态玩具】多模态大模型的训练和推理/
作者
Sheep0
发布于
2025年9月29日
许可协议