Python----深度学习(基于深度学习Pytroch簇分类,圆环分类,月牙分类)
一、引言
深度学习的重要性
深度学习是一种通过模拟人脑神经元结构来进行数据学习和模式识别的技术,在分类任务中展现出强大的能力。
分类任务的多样性
分类任务涵盖了各种场景,例如簇分类、圆环分类和月牙分类,每种任务都有不同的特征和应用。
二、分类任务详解
2.1、簇分类
- 定义
簇分类旨在将数据点分为多个簇或类别,目标是在特征空间中找到数据点的天然聚集。 - 数据特性
通常数据聚集在不同的区域形成簇,这些簇可能具有不同的形状和大小。 - 应用场景
数据挖掘、市场细分、社交网络分析等。
簇分类数据
class1_points = np.array([[3.2, 3.0], [2.6, 3.4], [3.5, 4.9], [2.5, 3.4], [1.8, 2.7], [1.3, 1.9], [1.1, 3.4], [1.0, 4.0],[1.2, 5.0], [2.8, 4.1],[2.7, 3.1], [2.6, 4.5], [2.1, 3.3], [2.3, 2.4], [2.6, 3.1], [1.9, 3.0], [0.7, 4.2], [1.4, 3.3],[1.6, 4.6], [2.3, 2.0],[1.3, 4.2], [1.9, 3.8], [3.6, 6.0], [1.2, 3.1], [1.6, 3.1], [3.5, 4.1], [1.7, 2.6], [2.4, 3.3],[0.8, 2.2], [1.5, 4.3],[1.3, 3.9], [1.6, 5.4], [3.4, 3.7], [2.3, 3.4], [2.6, 2.4], [1.8, 2.5], [1.1, 4.1], [1.8, 2.8],[0.7, 4.4], [1.1, 3.4],[1.9, 3.6], [1.5, 4.9], [1.0, 3.3], [1.4, 3.6], [2.8, 3.3], [3.1, 4.2], [2.7, 3.8], [3.3, 2.6],[3.0, 2.7], [0.8, 3.0],[1.1, 3.8], [1.8, 3.5], [1.9, 2.8], [0.7, 3.1], [2.5, 2.6], [1.3, 2.5], [2.9, 2.9], [3.1, 2.3],[2.4, 2.8], [1.5, 4.0],[1.2, 3.8], [2.4, 2.3], [2.1, 1.9], [2.6, 4.2], [2.1, 2.8], [1.6, 2.6], [0.9, 3.8], [1.5, 2.1],[1.7, 3.0], [3.0, 2.9],[2.3, 2.6], [1.5, 2.9], [2.9, 2.9], [1.9, 2.7], [0.9, 2.7], [1.0, 4.9], [3.3, 4.0], [2.3, 2.7],[2.2, 4.0], [1.7, 4.2],[1.5, 3.4], [2.1, 3.5], [2.7, 3.9], [1.0, 4.8], [2.4, 2.8], [1.5, 2.6], [2.2, 3.2], [2.5, 2.6],[3.9, 2.8], [2.9, 4.1],[2.1, 4.3], [1.9, 3.4], [1.3, 1.9], [0.7, 3.3], [1.8, 4.2], [1.7, 3.2], [3.9, 2.9], [1.6, 4.2],[2.4, 4.4], [1.8, 1.3],[3.5, 2.0], [2.2, 3.1], [3.0, 3.5], [2.9, 3.3], [1.9, 2.9], [1.6, 2.7], [2.8, 3.6], [3.0, 2.7],[2.9, 4.4], [3.1, 3.4],[1.9, 1.2], [3.0, 1.6], [2.0, 3.7], [1.3, 3.1], [2.8, 2.4], [1.5, 2.6], [2.2, 3.1], [3.0, 3.7],[0.9, 4.3], [3.4, 3.6],[1.0, 2.4], [2.1, 3.3], [0.7, 2.3], [2.9, 2.3], [2.7, 3.5], [1.3, 2.6], [1.7, 4.2], [2.5, 4.1],[2.2, 3.4], [3.3, 3.0],[2.2, 3.5], [1.7, 3.1], [1.9, 2.8], [1.7, 2.9], [3.4, 3.0], [1.6, 4.9], [2.8, 3.7], [1.3, 3.7],[2.6, 2.6], [4.1, 3.5],[4.1, 3.1], [1.2, 2.6], [2.5, 3.0], [1.8, 4.0], [3.6, 4.0], [2.1, 4.3], [1.8, 3.2], [3.3, 1.9],[2.4, 3.5], [1.4, 3.9]])
class2_points = np.array([[8.8, 7.2], [7.8, 7.3], [6.8, 7.8], [8.1, 7.5], [7.8, 5.4], [7.6, 8.1], [8.3, 7.5], [6.9, 8.5],[8.0, 8.2], [8.7, 7.2],[8.8, 7.0], [8.2, 8.3], [7.7, 7.6], [8.3, 8.1], [8.3, 7.7], [8.0, 7.7], [6.7, 6.2], [8.4, 7.8],[7.6, 7.3], [6.4, 8.3],[8.0, 6.6], [7.0, 6.1], [8.2, 6.5], [6.7, 6.4], [7.1, 8.4], [6.6, 7.6], [7.9, 7.6], [8.0, 8.0],[7.3, 8.6], [8.7, 7.5],[7.8, 9.2], [7.3, 6.1], [7.7, 7.4], [8.0, 7.3], [8.2, 7.3], [6.5, 8.4], [6.7, 7.0], [7.9, 8.2],[6.0, 7.1], [7.9, 7.6],[7.1, 7.8], [9.0, 7.4], [7.2, 8.5], [9.1, 6.5], [7.3, 8.6], [7.2, 7.7], [8.8, 7.3], [7.0, 6.5],[6.7, 8.4], [7.4, 8.3],[9.2, 6.3], [7.8, 8.0], [9.4, 7.3], [8.0, 6.5], [6.8, 7.3], [8.5, 7.4], [6.6, 7.4], [8.6, 8.4],[9.8, 6.9], [6.7, 9.5],[6.5, 8.0], [8.1, 7.6], [7.4, 8.0], [8.8, 6.1], [7.1, 9.3], [7.3, 7.7], [7.9, 6.7], [7.2, 9.8],[8.7, 7.8], [7.8, 9.0],[7.2, 7.3], [9.2, 8.9], [7.3, 7.3], [8.3, 6.7], [7.2, 8.2], [8.1, 7.6], [7.5, 9.7], [6.8, 6.9],[8.8, 7.5], [7.6, 7.0],[7.9, 8.7], [8.8, 7.8], [7.5, 7.0], [8.2, 8.2], [6.9, 6.7], [8.1, 7.8], [8.9, 7.4], [9.4, 7.1],[5.8, 7.9], [7.2, 8.0],[8.0, 7.2], [7.2, 9.0], [7.3, 7.4], [7.3, 7.9], [9.0, 7.0], [7.9, 7.8], [7.2, 6.9], [8.4, 6.7],[8.4, 6.2], [8.4, 7.9],[7.6, 6.5], [6.3, 7.0], [8.1, 7.2], [7.2, 7.9], [7.9, 7.0], [7.7, 7.0], [7.1, 7.4], [8.9, 7.7],[7.5, 6.3], [7.3, 7.4],[8.1, 6.9], [5.4, 8.1], [7.7, 7.1], [7.8, 7.8], [7.3, 8.1], [9.1, 7.5], [7.4, 7.1], [6.6, 7.2],[7.7, 7.8], [7.7, 8.8],[6.5, 8.4], [8.5, 8.0], [5.9, 8.3], [6.9, 6.4], [7.7, 6.8], [8.5, 6.5], [8.6, 6.5], [8.4, 7.2],[8.0, 7.9], [8.3, 8.4],[9.2, 7.7], [8.6, 8.0], [7.2, 8.3], [7.6, 8.7], [6.7, 7.5], [6.6, 7.1], [8.7, 8.0], [7.0, 7.8],[8.4, 8.9], [6.6, 7.8],[8.3, 6.7], [6.7, 7.8], [6.6, 7.1], [8.3, 7.2], [8.9, 8.0], [6.8, 6.6], [8.0, 7.7], [6.3, 7.4],[7.2, 8.8], [7.7, 7.4]])
模型训练效果
2.2、圆环分类
- 定义
圆环分类任务涉及在特征空间中识别环形结构的数据分布。 - 数据特性
数据点围绕某个中心形成多个同心圆,每个环对应不同的类别。 - 应用场景
图像分类、手写数字识别、模式识别等。
圆环分类数据
class1_points = np.array([[1.7, 4.6], [5.4, 7.7], [3.8, 1.9], [3.5, 2.2], [2.2, 2.5], [4.1, 8.1], [3.7, 7.3], [1.8, 4.2],[6.8, 2.7], [6.9, 3.1],[7.9, 6.9], [8.1, 5.0], [7.2, 7.0], [7.9, 3.8], [6.3, 2.2], [5.0, 2.6], [4.9, 7.6], [6.1, 1.6],[3.0, 6.6], [3.3, 6.7],[1.8, 4.9], [3.2, 7.5], [7.8, 3.7], [7.3, 2.5], [7.1, 6.7], [1.6, 6.0], [2.6, 2.8], [1.9, 4.3],[2.5, 2.8], [7.3, 3.3],[7.7, 5.1], [2.7, 7.4], [6.2, 7.7], [5.6, 7.6], [6.4, 7.2], [7.1, 6.6], [3.8, 8.1], [2.4, 6.3],[7.5, 3.7], [1.6, 2.9],[3.9, 7.8], [7.2, 6.9], [7.4, 4.8], [7.5, 4.4], [2.0, 5.2], [2.0, 4.0], [7.3, 3.8], [5.5, 7.6],[7.5, 5.9], [4.0, 2.4],[6.9, 7.1], [5.3, 2.0], [3.3, 7.0], [4.0, 2.3], [2.7, 2.7], [5.9, 7.8], [5.7, 2.1], [7.8, 5.9],[2.6, 7.0], [5.4, 2.1],[7.0, 2.7], [5.4, 7.4], [7.0, 6.4], [7.5, 5.3], [4.2, 2.1], [3.7, 7.7], [7.7, 5.3], [6.1, 7.3],[1.6, 4.3], [3.3, 2.4],[1.9, 6.4], [1.9, 6.2], [7.7, 6.0], [4.2, 8.4], [4.7, 1.6], [3.0, 3.3], [2.1, 3.6], [1.8, 6.7],[4.8, 7.7], [6.8, 2.7],[3.3, 2.5], [5.6, 7.5], [5.9, 7.9], [2.3, 4.6], [2.2, 6.2], [4.8, 1.7], [1.9, 4.2], [1.4, 4.1],[3.5, 7.1], [5.9, 7.8],[6.6, 6.8], [2.3, 5.3], [4.0, 7.6], [3.9, 7.2], [4.6, 2.4], [3.0, 2.2], [7.3, 2.7], [1.6, 5.3],[2.8, 2.8], [2.5, 5.7],[7.7, 5.6], [4.6, 1.3], [3.1, 7.3], [2.0, 3.1], [7.1, 3.7], [6.1, 7.7], [3.1, 1.9], [6.5, 6.3],[2.1, 3.6], [7.3, 5.2],[1.7, 6.0], [2.2, 5.0], [7.4, 2.7], [2.2, 6.4], [5.0, 8.2], [2.6, 2.8], [2.6, 2.5], [7.5, 4.0],[1.7, 3.7], [3.8, 7.7],[2.9, 6.2], [4.9, 1.8], [1.9, 5.3], [6.8, 6.7], [5.2, 1.6], [5.7, 2.3], [3.8, 8.1], [6.7, 3.0],[2.3, 3.1], [8.3, 5.8],[2.1, 4.5], [5.3, 1.7], [3.2, 1.9], [7.0, 3.1], [6.3, 2.0], [4.2, 7.2], [6.1, 7.4], [2.3, 6.5],[5.4, 1.5], [5.7, 7.2],[4.5, 7.5], [2.4, 6.8], [7.6, 4.5], [3.3, 2.0], [1.8, 3.6], [1.8, 4.3], [7.5, 4.9], [4.6, 8.3],[6.9, 6.8], [7.4, 3.4],[3.6, 7.9], [7.6, 4.4], [7.8, 6.1], [6.0, 2.2], [6.4, 2.7], [4.9, 7.6], [1.7, 6.4], [7.7, 5.7],[6.8, 6.8], [3.1, 2.9],[2.0, 2.5], [4.5, 2.3], [6.7, 7.2], [7.5, 7.1], [1.9, 5.5], [5.5, 1.7], [6.6, 2.2], [6.1, 7.2],[3.9, 2.1], [2.5, 6.6],[7.7, 3.9], [7.4, 5.5], [7.6, 3.8], [3.7, 2.2], [2.3, 7.3], [5.0, 2.2], [5.5, 1.4], [2.9, 7.0],[6.7, 2.4], [2.0, 5.6],[6.4, 2.6], [7.3, 4.9], [4.0, 1.6], [3.3, 2.3], [7.6, 5.1], [3.5, 1.5], [4.7, 7.9], [6.1, 7.4],[2.2, 6.2], [6.9, 2.6],[2.2, 2.7], [4.1, 7.5], [8.2, 4.4], [3.5, 7.8], [2.4, 6.5], [2.1, 3.8], [1.8, 5.1], [2.3, 2.6],[6.4, 2.7], [7.0, 2.6],[7.4, 3.6], [5.9, 1.7], [8.3, 5.8], [7.8, 3.6], [7.7, 5.1], [8.0, 3.9], [1.3, 5.3], [3.4, 7.1],[4.7, 7.8], [2.1, 3.8],[7.1, 6.0], [7.5, 4.1], [7.1, 3.5], [7.3, 6.9], [6.6, 2.3], [7.5, 3.3], [7.1, 6.5], [8.0, 5.8],[8.0, 4.2], [3.6, 7.7],[1.9, 5.0], [2.6, 2.8], [5.1, 7.0], [6.9, 7.2], [2.0, 6.0], [7.5, 2.5], [4.0, 2.1], [2.9, 7.0],[4.2, 7.2], [5.3, 1.8],[2.6, 6.8], [3.1, 2.3], [3.6, 2.3], [5.5, 1.3], [1.3, 4.2], [6.2, 1.9], [2.5, 3.1], [1.8, 4.5],[1.7, 5.5], [5.7, 7.8],[8.2, 4.8], [2.0, 3.4], [1.4, 4.4], [5.5, 7.9], [4.0, 1.7], [7.8, 4.7], [6.3, 7.2], [2.5, 2.3],[7.4, 4.4], [5.1, 7.9]])
class2_points = np.array([[5.7, 4.8], [4.8, 5.0], [4.7, 4.6], [4.6, 5.3], [5.0, 5.5], [4.3, 4.9], [4.2, 5.9], [6.0, 5.0],[4.1, 5.2], [5.4, 5.0],[4.9, 5.4], [4.5, 6.2], [5.3, 5.5], [4.2, 5.0], [4.0, 4.9], [5.9, 4.9], [4.3, 6.1], [4.5, 4.3],[5.1, 5.8], [5.6, 4.5],[4.9, 4.3], [5.5, 5.7], [5.4, 5.0], [4.7, 4.9], [5.6, 5.3], [5.8, 4.8], [4.8, 5.6], [5.3, 5.3],[5.1, 4.7], [5.0, 5.3],[4.0, 4.4], [5.9, 5.2], [5.7, 4.7], [5.8, 5.2], [5.1, 4.0], [5.8, 5.9], [5.3, 6.0], [5.5, 4.8],[5.1, 4.7], [4.7, 4.3],[5.7, 5.0], [4.3, 4.7], [5.7, 4.9], [4.7, 4.0], [4.9, 4.9], [5.2, 4.6], [4.6, 5.6], [5.2, 5.3],[4.8, 5.9], [4.5, 4.7],[5.3, 5.2], [4.7, 4.3], [4.7, 5.7], [4.7, 4.2], [4.7, 5.3], [5.3, 5.4], [5.4, 5.9], [4.6, 4.1],[4.1, 5.8], [5.6, 5.1],[5.2, 4.5], [5.6, 4.7], [5.0, 4.8], [5.7, 4.3], [4.5, 5.7], [4.4, 5.7], [5.5, 5.3], [4.7, 5.4],[5.1, 5.7], [5.2, 4.3],[4.6, 4.9], [4.7, 5.5], [4.5, 4.2], [5.2, 4.5], [5.4, 3.9], [4.0, 5.0], [4.4, 4.0], [5.0, 4.2],[5.8, 5.6], [5.8, 5.2],[4.7, 4.6], [4.7, 5.8], [5.6, 4.5], [5.8, 4.9], [4.6, 5.5], [5.6, 4.5], [5.1, 4.5], [4.2, 4.8],[4.9, 5.3], [5.0, 5.2],[4.0, 4.8], [5.5, 4.8], [6.0, 4.7], [4.4, 5.1], [4.3, 4.9], [5.1, 5.6], [4.7, 5.6], [5.1, 4.9],[4.2, 5.4], [4.4, 4.6],[5.5, 5.9], [4.1, 4.8], [5.0, 4.6], [5.2, 5.0], [4.1, 5.5], [4.6, 5.1], [5.2, 5.5], [5.1, 4.0],[4.4, 4.5], [5.3, 5.3],[4.8, 5.3], [5.2, 4.6], [5.7, 4.4], [4.3, 5.0], [5.1, 4.9], [4.6, 5.0], [5.4, 5.6], [5.3, 4.4],[4.6, 4.3], [5.2, 5.6],[5.0, 4.3], [4.4, 4.4], [5.5, 4.9], [4.3, 5.5], [5.0, 5.3], [4.8, 4.9], [5.3, 5.6], [4.1, 4.7],[4.6, 5.2], [5.5, 4.6],[4.6, 4.6], [4.5, 5.4], [4.6, 4.2], [5.1, 4.3], [5.2, 4.3], [5.1, 5.6], [5.5, 4.5], [5.1, 4.0],[4.5, 5.1], [4.8, 3.7],[4.3, 5.1], [4.6, 5.4], [5.2, 3.9], [4.6, 5.1], [4.2, 5.1], [4.5, 5.2], [5.6, 5.3], [5.6, 5.1],[5.9, 5.2], [5.0, 4.1],[5.1, 4.3], [4.8, 6.0], [5.3, 5.5], [5.3, 4.4], [4.4, 5.1], [5.2, 5.0], [4.9, 4.4], [5.3, 5.2],[5.2, 6.1], [5.6, 5.9],[4.7, 4.2], [6.1, 5.6], [4.6, 5.7], [5.5, 5.0], [4.5, 5.1], [4.8, 6.0], [4.8, 5.0], [5.5, 4.3],[4.1, 4.9], [3.9, 4.6],[4.9, 5.3], [4.4, 4.1], [4.6, 5.3], [5.0, 4.7], [5.3, 5.9], [5.1, 5.4], [5.3, 5.3], [4.9, 4.5],[5.6, 5.1], [5.2, 4.5],[5.3, 4.6], [5.5, 5.6], [5.0, 6.1], [4.5, 5.3], [4.8, 5.6], [4.7, 4.9], [4.7, 5.6], [4.6, 4.3],[5.8, 5.0], [4.9, 4.8],[5.6, 5.3], [5.5, 5.2], [4.8, 5.3], [4.6, 4.5], [5.2, 4.9], [5.5, 5.6], [6.2, 4.1], [5.6, 5.3],[5.3, 5.4], [5.4, 5.0],[5.5, 4.8], [5.1, 4.6], [4.8, 5.4], [4.8, 5.3], [5.8, 4.8], [4.5, 4.8], [4.6, 4.9], [4.3, 3.9],[4.6, 5.3], [5.1, 5.3],[5.4, 5.7], [4.3, 5.2], [4.8, 4.9], [5.6, 4.7], [4.2, 5.0], [5.3, 5.6], [4.9, 4.0], [5.1, 4.7],[5.0, 5.4], [6.0, 5.5],[5.5, 4.6], [5.7, 5.3], [4.5, 4.7], [5.5, 5.0], [5.9, 4.9], [5.5, 4.6], [4.9, 5.6], [5.4, 5.3],[5.2, 4.4], [4.3, 4.5],[5.1, 4.2], [4.3, 5.1], [5.6, 5.7], [4.8, 5.0], [5.1, 5.5], [5.7, 5.2], [5.9, 4.9], [5.1, 4.3],[5.3, 5.2], [4.4, 4.7],[5.2, 5.8], [6.3, 5.1], [4.0, 5.4], [5.4, 4.7], [4.2, 5.3], [5.7, 4.9], [5.4, 5.5], [4.8, 5.2],[5.4, 5.8], [4.6, 5.0]])
模型训练效果
2.3、月牙分类
- 定义
月牙分类任务要求识别流形或不规则的形状,数据分布呈现出像月牙形状的特征。 - 数据特性
数据集中的点通常呈现出一种弯曲的形态,具有独特的边界。 - 应用场景
生物医学影像分析、信号处理、推荐系统等。
月牙分类数据
class1_points = np.array([[6.5, 4.3], [4.5, 6.4], [1.3, 5.1], [1.7, 4.4], [4.8, 5.7], [5.4, 5.6], [1.8, 4.9], [1.2, 3.8],[2.8, 5.7], [6.4, 3.8],[4.5, 5.9], [5.3, 6.0], [5.9, 5.0], [1.7, 4.6], [2.3, 5.7], [3.4, 6.1], [5.9, 4.4], [5.4, 5.1],[5.2, 5.2], [5.6, 5.4],[4.2, 6.2], [1.4, 3.7], [3.6, 6.3], [4.8, 6.0], [4.8, 6.0], [5.0, 6.1], [5.8, 5.1], [1.6, 4.5],[1.5, 5.1], [2.2, 6.0],[5.1, 5.8], [3.8, 6.3], [2.0, 5.7], [2.1, 5.6], [2.0, 5.1], [1.0, 4.9], [3.0, 6.3], [6.0, 4.2],[2.3, 6.3], [4.8, 6.1],[1.8, 5.1], [2.2, 5.7], [6.3, 4.3], [5.7, 5.3], [5.6, 5.5], [3.0, 6.1], [6.1, 3.7], [6.3, 4.7],[3.4, 6.1], [5.2, 5.7],[5.8, 3.7], [0.7, 4.6], [4.9, 6.2], [1.8, 5.1], [4.6, 5.9], [1.5, 5.0], [1.4, 4.4], [4.0, 6.4],[5.3, 5.8], [4.6, 6.1],[3.5, 6.0], [6.2, 4.6], [4.5, 6.0], [2.6, 6.1], [5.9, 5.0], [2.8, 6.4], [2.4, 6.0], [5.3, 6.0],[2.0, 5.7], [1.2, 3.7],[2.8, 5.9], [2.5, 5.5], [6.3, 4.6], [1.2, 3.7], [6.3, 4.4], [6.0, 4.8], [1.5, 4.2], [6.4, 4.2],[1.3, 4.6], [2.0, 5.2],[1.9, 5.2], [1.6, 5.4], [5.5, 5.7], [3.5, 6.6], [1.7, 5.0], [6.2, 4.6], [6.1, 4.5], [4.1, 5.9],[6.1, 4.9], [1.7, 5.2],[3.5, 6.2], [2.9, 6.4], [5.0, 5.8], [2.5, 5.8], [3.1, 6.0], [2.0, 5.1], [2.6, 5.7], [6.1, 4.0],[6.5, 4.4], [5.4, 6.1],[5.9, 4.1], [4.7, 5.9], [2.4, 6.5], [4.5, 6.4], [5.9, 4.6], [0.9, 3.9], [3.6, 6.3], [3.7, 6.3],[1.6, 4.3], [6.0, 5.7],[4.2, 6.3], [1.8, 5.2], [2.7, 5.9], [2.4, 5.5], [6.4, 3.8], [5.2, 6.1], [6.2, 4.7], [4.2, 6.5],[5.7, 3.6], [3.9, 6.1],[1.1, 4.6], [5.5, 5.3], [2.0, 5.9], [5.2, 5.4], [5.7, 5.2], [5.3, 5.0], [1.4, 4.1], [2.8, 6.6],[3.6, 6.3], [1.1, 4.3],[5.5, 5.2], [3.9, 6.9], [6.2, 4.2], [5.5, 5.5], [1.6, 4.1], [1.1, 3.9], [1.4, 4.9], [4.5, 6.1],[1.7, 5.0], [1.9, 4.7],[5.8, 5.7], [4.8, 5.6], [3.2, 5.7], [6.3, 4.0], [1.6, 4.2], [1.8, 5.1], [1.9, 5.5], [2.9, 5.6],[1.0, 3.8], [5.9, 5.5],[2.6, 5.6], [5.3, 5.4], [1.5, 5.0], [3.2, 6.1], [1.0, 4.1], [1.9, 5.8], [3.3, 6.2], [6.1, 3.9],[2.9, 5.8], [4.8, 5.9],[6.0, 4.4], [3.6, 6.2], [1.6, 5.1], [5.6, 5.0], [4.0, 6.2], [6.2, 4.3], [4.2, 6.4], [4.0, 6.1],[5.5, 5.1], [4.3, 6.1],[4.5, 5.8], [3.7, 6.7], [1.6, 5.6], [5.7, 4.6], [1.6, 4.9], [6.2, 5.7], [2.8, 6.2], [2.1, 5.7],[5.8, 6.2], [1.5, 5.0],[5.6, 5.6], [4.1, 5.7], [1.8, 4.6], [6.4, 4.1], [1.2, 3.8], [2.4, 6.0], [1.5, 5.2], [6.0, 3.9],[5.9, 4.7], [1.9, 5.5],[2.3, 5.5], [6.1, 4.4], [2.0, 5.2], [1.8, 5.5], [4.6, 6.3], [3.4, 6.2], [4.7, 6.3], [3.1, 6.1],[3.8, 6.3], [5.7, 5.5],[1.9, 5.4], [4.7, 5.9], [6.0, 4.2], [4.5, 6.5], [1.3, 4.2], [5.1, 6.0], [1.8, 5.2], [4.0, 6.4],[5.8, 5.6], [1.2, 3.9],[6.1, 5.4], [1.7, 4.9], [6.3, 5.0], [5.2, 5.0], [3.0, 6.4], [1.6, 4.8], [1.5, 5.2], [4.7, 6.3],[1.5, 4.8], [5.3, 5.8],[4.3, 5.9], [3.2, 6.3], [2.4, 5.5], [2.6, 5.4], [1.2, 3.9], [4.8, 6.3], [6.2, 4.6], [1.3, 5.3],[6.6, 4.1], [2.9, 6.3],[3.3, 6.1], [6.0, 5.3], [1.5, 4.9], [5.6, 5.7], [5.9, 4.5], [4.9, 6.1], [6.0, 4.6], [5.0, 5.4],[3.4, 6.1], [5.9, 4.9],[2.8, 5.4], [1.9, 5.3], [3.2, 5.8], [1.2, 4.7], [3.1, 6.3], [1.2, 4.0], [6.0, 5.7], [2.7, 6.0],[3.4, 6.0], [5.9, 5.4]])
class2_points = np.array([[6.5, 2.5], [6.4, 2.3], [6.6, 2.8], [7.0, 2.6], [4.3, 2.9], [4.1, 3.7], [3.9, 3.3], [7.2, 2.7],[3.8, 4.5], [4.0, 4.7],[4.0, 3.9], [8.3, 3.8], [6.5, 3.1], [8.0, 3.6], [7.9, 3.4], [6.8, 2.5], [4.0, 4.4], [7.0, 2.6],[7.7, 3.1], [6.0, 2.1],[6.7, 2.7], [8.7, 4.2], [4.0, 3.9], [5.9, 2.2], [6.3, 2.7], [7.3, 2.9], [5.0, 2.6], [8.1, 3.9],[4.2, 4.0], [5.1, 2.5],[8.2, 3.3], [7.1, 2.9], [5.0, 3.0], [7.1, 2.3], [4.8, 3.1], [3.5, 4.4], [8.3, 3.3], [5.2, 3.0],[6.1, 2.2], [6.8, 2.2],[3.9, 4.9], [8.6, 3.6], [6.0, 2.3], [4.1, 4.0], [5.2, 2.8], [8.2, 3.5], [8.1, 3.4], [8.7, 4.9],[5.0, 2.4], [5.0, 2.6],[8.0, 3.0], [8.4, 4.3], [5.3, 2.7], [8.7, 5.1], [5.6, 2.5], [5.4, 2.7], [3.8, 4.5], [9.1, 4.3],[8.8, 4.1], [4.7, 3.3],[8.4, 4.6], [8.3, 4.5], [7.0, 2.7], [6.4, 2.3], [5.2, 2.5], [7.0, 2.2], [8.6, 3.3], [7.5, 3.0],[4.0, 3.9], [7.6, 3.0],[7.0, 2.7], [4.3, 3.1], [5.7, 2.8], [3.8, 4.3], [4.9, 3.1], [4.1, 3.3], [7.0, 2.3], [5.1, 2.9],[8.9, 4.5], [6.0, 2.7],[7.4, 2.6], [8.7, 4.7], [8.6, 4.5], [7.7, 3.0], [8.9, 5.0], [4.1, 4.0], [3.9, 4.8], [3.7, 3.8],[5.5, 2.3], [7.5, 3.4],[4.2, 3.3], [4.1, 3.5], [7.8, 3.1], [3.8, 4.7], [5.2, 3.3], [3.5, 4.7], [3.5, 4.8], [3.9, 4.2],[6.7, 3.1], [7.9, 3.0],[8.6, 4.1], [8.5, 4.4], [7.3, 2.6], [3.4, 4.7], [8.7, 3.9], [7.6, 3.0], [4.6, 3.1], [4.8, 2.7],[4.5, 2.5], [7.4, 2.9],[5.1, 2.7], [6.9, 2.7], [7.6, 2.6], [9.0, 5.0], [7.1, 2.2], [5.0, 2.7], [5.6, 2.4], [3.6, 4.8],[6.0, 2.4], [6.9, 2.9],[8.3, 4.9], [3.9, 4.0], [4.9, 3.1], [8.7, 3.9], [6.3, 2.4], [6.8, 2.5], [5.8, 2.1], [4.5, 4.1],[4.7, 3.2], [6.3, 2.6],[8.8, 4.8], [8.6, 4.1], [4.5, 3.8], [3.6, 4.3], [8.8, 5.0], [4.2, 3.9], [8.6, 4.4], [8.8, 4.0],[5.0, 3.4], [6.4, 2.5],[4.6, 2.6], [6.0, 2.6], [8.1, 3.5], [8.7, 4.5], [4.8, 2.8], [5.9, 2.7], [6.8, 2.6], [8.9, 4.6],[6.4, 2.6], [6.9, 2.5],[8.8, 3.3], [3.7, 4.0], [8.3, 4.0], [3.6, 4.3], [7.2, 2.2], [8.8, 4.4], [8.7, 4.7], [3.8, 4.4],[8.1, 3.4], [3.5, 4.7],[8.7, 4.1], [4.3, 3.8], [3.6, 4.0], [5.0, 2.7], [7.7, 3.2], [8.4, 3.2], [4.3, 3.7], [8.6, 4.3],[7.5, 3.2], [8.3, 3.8],[4.9, 2.9], [5.4, 2.4], [3.9, 4.9], [8.9, 3.6], [8.3, 3.4], [8.2, 3.3], [7.8, 2.8], [8.2, 3.2],[8.9, 4.8], [8.6, 3.8],[3.9, 5.3], [4.4, 4.6], [7.8, 3.0], [6.9, 2.7], [7.7, 3.0], [3.7, 3.7], [6.6, 3.0], [5.3, 2.6],[4.4, 4.1], [8.1, 3.6],[8.5, 3.4], [8.0, 3.7], [5.2, 2.7], [7.3, 2.8], [4.1, 4.0], [8.5, 3.6], [7.5, 2.4], [3.9, 3.8],[5.9, 2.5], [6.6, 2.9],[4.4, 3.4], [4.8, 3.3], [4.4, 3.1], [8.7, 4.8], [6.2, 2.7], [5.0, 3.2], [5.6, 2.7], [8.5, 4.2],[4.2, 3.5], [4.0, 3.1],[3.8, 4.1], [5.3, 2.2], [4.9, 3.3], [5.7, 3.1], [4.4, 3.5], [5.3, 2.8], [4.2, 3.3], [8.4, 3.6],[8.1, 3.5], [3.8, 4.4],[3.6, 4.3], [4.3, 4.6], [7.9, 3.1], [8.9, 4.9], [7.8, 3.2], [4.1, 3.7], [4.8, 3.1], [3.7, 4.3],[8.5, 3.8], [5.2, 2.7],[7.3, 2.8], [6.5, 2.6], [8.4, 4.3], [8.2, 4.0], [7.2, 2.9], [3.7, 4.2], [7.6, 2.6], [4.3, 4.7],[4.5, 3.5], [4.0, 4.2],[6.4, 2.7], [6.3, 2.6], [8.9, 3.9], [5.8, 2.3], [6.1, 2.6], [4.1, 3.7], [8.2, 3.1], [9.1, 4.5],[3.7, 4.1], [6.3, 2.7]])
模型训练效果
三、PyTorch实现
以月牙分类为例
划分数据集
# 将 point1 分割为训练集和测试集
np.random.shuffle(class1_points) # 随机打乱数据
split_index = int(0.1 * len(class1_points)) # 取前 10% 的数据作为测试集class1_train_points = class1_points[split_index:]
class2_train_points = class2_points[split_index:]
class1_test_points = class1_points[:split_index]
class2_test_points = class2_points[:split_index]# 合并两类点
train_points = np.concatenate((class1_train_points, class2_train_points))
# 标签 0表示类别1,1表示类别2
train_labels1 = np.zeros(len(class1_train_points))
train_labels2 = np.ones(len(class2_train_points))
train_labels = np.concatenate((train_labels1, train_labels2))
# 合并两类点
test_points = np.concatenate((class1_test_points, class2_test_points))
# 标签 0表示类别1,1表示类别2
test_labels1 = np.zeros(len(class1_test_points))
test_labels2 = np.ones(len(class2_test_points))
test_labels = np.concatenate((test_labels1, test_labels2))
构建模型
class ModelClass(nn.Module):def __init__(self):super().__init__()self.layer1 = nn.Linear(2, 8)self.layer2 = nn.Linear(8, 16)self.layer3 = nn.Linear(16, 32)self.layer4 = nn.Linear(32, 16)self.layer5 = nn.Linear(16, 8)self.layer6 = nn.Linear(8, 2)def forward(self, x):x = torch.tanh(self.layer1(x))x = torch.tanh(self.layer2(x))x = torch.tanh(self.layer3(x))x = torch.tanh(self.layer4(x))x = torch.tanh(self.layer5(x))x = torch.softmax(self.layer6(x),dim=1)return xmodel = ModelClass()
创建损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.005)
模型训练
for n in range(1,2001):# 将numpy数据转换为torch tensorinputs = torch.tensor(train_points, dtype=torch.float32)train_labels = torch.tensor(train_labels, dtype=torch.long)# 前向传播outputs = model(inputs)loss = criterion(outputs, train_labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if n % 100== 0 or n == 1:print(n,loss.item())
可视化
# 创建等高线绘图的网格点
x_min, x_max = 0, 10
y_min, y_max = 0, 10
step_size = 0.2
xx, yy = np.meshgrid(np.arange(x_min, x_max, step_size),np.arange(y_min, y_max, step_size))
grid_points = np.c_[xx.ravel(), yy.ravel()]# 创建三维图形和右侧的二维子图
fig = plt.figure(figsize=(10, 5))ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)step_list = []
loss_list = []
test_step_list = []
test_loss_list = []# 开始迭代
for n in range(1,2001):# 将numpy数据转换为torch tensorinputs = torch.tensor(train_points, dtype=torch.float32)train_labels = torch.tensor(train_labels, dtype=torch.long)# 前向传播outputs = model(inputs)loss = criterion(outputs, train_labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 更新右侧的损失图数据并绘制step_list.append(n)loss_list.append(loss.detach())# 显示频率设置frequency_display = 50# 显示与输出if n % 100== 0 or n == 1:# 使用训练好的模型预测网格点的标签grid_points_tensor = torch.tensor(grid_points, dtype=torch.float32)Z = model(grid_points_tensor).detach().numpy()Z = Z[:, 1] # 取正类的概率值Z = Z.reshape(xx.shape)# 绘制2D图ax1 = plt.subplot(121)ax1.clear()ax1.scatter(class1_train_points[:, 0], class1_train_points[:, 1], c='blue', label='label1')ax1.scatter(class2_train_points[:, 0], class2_train_points[:, 1], c='red', label='label2')ax1.contour(xx, yy, Z, levels=[0.5], colors='black')# 计算测试集损失test_inputs = torch.tensor(test_points, dtype=torch.float32)y_pred_test = model(test_inputs)test_labels = torch.tensor(test_labels, dtype=torch.long)loss_test = criterion(y_pred_test, test_labels)test_step_list.append(n)test_loss_list.append(loss_test.detach())ax2 = plt.subplot(122)ax2.clear()ax2.plot(step_list, loss_list, 'r-', label='Train Loss')ax2.plot(test_step_list, test_loss_list, 'b-', label='Test Loss') # 绘制测试集损失ax2.set_xlabel("Step")ax2.set_ylabel("Loss")ax2.legend()plt.show()
完整代码
import numpy as np
import torch
import random
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.init as init# 创造数据,数据集
class1_points = np.array([[6.5, 4.3], [4.5, 6.4], [1.3, 5.1], [1.7, 4.4], [4.8, 5.7], [5.4, 5.6], [1.8, 4.9], [1.2, 3.8],[2.8, 5.7], [6.4, 3.8],[4.5, 5.9], [5.3, 6.0], [5.9, 5.0], [1.7, 4.6], [2.3, 5.7], [3.4, 6.1], [5.9, 4.4], [5.4, 5.1],[5.2, 5.2], [5.6, 5.4],[4.2, 6.2], [1.4, 3.7], [3.6, 6.3], [4.8, 6.0], [4.8, 6.0], [5.0, 6.1], [5.8, 5.1], [1.6, 4.5],[1.5, 5.1], [2.2, 6.0],[5.1, 5.8], [3.8, 6.3], [2.0, 5.7], [2.1, 5.6], [2.0, 5.1], [1.0, 4.9], [3.0, 6.3], [6.0, 4.2],[2.3, 6.3], [4.8, 6.1],[1.8, 5.1], [2.2, 5.7], [6.3, 4.3], [5.7, 5.3], [5.6, 5.5], [3.0, 6.1], [6.1, 3.7], [6.3, 4.7],[3.4, 6.1], [5.2, 5.7],[5.8, 3.7], [0.7, 4.6], [4.9, 6.2], [1.8, 5.1], [4.6, 5.9], [1.5, 5.0], [1.4, 4.4], [4.0, 6.4],[5.3, 5.8], [4.6, 6.1],[3.5, 6.0], [6.2, 4.6], [4.5, 6.0], [2.6, 6.1], [5.9, 5.0], [2.8, 6.4], [2.4, 6.0], [5.3, 6.0],[2.0, 5.7], [1.2, 3.7],[2.8, 5.9], [2.5, 5.5], [6.3, 4.6], [1.2, 3.7], [6.3, 4.4], [6.0, 4.8], [1.5, 4.2], [6.4, 4.2],[1.3, 4.6], [2.0, 5.2],[1.9, 5.2], [1.6, 5.4], [5.5, 5.7], [3.5, 6.6], [1.7, 5.0], [6.2, 4.6], [6.1, 4.5], [4.1, 5.9],[6.1, 4.9], [1.7, 5.2],[3.5, 6.2], [2.9, 6.4], [5.0, 5.8], [2.5, 5.8], [3.1, 6.0], [2.0, 5.1], [2.6, 5.7], [6.1, 4.0],[6.5, 4.4], [5.4, 6.1],[5.9, 4.1], [4.7, 5.9], [2.4, 6.5], [4.5, 6.4], [5.9, 4.6], [0.9, 3.9], [3.6, 6.3], [3.7, 6.3],[1.6, 4.3], [6.0, 5.7],[4.2, 6.3], [1.8, 5.2], [2.7, 5.9], [2.4, 5.5], [6.4, 3.8], [5.2, 6.1], [6.2, 4.7], [4.2, 6.5],[5.7, 3.6], [3.9, 6.1],[1.1, 4.6], [5.5, 5.3], [2.0, 5.9], [5.2, 5.4], [5.7, 5.2], [5.3, 5.0], [1.4, 4.1], [2.8, 6.6],[3.6, 6.3], [1.1, 4.3],[5.5, 5.2], [3.9, 6.9], [6.2, 4.2], [5.5, 5.5], [1.6, 4.1], [1.1, 3.9], [1.4, 4.9], [4.5, 6.1],[1.7, 5.0], [1.9, 4.7],[5.8, 5.7], [4.8, 5.6], [3.2, 5.7], [6.3, 4.0], [1.6, 4.2], [1.8, 5.1], [1.9, 5.5], [2.9, 5.6],[1.0, 3.8], [5.9, 5.5],[2.6, 5.6], [5.3, 5.4], [1.5, 5.0], [3.2, 6.1], [1.0, 4.1], [1.9, 5.8], [3.3, 6.2], [6.1, 3.9],[2.9, 5.8], [4.8, 5.9],[6.0, 4.4], [3.6, 6.2], [1.6, 5.1], [5.6, 5.0], [4.0, 6.2], [6.2, 4.3], [4.2, 6.4], [4.0, 6.1],[5.5, 5.1], [4.3, 6.1],[4.5, 5.8], [3.7, 6.7], [1.6, 5.6], [5.7, 4.6], [1.6, 4.9], [6.2, 5.7], [2.8, 6.2], [2.1, 5.7],[5.8, 6.2], [1.5, 5.0],[5.6, 5.6], [4.1, 5.7], [1.8, 4.6], [6.4, 4.1], [1.2, 3.8], [2.4, 6.0], [1.5, 5.2], [6.0, 3.9],[5.9, 4.7], [1.9, 5.5],[2.3, 5.5], [6.1, 4.4], [2.0, 5.2], [1.8, 5.5], [4.6, 6.3], [3.4, 6.2], [4.7, 6.3], [3.1, 6.1],[3.8, 6.3], [5.7, 5.5],[1.9, 5.4], [4.7, 5.9], [6.0, 4.2], [4.5, 6.5], [1.3, 4.2], [5.1, 6.0], [1.8, 5.2], [4.0, 6.4],[5.8, 5.6], [1.2, 3.9],[6.1, 5.4], [1.7, 4.9], [6.3, 5.0], [5.2, 5.0], [3.0, 6.4], [1.6, 4.8], [1.5, 5.2], [4.7, 6.3],[1.5, 4.8], [5.3, 5.8],[4.3, 5.9], [3.2, 6.3], [2.4, 5.5], [2.6, 5.4], [1.2, 3.9], [4.8, 6.3], [6.2, 4.6], [1.3, 5.3],[6.6, 4.1], [2.9, 6.3],[3.3, 6.1], [6.0, 5.3], [1.5, 4.9], [5.6, 5.7], [5.9, 4.5], [4.9, 6.1], [6.0, 4.6], [5.0, 5.4],[3.4, 6.1], [5.9, 4.9],[2.8, 5.4], [1.9, 5.3], [3.2, 5.8], [1.2, 4.7], [3.1, 6.3], [1.2, 4.0], [6.0, 5.7], [2.7, 6.0],[3.4, 6.0], [5.9, 5.4]])
class2_points = np.array([[6.5, 2.5], [6.4, 2.3], [6.6, 2.8], [7.0, 2.6], [4.3, 2.9], [4.1, 3.7], [3.9, 3.3], [7.2, 2.7],[3.8, 4.5], [4.0, 4.7],[4.0, 3.9], [8.3, 3.8], [6.5, 3.1], [8.0, 3.6], [7.9, 3.4], [6.8, 2.5], [4.0, 4.4], [7.0, 2.6],[7.7, 3.1], [6.0, 2.1],[6.7, 2.7], [8.7, 4.2], [4.0, 3.9], [5.9, 2.2], [6.3, 2.7], [7.3, 2.9], [5.0, 2.6], [8.1, 3.9],[4.2, 4.0], [5.1, 2.5],[8.2, 3.3], [7.1, 2.9], [5.0, 3.0], [7.1, 2.3], [4.8, 3.1], [3.5, 4.4], [8.3, 3.3], [5.2, 3.0],[6.1, 2.2], [6.8, 2.2],[3.9, 4.9], [8.6, 3.6], [6.0, 2.3], [4.1, 4.0], [5.2, 2.8], [8.2, 3.5], [8.1, 3.4], [8.7, 4.9],[5.0, 2.4], [5.0, 2.6],[8.0, 3.0], [8.4, 4.3], [5.3, 2.7], [8.7, 5.1], [5.6, 2.5], [5.4, 2.7], [3.8, 4.5], [9.1, 4.3],[8.8, 4.1], [4.7, 3.3],[8.4, 4.6], [8.3, 4.5], [7.0, 2.7], [6.4, 2.3], [5.2, 2.5], [7.0, 2.2], [8.6, 3.3], [7.5, 3.0],[4.0, 3.9], [7.6, 3.0],[7.0, 2.7], [4.3, 3.1], [5.7, 2.8], [3.8, 4.3], [4.9, 3.1], [4.1, 3.3], [7.0, 2.3], [5.1, 2.9],[8.9, 4.5], [6.0, 2.7],[7.4, 2.6], [8.7, 4.7], [8.6, 4.5], [7.7, 3.0], [8.9, 5.0], [4.1, 4.0], [3.9, 4.8], [3.7, 3.8],[5.5, 2.3], [7.5, 3.4],[4.2, 3.3], [4.1, 3.5], [7.8, 3.1], [3.8, 4.7], [5.2, 3.3], [3.5, 4.7], [3.5, 4.8], [3.9, 4.2],[6.7, 3.1], [7.9, 3.0],[8.6, 4.1], [8.5, 4.4], [7.3, 2.6], [3.4, 4.7], [8.7, 3.9], [7.6, 3.0], [4.6, 3.1], [4.8, 2.7],[4.5, 2.5], [7.4, 2.9],[5.1, 2.7], [6.9, 2.7], [7.6, 2.6], [9.0, 5.0], [7.1, 2.2], [5.0, 2.7], [5.6, 2.4], [3.6, 4.8],[6.0, 2.4], [6.9, 2.9],[8.3, 4.9], [3.9, 4.0], [4.9, 3.1], [8.7, 3.9], [6.3, 2.4], [6.8, 2.5], [5.8, 2.1], [4.5, 4.1],[4.7, 3.2], [6.3, 2.6],[8.8, 4.8], [8.6, 4.1], [4.5, 3.8], [3.6, 4.3], [8.8, 5.0], [4.2, 3.9], [8.6, 4.4], [8.8, 4.0],[5.0, 3.4], [6.4, 2.5],[4.6, 2.6], [6.0, 2.6], [8.1, 3.5], [8.7, 4.5], [4.8, 2.8], [5.9, 2.7], [6.8, 2.6], [8.9, 4.6],[6.4, 2.6], [6.9, 2.5],[8.8, 3.3], [3.7, 4.0], [8.3, 4.0], [3.6, 4.3], [7.2, 2.2], [8.8, 4.4], [8.7, 4.7], [3.8, 4.4],[8.1, 3.4], [3.5, 4.7],[8.7, 4.1], [4.3, 3.8], [3.6, 4.0], [5.0, 2.7], [7.7, 3.2], [8.4, 3.2], [4.3, 3.7], [8.6, 4.3],[7.5, 3.2], [8.3, 3.8],[4.9, 2.9], [5.4, 2.4], [3.9, 4.9], [8.9, 3.6], [8.3, 3.4], [8.2, 3.3], [7.8, 2.8], [8.2, 3.2],[8.9, 4.8], [8.6, 3.8],[3.9, 5.3], [4.4, 4.6], [7.8, 3.0], [6.9, 2.7], [7.7, 3.0], [3.7, 3.7], [6.6, 3.0], [5.3, 2.6],[4.4, 4.1], [8.1, 3.6],[8.5, 3.4], [8.0, 3.7], [5.2, 2.7], [7.3, 2.8], [4.1, 4.0], [8.5, 3.6], [7.5, 2.4], [3.9, 3.8],[5.9, 2.5], [6.6, 2.9],[4.4, 3.4], [4.8, 3.3], [4.4, 3.1], [8.7, 4.8], [6.2, 2.7], [5.0, 3.2], [5.6, 2.7], [8.5, 4.2],[4.2, 3.5], [4.0, 3.1],[3.8, 4.1], [5.3, 2.2], [4.9, 3.3], [5.7, 3.1], [4.4, 3.5], [5.3, 2.8], [4.2, 3.3], [8.4, 3.6],[8.1, 3.5], [3.8, 4.4],[3.6, 4.3], [4.3, 4.6], [7.9, 3.1], [8.9, 4.9], [7.8, 3.2], [4.1, 3.7], [4.8, 3.1], [3.7, 4.3],[8.5, 3.8], [5.2, 2.7],[7.3, 2.8], [6.5, 2.6], [8.4, 4.3], [8.2, 4.0], [7.2, 2.9], [3.7, 4.2], [7.6, 2.6], [4.3, 4.7],[4.5, 3.5], [4.0, 4.2],[6.4, 2.7], [6.3, 2.6], [8.9, 3.9], [5.8, 2.3], [6.1, 2.6], [4.1, 3.7], [8.2, 3.1], [9.1, 4.5],[3.7, 4.1], [6.3, 2.7]])# 将 class1_points 分割为训练集和测试集
np.random.shuffle(class1_points) # 随机打乱数据
split_index = int(0.1 * len(class1_points)) # 取前10%的数据作为测试集 # 将 class1 和 class2 中的数据分为训练和测试集
class1_train_points = class1_points[split_index:] # 90%的 class1 数据作为训练集
class2_train_points = class2_points[split_index:] # 90%的 class2 数据作为训练集
class1_test_points = class1_points[:split_index] # 10%的 class1 数据作为测试集
class2_test_points = class2_points[:split_index] # 10%的 class2 数据作为测试集 # 合并训练集
train_points = np.concatenate((class1_train_points, class2_train_points)) # 合并两个类别的训练点
# 创建训练标签,类别1用0表示,类别2用1表示
train_labels1 = np.zeros(len(class1_train_points)) # 类别1的标签
train_labels2 = np.ones(len(class2_train_points)) # 类别2的标签
train_labels = np.concatenate((train_labels1, train_labels2)) # 合并所有训练标签 # 合并测试集
test_points = np.concatenate((class1_test_points, class2_test_points)) # 合并两个类别的测试点
# 创建测试标签
test_labels1 = np.zeros(len(class1_test_points)) # 类别1的标签
test_labels2 = np.ones(len(class2_test_points)) # 类别2的标签
test_labels = np.concatenate((test_labels1, test_labels2)) # 合并所有测试标签 # 2. 定义前向模型
class YourModelClass(nn.Module): def __init__(self): super(YourModelClass, self).__init__() # 定义六层的全连接神经网络结构 self.layer1 = nn.Linear(2, 8) # 输入层到第一隐藏层 self.layer2 = nn.Linear(8, 16) # 第一隐藏层到第二隐藏层 self.layer3 = nn.Linear(16, 32) # 第二隐藏层到第三隐藏层 self.layer4 = nn.Linear(32, 16) # 第三隐藏层到第四隐藏层 self.layer5 = nn.Linear(16, 8) # 第四隐藏层到第五隐藏层 self.layer6 = nn.Linear(8, 2) # 第五隐藏层到输出层 def forward(self, x): # 前向传播函数 x = torch.tanh(self.layer1(x)) # 使用tanh激活函数 x = torch.tanh(self.layer2(x)) x = torch.tanh(self.layer3(x)) x = torch.tanh(self.layer4(x)) x = torch.tanh(self.layer5(x)) x = torch.softmax(self.layer6(x), dim=1) # 使用softmax激活函数进行分类 return x # 实例化模型
model = YourModelClass() # 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失用于多分类问题
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.005) # Adam优化器,学习率和权重衰减 # 创建等高线绘图的网格点
x_min, x_max = 0, 10
y_min, y_max = 0, 10
step_size = 0.2
xx, yy = np.meshgrid(np.arange(x_min, x_max, step_size), np.arange(y_min, y_max, step_size)) # 生成网格点
grid_points = np.c_[xx.ravel(), yy.ravel()] # 将网格点展平为二维数组 # 创建图形和子图
fig = plt.figure(figsize=(10, 5)) ax1 = fig.add_subplot(121) # 左侧图
ax2 = fig.add_subplot(122) # 右侧图 step_list = [] # 存储训练步数
loss_list = [] # 存储训练损失
test_step_list = [] # 存储测试步数
test_loss_list = [] # 存储测试损失 # 4. 开始迭代
num_iterations = 2000
for n in range(num_iterations + 1): # 将numpy数据转换为torch tensor inputs = torch.tensor(train_points, dtype=torch.float32) # 将训练点转换为张量 train_labels = torch.tensor(train_labels, dtype=torch.long) # 将训练标签转换为张量 # 前向传播 outputs = model(inputs) # 得到模型输出 loss = criterion(outputs, train_labels) # 计算损失 # 反向传播和优化 optimizer.zero_grad() # 清除梯度 loss.backward() # 反向传播计算梯度 optimizer.step() # 更新参数 # 更新损失图数据 step_list.append(n) # 记录当前步数 loss_list.append(loss.detach()) # 记录当前损失值 # 5. 显示频率设置 frequency_display = 50 # 每50步输出一次信息 # 6. 显示与输出 if n % frequency_display == 0 or n == 1: # 使用训练好的模型预测网格点的标签 grid_points_tensor = torch.tensor(grid_points, dtype=torch.float32) # 将网格点转换为张量 Z = model(grid_points_tensor).detach().numpy() # 得到予测输出 Z = Z[:, 1] # 取类别2的概率值(1的列) Z = Z.reshape(xx.shape) # 调整Z的形状以适应网格 # 绘制2D图形 ax1.clear() # 清除当前图 ax1.scatter(class1_train_points[:, 0], class1_train_points[:, 1], c='blue', label='label1') # 类别1的点 ax1.scatter(class2_train_points[:, 0], class2_train_points[:, 1], c='red', label='label2') # 类别2的点 ax1.contour(xx, yy, Z, levels=[0.5], colors='black') # 绘制等高线 # 计算测试集损失 test_inputs = torch.tensor(test_points, dtype=torch.float32) # 将测试点转换为张量 y_pred_test = model(test_inputs) # 得到模型输出 test_labels = torch.tensor(test_labels, dtype=torch.long) # 将测试标签转换为张量 loss_test = criterion(y_pred_test, test_labels) # 计算测试集损失 test_step_list.append(n) # 记录测试步数 test_loss_list.append(loss_test.detach()) # 记录测试损失 ax2.clear() # 清除当前损失图 ax2.plot(step_list, loss_list, 'r-', label='Train Loss') # 绘制训练损失 ax2.plot(test_step_list, test_loss_list, 'b-', label='Test Loss') # 绘制测试损失 ax2.set_xlabel("Step") # x轴标签 ax2.set_ylabel("Loss") # y轴标签 ax2.legend() # 显示图例 plt.show() # 展示图形
相关文章:
Python----深度学习(基于深度学习Pytroch簇分类,圆环分类,月牙分类)
一、引言 深度学习的重要性 深度学习是一种通过模拟人脑神经元结构来进行数据学习和模式识别的技术,在分类任务中展现出强大的能力。 分类任务的多样性 分类任务涵盖了各种场景,例如簇分类、圆环分类和月牙分类,每种任务都有不同的…...
Python图像处理——基于Retinex算法的低光照图像增强系统
1.项目内容 (1)算法介绍 ①MSRCR (Multi-Scale Retinex with Color Restoration) MSRCR 是多尺度 Retinex 算法(MSR)的扩展版,引入了色彩恢复机制以进一步提升图像增强质量。MSR 能有效地压缩图像动态范围ÿ…...
【网络】MQTT协议
MQTT协议全称是(Message Queuing Telemetry Transport),即消息队列遥测传输协议 是一种基于发布/订阅(publish/subscribe)模式的“轻量级”通讯协议,该协议构建于TCP/IP协议上 MQTT通信模型 特点: 1、客户端使用它连…...
python基础-requests结合AI实现自动化数据抓取
Python Requests高级指南:从入门到精通 概述 Requests 是一个基于 urllib3 封装的 Python HTTP 客户端库,提供了极其简洁且人性化的接口,使得发送 HTTP 请求和处理响应变得轻而易举。它支持常见的 HTTP 方法(GET、POST、PUT、DE…...
边界凸台建模与实例
文章目录 边界凸台特征耳机案例瓶子 边界凸台特征 两侧对称拉伸最上面的圆柱 同过两点一基准面画草图,在基准面上画椭圆 隐藏无关的实体和草图,以便椭圆的端点能与线给穿透约束,下面的点与下面的线也给穿透,短轴长给35(…...
Kafka和Spark-Streaming
Kafka和Spark-Streaming 一、Kafka 1、Kafka和Flume的整合 ① 需求1:利用flume监控某目录中新生成的文件,将监控到的变更数据发送给kafka,kafka将收到的数据打印到控制台: 在flume/conf下添加.conf文件, vi flume…...
5.2 AutoGen:支持多Agent对话的开源框架,适合自动化任务
AutoGen作为由Microsoft开发的开源框架,已成为构建多Agent对话系统和自动化任务的领先工具。其核心在于通过自然语言和代码驱动的多Agent对话,支持复杂任务的自治执行或结合人类反馈优化,广泛应用于客服自动化、金融分析、供应链优化和医疗诊…...
探索亚马逊云科技:开启您的云计算之旅
前言 在当今数字化时代,云计算已成为企业和个人不可或缺的技术基础设施。作为全球领先的云服务提供商,亚马逊云科技(Amazon Web Services)为您提供强大、可靠且安全的云计算解决方案。 想要立即体验亚马逊云科技的强大功能&#x…...
2023年第十四届蓝桥杯Scratch02月stema选拔赛真题——算式题
完整题目可点击下方地址查看,支持在线编程,支持源码和素材获取: 算式题_scratch_少儿编程题库学习中心-嗨信奥https://www.hixinao.com/tiku/scratch/show-4267.html?_shareid3 程序演示可点击下方地址查看,支持源码和素材获取&…...
霍格软件测试-JMeter高级性能测试一期
课程大小:32.2G 课程下载:https://download.csdn.net/download/m0_66047725/90631395 更多资源下载:关注我 当下BAT、TMD等互联网一线企业已几乎不再招募传统测试工程师,而只招测试开发工程师!在软件测试技术栈迭代…...
django.db.utils.OperationalError: (1050, “Table ‘你的表名‘ already exists“)
这个错误意味着 Django 尝试执行迁移时,发现数据库中已经有一张叫 你的表名的表了,但这张表不是通过 Django 当前的迁移系统管理的,或者迁移状态和数据库实际状态不一致。 🧠 可能出现这个问题的几种情况: 1.你手动创…...
分布式ID生成方案详解
分布式ID生成方案详解 一、问题背景 分库分表场景下,传统自增ID会导致不同库/表的ID重复,需要分布式ID生成方案解决以下核心需求: •全局唯一性:跨数据库/表的ID不重复 •有序性:利于索引优化和范围查询 •高性能&…...
短视频矩阵系统可视化剪辑功能开发,支持OEM
在短视频营销与内容创作竞争日益激烈的当下,矩阵系统中的可视化剪辑功能成为提升内容产出效率与质量的关键模块。它以直观的操作界面和强大的编辑能力,帮助创作者快速将创意转化为优质视频。本文将结合实际开发经验,从需求分析、技术选型到核…...
使用开源免费雷池WAF防火墙,接入保护你的网站
使用开源免费雷池WAF防火墙,接入保护你的网站 大家好,我是星哥,昨天介绍了《开源免费WEB防火墙,不让黑客越雷池一步!》链接:https://mp.weixin.qq.com/s/9TOXth3128N6PtXhaWI5aw 今天讲一下如何把网站接入…...
Python-Agent调用多个Server-FastAPI版本
Python-Agent调用多个Server-FastAPI版本 Agent调用多个McpServer进行工具调用 1-核心知识点 fastAPI的快速使用agent调用多个server 2-思路整理 1)先把每个子服务搭建起来2)再暴露一个Agent 3-参考网址 VSCode配置Python开发环境:https:/…...
spark-standalone模式
Spark Standalone模式是Spark集群的一种部署方式,即在没有使用其他资源管理器(如YARN或Mesos)的情况下,在Spark自身提供的集群管理器中部署和运行Spark应用程序。 在Spark Standalone模式下,有一个主节点(…...
3、LangChain基础:LangChain Chat Model
Prompt templates: Few shot、Example selector Few shot(少量示例) 创建少量示例的格式化程序 创建一个简单的提示模板,用于在生成时向模型提供示例输入和输出。向LLM提供少量这样的示例被称为少量示例,这是一种简单但强大的指导生成的方式,在某些情况下可以显著提高模型…...
信创时代开发工具选择指南:国产替代背景下的技术生态与实践路径
🧑 博主简介:CSDN博客专家、CSDN平台优质创作者,高级开发工程师,数学专业,10年以上C/C, C#, Java等多种编程语言开发经验,拥有高级工程师证书;擅长C/C、C#等开发语言,熟悉Java常用开…...
Coze高阶玩法 | 使用Coze制作思维认知提升视频,效率提升300%!(附保姆级教程)
目录 一、工作流整体设计 二、制作工作流 2.1 开始节点 2.2 大模型 2.3 文本处理 2.4 代码 2.5 批处理 2.6 选择器 2.7 画板_视频模板 2.8 合成音频 2.9 图片与音频合并视频 2.10 视频合并 2.11 结束节点 三、智能体应用体验 中午吃饭的时候,刷到了一个思维认知…...
数据湖DataLake和传统数据仓库Datawarehouse的主要区别是什么?优缺点是什么?
数据湖和传统数据仓库的主要区别 以下是数据湖和传统数据仓库的主要区别,以表格形式展示: 特性数据湖传统数据仓库数据类型支持结构化、半结构化及非结构化数据主要处理结构化数据架构设计扁平化架构,所有数据存储在一个大的“池”中多层架…...
GStreamer 简明教程(十一):插件开发,以一个音频生成(Audio Source)插件为例
系列文章目录 GStreamer 简明教程(一):环境搭建,运行 Basic Tutorial 1 Hello world! GStreamer 简明教程(二):基本概念介绍,Element 和 Pipeline GStreamer 简明教程(三…...
chrome://inspect/#devices 调试 HTTP/1.1 404 Not Found 如何解决
使用chrome是需要翻墙的,可以换个浏览器进行使用 可以使用edge浏览器,下载地址如下 微软官方edge浏览器|Mac版:浏览更智能,工作更高效 下载Edge浏览器 edge://inspect/#devices 点击inspect即可 qq浏览器 1. 下载qq浏览器 2. …...
RFID使用指南
## 什么是RFID? RFID(Radio Frequency Identification)即射频识别技术,是一种通过无线电波进行非接触式数据交换的技术。 ## RFID系统的主要组成部分 1. **RFID标签(Tag)** - 包含芯片和天线 - 分为有源标…...
初识Redis · 哨兵机制
目录 前言: 引入哨兵 模拟哨兵机制 配置docker环境 基于docker环境搭建哨兵环境 对比三种配置文件 编排主从节点和sentinel 主从节点 sentinel 模拟哨兵 前言: 在前文我们介绍了Redis的主从复制有一个最大的缺点就是,主节点挂了之…...
JAVA设计模式——(七)代理模式
JAVA设计模式——(七)代理模式 介绍理解实现抽象主题角色具体主题角色代理类测试 应用 介绍 代理模式和装饰模式还是挺像的。装饰模式是抽象类对装饰对象的实现,在继承装饰对象。代理模式则是直接对代理对象的实现。 理解 代理模式可以看成…...
Redis 原子操作
文章目录 前言✅ 一、什么是「原子操作」?🔍 二、怎么判断一个操作是否原子?🧪 三、项目中的原子 vs 非原子案例(秒杀系统)✅ 原子性(OK)❌ 非原子性(高风险)…...
待办事项日历组件实现
待办事项日历组件实现 今天积累一个简易的待办事项日历组件的实现方法。 需求: 修改样式,变成符合项目要求的日历样式日历上展示待办事项提示(有未完成待办:展示黄点,有已完成待办:展示绿点)…...
Flask 请求数据获取方法详解
一、工作原理 在 Flask 中,所有客户端请求的数据都通过全局的 request 对象访问。该对象是 请求上下文 的一部分,仅在请求处理期间存在。Flask 在收到请求时自动创建 request 对象,并根据请求类型(如 GET、POST)和内容…...
PicoVR眼镜在XR融合现实显示模式下无法显示粒子问题
PicoVR眼镜开启XR融合现实显示模式下,Unity3D粒子效果无法显示问题,其原因是XR融合显示模式下,Unity3D应用显示层在最终合成到眼镜显示器时,驱动层先渲染摄像机画面,再以Alpha透明方式渲染应用层画面,问题就…...
vue-lottie的使用和配置
一、vue-lottie 简介 vue-lottie 是一个 Vue 组件,用于在 Vue 项目中集成 Airbnb 的 Lottie 动画库。它通过 JSON 文件渲染 After Effects 动画,适用于复杂矢量动画的高效展示。 二、安装与基础使用 1. 安装 npm install vue-lottielatest # 或 yarn…...
PyTorch 实现食物图像分类实战:从数据处理到模型训练
一、简介 在计算机视觉领域,图像分类是一项基础且重要的任务,广泛应用于智能安防、医疗诊断、电商推荐等场景。本文将以食物图像分类为例,基于 PyTorch 框架,详细介绍从数据准备、模型构建到训练测试的全流程,帮助读者…...
传统中台的重生——云原生如何重塑政务系统后端架构
📝个人主页🌹:一ge科研小菜鸡-CSDN博客 🌹🌹期待您的关注 🌹🌹 一、引言:传统后端架构的“痛”与“变” 在过去十年中,无数企业和机构纷纷构建中台系统,尤其是政务、金融、交通、教育等领域。这些中台系统一般基于 Java EE 单体架构,集中部署于虚拟机上,靠人…...
jQuery AJAX、Axios与Fetch
jQuery AJAX、Axios与Fetch对比 #mermaid-svg-FRNqb7d4i2fmbavm {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-FRNqb7d4i2fmbavm .error-icon{fill:#552222;}#mermaid-svg-FRNqb7d4i2fmbavm .error-text{fill:#552…...
【Hive入门】Hive数据导出完全指南:从HDFS到本地文件系统的专业实践
目录 引言 1 Hive数据导出概述 1.1 数据导出的核心概念 1.2 典型导出场景 2 Hive到HDFS导出详解 2.1 INSERT OVERWRITE DIRECTORY方法 2.2 多目录导出技术 2.3 动态分区导出 3 HDFS到本地文件系统转移 3.1 hadoop fs命令操作 3.2 分布式拷贝工具DistCp 4 直接导出到…...
stack __ queue(栈和队列)
1. stack的介绍和使用 栈和队列里面都叫容器适配器 存储数据就要交给别的容器 通过封装别的容器,可以进行相应的操作,来达到目的 适配的本质就是复用 这就没有迭代器了,不支持随便遍历 2. queue的介绍和使用 下面用一些题来深入理解 栈…...
UML 类图基础和类关系辨析
UML 类图 目录 1 概述 2 类图MerMaid基本表示法 3 类关系详解 3.1 实现和继承 3.1.1 实现(Realization)3.1.2 继承/泛化(Inheritance/Generalization) 3.2 聚合和组合 3.2.1 组合(Composition)3.2.2 聚…...
STM32F103C8T6信息
STM32F103C8T6 完整参数列表 一、核心参数 内核架构 ARM Cortex-M3 32位RISC处理器 最大主频:72 MHz(基于APB总线时钟) 运算性能:1.25 DMIPS/MHz(Dhrystone 2.1基准) 总线与存储 总线宽度ÿ…...
unity 读取csv
1.读取代码 string filePath Application.streamingAssetsPath "\\data.csv"; public List<MovieData> movieData new List<MovieData>(); private void ReadCSV(string filePath) { List<List<string>> data new List<…...
那些年踩过的坑之Arrays.asList
一、前言 熟悉开发的兄弟都知道,在写新增和删除功能的时候,大多数时候会写成批量的,原因也很简单,批量既支持单个也支持多个对象的操作,事情也是发生在这个批量方法的调用上,下面我简单说一下这个事情。 二…...
ASP.NET Core 自动识别 appsettings.json的机制解析
ASP.NET Core 自动识别 appsettings.json 的机制解析 在 ASP.NET Core 中,IConfiguration 能自动识别 appsettings.json 并直接读取值的机制,是通过框架的 “约定优于配置” 设计和 依赖注入系统 共同实现的。以下是详细原理: 默认配置源的自…...
深入解析Mlivus Cloud核心架构:rootcoord组件的最佳实践与调优指南
作为大禹智库的向量数据库高级研究员,同时也是《向量数据库指南》的作者,我在过去30年的向量数据库和AI应用实战中见证了这项技术的演进与革新。今天,我将以专业视角为您深入剖析Mlivus Cloud的核心组件之一——rootcoord,这个组件在系统架构中扮演着至关重要的角色。如果您…...
ApplicationEventPublisher用法-笔记
1.ApplicationEventPublisher简介 org.springframework.context.ApplicationEventPublisher 是 Spring 框架中用于发布自定义事件的核心接口。它允许你在 Spring 应用上下文中触发事件,并由其他组件(监听器)进行响应。 ApplicationEventPub…...
数字孪生:从概念到实践,重构未来产业的“虚拟镜像”
一、开篇:为什么数字孪生是下一个技术风口? 现象级案例引入: “特斯拉用数字孪生技术将电池故障预测准确率提升40%;西门子通过虚拟工厂模型缩短30%产品研发周期;波音777X飞机设计全程零实物原型……” 数据支撑&#…...
Python笔记:VS2013编译Python-3.5.10
注:本文是编译老版本,有点麻烦,测试了编译新版,基本上是傻瓜是操作即可 1. python官网下载源码 https://www.python.org/ftp/python/3.5.10/Python-3.5.10.tgz 2. 编译前查看目录中相关文档 源码目录结构 看README文档 经过查…...
STM32八股【6】-----CortexM3的双堆栈(MSP、PSP)设计
STM32的线程模式(Thread Mode)和内核模式(Handler Mode)以及其对应的权级和堆栈指针 线程模式: 正常代码执行时的模式(如 main 函数、FreeRTOS任务) 可以是特权级(使用MSPÿ…...
MySQL触法器
1. 什么是触发器及其特点 MySQL数据库中触发器是一个特殊的存储过程,不同的是执行存储过程要使用 CALL 语句来调用,而触发器的执行不需要使用 CALL 语句来调用,也不需要手工启动,只要一个预定义的事件发生就会被 MySQL自动调用。…...
金仓数据库征文-政务领域国产化数据库更替:金仓 KingbaseES 应用实践
目录 一.金仓数据库介绍 二.政务领域数据库替换的时代需求 三.金仓数据库 KingbaseES 在政务领域的替换优势 1.强大的兼容性与迁移能力 2.高安全性与稳定性保障 3.良好的国产化适配性 四.金仓数据库 KingbaseES 在政务领域的典型应用实践 1.电子政务办公系…...
微服务架构在云原生后端的深度融合与实践路径
📝个人主页🌹:一ge科研小菜鸡-CSDN博客 🌹🌹期待您的关注 🌹🌹 一、引言:后端架构的演变,走向云原生与微服务融合 过去十余年,后端架构经历了从单体应用(Monolithic)、垂直切分(Modularization)、到微服务(Microservices)的演进,每一次变化都是为了解决…...
北斗导航 | 北斗卫星导航单点定位与深度学习结合提升精度
以下是北斗卫星导航单点定位(SPP)与深度学习结合提升精度的关键方法总结,综合了误差建模、信号识别、动态环境适应等技术方向: 一、非直射信号(NLOS)抑制与权重修正 1. 双自注意力网络(Dual Self-Attention Network) 原理:通过同时建模卫星信号的空间环境特征(如天空…...
AlarmClock4.8.4(官方版)桌面时钟工具软件下载安装教程
1.软件名称:AlarmClock 2.软件版本:4.8.4 3.软件大小:187 MB 4.安装环境:win7/win10/win11(64位) 5.下载地址: https://www.kdocs.cn/l/cdZMwizD2ZL1?RL1MvMTM%3D 提示:先转存后下载,防止资…...