模型 model¶
model
类封装了训练和推断的逻辑。主要功能是准备数据(划分batch)、初始化网络(根据传入的net
对象)和管理训练过程。
初始化 model(net, data, **kargs)
¶
必选参数:
net
: 定义网络结构的net
对象data
: 训练或测试用数据,dict
对象,包括X_train
,X_val
,y_train
,y_val
可选参数:
update_rule
:优化器函数optim_config
:Dict, 优化器的初始配置lr_decay
:Float, 学习率调控batch_size
:Int, 批次大小,默认100num_epochs
:Int, 训练批次,默认10print_every
:Int, 打印频次,默认10verbose
:Boolean, 是否显示进度,默认True
num_train_sample
: Int, 训练用样本数,默认是1000,设置为None
则使用传入的全部数据num_val_sample
: Int,验证集样本数,默认None
,即使用全部验证样本checkpoint_name
:存档点路径及名称
warmup()
¶
model
对象建立后即可执行,传入一个batch的数据来初始化网络的习得参数。
train()
¶
train
方法完成训练过程的逻辑,首先是划分数据的批次,完成模型的一次参数更新,在一个epoch结束后在验证集上检测结果,保存存档点和当前最佳参数。
参数更新用到内部方法_step()
,核心逻辑如下:
# foward pass
loss, dout = self.net.loss(X_batch, y_batch)
# backward pass
self.net.backward(self.optimizer, dout)