网络Net¶
net
类是对网络结构的抽象,完成对层的组织和数据流的控制,目前尽支持单方向无分叉的序列结构。
初始化 net(layer_stack=list(), loss_func, reg)
¶
建立net
对象需要层的顺序列表(每一项都是layer
对象),损失函数loss_func
和正则化参数reg
。
loss_func
为function
模块中定义的函数,目前支持svm_loss(x, y)
和softmax_loss(x, y)
。
loss_func
同时返回损失和反向梯度,是前向计算的终点和反向传播的起点。
forward(data_batch)
¶
forward()
方法完成前向计算部分,递归调用每一层的forward()
方法并缓存每一层的输入。
loss(X_batch, y_batch)
¶
loss()
方法是对forward()
的包装,接收label
数据计算损失并返回反向梯度。
训练时调用loss()
方法,推断时调用forward()
方法。