《Keras 3 在 TPU 上的肺炎分类》
Keras 3 在 TPU 上的肺炎分类
作者:Amy MiHyun Jang
创建日期:2020/07/28
最后修改时间:2024/02/12
描述:TPU 上的医学图像分类。
在 Colab 中查看
GitHub 源
简介 + 设置
本教程将介绍如何构建 X 射线图像分类模型 预测 X 线扫描是否显示肺炎的存在。
import re
import os
import random
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plttry:tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()print("Device:", tpu.master())strategy = tf.distribute.TPUStrategy(tpu)
except:strategy = tf.distribute.get_strategy()
print("Number of replicas:", strategy.num_replicas_in_sync)
Device: grpc://10.0.27.122:8470 INFO:tensorflow:Initializing the TPU system: grpc://10.0.27.122:8470 INFO:tensorflow:Initializing the TPU system: grpc://10.0.27.122:8470 INFO:tensorflow:Clearing out eager caches INFO:tensorflow:Clearing out eager caches INFO:tensorflow:Finished initializing TPU system. INFO:tensorflow:Finished initializing TPU system. WARNING:absl:[`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is deprecated, please use the non experimental symbol [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) instead. INFO:tensorflow:Found TPU system: INFO:tensorflow:Found TPU system: INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) Number of replicas: 8
我们需要一个指向我们数据的 Google Cloud 链接,以便使用 TPU 加载数据。 下面,我们定义了我们将在此示例中使用的关键配置参数。 要在 TPU 上运行,此示例必须在 Colab 上,并选择 TPU 运行时。
AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 25 * strategy.num_replicas_in_sync
IMAGE_SIZE = [180, 180]
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
加载数据
我们使用的 Cell 的胸部 X 光数据将数据分为 training 和 test 文件。让我们首先加载训练 TFRecords。
train_images = tf.data.TFRecordDataset("gs://download.tensorflow.org/data/ChestXRay2017/train/images.tfrec"
)
train_paths = tf.data.TFRecordDataset("gs://download.tensorflow.org/data/ChestXRay2017/train/paths.tfrec"
)ds = tf.data.Dataset.zip((train_images, train_paths))
让我们数一数我们有多少次健康/正常的胸部 X 光片,以及有多少 肺炎胸部 X 光片我们有:
COUNT_NORMAL = len([filenamefor filename in train_pathsif "NORMAL" in filename.numpy().decode("utf-8")]
)
print("Normal images count in training set: " + str(COUNT_NORMAL))COUNT_PNEUMONIA = len([filenamefor filename in train_pathsif "PNEUMONIA" in filename.numpy().decode("utf-8")]
)
print("Pneumonia images count in training set: " + str(COUNT_PNEUMONIA))
Normal images count in training set: 1349 Pneumonia images count in training set: 3883
请注意,被归类为肺炎的图像比正常情况多得多。这 显示我们的数据不平衡。我们稍后会纠正这种不平衡 在我们的笔记本中。
我们想将每个文件名映射到相应的 (image, label) 对。以下内容 方法将帮助我们做到这一点。
由于我们只有两个标签,因此我们将对标签进行编码,以便 或 肺炎和/或表示正常。1
True
0
False
def get_label(file_path):# convert the path to a list of path componentsparts = tf.strings.split(file_path, "/")# The second to last is the class-directoryif parts[-2] == "PNEUMONIA":return 1else:return 0def decode_img(img):# convert the compressed string to a 3D uint8 tensorimg = tf.image.decode_jpeg(img, channels=3)# resize the image to the desired size.return tf.image.resize(img, IMAGE_SIZE)def process_path(image, path):label = get_label(path)# load the raw data from the file as a stringimg = decode_img(image)return img, labelds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
让我们将数据拆分为训练和验证数据集。
ds = ds.shuffle(10000)
train_ds = ds.take(4200)
val_ds = ds.skip(4200)
让我们可视化 (image, label) 对的形状。
for image, label in train_ds.take(1):print("Image shape: ", image.numpy().shape)print("Label: ", label.numpy())
Image shape: (180, 180, 3) Label: False
同时加载测试数据并设置其格式。
test_images = tf.data.TFRecordDataset("gs://download.tensorflow.org/data/ChestXRay2017/test/images.tfrec"
)
test_paths = tf.data.TFRecordDataset("gs://download.tensorflow.org/data/ChestXRay2017/test/paths.tfrec"
)
test_ds = tf.data.Dataset.zip((test_images, test_paths))test_ds = test_ds.map(process_path, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.batch(BATCH_SIZE)
可视化数据集
首先,让我们使用缓冲预取,这样我们就可以在没有 I/O 的情况下从磁盘生成数据 变为阻塞。
请注意,大型图像数据集不应缓存在内存中。我们在这里做 因为数据集不是很大,我们想在 TPU 上训练。
def prepare_for_training(ds, cache=True):# This is a small dataset, only load it once, and keep it in memory.# use `.cache(filename)` to cache preprocessing work for datasets that don't# fit in memory.if cache:if isinstance(cache, str):ds = ds.cache(cache)else:ds = ds.cache()ds = ds.batch(BATCH_SIZE)# `prefetch` lets the dataset fetch batches in the background while the model# is training.ds = ds.prefetch(buffer_size=AUTOTUNE)return ds
调用训练数据的下一个批次迭代。
train_ds = prepare_for_training(train_ds)
val_ds = prepare_for_training(val_ds)image_batch, label_batch = next(iter(train_ds))
定义在批处理中显示图像的方法。
def show_batch(image_batch, label_batch):plt.figure(figsize=(10, 10))for n in range(25):ax = plt.subplot(5, 5, n + 1)plt.imshow(image_batch[n] / 255)if label_batch[n]:plt.title("PNEUMONIA")else:plt.title("NORMAL")plt.axis("off")
由于该方法将 NumPy 数组作为其参数,因此请在 batches 以 NumPy 数组形式返回张量。
show_batch(image_batch.numpy(), label_batch.numpy())
构建 CNN
为了使我们的模型更加模块化和更容易理解,让我们定义一些块。如 我们正在构建一个卷积神经网络,我们将创建一个卷积块和一个密集的 layer 块。
此 CNN 的体系结构受到本文的启发。
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'import keras
from keras import layersdef conv_block(filters, inputs):x = layers.SeparableConv2D(filters, 3, activation="relu", padding="same")(inputs)x = layers.SeparableConv2D(filters, 3, activation="relu", padding="same")(x)x = layers.BatchNormalization()(x)outputs = layers.MaxPool2D()(x)return outputsdef dense_block(units, dropout_rate, inputs):x = layers.Dense(units, activation="relu")(inputs)x = layers.BatchNormalization()(x)outputs = layers.Dropout(dropout_rate)(x)return outputs
以下方法将定义函数来为我们构建模型。
图像最初的值范围为 [0, 255]。CNN 与较小的 CNN 配合得更好 numbers 来调整它,以便根据我们的输入进行缩小。
Dropout 图层很重要,因为它们 降低模型过拟合的可能性。我们希望用一个具有一个节点的层来结束模型,因为这将是确定 X 射线是否显示的二进制输出 存在肺炎。Dense
def build_model():inputs = keras.Input(shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3))x = layers.Rescaling(1.0 / 255)(inputs)x = layers.Conv2D(16, 3, activation="relu", padding="same")(x)x = layers.Conv2D(16, 3, activation="relu", padding="same")(x)x = layers.MaxPool2D()(x)x = conv_block(32, x)x = conv_block(64, x)x = conv_block(128, x)x = layers.Dropout(0.2)(x)x = conv_block(256, x)x = layers.Dropout(0.2)(x)x = layers.Flatten()(x)x = dense_block(512, 0.7, x)x = dense_block(128, 0.5, x)x = dense_block(64, 0.3, x)outputs = layers.Dense(1, activation="sigmoid")(x)model = keras.Model(inputs=inputs, outputs=outputs)return model
更正数据不平衡
在这个例子的前面部分,我们看到数据不平衡,分类的图像更多 作为肺炎比正常。我们将通过使用类加权来纠正这个问题:
initial_bias = np.log([COUNT_PNEUMONIA / COUNT_NORMAL])
print("Initial bias: {:.5f}".format(initial_bias[0]))TRAIN_IMG_COUNT = COUNT_NORMAL + COUNT_PNEUMONIA
weight_for_0 = (1 / COUNT_NORMAL) * (TRAIN_IMG_COUNT) / 2.0
weight_for_1 = (1 / COUNT_PNEUMONIA) * (TRAIN_IMG_COUNT) / 2.0class_weight = {0: weight_for_0, 1: weight_for_1}print("Weight for class 0: {:.2f}".format(weight_for_0))
print("Weight for class 1: {:.2f}".format(weight_for_1))
Initial bias: 1.05724 Weight for class 0: 1.94 Weight for class 1: 0.67
类别 (Normal) 的权重比类别 (Pneumonia) 的权重高得多。由于法线图像较少,因此将对每个法线图像进行加权 more 来平衡数据,因为 CNN 在训练数据平衡时效果最佳。0
1
训练模型
定义回调
checkpoint 回调保存了模型的最佳权重,因此下次我们想使用 模型,我们不必花时间训练它。提前停止回调停止 当模型开始停滞时,甚至更糟糕的是,当 模型开始过拟合。
checkpoint_cb = keras.callbacks.ModelCheckpoint("xray_model.keras", save_best_only=True)early_stopping_cb = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True
)
我们还希望调整我们的学习率。学习率过高会导致模型 发散。学习速率太小会导致模型太慢。我们 实现下面的指数学习率调度方法。
initial_learning_rate = 0.015
lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)
拟合模型
对于我们的指标,我们希望包括 precision 和 recall,因为它们将为 更了解我们的模型有多好。准确率告诉我们 labels 是正确的。由于我们的数据不平衡,准确性可能会给人一种歪曲的感觉 一个好的模型(即始终预测 PNEUMONIA 的模型将准确率为 74%,但并非如此 一个很好的模型)。
精度是 TP 和假阳性之和的真阳性 (TP) 数 (FP) 的 Shell。它显示标记的阳性实际正确的比例。
召回率是 TP 和假负数 (FN) 之和的 TP 数。它显示了什么 实际阳性的比例是正确的。
由于图像只有两个可能的标签,因此我们将使用 二进制交叉熵损失。当我们拟合模型时,请记住指定类权重 我们之前定义过。因为我们使用的是 TPU,所以训练会很快 - 小于 2 分钟。
with strategy.scope():model = build_model()METRICS = [keras.metrics.BinaryAccuracy(),keras.metrics.Precision(name="precision"),keras.metrics.Recall(name="recall"),]model.compile(optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),loss="binary_crossentropy",metrics=METRICS,)history = model.fit(train_ds,epochs=100,validation_data=val_ds,class_weight=class_weight,callbacks=[checkpoint_cb, early_stopping_cb],
)
Epoch 1/100 WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. 21/21 [==============================] - 12s 568ms/step - loss: 0.5857 - binary_accuracy: 0.6960 - precision: 0.8887 - recall: 0.6733 - val_loss: 34.0149 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 2/100 21/21 [==============================] - 3s 128ms/step - loss: 0.2916 - binary_accuracy: 0.8755 - precision: 0.9540 - recall: 0.8738 - val_loss: 97.5194 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 3/100 21/21 [==============================] - 4s 167ms/step - loss: 0.2384 - binary_accuracy: 0.9002 - precision: 0.9663 - recall: 0.8964 - val_loss: 27.7902 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 4/100 21/21 [==============================] - 4s 173ms/step - loss: 0.2046 - binary_accuracy: 0.9145 - precision: 0.9725 - recall: 0.9102 - val_loss: 10.8302 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 5/100 21/21 [==============================] - 4s 174ms/step - loss: 0.1841 - binary_accuracy: 0.9279 - precision: 0.9733 - recall: 0.9279 - val_loss: 3.5860 - val_binary_accuracy: 0.7103 - val_precision: 0.7162 - val_recall: 0.9879 Epoch 6/100 21/21 [==============================] - 4s 185ms/step - loss: 0.1600 - binary_accuracy: 0.9362 - precision: 0.9791 - recall: 0.9337 - val_loss: 0.3014 - val_binary_accuracy: 0.8895 - val_precision: 0.8973 - val_recall: 0.9555 Epoch 7/100 21/21 [==============================] - 3s 130ms/step - loss: 0.1567 - binary_accuracy: 0.9393 - precision: 0.9798 - recall: 0.9372 - val_loss: 0.6763 - val_binary_accuracy: 0.7810 - val_precision: 0.7760 - val_recall: 0.9771 Epoch 8/100 21/21 [==============================] - 3s 131ms/step - loss: 0.1532 - binary_accuracy: 0.9421 - precision: 0.9825 - recall: 0.9385 - val_loss: 0.3169 - val_binary_accuracy: 0.8895 - val_precision: 0.8684 - val_recall: 0.9973 Epoch 9/100 21/21 [==============================] - 4s 184ms/step - loss: 0.1457 - binary_accuracy: 0.9431 - precision: 0.9822 - recall: 0.9401 - val_loss: 0.2064 - val_binary_accuracy: 0.9273 - val_precision: 0.9840 - val_recall: 0.9136 Epoch 10/100 21/21 [==============================] - 3s 132ms/step - loss: 0.1201 - binary_accuracy: 0.9521 - precision: 0.9869 - recall: 0.9479 - val_loss: 0.4364 - val_binary_accuracy: 0.8605 - val_precision: 0.8443 - val_recall: 0.9879 Epoch 11/100 21/21 [==============================] - 3s 127ms/step - loss: 0.1200 - binary_accuracy: 0.9510 - precision: 0.9863 - recall: 0.9469 - val_loss: 0.5197 - val_binary_accuracy: 0.8508 - val_precision: 1.0000 - val_recall: 0.7922 Epoch 12/100 21/21 [==============================] - 4s 186ms/step - loss: 0.1077 - binary_accuracy: 0.9581 - precision: 0.9870 - recall: 0.9559 - val_loss: 0.1349 - val_binary_accuracy: 0.9486 - val_precision: 0.9587 - val_recall: 0.9703 Epoch 13/100 21/21 [==============================] - 4s 173ms/step - loss: 0.0918 - binary_accuracy: 0.9650 - precision: 0.9914 - recall: 0.9611 - val_loss: 0.0926 - val_binary_accuracy: 0.9700 - val_precision: 0.9837 - val_recall: 0.9744 Epoch 14/100 21/21 [==============================] - 3s 130ms/step - loss: 0.0996 - binary_accuracy: 0.9612 - precision: 0.9913 - recall: 0.9559 - val_loss: 0.1811 - val_binary_accuracy: 0.9419 - val_precision: 0.9956 - val_recall: 0.9231 Epoch 15/100 21/21 [==============================] - 3s 129ms/step - loss: 0.0898 - binary_accuracy: 0.9643 - precision: 0.9901 - recall: 0.9614 - val_loss: 0.1525 - val_binary_accuracy: 0.9486 - val_precision: 0.9986 - val_recall: 0.9298 Epoch 16/100 21/21 [==============================] - 3s 128ms/step - loss: 0.0941 - binary_accuracy: 0.9621 - precision: 0.9904 - recall: 0.9582 - val_loss: 0.5101 - val_binary_accuracy: 0.8527 - val_precision: 1.0000 - val_recall: 0.7949 Epoch 17/100 21/21 [==============================] - 3s 125ms/step - loss: 0.0798 - binary_accuracy: 0.9636 - precision: 0.9897 - recall: 0.9607 - val_loss: 0.1239 - val_binary_accuracy: 0.9622 - val_precision: 0.9875 - val_recall: 0.9595 Epoch 18/100 21/21 [==============================] - 3s 126ms/step - loss: 0.0821 - binary_accuracy: 0.9657 - precision: 0.9911 - recall: 0.9623 - val_loss: 0.1597 - val_binary_accuracy: 0.9322 - val_precision: 0.9956 - val_recall: 0.9096 Epoch 19/100 21/21 [==============================] - 3s 143ms/step - loss: 0.0800 - binary_accuracy: 0.9657 - precision: 0.9917 - recall: 0.9617 - val_loss: 0.2538 - val_binary_accuracy: 0.9109 - val_precision: 1.0000 - val_recall: 0.8758 Epoch 20/100 21/21 [==============================] - 3s 127ms/step - loss: 0.0605 - binary_accuracy: 0.9738 - precision: 0.9950 - recall: 0.9694 - val_loss: 0.6594 - val_binary_accuracy: 0.8566 - val_precision: 1.0000 - val_recall: 0.8003 Epoch 21/100 21/21 [==============================] - 4s 167ms/step - loss: 0.0726 - binary_accuracy: 0.9733 - precision: 0.9937 - recall: 0.9701 - val_loss: 0.0593 - val_binary_accuracy: 0.9816 - val_precision: 0.9945 - val_recall: 0.9798 Epoch 22/100 21/21 [==============================] - 3s 126ms/step - loss: 0.0577 - binary_accuracy: 0.9783 - precision: 0.9951 - recall: 0.9755 - val_loss: 0.1087 - val_binary_accuracy: 0.9729 - val_precision: 0.9931 - val_recall: 0.9690 Epoch 23/100 21/21 [==============================] - 3s 125ms/step - loss: 0.0652 - binary_accuracy: 0.9729 - precision: 0.9924 - recall: 0.9707 - val_loss: 1.8465 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 24/100 21/21 [==============================] - 3s 124ms/step - loss: 0.0538 - binary_accuracy: 0.9783 - precision: 0.9951 - recall: 0.9755 - val_loss: 1.5769 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 25/100 21/21 [==============================] - 4s 167ms/step - loss: 0.0549 - binary_accuracy: 0.9776 - precision: 0.9954 - recall: 0.9743 - val_loss: 0.0590 - val_binary_accuracy: 0.9777 - val_precision: 0.9904 - val_recall: 0.9784 Epoch 26/100 21/21 [==============================] - 3s 131ms/step - loss: 0.0677 - binary_accuracy: 0.9719 - precision: 0.9924 - recall: 0.9694 - val_loss: 2.6008 - val_binary_accuracy: 0.6928 - val_precision: 0.9977 - val_recall: 0.5735 Epoch 27/100 21/21 [==============================] - 3s 127ms/step - loss: 0.0469 - binary_accuracy: 0.9833 - precision: 0.9971 - recall: 0.9804 - val_loss: 1.0184 - val_binary_accuracy: 0.8605 - val_precision: 0.9983 - val_recall: 0.8070 Epoch 28/100 21/21 [==============================] - 3s 126ms/step - loss: 0.0501 - binary_accuracy: 0.9790 - precision: 0.9961 - recall: 0.9755 - val_loss: 0.3737 - val_binary_accuracy: 0.9089 - val_precision: 0.9954 - val_recall: 0.8772 Epoch 29/100 21/21 [==============================] - 3s 128ms/step - loss: 0.0548 - binary_accuracy: 0.9798 - precision: 0.9941 - recall: 0.9784 - val_loss: 1.2928 - val_binary_accuracy: 0.7907 - val_precision: 1.0000 - val_recall: 0.7085 Epoch 30/100 21/21 [==============================] - 3s 129ms/step - loss: 0.0370 - binary_accuracy: 0.9860 - precision: 0.9980 - recall: 0.9829 - val_loss: 0.1370 - val_binary_accuracy: 0.9612 - val_precision: 0.9972 - val_recall: 0.9487 Epoch 31/100 21/21 [==============================] - 3s 125ms/step - loss: 0.0585 - binary_accuracy: 0.9819 - precision: 0.9951 - recall: 0.9804 - val_loss: 1.1955 - val_binary_accuracy: 0.6870 - val_precision: 0.9976 - val_recall: 0.5655 Epoch 32/100 21/21 [==============================] - 3s 140ms/step - loss: 0.0813 - binary_accuracy: 0.9695 - precision: 0.9934 - recall: 0.9652 - val_loss: 1.0394 - val_binary_accuracy: 0.8576 - val_precision: 0.9853 - val_recall: 0.8138 Epoch 33/100 21/21 [==============================] - 3s 128ms/step - loss: 0.1111 - binary_accuracy: 0.9555 - precision: 0.9870 - recall: 0.9524 - val_loss: 4.9438 - val_binary_accuracy: 0.5911 - val_precision: 1.0000 - val_recall: 0.4305 Epoch 34/100 21/21 [==============================] - 3s 130ms/step - loss: 0.0680 - binary_accuracy: 0.9726 - precision: 0.9921 - recall: 0.9707 - val_loss: 2.8822 - val_binary_accuracy: 0.7267 - val_precision: 0.9978 - val_recall: 0.6208 Epoch 35/100 21/21 [==============================] - 4s 187ms/step - loss: 0.0784 - binary_accuracy: 0.9712 - precision: 0.9892 - recall: 0.9717 - val_loss: 0.3940 - val_binary_accuracy: 0.9390 - val_precision: 0.9942 - val_recall: 0.9204
可视化模型性能
让我们绘制训练集和验证集的模型准确率和损失。请注意, 没有为此笔记本指定随机种子。对于您的笔记本,可能会有轻微的 方差。
fig, ax = plt.subplots(1, 4, figsize=(20, 3))
ax = ax.ravel()for i, met in enumerate(["precision", "recall", "binary_accuracy", "loss"]):ax[i].plot(history.history[met])ax[i].plot(history.history["val_" + met])ax[i].set_title("Model {}".format(met))ax[i].set_xlabel("epochs")ax[i].set_ylabel(met)ax[i].legend(["train", "val"])
我们看到模型的准确率约为 95%。
预测和评估结果
让我们根据测试数据评估模型!
model.evaluate(test_ds, return_dict=True)
4/4 [==============================] - 3s 708ms/step - loss: 0.9718 - binary_accuracy: 0.7901 - precision: 0.7524 - recall: 0.9897 {'binary_accuracy': 0.7900640964508057, 'loss': 0.9717951416969299, 'precision': 0.752436637878418, 'recall': 0.9897436499595642}
我们看到,测试数据的准确性低于验证的准确性 设置。这可能表示过拟合。
我们的召回率大于我们的精确率,这表明几乎所有的肺炎图像都是 识别正确,但一些正常图像被错误识别。我们应该致力于 提高我们的精度。
for image, label in test_ds.take(1):plt.imshow(image[0] / 255.0)plt.title(CLASS_NAMES[label[0].numpy()])prediction = model.predict(test_ds.take(1))[0]
scores = [1 - prediction, prediction]for score, name in zip(scores, CLASS_NAMES):print("This image is %.2f percent %s" % ((100 * score), name))
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:3: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index This is separate from the ipykernel package so we can avoid doing imports until This image is 47.19 percent NORMAL This image is 52.81 percent PNEUMONIA
相关文章:
《Keras 3 在 TPU 上的肺炎分类》
Keras 3 在 TPU 上的肺炎分类 作者:Amy MiHyun Jang创建日期:2020/07/28最后修改时间:2024/02/12描述:TPU 上的医学图像分类。 (i) 此示例使用 Keras 3 在 Colab 中查看 GitHub 源 简介 设置 本教程将介…...
团体程序设计天梯赛-练习集——L1-012 计算指数
前言 这道题简单至极,几行代码就全都解决了。这次多来几个写法; L1-012 计算指数 真的没骗你,这道才是简单题 —— 对任意给定的不超过 10 的正整数 n,要求你输出 2 的n次方 。不难吧? 输入格式: 输入…...
【机器学习】制造业转型:机器学习如何推动工业 4.0 的深度发展
我的个人主页 我的领域:人工智能篇,希望能帮助到大家!!!👍点赞 收藏❤ 引言 在当今科技飞速发展的时代,制造业正经历着前所未有的变革,工业4.0的浪潮席卷而来。工业4.0旨在通过将…...
深度学习篇---数据集分类
文章目录 前言第一部分:VOC数据集标签、COCO数据集格式1.VOC数据集标签的特点及优缺点特点优点缺点 2.COCO数据集标签的特点及优缺点特点优点缺点 3.YOLO数据集标签的特点及优缺点特点优点缺点 第二部分:VOC格式和YOLO格式1.VOC格式3.YOLO格式3.区别(1)文…...
力扣hot100之螺旋矩阵
class Solution:def spiralOrder(self, matrix: List[List[int]]) -> List[int]:# 用四个数对应4个遍历的方向[0, 1, 2, 3] - [右,下,左,上]go_state 0 # 起始必须向右# record_matrix [[0] * n for _ in range(m)]n_0, n_1, n_2, n_3 …...
嵌入式知识点总结 C/C++ 专题提升(一)-关键字
针对于嵌入式软件杂乱的知识点总结起来,提供给读者学习复习对下述内容的强化。 目录 1.C语言宏中"#“和"##"的用法 1.1.(#)字符串化操作符 1.2.(##)符号连接操作符 2.关键字volatile有什么含意?并举出三个不同的例子? 2.1.并行设备的硬件寄存…...
Android 高版本如何获取App安装列表?
有个需求需要获取App内的安装列表,但是现在在高版本Android中,只能获取到一部分App效果,我获取的代码如下: val calendar Calendar.getInstance()val packageManager context.packageManagerval usageStatsManager context.getSystemService(Context.USAGE_STATS_SERVICE) …...
StarGAN:原理、用途及最新发展
一、引言 StarGAN是一种具有广泛应用的生成模型,具有同时生成多种类别数据的能力。它由Yunjey Choi等人在2017年提出,旨在实现图像多域间迁移,尤其适用于人脸属性转换。StarGAN的提出,标志着生成对抗网络(Generative A…...
TCP报文格式与核心机制
TCP与UDP都是传输层的重要协议,TCP的特性包括有连接、可靠传输、面向字节流、全双工。 TCP的报文格式如下: 一、报文格式 1.源端口号、目的端口号 源端口和目的端口是五元组中重要的两个性质,源端口即数据是从哪里来的,目的端…...
【2024年华为OD机试】 (B卷,100分)- 金字塔,BOSS的收入(Java JS PythonC/C++)
一、问题描述 微商模式收入计算 题目描述 微商模式中,下级每赚 100 元就要上交 15 元。给定每个级别的收入,求出金字塔尖上的人的收入。 输入描述 第一行输入 N,表示有 N 个代理商上下级关系。接下来输入 N 行,每行三个数&am…...
缓存、数据库双写一致性解决方案
双写一致性问题的核心是确保数据库和缓存之间的数据同步,以避免缓存与数据库数据不同步的问题,尤其是在高并发和异步环境下。本文将探讨双写一致性面临的主要问题和解决方案,重点关注最终一致性。 本文讨论的是最终一致性问题 双写一致性面…...
开放银行数据保护与合规实践案例
总体原则 开放银行的数据处理基本原则指的是数据处理者在数据生命周期的各阶段进行各种数 据处理时均应遵循的根本准则,是指导监管机构制定规范、进行管理以及开放银行进 行具体数据处理行为的纲领。根据《民法典》《个人信息保护法》《数据安全法》 《网络安全法…...
51c自动驾驶~合集47
我自己的原文哦~ https://blog.51cto.com/whaosoft/13083194 #DreamDrive 性能爆拉30%!英伟达:时空一致下的生成重建大一统新方案~ 从自车的驾驶轨迹中生成真实的视觉图像是实现自动驾驶模型可扩展训练的关键一步。基于重建的方法从log中生成3D场景…...
2024年AI与大数据技术趋势洞察:跨领域创新与社会变革
目录 引言 技术洞察 1. 大模型技术的创新与开源推动 2. AI Agent 智能体平台技术 3. 多模态技术的兴起:跨领域应用的新风口 4. 强化学习与推荐系统:智能化决策的底层驱动 5. 开源工具与平台的快速发展:赋能技术创新 6. 技术安全与伦理:AI技术的双刃剑 7. 跨领域技…...
【protobuf】二、proto3语法详解①
文章目录 前言Ⅰ. 字段规则Ⅱ. 消息类型的定义和使用1、定义2、使用1️⃣消息类型可作为字段类型使⽤2️⃣可导入其他 .proto 文件的消息并使用 -- import 3、创建通讯录 2.0 版本的 .proto 文件4、通讯录 2.0 版本的读写实现 -- 第一种验证方式5、decode选项 -- 第二种验证方式…...
React 中hooks之useLayoutEffect 用法总结以及与useEffect的区别
React useLayoutEffect 1. useLayoutEffect 基本概念 useLayoutEffect 是 React 的一个 Hook,它的函数签名与 useEffect 完全相同,但它会在所有的 DOM 变更之后同步调用 effect。它可以用来读取 DOM 布局并同步触发重渲染。 2. useLayoutEffect vs us…...
实战经验:使用 Python 的 PyPDF 进行 PDF 操作
文章目录 1. 为什么选择 PyPDF?2. 安装 PyPDF3. PDF 文件的合并与拆分3.1 合并 PDF 文件3.2 拆分 PDF 文件 4. 提取 PDF 文本5. 修改 PDF 元信息6. PDF 加密与解密6.1 加密 PDF6.2 解密 PDF 7. 页面旋转与裁剪7.1 旋转页面7.2 裁剪页面 8. 实战经验总结 PDF 是一种非…...
数据结构与算法之排序: LeetCode 15. 三数之和 (Ts版)
三数之和 https://leetcode.cn/problems/3sum/description/ 描述 给你一个整数数组 nums ,判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k ,同时还满足 nums[i] nums[j] nums[k] 0请你返回所有和为 0 且不重复的三元…...
51c嵌入式~单片机~合集6
我自己的原文哦~ https://blog.51cto.com/whaosoft/13127816 一、STM32单片机的知识点总结 本文将以STM32F10x为例,对标准库开发进行概览。主要分为三块内容: STM32系统结构寄存器通过点灯案例,详解如何基于标准库构建STM32工程 STM3…...
SQL Server Management Studio 表内数据查询与删除指令
查询指令 //select * from 表名称 where 列名称 数据名称 select * from Card_Info where num CC3D4D删除指令,删除数据库有风险,操作不可逆,建议删除前备份,以免删错。 //Delete * from 表名称 where 列名称 数据名称 Delete f…...
Timesheet.js - 轻松打造炫酷时间表
Timesheet.js - 轻松打造炫酷时间表 前言 在现代网页设计中,时间表是一个常见的元素,用于展示项目进度、历史事件、个人经历等信息。 然而,创建一个既美观又功能强大的时间表并非易事。 幸运的是,Timesheet.js 这款神奇的 Java…...
产品经理面试题总结2025【其一】
一、产品理解与定位 1、你如何理解产品经理这个角色? 作为一名互联网产品经理,我理解这个角色的核心在于成为产品愿景的制定者和执行的推动者。具体来说,产品经理是连接市场、用户和技术团队之间的桥梁,负责理解市场需求、用户痛…...
第16章:Python TDD实现多币种货币运算
写在前面 这本书是我们老板推荐过的,我在《价值心法》的推荐书单里也看到了它。用了一段时间 Cursor 软件后,我突然思考,对于测试开发工程师来说,什么才更有价值呢?如何让 AI 工具更好地辅助自己写代码,或许…...
【Web】2025-SUCTF个人wp
目录 SU_blog SU_photogallery SU_POP SU_blog 先是注册功能覆盖admin账号 以admin身份登录,拿到读文件的权限 ./article?filearticles/..././..././..././..././..././..././etc/passwd ./article?filearticles/..././..././..././..././..././..././proc/1…...
C++实现矩阵Matrix类 实现基本运算
本系列文章致力于实现“手搓有限元,干翻Ansys的目标”,基本框架为前端显示使用QT实现交互,后端计算采用Visual Studio C。 目录 Matrix类 1、public function 1.1、构造函数与析构函数 1.2、获取矩阵数值 1.3、设置矩阵 1.4、矩阵转置…...
【GORM】初探gorm模型,字段标签与go案例
GORM是什么? GORM 是一个Go 语言 ORM(对象关系映射)库,它让我们可以使用结构体来操作数据库,而无需编写SQL 语句 GORM 模型与字段标签详解 在 GORM 中,模型是数据库表的抽象表示,字段标签&am…...
html全局遮罩,通过websocket来实现实时发布公告
1.index.html代码示例 <div id"websocket" style"display:none;position: absolute;color:red;background-color: black;width: 100%;height: 100%;z-index: 100; opacity: 0.9; padding-top: 30%;padding-left: 30%; padding-border:1px; "onclick&q…...
基于VSCode+CMake+debootstrap搭建Ubuntu交叉编译开发环境
基于VSCodeCMakedebootstrap搭建Ubuntu交叉编译开发环境 1 基于debootstrap搭建目标系统环境1.1 安装必要软件包1.2 创建sysroot目录1.3 运行debootstrap1.4 挂载必要的虚拟文件系统1.5 复制 QEMU 静态二进制文件1.6 进入目标系统1.7 使用目标系统(以安装zlog为例&a…...
C#中System.Text.Json:从入门到精通的实用指南
一、引言 在当今数字化时代,数据的高效交换与处理成为软件开发的核心环节。JSON(JavaScript Object Notation)凭借其简洁、轻量且易于读写的特性,已然成为数据交换领域的中流砥柱。无论是前后端数据交互,还是配置文件…...
【深度学习】Huber Loss详解
文章目录 1. Huber Loss 原理详解2. Pytorch 代码详解3.与 MSELoss、MAELoss 区别及各自优缺点3.1 MSELoss 均方误差损失3.2 MAELoss 平均绝对误差损失3.3 Huber Loss 4. 总结4.1 优化平滑4.2 梯度较好4.3 为什么说 MSE 是平滑的 1. Huber Loss 原理详解 Huber Loss 是一种结合…...
Maven下载配置
目录 Win下载配置maven的环境变量 Mac下载安装配置环境变量 MavenSetting.xml文件配置 Win 下载 https://maven.apache.org/ 在主页面点击Download 点击archives 最好不要下载使用新版本,我使用的是maven-3.6.3,我们点击页面下方的archives࿰…...
JS基础(5):运算符和语句
一.运算符 1.赋值运算符 加减乘除都是一样的,,-,*,/ 2.一元运算符:经常用来计数 自增: 每次只能加一 自减:-- 前置自增 后置自增 结…...
游戏引擎学习第81天
仓库:https://gitee.com/mrxiao_com/2d_game_2 或许我们应该尝试在地面上添加一些绘图 在这段时间的工作中,讨论了如何改进地面渲染的问题。虽然之前并没有专注于渲染部分,因为当时主要的工作重心不在这里,但在实现过程中,发现地…...
网络安全 | 什么是正向代理和反向代理?
关注:CodingTechWork 引言 在现代网络架构中,代理服务器扮演着重要的角色。它们在客户端和服务器之间充当中介,帮助管理、保护和优化数据流。根据代理的工作方向和用途,代理服务器可分为正向代理和反向代理。本文将深入探讨这两种…...
前缀和——模板 二维前缀和
一.题目描述 【模板】二维前缀和_牛客题霸_牛客网 二.题目解析 这道题和上一道题有点异曲同工之妙。输入一个m行n列的矩阵,然后进行q次操作,每次操作输入4个数,作为两个点的坐标,计算这两个点为对角线的矩阵的和。 三.算法原理 …...
oracle使用case when报错ORA-12704字符集不匹配原因分析及解决方法
问题概述 使用oracle的case when函数时,报错提示ORA-12704字符集不匹配,如下图,接下来分析报错原因并提出解决方法。 样例演示 现在有一个TESTTABLE表,本表包含的字段如下图所示,COL01字段是NVARCHAR2类型࿰…...
高等数学学习笔记 ☞ 定积分与积分公式
1. 定积分的基本概念 1.1 定积分的定义 1. 定义:设函数在闭区间上有界。在闭区间上任意插入若干个分点,即, 此时每个小区间的长度记作(不一定是等分的)。然后在每个小区间上任意取,对应的函数值为。 为保证每段的值(即矩形面积)无…...
MLMs之Agent:Phidata的简介、安装和使用方法、案例应用之详细攻略
MLMs之Agent:Phidata的简介、安装和使用方法、案例应用之详细攻略 目录 Phidata简介 Phidata安装和使用方法 1、安装 2、使用方法 (1)、认证 (2)、创建 Agent (3)、运行 Agent (4)、Agent Playground Phidata 案例应用 1、多工具 Agent 2、多模态 Agent …...
如何在不暴露MinIO地址的情况下,用Spring Boot与KKFileView实现文件预览
在现代Web应用中,文件预览是一项常见且重要的功能。它允许用户在不上传或下载文件的情况下,直接在浏览器中查看文件内容。然而,直接将文件存储服务(如MinIO)暴露给前端可能会带来安全风险。本文将介绍如何在不暴露MinI…...
ESP8266固件烧录
一、烧录原理 1、引脚布局 2、引脚定义 3、尺寸封装 4、环境要求 5、接线方式 ESP8266系列模块集成了高速GPI0和外围接口,这可能会导致严重的开关噪声。如果某些应用需要高功率和EMI特性,建议在数字I/0线上串联10到100欧姆。这可以在切换电源时抑制过冲…...
Python 模拟真人鼠标轨迹算法 - 防止游戏检测
一.简介 鼠标轨迹算法是一种模拟人类鼠标操作的程序,它能够模拟出自然而真实的鼠标移动路径。 鼠标轨迹算法的底层实现采用C/C语言,原因在于C/C提供了高性能的执行能力和直接访问操作系统底层资源的能力。 鼠标轨迹算法具有以下优势: 模拟…...
三天急速通关Java基础知识:Day1 基本语法
三天急速通关JAVA基础知识:Day1 基本语法 0 文章说明1 关键字 Keywords2 注释 Comments2.1 单行注释2.2 多行注释2.3 文档注释 3 数据类型 Data Types3.1 基本数据类型3.2 引用数据类型 4 变量与常量 Variables and Constant5 运算符 Operators6 字符串 String7 输入…...
免费使用 Adobe 和 JetBrains 软件的秘密
今天想和大家聊聊如何利用 Edu 教育邮箱免费使用 Photoshop、Illustrator 等 Adobe 系列软件,以及 JetBrains 开发工具。 首先,Adobe 的软件是设计师的必备工具。无论是处理图像的 Photoshop,还是进行矢量设计的 Illustrator,它们…...
Pytorch 自学笔记(三):利用自定义文本数据集构建Dataset和DataLoader
Pytorch 自学笔记(三) 1. Dataset与DataLoader1.1 torch.utils.data.Dataset1.2 torch.utils.data.DataLoader Pytorch 自学笔记系列的第三篇。针对Pytorch的Dataset和DataLoader进行简单的介绍,同时,介绍如何使用自定义文本数据集…...
gradle项目的创建和基本结构简介
文章目录 创建gradle项目(命令行)创建gradle项目(IDEA)项目基本结构和功能Gradle 构建流程测试类体验 创建gradle项目(命令行) yangMacdeMac-mini gradleStudy % gradle init Starting a Gradle Daemon (s…...
wow-agent---Day3 Zigent 智能代理开发框架
这个框架是课程讲解的,但资料比较少,觉得框架比较小众,所以这里只分析代码,打算把更多的精力放在metagpt的学习上,毕竟还是要学教为主流的框架,这对后续维护升级都有帮助,当然感兴趣作为研究&am…...
python 入门
1. Python 概述 1.1 简介 python 是一种面向对象的解释型编程语言,由吉多范罗苏姆开发; 1991 年,公开发行版发布; 因其可以将其他语言制作的模块轻松联接在一起,又被称作胶水语言; 1.2 优点 简单易学&…...
sentinel微服务保护
学习链接 SpringCloudRabbitMQDockerRedis搜索分布式 文章目录 学习链接1.初识Sentinel1.1.雪崩问题及解决方案1.1.1.雪崩问题1.1.2.超时处理1.1.3.仓壁模式1.1.4.断路器1.1.5.限流1.1.6.总结 1.2.服务保护技术对比1.3.Sentinel介绍和安装1.3.1.初识Sentinel官网地址github地址…...
接口测试Day10-封装IHRM登录
-登录接口 普通方式实现 登录接口对象层 思路: 动态变化的,写入参数固定不变,直接写到方法实现内响应结果,通过返回值 return 分析: 封装实现: 登录接口测试用例层 封装断言方法 1、创建 文件 assert_uti…...
什么是IP地址、子网掩码、网关、DNS
简单解释 IP地址在网络中用于标识一个节点(或者网络设备的接口) IP地址用于IP报文在网络中的寻址 一个IPv4地址有32 bit。 IPv4地址通常采用“点分十进制”表示。 IPv4地址范围:0.0.0.0~255.255.255.255 网络部分:用来标识一个网络,代表IP地址所属网络。 主机部分:…...