Commit 47bdae94 authored by Francesc Campoy's avatar Francesc Campoy Committed by Francesc Campoy Flores

cmd/vet: detect defer resp.Body.Close() before error check

This check detects the code

	resp, err := http.Get("http://foo.com")
	defer resp.Body.Close()
	if err != nil {
		...
	}

For every call to a function on the net/http package or any method
on http.Client that returns (*http.Response, error), it checks
whether the next line is a defer statement that calls on the response.

Fixes #17780.

Change-Id: I9d70edcbfa2bad205bf7f45281597d074c795977
Reviewed-on: https://go-review.googlesource.com/32911Reviewed-by: default avatarRob Pike <r@golang.org>
parent 91135f27
...@@ -84,12 +84,12 @@ Flag: -copylocks ...@@ -84,12 +84,12 @@ Flag: -copylocks
Locks that are erroneously passed by value. Locks that are erroneously passed by value.
Tests and documentation examples HTTP responses used incorrectly
Flag: -tests Flag: -httpresponse
Mistakes involving tests including functions with incorrect names or signatures Mistakes deferring a function call on an HTTP response before
and example tests that document identifiers not in the package. checking whether the error returned with the response was nil.
Failure to call the cancelation function returned by WithCancel Failure to call the cancelation function returned by WithCancel
...@@ -162,6 +162,13 @@ Flag: -structtags ...@@ -162,6 +162,13 @@ Flag: -structtags
Struct tags that do not follow the format understood by reflect.StructTag.Get. Struct tags that do not follow the format understood by reflect.StructTag.Get.
Well-known encoding struct tags (json, xml) used with unexported fields. Well-known encoding struct tags (json, xml) used with unexported fields.
Tests and documentation examples
Flag: -tests
Mistakes involving tests including functions with incorrect names or signatures
and example tests that document identifiers not in the package.
Unreachable code Unreachable code
Flag: -unreachable Flag: -unreachable
......
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains the check for http.Response values being used before
// checking for errors.
package main
import (
"go/ast"
"go/types"
)
var (
httpResponseType types.Type
httpClientType types.Type
)
func init() {
if typ := importType("net/http", "Response"); typ != nil {
httpResponseType = typ
}
if typ := importType("net/http", "Client"); typ != nil {
httpClientType = typ
}
// if http.Response or http.Client are not defined don't register this check.
if httpResponseType == nil || httpClientType == nil {
return
}
register("httpresponse",
"check errors are checked before using an http Response",
checkHTTPResponse, callExpr)
}
func checkHTTPResponse(f *File, node ast.Node) {
call := node.(*ast.CallExpr)
if !isHTTPFuncOrMethodOnClient(f, call) {
return // the function call is not related to this check.
}
finder := &blockStmtFinder{node: call}
ast.Walk(finder, f.file)
stmts := finder.stmts()
if len(stmts) < 2 {
return // the call to the http function is the last statement of the block.
}
asg, ok := stmts[0].(*ast.AssignStmt)
if !ok {
return // the first statement is not assignment.
}
resp := rootIdent(asg.Lhs[0])
if resp == nil {
return // could not find the http.Response in the assignment.
}
def, ok := stmts[1].(*ast.DeferStmt)
if !ok {
return // the following statement is not a defer.
}
root := rootIdent(def.Call.Fun)
if root == nil {
return // could not find the receiver of the defer call.
}
if resp.Obj == root.Obj {
f.Badf(root.Pos(), "using %s before checking for errors", resp.Name)
}
}
// isHTTPFuncOrMethodOnClient checks whether the given call expression is on
// either a function of the net/http package or a method of http.Client that
// returns (*http.Response, error).
func isHTTPFuncOrMethodOnClient(f *File, expr *ast.CallExpr) bool {
fun, _ := expr.Fun.(*ast.SelectorExpr)
sig, _ := f.pkg.types[fun].Type.(*types.Signature)
if sig == nil {
return false // the call is not on of the form x.f()
}
res := sig.Results()
if res.Len() != 2 {
return false // the function called does not return two values.
}
if ptr, ok := res.At(0).Type().(*types.Pointer); !ok || !types.Identical(ptr.Elem(), httpResponseType) {
return false // the first return type is not *http.Response.
}
if !types.Identical(res.At(1).Type().Underlying(), errorType) {
return false // the second return type is not error
}
typ := f.pkg.types[fun.X].Type
if typ == nil {
id, ok := fun.X.(*ast.Ident)
return ok && id.Name == "http" // function in net/http package.
}
if types.Identical(typ, httpClientType) {
return true // method on http.Client.
}
ptr, ok := typ.(*types.Pointer)
return ok && types.Identical(ptr.Elem(), httpClientType) // method on *http.Client.
}
// blockStmtFinder is an ast.Visitor that given any ast node can find the
// statement containing it and its succeeding statements in the same block.
type blockStmtFinder struct {
node ast.Node // target of search
stmt ast.Stmt // innermost statement enclosing argument to Visit
block *ast.BlockStmt // innermost block enclosing argument to Visit.
}
// Visit finds f.node performing a search down the ast tree.
// It keeps the last block statement and statement seen for later use.
func (f *blockStmtFinder) Visit(node ast.Node) ast.Visitor {
if node == nil || f.node.Pos() < node.Pos() || f.node.End() > node.End() {
return nil // not here
}
switch n := node.(type) {
case *ast.BlockStmt:
f.block = n
case ast.Stmt:
f.stmt = n
}
if f.node.Pos() == node.Pos() && f.node.End() == node.End() {
return nil // found
}
return f // keep looking
}
// stmts returns the statements of f.block starting from the one including f.node.
func (f *blockStmtFinder) stmts() []ast.Stmt {
for i, v := range f.block.List {
if f.stmt == v {
return f.block.List[i:]
}
}
return nil
}
// rootIdent finds the root identifier x in a chain of selections x.y.z, or nil if not found.
func rootIdent(n ast.Node) *ast.Ident {
switch n := n.(type) {
case *ast.SelectorExpr:
return rootIdent(n.X)
case *ast.Ident:
return n
default:
return nil
}
}
package testdata
import (
"log"
"net/http"
)
func goodHTTPGet() {
res, err := http.Get("http://foo.com")
if err != nil {
log.Fatal(err)
}
defer res.Body.Close()
}
func badHTTPGet() {
res, err := http.Get("http://foo.com")
defer res.Body.Close() // ERROR "using res before checking for errors"
if err != nil {
log.Fatal(err)
}
}
func badHTTPHead() {
res, err := http.Head("http://foo.com")
defer res.Body.Close() // ERROR "using res before checking for errors"
if err != nil {
log.Fatal(err)
}
}
func goodClientGet() {
client := http.DefaultClient
res, err := client.Get("http://foo.com")
if err != nil {
log.Fatal(err)
}
defer res.Body.Close()
}
func badClientPtrGet() {
client := http.DefaultClient
resp, err := client.Get("http://foo.com")
defer resp.Body.Close() // ERROR "using resp before checking for errors"
if err != nil {
log.Fatal(err)
}
}
func badClientGet() {
client := http.Client{}
resp, err := client.Get("http://foo.com")
defer resp.Body.Close() // ERROR "using resp before checking for errors"
if err != nil {
log.Fatal(err)
}
}
func badClientPtrDo() {
client := http.DefaultClient
req, err := http.NewRequest("GET", "http://foo.com", nil)
if err != nil {
log.Fatal(err)
}
resp, err := client.Do(req)
defer resp.Body.Close() // ERROR "using resp before checking for errors"
if err != nil {
log.Fatal(err)
}
}
func badClientDo() {
var client http.Client
req, err := http.NewRequest("GET", "http://foo.com", nil)
if err != nil {
log.Fatal(err)
}
resp, err := client.Do(req)
defer resp.Body.Close() // ERROR "using resp before checking for errors"
if err != nil {
log.Fatal(err)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment