昇思25天学习打卡营第7天之二 | 模型保存与加载

1. 保存与加载

在训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型推理与部署,本章节我们将介绍如何保存与加载模型。

1.1 导入依赖

# 导入numpy库,并将其重命名为np,以便在代码中引用
import numpy as np

# 导入MindSpore库,这是华为推出的一个开源深度学习框架,用于构建和训练神经网络
import mindspore

# 从MindSpore库中导入nn模块,这个模块包含了构建神经网络所需的各种层和函数
from mindspore import nn

# 从MindSpore库中导入Tensor模块,Tensor是MindSpore中用于表示张量的类
from mindspore import Tensor

1.1定义神经网络模型

# 定义一个函数,该函数创建一个简单的全连接神经网络模型
def network():
    
	# 使用nn.SequentialCell创建一个层序列,这是一个容器类,可以包含多个层
    model = nn.SequentialCell(
        
        # 第一个层是一个Flatten层,用于将输入的二维图像数据展平为一维向量
        nn.Flatten(),
        # 第二个层是一个全连接层,将28x28的输入节点映射到512个节点
        nn.Dense(28*28, 512),
        # 第三个层是一个ReLU激活函数,用于非线性变换
        nn.ReLU(),
        # 第四个层是一个全连接层,将512个节点映射到512个节点
        nn.Dense(512, 512),
        # 第五个层是一个ReLU激活函数,用于非线性变换
        nn.ReLU(),
        # 第六个层是一个全连接层,将512个节点映射到10个节点,对应于10个类别的输出
        nn.Dense(512, 10)
    )
    # 返回创建好的模型
    return model

1.2 保存和加载模型权重

1.2.1 保存模型

保存模型使用save_checkpoint接口,传入网络和指定的保存路径:

# 创建一个神经网络模型实例
model = network()

# 使用MindSpore的save_checkpoint函数将模型的检查点保存到文件
# 第一个参数是模型对象
# 第二个参数是文件名,这里保存为"model.ckpt"
mindspore.save_checkpoint(model, "model.ckpt")
# 打印模型结构
print(model)

输出:

SequentialCell<
  (0): Flatten<>
  (1): Dense<input_channels=784, output_channels=512, has_bias=True>
  (2): ReLU<>
  (3): Dense<input_channels=512, output_channels=512, has_bias=True>
  (4): ReLU<>
  (5): Dense<input_channels=512, output_channels=10, has_bias=True>
  >

模型大小估算:
model_capacity ≈ 模型参数 * 数据精度(默认是int32类型)大小 = [(784512+512) + (512512+512) + (512*10 +10)] *32bit/8(bit/Byte)= 669704 *4 = 2678824 Byte
可以看到,模型参数量约为67W,占用空间大小应约为2678824字节
实际该模型文件大小为2679017。可以说非常接近了,剩下的字节应该就是文件类型描述符加模型结构描述符之类的内容了。
所以当我们已知一个模型的参数量和参数精度后,实际就可以估算出模型占用的磁盘空间大小了。

1.2.2 加载模型

要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpointload_param_into_net方法加载参数。

# 创建一个神经网络模型
model = network()

# 使用MindSpore的load_checkpoint函数从文件中加载模型的参数和优化器状态
# 参数是检查点的文件名,这里加载的文件名为"model.ckpt"
param_dict = mindspore.load_checkpoint("model.ckpt")

# 使用MindSpore的load_param_into_net函数将加载的参数字典加载到模型中
# 第一个参数是模型对象
# 第二个参数是参数字典
# 返回值是一个元组,第一个元素是未加载的参数列表,第二个元素是加载的参数列表
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)

# 打印未加载的参数列表,如果加载成功,这个列表应该是空的
print(param_not_load)

输出:

[]

param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

1.3 保存和加载MindIR

除Checkpoint外,MindSpore提供了云侧(训练)和端侧(推理)统一的中间表示(Intermediate Representation,IR)。可使用export接口直接将模型保存为MindIR。

# 创建网络模型
model = network()
# 创建一个Tensor对象,它包含一个大小为[1, 1, 28, 28]的矩阵,所有元素都是1,数据类型为float32
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
# 使用mindspore.export函数将模型导出为MINDIR格式
# 第一个参数是模型对象
# 第二个参数是输入数据,这里使用了一个Tensor对象作为示例
# 第三个参数是文件名,这里导出的文件名为"model"
# 第四个参数是文件格式,这里设置为"MINDIR",表示导出的模型格式
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。

已有的MindIR模型可以方便地通过load接口加载,传入nn.GraphCell即可进行推理。

nn.GraphCell仅支持图模式。

# 设置MindSpore的执行模式为GRAPH_MODE
mindspore.set_context(mode=mindspore.GRAPH_MODE)
# 加载之前导出的MINDIR模型
graph = mindspore.load("model.mindir")
# 创建一个GraphCell对象,它将graph作为其成员
model = nn.GraphCell(graph)
# 使用模型对输入数据进行前向计算,得到输出
outputs = model(inputs)
# 打印输出的形状
print(outputs.shape)

输出:
模型加载f

2. 小结

本文主要介绍了模型的保存和加载,都包括检查点checkpoint和统一中间表示MindIR(Intermediate Representation)两种方法,还介绍了模型大小的估算方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/759340.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C语言】--分支和循环(1)

&#x1f37f;个人主页: 起名字真南 &#x1f9c7;个人专栏:【数据结构初阶】 【C语言】 目录 前言1 if 语句1.1 if1.2 else1.3 嵌套if1.4 悬空else 前言 C语言是结构化的程序设计语言&#xff0c;这里的结构指的是顺序结构、选择结构、循环结构。 我们可以用if、switch实现分支…

51单片机第6步_stdlib.h库函数

本章重点学习stdlib.h库函数。 #include <REG51.h> //包含头文件REG51.h,使能51内部寄存器; #include <stdlib.h> //float atof (char *s1); //参数s1字符串可包含正负号,小数点或E(e)来表示指数部分,如123.456或123e-2; //若首字符是非数据字符,或为正负号…

力扣每日一题 6/30 记忆化搜索/动态规划

博客主页&#xff1a;誓则盟约系列专栏&#xff1a;IT竞赛 专栏关注博主&#xff0c;后期持续更新系列文章如果有错误感谢请大家批评指出&#xff0c;及时修改感谢大家点赞&#x1f44d;收藏⭐评论✍ 494.目标和【中等】 题目&#xff1a; 给你一个非负整数数组 nums 和一个…

VMware中的三种虚拟网络模式

虚拟机网络模式 1 主机网络环境2 VMware中的三种虚拟网络模式2.1 桥接模式NAT模式仅主机模式网络模式选择1 VMware虚拟网络配置2 虚拟机选择网络模式3 Windows主机网络配置 配置静态IP 虚拟机联网方式为桥接模式&#xff0c;这种模式下&#xff0c;虚拟机通过主机的物理网卡&am…

mysql8.0-学习

文章目录 mysql8.0基础知识-学习安装mysql_8.0登录mysql8.0的体系结构与管理体系结构图连接mysqlmysql8.0的 “新姿势” mysql的日常管理用户安全权限练习查看用户的权限回收:revoke角色 mysql的多种连接方式socket显示系统中当前运行的所有线程 tcp/ip客户端工具基于SSL的安全…

2024最新boss直聘岗位数据爬虫,并进行可视化分析

前言 近年来,随着互联网的发展和就业市场的变化,数据科学与爬虫技术在招聘信息分析中的应用变得越来越重要。通过对招聘信息的爬取和可视化分析,我们可以更好地了解当前的就业市场动态、职位需求和薪资水平,从而为求职者和招聘企业提供有价值的数据支持。本文将介绍如何使…

Linux系统编程--进程间通信

目录 1. 介绍 1.1 进程间通信的目的 1.2 进程间通信的分类 2. 管道 2.1 什么是管道 2.2 匿名管道 2.2.1 接口 2.2.2 步骤--以父子进程通信为例 2.2.3 站在文件描述符角度-深度理解 2.2.4 管道代码 2.2.5 读写特征 2.2.6 管道特征 2.3 命名管道 2.3.1 接口 2.3.2…

【驱动篇】龙芯LS2K0300之i2c设备驱动

实验背景 由于官方内核i2c的BSP有问题&#xff08;怀疑是设备树这块&#xff09;&#xff0c;本次实验将不通过设备树来驱动aht20&#xff08;i2c&#xff09;模块&#xff0c;大致的操作过程如下&#xff1a; 模块连接&#xff0c;查看aht20设备地址编写device驱动&#xff…

K8S之网络深度剖析(一)(持续更新ing)

K8S之网络深度剖析 一 、关于K8S的网络模型 在K8s的世界上,IP是以Pod为单位进行分配的。一个Pod内部的所有容器共享一个网络堆栈(相当于一个网络命名空间,它们的IP地址、网络设备、配置等都是共享的)。按照这个网络原则抽象出来的为每个Pod都设置一个IP地址的模型也被称作为I…

忍法:声音克隆之术

前言&#xff1a; 最近因为一直在给肚子里面的宝宝做故事胎教&#xff0c;每天&#xff08;其实是看自己心情抽空讲下故事&#xff09;都要给宝宝讲故事&#xff0c;心想反正宝宝也看不见我&#xff0c;只听我的声音&#xff0c;干脆偷个懒&#xff0c;克隆自己的声音&#xf…

信息学奥赛初赛天天练-40-CSP-J2021基础题-组合数学-缩倍法、平均分组、2进制转10进制、面向过程/面向对象语言应用

PDF文档公众号回复关键字:20240630 2021 CSP-J 选择题 单项选择题&#xff08;共15题&#xff0c;每题2分&#xff0c;共计30分&#xff1a;每题有且仅有一个正确选项&#xff09; 1.以下不属于面向对象程序设计语言是( ) A. C B. Python C. Java D. C 2.以下奖项与计…

R包的4种安装方式及常见问题解决方法

R包的4种安装方式及常见问题解决方法 R包的四种安装方式1. install.packages()2. 从Bioconductor安装3. 从本地源码安装4. 从github安装 常见问题的解决1. 版本问题2. 网络/镜像问题3.缺少Rtools R包的四种安装方式 1. install.packages() 对于R自带的包的安装一般都可以通过…

HarmonyOS--路由管理--组件导航 (Navigation)

文档中心 什么是组件导航 (Navigation) &#xff1f; 1、Navigation是路由容器组件&#xff0c;一般作为首页的根容器&#xff0c;包括单栏(Stack)、分栏(Split)和自适应(Auto)三种显示模式 2、Navigation组件适用于模块内和跨模块的路由切换&#xff0c;一次开发&#xff0…

实现点击按钮导出页面pdf

在Vue 3 Vite项目中&#xff0c;你可以使用html2canvas和jspdf库来实现将页面某部分导出为PDF文档的功能。以下是一个简单的实现方式&#xff1a; 1.安装html2canvas和jspdf&#xff1a; pnpm install html2canvas jspdf 2.在Vue组件中使用这些库来实现导出功能&#xff1a;…

网线直连电脑可以上网,网线连tplink路由器上不了网

家里wifi网络连不上好几天了&#xff0c;用网线直连电脑可以上网&#xff0c;但网线连tplink路由器wan口上不了网&#xff0c;无Internet连接&#xff0c;网线连lan口可以电脑上网&#xff0c;手机上不了。 后来发现网线的主路由用的192.168.0.1&#xff0c;我的路由器wan口自…

在node环境使用MySQL

什么是Sequelize? Sequelize是一个基于Promise的NodeJS ORM模块 什么是ORM? ORM(Object-Relational-Mapping)是对象关系映射 对象关系映射可以把JS中的类和对象&#xff0c;和数据库中的表和数据进行关系映射。映射之后我们就可以直接通过类和对象来操作数据表和数据了, 就…

【大数据导论】大数据序言

各位大佬好 &#xff0c;这里是阿川的博客&#xff0c;祝您变得更强 个人主页&#xff1a;在线OJ的阿川 大佬的支持和鼓励&#xff0c;将是我成长路上最大的动力 阿川水平有限&#xff0c;如有错误&#xff0c;欢迎大佬指正 目录 数据概念及类型及可用及组织形式数据概念数据…

golang项目基于gorm框架从postgre数据库迁移到达梦数据库的实践

一、安装达梦数据库 1、登录达梦数据库官网&#xff0c;下载对应系统版本的安装包。 2、下载地址为&#xff1a;https://www.dameng.com/list_103.html 3、达梦数据库对大小写敏感&#xff0c;在安装初始化数据库实例时建议忽略大小写&#xff1b;具体安装教程可参考以下博客: …

python办公自动化之pandas

用到的库&#xff1a;pandas 实现效果&#xff1a;创建一张空白的表同时往里面插入准备好的数据 代码&#xff1a; import pandas # 准备好要写入的数据&#xff0c;字典格式 data{日期:[7.2,7.3],产品型号:[ca,ce],成交量:[500,600]} dfpandas.DataFrame(data) # 把数据写入…

Java代码基础算法练习-计算被 3 或 5 整除数之和-2024.06.29

任务描述&#xff1a; 计算 1 到 n 之间能够被 3 或者 5 整除的数之和。 解决思路&#xff1a; 输入的数字为 for 循环总次数&#xff0c;每次循环就以当前的 i 进行 3、5 的取余操作&#xff0c;都成立计入总数sum中&#xff0c;循环结束&#xff0c;输出 sum 的值 代码示例&…