提交 8d8a1dc7 作者: Juan Batiz-Benet

Merge pull request #473 from jbenet/dht-test-providers

dht fixes
...@@ -50,32 +50,44 @@ func (d *Dialer) Dial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) ( ...@@ -50,32 +50,44 @@ func (d *Dialer) Dial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) (
return nil, err return nil, err
} }
select { var connOut Conn
case <-ctx.Done(): var errOut error
maconn.Close() done := make(chan struct{})
return nil, ctx.Err()
default: // do it async to ensure we respect don contexteone
} go func() {
defer func() { done <- struct{}{} }()
c, err := newSingleConn(ctx, d.LocalPeer, remote, maconn) c, err := newSingleConn(ctx, d.LocalPeer, remote, maconn)
if err != nil { if err != nil {
return nil, err errOut = err
return
} }
if d.PrivateKey == nil { if d.PrivateKey == nil {
log.Warning("dialer %s dialing INSECURELY %s at %s!", d, remote, raddr) log.Warning("dialer %s dialing INSECURELY %s at %s!", d, remote, raddr)
return c, nil connOut = c
return
}
c2, err := newSecureConn(ctx, d.PrivateKey, c)
if err != nil {
errOut = err
c.Close()
return
} }
connOut = c2
}()
select { select {
case <-ctx.Done(): case <-ctx.Done():
c.Close() maconn.Close()
return nil, ctx.Err() return nil, ctx.Err()
default: case <-done:
// whew, finished.
} }
// return c, nil return connOut, errOut
return newSecureConn(ctx, d.PrivateKey, c)
} }
// MultiaddrProtocolsMatch returns whether two multiaddrs match in protocol stacks. // MultiaddrProtocolsMatch returns whether two multiaddrs match in protocol stacks.
......
...@@ -109,7 +109,7 @@ func Listen(ctx context.Context, addr ma.Multiaddr, local peer.ID, sk ic.PrivKey ...@@ -109,7 +109,7 @@ func Listen(ctx context.Context, addr ma.Multiaddr, local peer.ID, sk ic.PrivKey
} }
l.cg.SetTeardown(l.teardown) l.cg.SetTeardown(l.teardown)
log.Infof("swarm listening on %s\n", l.Multiaddr()) log.Infof("swarm listening on %s", l.Multiaddr())
log.Event(ctx, "swarmListen", l) log.Event(ctx, "swarmListen", l)
return l, nil return l, nil
} }
...@@ -38,10 +38,11 @@ func NewIDService(n Network) *IDService { ...@@ -38,10 +38,11 @@ func NewIDService(n Network) *IDService {
func (ids *IDService) IdentifyConn(c Conn) { func (ids *IDService) IdentifyConn(c Conn) {
ids.currmu.Lock() ids.currmu.Lock()
if _, found := ids.currid[c]; found { if wait, found := ids.currid[c]; found {
ids.currmu.Unlock() ids.currmu.Unlock()
log.Debugf("IdentifyConn called twice on: %s", c) log.Debugf("IdentifyConn called twice on: %s", c)
return // already identifying it. <-wait // already identifying it. wait for it.
return
} }
ids.currid[c] = make(chan struct{}) ids.currid[c] = make(chan struct{})
ids.currmu.Unlock() ids.currmu.Unlock()
...@@ -50,10 +51,11 @@ func (ids *IDService) IdentifyConn(c Conn) { ...@@ -50,10 +51,11 @@ func (ids *IDService) IdentifyConn(c Conn) {
if err != nil { if err != nil {
log.Error("network: unable to open initial stream for %s", ProtocolIdentify) log.Error("network: unable to open initial stream for %s", ProtocolIdentify)
log.Event(ids.Network.CtxGroup().Context(), "IdentifyOpenFailed", c.RemotePeer()) log.Event(ids.Network.CtxGroup().Context(), "IdentifyOpenFailed", c.RemotePeer())
} } else {
// ok give the response to our handler. // ok give the response to our handler.
ids.ResponseHandler(s) ids.ResponseHandler(s)
}
ids.currmu.Lock() ids.currmu.Lock()
ch, found := ids.currid[c] ch, found := ids.currid[c]
......
...@@ -82,15 +82,6 @@ type Network interface { ...@@ -82,15 +82,6 @@ type Network interface {
// If ProtocolID is "", writes no header. // If ProtocolID is "", writes no header.
NewStream(ProtocolID, peer.ID) (Stream, error) NewStream(ProtocolID, peer.ID) (Stream, error)
// Peers returns the peers connected
Peers() []peer.ID
// Conns returns the connections in this Netowrk
Conns() []Conn
// ConnsToPeer returns the connections in this Netowrk for given peer.
ConnsToPeer(p peer.ID) []Conn
// BandwidthTotals returns the total number of bytes passed through // BandwidthTotals returns the total number of bytes passed through
// the network since it was instantiated // the network since it was instantiated
BandwidthTotals() (uint64, uint64) BandwidthTotals() (uint64, uint64)
...@@ -133,6 +124,15 @@ type Dialer interface { ...@@ -133,6 +124,15 @@ type Dialer interface {
// Connectedness returns a state signaling connection capabilities // Connectedness returns a state signaling connection capabilities
Connectedness(peer.ID) Connectedness Connectedness(peer.ID) Connectedness
// Peers returns the peers connected
Peers() []peer.ID
// Conns returns the connections in this Netowrk
Conns() []Conn
// ConnsToPeer returns the connections in this Netowrk for given peer.
ConnsToPeer(p peer.ID) []Conn
} }
// Connectedness signals the capacity for a connection with a given node. // Connectedness signals the capacity for a connection with a given node.
......
...@@ -148,7 +148,19 @@ func (n *network) DialPeer(ctx context.Context, p peer.ID) error { ...@@ -148,7 +148,19 @@ func (n *network) DialPeer(ctx context.Context, p peer.ID) error {
} }
// identify the connection before returning. // identify the connection before returning.
done := make(chan struct{})
go func() {
n.ids.IdentifyConn((*conn_)(sc)) n.ids.IdentifyConn((*conn_)(sc))
close(done)
}()
// respect don contexteone
select {
case <-done:
case <-ctx.Done():
return ctx.Err()
}
log.Debugf("network for %s finished dialing %s", n.local, p) log.Debugf("network for %s finished dialing %s", n.local, p)
return nil return nil
} }
......
...@@ -248,15 +248,21 @@ func TestConnHandler(t *testing.T) { ...@@ -248,15 +248,21 @@ func TestConnHandler(t *testing.T) {
<-time.After(time.Millisecond) <-time.After(time.Millisecond)
// should've gotten 5 by now. // should've gotten 5 by now.
close(gotconn)
swarms[0].SetConnHandler(nil)
expect := 4 expect := 4
actual := 0 for i := 0; i < expect; i++ {
for _ = range gotconn { select {
actual++ case <-time.After(time.Second):
t.Fatal("failed to get connections")
case <-gotconn:
}
} }
if actual != expect { select {
t.Fatal("should have connected to %d swarms. got: %d", actual, expect) case <-gotconn:
t.Fatalf("should have connected to %d swarms", expect)
default:
} }
} }
...@@ -28,6 +28,10 @@ var log = eventlog.Logger("dht") ...@@ -28,6 +28,10 @@ var log = eventlog.Logger("dht")
const doPinging = false const doPinging = false
// NumBootstrapQueries defines the number of random dht queries to do to
// collect members of the routing table.
const NumBootstrapQueries = 5
// TODO. SEE https://github.com/jbenet/node-ipfs/blob/master/submodules/ipfs-dht/index.js // TODO. SEE https://github.com/jbenet/node-ipfs/blob/master/submodules/ipfs-dht/index.js
// IpfsDHT is an implementation of Kademlia with Coral and S/Kademlia modifications. // IpfsDHT is an implementation of Kademlia with Coral and S/Kademlia modifications.
...@@ -361,25 +365,20 @@ func (dht *IpfsDHT) PingRoutine(t time.Duration) { ...@@ -361,25 +365,20 @@ func (dht *IpfsDHT) PingRoutine(t time.Duration) {
} }
// Bootstrap builds up list of peers by requesting random peer IDs // Bootstrap builds up list of peers by requesting random peer IDs
func (dht *IpfsDHT) Bootstrap(ctx context.Context) { func (dht *IpfsDHT) Bootstrap(ctx context.Context, queries int) {
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
// bootstrap sequentially, as results will compound
for i := 0; i < NumBootstrapQueries; i++ {
id := make([]byte, 16) id := make([]byte, 16)
rand.Read(id) rand.Read(id)
pi, err := dht.FindPeer(ctx, peer.ID(id)) pi, err := dht.FindPeer(ctx, peer.ID(id))
if err != nil { if err == routing.ErrNotFound {
// NOTE: this is not an error. this is expected! // this isn't an error. this is precisely what we expect.
} else if err != nil {
log.Errorf("Bootstrap peer error: %s", err) log.Errorf("Bootstrap peer error: %s", err)
} } else {
// woah, we got a peer under a random id? it _cannot_ be valid. // woah, we got a peer under a random id? it _cannot_ be valid.
log.Errorf("dht seemingly found a peer at a random bootstrap id (%s)...", pi) log.Errorf("dht seemingly found a peer at a random bootstrap id (%s)...", pi)
}()
} }
wg.Wait() }
} }
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
inet "github.com/jbenet/go-ipfs/net" inet "github.com/jbenet/go-ipfs/net"
peer "github.com/jbenet/go-ipfs/peer" peer "github.com/jbenet/go-ipfs/peer"
pb "github.com/jbenet/go-ipfs/routing/dht/pb" pb "github.com/jbenet/go-ipfs/routing/dht/pb"
ctxutil "github.com/jbenet/go-ipfs/util/ctx"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
ggio "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/gogoprotobuf/io" ggio "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/gogoprotobuf/io"
...@@ -21,18 +22,21 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) { ...@@ -21,18 +22,21 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) {
defer s.Close() defer s.Close()
ctx := dht.Context() ctx := dht.Context()
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax) cr := ctxutil.NewReader(ctx, s) // ok to use. we defer close stream in this func
w := ggio.NewDelimitedWriter(s) cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func
r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax)
w := ggio.NewDelimitedWriter(cw)
mPeer := s.Conn().RemotePeer() mPeer := s.Conn().RemotePeer()
// receive msg // receive msg
pmes := new(pb.Message) pmes := new(pb.Message)
if err := r.ReadMsg(pmes); err != nil { if err := r.ReadMsg(pmes); err != nil {
log.Error("Error unmarshaling data") log.Errorf("Error unmarshaling data: %s", err)
return return
} }
// update the peer (on valid msgs only) // update the peer (on valid msgs only)
dht.Update(ctx, mPeer) dht.updateFromMessage(ctx, mPeer, pmes)
log.Event(ctx, "foo", dht.self, mPeer, pmes) log.Event(ctx, "foo", dht.self, mPeer, pmes)
...@@ -76,8 +80,10 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message ...@@ -76,8 +80,10 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message
} }
defer s.Close() defer s.Close()
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax) cr := ctxutil.NewReader(ctx, s) // ok to use. we defer close stream in this func
w := ggio.NewDelimitedWriter(s) cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func
r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax)
w := ggio.NewDelimitedWriter(cw)
start := time.Now() start := time.Now()
...@@ -98,6 +104,9 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message ...@@ -98,6 +104,9 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message
return nil, errors.New("no response to request") return nil, errors.New("no response to request")
} }
// update the peer (on valid msgs only)
dht.updateFromMessage(ctx, p, rpmes)
dht.peerstore.RecordLatency(p, time.Since(start)) dht.peerstore.RecordLatency(p, time.Since(start))
log.Event(ctx, "dhtReceivedMessage", dht.self, p, rpmes) log.Event(ctx, "dhtReceivedMessage", dht.self, p, rpmes)
return rpmes, nil return rpmes, nil
...@@ -113,7 +122,8 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message ...@@ -113,7 +122,8 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message
} }
defer s.Close() defer s.Close()
w := ggio.NewDelimitedWriter(s) cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func
w := ggio.NewDelimitedWriter(cw)
log.Debugf("%s writing", dht.self) log.Debugf("%s writing", dht.self)
if err := w.WriteMsg(pmes); err != nil { if err := w.WriteMsg(pmes); err != nil {
...@@ -123,3 +133,8 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message ...@@ -123,3 +133,8 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message
log.Debugf("%s done", dht.self) log.Debugf("%s done", dht.self)
return nil return nil
} }
func (dht *IpfsDHT) updateFromMessage(ctx context.Context, p peer.ID, mes *pb.Message) error {
dht.Update(ctx, p)
return nil
}
...@@ -2,7 +2,9 @@ package dht ...@@ -2,7 +2,9 @@ package dht
import ( import (
"bytes" "bytes"
"fmt"
"sort" "sort"
"sync"
"testing" "testing"
"time" "time"
...@@ -15,10 +17,22 @@ import ( ...@@ -15,10 +17,22 @@ import (
// ci "github.com/jbenet/go-ipfs/crypto" // ci "github.com/jbenet/go-ipfs/crypto"
inet "github.com/jbenet/go-ipfs/net" inet "github.com/jbenet/go-ipfs/net"
peer "github.com/jbenet/go-ipfs/peer" peer "github.com/jbenet/go-ipfs/peer"
routing "github.com/jbenet/go-ipfs/routing"
u "github.com/jbenet/go-ipfs/util" u "github.com/jbenet/go-ipfs/util"
testutil "github.com/jbenet/go-ipfs/util/testutil" testutil "github.com/jbenet/go-ipfs/util/testutil"
) )
var testCaseValues = map[u.Key][]byte{}
func init() {
testCaseValues["hello"] = []byte("world")
for i := 0; i < 100; i++ {
k := fmt.Sprintf("%d -- key", i)
v := fmt.Sprintf("%d -- value", i)
testCaseValues[u.Key(k)] = []byte(v)
}
}
func setupDHT(ctx context.Context, t *testing.T, addr ma.Multiaddr) *IpfsDHT { func setupDHT(ctx context.Context, t *testing.T, addr ma.Multiaddr) *IpfsDHT {
sk, pk, err := testutil.RandKeyPair(512) sk, pk, err := testutil.RandKeyPair(512)
...@@ -78,6 +92,27 @@ func connect(t *testing.T, ctx context.Context, a, b *IpfsDHT) { ...@@ -78,6 +92,27 @@ func connect(t *testing.T, ctx context.Context, a, b *IpfsDHT) {
} }
} }
func bootstrap(t *testing.T, ctx context.Context, dhts []*IpfsDHT) {
ctx, cancel := context.WithCancel(ctx)
rounds := 1
for i := 0; i < rounds; i++ {
log.Debugf("bootstrapping round %d/%d\n", i, rounds)
// tried async. sequential fares much better. compare:
// 100 async https://gist.github.com/jbenet/56d12f0578d5f34810b2
// 100 sync https://gist.github.com/jbenet/6c59e7c15426e48aaedd
// probably because results compound
for _, dht := range dhts {
log.Debugf("bootstrapping round %d/%d -- %s\n", i, rounds, dht.self)
dht.Bootstrap(ctx, 3)
}
}
cancel()
}
func TestPing(t *testing.T) { func TestPing(t *testing.T) {
// t.Skip("skipping test to debug another") // t.Skip("skipping test to debug another")
ctx := context.Background() ctx := context.Background()
...@@ -174,26 +209,39 @@ func TestProvides(t *testing.T) { ...@@ -174,26 +209,39 @@ func TestProvides(t *testing.T) {
connect(t, ctx, dhts[1], dhts[2]) connect(t, ctx, dhts[1], dhts[2])
connect(t, ctx, dhts[1], dhts[3]) connect(t, ctx, dhts[1], dhts[3])
err := dhts[3].putLocal(u.Key("hello"), []byte("world")) for k, v := range testCaseValues {
log.Debugf("adding local values for %s = %s", k, v)
err := dhts[3].putLocal(k, v)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
bits, err := dhts[3].getLocal(u.Key("hello")) bits, err := dhts[3].getLocal(k)
if err != nil && bytes.Equal(bits, []byte("world")) { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(bits, v) {
t.Fatal("didn't store the right bits (%s, %s)", k, v)
}
}
err = dhts[3].Provide(ctx, u.Key("hello")) for k, _ := range testCaseValues {
if err != nil { log.Debugf("announcing provider for %s", k)
if err := dhts[3].Provide(ctx, k); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}
// what is this timeout for? was 60ms before. // what is this timeout for? was 60ms before.
time.Sleep(time.Millisecond * 6) time.Sleep(time.Millisecond * 6)
n := 0
for k, _ := range testCaseValues {
n = (n + 1) % 3
log.Debugf("getting providers for %s from %d", k, n)
ctxT, _ := context.WithTimeout(ctx, time.Second) ctxT, _ := context.WithTimeout(ctx, time.Second)
provchan := dhts[0].FindProvidersAsync(ctxT, u.Key("hello"), 1) provchan := dhts[n].FindProvidersAsync(ctxT, k, 1)
select { select {
case prov := <-provchan: case prov := <-provchan:
...@@ -201,11 +249,169 @@ func TestProvides(t *testing.T) { ...@@ -201,11 +249,169 @@ func TestProvides(t *testing.T) {
t.Fatal("Got back nil provider") t.Fatal("Got back nil provider")
} }
if prov.ID != dhts[3].self { if prov.ID != dhts[3].self {
t.Fatal("Got back nil provider") t.Fatal("Got back wrong provider")
} }
case <-ctxT.Done(): case <-ctxT.Done():
t.Fatal("Did not get a provider back.") t.Fatal("Did not get a provider back.")
} }
}
}
func TestBootstrap(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
ctx := context.Background()
nDHTs := 15
_, _, dhts := setupDHTS(ctx, nDHTs, t)
defer func() {
for i := 0; i < nDHTs; i++ {
dhts[i].Close()
defer dhts[i].network.Close()
}
}()
t.Logf("connecting %d dhts in a ring", nDHTs)
for i := 0; i < nDHTs; i++ {
connect(t, ctx, dhts[i], dhts[(i+1)%len(dhts)])
}
<-time.After(100 * time.Millisecond)
t.Logf("bootstrapping them so they find each other", nDHTs)
ctxT, _ := context.WithTimeout(ctx, 5*time.Second)
bootstrap(t, ctxT, dhts)
if u.Debug {
// the routing tables should be full now. let's inspect them.
<-time.After(5 * time.Second)
t.Logf("checking routing table of %d", nDHTs)
for _, dht := range dhts {
fmt.Printf("checking routing table of %s\n", dht.self)
dht.routingTable.Print()
fmt.Println("")
}
}
// test "well-formed-ness" (>= 3 peers in every routing table)
for _, dht := range dhts {
rtlen := dht.routingTable.Size()
if rtlen < 4 {
t.Errorf("routing table for %s only has %d peers", dht.self, rtlen)
}
}
}
func TestProvidesMany(t *testing.T) {
t.Skip("this test doesn't work")
// t.Skip("skipping test to debug another")
ctx := context.Background()
nDHTs := 40
_, _, dhts := setupDHTS(ctx, nDHTs, t)
defer func() {
for i := 0; i < nDHTs; i++ {
dhts[i].Close()
defer dhts[i].network.Close()
}
}()
t.Logf("connecting %d dhts in a ring", nDHTs)
for i := 0; i < nDHTs; i++ {
connect(t, ctx, dhts[i], dhts[(i+1)%len(dhts)])
}
<-time.After(100 * time.Millisecond)
t.Logf("bootstrapping them so they find each other", nDHTs)
ctxT, _ := context.WithTimeout(ctx, 5*time.Second)
bootstrap(t, ctxT, dhts)
if u.Debug {
// the routing tables should be full now. let's inspect them.
<-time.After(5 * time.Second)
t.Logf("checking routing table of %d", nDHTs)
for _, dht := range dhts {
fmt.Printf("checking routing table of %s\n", dht.self)
dht.routingTable.Print()
fmt.Println("")
}
}
var providers = map[u.Key]peer.ID{}
d := 0
for k, v := range testCaseValues {
d = (d + 1) % len(dhts)
dht := dhts[d]
providers[k] = dht.self
t.Logf("adding local values for %s = %s (on %s)", k, v, dht.self)
err := dht.putLocal(k, v)
if err != nil {
t.Fatal(err)
}
bits, err := dht.getLocal(k)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(bits, v) {
t.Fatal("didn't store the right bits (%s, %s)", k, v)
}
t.Logf("announcing provider for %s", k)
if err := dht.Provide(ctx, k); err != nil {
t.Fatal(err)
}
}
// what is this timeout for? was 60ms before.
time.Sleep(time.Millisecond * 6)
errchan := make(chan error)
ctxT, _ = context.WithTimeout(ctx, 5*time.Second)
var wg sync.WaitGroup
getProvider := func(dht *IpfsDHT, k u.Key) {
defer wg.Done()
expected := providers[k]
provchan := dht.FindProvidersAsync(ctxT, k, 1)
select {
case prov := <-provchan:
actual := prov.ID
if actual == "" {
errchan <- fmt.Errorf("Got back nil provider (%s at %s)", k, dht.self)
} else if actual != expected {
errchan <- fmt.Errorf("Got back wrong provider (%s != %s) (%s at %s)",
expected, actual, k, dht.self)
}
case <-ctxT.Done():
errchan <- fmt.Errorf("Did not get a provider back (%s at %s)", k, dht.self)
}
}
for k, _ := range testCaseValues {
// everyone should be able to find it...
for _, dht := range dhts {
log.Debugf("getting providers for %s at %s", k, dht.self)
wg.Add(1)
go getProvider(dht, k)
}
}
// we need this because of printing errors
go func() {
wg.Wait()
close(errchan)
}()
for err := range errchan {
t.Error(err)
}
} }
func TestProvidesAsync(t *testing.T) { func TestProvidesAsync(t *testing.T) {
...@@ -291,18 +497,20 @@ func TestLayeredGet(t *testing.T) { ...@@ -291,18 +497,20 @@ func TestLayeredGet(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
time.Sleep(time.Millisecond * 60) time.Sleep(time.Millisecond * 6)
t.Log("interface was changed. GetValue should not use providers.")
ctxT, _ := context.WithTimeout(ctx, time.Second) ctxT, _ := context.WithTimeout(ctx, time.Second)
val, err := dhts[0].GetValue(ctxT, u.Key("/v/hello")) val, err := dhts[0].GetValue(ctxT, u.Key("/v/hello"))
if err != nil { if err != routing.ErrNotFound {
t.Fatal(err) t.Error(err)
} }
if string(val) == "world" {
if string(val) != "world" { t.Error("should not get value.")
t.Fatal("Got incorrect value.") }
if len(val) > 0 && string(val) != "world" {
t.Error("worse, there's a value and its not even the right one.")
} }
} }
func TestFindPeer(t *testing.T) { func TestFindPeer(t *testing.T) {
......
...@@ -73,7 +73,7 @@ func TestGetFailures(t *testing.T) { ...@@ -73,7 +73,7 @@ func TestGetFailures(t *testing.T) {
}) })
// This one should fail with NotFound // This one should fail with NotFound
ctx2, _ := context.WithTimeout(context.Background(), time.Second) ctx2, _ := context.WithTimeout(context.Background(), 3*time.Second)
_, err = d.GetValue(ctx2, u.Key("test")) _, err = d.GetValue(ctx2, u.Key("test"))
if err != nil { if err != nil {
if err != routing.ErrNotFound { if err != routing.ErrNotFound {
......
...@@ -148,7 +148,7 @@ func (dht *IpfsDHT) handleFindPeer(ctx context.Context, p peer.ID, pmes *pb.Mess ...@@ -148,7 +148,7 @@ func (dht *IpfsDHT) handleFindPeer(ctx context.Context, p peer.ID, pmes *pb.Mess
} }
if closest == nil { if closest == nil {
log.Errorf("handleFindPeer: could not find anything.") log.Debugf("handleFindPeer: could not find anything.")
return resp, nil return resp, nil
} }
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
todoctr "github.com/jbenet/go-ipfs/util/todocounter" todoctr "github.com/jbenet/go-ipfs/util/todocounter"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
ctxgroup "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-ctxgroup"
) )
var maxQueryConcurrency = AlphaValue var maxQueryConcurrency = AlphaValue
...@@ -78,9 +79,8 @@ type dhtQueryRunner struct { ...@@ -78,9 +79,8 @@ type dhtQueryRunner struct {
// peersRemaining is a counter of peers remaining (toQuery + processing) // peersRemaining is a counter of peers remaining (toQuery + processing)
peersRemaining todoctr.Counter peersRemaining todoctr.Counter
// context // context group
ctx context.Context cg ctxgroup.ContextGroup
cancel context.CancelFunc
// result // result
result *dhtQueryResult result *dhtQueryResult
...@@ -93,16 +93,13 @@ type dhtQueryRunner struct { ...@@ -93,16 +93,13 @@ type dhtQueryRunner struct {
} }
func newQueryRunner(ctx context.Context, q *dhtQuery) *dhtQueryRunner { func newQueryRunner(ctx context.Context, q *dhtQuery) *dhtQueryRunner {
ctx, cancel := context.WithCancel(ctx)
return &dhtQueryRunner{ return &dhtQueryRunner{
ctx: ctx,
cancel: cancel,
query: q, query: q,
peersToQuery: queue.NewChanQueue(ctx, queue.NewXORDistancePQ(q.key)), peersToQuery: queue.NewChanQueue(ctx, queue.NewXORDistancePQ(q.key)),
peersRemaining: todoctr.NewSyncCounter(), peersRemaining: todoctr.NewSyncCounter(),
peersSeen: peer.Set{}, peersSeen: peer.Set{},
rateLimit: make(chan struct{}, q.concurrency), rateLimit: make(chan struct{}, q.concurrency),
cg: ctxgroup.WithContext(ctx),
} }
} }
...@@ -120,11 +117,13 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { ...@@ -120,11 +117,13 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) {
// add all the peers we got first. // add all the peers we got first.
for _, p := range peers { for _, p := range peers {
r.addPeerToQuery(p, "") // don't have access to self here... r.addPeerToQuery(r.cg.Context(), p, "") // don't have access to self here...
} }
// go do this thing. // go do this thing.
go r.spawnWorkers() // do it as a child func to make sure Run exits
// ONLY AFTER spawn workers has exited.
r.cg.AddChildFunc(r.spawnWorkers)
// so workers are working. // so workers are working.
...@@ -133,7 +132,7 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { ...@@ -133,7 +132,7 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) {
select { select {
case <-r.peersRemaining.Done(): case <-r.peersRemaining.Done():
r.cancel() // ran all and nothing. cancel all outstanding workers. r.cg.Close()
r.RLock() r.RLock()
defer r.RUnlock() defer r.RUnlock()
...@@ -141,10 +140,10 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { ...@@ -141,10 +140,10 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) {
err = r.errs[0] err = r.errs[0]
} }
case <-r.ctx.Done(): case <-r.cg.Closed():
r.RLock() r.RLock()
defer r.RUnlock() defer r.RUnlock()
err = r.ctx.Err() err = r.cg.Context().Err() // collect the error.
} }
if r.result != nil && r.result.success { if r.result != nil && r.result.success {
...@@ -154,7 +153,7 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { ...@@ -154,7 +153,7 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) {
return nil, err return nil, err
} }
func (r *dhtQueryRunner) addPeerToQuery(next peer.ID, benchmark peer.ID) { func (r *dhtQueryRunner) addPeerToQuery(ctx context.Context, next peer.ID, benchmark peer.ID) {
// if new peer is ourselves... // if new peer is ourselves...
if next == r.query.dialer.LocalPeer() { if next == r.query.dialer.LocalPeer() {
return return
...@@ -180,43 +179,48 @@ func (r *dhtQueryRunner) addPeerToQuery(next peer.ID, benchmark peer.ID) { ...@@ -180,43 +179,48 @@ func (r *dhtQueryRunner) addPeerToQuery(next peer.ID, benchmark peer.ID) {
r.peersSeen[next] = struct{}{} r.peersSeen[next] = struct{}{}
r.Unlock() r.Unlock()
log.Debugf("adding peer to query: %v\n", next) log.Debugf("adding peer to query: %v", next)
// do this after unlocking to prevent possible deadlocks. // do this after unlocking to prevent possible deadlocks.
r.peersRemaining.Increment(1) r.peersRemaining.Increment(1)
select { select {
case r.peersToQuery.EnqChan <- next: case r.peersToQuery.EnqChan <- next:
case <-r.ctx.Done(): case <-ctx.Done():
} }
} }
func (r *dhtQueryRunner) spawnWorkers() { func (r *dhtQueryRunner) spawnWorkers(parent ctxgroup.ContextGroup) {
for { for {
select { select {
case <-r.peersRemaining.Done(): case <-r.peersRemaining.Done():
return return
case <-r.ctx.Done(): case <-r.cg.Closing():
return return
case p, more := <-r.peersToQuery.DeqChan: case p, more := <-r.peersToQuery.DeqChan:
if !more { if !more {
return // channel closed. return // channel closed.
} }
log.Debugf("spawning worker for: %v\n", p) log.Debugf("spawning worker for: %v", p)
go r.queryPeer(p)
// do it as a child func to make sure Run exits
// ONLY AFTER spawn workers has exited.
parent.AddChildFunc(func(cg ctxgroup.ContextGroup) {
r.queryPeer(cg, p)
})
} }
} }
} }
func (r *dhtQueryRunner) queryPeer(p peer.ID) { func (r *dhtQueryRunner) queryPeer(cg ctxgroup.ContextGroup, p peer.ID) {
log.Debugf("spawned worker for: %v", p) log.Debugf("spawned worker for: %v", p)
// make sure we rate limit concurrency. // make sure we rate limit concurrency.
select { select {
case <-r.rateLimit: case <-r.rateLimit:
case <-r.ctx.Done(): case <-cg.Closing():
r.peersRemaining.Decrement(1) r.peersRemaining.Decrement(1)
return return
} }
...@@ -233,8 +237,10 @@ func (r *dhtQueryRunner) queryPeer(p peer.ID) { ...@@ -233,8 +237,10 @@ func (r *dhtQueryRunner) queryPeer(p peer.ID) {
}() }()
// make sure we're connected to the peer. // make sure we're connected to the peer.
err := r.query.dialer.DialPeer(r.ctx, p) if conns := r.query.dialer.ConnsToPeer(p); len(conns) == 0 {
if err != nil { log.Infof("worker for: %v -- not connected. dial start", p)
if err := r.query.dialer.DialPeer(cg.Context(), p); err != nil {
log.Debugf("ERROR worker for: %v -- err connecting: %v", p, err) log.Debugf("ERROR worker for: %v -- err connecting: %v", p, err)
r.Lock() r.Lock()
r.errs = append(r.errs, err) r.errs = append(r.errs, err)
...@@ -242,8 +248,11 @@ func (r *dhtQueryRunner) queryPeer(p peer.ID) { ...@@ -242,8 +248,11 @@ func (r *dhtQueryRunner) queryPeer(p peer.ID) {
return return
} }
log.Infof("worker for: %v -- not connected. dial success!", p)
}
// finally, run the query against this peer // finally, run the query against this peer
res, err := r.query.qfunc(r.ctx, p) res, err := r.query.qfunc(cg.Context(), p)
if err != nil { if err != nil {
log.Debugf("ERROR worker for: %v %v", p, err) log.Debugf("ERROR worker for: %v %v", p, err)
...@@ -256,14 +265,20 @@ func (r *dhtQueryRunner) queryPeer(p peer.ID) { ...@@ -256,14 +265,20 @@ func (r *dhtQueryRunner) queryPeer(p peer.ID) {
r.Lock() r.Lock()
r.result = res r.result = res
r.Unlock() r.Unlock()
r.cancel() // signal to everyone that we're done. go r.cg.Close() // signal to everyone that we're done.
// must be async, as we're one of the children, and Close blocks.
} else if len(res.closerPeers) > 0 { } else if len(res.closerPeers) > 0 {
log.Debugf("PEERS CLOSER -- worker for: %v (%d closer peers)", p, len(res.closerPeers)) log.Debugf("PEERS CLOSER -- worker for: %v (%d closer peers)", p, len(res.closerPeers))
for _, next := range res.closerPeers { for _, next := range res.closerPeers {
// add their addresses to the dialer's peerstore // add their addresses to the dialer's peerstore
conns := r.query.dialer.ConnsToPeer(next.ID)
if len(conns) == 0 {
log.Infof("PEERS CLOSER -- worker for %v FOUND NEW PEER: %s %s", p, next.ID, next.Addrs)
}
r.query.dialer.Peerstore().AddAddresses(next.ID, next.Addrs) r.query.dialer.Peerstore().AddAddresses(next.ID, next.Addrs)
r.addPeerToQuery(next.ID, p) r.addPeerToQuery(cg.Context(), next.ID, p)
log.Debugf("PEERS CLOSER -- worker for: %v added %v (%v)", p, next.ID, next.Addrs) log.Debugf("PEERS CLOSER -- worker for: %v added %v (%v)", p, next.ID, next.Addrs)
} }
} else { } else {
......
...@@ -223,8 +223,16 @@ func (rt *RoutingTable) ListPeers() []peer.ID { ...@@ -223,8 +223,16 @@ func (rt *RoutingTable) ListPeers() []peer.ID {
func (rt *RoutingTable) Print() { func (rt *RoutingTable) Print() {
fmt.Printf("Routing Table, bs = %d, Max latency = %d\n", rt.bucketsize, rt.maxLatency) fmt.Printf("Routing Table, bs = %d, Max latency = %d\n", rt.bucketsize, rt.maxLatency)
rt.tabLock.RLock() rt.tabLock.RLock()
peers := rt.ListPeers()
for i, p := range peers { for i, b := range rt.Buckets {
fmt.Printf("%d) %s %s\n", i, p.Pretty(), rt.metrics.LatencyEWMA(p).String()) fmt.Printf("\tbucket: %d\n", i)
b.lk.RLock()
for e := b.list.Front(); e != nil; e = e.Next() {
p := e.Value.(peer.ID)
fmt.Printf("\t\t- %s %s\n", p.Pretty(), rt.metrics.LatencyEWMA(p).String())
}
b.lk.RUnlock()
} }
rt.tabLock.RUnlock()
} }
package ctxutil
import (
"io"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
)
type ioret struct {
n int
err error
}
type Writer interface {
io.Writer
}
type ctxWriter struct {
w io.Writer
ctx context.Context
}
// NewWriter wraps a writer to make it respect given Context.
// If there is a blocking write, the returned Writer will return
// whenever the context is cancelled (the return values are n=0
// and err=ctx.Err().)
//
// Note well: this wrapper DOES NOT ACTUALLY cancel the underlying
// write-- there is no way to do that with the standard go io
// interface. So the read and write _will_ happen or hang. So, use
// this sparingly, make sure to cancel the read or write as necesary
// (e.g. closing a connection whose context is up, etc.)
//
// Furthermore, in order to protect your memory from being read
// _after_ you've cancelled the context, this io.Writer will
// first make a **copy** of the buffer.
func NewWriter(ctx context.Context, w io.Writer) *ctxWriter {
if ctx == nil {
ctx = context.Background()
}
return &ctxWriter{ctx: ctx, w: w}
}
func (w *ctxWriter) Write(buf []byte) (int, error) {
buf2 := make([]byte, len(buf))
copy(buf2, buf)
c := make(chan ioret, 1)
go func() {
n, err := w.w.Write(buf2)
c <- ioret{n, err}
close(c)
}()
select {
case r := <-c:
return r.n, r.err
case <-w.ctx.Done():
return 0, w.ctx.Err()
}
}
type Reader interface {
io.Reader
}
type ctxReader struct {
r io.Reader
ctx context.Context
}
// NewReader wraps a reader to make it respect given Context.
// If there is a blocking read, the returned Reader will return
// whenever the context is cancelled (the return values are n=0
// and err=ctx.Err().)
//
// Note well: this wrapper DOES NOT ACTUALLY cancel the underlying
// write-- there is no way to do that with the standard go io
// interface. So the read and write _will_ happen or hang. So, use
// this sparingly, make sure to cancel the read or write as necesary
// (e.g. closing a connection whose context is up, etc.)
//
// Furthermore, in order to protect your memory from being read
// _before_ you've cancelled the context, this io.Reader will
// allocate a buffer of the same size, and **copy** into the client's
// if the read succeeds in time.
func NewReader(ctx context.Context, r io.Reader) *ctxReader {
return &ctxReader{ctx: ctx, r: r}
}
func (r *ctxReader) Read(buf []byte) (int, error) {
buf2 := make([]byte, len(buf))
c := make(chan ioret, 1)
go func() {
n, err := r.r.Read(buf2)
c <- ioret{n, err}
close(c)
}()
select {
case ret := <-c:
copy(buf, buf2)
return ret.n, ret.err
case <-r.ctx.Done():
return 0, r.ctx.Err()
}
}
package ctxutil
import (
"bytes"
"io"
"testing"
"time"
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
)
func TestReader(t *testing.T) {
buf := []byte("abcdef")
buf2 := make([]byte, 3)
r := NewReader(context.Background(), bytes.NewReader(buf))
// read first half
n, err := r.Read(buf2)
if n != 3 {
t.Error("n should be 3")
}
if err != nil {
t.Error("should have no error")
}
if string(buf2) != string(buf[:3]) {
t.Error("incorrect contents")
}
// read second half
n, err = r.Read(buf2)
if n != 3 {
t.Error("n should be 3")
}
if err != nil {
t.Error("should have no error")
}
if string(buf2) != string(buf[3:6]) {
t.Error("incorrect contents")
}
// read more.
n, err = r.Read(buf2)
if n != 0 {
t.Error("n should be 0", n)
}
if err != io.EOF {
t.Error("should be EOF", err)
}
}
func TestWriter(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(context.Background(), &buf)
// write three
n, err := w.Write([]byte("abc"))
if n != 3 {
t.Error("n should be 3")
}
if err != nil {
t.Error("should have no error")
}
if string(buf.Bytes()) != string("abc") {
t.Error("incorrect contents")
}
// write three more
n, err = w.Write([]byte("def"))
if n != 3 {
t.Error("n should be 3")
}
if err != nil {
t.Error("should have no error")
}
if string(buf.Bytes()) != string("abcdef") {
t.Error("incorrect contents")
}
}
func TestReaderCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
piper, pipew := io.Pipe()
r := NewReader(ctx, piper)
buf := make([]byte, 10)
done := make(chan ioret)
go func() {
n, err := r.Read(buf)
done <- ioret{n, err}
}()
pipew.Write([]byte("abcdefghij"))
select {
case ret := <-done:
if ret.n != 10 {
t.Error("ret.n should be 10", ret.n)
}
if ret.err != nil {
t.Error("ret.err should be nil", ret.err)
}
if string(buf) != "abcdefghij" {
t.Error("read contents differ")
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to read")
}
go func() {
n, err := r.Read(buf)
done <- ioret{n, err}
}()
cancel()
select {
case ret := <-done:
if ret.n != 0 {
t.Error("ret.n should be 0", ret.n)
}
if ret.err == nil {
t.Error("ret.err should be ctx error", ret.err)
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to stop reading after cancel")
}
}
func TestWriterCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
piper, pipew := io.Pipe()
w := NewWriter(ctx, pipew)
buf := make([]byte, 10)
done := make(chan ioret)
go func() {
n, err := w.Write([]byte("abcdefghij"))
done <- ioret{n, err}
}()
piper.Read(buf)
select {
case ret := <-done:
if ret.n != 10 {
t.Error("ret.n should be 10", ret.n)
}
if ret.err != nil {
t.Error("ret.err should be nil", ret.err)
}
if string(buf) != "abcdefghij" {
t.Error("write contents differ")
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to write")
}
go func() {
n, err := w.Write([]byte("abcdefghij"))
done <- ioret{n, err}
}()
cancel()
select {
case ret := <-done:
if ret.n != 0 {
t.Error("ret.n should be 0", ret.n)
}
if ret.err == nil {
t.Error("ret.err should be ctx error", ret.err)
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to stop writing after cancel")
}
}
func TestReadPostCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
piper, pipew := io.Pipe()
r := NewReader(ctx, piper)
buf := make([]byte, 10)
done := make(chan ioret)
go func() {
n, err := r.Read(buf)
done <- ioret{n, err}
}()
cancel()
select {
case ret := <-done:
if ret.n != 0 {
t.Error("ret.n should be 0", ret.n)
}
if ret.err == nil {
t.Error("ret.err should be ctx error", ret.err)
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to stop reading after cancel")
}
pipew.Write([]byte("abcdefghij"))
if !bytes.Equal(buf, make([]byte, len(buf))) {
t.Fatal("buffer should have not been written to")
}
}
func TestWritePostCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
piper, pipew := io.Pipe()
w := NewWriter(ctx, pipew)
buf := []byte("abcdefghij")
buf2 := make([]byte, 10)
done := make(chan ioret)
go func() {
n, err := w.Write(buf)
done <- ioret{n, err}
}()
piper.Read(buf2)
select {
case ret := <-done:
if ret.n != 10 {
t.Error("ret.n should be 10", ret.n)
}
if ret.err != nil {
t.Error("ret.err should be nil", ret.err)
}
if string(buf2) != "abcdefghij" {
t.Error("write contents differ")
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to write")
}
go func() {
n, err := w.Write(buf)
done <- ioret{n, err}
}()
cancel()
select {
case ret := <-done:
if ret.n != 0 {
t.Error("ret.n should be 0", ret.n)
}
if ret.err == nil {
t.Error("ret.err should be ctx error", ret.err)
}
case <-time.After(20 * time.Millisecond):
t.Fatal("failed to stop writing after cancel")
}
copy(buf, []byte("aaaaaaaaaa"))
piper.Read(buf2)
if string(buf2) == "aaaaaaaaaa" {
t.Error("buffer was read from after ctx cancel")
} else if string(buf2) != "abcdefghij" {
t.Error("write contents differ from expected")
}
}
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"sync"
"testing" "testing"
ci "github.com/jbenet/go-ipfs/crypto" ci "github.com/jbenet/go-ipfs/crypto"
...@@ -49,17 +50,24 @@ func RandLocalTCPAddress() ma.Multiaddr { ...@@ -49,17 +50,24 @@ func RandLocalTCPAddress() ma.Multiaddr {
// most ports above 10000 aren't in use by long running processes, so yay. // most ports above 10000 aren't in use by long running processes, so yay.
// (maybe there should be a range of "loopback" ports that are guaranteed // (maybe there should be a range of "loopback" ports that are guaranteed
// to be open for the process, but naturally can only talk to self.) // to be open for the process, but naturally can only talk to self.)
if lastPort == 0 {
lastPort = 10000 + SeededRand.Intn(50000) lastPort.Lock()
if lastPort.port == 0 {
lastPort.port = 10000 + SeededRand.Intn(50000)
} }
lastPort++ port := lastPort.port
lastPort.port++
lastPort.Unlock()
addr := fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", lastPort) addr := fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", port)
maddr, _ := ma.NewMultiaddr(addr) maddr, _ := ma.NewMultiaddr(addr)
return maddr return maddr
} }
var lastPort = 0 var lastPort = struct {
port int
sync.Mutex
}{}
// PeerNetParams is a struct to bundle together the four things // PeerNetParams is a struct to bundle together the four things
// you need to run a connection with a peer: id, 2keys, and addr. // you need to run a connection with a peer: id, 2keys, and addr.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论