模型 model¶
model类封装了训练和推断的逻辑。主要功能是准备数据(划分batch)、初始化网络(根据传入的net对象)和管理训练过程。
初始化 model(net, data, **kargs)¶
必选参数:
net: 定义网络结构的net对象data: 训练或测试用数据,dict对象,包括X_train,X_val,y_train,y_val
可选参数:
update_rule:优化器函数optim_config:Dict, 优化器的初始配置lr_decay:Float, 学习率调控batch_size:Int, 批次大小,默认100num_epochs:Int, 训练批次,默认10print_every:Int, 打印频次,默认10verbose:Boolean, 是否显示进度,默认Truenum_train_sample: Int, 训练用样本数,默认是1000,设置为None则使用传入的全部数据num_val_sample: Int,验证集样本数,默认None,即使用全部验证样本checkpoint_name:存档点路径及名称
warmup()¶
model对象建立后即可执行,传入一个batch的数据来初始化网络的习得参数。
train()¶
train方法完成训练过程的逻辑,首先是划分数据的批次,完成模型的一次参数更新,在一个epoch结束后在验证集上检测结果,保存存档点和当前最佳参数。
参数更新用到内部方法_step(),核心逻辑如下:
# foward pass
loss, dout = self.net.loss(X_batch, y_batch)
# backward pass
self.net.backward(self.optimizer, dout)