网络Net

net类是对网络结构的抽象,完成对层的组织和数据流的控制,目前尽支持单方向无分叉的序列结构。

初始化 net(layer_stack=list(), loss_func, reg)

建立net对象需要层的顺序列表(每一项都是layer对象),损失函数loss_func和正则化参数reg

loss_funcfunction模块中定义的函数,目前支持svm_loss(x, y)softmax_loss(x, y)

loss_func同时返回损失和反向梯度,是前向计算的终点和反向传播的起点。

forward(data_batch)

forward()方法完成前向计算部分,递归调用每一层的forward()方法并缓存每一层的输入。

loss(X_batch, y_batch)

loss()方法是对forward()的包装,接收label数据计算损失并返回反向梯度。

训练时调用loss()方法,推断时调用forward()方法。

backward(optimizer, dout)

backward()方法完成反向传播过程,流数据为loss()传过来的反向梯度和forward()过程缓存的每一层输入,调用layer对象的grad()update()方法来计算梯度并更新习得参数。

了解更多有关net的作用,参见设计理念