提交 db56c0f1 作者: Juan Batiz-Benet

Merge pull request #1037 from torarnv/harden-shutdown-logic

Harden shutdown logic
...@@ -80,6 +80,15 @@ func daemonFunc(req cmds.Request, res cmds.Response) { ...@@ -80,6 +80,15 @@ func daemonFunc(req cmds.Request, res cmds.Response) {
// let the user know we're going. // let the user know we're going.
fmt.Printf("Initializing daemon...\n") fmt.Printf("Initializing daemon...\n")
ctx := req.Context()
go func() {
select {
case <-ctx.Context.Done():
fmt.Println("Received interrupt signal, shutting down...")
}
}()
// first, whether user has provided the initialization flag. we may be // first, whether user has provided the initialization flag. we may be
// running in an uninitialized state. // running in an uninitialized state.
initialize, _, err := req.Option(initOptionKwd).Bool() initialize, _, err := req.Option(initOptionKwd).Bool()
...@@ -111,7 +120,6 @@ func daemonFunc(req cmds.Request, res cmds.Response) { ...@@ -111,7 +120,6 @@ func daemonFunc(req cmds.Request, res cmds.Response) {
return return
} }
ctx := req.Context()
cfg, err := ctx.GetConfig() cfg, err := ctx.GetConfig()
if err != nil { if err != nil {
res.SetError(err, cmds.ErrNormal) res.SetError(err, cmds.ErrNormal)
...@@ -149,7 +157,19 @@ func daemonFunc(req cmds.Request, res cmds.Response) { ...@@ -149,7 +157,19 @@ func daemonFunc(req cmds.Request, res cmds.Response) {
res.SetError(err, cmds.ErrNormal) res.SetError(err, cmds.ErrNormal)
return return
} }
defer node.Close()
defer func() {
// We wait for the node to close first, as the node has children
// that it will wait for before closing, such as the API server.
node.Close()
select {
case <-ctx.Context.Done():
log.Info("Gracefully shut down daemon")
default:
}
}()
req.Context().ConstructNode = func() (*core.IpfsNode, error) { req.Context().ConstructNode = func() (*core.IpfsNode, error) {
return node, nil return node, nil
} }
...@@ -262,9 +282,6 @@ func daemonFunc(req cmds.Request, res cmds.Response) { ...@@ -262,9 +282,6 @@ func daemonFunc(req cmds.Request, res cmds.Response) {
corehttp.VersionOption(), corehttp.VersionOption(),
} }
// our global interrupt handler can now try to stop the daemon
close(req.Context().InitDone)
if rootRedirect != nil { if rootRedirect != nil {
opts = append(opts, rootRedirect) opts = append(opts, rootRedirect)
} }
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
...@@ -39,7 +40,6 @@ const ( ...@@ -39,7 +40,6 @@ const (
cpuProfile = "ipfs.cpuprof" cpuProfile = "ipfs.cpuprof"
heapProfile = "ipfs.memprof" heapProfile = "ipfs.memprof"
errorFormat = "ERROR: %v\n\n" errorFormat = "ERROR: %v\n\n"
shutdownMessage = "Received interrupt signal, shutting down..."
) )
type cmdInvocation struct { type cmdInvocation struct {
...@@ -132,15 +132,10 @@ func main() { ...@@ -132,15 +132,10 @@ func main() {
os.Exit(1) os.Exit(1)
} }
// our global interrupt handler may try to stop the daemon
// before the daemon is ready to be stopped; this dirty
// workaround is for the daemon only; other commands are always
// ready to be stopped
if invoc.cmd != daemonCmd {
close(invoc.req.Context().InitDone)
}
// ok, finally, run the command invocation. // ok, finally, run the command invocation.
intrh, ctx := invoc.SetupInterruptHandler(ctx)
defer intrh.Close()
output, err := invoc.Run(ctx) output, err := invoc.Run(ctx)
if err != nil { if err != nil {
printErr(err) printErr(err)
...@@ -157,8 +152,6 @@ func main() { ...@@ -157,8 +152,6 @@ func main() {
} }
func (i *cmdInvocation) Run(ctx context.Context) (output io.Reader, err error) { func (i *cmdInvocation) Run(ctx context.Context) (output io.Reader, err error) {
// setup our global interrupt handler.
i.setupInterruptHandler()
// check if user wants to debug. option OR env var. // check if user wants to debug. option OR env var.
debug, _, err := i.req.Option("debug").Bool() debug, _, err := i.req.Option("debug").Bool()
...@@ -226,7 +219,6 @@ func (i *cmdInvocation) Parse(ctx context.Context, args []string) error { ...@@ -226,7 +219,6 @@ func (i *cmdInvocation) Parse(ctx context.Context, args []string) error {
if err != nil { if err != nil {
return err return err
} }
i.req.Context().Context = ctx
repoPath, err := getRepoPath(i.req) repoPath, err := getRepoPath(i.req)
if err != nil { if err != nil {
...@@ -279,6 +271,8 @@ func callCommand(ctx context.Context, req cmds.Request, root *cmds.Command, cmd ...@@ -279,6 +271,8 @@ func callCommand(ctx context.Context, req cmds.Request, root *cmds.Command, cmd
log.Info(config.EnvDir, " ", req.Context().ConfigRoot) log.Info(config.EnvDir, " ", req.Context().ConfigRoot)
var res cmds.Response var res cmds.Response
req.Context().Context = ctx
details, err := commandDetails(req.Path(), root) details, err := commandDetails(req.Path(), root)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -474,59 +468,70 @@ func writeHeapProfileToFile() error { ...@@ -474,59 +468,70 @@ func writeHeapProfileToFile() error {
return pprof.WriteHeapProfile(mprof) return pprof.WriteHeapProfile(mprof)
} }
// listen for and handle SIGTERM // IntrHandler helps set up an interrupt handler that can
func (i *cmdInvocation) setupInterruptHandler() { // be cleanly shut down through the io.Closer interface.
type IntrHandler struct {
sig chan os.Signal
wg sync.WaitGroup
}
ctx := i.req.Context() func NewIntrHandler() *IntrHandler {
sig := allInterruptSignals() ih := &IntrHandler{}
ih.sig = make(chan os.Signal, 1)
return ih
}
go func() { func (ih *IntrHandler) Close() error {
// first time, try to shut down. close(ih.sig)
ih.wg.Wait()
return nil
}
// loop because we may be
for count := 0; ; count++ {
<-sig
// if we're still initializing, cannot use `ctx.GetNode()` // Handle starts handling the given signals, and will call the handler
select { // callback function each time a signal is catched. The function is passed
default: // initialization not done // the number of times the handler has been triggered in total, as
fmt.Println(shutdownMessage) // well as the handler itself, so that the handling logic can use the
os.Exit(-1) // handler's wait group to ensure clean shutdown when Close() is called.
case <-ctx.InitDone: func (ih *IntrHandler) Handle(handler func(count int, ih *IntrHandler), sigs ...os.Signal) {
signal.Notify(ih.sig, sigs...)
ih.wg.Add(1)
go func() {
defer ih.wg.Done()
count := 0
for _ = range ih.sig {
count++
handler(count, ih)
} }
signal.Stop(ih.sig)
}()
}
func (i *cmdInvocation) SetupInterruptHandler(ctx context.Context) (io.Closer, context.Context) {
intrh := NewIntrHandler()
ctx, cancelFunc := context.WithCancel(ctx)
handlerFunc := func(count int, ih *IntrHandler) {
switch count { switch count {
case 0: case 1:
fmt.Println(shutdownMessage) fmt.Println() // Prevent un-terminated ^C character in terminal
if ctx.Online {
ih.wg.Add(1)
go func() { go func() {
// TODO cancel the command context instead defer ih.wg.Done()
n, err := ctx.GetNode() cancelFunc()
if err != nil {
log.Error(err)
fmt.Println(shutdownMessage)
os.Exit(-1)
}
n.Close()
log.Info("Gracefully shut down.")
}() }()
} else {
os.Exit(0)
}
default: default:
fmt.Println("Received another interrupt before graceful shutdown, terminating...") fmt.Println("Received another interrupt before graceful shutdown, terminating...")
os.Exit(-1) os.Exit(-1)
} }
} }
}()
}
func allInterruptSignals() chan os.Signal { intrh.Handle(handlerFunc, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM)
sigc := make(chan os.Signal, 1)
signal.Notify(sigc, syscall.SIGHUP, syscall.SIGINT, return intrh, ctx
syscall.SIGTERM)
return sigc
} }
func profileIfEnabled() (func(), error) { func profileIfEnabled() (func(), error) {
......
...@@ -82,25 +82,44 @@ func (c *client) Send(req cmds.Request) (cmds.Response, error) { ...@@ -82,25 +82,44 @@ func (c *client) Send(req cmds.Request) (cmds.Response, error) {
version := config.CurrentVersionNumber version := config.CurrentVersionNumber
httpReq.Header.Set("User-Agent", fmt.Sprintf("/go-ipfs/%s/", version)) httpReq.Header.Set("User-Agent", fmt.Sprintf("/go-ipfs/%s/", version))
ec := make(chan error, 1)
rc := make(chan cmds.Response, 1)
dc := req.Context().Context.Done()
go func() {
httpRes, err := http.DefaultClient.Do(httpReq) httpRes, err := http.DefaultClient.Do(httpReq)
if err != nil { if err != nil {
return nil, err ec <- err
return
} }
// using the overridden JSON encoding in request // using the overridden JSON encoding in request
res, err := getResponse(httpRes, req) res, err := getResponse(httpRes, req)
if err != nil { if err != nil {
return nil, err ec <- err
return
} }
rc <- res
}()
for {
select {
case <-dc:
log.Debug("Context cancelled, cancelling HTTP request...")
tr := http.DefaultTransport.(*http.Transport)
tr.CancelRequest(httpReq)
dc = nil // Wait for ec or rc
case err := <-ec:
return nil, err
case res := <-rc:
if found && len(previousUserProvidedEncoding) > 0 { if found && len(previousUserProvidedEncoding) > 0 {
// reset to user provided encoding after sending request // reset to user provided encoding after sending request
// NB: if user has provided an encoding but it is the empty string, // NB: if user has provided an encoding but it is the empty string,
// still leave it as JSON. // still leave it as JSON.
req.SetOption(cmds.EncShort, previousUserProvidedEncoding) req.SetOption(cmds.EncShort, previousUserProvidedEncoding)
} }
return res, nil return res, nil
}
}
} }
func getQuery(req cmds.Request) (string, error) { func getQuery(req cmds.Request) (string, error) {
...@@ -162,6 +181,8 @@ func getResponse(httpRes *http.Response, req cmds.Request) (cmds.Response, error ...@@ -162,6 +181,8 @@ func getResponse(httpRes *http.Response, req cmds.Request) (cmds.Response, error
dec := json.NewDecoder(httpRes.Body) dec := json.NewDecoder(httpRes.Body)
outputType := reflect.TypeOf(req.Command().Type) outputType := reflect.TypeOf(req.Command().Type)
ctx := req.Context().Context
for { for {
var v interface{} var v interface{}
var err error var err error
...@@ -175,6 +196,14 @@ func getResponse(httpRes *http.Response, req cmds.Request) (cmds.Response, error ...@@ -175,6 +196,14 @@ func getResponse(httpRes *http.Response, req cmds.Request) (cmds.Response, error
fmt.Println(err.Error()) fmt.Println(err.Error())
return return
} }
select {
case <-ctx.Done():
close(outChan)
return
default:
}
if err == io.EOF { if err == io.EOF {
close(outChan) close(outChan)
return return
......
...@@ -30,7 +30,6 @@ type Context struct { ...@@ -30,7 +30,6 @@ type Context struct {
node *core.IpfsNode node *core.IpfsNode
ConstructNode func() (*core.IpfsNode, error) ConstructNode func() (*core.IpfsNode, error)
InitDone chan bool
} }
// GetConfig returns the config of the current Command exection // GetConfig returns the config of the current Command exection
...@@ -288,7 +287,7 @@ func NewRequest(path []string, opts OptMap, args []string, file files.File, cmd ...@@ -288,7 +287,7 @@ func NewRequest(path []string, opts OptMap, args []string, file files.File, cmd
optDefs = make(map[string]Option) optDefs = make(map[string]Option)
} }
ctx := Context{Context: context.TODO(), InitDone: make(chan bool)} ctx := Context{Context: context.TODO()}
values := make(map[string]interface{}) values := make(map[string]interface{})
req := &request{path, opts, args, file, cmd, ctx, optDefs, values, os.Stdin} req := &request{path, opts, args, file, cmd, ctx, optDefs, values, os.Stdin}
err := req.ConvertOptions() err := req.ConvertOptions()
......
...@@ -2,6 +2,7 @@ package corehttp ...@@ -2,6 +2,7 @@ package corehttp
import ( import (
"net/http" "net/http"
"time"
manners "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/braintree/manners" manners "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/braintree/manners"
ma "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" ma "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr"
...@@ -63,6 +64,9 @@ func listenAndServe(node *core.IpfsNode, addr ma.Multiaddr, handler http.Handler ...@@ -63,6 +64,9 @@ func listenAndServe(node *core.IpfsNode, addr ma.Multiaddr, handler http.Handler
var serverError error var serverError error
serverExited := make(chan struct{}) serverExited := make(chan struct{})
node.Children().Add(1)
defer node.Children().Done()
go func() { go func() {
serverError = server.ListenAndServe(host, handler) serverError = server.ListenAndServe(host, handler)
close(serverExited) close(serverExited)
...@@ -75,8 +79,22 @@ func listenAndServe(node *core.IpfsNode, addr ma.Multiaddr, handler http.Handler ...@@ -75,8 +79,22 @@ func listenAndServe(node *core.IpfsNode, addr ma.Multiaddr, handler http.Handler
// if node being closed before server exits, close server // if node being closed before server exits, close server
case <-node.Closing(): case <-node.Closing():
log.Infof("server at %s terminating...", addr) log.Infof("server at %s terminating...", addr)
// make sure keep-alive connections do not keep the server running
server.InnerServer.SetKeepAlivesEnabled(false)
server.Shutdown <- true server.Shutdown <- true
<-serverExited // now, DO wait until server exit
outer:
for {
// wait until server exits
select {
case <-serverExited:
break outer
case <-time.After(5 * time.Second):
log.Infof("waiting for server at %s to terminate...", addr)
}
}
} }
log.Infof("server at %s terminated", addr) log.Infof("server at %s terminated", addr)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论