提交 adfbecf3 作者: Steven Allen

use stream.Reset where appropriate

License: MIT
Signed-off-by: 's avatarSteven Allen <steven@stebalien.com>
上级 8deaaa8d
...@@ -40,6 +40,7 @@ type BitSwapNetwork interface { ...@@ -40,6 +40,7 @@ type BitSwapNetwork interface {
type MessageSender interface { type MessageSender interface {
SendMsg(context.Context, bsmsg.BitSwapMessage) error SendMsg(context.Context, bsmsg.BitSwapMessage) error
Close() error Close() error
Reset() error
} }
// Implement Receiver to receive messages from the BitSwapNetwork // Implement Receiver to receive messages from the BitSwapNetwork
......
...@@ -56,6 +56,10 @@ func (s *streamMessageSender) Close() error { ...@@ -56,6 +56,10 @@ func (s *streamMessageSender) Close() error {
return s.s.Close() return s.s.Close()
} }
func (s *streamMessageSender) Reset() error {
return s.s.Reset()
}
func (s *streamMessageSender) SendMsg(ctx context.Context, msg bsmsg.BitSwapMessage) error { func (s *streamMessageSender) SendMsg(ctx context.Context, msg bsmsg.BitSwapMessage) error {
return msgToStream(ctx, s.s, msg) return msgToStream(ctx, s.s, msg)
} }
...@@ -121,9 +125,14 @@ func (bsnet *impl) SendMessage( ...@@ -121,9 +125,14 @@ func (bsnet *impl) SendMessage(
if err != nil { if err != nil {
return err return err
} }
defer s.Close()
return msgToStream(ctx, s, outgoing) err = msgToStream(ctx, s, outgoing)
if err != nil {
s.Reset()
} else {
s.Close()
}
return err
} }
func (bsnet *impl) SetDelegate(r Receiver) { func (bsnet *impl) SetDelegate(r Receiver) {
...@@ -180,6 +189,7 @@ func (bsnet *impl) handleNewStream(s inet.Stream) { ...@@ -180,6 +189,7 @@ func (bsnet *impl) handleNewStream(s inet.Stream) {
defer s.Close() defer s.Close()
if bsnet.receiver == nil { if bsnet.receiver == nil {
s.Reset()
return return
} }
...@@ -188,6 +198,7 @@ func (bsnet *impl) handleNewStream(s inet.Stream) { ...@@ -188,6 +198,7 @@ func (bsnet *impl) handleNewStream(s inet.Stream) {
received, err := bsmsg.FromPBReader(reader) received, err := bsmsg.FromPBReader(reader)
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
s.Reset()
go bsnet.receiver.ReceiveError(err) go bsnet.receiver.ReceiveError(err)
log.Debugf("bitswap net handleNewStream from %s error: %s", s.Conn().RemotePeer(), err) log.Debugf("bitswap net handleNewStream from %s error: %s", s.Conn().RemotePeer(), err)
} }
......
...@@ -133,6 +133,10 @@ func (mp *messagePasser) Close() error { ...@@ -133,6 +133,10 @@ func (mp *messagePasser) Close() error {
return nil return nil
} }
func (mp *messagePasser) Reset() error {
return nil
}
func (n *networkClient) NewMessageSender(ctx context.Context, p peer.ID) (bsnet.MessageSender, error) { func (n *networkClient) NewMessageSender(ctx context.Context, p peer.ID) (bsnet.MessageSender, error) {
return &messagePasser{ return &messagePasser{
net: n.network, net: n.network,
......
...@@ -172,18 +172,19 @@ func (pm *WantManager) stopPeerHandler(p peer.ID) { ...@@ -172,18 +172,19 @@ func (pm *WantManager) stopPeerHandler(p peer.ID) {
} }
func (mq *msgQueue) runQueue(ctx context.Context) { func (mq *msgQueue) runQueue(ctx context.Context) {
defer func() {
if mq.sender != nil {
mq.sender.Close()
}
}()
for { for {
select { select {
case <-mq.work: // there is work to be done case <-mq.work: // there is work to be done
mq.doWork(ctx) mq.doWork(ctx)
case <-mq.done: case <-mq.done:
if mq.sender != nil {
mq.sender.Close()
}
return return
case <-ctx.Done(): case <-ctx.Done():
if mq.sender != nil {
mq.sender.Reset()
}
return return
} }
} }
...@@ -218,7 +219,7 @@ func (mq *msgQueue) doWork(ctx context.Context) { ...@@ -218,7 +219,7 @@ func (mq *msgQueue) doWork(ctx context.Context) {
} }
log.Infof("bitswap send error: %s", err) log.Infof("bitswap send error: %s", err)
mq.sender.Close() mq.sender.Reset()
mq.sender = nil mq.sender = nil
select { select {
......
...@@ -64,7 +64,7 @@ func (p2p *P2P) Dial(ctx context.Context, addr ma.Multiaddr, peer peer.ID, proto ...@@ -64,7 +64,7 @@ func (p2p *P2P) Dial(ctx context.Context, addr ma.Multiaddr, peer peer.ID, proto
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
listener, err := manet.Listen(bindAddr) listener, err := manet.Listen(bindAddr)
if err != nil { if err != nil {
if err2 := remote.Close(); err2 != nil { if err2 := remote.Reset(); err2 != nil {
return nil, err2 return nil, err2
} }
return nil, err return nil, err
...@@ -158,7 +158,7 @@ func (p2p *P2P) registerStreamHandler(ctx2 context.Context, protocol string) (*P ...@@ -158,7 +158,7 @@ func (p2p *P2P) registerStreamHandler(ctx2 context.Context, protocol string) (*P
select { select {
case list.conCh <- s: case list.conCh <- s:
case <-ctx.Done(): case <-ctx.Done():
s.Close() s.Reset()
} }
}) })
...@@ -198,7 +198,7 @@ func (p2p *P2P) acceptStreams(listenerInfo *ListenerInfo, listener Listener) { ...@@ -198,7 +198,7 @@ func (p2p *P2P) acceptStreams(listenerInfo *ListenerInfo, listener Listener) {
local, err := manet.Dial(listenerInfo.Address) local, err := manet.Dial(listenerInfo.Address)
if err != nil { if err != nil {
remote.Close() remote.Reset()
continue continue
} }
......
...@@ -4,6 +4,8 @@ import ( ...@@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"io" "io"
net "gx/ipfs/QmNa31VPzC561NWwRsJLE7nGYZYuuD2QfpK2b1q9BK54J1/go-libp2p-net"
manet "gx/ipfs/QmX3U3YXCQ6UYBxq2LVWF8dARS1hPUTEYLrSx654Qyxyw6/go-multiaddr-net"
ma "gx/ipfs/QmXY77cVe7rVRQXZZQRioukUM7aRW3BTcAgJe12MCtb3Ji/go-multiaddr" ma "gx/ipfs/QmXY77cVe7rVRQXZZQRioukUM7aRW3BTcAgJe12MCtb3Ji/go-multiaddr"
peer "gx/ipfs/QmXYjuNuxVzXKJCfWasQk1RqkhVLDM9jtUKhqc2WPQmFSB/go-libp2p-peer" peer "gx/ipfs/QmXYjuNuxVzXKJCfWasQk1RqkhVLDM9jtUKhqc2WPQmFSB/go-libp2p-peer"
) )
...@@ -76,8 +78,8 @@ type StreamInfo struct { ...@@ -76,8 +78,8 @@ type StreamInfo struct {
RemotePeer peer.ID RemotePeer peer.ID
RemoteAddr ma.Multiaddr RemoteAddr ma.Multiaddr
Local io.ReadWriteCloser Local manet.Conn
Remote io.ReadWriteCloser Remote net.Stream
Registry *StreamRegistry Registry *StreamRegistry
} }
...@@ -90,15 +92,31 @@ func (s *StreamInfo) Close() error { ...@@ -90,15 +92,31 @@ func (s *StreamInfo) Close() error {
return nil return nil
} }
// Reset closes stream endpoints and deregisters it
func (s *StreamInfo) Reset() error {
s.Local.Close()
s.Remote.Reset()
s.Registry.Deregister(s.HandlerID)
return nil
}
func (s *StreamInfo) startStreaming() { func (s *StreamInfo) startStreaming() {
go func() { go func() {
io.Copy(s.Local, s.Remote) _, err := io.Copy(s.Local, s.Remote)
s.Close() if err != nil {
s.Reset()
} else {
s.Close()
}
}() }()
go func() { go func() {
io.Copy(s.Remote, s.Local) _, err := io.Copy(s.Remote, s.Local)
s.Close() if err != nil {
s.Reset()
} else {
s.Close()
}
}() }()
} }
......
...@@ -42,6 +42,7 @@ func (lb *Loopback) HandleStream(s inet.Stream) { ...@@ -42,6 +42,7 @@ func (lb *Loopback) HandleStream(s inet.Stream) {
pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax) pbr := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
var incoming dhtpb.Message var incoming dhtpb.Message
if err := pbr.ReadMsg(&incoming); err != nil { if err := pbr.ReadMsg(&incoming); err != nil {
s.Reset()
log.Debug(err) log.Debug(err)
return return
} }
...@@ -51,6 +52,8 @@ func (lb *Loopback) HandleStream(s inet.Stream) { ...@@ -51,6 +52,8 @@ func (lb *Loopback) HandleStream(s inet.Stream) {
pbw := ggio.NewDelimitedWriter(s) pbw := ggio.NewDelimitedWriter(s)
if err := pbw.WriteMsg(outgoing); err != nil { if err := pbw.WriteMsg(outgoing); err != nil {
return // TODO logerr s.Reset()
log.Debug(err)
return
} }
} }
...@@ -60,7 +60,7 @@ func (px *standard) Bootstrap(ctx context.Context) error { ...@@ -60,7 +60,7 @@ func (px *standard) Bootstrap(ctx context.Context) error {
func (p *standard) HandleStream(s inet.Stream) { func (p *standard) HandleStream(s inet.Stream) {
// TODO(brian): Should clients be able to satisfy requests? // TODO(brian): Should clients be able to satisfy requests?
log.Error("supernode client received (dropped) a routing message from", s.Conn().RemotePeer()) log.Error("supernode client received (dropped) a routing message from", s.Conn().RemotePeer())
s.Close() s.Reset()
} }
const replicationFactor = 2 const replicationFactor = 2
...@@ -102,9 +102,15 @@ func (px *standard) sendMessage(ctx context.Context, m *dhtpb.Message, remote pe ...@@ -102,9 +102,15 @@ func (px *standard) sendMessage(ctx context.Context, m *dhtpb.Message, remote pe
if err != nil { if err != nil {
return err return err
} }
defer s.Close()
pbw := ggio.NewDelimitedWriter(s) pbw := ggio.NewDelimitedWriter(s)
return pbw.WriteMsg(m)
err = pbw.WriteMsg(m)
if err == nil {
s.Close()
} else {
s.Reset()
}
return err
} }
// SendRequest sends the request to each remote sequentially (randomized order), // SendRequest sends the request to each remote sequentially (randomized order),
...@@ -139,17 +145,20 @@ func (px *standard) sendRequest(ctx context.Context, m *dhtpb.Message, remote pe ...@@ -139,17 +145,20 @@ func (px *standard) sendRequest(ctx context.Context, m *dhtpb.Message, remote pe
r := ggio.NewDelimitedReader(s, inet.MessageSizeMax) r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
w := ggio.NewDelimitedWriter(s) w := ggio.NewDelimitedWriter(s)
if err = w.WriteMsg(m); err != nil { if err = w.WriteMsg(m); err != nil {
s.Reset()
e.SetError(err) e.SetError(err)
return nil, err return nil, err
} }
response := &dhtpb.Message{} response := &dhtpb.Message{}
if err = r.ReadMsg(response); err != nil { if err = r.ReadMsg(response); err != nil {
s.Reset()
e.SetError(err) e.SetError(err)
return nil, err return nil, err
} }
// need ctx expiration? // need ctx expiration?
if response == nil { if response == nil {
s.Reset()
err := errors.New("no response to request") err := errors.New("no response to request")
e.SetError(err) e.SetError(err)
return nil, err return nil, err
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论