代码之家  ›  专栏  ›  技术社区  ›  ytrewq

从火炬中删除项目。张量

  •  0
  • ytrewq  · 技术社区  · 7 年前

    我在lua中编写了以下代码。

    scores 以及相应的分数。

    看起来我必须迭代地从 分数

    nqs=dataset['question']:size(1);
    scores=torch.Tensor(nqs,noutput);
    qids=torch.LongTensor(nqs);
    for i=1,nqs,batch_size do
        xlua.progress(i, nqs)
        r=math.min(i+batch_size-1,nqs);
        scores[{{i,r},{}}],qids[{{i,r}}]=forward(i,r);
    --    print(scores)
    end
    
    tmp,pred=torch.max(scores,2);
    
    1 回复  |  直到 7 年前
        1
  •  1
  •   Ash    7 年前

     sr=scores:view(-1,scores:size(1)*scores:size(2))
     val,id=sr:sort()
     --val is a row vector with the values stored in increasing order
     --id will be the corresponding index in sr
     --now you can slice val and id from the end to find the N values you want, then you can recover the original index in the scores matrix simply with
     col=(index-1)%scores:size(2)+1
     row=math.ceil(index/scores:size(2))