提交 b680f493 作者: Jeromy

fix wantlist removal accounting, add tests

License: MIT
Signed-off-by: 's avatarJeromy <jeromyj@gmail.com>
上级 e43d1317
...@@ -285,6 +285,9 @@ func (bs *Bitswap) getNextSessionID() uint64 { ...@@ -285,6 +285,9 @@ func (bs *Bitswap) getNextSessionID() uint64 {
// CancelWant removes a given key from the wantlist // CancelWant removes a given key from the wantlist
func (bs *Bitswap) CancelWants(cids []*cid.Cid, ses uint64) { func (bs *Bitswap) CancelWants(cids []*cid.Cid, ses uint64) {
if len(cids) == 0 {
return
}
bs.wm.CancelWants(context.Background(), cids, nil, ses) bs.wm.CancelWants(context.Background(), cids, nil, ses)
} }
......
...@@ -318,7 +318,7 @@ func TestBasicBitswap(t *testing.T) { ...@@ -318,7 +318,7 @@ func TestBasicBitswap(t *testing.T) {
t.Log("Test a one node trying to get one block from another") t.Log("Test a one node trying to get one block from another")
instances := sg.Instances(2) instances := sg.Instances(3)
blocks := bg.Blocks(1) blocks := bg.Blocks(1)
err := instances[0].Exchange.HasBlock(blocks[0]) err := instances[0].Exchange.HasBlock(blocks[0])
if err != nil { if err != nil {
...@@ -333,6 +333,10 @@ func TestBasicBitswap(t *testing.T) { ...@@ -333,6 +333,10 @@ func TestBasicBitswap(t *testing.T) {
} }
time.Sleep(time.Millisecond * 20) time.Sleep(time.Millisecond * 20)
wl := instances[2].Exchange.WantlistForPeer(instances[1].Peer)
if len(wl) != 0 {
t.Fatal("should have no items in other peers wantlist")
}
if len(instances[1].Exchange.GetWantlist()) != 0 { if len(instances[1].Exchange.GetWantlist()) != 0 {
t.Fatal("shouldnt have anything in wantlist") t.Fatal("shouldnt have anything in wantlist")
} }
......
...@@ -105,13 +105,10 @@ func NewEngine(ctx context.Context, bs bstore.Blockstore) *Engine { ...@@ -105,13 +105,10 @@ func NewEngine(ctx context.Context, bs bstore.Blockstore) *Engine {
} }
func (e *Engine) WantlistForPeer(p peer.ID) (out []*wl.Entry) { func (e *Engine) WantlistForPeer(p peer.ID) (out []*wl.Entry) {
e.lock.Lock() partner := e.findOrCreate(p)
partner, ok := e.ledgerMap[p] partner.lk.Lock()
if ok { defer partner.lk.Unlock()
out = partner.wantList.SortedEntries() return partner.wantList.SortedEntries()
}
e.lock.Unlock()
return out
} }
func (e *Engine) LedgerForPeer(p peer.ID) *Receipt { func (e *Engine) LedgerForPeer(p peer.ID) *Receipt {
......
...@@ -170,7 +170,7 @@ func (w *Wantlist) Remove(c *cid.Cid) bool { ...@@ -170,7 +170,7 @@ func (w *Wantlist) Remove(c *cid.Cid) bool {
} }
delete(w.set, k) delete(w.set, k)
return false return true
} }
func (w *Wantlist) Contains(k *cid.Cid) (*Entry, bool) { func (w *Wantlist) Contains(k *cid.Cid) (*Entry, bool) {
......
...@@ -48,9 +48,13 @@ func assertNotHasCid(t *testing.T, w wli, c *cid.Cid) { ...@@ -48,9 +48,13 @@ func assertNotHasCid(t *testing.T, w wli, c *cid.Cid) {
func TestBasicWantlist(t *testing.T) { func TestBasicWantlist(t *testing.T) {
wl := New() wl := New()
wl.Add(testcids[0], 5) if !wl.Add(testcids[0], 5) {
t.Fatal("expected true")
}
assertHasCid(t, wl, testcids[0]) assertHasCid(t, wl, testcids[0])
wl.Add(testcids[1], 4) if !wl.Add(testcids[1], 4) {
t.Fatal("expected true")
}
assertHasCid(t, wl, testcids[0]) assertHasCid(t, wl, testcids[0])
assertHasCid(t, wl, testcids[1]) assertHasCid(t, wl, testcids[1])
...@@ -58,7 +62,9 @@ func TestBasicWantlist(t *testing.T) { ...@@ -58,7 +62,9 @@ func TestBasicWantlist(t *testing.T) {
t.Fatal("should have had two items") t.Fatal("should have had two items")
} }
wl.Add(testcids[1], 4) if wl.Add(testcids[1], 4) {
t.Fatal("add shouldnt report success on second add")
}
assertHasCid(t, wl, testcids[0]) assertHasCid(t, wl, testcids[0])
assertHasCid(t, wl, testcids[1]) assertHasCid(t, wl, testcids[1])
...@@ -66,7 +72,10 @@ func TestBasicWantlist(t *testing.T) { ...@@ -66,7 +72,10 @@ func TestBasicWantlist(t *testing.T) {
t.Fatal("should have had two items") t.Fatal("should have had two items")
} }
wl.Remove(testcids[0]) if !wl.Remove(testcids[0]) {
t.Fatal("should have gotten true")
}
assertHasCid(t, wl, testcids[1]) assertHasCid(t, wl, testcids[1])
if _, has := wl.Contains(testcids[0]); has { if _, has := wl.Contains(testcids[0]); has {
t.Fatal("shouldnt have this cid") t.Fatal("shouldnt have this cid")
...@@ -76,12 +85,20 @@ func TestBasicWantlist(t *testing.T) { ...@@ -76,12 +85,20 @@ func TestBasicWantlist(t *testing.T) {
func TestSesRefWantlist(t *testing.T) { func TestSesRefWantlist(t *testing.T) {
wl := NewThreadSafe() wl := NewThreadSafe()
wl.Add(testcids[0], 5, 1) if !wl.Add(testcids[0], 5, 1) {
t.Fatal("should have added")
}
assertHasCid(t, wl, testcids[0]) assertHasCid(t, wl, testcids[0])
wl.Remove(testcids[0], 2) if wl.Remove(testcids[0], 2) {
t.Fatal("shouldnt have removed")
}
assertHasCid(t, wl, testcids[0]) assertHasCid(t, wl, testcids[0])
wl.Add(testcids[0], 5, 1) if wl.Add(testcids[0], 5, 1) {
t.Fatal("shouldnt have added")
}
assertHasCid(t, wl, testcids[0]) assertHasCid(t, wl, testcids[0])
wl.Remove(testcids[0], 1) if !wl.Remove(testcids[0], 1) {
t.Fatal("should have removed")
}
assertNotHasCid(t, wl, testcids[0]) assertNotHasCid(t, wl, testcids[0])
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论