『MXNet』第八彈_物體檢測之SSD
阿新 • • 發佈:2018-05-30
out can RR AS upd 全部 ask 類別 clu
預、API介紹
mxnet.metric
from mxnet import metric cls_metric = metric.Accuracy() box_metric = metric.MAE() cls_metric.update([cls_target], [class_preds.transpose((0,2,1))]) box_metric.update([box_target], [box_preds * box_mask]) cls_metric.get() box_metric.get()
gluon.loss.Loss
class FocalLoss(gluon.loss.Loss): def __init__(self, axis=-1, alpha=0.25, gamma=2, batch_axis=0, **kwargs): super(FocalLoss, self).__init__(None, batch_axis, **kwargs) self._axis = axis self._alpha = alpha self._gamma = gamma def hybrid_forward(self, F, output, label): # Here `F` can be either mx.nd or mx.sym # 這裏使用F取代在forward中顯式的指定兩者,方便使用 # 所以非hybrid無此參數 output = F.softmax(output) pj = output.pick(label, axis=self._axis, keepdims=True) loss = - self._alpha * ((1 - pj) ** self._gamma) * pj.log() return loss.mean(axis=self._batch_axis, exclude=True)
mxnet.contrib.ndarray.MultiBoxTarget
def training_targets(anchors, class_preds, labels): """ 得到的全部邊框坐標 得到的全部邊框各個類別得分 真實類別及對應邊框坐標 """ class_preds = class_preds.transpose(axes=(0,2,1)) return MultiBoxTarget(anchors, labels, class_preds) # Output achors: (1, 5444, 4) # Output class predictions: (1, 5444, 3) # batch.label: (1, 1, 5) out = training_targets(anchors, class_preds, batch.label[0][0:1])
mxnet.contrib.ndarray.MultiBoxDetection
『MXNet』第八彈_物體檢測之SSD