当前位置: 首页 > news >正文

生成任务,大模型

一个生成项目

输入:文字描述(但是给的数据集是一串数字,id,ct描述,医生描述)
输出:诊断报告

一、数据处理

import pandas as pd  #处理表格数据pre_train_file= "data/train.csv"train_df = pd.read_csv(pre_train_file,header=None,names=["id","input","tgt"]) #读入数据print(train_df.head())train_data = train_df.sample(frac=0.9, random_state=0, axis=0)   #采样0.9的比例val_data = train_df[~train_df.index.isin(train_data.index)]       #干啥的,  过来用train_data.to_csv("data/pro_train_data.csv", index=False,header=False)val_data.to_csv("data/pro_val_data.csv", index=False,header=False)

主要是用于从一个CSV文件中读取数据,并将其划分为训练集和验证集,然后将这两个数据集分别保存到新的CSV文件中。

代码逐行解释

导入必要的库
import pandas as pd  # 处理表格数据
  • pandas:一个强大的数据分析和处理库,特别适合处理表格数据(如CSV文件)。
定义文件路径并读取数据
pre_train_file = "data/train.csv"train_df = pd.read_csv(pre_train_file, header=None, names=["id", "input", "tgt"])  # 读入数据print(train_df.head())
  • pre_train_file:指定要读取的CSV文件路径。
  • pd.read_csv
    • header=None:表示CSV文件没有表头(第一行不是列名)。
    • names=["id", "input", "tgt"]:为每一列指定名称。
  • print(train_df.head()):打印前五行数据,以便检查读取是否正确。
数据划分
train_data = train_df.sample(frac=0.9, random_state=0, axis=0)  # 采样0.9的比例val_data = train_df[~train_df.index.isin(train_data.index)]  # 干啥的, 过来用
  • train_data

    • 使用 sample 方法随机采样90%的数据作为训练集。
    • frac=0.9:表示采样的比例为90%。
    • random_state=0:设置随机种子以确保结果可重复。
    • axis=0:表示沿行方向进行采样(默认行为)。
  • val_data

    • 使用 ~train_df.index.isin(train_data.index) 来获取不在训练集中的数据作为验证集。
    • isin(train_data.index) 返回一个布尔数组,指示哪些索引在训练集中。
    • ~ 取反操作符,返回不在训练集中的索引。
保存数据
train_data.to_csv("data/pro_train_data.csv", index=False, header=False)val_data.to_csv("data/pro_val_data.csv", index=False, header=False)
  • to_csv 方法
    • 将DataFrame保存为CSV文件。
    • index=False:不保存行索引。
    • header=False:不保存列名。

二、处理词表

import sys
import torch
from collections import Counter
from transformers import BertTokenizer
from transformers import BartConfig
from transformers import BartForConditionalGeneration
from model_utils.config import parse_argsargs = parse_args()         #设置 ,字典, 属性类  config  {}def load_data(path):with open(path, 'r', encoding='utf-8') as f:lines = f.readlines()datas = []for line in lines:line = line.strip().split(",")if len(line) == 3:# 训练集text, target = line[1].split(" "), line[2].split(" ")datas.append(text + target)else:text = line[1].split(" ")datas.append(text)return datastrain_data = load_data('./data/train.csv')token2count = Counter()     #计数工具 哈希表for i in train_data:token2count.update(i)       #不需要知道原理tail = []
ct = 0
for k, v in token2count.items():if v >= ct:tail.append(k)
tail.sort()
vocab = tailvocab.insert(0,"[PAD]")
vocab.insert(100,"[UNK]")
vocab.insert(101,"[CLS]")
vocab.insert(102,"[SEP]")
vocab.insert(103,"[MASK]")
vocab.insert(104,"[EOS]")
# tokenizer = BertTokenizer.from_pretrained(args.pre_model_path)
# vocabs = tokenizer.get_vocab()   #获取模型词表# new_vocabs = list(vocabs.keys())
# print(len(vocabs))
# count = 0
# for v in vocab:         #mn复杂度
#     if v not in vocabs:
#         count += 1
#         new_vocabs.append(v)
# print(len(new_vocabs))
new_vocabs = vocab
with open(args.pre_model_path+'/vocab.txt', 'w', encoding='utf-8') as f:for v in new_vocabs:f.write(f"{v}\n")    #保存model = BartForConditionalGeneration.from_pretrained(args.pre_model_path)      #模型
model.resize_token_embeddings(len(new_vocabs))
state_dict = model.state_dict()
torch.save(state_dict, args.pre_model_path+'/pytorch_model.bin')
bartconfig = BartConfig.from_pretrained(args.pre_model_path)
bartconfig.vocab_size = len(new_vocabs)
bartconfig.save_pretrained(args.pre_model_path)

1. 导入必要的库

import sys
import torch
from collections import Counter
from transformers import BertTokenizer
from transformers import BartConfig
from transformers import BartForConditionalGeneration
from model_utils.config import parse_args
  • sys:用于系统相关的操作(如命令行参数)。
  • torch:PyTorch的核心库,用于深度学习模型。
  • Counter:来自 collections 模块,用于统计元素出现的次数。
  • BertTokenizer, BartConfig, BartForConditionalGeneration:来自 transformers 库,分别用于分词、配置和加载预训练模型。
  • parse_args:自定义函数,用于解析命令行参数或配置文件,返回一个包含配置参数的对象。

2. 解析参数

args = parse_args()  # 设置,字典,属性类 config {}
  • parse_args:调用自定义函数解析配置参数,并将其存储在 args 对象中。假设 args 包含诸如 pre_model_path 等路径信息。

3. 定义数据加载函数

def load_data(path):with open(path, 'r', encoding='utf-8') as f:lines = f.readlines()datas = []for line in lines:line = line.strip().split(",")if len(line) == 3:# 训练集text, target = line[1].split(" "), line[2].split(" ")datas.append(text + target)else:text = line[1].split(" ")datas.append(text)return datas
  • load_data 函数
    • 打开指定路径的文件并读取每一行。
    • 使用 strip() 去除每行的前后空白字符,并使用 split(",") 将其按逗号分割为列表。
    • 如果列表长度为3(假设是训练集),则将第二列和第三列的数据拆分为单词列表,并合并后添加到 datas 列表中。
    • 如果列表长度不为3,则仅处理第二列的数据,并将其拆分为单词列表后添加到 datas 列表中。
    • 返回 datas 列表。

4. 加载数据

train_data = load_data('./data/train.csv')
  • 调用 load_data 函数加载训练数据,并将结果存储在 train_data 变量中。

5. 统计词频

token2count = Counter()  # 计数工具 哈希表for i in train_data:token2count.update(i)  # 不需要知道原理
  • token2count:使用 Counter 类创建一个哈希表来统计每个单词出现的次数。
  • 遍历 train_data 中的每一行数据,并使用 update 方法更新 token2count,记录每个单词出现的次数。

6. 创建词汇表

tail = []
ct = 0
for k, v in token2count.items():if v >= ct:tail.append(k)
tail.sort()
vocab = tailvocab.insert(0, "[PAD]")
vocab.insert(100, "[UNK]")
vocab.insert(101, "[CLS]")
vocab.insert(102, "[SEP]")
vocab.insert(103, "[MASK]")
vocab.insert(104, "[EOS]")
  • tail:筛选出频率大于等于 ct 的单词,并按字母顺序排序。注意这里 ct 设为0,因此所有单词都会被包含进来。
  • vocab:将 tail 赋值给 vocab
  • 插入特殊标记:在 vocab 中插入一些特殊的标记符号(如 [PAD], [UNK], [CLS], [SEP], [MASK], [EOS]),这些标记在自然语言处理任务中具有特定含义。

7. 保存词汇表

new_vocabs = vocab
with open(args.pre_model_path + '/vocab.txt', 'w', encoding='utf-8') as f:for v in new_vocabs:f.write(f"{v}\n")  # 保存
  • new_vocabs:直接赋值为 vocab
  • 保存词汇表:将词汇表中的每个单词写入 vocab.txt 文件中,文件路径由 args.pre_model_path 指定。

8. 加载预训练模型并调整词汇表大小

model = BartForConditionalGeneration.from_pretrained(args.pre_model_path)  # 模型
model.resize_token_embeddings(len(new_vocabs))
state_dict = model.state_dict()
torch.save(state_dict, args.pre_model_path + '/pytorch_model.bin')bartconfig = BartConfig.from_pretrained(args.pre_model_path)
bartconfig.vocab_size = len(new_vocabs)
bartconfig.save_pretrained(args.pre_model_path)
  • 加载预训练模型:使用 BartForConditionalGeneration.from_pretrained 加载预训练模型。
  • 调整词汇表大小:使用 resize_token_embeddings 方法调整模型的嵌入层大小以适应新的词汇表。
  • 保存模型状态:将模型的状态字典保存到 pytorch_model.bin 文件中,文件路径由 args.pre_model_path 指定。
  • 更新配置:更新 BartConfig 中的 vocab_size 属性,并保存配置。

三、自监督预训练

from model_utils.pre_data import PreTrainDataset, loadData, MLM_Data
from torch.utils.data import DataLoader, Dataset
from model_utils.models import preModel
import logging        #日志
import os
from model_utils.config import parse_args
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer
import torch
import time
# os.environ['CUDA_VISIBLE_DEVICES']='0'def train_and_validate(args):# 1. load data  modelmodel = preModel(args)     #加载预训练模型optimizer, scheduler = build_optimizer(args, model)# model = model.to(args.device)use_pre = Falseif use_pre:checkpoint = torch.load(args.pre_file, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'],strict=False)if args.device == 'cuda':if args.paral == True:model = torch.nn.parallel.DataParallel(model.to(args.device))else:model = model.to(args.device)# model = BalancedDataParallel(16, model, dim=0).to(args.device)# model = model.to(args.device)#-------ema here-----------------all_data = loadData(args.data_path)train_MLM_data = MLM_Data(all_data, args)train_dataloader = DataLoader(train_MLM_data, batch_size=args.batch_size, shuffle=True,collate_fn=train_MLM_data.collate)step = 0start_time = time.time()num_total_steps = len(train_dataloader) * args.max_epochsfor epoch in range(args.max_epochs):    #开始训练了for batch in train_dataloader:model.train()loss= model(batch)loss = loss.mean()loss.backward()optimizer.step()optimizer.zero_grad()scheduler.step()step += 1if step % args.print_steps == 0:time_per_step = (time.time() - start_time) / max(1, step)remaining_time = time_per_step * (num_total_steps - step)remaining_time = time.strftime('%H:%M:%S', time.gmtime(remaining_time))logging.info(f"Epoch {epoch} step {step} eta {remaining_time}: loss {loss:.3f}")logging.info(f"VAL_Epoch {epoch} step {step}: loss {loss:.3f}")if epoch % 5 == 0:torch.save({'epoch': epoch, 'model_state_dict': model.module.state_dict()},f'{args.savedmodel_path}/lr{args.learning_rate}epoch{epoch}loss{loss:.3f}pre_model.bin')def main():args = parse_args()           #设置   字典setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args)         #LINUXtrain_and_validate(args)if __name__ == '__main__':main()

实现了一个完整的训练和验证流程,包括数据加载、模型初始化、训练循环、日志记录以及模型保存等功能

1. 导入必要的库

from model_utils.pre_data import PreTrainDataset, loadData, MLM_Data
from torch.utils.data import DataLoader, Dataset
from model_utils.models import preModel
import logging        # 日志
import os
from model_utils.config import parse_args
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer
import torch
import time
  • PreTrainDataset, loadData, MLM_Data:自定义模块,用于数据处理。
  • DataLoader, Dataset:PyTorch提供的类,用于数据加载和管理。
  • preModel:自定义模型类。
  • logging:用于记录日志信息。
  • os:用于操作系统相关的操作(如文件路径处理)。
  • parse_args:自定义函数,解析命令行参数或配置文件。
  • setup_device, setup_seed, setup_logging, build_optimizer:自定义工具函数,分别用于设置设备、随机种子、日志记录和优化器构建。
  • torch:PyTorch核心库。
  • time:用于时间相关操作。

2. 定义训练和验证函数

def train_and_validate(args):# 1. 加载数据和模型model = preModel(args)     # 加载预训练模型optimizer, scheduler = build_optimizer(args, model)use_pre = Falseif use_pre:checkpoint = torch.load(args.pre_file, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'], strict=False)if args.device == 'cuda':if args.paral == True:model = torch.nn.parallel.DataParallel(model.to(args.device))else:model = model.to(args.device)all_data = loadData(args.data_path)train_MLM_data = MLM_Data(all_data, args)train_dataloader = DataLoader(train_MLM_data, batch_size=args.batch_size, shuffle=True, collate_fn=train_MLM_data.collate)step = 0start_time = time.time()num_total_steps = len(train_dataloader) * args.max_epochsfor epoch in range(args.max_epochs):    # 开始训练了for batch in train_dataloader:model.train()loss = model(batch)loss = loss.mean()loss.backward()optimizer.step()optimizer.zero_grad()scheduler.step()step += 1if step % args.print_steps == 0:time_per_step = (time.time() - start_time) / max(1, step)remaining_time = time_per_step * (num_total_steps - step)remaining_time = time.strftime('%H:%M:%S', time.gmtime(remaining_time))logging.info(f"Epoch {epoch} step {step} eta {remaining_time}: loss {loss:.3f}")logging.info(f"VAL_Epoch {epoch} step {step}: loss {loss:.3f}")if epoch % 5 == 0:torch.save({'epoch': epoch, 'model_state_dict': model.module.state_dict()},f'{args.savedmodel_path}/lr{args.learning_rate}epoch{epoch}loss{loss:.3f}pre_model.bin')
解释
  • 加载数据和模型

    • 使用 preModel 类加载预训练模型。
    • 使用 build_optimizer 函数构建优化器和学习率调度器。
    • 如果 use_pre 为真,则从指定路径加载预训练模型的权重。
    • 根据 args.deviceargs.paral 参数决定是否使用多GPU并行训练。
  • 数据加载

    • 使用 loadData 函数加载所有数据。
    • 使用 MLM_Data 类将数据转换为适合训练的数据集格式。
    • 使用 DataLoader 创建数据加载器,支持批量加载和数据打乱。
  • 训练循环

    • 对每个epoch进行遍历。
    • 对每个batch进行前向传播计算损失,反向传播更新权重。
    • 记录训练进度和剩余时间,并在特定步数时打印日志。
    • 每隔5个epoch保存一次模型。

3. 主函数

def main():args = parse_args()           # 设置   字典setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args)         # LINUXtrain_and_validate(args)if __name__ == '__main__':main()
  • main 函数
    • 调用 parse_args 解析命令行参数。
    • 调用 setup_logging 配置日志记录。
    • 调用 setup_devicesetup_seed 分别设置设备和随机种子。
    • 创建保存模型的目录(如果不存在)。
    • 打印训练和评估参数。
    • 调用 train_and_validate 函数开始训练和验证过程。

四、微调

import logging
import os
import time
import torch
from transformers import PretrainedBartModel
from model_utils.config import parse_args
from model_utils.data import create_dataloaders
from model_utils.models import myModel
from model_utils.score import CiderD, CE
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer,array2str
from torch.cuda.amp import autocast as ac
from tqdm import tqdm as tqdmos.environ['CUDA_VISIBLE_DEVICES']='0'# 不需要完全理解,  知道每一块在做什么就行   知道之后,  以后再用到, 搬过去就行def validate(model, loader, args, output_file=None, beam=1, n=-1):res, gts = [], {}tot = 0for (source, targets) in tqdm(loader):if n>0 and tot>n:breaksource = source.cuda()pred = model(source[:, :args. input_l])pred = pred.cpu().detach().numpy()#print(pred.shape)for i in range(pred.shape[0]):# res.append({'image_id':tot, 'caption': [array2str(pred[i][2:], args)]})# gts[tot] = [array2str(targets[i][1:], args)]res.append({'image_id':tot, 'caption': [array2str(pred[i], args)]})gts[tot] = [array2str(targets[i][1:], args)]tot += 1CiderD_scorer = CiderD(df='corpus', sigma=15)cider_score, cider_scores = CiderD_scorer.compute_score(gts, res)return cider_scoredef train_and_validate(args):# 1. load datatrain_dataloader, val_dataloader = create_dataloaders(args)model = myModel(args)use_pre = Trueif use_pre:print('use_pre')checkpoint = torch.load(args.my_pre_model_path, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'],strict=True)optimizer, scheduler = build_optimizer(args, model)model = model.to(args.device)#-------ema here-----------------model.train()#-------------------------------# loss, results = validate(model, val_dataloader)# 3. trainingstep = 0best_score = args.best_score     #评估指标  准确率for epoch in range(args.max_epochs):for (source, targets) in tqdm(train_dataloader):source = source.cuda()targets = targets.cuda()model.train()pred = model(source[:, :args. input_l], targets[:, :args.output_l])loss  = CE(pred[:, :-1], targets[:, 1:])loss = loss.mean()loss.backward()optimizer.step()model.zero_grad()scheduler.step()step += 1if epoch % 1 == 0:cider_score = validate(model, val_dataloader, args)logging.info(f"Epoch {epoch} step {step}: loss {loss:.3f}, cider_score {cider_score}")if cider_score >= best_score:best_score = cider_scoretorch.save({'epoch': epoch, 'model_state_dict': model.state_dict()},f'{args.savedmodel_path}/model_epoch_{epoch}_cider_score_{cider_score}.bin')def main():args = parse_args()setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args)train_and_validate(args)if __name__ == '__main__':main()

实现了一个完整的训练和验证流程,包括数据加载、模型初始化、训练循环、验证评估以及模型保存等功能。

1. 导入必要的库

import logging
import os
import time
import torch
from transformers import PretrainedBartModel
from model_utils.config import parse_args
from model_utils.data import create_dataloaders
from model_utils.models import myModel
from model_utils.score import CiderD, CE
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer, array2str
from torch.cuda.amp import autocast as ac
from tqdm import tqdm as tqdmos.environ['CUDA_VISIBLE_DEVICES'] = '0'
  • logging:用于记录日志信息。
  • os:用于操作系统相关的操作(如文件路径处理)。
  • time:用于时间相关操作。
  • torch:PyTorch核心库。
  • PretrainedBartModel:来自 transformers 库的预训练模型基类。
  • parse_args:自定义函数,解析命令行参数或配置文件。
  • create_dataloaders:自定义函数,创建数据加载器。
  • myModel:自定义模型类。
  • CiderD, CE:自定义评分函数,分别用于计算CIDEr-D分数和交叉熵损失。
  • setup_device, setup_seed, setup_logging, build_optimizer, array2str:自定义工具函数,分别用于设置设备、随机种子、日志记录、构建优化器和数组转字符串。
  • autocast:用于混合精度训练。
  • tqdm:用于显示进度条。

2. 定义验证函数

def validate(model, loader, args, output_file=None, beam=1, n=-1):res, gts = [], {}tot = 0for (source, targets) in tqdm(loader):if n > 0 and tot > n:breaksource = source.cuda()pred = model(source[:, :args.input_l])pred = pred.cpu().detach().numpy()for i in range(pred.shape[0]):res.append({'image_id': tot, 'caption': [array2str(pred[i], args)]})gts[tot] = [array2str(targets[i][1:], args)]tot += 1CiderD_scorer = CiderD(df='corpus', sigma=15)cider_score, cider_scores = CiderD_scorer.compute_score(gts, res)return cider_score
解释
  • 输入参数

    • model: 需要验证的模型。
    • loader: 数据加载器。
    • args: 命令行参数或配置对象。
    • output_file: 输出文件路径(可选)。
    • beam: 束搜索宽度(可选,默认为1)。
    • n: 验证样本数限制(可选,默认为-1,表示不限制)。
  • 逻辑

    • 初始化结果列表 res 和真实标签字典 gts
    • 使用 tqdm 显示进度条遍历数据加载器中的每个批次 (source, targets)
    • source 移动到 GPU 并进行前向传播得到预测结果 pred
    • 将预测结果和真实标签转换为字符串格式并添加到 resgts 中。
    • 使用 CiderD 计算预测结果与真实标签之间的 CIDEr-D 分数。
    • 返回 CIDEr-D 分数。

3. 定义训练和验证函数

def train_and_validate(args):# 1. load datatrain_dataloader, val_dataloader = create_dataloaders(args)model = myModel(args)use_pre = Trueif use_pre:print('use_pre')checkpoint = torch.load(args.my_pre_model_path, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'], strict=True)optimizer, scheduler = build_optimizer(args, model)model = model.to(args.device)model.train()step = 0best_score = args.best_score  # 评估指标 准确率for epoch in range(args.max_epochs):for (source, targets) in tqdm(train_dataloader):source = source.cuda()targets = targets.cuda()model.train()pred = model(source[:, :args.input_l], targets[:, :args.output_l])loss = CE(pred[:, :-1], targets[:, 1:])loss = loss.mean()loss.backward()optimizer.step()model.zero_grad()scheduler.step()step += 1if epoch % 1 == 0:cider_score = validate(model, val_dataloader, args)logging.info(f"Epoch {epoch} step {step}: loss {loss:.3f}, cider_score {cider_score}")if cider_score >= best_score:best_score = cider_scoretorch.save({'epoch': epoch, 'model_state_dict': model.state_dict()},f'{args.savedmodel_path}/model_epoch_{epoch}_cider_score_{cider_score}.bin')
解释
  • 加载数据

    • 使用 create_dataloaders 函数加载训练和验证数据加载器。
  • 初始化模型和优化器

    • 使用 myModel 类加载模型。
    • 如果 use_pre 为真,则从指定路径加载预训练模型的权重。
    • 使用 build_optimizer 函数构建优化器和学习率调度器。
    • 将模型移动到指定设备(CPU或GPU)。
  • 训练循环

    • 对每个epoch进行遍历。
    • 对每个batch进行前向传播计算损失,反向传播更新权重。
    • 每个epoch结束后调用 validate 函数计算验证集上的 CIDEr-D 分数。
    • 如果当前 CIDEr-D 分数优于历史最佳分数,则保存模型。

4. 主函数

def main():args = parse_args()  # 设置   字典setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args)  # LINUXtrain_and_validate(args)if __name__ == '__main__':main()
解释
  • 主函数
    • 调用 parse_args 解析命令行参数。
    • 调用 setup_logging 配置日志记录。
    • 调用 setup_devicesetup_seed 分别设置设备和随机种子。
    • 创建保存模型的目录(如果不存在)。
    • 打印训练和评估参数。
    • 调用 train_and_validate 函数开始训练和验证过程。

五、inference

from tqdm import tqdm
import csv
from model_utils.utils import to_device, array2str
from model_utils.models import myModel
from model_utils.data import create_dataloaders
import torch
from model_utils.config import parse_argsdef inference(args):test_loader = create_dataloaders(args,test=True)model = myModel(args)print(args.ckpt_file)checkpoint = torch.load(args.ckpt_file, map_location='cpu')model.load_state_dict(checkpoint['model_state_dict'],strict=False)model.to('cuda:0')model.eval()fp = open(args.test_output_csv, 'w', newline='')writer = csv.writer(fp)tot = 0for source in tqdm(test_loader):source = to_device(source, 'cuda:0')pred = model(source)pred = pred.cpu().numpy()for i in range(pred.shape[0]):writer.writerow([tot, array2str(pred[i][2:], args)])tot += 1fp.close()if __name__ == '__main__':args = parse_args()inference(args)

实现了一个推理(inference)流程,包括数据加载、模型加载、前向传播以及结果保存等功能。

1. 导入必要的库

from tqdm import tqdm
import csv
from model_utils.utils import to_device, array2str
from model_utils.models import myModel
from model_utils.data import create_dataloaders
import torch
from model_utils.config import parse_args
  • tqdm:用于显示进度条。
  • csv:用于处理CSV文件的读写操作。
  • to_device:自定义函数,将数据移动到指定设备(CPU或GPU)。
  • array2str:自定义函数,将数组转换为字符串。
  • myModel:自定义模型类。
  • create_dataloaders:自定义函数,创建数据加载器。
  • torch:PyTorch核心库。
  • parse_args:自定义函数,解析命令行参数或配置文件。

2. 定义推理函数

def inference(args):test_loader = create_dataloaders(args, test=True)model = myModel(args)print(args.ckpt_file)checkpoint = torch.load(args.ckpt_file, map_location='cpu')model.load_state_dict(checkpoint['model_state_dict'], strict=False)model.to('cuda:0')model.eval()fp = open(args.test_output_csv, 'w', newline='')writer = csv.writer(fp)tot = 0for source in tqdm(test_loader):source = to_device(source, 'cuda:0')pred = model(source)pred = pred.cpu().numpy()for i in range(pred.shape[0]):writer.writerow([tot, array2str(pred[i][2:], args)])tot += 1fp.close()
解释
  • 加载测试数据

    • 使用 create_dataloaders 函数加载测试数据加载器,设置 test=True 表示加载测试集。
  • 初始化模型并加载权重

    • 使用 myModel 类加载模型。
    • 打印预训练模型路径 args.ckpt_file
    • 使用 torch.load 加载预训练模型的权重,并使用 load_state_dict 方法加载到模型中。
    • 将模型移动到 GPU(cuda:0),并设置为评估模式(model.eval())。
  • 推理过程

    • 打开输出 CSV 文件,并创建 CSV 写入器。
    • 使用 tqdm 显示进度条遍历测试数据加载器中的每个批次 source
    • source 移动到 GPU 并进行前向传播得到预测结果 pred
    • 将预测结果转换为 NumPy 数组,并逐个样本写入 CSV 文件。

3. 主函数

if __name__ == '__main__':args = parse_args()inference(args)
  • 主函数
    • 调用 parse_args 解析命令行参数。
    • 调用 inference 函数开始推理过程。

相关文章:

生成任务,大模型

一个生成项目 输入:文字描述(但是给的数据集是一串数字,id,ct描述,医生描述) 输出:诊断报告 一、数据处理 import pandas as pd #处理表格数据pre_train_file "data/train.csv"tr…...

下载Hugging Face模型的几种方式

1.网页下载 直接访问Hugging Face模型页面,点击“File and versions”选项卡,选择所需的文件进行下载。 2.使用huggingface-cli 首先,安装huggingface_hub: pip install huggingface_hub 然后,使用以下命令下载模型&#xff1…...

【Elasticsearch入门到落地】9、hotel数据结构分析

接上篇《8、RestClient操作索引库-基础介绍及导入demo》 上一篇我们介绍了RestClient的基础,并导入了使用Java语言编写的RestClient程序Demo以及将要分析的数据库。本篇我们就要分析导入的宾馆数据库tb_hotel表结构的具体含义,并分析如何建立其索引库。 …...

【由技及道】量子构建交响曲:Jenkinsfile流水线的十一维编程艺术【人工智障AI2077的开发日志008】

摘要:当代码提交触发时空涟漪,当构建流水线穿越量子维度——欢迎来到自动化构建的终极形态。本文将揭示如何用Jenkinsfile编写量子构建乐章,让每次代码提交都成为跨维度交响乐的音符。 动机:构建系统的量子哲学 “主人啊&#xff…...

Unity开发——CanvasGroup组件介绍和应用

CanvasGroup是Unity中用于控制UI的透明度、交互性和渲染顺序的组件。 一、常用属性的解释 1、alpha:控制UI的透明度 类型:float,0.0 ~1.0, 其中 0.0 完全透明,1.0 完全不透明。 通过调整alpha值可以实现UI的淡入淡…...

jenkins配置连接k8s集群

jenkins配置连接k8s集群 前言 我这边jenkins是在一个服务器里面,k8s集群在其他服务器,实现连接 首先jenkins下载有k8s插件 进入配置页面 获取k8s-api-server地址 对应k8s服务器执行 kubectl config view --minify -o jsonpath{.clusters[0].cluste…...

Linux 入门:常用命令速查手册

目录 一.指令 1.pwd(显示所在路径) 2.ls(列出所有子目录与文件) 3.touch(创建文件) 4.mkdir(创建目录) 5.cd(改变所处位置) 6.rm(删除&…...

【VUE】day01-vue基本使用、调试工具、指令与过滤器

【VUE】day01-vue基本使用、调试工具、指令与过滤器 1. 什么是Vue2. Vue的基本使用 1. 什么是Vue Vue(Vue.js)是一个用于构建用户界面的渐进式 JavaScript 框架,其核心设计理念是“自底向上逐层应用”,既能作为轻量级库增强现有项…...

deepseek为什么要开源

一、生态位的抢占与锁定:以 JDK 版本为例​ 在软件开发的世界里,生态位的抢占和先入为主的效应十分显著。就拿 Java 开发中的 JDK 版本来说,目前大多数开发者仍在广泛使用 JDK8。尽管 JDK17 和 JDK21 已经推出,且具备更多先进特性…...

软考 中级软件设计师 考点知识点笔记总结 day02

文章目录 3、计算机系统组成 (五大部件)3.1、主存储器3.2、运算器3.3、控制器3.4、Flynn分类法 4、指令系统4.1、七种寻址方式4.2、指令的流水处理4.3、流水线的计算 上一篇文章 软考知识点 day01 3、计算机系统组成 (五大部件) …...

Redis Cluster 客户端定位分片全解析:哈希槽与动态路由机制

Redis Cluster客户端定位分片全解析:哈希槽与动态路由机制 一、引言 Redis Cluster通过分片技术将数据分散存储在多个节点,实现水平扩展。客户端如何快速定位目标分片?本文将深入解析哈希槽算法、路由逻辑及实战技巧。 二、核心原理&#…...

基于Python+Vue的智能服装商城管理系统的设计与实现

👗 基于PythonVue的智能服装商城管理系统的设计与实现 电商级解决方案:全栈技术融合 智能推荐系统 多维度数据分析 项目亮点:课程设计优选 | 企业级架构规范 | 完整电商功能闭环 | 毕业设计选择 🌐 在线资源速览 类别地址访问方…...

提升Web可访问性的10个关键实践

在当今互联网时代,确保网站的可访问性(Accessibility)已经成为开发者和设计师的重要任务之一。Web可访问性不仅有助于残障用户更好地访问和使用网站,还能提升整体用户体验。本文将介绍10个关键的Web可访问性实践,帮助你…...

基于DeepSeek的智慧医药系统(源码+部署教程)

运行环境 智慧医药系统运行环境如下: 前端: HTMLCSS后端:Java AIGCDeepseekIDE工具:IDEA技术栈:Springboot HTMLCSS MySQL 主要角色 智慧医药系统主要分为两个角色。 游客 尚未进行注册和登录。具备登录注册、…...

yolov5训练自己数据集的全流程+踩过的坑

一,拿到yolov5数据集的第一步是什么呢,安装必要的依赖文件。在requirements.txt文件下存放 pip install -r requirements.txt二,检查是否可以正常进行检测,在detect.py,文件下,里面有默认的设置文件是可以…...

【Recon】Git源代码泄露题目解题方法

CTF中Git源代码泄露题目解题方法 1. 确认存在.git目录泄露2. 下载完整的.git目录3. 恢复Git仓库历史4. 查找Flag的常见位置5. 处理不完整的.git目录6. 其他技巧示例流程 在CTF中遇到Git源代码泄露题目时,通常可以通过以下步骤解决: 1. 确认存在.git目录泄…...

Android APP 启动流程详解(含冷启动、热启动)

目录 一、流程对比图 二、冷启动(Cold Launch) 2.1 用户点击应用图标(Launcher 触发) 2.2 AMS 处理启动请求 2.3 请求 Zygote 创建新进程 2.4 初始化应用进程 2.5 创建 Application 对象 2.6 启动目标 Activity 2.7 执行 …...

Python实现网络通信:Socket模块与TCP/IP协议全解析

Langchain系列文章目录 01-玩转LangChain:从模型调用到Prompt模板与输出解析的完整指南 02-玩转 LangChain Memory 模块:四种记忆类型详解及应用场景全覆盖 03-全面掌握 LangChain:从核心链条构建到动态任务分配的实战指南 04-玩转 LangChai…...

信奥赛CSP-J复赛集训(模拟算法专题)(1):P8813 [CSP-J 2022] 乘方

信奥赛CSP-J复赛集训(模拟算法专题)(1):P8813 [CSP-J 2022] 乘方 题目描述 小文同学刚刚接触了信息学竞赛,有一天她遇到了这样一个题:给定正整数 a a a 和 b b b,求 a b a^b ab …...

MongoDB学习笔记

MongoDB https://www.mongodb.com/download-center/community 打开客户端 mongo.exe 注意6.0版本不一样需要自行安装Mongoshell MongoDB Shell Download | MongoDB 创建数据库 use go_db; 创建集合 db.createCollection("student"); 添加MongoDB依赖 go get …...

C#模拟鼠标点击,模拟鼠标双击,模拟鼠标恒定速度移动,可以看到轨迹

C#模拟鼠标点击,模拟鼠标双击,模拟鼠标恒定速度移动,可以看到轨迹 using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks;namespa…...

时间复杂度空间复杂度

一、时间复杂度 时间复杂度(Time Complexity)表示算法运行时间随输入规模增长的变化趋势。通常用大 O 表示法(Big O Notation)来描述。 常见时间复杂度 复杂度名称例子O(1)常数时间复杂度访问数组中的某个元素。O(log n)对数时间复…...

【科研绘图系列】R语言绘制组合箱线图(grouped boxplot)

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍加载R包数据下载导入数据画图输出图片系统信息介绍 【科研绘图系列】R语言绘制组合箱线图(grouped boxplot) 加载R包 library(tidyverse) library(lemon) library(ggnewscale)…...

【前缀和与差分 C/C++】洛谷 P8218 求区间和

2025 - 03 - 09 - 第 72 篇 Author: 郑龙浩 / 仟濹 【前缀和与差分 C/C】 文章目录 洛谷 P8218 求区间和题目描述输入格式输出格式输入输出样例 #1输入 #1输出 #1 说明/提示思路代码 洛谷 P8218 求区间和 题目描述 给定 n n n 个正整数组成的数列 a 1 , a 2 , ⋯ , a n a_…...

数据库二三事(14)

备份与恢复数据库 备份具体内容包括数据库结构,对象与数据,造成数据丢失的原因有: 存储介质故障(硬件损耗) 用户操作错误(人工) 服务器故障(软硬都可能) 病毒侵害 …...

C++之list

list是链表的意思&#xff0c;由一个个节点组成 一、基本接口使用&#xff1a; &#xff08;1&#xff09;与vector相同&#xff0c;有个尾插&#xff0c;也可以使用迭代器遍历&#xff1a; void test_list1() {list<int> lt;lt.push_back(1);lt.push_back(2);lt.push…...

数据增强术:如何利用大模型(LLMs)来模拟不同的扰动类型以增强信息提取任务的鲁棒性

一、对抗样本库构建 1. 基于LLMs的领域针对性扰动设计对抗样本生成 替换实体、三元组和触发器(Replace Entity, Triple, and Trigger) 使用LLMs(如GPT-4)来替换句子中的实体、关系三元组或事件触发器,同时保持其类型不变,并确保其他内容不受影响: xxx名称(如“x方” →…...

《Gradio : AI awesome-demos》

《Gradio : AI awesome-demos》 This is a list of some wonderful demos & applications built with Gradio. Heres how to contribute yours! &#x1f58a;️ Natural language processing Demo name (link to demo)input type(s)output type(s)status badgeruDALL-ET…...

物联网中如何增加其可扩展性 协议 网络 设备 还包括软件层面上的

物联网(IoT)系统的可扩展性是指系统能够随着设备数量、数据流量和业务需求的增长而灵活扩展的能力。为了增加物联网的可扩展性,需要从协议、网络、设备和软件等多个层面进行优化和设计。以下是一些具体的策略和方法: 1. 协议层面的可扩展性 1.1 采用轻量级协议 轻量级协议…...

【每日学点HarmonyOS Next知识】对话框去掉圆角、数组拼接、自定义对话框依附某个控件、平移动画、页面栈管理

1、 HarmonyOS CustomDialog怎么去掉左右和底部的透明以及圆角&#xff1f; CustomDialog怎么去掉左右和底部的透明以及圆角 设置customStyle为true即可开启使用自定义样式。设置borderRadius为0去掉圆角属性。 属性用法参考文档&#xff1a;https://developer.huawei.com/c…...

Unity 通用UI界面逻辑总结

概述 在游戏开发中&#xff0c;常常会遇到一些通用的界面逻辑&#xff0c;它不论在什么类型的游戏中都会出现。为了避免重复造轮子&#xff0c;本文总结并提供了一些常用UI界面的实现逻辑。希望可以帮助大家快速开发通用界面模块&#xff0c;也可以在次基础上进行扩展修改&…...

【网络】HTTP协议、HTTPS协议

HTTP与HTTPS HTTP协议概述 HTTP(超文本传输协议):工作在OSI顶层应用层,用于客户端(浏览器)与服务器之间的通信,B/S模式 无状态:每次请求独立,服务器不保存客户端状态(通过Cookie/Session扩展状态管理)。基于TCP:默认端口80(HTTP)、443(HTTPS),保证可靠传输。请…...

计算机网络——交换机

一、什么是交换机&#xff1f; 交换机&#xff08;Switch&#xff09;是局域网&#xff08;LAN&#xff09;中的核心设备&#xff0c;负责在 数据链路层&#xff08;OSI第二层&#xff09;高效转发数据帧。它像一位“智能交通警察”&#xff0c;根据设备的 MAC地址 精准引导数…...

机器学习:愚者未完成的诗篇(零)

当算法在数据海洋中打捞支离破碎的韵律时&#xff0c;机器学习系统展现出的智慧如同断臂的维纳斯雕像——完美与残缺构成令人战栗的美学悖论。愚者&#xff0c;在词语的混沌中编织逻辑经纬&#xff0c;却总在即将触及诗性本质的瞬间&#xff0c;暴露出认知维度的致命裂隙。 一…...

解锁DeepSpeek-R1大模型微调:从训练到部署,打造定制化AI会话系统

目录 1. 前言 2.大模型微调概念简述 2.1. 按学习范式分类 2.2. 按参数更新范围分类 2.3. 大模型微调框架简介 3. DeepSpeek R1大模型微调实战 3.1.LLaMA-Factory基础环境安装 3.1大模型下载 3.2. 大模型训练 3.3. 大模型部署 3.4. 微调大模型融合基于SpirngBootVue2…...

性能测试和Jmeter

文章目录 前言性能测试理论知识什么是性能&#xff1f;什么是性能测试&#xff1f;性能测试的作用性能测试与功能测试的区别性能测试常见术语性能测试的策略基准测试负载测试稳定性测试压力测试并发测试 常见性能测试指标响应时间并发数吞吐量点击数和错误率资源使用率 性能测试…...

Linux网络之数据链路层协议

目录 数据链路层 MAC地址与IP地址 数据帧 ARP协议 NAT技术 代理服务器 正向代理 反向代理 上期我们学习了网络层中的相关协议&#xff0c;为IP协议。IP协议通过报头中的目的IP地址告知了数据最终要传送的目的主机的IP地址&#xff0c;从而指引了数据在网络中的一步…...

数据结构第八节:红黑树(初阶)

【本节要点】 红黑树概念红黑树性质红黑树结点定义红黑树结构红黑树插入操作的分析 一、红黑树的概念与性质 1.1 红黑树的概念 红黑树 &#xff0c;是一种 二叉搜索树 &#xff0c;但 在每个结点上增加一个存储位表示结点的颜色&#xff0c;可以是 Red和 Black 。 通过对 任何…...

【大模型知识点】位置编码——绝对位置编码,相对位置编码,旋转位置编码RoPE

由于Transformer 中的自注意模块具有置换不变性&#xff08;不关心输入序列的顺序&#xff09;&#xff0c;因此需要使用位置编码来注入位置信息以建模序列&#xff0c;使模型能够区分不同位置的 token&#xff0c;并捕捉序列的顺序关系。 在介绍一些位置编码方法前&#xff0…...

【大模型篇】推理模型大作战(QwQ-32B vs DeepSeek-R1)

大家好,我是大 F,深耕AI算法十余年,互联网大厂技术岗。分享AI算法干货、技术心得。 欢迎关注《大模型理论和实战》、《DeepSeek技术解析和实战》,一起探索技术的无限可能! 写在前面 当我让QwQ-32B vs DeepSeek-R1 写一封未来自己的信 大家更喜欢哪种风格? QwQ-32B 模…...

【汇编语言】单片机程序执行过程

一、任务需求 指示灯LED4闪烁&#xff0c;亮0.5秒&#xff0c;灭0.5秒&#xff0c;无限循环 二、针对硬件的编程 1、确定原理图2、确定硬件的物理关系 三、设计步骤 1.用自己的语言描述工作流程 1.1指示灯LED4亮1.2延时0.5秒1.3指示灯LED4灭1.4延时0.5秒1.5跳转到1.1步 …...

MYSQL之创建数据库和表

创建数据库db_ck &#xff08;下面的创建是最好的创建方法&#xff0c;如果数据库存在也不会报错&#xff0c;并且指定使用utf8mb4&#xff09; show databases命令可以查看所有的数据库名&#xff0c;可以找到刚刚创建的db_ck数据库 使用该数据库时&#xff0c;发现里面没有…...

MybatisPlus

1.增删改查入门案例&#xff1a; 首先导入依赖&#xff1a; <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.5.3.1</version></dependency> 然后这些增删改查…...

HCIE云计算学什么?怎么学?未来职业发展如何?

随着云计算成为IT行业发展的主流方向&#xff0c;HCIE云计算&#xff08;华为认证云计算专家&#xff09;作为华为认证体系中的高端认证之一&#xff0c;逐渐成为了许多网络工程师和IT从业者提升职业竞争力的重要途径。 那么&#xff0c;HCIE云计算究竟学什么内容&#xff0c;如…...

小程序 -- uni-app开发微信小程序环境搭建(HBuilder X+微信开发者工具)

目录 前言 一 软件部分 1. 微信开发者工具 2. HBuilder X 开发工具 二 配置部分 1. 关于 HBuilder X 配置 2. 关于 微信开发工具 配置 三 运行项目 1. 新建项目 2. 代码编写 3. 内置浏览器 编译 4. 配置小程序 AppID获取 注意 四 实现效果 前言 uni-app开发小程…...

多线程-线程本地变量ThreadLocal

简介 ThreadLocal是线程本地变量&#xff0c;用于存储独属于线程的变量&#xff0c;这些变量可以在同一个线程内跨方法、跨类传递。每一个ThreadLocal对象&#xff0c;只能为当前线程关联一个数据&#xff0c;如果要为当前线程关联多个数据&#xff0c;就需要使用多个ThreadLo…...

MuBlE:为机器人操作任务规划提供了逼真的视觉观察和精确的物理建模

2025-03-05&#xff0c;由华为诺亚方舟实验室、捷克技术大学和帝国理工学院联合开发的MuBlE&#xff08;MuJoCo and Blender simulation Environment&#xff09;模拟环境和基准测试。通过结合MuJoCo物理引擎和Blender高质量渲染&#xff0c;为机器人操作任务规划提供了逼真的视…...

计算机网络笔记(一)——1.1计算机网络在信息时代中的作用

21世纪的一些重要特征是数字化、网络化和信息化&#xff0c;它是一个以网络为核心的信息时代。要实现信息化就必须依靠完善的网络&#xff0c;因为网络可以迅速地传递信息。网络现在已经成为信息社会的命脉和发展知识经济的重要基础。 有三大类网络大家应该很熟悉&#xff0c;即…...

第十五届蓝桥杯省赛电子类单片机学习过程记录(客观题)

客观试题: 01.典型的BUCK电源电路包含哪些关键器件(ABCD) A. 电容 B. 二极管 C. 电感 D. MOSFET 解析: 典型的 BUCK 电源电路是一种降压型的直流-直流转换电路,它包含以下关键器件: A.电容:电容在电路中起到滤波的作用。输入电容用于平滑输入电压的波动,减少电源噪声对…...

计算机组成与体系结构-存储系统

主存编址 存储单元&#xff1a;最小存储单元&#xff0c;一般为4bit。每个存储单元有自己的二进制编号 存储器&#xff1a;多个存储单元排布而成。常见的有8*4存储器&#xff08;8个4bit的存储单元&#xff09; 编址内容&#xff1a; 按字编址&#xff1a;存储体的最小存储单…...