网络Net¶
net类是对网络结构的抽象,完成对层的组织和数据流的控制,目前尽支持单方向无分叉的序列结构。
初始化 net(layer_stack=list(), loss_func, reg)¶
建立net对象需要层的顺序列表(每一项都是layer对象),损失函数loss_func和正则化参数reg。
loss_func为function模块中定义的函数,目前支持svm_loss(x, y)和softmax_loss(x, y)。
loss_func同时返回损失和反向梯度,是前向计算的终点和反向传播的起点。
forward(data_batch)¶
forward()方法完成前向计算部分,递归调用每一层的forward()方法并缓存每一层的输入。
loss(X_batch, y_batch)¶
loss()方法是对forward()的包装,接收label数据计算损失并返回反向梯度。
训练时调用loss()方法,推断时调用forward()方法。