提交 44c87461 作者: Jeromy

a few small changes to make the dht more efficient

License: MIT
Signed-off-by: 's avatarJeromy <why@ipfs.io>
上级 c9ddc7d1
...@@ -117,34 +117,6 @@ func (dht *IpfsDHT) putValueToPeer(ctx context.Context, p peer.ID, ...@@ -117,34 +117,6 @@ func (dht *IpfsDHT) putValueToPeer(ctx context.Context, p peer.ID,
return nil return nil
} }
// putProvider sends a message to peer 'p' saying that the local node
// can provide the value of 'key'
func (dht *IpfsDHT) putProvider(ctx context.Context, p peer.ID, skey string) error {
// add self as the provider
pi := pstore.PeerInfo{
ID: dht.self,
Addrs: dht.host.Addrs(),
}
// // only share WAN-friendly addresses ??
// pi.Addrs = addrutil.WANShareableAddrs(pi.Addrs)
if len(pi.Addrs) < 1 {
// log.Infof("%s putProvider: %s for %s error: no wan-friendly addresses", dht.self, p, key.Key(key), pi.Addrs)
return fmt.Errorf("no known addresses for self. cannot put provider.")
}
pmes := pb.NewMessage(pb.Message_ADD_PROVIDER, skey, 0)
pmes.ProviderPeers = pb.RawPeerInfosToPBPeers([]pstore.PeerInfo{pi})
err := dht.sendMessage(ctx, p, pmes)
if err != nil {
return err
}
log.Debugf("%s putProvider: %s for %s (%s)", dht.self, p, key.Key(skey), pi.Addrs)
return nil
}
var errInvalidRecord = errors.New("received invalid record") var errInvalidRecord = errors.New("received invalid record")
// getValueOrPeers queries a particular peer p for the value for // getValueOrPeers queries a particular peer p for the value for
......
...@@ -2,6 +2,7 @@ package dht ...@@ -2,6 +2,7 @@ package dht
import ( import (
"bytes" "bytes"
"fmt"
"sync" "sync"
"time" "time"
...@@ -243,13 +244,18 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key key.Key) error { ...@@ -243,13 +244,18 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key key.Key) error {
return err return err
} }
mes, err := dht.makeProvRecord(key)
if err != nil {
return err
}
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
for p := range peers { for p := range peers {
wg.Add(1) wg.Add(1)
go func(p peer.ID) { go func(p peer.ID) {
defer wg.Done() defer wg.Done()
log.Debugf("putProvider(%s, %s)", key, p) log.Debugf("putProvider(%s, %s)", key, p)
err := dht.putProvider(ctx, p, string(key)) err := dht.sendMessage(ctx, p, mes)
if err != nil { if err != nil {
log.Debug(err) log.Debug(err)
} }
...@@ -258,6 +264,22 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key key.Key) error { ...@@ -258,6 +264,22 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key key.Key) error {
wg.Wait() wg.Wait()
return nil return nil
} }
func (dht *IpfsDHT) makeProvRecord(skey key.Key) (*pb.Message, error) {
pi := pstore.PeerInfo{
ID: dht.self,
Addrs: dht.host.Addrs(),
}
// // only share WAN-friendly addresses ??
// pi.Addrs = addrutil.WANShareableAddrs(pi.Addrs)
if len(pi.Addrs) < 1 {
return nil, fmt.Errorf("no known addresses for self. cannot put provider.")
}
pmes := pb.NewMessage(pb.Message_ADD_PROVIDER, string(skey), 0)
pmes.ProviderPeers = pb.RawPeerInfosToPBPeers([]pstore.PeerInfo{pi})
return pmes, nil
}
// FindProviders searches until the context expires. // FindProviders searches until the context expires.
func (dht *IpfsDHT) FindProviders(ctx context.Context, key key.Key) ([]pstore.PeerInfo, error) { func (dht *IpfsDHT) FindProviders(ctx context.Context, key key.Key) ([]pstore.PeerInfo, error) {
......
...@@ -48,11 +48,11 @@ func NewRoutingTable(bucketsize int, localID ID, latency time.Duration, m pstore ...@@ -48,11 +48,11 @@ func NewRoutingTable(bucketsize int, localID ID, latency time.Duration, m pstore
// Update adds or moves the given peer to the front of its respective bucket // Update adds or moves the given peer to the front of its respective bucket
// If a peer gets removed from a bucket, it is returned // If a peer gets removed from a bucket, it is returned
func (rt *RoutingTable) Update(p peer.ID) { func (rt *RoutingTable) Update(p peer.ID) {
rt.tabLock.Lock()
defer rt.tabLock.Unlock()
peerID := ConvertPeerID(p) peerID := ConvertPeerID(p)
cpl := commonPrefixLen(peerID, rt.local) cpl := commonPrefixLen(peerID, rt.local)
rt.tabLock.Lock()
defer rt.tabLock.Unlock()
bucketID := cpl bucketID := cpl
if bucketID >= len(rt.Buckets) { if bucketID >= len(rt.Buckets) {
bucketID = len(rt.Buckets) - 1 bucketID = len(rt.Buckets) - 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论