2025.1.26机器学习笔记:C-RNN-GAN文献阅读
2025.1.26周报
- 文献阅读
- 题目信息
- 摘要
- Abstract
- 创新点
- 网络架构
- 实验
- 结论
- 缺点以及后续展望
- 总结
文献阅读
题目信息
- 题目: C-RNN-GAN: Continuous recurrent neural networks with adversarial training
- 会议期刊: NIPS
- 作者: Olof Mogren
- 发表时间: 2016
- 文章链接:https://arxiv.org/pdf/1611.09904
摘要
生成对抗网络(GANs)目的是生成数据,而循环神经网络(RNNs)常用于生成数据序列。目前已有研究用RNN进行音乐生成,但多使用符号表示。本论文中,作者研究了使用对抗训练生成连续数据的序列可行性,并使用古典音乐的midi文件进行评估。作者提出C-RNN-GAN(连续循环生成对抗网络)这种神经网络架构,用对抗训练来对序列的整体联合概率建模并生成高质量的数据序列。通过在古典音乐midi格式序列上训练该模型,并用音阶一致性和音域等指标进行评估,以验证生成对抗训练是一种可行的训练网络的方法,提出的模型为连续数据的生成提供了新思路。
Abstract
The purpose of Generative Adversarial Networks (GANs) is to generate data, while Recurrent Neural Networks (RNNs) are often used for generating data sequences. Currently, there have been many studies using RNNs for music generation, but most of them employ symbolic representations. In this paper, the authors investigate the feasibility of using adversarial training to generate sequences of continuous data, and evaluate it using classical music MIDI files. They propose the C-RNN-GAN (Continuous Recurrent Neural Network GAN), a neural network architecture that uses adversarial training to model the joint probability of the entire sequence and generate high-quality data sequences. By training this model on classical music MIDI format sequences and assessing it with metrics such as scale consistency and range, the authors demonstrate that adversarial training is a viable method for training networks, and the proposed model offers a new approach for the generation of continuous data.
创新点
本研究创新性在于提出C-RNN-GAN模型,作者采用对抗训练方式处理连续序列数据。作者使用四个实值标量对音乐信号进行生成,此外,还使用了反向传播算法进行端到端训练。
网络架构
提出C-RNN-GAN模型,RNN-GAN 由生成器(Generator)和判别器(Discriminator)两个主要部分组成。
如下图所示:
生成器(G)从随机输入(噪声)生成音乐序列。其包含LSTM层和全连接层。输入为随机噪声输入(如,随机向量);输出是生成的音乐序列。
判别器(D)用于区分生成的音乐序列和真实音乐序列。D由Bi-LSTM(双向长短期记忆网络)组成。输入为真实或生成的音乐序列;输出为一个概率值(表示输入序列是真实音乐的概率)。
在训练中,G与D相互对抗,生成器和判别器交替训练,生成器的目标是欺骗判别器,判别器的目标是准确区分真实和生成的音乐。
其中G与D的损失函数表达式如下:
L G = 1 m ∑ i = 1 m log ( 1 − D ( G ( z ( i ) ) ) ) L_{G}=\frac{1}{m} \sum_{i=1}^{m} \log \left(1-D\left(G\left(\boldsymbol{z}^{(i)}\right)\right)\right) LG=m1i=1∑mlog(1−D(G(z(i))))
L D = 1 m ∑ i = 1 m [ − log D ( x ( i ) ) − ( log ( 1 − D ( G ( z ( i ) ) ) ) ) ] L_{D}=\frac{1}{m} \sum_{i=1}^{m}\left[-\log D\left(\boldsymbol{x}^{(i)}\right)-\left(\log \left(1-D\left(G\left(\boldsymbol{z}^{(i)}\right)\right)\right)\right)\right] LD=m1i=1∑m[−logD(x(i))−(log(1−D(G(z(i)))))]
其中, z ( i ) z^{(i)} z(i) 是 [ 0 , 1 ] k [0,1]^{k} [0,1]k 中的均匀随机向量的序列,而 x ( i ) x^(i) x(i) 是来自训练数据的序列,k 表示随机序列中的数据的维数。G 中每个单元格的输入是一个随机向量,与先前单元格的输出串联。.
其实就跟我们之前阅读的GAN差不多,这里就不在赘述了。
实验
从网络收集midi格式的古典音乐文件作为训练数据,训练数据是以midi格式的音乐文件形式从网上收集的,包含着名的古典音乐作品。 每个midi事件被加载并与其持续时间,音调,强度(速度)以及自上一音调开始以来的时间一起保存。音调数据在内部用相应的声音频率表示。所有数据归一化为每秒384点的刻度分辨率。 该数据包含来自160位不同古典音乐作曲家的3697个midi文件,最后作者通过多维度指标评估生成音乐。
实验的模型评估指标:
Polyphony(复音):衡量两个音调同时演奏的频率。
Scale consistency(音阶一致性):通过计算属于标准音阶的音调比例得出,报告最匹配音阶的数值。
Repetitions (重复度):计算样本中的重复程度,仅考虑音调及其顺序,不考虑时间。
Tone span(音域):样本中最低和最高音调之间的半音步数。
模型参数:
生成器(G)和判别器(D)中的LSTM网络深度都为2,每个LSTM单元具有350个隐藏单元。
D双向的,而G是单向的。其中,来自D中的每个LSTM单元的输出被馈送到完全连接的层,其中权重在时间步长上共享,然后每个单元的sigmoid输出被平均化。
此外,在训练中使用反向传播(BPTT)和小批量随机梯度下降。学习率设置为0.1,并且将L2正则化应用于G和D中的权重。模型预训练6个epochs,平方误差损失以预测训练序列中的下一个事件。每个LSTM单元的输入是随机向量v,与前一时间步的输出连接。 v均匀分布在 [ 0 , 1 ] k [0,1]^k [0,1]k 中,并且k被选择为每个音调中的特征数量。在预训练期间,作者使用序列长度的模式,从短序列开始,从训练数据中随机样,最终用越来越长的序列训练模型。
实验结果:
C-RNN-GAN随着训练进行,生成音乐的复杂性增加。独特音调数量有微弱上升趋势,音阶一致性在10-15个周期后趋于稳定。
3音调重复在前25个周期有上升趋势,然后保持在较低水平,其与使用的音调数量相关。
Baseline(一个类似于生成器的循环网络)变化程度未达到C-RNN-GAN的水平。使用的独特音调数量一直低很多,音阶一致性与C-RNN-GAN相似,但音域与独特音调数量的关系比C-RNN-GAN更紧密,表明其使用的音调变化性更小。
C-RNN-GAN-3(3的意思是每个LSTM单元三个音调输出)与C-RNN-GAN和Baseline模型相比,获得了更高的复音分数。
在第50 - 55个周期左右达到许多零值输出状态后,在音域、独特音调数量、强度范围和3音调重复方面达到了更高的值。
真实音乐强度范围与生成音乐相似,音阶一致性略高但变化更大,复音分数与C-RNN-GAN-3相似,3音调重复高很多,但由于歌曲长度不同难以比较(通过除以真实音乐长度与生成音乐长度之比进行了归一化)。
从实验结果可以看出对抗训练有助于模型学习更多变、音域更广、强度范围更大的音乐。其中,模型每个LSTM单元输出多于一个音调有助于生成复音分数更高的音乐。虽然生成音乐是复音的,但在实验评估的复音分数方面,C-RNN-GAN得分较低,而允许每个LSTM单元同时输出多达三个音调的模型(C-RNN-GAN-3)在复音方面得分更好。虽然样本之间的时间差异较大,但在一首曲子内大致相同。
代码:https://github.com/olofmogren/c-rnn-gan
"""
模型参数:
learning_rate - 学习率的初始值
max_grad_norm - 梯度的最大允许范数
num_layers - LSTM 层的数量
songlength - LSTM 展开的步数
hidden_size - LSTM 单元的数量
epochs_before_decay - 使用初始学习率训练的轮数
max_epoch - 训练的总轮数
keep_prob - Dropout 层中保留权重的概率
lr_decay - 在 "epochs_before_decay" 之后每个轮数的学习率衰减
batch_size - 批量大小
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport time, datetime, os, sys
import pickle as pkl
from subprocess import call, Popenimport numpy as np
import tensorflow as tf
from tensorflow.python.client import timelineimport music_data_utils
from midi_statistics import get_all_statsflags = tf.flags
logging = tf.loggingflags.DEFINE_string("datadir", None, "保存和加载 MIDI 音乐文件的目录")
flags.DEFINE_string("traindir", None, "保存检查点和 gnuplot 文件的目录")
flags.DEFINE_integer("epochs_per_checkpoint", 2, "每个检查点之间进行的训练轮数")
flags.DEFINE_boolean("log_device_placement", False, "输出设备放置的信息")
flags.DEFINE_string("call_after", None, "退出后调用的命令")
flags.DEFINE_integer("exit_after", 1440, "运行多少分钟后退出")
flags.DEFINE_integer("select_validation_percentage", None, "选择作为验证集的数据的随机百分比")
flags.DEFINE_integer("select_test_percentage", None, "选择作为测试集的数据的随机百分比")
flags.DEFINE_boolean("sample", False, "从模型中采样输出。假设训练已经完成。将采样输出保存到文件中")
flags.DEFINE_integer("works_per_composer", None, "限制每个作曲家加载的作品数量")
flags.DEFINE_boolean("disable_feed_previous", False, "在生成器中,将前一个单元的输出作为下一个单元的输入")
flags.DEFINE_float("init_scale", 0.05, "权重的初始缩放值")
flags.DEFINE_float("learning_rate", 0.1, "学习率")
flags.DEFINE_float("d_lr_factor", 0.5, "学习率衰减因子")
flags.DEFINE_float("max_grad_norm", 5.0, "梯度的最大允许范数")
flags.DEFINE_float("keep_prob", 0.5, "保留权重的概率。1表示不使用 Dropout")
flags.DEFINE_float("lr_decay", 1.0, "在 'epochs_before_decay' 之后每个轮数的学习率衰减")
flags.DEFINE_integer("num_layers_g", 2, "生成器 G 中堆叠的循环单元数量")
flags.DEFINE_integer("num_layers_d", 2, "判别器 D 中堆叠的循环单元数量")
flags.DEFINE_integer("songlength", 100, "限制歌曲输入的事件数量")
flags.DEFINE_integer("meta_layer_size", 200, "元信息模块的隐藏层大小")
flags.DEFINE_integer("hidden_size_g", 100, "生成器 G 的循环部分的隐藏层大小")
flags.DEFINE_integer("hidden_size_d", 100, "判别器 D 的循环部分的隐藏层大小,默认与 G 相同")
flags.DEFINE_integer("epochs_before_decay", 60, "开始衰减之前进行的轮数")
flags.DEFINE_integer("max_epoch", 500, "停止训练之前的总轮数")
flags.DEFINE_integer("batch_size", 20, "批量大小")
flags.DEFINE_integer("biscale_slow_layer_ticks", 8, "Biscale 慢层的刻度")
flags.DEFINE_boolean("multiscale", False, "多尺度 RNN")
flags.DEFINE_integer("pretraining_epochs", 6, "进行语言模型风格预训练的轮数")
flags.DEFINE_boolean("pretraining_d", False, "在预训练期间训练 D")
flags.DEFINE_boolean("initialize_d", False, "初始化 D 的变量,无论检查点中是否有已训练的版本")
flags.DEFINE_boolean("ignore_saved_args", False, "告诉程序忽略已保存的参数,而是使用命令行参数")
flags.DEFINE_boolean("pace_events", False, "在解析输入数据时,如果某个四分音符位置没有音符,则插入一个虚拟事件")
flags.DEFINE_boolean("minibatch_d", False, "为小批量增加核特征以提高多样性")
flags.DEFINE_boolean("unidirectional_d", False, "使用单向 RNN 而不是双向 RNN 作为 D")
flags.DEFINE_boolean("profiling", False, "性能分析。在 plots 目录中写入 timeline.json 文件")
flags.DEFINE_boolean("float16", False, "使用 float16 数据类型,否则,使用 float32")
flags.DEFINE_boolean("adam", False, "使用 Adam 优化器")
flags.DEFINE_boolean("feature_matching", False, "生成器 G 的特征匹配目标")
flags.DEFINE_boolean("disable_l2_regularizer", False, "对权重进行 L2 正则化")
flags.DEFINE_float("reg_scale", 1.0, "L2 正则化系数")
flags.DEFINE_boolean("synthetic_chords", False, "使用合成生成的和弦进行训练(每个事件三个音符)")
flags.DEFINE_integer("tones_per_cell", 1, "每个 RNN 单元输出的最大音符数量")
flags.DEFINE_string("composer", None, "指定一个作曲家,并仅在此作曲家的作品上训练模型")
flags.DEFINE_boolean("generate_meta", False, "将作曲家和流派作为输出的一部分生成")
flags.DEFINE_float("random_input_scale", 1.0, "随机输入的缩放比例(1表示与生成的特征大小相同)")
flags.DEFINE_boolean("end_classification", False, "仅在 D 的末尾进行分类。否则,在每个时间步进行分类并取平均值")FLAGS = flags.FLAGSmodel_layout_flags = ['num_layers_g', 'num_layers_d', 'meta_layer_size', 'hidden_size_g', 'hidden_size_d', 'biscale_slow_layer_ticks', 'multiscale', 'multiscale', 'disable_feed_previous', 'pace_events', 'minibatch_d', 'unidirectional_d', 'feature_matching', 'composer']def make_rnn_cell(rnn_layer_sizes,dropout_keep_prob=1.0,attn_length=0,base_cell=tf.contrib.rnn.BasicLSTMCell,state_is_tuple=True,reuse=False):
"""
根据给定的超参数创建一个RNN单元。参数:rnn_layer_sizes:一个整数列表,表示 RNN 每层的大小。dropout_keep_prob:一个浮点数,表示保留任何给定子单元输出的概率。attn_length:注意力向量的大小。base_cell:用于子单元的基础 tf.contrib.rnn.RNNCell。state_is_tuple:一个布尔值,指定是否使用隐藏矩阵和单元矩阵的元组作为状态,而不是拼接矩阵。return:一个基于给定超参数的 tf.contrib.rnn.MultiRNNCell。"""cells = []for num_units in rnn_layer_sizes:cell = base_cell(num_units, state_is_tuple=state_is_tuple, reuse=reuse)cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=dropout_keep_prob)cells.append(cell)cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=state_is_tuple)if attn_length:cell = tf.contrib.rnn.AttentionCellWrapper(cell, attn_length, state_is_tuple=state_is_tuple, reuse=reuse)return cell
def restore_flags(save_if_none_found=True):if FLAGS.traindir:saved_args_dir = os.path.join(FLAGS.traindir, 'saved_args')if save_if_none_found:try: os.makedirs(saved_args_dir)except: passfor arg in FLAGS.__flags:if arg not in model_layout_flags:continueif FLAGS.ignore_saved_args and os.path.exists(os.path.join(saved_args_dir, arg+'.pkl')):print('{:%Y-%m-%d %H:%M:%S}: saved_args: Found {} setting from saved state, but using CLI args ({}) and saving (--ignore_saved_args).'.format(datetime.datetime.today(), arg, getattr(FLAGS, arg)))elif os.path.exists(os.path.join(saved_args_dir, arg+'.pkl')):with open(os.path.join(saved_args_dir, arg+'.pkl'), 'rb') as f:setattr(FLAGS, arg, pkl.load(f))print('{:%Y-%m-%d %H:%M:%S}: saved_args: {} from saved state ({}), ignoring CLI args.'.format(datetime.datetime.today(), arg, getattr(FLAGS, arg)))elif save_if_none_found:print('{:%Y-%m-%d %H:%M:%S}: saved_args: Found no {} setting from saved state, using CLI args ({}) and saving.'.format(datetime.datetime.today(), arg, getattr(FLAGS, arg)))with open(os.path.join(saved_args_dir, arg+'.pkl'), 'wb') as f:print(getattr(FLAGS, arg),arg)pkl.dump(getattr(FLAGS, arg), f)else:print('{:%Y-%m-%d %H:%M:%S}: saved_args: Found no {} setting from saved state, using CLI args ({}) but not saving.'.format(datetime.datetime.today(), arg, getattr(FLAGS, arg)))# 定义数据类型
def data_type():return tf.float16 if FLAGS.float16 else tf.float32#return tf.float16def my_reduce_mean(what_to_take_mean_over):return tf.reshape(what_to_take_mean_over, shape=[-1])[0]denom = 1.0#print(what_to_take_mean_over.get_shape())for d in what_to_take_mean_over.get_shape():#print(d)if type(d) == tf.Dimension:denom = denom*d.valueelse:denom = denom*dreturn tf.reduce_sum(what_to_take_mean_over)/denomdef linear(inp, output_dim, scope=None, stddev=1.0, reuse_scope=False):norm = tf.random_normal_initializer(stddev=stddev, dtype=data_type())const = tf.constant_initializer(0.0, dtype=data_type())with tf.variable_scope(scope or 'linear') as scope:scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))if reuse_scope:scope.reuse_variables()#print('inp.get_shape(): {}'.format(inp.get_shape()))w = tf.get_variable('w', [inp.get_shape()[1], output_dim], initializer=norm, dtype=data_type())b = tf.get_variable('b', [output_dim], initializer=const, dtype=data_type())return tf.matmul(inp, w) + bdef minibatch(inp, num_kernels=25, kernel_dim=10, scope=None, msg='', reuse_scope=False):with tf.variable_scope(scope or 'minibatch_d') as scope:scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))if reuse_scope:scope.reuse_variables()inp = tf.Print(inp, [inp],'{} inp = '.format(msg), summarize=20, first_n=20)x = tf.sigmoid(linear(inp, num_kernels * kernel_dim, scope))activation = tf.reshape(x, (-1, num_kernels, kernel_dim))activation = tf.Print(activation, [activation],'{} activation = '.format(msg), summarize=20, first_n=20)diffs = tf.expand_dims(activation, 3) - \tf.expand_dims(tf.transpose(activation, [1, 2, 0]), 0)diffs = tf.Print(diffs, [diffs],'{} diffs = '.format(msg), summarize=20, first_n=20)abs_diffs = tf.reduce_sum(tf.abs(diffs), 2)abs_diffs = tf.Print(abs_diffs, [abs_diffs],'{} abs_diffs = '.format(msg), summarize=20, first_n=20)minibatch_features = tf.reduce_sum(tf.exp(-abs_diffs), 2)minibatch_features = tf.Print(minibatch_features, [tf.reduce_min(minibatch_features), tf.reduce_max(minibatch_features)],'{} minibatch_features (min,max) = '.format(msg), summarize=20, first_n=20)return tf.concat( [inp, minibatch_features],1)class RNNGAN(object):"""定义RNN-GAN模型."""def __init__(self, is_training, num_song_features=None, num_meta_features=None):batch_size = FLAGS.batch_sizeself.batch_size = batch_sizesonglength = FLAGS.songlengthself.songlength = songlength#self.global_step= tf.Variable(0, trainable=False)print('songlength: {}'.format(self.songlength))self._input_songdata = tf.placeholder(shape=[batch_size, songlength, num_song_features], dtype=data_type())self._input_metadata = tf.placeholder(shape=[batch_size, num_meta_features], dtype=data_type())#_split = tf.split(self._input_songdata,songlength,1)[0]print("self._input_songdata",self._input_songdata, 'songlength',songlength)#print(tf.squeeze(_split,[1]))songdata_inputs = [tf.squeeze(input_, [1])for input_ in tf.split(self._input_songdata,songlength,1)]with tf.variable_scope('G') as scope:scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))#lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.hidden_size_g, forget_bias=1.0, state_is_tuple=True)if is_training and FLAGS.keep_prob < 1:#lstm_cell = tf.nn.rnn_cell.DropoutWrapper(# lstm_cell, output_keep_prob=FLAGS.keep_prob)cell = make_rnn_cell([FLAGS.hidden_size_g]*FLAGS.num_layers_g,dropout_keep_prob=FLAGS.keep_prob)else:cell = make_rnn_cell([FLAGS.hidden_size_g]*FLAGS.num_layers_g) #cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell for _ in range( FLAGS.num_layers_g)], state_is_tuple=True)self._initial_state = cell.zero_state(batch_size, data_type())# TODO: (possibly temporarily) disabling meta infoif FLAGS.generate_meta:metainputs = tf.random_uniform(shape=[batch_size, int(FLAGS.random_input_scale*num_meta_features)], minval=0.0, maxval=1.0)meta_g = tf.nn.relu(linear(metainputs, FLAGS.meta_layer_size, scope='meta_layer', reuse_scope=False))meta_softmax_w = tf.get_variable("meta_softmax_w", [FLAGS.meta_layer_size, num_meta_features])meta_softmax_b = tf.get_variable("meta_softmax_b", [num_meta_features])meta_logits = tf.nn.xw_plus_b(meta_g, meta_softmax_w, meta_softmax_b)meta_probs = tf.nn.softmax(meta_logits)random_rnninputs = tf.random_uniform(shape=[batch_size, songlength, int(FLAGS.random_input_scale*num_song_features)], minval=0.0, maxval=1.0, dtype=data_type())random_rnninputs = [tf.squeeze(input_, [1]) for input_ in tf.split( random_rnninputs,songlength,1)]# REAL GENERATOR:state = self._initial_state# as we feed the output as the input to the next, we 'invent' the initial 'output'.generated_point = tf.random_uniform(shape=[batch_size, num_song_features], minval=0.0, maxval=1.0, dtype=data_type())outputs = []self._generated_features = []for i,input_ in enumerate(random_rnninputs):if i > 0: scope.reuse_variables()concat_values = [input_]if not FLAGS.disable_feed_previous:concat_values.append(generated_point)if FLAGS.generate_meta:concat_values.append(meta_probs)if len(concat_values):input_ = tf.concat(axis=1, values=concat_values)input_ = tf.nn.relu(linear(input_, FLAGS.hidden_size_g,scope='input_layer', reuse_scope=(i!=0)))output, state = cell(input_, state)outputs.append(output)#generated_point = tf.nn.relu(linear(output, num_song_features, scope='output_layer', reuse_scope=(i!=0)))generated_point = linear(output, num_song_features, scope='output_layer', reuse_scope=(i!=0))self._generated_features.append(generated_point)# PRETRAINING GENERATOR, will feed inputs, not generated outputs:scope.reuse_variables()# as we feed the output as the input to the next, we 'invent' the initial 'output'.prev_target = tf.random_uniform(shape=[batch_size, num_song_features], minval=0.0, maxval=1.0, dtype=data_type())outputs = []self._generated_features_pretraining = []for i,input_ in enumerate(random_rnninputs):concat_values = [input_]if not FLAGS.disable_feed_previous:concat_values.append(prev_target)if FLAGS.generate_meta:concat_values.append(self._input_metadata)if len(concat_values):input_ = tf.concat(axis=1, values=concat_values)input_ = tf.nn.relu(linear(input_, FLAGS.hidden_size_g, scope='input_layer', reuse_scope=(i!=0)))output, state = cell(input_, state)outputs.append(output)#generated_point = tf.nn.relu(linear(output, num_song_features, scope='output_layer', reuse_scope=(i!=0)))generated_point = linear(output, num_song_features, scope='output_layer', reuse_scope=(i!=0))self._generated_features_pretraining.append(generated_point)prev_target = songdata_inputs[i]#outputs, state = tf.nn.rnn(cell, transformed, initial_state=self._initial_state)#self._generated_features = [tf.nn.relu(linear(output, num_song_features, scope='output_layer', reuse_scope=(i!=0))) for i,output in enumerate(outputs)]self._final_state = state# These are used both for pretraining and for D/G training further down.self._lr = tf.Variable(FLAGS.learning_rate, trainable=False, dtype=data_type())self.g_params = [v for v in tf.trainable_variables() if v.name.startswith('model/G/')]if FLAGS.adam:g_optimizer = tf.train.AdamOptimizer(self._lr)else:g_optimizer = tf.train.GradientDescentOptimizer(self._lr)reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)reg_constant = 0.1 # Choose an appropriate one.reg_loss = reg_constant * sum(reg_losses)reg_loss = tf.Print(reg_loss, reg_losses,'reg_losses = ', summarize=20, first_n=20)# 预训练print(tf.transpose(tf.stack(self._generated_features_pretraining), perm=[1, 0, 2]).get_shape())print(self._input_songdata.get_shape())self.rnn_pretraining_loss = tf.reduce_mean(tf.squared_difference(x=tf.transpose(tf.stack(self._generated_features_pretraining), perm=[1, 0, 2]), y=self._input_songdata))if not FLAGS.disable_l2_regularizer:self.rnn_pretraining_loss = self.rnn_pretraining_loss+reg_losspretraining_grads, _ = tf.clip_by_global_norm(tf.gradients(self.rnn_pretraining_loss, self.g_params), FLAGS.max_grad_norm)self.opt_pretraining = g_optimizer.apply_gradients(zip(pretraining_grads, self.g_params))# The discriminator tries to tell the difference between samples from the# true data distribution (self.x) and the generated samples (self.z).## Here we create two copies of the discriminator network (that share parameters),# as you cannot use the same network with different inputs in TensorFlow.with tf.variable_scope('D') as scope:scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))# Make list of tensors. One per step in recurrence.# Each tensor is batchsize*numfeatures.# TODO: (possibly temporarily) disabling meta infoprint('self._input_songdata shape {}'.format(self._input_songdata.get_shape()))print('generated data shape {}'.format(self._generated_features[0].get_shape()))# TODO: (possibly temporarily) disabling meta infoif FLAGS.generate_meta:songdata_inputs = [tf.concat([self._input_metadata, songdata_input],1) for songdata_input in songdata_inputs]#print(songdata_inputs[0])#print(songdata_inputs[0])#print('metadata inputs shape {}'(self._input_metadata.get_shape()))#print('generated metadata shape {}'.format(meta_probs.get_shape()))self.real_d,self.real_d_features = self.discriminator(songdata_inputs, is_training, msg='real')scope.reuse_variables()# TODO: (possibly temporarily) disabling meta infoif FLAGS.generate_meta:generated_data = [tf.concat([meta_probs, songdata_input],1) for songdata_input in self._generated_features]else:generated_data = self._generated_featuresif songdata_inputs[0].get_shape() != generated_data[0].get_shape():print('songdata_inputs shape {} != generated data shape {}'.format(songdata_inputs[0].get_shape(), generated_data[0].get_shape()))self.generated_d,self.generated_d_features = self.discriminator(generated_data, is_training, msg='generated')# Define the loss for discriminator and generator networks (see the original# paper for details), and create optimizers for bothself.d_loss = tf.reduce_mean(-tf.log(tf.clip_by_value(self.real_d, 1e-1000000, 1.0)) \-tf.log(1 - tf.clip_by_value(self.generated_d, 0.0, 1.0-1e-1000000)))self.g_loss_feature_matching = tf.reduce_sum(tf.squared_difference(self.real_d_features, self.generated_d_features))self.g_loss = tf.reduce_mean(-tf.log(tf.clip_by_value(self.generated_d, 1e-1000000, 1.0)))if not FLAGS.disable_l2_regularizer:self.d_loss = self.d_loss+reg_lossself.g_loss_feature_matching = self.g_loss_feature_matching+reg_lossself.g_loss = self.g_loss+reg_lossself.d_params = [v for v in tf.trainable_variables() if v.name.startswith('model/D/')]if not is_training:returnd_optimizer = tf.train.GradientDescentOptimizer(self._lr*FLAGS.d_lr_factor)d_grads, _ = tf.clip_by_global_norm(tf.gradients(self.d_loss, self.d_params),FLAGS.max_grad_norm)self.opt_d = d_optimizer.apply_gradients(zip(d_grads, self.d_params))if FLAGS.feature_matching:g_grads, _ = tf.clip_by_global_norm(tf.gradients(self.g_loss_feature_matching,self.g_params),FLAGS.max_grad_norm)else:g_grads, _ = tf.clip_by_global_norm(tf.gradients(self.g_loss, self.g_params),FLAGS.max_grad_norm)self.opt_g = g_optimizer.apply_gradients(zip(g_grads, self.g_params))self._new_lr = tf.placeholder(shape=[], name="new_learning_rate", dtype=data_type())self._lr_update = tf.assign(self._lr, self._new_lr)def discriminator(self, inputs, is_training, msg=''):# RNN discriminator:#for i in xrange(len(inputs)):# print('shape inputs[{}] {}'.format(i, inputs[i].get_shape()))#inputs[0] = tf.Print(inputs[0], [inputs[0]],# '{} inputs[0] = '.format(msg), summarize=20, first_n=20)if is_training and FLAGS.keep_prob < 1:inputs = [tf.nn.dropout(input_, FLAGS.keep_prob) for input_ in inputs]#lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.hidden_size_d, forget_bias=1.0, state_is_tuple=True)if is_training and FLAGS.keep_prob < 1:#lstm_cell = tf.nn.rnn_cell.DropoutWrapper(#lstm_cell, output_keep_prob=FLAGS.keep_prob)cell_fw = make_rnn_cell([FLAGS.hidden_size_d]* FLAGS.num_layers_d,dropout_keep_prob=FLAGS.keep_prob)cell_bw = make_rnn_cell([FLAGS.hidden_size_d]* FLAGS.num_layers_d,dropout_keep_prob=FLAGS.keep_prob)else:cell_fw = make_rnn_cell([FLAGS.hidden_size_d]* FLAGS.num_layers_d)cell_bw = make_rnn_cell([FLAGS.hidden_size_d]* FLAGS.num_layers_d)#cell_fw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell for _ in range( FLAGS.num_layers_d)], state_is_tuple=True)self._initial_state_fw = cell_fw.zero_state(self.batch_size, data_type())if not FLAGS.unidirectional_d:#lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.hidden_size_g, forget_bias=1.0, state_is_tuple=True)#if is_training and FLAGS.keep_prob < 1:# lstm_cell = tf.nn.rnn_cell.DropoutWrapper(# lstm_cell, output_keep_prob=FLAGS.keep_prob)#cell_bw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell for _ in range( FLAGS.num_layers_d)], state_is_tuple=True)self._initial_state_bw = cell_bw.zero_state(self.batch_size, data_type())print("cell_fw",cell_fw.output_size)#print("cell_bw",cell_bw.output_size)#print("inputs",inputs)#print("initial_state_fw",self._initial_state_fw)#print("initial_state_bw",self._initial_state_bw)outputs, state_fw, state_bw = tf.contrib.rnn.static_bidirectional_rnn(cell_fw, cell_bw, inputs, initial_state_fw=self._initial_state_fw, initial_state_bw=self._initial_state_bw)#outputs[0] = tf.Print(outputs[0], [outputs[0]],# '{} outputs[0] = '.format(msg), summarize=20, first_n=20)#state = tf.concat(state_fw, state_bw)#endoutput = tf.concat(concat_dim=1, values=[outputs[0],outputs[-1]])else:outputs, state = tf.nn.rnn(cell_fw, inputs, initial_state=self._initial_state_fw)#state = self._initial_state#outputs, state = cell_fw(tf.convert_to_tensor (inputs),state)#endoutput = outputs[-1]if FLAGS.minibatch_d:outputs = [minibatch(tf.reshape(outp, shape=[FLAGS.batch_size, -1]), msg=msg, reuse_scope=(i!=0)) for i,outp in enumerate(outputs)]# decision = tf.sigmoid(linear(outputs[-1], 1, 'decision'))if FLAGS.end_classification:decisions = [tf.sigmoid(linear(output, 1, 'decision', reuse_scope=(i!=0))) for i,output in enumerate([outputs[0], outputs[-1]])]decisions = tf.stack(decisions)decisions = tf.transpose(decisions, perm=[1,0,2])print('shape, decisions: {}'.format(decisions.get_shape()))else:decisions = [tf.sigmoid(linear(output, 1, 'decision', reuse_scope=(i!=0))) for i,output in enumerate(outputs)]decisions = tf.stack(decisions)decisions = tf.transpose(decisions, perm=[1,0,2])print('shape, decisions: {}'.format(decisions.get_shape()))decision = tf.reduce_mean(decisions, reduction_indices=[1,2])decision = tf.Print(decision, [decision],'{} decision = '.format(msg), summarize=20, first_n=20)return (decision,tf.transpose(tf.stack(outputs), perm=[1,0,2]))def assign_lr(self, session, lr_value):session.run(self._lr_update, feed_dict={self._new_lr: lr_value})@propertydef generated_features(self):return self._generated_features@propertydef input_songdata(self):return self._input_songdata@propertydef input_metadata(self):return self._input_metadata@propertydef targets(self):return self._targets@propertydef initial_state(self):return self._initial_state@propertydef cost(self):return self._cost@propertydef final_state(self):return self._final_state@propertydef lr(self):return self._lr@propertydef train_op(self):return self._train_opdef run_epoch(session, model, loader, datasetlabel, eval_op_g, eval_op_d, pretraining=False, verbose=False, run_metadata=None, pretraining_d=False):"""Runs the model on the given data."""#epoch_size = ((len(data) // model.batch_size) - 1) // model.songlengthepoch_start_time = time.time()g_loss, d_loss = 10.0, 10.0g_losses, d_losses = 0.0, 0.0iters = 0#state = session.run(model.initial_state)time_before_graph = Nonetime_after_graph = Nonetimes_in_graph = []times_in_python = []#times_in_batchreading = []loader.rewind(part=datasetlabel)[batch_meta, batch_song] = loader.get_batch(model.batch_size, model.songlength, part=datasetlabel)run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)while batch_meta is not None and batch_song is not None:op_g = eval_op_gop_d = eval_op_dif datasetlabel == 'train' and not pretraining: # and not FLAGS.feature_matching:if d_loss == 0.0 and g_loss == 0.0:print('Both G and D train loss are zero. Exiting.')break#saver.save(session, checkpoint_path, global_step=m.global_step)#breakelif d_loss == 0.0:#print('D train loss is zero. Freezing optimization. G loss: {:.3f}'.format(g_loss))op_g = tf.no_op()elif g_loss == 0.0: #print('G train loss is zero. Freezing optimization. D loss: {:.3f}'.format(d_loss))op_d = tf.no_op()elif g_loss < 2.0 or d_loss < 2.0:if g_loss*.7 > d_loss:#print('G train loss is {:.3f}, D train loss is {:.3f}. Freezing optimization of D'.format(g_loss, d_loss))op_g = tf.no_op()#elif d_loss*.7 > g_loss:#print('G train loss is {:.3f}, D train loss is {:.3f}. Freezing optimization of G'.format(g_loss, d_loss))op_d = tf.no_op()#fetches = [model.cost, model.final_state, eval_op]if pretraining:if pretraining_d:fetches = [model.rnn_pretraining_loss, model.d_loss, op_g, op_d]else:fetches = [model.rnn_pretraining_loss, tf.no_op(), op_g, op_d]else:fetches = [model.g_loss, model.d_loss, op_g, op_d]feed_dict = {}feed_dict[model.input_songdata.name] = batch_songfeed_dict[model.input_metadata.name] = batch_meta#print(batch_song)#print(batch_song.shape)#for i, (c, h) in enumerate(model.initial_state):# feed_dict[c] = state[i].c# feed_dict[h] = state[i].h#cost, state, _ = session.run(fetches, feed_dict)time_before_graph = time.time()if iters > 0:times_in_python.append(time_before_graph-time_after_graph)if run_metadata:g_loss, d_loss, _, _ = session.run(fetches, feed_dict, options=run_options, run_metadata=run_metadata)else:g_loss, d_loss, _, _ = session.run(fetches, feed_dict)time_after_graph = time.time()if iters > 0:times_in_graph.append(time_after_graph-time_before_graph)g_losses += g_lossif not pretraining:d_losses += d_lossiters += 1if verbose and iters % 10 == 9:songs_per_sec = float(iters * model.batch_size)/float(time.time() - epoch_start_time)avg_time_in_graph = float(sum(times_in_graph))/float(len(times_in_graph))avg_time_in_python = float(sum(times_in_python))/float(len(times_in_python))#avg_time_batchreading = float(sum(times_in_batchreading))/float(len(times_in_batchreading))if pretraining:print("{}: {} (pretraining) batch loss: G: {:.3f}, avg loss: G: {:.3f}, speed: {:.1f} songs/s, avg in graph: {:.1f}, avg in python: {:.1f}.".format(datasetlabel, iters, g_loss, float(g_losses)/float(iters), songs_per_sec, avg_time_in_graph, avg_time_in_python))else:print("{}: {} batch loss: G: {:.3f}, D: {:.3f}, avg loss: G: {:.3f}, D: {:.3f} speed: {:.1f} songs/s, avg in graph: {:.1f}, avg in python: {:.1f}.".format(datasetlabel, iters, g_loss, d_loss, float(g_losses)/float(iters), float(d_losses)/float(iters),songs_per_sec, avg_time_in_graph, avg_time_in_python))#batchtime = time.time()[batch_meta, batch_song] = loader.get_batch(model.batch_size, model.songlength, part=datasetlabel)#times_in_batchreading.append(time.time()-batchtime)if iters == 0:return (None,None)g_mean_loss = g_losses/itersif pretraining and not pretraining_d:d_mean_loss = Noneelse:d_mean_loss = d_losses/itersreturn (g_mean_loss, d_mean_loss)def sample(session, model, batch=False):"""Samples from the generative model."""#state = session.run(model.initial_state)fetches = [model.generated_features]feed_dict = {}generated_features, = session.run(fetches, feed_dict)#print( generated_features)print( generated_features[0].shape)# The following worked when batch_size=1.# generated_features = [np.squeeze(x, axis=0) for x in generated_features]# If batch_size != 1, we just pick the first sample. Wastefull, yes.returnable = []if batch:for batchno in range(generated_features[0].shape[0]):returnable.append([x[batchno,:] for x in generated_features])else:returnable = [x[0,:] for x in generated_features]return returnabledef main(_):if not FLAGS.datadir:raise ValueError("Must set --datadir to midi music dir.")if not FLAGS.traindir:raise ValueError("Must set --traindir to dir where I can save model and plots.")restore_flags()summaries_dir = Noneplots_dir = Nonegenerated_data_dir = Nonesummaries_dir = os.path.join(FLAGS.traindir, 'summaries')plots_dir = os.path.join(FLAGS.traindir, 'plots')generated_data_dir = os.path.join(FLAGS.traindir, 'generated_data')try: os.makedirs(FLAGS.traindir)except: passtry: os.makedirs(summaries_dir)except: passtry: os.makedirs(plots_dir)except: passtry: os.makedirs(generated_data_dir)except: passdirectorynames = FLAGS.traindir.split('/')experiment_label = ''while not experiment_label:experiment_label = directorynames.pop()global_step = -1if os.path.exists(os.path.join(FLAGS.traindir, 'global_step.pkl')):with open(os.path.join(FLAGS.traindir, 'global_step.pkl'), 'r') as f:global_step = pkl.load(f)global_step += 1songfeatures_filename = os.path.join(FLAGS.traindir, 'num_song_features.pkl')metafeatures_filename = os.path.join(FLAGS.traindir, 'num_meta_features.pkl')synthetic=Noneif FLAGS.synthetic_chords:synthetic = 'chords'print('Training on synthetic chords!')if FLAGS.composer is not None:print('Single composer: {}'.format(FLAGS.composer))loader = music_data_utils.MusicDataLoader(FLAGS.datadir, FLAGS.select_validation_percentage, FLAGS.select_test_percentage, FLAGS.works_per_composer, FLAGS.pace_events, synthetic=synthetic, tones_per_cell=FLAGS.tones_per_cell, single_composer=FLAGS.composer)if FLAGS.synthetic_chords:# This is just a print out, to check the generated data.batch = loader.get_batch(batchsize=1, songlength=400)loader.get_midi_pattern([batch[1][0][i] for i in xrange(batch[1].shape[1])])num_song_features = loader.get_num_song_features()print('num_song_features:{}'.format(num_song_features))num_meta_features = loader.get_num_meta_features()print('num_meta_features:{}'.format(num_meta_features))train_start_time = time.time()checkpoint_path = os.path.join(FLAGS.traindir, "model.ckpt")songlength_ceiling = FLAGS.songlengthif global_step < FLAGS.pretraining_epochs:FLAGS.songlength = int(min(((global_step+10)/10)*10,songlength_ceiling))FLAGS.songlength = int(min((global_step+1)*4,songlength_ceiling))with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) as session:with tf.variable_scope("model", reuse=None) as scope:scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))m = RNNGAN(is_training=True, num_song_features=num_song_features, num_meta_features=num_meta_features)if FLAGS.initialize_d:vars_to_restore = {}for v in tf.trainable_variables():if v.name.startswith('model/G/'):print(v.name[:-2])vars_to_restore[v.name[:-2]] = vsaver = tf.train.Saver(vars_to_restore)ckpt = tf.train.get_checkpoint_state(FLAGS.traindir)if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):print("Reading model parameters from %s" % ckpt.model_checkpoint_path,end=" ")saver.restore(session, ckpt.model_checkpoint_path)session.run(tf.initialize_variables([v for v in tf.trainable_variables() if v.name.startswith('model/D/')]))else:print("Created model with fresh parameters.")session.run(tf.initialize_all_variables())saver = tf.train.Saver(tf.all_variables())else:saver = tf.train.Saver(tf.all_variables())ckpt = tf.train.get_checkpoint_state(FLAGS.traindir)if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):print("Reading model parameters from %s" % ckpt.model_checkpoint_path)saver.restore(session, ckpt.model_checkpoint_path)else:print("Created model with fresh parameters.")session.run(tf.initialize_all_variables())run_metadata = Noneif FLAGS.profiling:run_metadata = tf.RunMetadata()if not FLAGS.sample:train_g_loss,train_d_loss = 1.0,1.0for i in range(global_step, FLAGS.max_epoch):lr_decay = FLAGS.lr_decay ** max(i - FLAGS.epochs_before_decay, 0.0)if global_step < FLAGS.pretraining_epochs:#new_songlength = int(min(((i+10)/10)*10,songlength_ceiling))new_songlength = int(min((i+1)*4,songlength_ceiling))else:new_songlength = songlength_ceilingif new_songlength != FLAGS.songlength:print('Changing songlength, now training on {} events from songs.'.format(new_songlength))FLAGS.songlength = new_songlengthwith tf.variable_scope("model", reuse=True) as scope:scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))m = RNNGAN(is_training=True, num_song_features=num_song_features, num_meta_features=num_meta_features)if not FLAGS.adam:m.assign_lr(session, FLAGS.learning_rate * lr_decay)save = Falsedo_exit = Falseprint("Epoch: {} Learning rate: {:.3f}, pretraining: {}".format(i, session.run(m.lr), (i<FLAGS.pretraining_epochs)))if i<FLAGS.pretraining_epochs:opt_d = tf.no_op()if FLAGS.pretraining_d:opt_d = m.opt_dtrain_g_loss,train_d_loss = run_epoch(session, m, loader, 'train', m.opt_pretraining, opt_d, pretraining = True, verbose=True, run_metadata=run_metadata, pretraining_d=FLAGS.pretraining_d)if FLAGS.pretraining_d:try:print("Epoch: {} Pretraining loss: G: {:.3f}, D: {:.3f}".format(i, train_g_loss, train_d_loss))except:print(train_g_loss)print(train_d_loss)else:print("Epoch: {} Pretraining loss: G: {:.3f}".format(i, train_g_loss))else:train_g_loss,train_d_loss = run_epoch(session, m, loader, 'train', m.opt_d, m.opt_g, verbose=True, run_metadata=run_metadata)try:print("Epoch: {} Train loss: G: {:.3f}, D: {:.3f}".format(i, train_g_loss, train_d_loss))except:print("Epoch: {} Train loss: G: {}, D: {}".format(i, train_g_loss, train_d_loss))valid_g_loss,valid_d_loss = run_epoch(session, m, loader, 'validation', tf.no_op(), tf.no_op())try:print("Epoch: {} Valid loss: G: {:.3f}, D: {:.3f}".format(i, valid_g_loss, valid_d_loss))except:print("Epoch: {} Valid loss: G: {}, D: {}".format(i, valid_g_loss, valid_d_loss))if train_d_loss == 0.0 and train_g_loss == 0.0:print('Both G and D train loss are zero. Exiting.')save = Truedo_exit = Trueif i % FLAGS.epochs_per_checkpoint == 0:save = Trueif FLAGS.exit_after > 0 and time.time() - train_start_time > FLAGS.exit_after*60:print("%s: Has been running for %d seconds. Will exit (exiting after %d minutes)."%(datetime.datetime.today().strftime('%Y-%m-%d %H:%M:%S'), (int)(time.time() - train_start_time), FLAGS.exit_after))save = Truedo_exit = Trueif save:saver.save(session, checkpoint_path, global_step=i)with open(os.path.join(FLAGS.traindir, 'global_step.pkl'), 'wb') as f:pkl.dump(i, f)if FLAGS.profiling:# Create the Timeline object, and write it to a jsontl = timeline.Timeline(run_metadata.step_stats)ctf = tl.generate_chrome_trace_format()with open(os.path.join(plots_dir, 'timeline.json'), 'w') as f:f.write(ctf)print('{}: Saving done!'.format(i))step_time, loss = 0.0, 0.0if train_d_loss is None: #pretrainingtrain_d_loss = 0.0valid_d_loss = 0.0valid_g_loss = 0.0if not os.path.exists(os.path.join(plots_dir, 'gnuplot-input.txt')):with open(os.path.join(plots_dir, 'gnuplot-input.txt'), 'w') as f:f.write('# global-step learning-rate train-g-loss train-d-loss valid-g-loss valid-d-loss\n')with open(os.path.join(plots_dir, 'gnuplot-input.txt'), 'a') as f:try:f.write('{} {:.4f} {:.2f} {:.2f} {:.3} {:.3f}\n'.format(i, m.lr.eval(), train_g_loss, train_d_loss, valid_g_loss, valid_d_loss))except:f.write('{} {} {} {} {} {}\n'.format(i, m.lr.eval(), train_g_loss, train_d_loss, valid_g_loss, valid_d_loss))if not os.path.exists(os.path.join(plots_dir, 'gnuplot-commands-loss.txt')):with open(os.path.join(plots_dir, 'gnuplot-commands-loss.txt'), 'a') as f:f.write('set terminal postscript eps color butt "Times" 14\nset yrange [0:400]\nset output "loss.eps"\nplot \'gnuplot-input.txt\' using ($1):($3) title \'train G\' with linespoints, \'gnuplot-input.txt\' using ($1):($4) title \'train D\' with linespoints, \'gnuplot-input.txt\' using ($1):($5) title \'valid G\' with linespoints, \'gnuplot-input.txt\' using ($1):($6) title \'valid D\' with linespoints, \n')if not os.path.exists(os.path.join(plots_dir, 'gnuplot-commands-midistats.txt')):with open(os.path.join(plots_dir, 'gnuplot-commands-midistats.txt'), 'a') as f:f.write('set terminal postscript eps color butt "Times" 14\nset yrange [0:127]\nset xrange [0:70]\nset output "midistats.eps"\nplot \'midi_stats.gnuplot\' using ($1):(100*$3) title \'Scale consistency, %\' with linespoints, \'midi_stats.gnuplot\' using ($1):($6) title \'Tone span, halftones\' with linespoints, \'midi_stats.gnuplot\' using ($1):($10) title \'Unique tones\' with linespoints, \'midi_stats.gnuplot\' using ($1):($23) title \'Intensity span, units\' with linespoints, \'midi_stats.gnuplot\' using ($1):(100*$24) title \'Polyphony, %\' with linespoints, \'midi_stats.gnuplot\' using ($1):($12) title \'3-tone repetitions\' with linespoints\n')try:Popen(['gnuplot','gnuplot-commands-loss.txt'], cwd=plots_dir)Popen(['gnuplot','gnuplot-commands-midistats.txt'], cwd=plots_dir)except:print('failed to run gnuplot. Please do so yourself: gnuplot gnuplot-commands.txt cwd={}'.format(plots_dir))song_data = sample(session, m, batch=True)midi_patterns = []print('formatting midi...')midi_time = time.time()for d in song_data:midi_patterns.append(loader.get_midi_pattern(d))print('done. time: {}'.format(time.time()-midi_time))filename = os.path.join(generated_data_dir, 'out-{}-{}-{}.mid'.format(experiment_label, i, datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')))loader.save_midi_pattern(filename, midi_patterns[0])stats = []print('getting stats...')stats_time = time.time()for p in midi_patterns:stats.append(get_all_stats(p))print('done. time: {}'.format(time.time()-stats_time))#print(stats)stats = [stat for stat in stats if stat is not None]if len(stats):stats_keys_string = ['scale']stats_keys = ['scale_score', 'tone_min', 'tone_max', 'tone_span', 'freq_min', 'freq_max', 'freq_span', 'tones_unique', 'repetitions_2', 'repetitions_3', 'repetitions_4', 'repetitions_5', 'repetitions_6', 'repetitions_7', 'repetitions_8', 'repetitions_9', 'estimated_beat', 'estimated_beat_avg_ticks_off', 'intensity_min', 'intensity_max', 'intensity_span', 'polyphony_score', 'top_2_interval_difference', 'top_3_interval_difference', 'num_tones']statsfilename = os.path.join(plots_dir, 'midi_stats.gnuplot')if not os.path.exists(statsfilename):with open(statsfilename, 'a') as f:f.write('# Average numers over one minibatch of size {}.\n'.format(FLAGS.batch_size))f.write('# global-step {} {}\n'.format(' '.join([s.replace(' ', '_') for s in stats_keys_string]), ' '.join(stats_keys)))with open(statsfilename, 'a') as f:f.write('{} {} {}\n'.format(i, ' '.join(['{}'.format(stats[0][key].replace(' ', '_')) for key in stats_keys_string]), ' '.join(['{:.3f}'.format(sum([s[key] for s in stats])/float(len(stats))) for key in stats_keys])))print('Saved {}.'.format(filename))if do_exit:if FLAGS.call_after is not None:print("%s: Will call \"%s\" before exiting."%(datetime.datetime.today().strftime('%Y-%m-%d %H:%M:%S'), FLAGS.call_after))res = call(FLAGS.call_after.split(" "))print ('{}: call returned {}.'.format(datetime.datetime.today().strftime('%Y-%m-%d %H:%M:%S'), res))exit()sys.stdout.flush()test_g_loss,test_d_loss = run_epoch(session, m, loader, 'test', tf.no_op(), tf.no_op())print("Test loss G: %.3f, D: %.3f" %(test_g_loss, test_d_loss))song_data = sample(session, m)filename = os.path.join(generated_data_dir, 'out-{}-{}-{}.mid'.format(experiment_label, i, datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')))loader.save_data(filename, song_data)print('Saved {}.'.format(filename))if __name__ == "__main__":tf.app.run()
结论
作者提出了一种基于生成对抗网络训练的连续数据循环神经网络C-RNN-GAN。实验结果表明对抗训练有助于模型学习更多变的模式。虽然生成音乐与训练数据中的音乐相比仍有差距,但C-RNN-GAN生成音乐更接近真实音乐。
缺点以及后续展望
模型虽能生成音乐,但与人类判断的音乐仍有差距,后续可深入探究生成音乐与真实音乐存在差距的原因。作者提出可以进一步优化模型结构,提高生成音乐的质量。此外,还可研究该模型在其他类型连续序列数据中的应用。
总结
本周我阅读了一篇关于GAN生成序列数据的论文,为下一次阅读TimeGAN论文打作铺垫。通过阅读这篇论文,我了解到C-RNN-GAN模型如何利用对抗训练来生成连续序列数据,其中,生成器(G)包含LSTM层和全连接层;判别器(D)由Bi-LSTM(双向长短期记忆网络)组成。即 D双向的,G是单向的。同时,作者也通过实验证明了C-RNN-GAN的优势,虽然模型在序列数据生成方面有一定的效果,但仍存在一些不足之处,如生成序列数据与真实序列数据之间任然存在差距、模型结构尚可优化、应用到其他场景等等。作者提出的这些不足与展望为我后续研究数据增强方向提供了参考和思路。
相关文章:
2025.1.26机器学习笔记:C-RNN-GAN文献阅读
2025.1.26周报 文献阅读题目信息摘要Abstract创新点网络架构实验结论缺点以及后续展望 总结 文献阅读 题目信息 题目: C-RNN-GAN: Continuous recurrent neural networks with adversarial training会议期刊: NIPS作者: Olof Mogren发表时间…...
嵌入式蓝桥杯电子赛嵌入式(第14届国赛真题)总结
打开systic 生成工程编译查看是否有问题同时打开对应需要的文档 修改名称的要求 5.简单浏览赛题 选择题,跟单片机有关的可以查相关手册 答题顺序 先从显示开始看 1,2 所以先打开PA1的定时器这次选TIM2 从模式、TI2FP2二通道、内部时钟、1通道设为直接2通道设置…...
【机器学习】深入探索SVM:支持向量机的原理与应用
目录 🍔 SVM引入 1.1什么是SVM? 1.2支持向量机分类 1.3 线性可分、线性和非线性的区分 🍔 小结 学习目标 知道SVM的概念 🍔 SVM引入 1.1什么是SVM? 看一个故事,故事是这样子的: 在很久以前的情人节…...
Leetcode40: 组合总和 II
题目描述: 给定一个候选人编号的集合 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合。 candidates 中的每个数字在每个组合中只能使用 一次 。 注意:解集不能包含重复的组合。 代码思路ÿ…...
项目测试之MockMvc
文章目录 基础基础概念Mockxxx一般实现文件位置 实战MockMvc与Test注解不兼容RequestParams参数RequestBody参数 基础 基础概念 定义:是Spring框架提供的一种用于测试Spring MVC控制器的工具,它允许开发者在不启动完整的web服务器的情况下,…...
网易Android开发面试题200道及参考答案 (下)
说明原码、反码、补码的概念 原码:是一种简单的机器数表示法。对于有符号数,最高位为符号位,0 表示正数,1 表示负数,其余位表示数值的绝对值。比如,对于 8 位二进制数,+5 的原码是 00000101,-5 的原码是 10000101。原码的优点是直观,容易理解,但在进行加减法运算时,…...
PHP根据IP地址获取地理位置城市和经纬度信息
/** 根据IP地址 获取地理位置*/ function getLocationByIP($ip) {$url "http://ip-api.com/json/{$ip}?langzh-CN&fieldsstatus,message,country,countryCode,region,regionName,city,lat,lon,timezone,isp,org,as";$response file_get_contents($url);$data …...
AI Agent的多轮对话:提升用户体验的关键技巧
在前面的文章中,我们讨论了 AI Agent 的各个核心系统。今天,我想聊聊如何实现一个好用的多轮对话系统。说实话,这个话题我琢磨了很久,因为它直接影响到用户体验。 从一个槽点说起 还记得我最开始做对话系统时的一个典型场景&…...
在docker上部署nacos
一、首先下载nacos的docker镜像 docker pull nacos:2.5.0 二、然后下载nacos的安装包,这里是为了拿到他的配置文件。下载完解压缩后,以备后用 https://download.nacos.io/nacos-server/nacos-server-2.5.0.zip?spm5238cd80.6a33be36.0.0.2eb81e5d7mQ…...
ComfyUI实现老照片修复——AI修复老照片(ComfyUI-ReActor / ReSwapper)解决天坑问题及加速pip下载
AI修复老照片,试试吧,不一定好~~哈哈 2023年4月曾用过ComfyUI,当时就感慨这个工具和虚幻的蓝图很像,以后肯定是专业人玩的。 2024年我写代码去了,AI做图没太关注,没想到,现在ComfyUI真的变成了工…...
Win11画图工具没了怎么重新安装
有些朋友想要简单地把图片另存为其他格式,或是进行一些编辑,但是发现自己的Win11系统里面没有画图工具,这可能是因为用户安装的是精简版的Win11系统,解决方法自然是重新安装一下画图工具,具体应该怎么做呢?…...
Git Bash 配置 zsh
博客食用更佳 博客链接 安装 zsh 安装 Zsh 安装 Oh-my-zsh github仓库 sh -c "$(curl -fsSL https://install.ohmyz.sh/)"让 zsh 成为 git bash 默认终端 vi ~/.bashrc写入: if [ -t 1 ]; thenexec zsh fisource ~/.bashrc再重启即可。 更换主题 …...
《STL基础之hashtable》
【hashtable导读】STL为大家提供了丰富的容器,hashtable也是值得大家学习和掌握的基础容器,而且面试官经常会把它和hashmap混在一起,让同学们做下区分。因此关于hashtable的一些特性,比如:底层的数据结构、插入、查找元…...
Vue3组件重构实战:从Geeker-Admin拆解DataTable的最佳实践
一、前言 背景与动机 在当前的开发实践中,我们选择了开源项目 Geeker-Admin 作为前端框架的二次开发基础。其内置的 ProTable.vue 组件虽然提供了一定程度的开箱即用性,但在实际业务场景中逐渐暴露出设计上的局限性,尤其是其将 搜索条件表单…...
小智 AI 聊天机器人
小智 AI 聊天机器人 (XiaoZhi AI Chatbot) 👉参考源项目复现 👉 ESP32SenseVoiceQwen72B打造你的AI聊天伴侣!【bilibili】 👉 手工打造你的 AI 女友,新手入门教程【bilibili】 项目目的 本…...
关于圆周率的新认知
从自然对数底 的泰勒展开, 可以得出 的展开式, 它可以被认为是,以 0 为周期的单位 1 ,以 1 为周期的单位 1 ,以 2 为周期的单位 1 等所有自然数为周期的单位 1 分阶段合成(体现为阶乘的倒数)之…...
【趋势】《2024—2026金融科技十大趋势预测》一览
本白皮书基于新华三在金融行业的前沿实践和IDC的全球研究成果,深入分析了金融科技领域的十大关键趋势,旨在为金融机构提供前瞻性的战略指导和业务创新的参考。 导言 当前,在地缘政治冲突加剧、商业经济市场环境高度不确定、数字化业务加速发展的背景下,金融行业处于深度变…...
vim 中粘贴内容时提示: -- (insert) VISUAL --
目录 问题现象:解决方法:问题原因: 问题现象: 使用 vim 打开一个文本文件,切换到编辑模式后,复制内容进行粘贴时有以下提示: 解决方法: 在命令行模式下禁用鼠标支持 :set mouse …...
CAPL高级应用
CAPL高级应用 目录 CAPL高级应用1. 引言2. 多线程编程2.1 多线程编程简介2.2 多线程编程实现3. 数据库操作3.1 数据库操作简介3.2 数据库操作实现4. 网络通信4.1 网络通信简介4.2 网络通信实现5. 案例说明5.1 案例1:多线程编程实现5.2 案例2:数据库操作实现5.3 案例3:网络通…...
基于微信小程序的网上订餐管理系统
作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…...
python的设计模式
设计模式是解决软件设计中常见问题的可重用解决方案。Python 作为一种灵活且强大的编程语言,支持多种设计模式的实现。以下是 Python 中常见的几种设计模式及其示例: 1. 单例模式(Singleton Pattern) 确保一个类只有一个实例&…...
EventBus事件总线的使用以及优缺点
EventBus EventBus (事件总线)是一种组件通信方法,基于发布/订阅模式,能够实现业务代码解耦,提高开发效率 发布/订阅模式 发布/订阅模式是一种设计模式,当一个对象的状态发生变化时,所有依赖…...
C++解决走迷宫问题:DFS、BFS算法应用
文章目录 思路:DFSBFSBFS和DFS的特点BFS 与 DFS 的区别BFS 的优点BFS 时间复杂度深度优先搜索(DFS)的优点深度优先搜索(DFS)的时间复杂度解释:空间复杂度总结:例如下面的迷宫: // 迷宫的表示:0表示可以走,1表示障碍 vector<vector<int>> maze = {{0, 0,…...
2025春招 SpringCloud 面试题汇总
大家好,我是 V 哥。SpringCloud 在面试中属于重灾区,不仅是基础概念、组件细节,还有高级特性、性能优化,关键是项目实践经验的解决方案,都是需要掌握的内容,正所谓打有准备的仗,秒杀面试官&…...
PostGIS笔记:PostgreSQL 数据库与用户 基础操作
数据库基础操作包括数据模型的实现、添加数据、查询数据、视图应用、创建日志规则等。我这里是在Ubuntu系统学习的数据库管理。Windows平台与Linux平台在命令上几乎无差异,只是说在 Windows 上虽然也能运行良好,但在性能、稳定性、功能扩展等方面&#x…...
Selenium配合Cookies实现网页免登录
文章目录 前言1 方案一:使用Chrome用户数据目录2 方案二:手动获取并保存Cookies,后续使用保存的Cookies3 注意事项 前言 在进行使用Selenium进行爬虫、网页自动化操作时,登录往往是一个必须解决的问题,但是Selenium每次…...
HarmonyOS简介:HarmonyOS核心技术理念
核心理念 一次开发、多端部署可分可合、自由流转统一生态、原生智能 一次开发、多端部署 可分可合 自由流转 自由流转可分为跨端迁移和多端协同两种情况 统一生态 支持业界主流跨平台开发框架,通过多层次的开放能力提供统一接入标准,实现三方框架快速…...
Unity URP 获取/设置 Light-Indirect Multiplier
Unity URP 获取/设置 Light-Indirect Multiplier 他喵的代码的字段名称叫:bounceIntensity ~~~~~~...
计算机网络 (60)蜂窝移动通信网
一、定义与原理 蜂窝移动通信网是指将一个服务区分为若干蜂窝状相邻小区并采用频率空间复用技术的移动通信网。其原理在于,将移动通信服务区划分成许多以正六边形为基本几何图形的覆盖区域,称为蜂窝小区。每个小区设置一个基站,负责本小区内移…...
解决.NET程序通过网盘传到Linux和macOS不能运行的问题
问题描述:.net程序用U盘传到虚拟机macOS和Linux可以正常运行,但是网盘传过去就不行。 解决方法: 这是文件权限的问题。当你通过U盘将文件传输到虚拟机的macOS和Linux系统时,文件的权限和所有权可能得到了保留或正确设置。但如果…...
LeetCode | 不同路径
一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角(在下图中标记为 “Finish” )。 问总共有多少条不同的路径? 示例 1…...
渗透测试技法之口令安全
一、口令安全威胁 口令泄露途径 代码与文件存储不当:在软件开发和系统维护过程中,开发者可能会将口令以明文形式存储在代码文件、配置文件或注释中。例如,在开源代码托管平台 GitHub 上,一些开发者由于疏忽,将包含数据…...
【C语言】main函数解析
一、前言 在学习编程的过程中,我们很早就接触到了main函数。在Linux系统中,当你运行一个可执行文件(例如 ./a.out)时,如果需要传入参数,就需要了解main函数的用法。本文将详细解析main函数的参数ÿ…...
Vue3笔记——(二)
015 生命周期 组件的生命周期: 【时刻】 【调用特定的函数】 vue2生命周期 创建 beforeCreate、 created 挂载 beforeMounte、mounted 更新 beforeUpdate、updated 销毁 beforeDestroy、destroyed 生命周期、生命周期函数、生命周期钩子 vue3生命周期 创建 setup 挂…...
linux文件I/O
open 用于打开一个文件并返回一个文件描述符。文件描述符是一个整数,它在后续的文件操作中用于标识文件。 原型: int open(const char *pathname, int flags, mode_t mode);pathname:要打开的文件的路径flags:指定文件打开方式…...
利用双指针一次遍历实现”找到“并”删除“单链表倒数第K个节点(力扣题目为例)
Problem: 19. 删除链表的倒数第 N 个结点 文章目录 题目描述思路复杂度Code 题目描述 思路 1.欲找到倒数第k个节点,即是找到正数的第n-k1、其中n为单链表中节点的个数个节点。 2.为实现只遍历一次单链表,我们先可以使一个指针p1指向链表头部再让其先走k步…...
MySQL 8 不开通 CLONE 插件,建立主从关系
文章目录 前言一、主库操作二、从库操作三、主库操作四、测试总结 前言 MySQL 版本:8.0.36 MySQL 8 通过 CLONE 插件,搭建主从数据库详情参考链接文章 主库不开通 CLONE 插件,如何建立主从关系呢?本文简单介绍一下 一、主库操作…...
活动回顾和预告|微软开发者社区 Code Without Barriers 上海站首场活动成功举办!
Code Without Barriers 上海活动回顾 Code Without Barriers:AI & DATA 深入探索人工智能与数据如何变革行业 2025年1月16日,微软开发者社区 Code Without Barriers (CWB)携手 She Rewires 她原力在大中华区的首场活动“AI &…...
Direct Preference Optimization (DPO): 一种无需强化学习的语言模型偏好优化方法
论文地址:https://arxiv.org/pdf/2305.18290 1. 背景与挑战 近年来,大规模无监督语言模型(LM)在知识获取和推理能力方面取得了显著进展,但如何精确控制其行为仍是一个难题。 现有的方法通常通过**强化学习从人类反馈&…...
搜狐Android开发(安卓)面试题及参考答案
ViewModel 的作用及原理是什么? ViewModel 是 Android 架构组件中的一部分,主要作用是在 MVVM 架构中充当数据与视图之间的桥梁。它负责为视图准备数据,并处理与数据相关的业务逻辑,让视图(Activity、Fragment 等)专注于展示数据和与用户交互。比如在一个新闻应用中,Vie…...
蓝牙的一些基础知识(TODO)
前阵工作中遇到的。 iOS 和 iPadOS 支持的蓝牙描述文件 - 官方 Apple 支持 (中国) 在树莓派上定制蓝牙 Profile 通常需要修改或创建自定义的 Bluetooth 服务 (Profile) 来实现特定的功能,例如定制 Audio Sink、HID(Human Interface Device)、…...
Redis实战(黑马点评)——涉及session、redis存储验证码,双拦截器处理请求
项目整体介绍 数据库表介绍 基于session的短信验证码登录与注册 controller层 // 获取验证码PostMapping("code")public Result sendCode(RequestParam("phone") String phone, HttpSession session) {return userService.sendCode(phone, session);}// 获…...
WPF常见面试题解答
以下是WPF(Windows Presentation Foundation)面试中常见的问题及解答,涵盖基础概念、高级功能和实际应用,帮助你更好地准备面试: 基础概念 什么是WPF? WPF是微软开发的用于构建桌面应用程序的UI框架&#x…...
Nginx前端后端共用一个域名如何配置
在 Nginx 中配置前端和后端共用一个域名的情况,通常是通过路径或子路径将请求转发到不同的服务。以下是一个示例配置,假设: 前端静态文件在 /var/www/frontend/。 后端 API 服务运行在 http://127.0.0.1:5000。 域名是 example.comÿ…...
DeepSeek-R1-Distill-Qwen-1.5B:最佳小型LLM?
DeepSeek掀起了生成式AI领域的风暴。 首先推出DeepSeek-v3,现在推出DeepSeek-R1,这两款模型都打破了所有基准,并且完全开源。 但今天我们不是在讨论这两款超级模型,而是讨论DeepSeek-R1的一个蒸馏版本——DeepSeek-R1-Distill-Qwen-1.5B,它可能是今天被低估的版本,虽然…...
wampserver + phpstrom 调试配置
step 1 点击任务栏wampserver图标->php->php.ini[apache module] 在文件最后面,确保这些值被定义且跟以下的一样 xdebug.mode debug xdebug.start_with_request yes xdebug.client_port 9003 xdebug.client_host 127.0.0.1step 2 按如下配置 step3 下断点,运行即…...
MySQL分表自动化创建的实现方案(存储过程、事件调度器)
《MySQL 新年度自动分表创建项目方案》 一、项目目的 在数据库应用场景中,随着数据量的不断增长,单表存储数据可能会面临性能瓶颈,例如查询、插入、更新等操作的效率会逐渐降低。分表是一种有效的优化策略,它将数据分散存储在多…...
RabbitMQ 架构分析
文章目录 前言一、RabbitMQ架构分析1、Broker2、Vhost3、Producer4、Messages5、Connections6、Channel7、Exchange7、Queue8、Consumer 二、消息路由机制1、Direct Exchange2、Topic Exchange3、Fanout Exchange4、Headers Exchange5、notice5.1、备用交换机(Alter…...
Spring Boot 无缝集成SpringAI的函数调用模块
这是一个 完整的 Spring AI 函数调用实例,涵盖从函数定义、注册到实际调用的全流程,以「天气查询」功能为例,结合代码详细说明: 1. 环境准备 1.1 添加依赖 <!-- Spring AI OpenAI --> <dependency><groupId>o…...
如何跨互联网adb连接到远程手机-蓝牙电话集中维护
如何跨互联网adb连接到远程手机-蓝牙电话集中维护 --ADB连接专题 一、前言 随便找一个手机,安装一个App并简单设置一下,就可以跨互联网的ADB连接到这个手机,从而远程操控这个手机做各种操作。你敢相信吗?而这正是本篇想要描述的…...