1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
| def accuracy(output, target, topk=(1,)): maxk = max(topk) batch_size = target.shape[0]
pred = tf.math.top_k(output, maxk).indices print(pred)
pred = tf.transpose(pred, perm=[1, 0]) print(pred)
target_ = tf.broadcast_to(target, pred.shape) print(target_) correct = tf.equal(pred, target_) print(correct)
res = [] for k in topk: correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32) correct_k = tf.reduce_sum(correct_k) acc = float(correct_k * (100.0 / batch_size)) res.append(acc)
return res
output = tf.random.normal([4, 3]) output = tf.math.softmax(output, axis=1) print('output:', output.numpy()) target = tf.random.uniform([4], maxval=3, dtype=tf.int32) print('target:', target.numpy()) acc = accuracy(output, target, topk=(1, 2, 3)) print('top1-6 acc:', acc)
|