完成读取json文件、结果导出为txt文件、超时处理(时限参数可输入) 控制并发访问

This commit is contained in:
alen 2017-06-08 10:06:13 +08:00
parent 9e410eaa21
commit 47366d7696
3 changed files with 141 additions and 45 deletions

32
cfg.go
View file

@ -1,10 +1,13 @@
package main package main
import ( import (
"bufio"
"encoding/binary" "encoding/binary"
"encoding/json"
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
"os"
"strconv" "strconv"
"strings" "strings"
) )
@ -36,6 +39,35 @@ func Getfile(filePath string) ([]string, error) {
return result, nil return result, nil
} }
//gu
func GetJsonFile(filePath string) ([]SSHHost, error) {
result := []SSHHost{}
b, err := ioutil.ReadFile(filePath)
if err != nil {
log.Println("read file ", filePath, err)
return result, err
}
var m HostJson
json.Unmarshal(b, &m)
result = m.SshHosts
return result, nil
}
func WriteIntoTxt(sshHost SSHHost) error {
outputFile, outputError := os.OpenFile(sshHost.Host+".txt", os.O_WRONLY|os.O_CREATE, 0666)
if outputError != nil {
return outputError
}
defer outputFile.Close()
outputWriter := bufio.NewWriter(outputFile)
//var outputString string
outputString := sshHost.Result
outputWriter.WriteString(outputString)
outputWriter.Flush()
return nil
}
func GetIpList(filePath string) ([]string, error) { func GetIpList(filePath string) ([]string, error) {
res, err := Getfile(filePath) res, err := Getfile(filePath)
if err != nil { if err != nil {

140
main.go
View file

@ -4,17 +4,23 @@ import (
"flag" "flag"
"fmt" "fmt"
"log" "log"
"strconv"
"strings" "strings"
"time" "time"
// "github.com/bitly/go-simplejson" // "github.com/bitly/go-simplejson"
) )
type sshhost struct { type SSHHost struct {
host string Host string
port int Port int
username string Username string
password string Password string
cmd []string CmdFile string
Cmd []string
Result string
}
type HostJson struct {
SshHosts []SSHHost
} }
func main() { func main() {
@ -22,62 +28,100 @@ func main() {
cmd := flag.String("cmd", "", "cmds") cmd := flag.String("cmd", "", "cmds")
username := flag.String("u", "", "username") username := flag.String("u", "", "username")
password := flag.String("p", "", "password") password := flag.String("p", "", "password")
port := flag.Int("port", 22, "ssh port") port := flag.String("port", "", "ssh port")
cmdfile := flag.String("cmdfile", "", "cmdfile path") cmdFile := flag.String("cmdfile", "", "cmdfile path")
hostfile := flag.String("hostfile", "", "hostfile path") hostFile := flag.String("hostfile", "", "hostfile path")
ipfile := flag.String("ipfile", "", "hostfile path") ipFile := flag.String("ipfile", "", "hostfile path")
cfg := flag.String("cfg", "", "cfg path") cfg := flag.String("cfg", "", "cfg path")
//gu
jsonFile := flag.String("j", "ssh.json", "Json File Path")
outTxt := flag.Bool("outTxt", false, "write result into txt")
timeLimit := flag.Duration("t", 30, "max timeout")
numLimit := flag.Int("n", 20, "max execute number")
flag.Parse() flag.Parse()
var cmdList []string
var cmdlist []string var hostList []string
var hostlist []string
var err error var err error
//gu
var usernameList []string
var passwordList []string
var portList []string
sshhosts := []sshhost{} sshHosts := []SSHHost{}
var host_struct sshhost var host_Struct SSHHost
if *ipfile != "" { if *ipFile != "" {
hostlist, err = GetIpList(*ipfile) hostList, err = GetIpList(*ipFile)
if err != nil { if err != nil {
log.Println("load hostlist error: ", err) log.Println("load hostlist error: ", err)
return return
} }
} }
if *hostfile != "" { if *hostFile != "" {
hostlist, err = Getfile(*hostfile) hostList, err = Getfile(*hostFile)
if err != nil { if err != nil {
log.Println("load hostfile error: ", err) log.Println("load hostfile error: ", err)
return return
} }
} }
if *hosts != "" { if *hosts != "" {
hostlist = strings.Split(*hosts, ";") hostList = strings.Split(*hosts, ";")
} }
if *cmdfile != "" { //gu
cmdlist, err = Getfile(*cmdfile) if *username != "" {
usernameList = strings.Split(*username, ";")
}
if *password != "" {
passwordList = strings.Split(*password, ";")
}
if *port != "" {
portList = strings.Split(*port, ";")
}
////
if *cmdFile != "" {
cmdList, err = Getfile(*cmdFile)
if err != nil { if err != nil {
log.Println("load cmdfile error: ", err) log.Println("load cmdfile error: ", err)
return return
} }
} }
if *cmd != "" { if *cmd != "" {
cmdlist = strings.Split(*cmd, ";") cmdList = strings.Split(*cmd, ";")
} }
if *cfg == "" { if *cfg == "" {
for _, host := range hostlist { for pos, host := range hostList {
host_struct.host = host host_Struct.Host = host
host_struct.username = *username host_Struct.Username = usernameList[pos]
host_struct.password = *password host_Struct.Password = passwordList[pos]
host_struct.port = *port host_Struct.Port, _ = strconv.Atoi(portList[pos])
host_struct.cmd = cmdlist host_Struct.Cmd = cmdList
sshhosts = append(sshhosts, host_struct) sshHosts = append(sshHosts, host_Struct)
} }
} }
//gu
if *jsonFile != "" {
sshHosts, err = GetJsonFile(*jsonFile)
if err != nil {
log.Println("load jsonFile error: ", err)
return
}
for i := 0; i < len(sshHosts); i++ {
cmdList, err = Getfile(sshHosts[i].CmdFile)
if err != nil {
log.Println("load cmdFile error: ", err)
return
}
//fmt.Println(cmdList)
sshHosts[i].Cmd = cmdList
}
//为什么不能用for range
}
/* /*
else { else {
cfgjson, err := GetfileAll(*cfg) cfgjson, err := GetfileAll(*cfg)
@ -96,23 +140,41 @@ func main() {
} }
*/ */
//fmt.Println(sshhosts) //fmt.Println(sshhosts)
chLimit := make(chan bool, numLimit)
chs := make([]chan string, len(sshhosts)) chs := make([]chan string, len(sshHosts))
for i, host := range sshhosts { limitFunc := func(chLimit chan bool, ch chan string, host SSHHost) {
dossh(host.Username, host.Password, host.Host, host.Cmd, host.Port, ch)
<-chLimit
}
for i, host := range sshHosts {
chs[i] = make(chan string, 1) chs[i] = make(chan string, 1)
go dossh(host.username, host.password, host.host, host.cmd, host.port, chs[i]) chLimit <- true
go limitFunc(chLimit, chs[i], host)
} }
for i, ch := range chs { for i, ch := range chs {
fmt.Println(sshhosts[i].host, " ssh start") fmt.Println(sshHosts[i].Host, " ssh start")
select { select {
case res := <-ch: case res := <-ch:
if res != "" { if res != "" {
fmt.Println(res) fmt.Println(res)
sshHosts[i].Result += res
} }
case <-time.After(30 * 1000 * 1000 * 1000): case <-time.After(*timeLimit * 1000 * 1000 * 1000):
log.Println("SSH run timeout") log.Println("SSH run timeout")
sshHosts[i].Result += ("SSH run timeout" + strconv.Itoa(int(*timeLimit)) + "second.")
}
fmt.Println(sshHosts[i].Host, " ssh end")
}
//gu
if *outTxt {
for i := 0; i < len(sshHosts); i++ {
err = WriteIntoTxt(sshHosts[i])
if err != nil {
log.Println("write into txt error: ", err)
return
}
} }
fmt.Println(sshhosts[i].host, " ssh end\n")
} }
} }

View file

@ -3,9 +3,9 @@ package main
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"time"
"net" "net"
//"os"
"time"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -55,29 +55,31 @@ func connect(user, password, host string, port int) (*ssh.Session, error) {
func dossh(username, password, ip string, cmdlist []string, port int, ch chan string) { func dossh(username, password, ip string, cmdlist []string, port int, ch chan string) {
session, err := connect(username, password, ip, port) session, err := connect(username, password, ip, port)
if err != nil { if err != nil {
ch <- fmt.Sprintf("<%s>", err.Error()) ch <- fmt.Sprintf("<%s>", err.Error())
//<-chLimit
return return
} }
defer session.Close() defer session.Close()
// cmd := "ls;date;exit" // cmd := "ls;date;exit"
stdinBuf, _ := session.StdinPipe() stdinBuf, _ := session.StdinPipe()
//fmt.Fprintf(os.Stdout, "%s", stdinBuf)
var outbt, errbt bytes.Buffer var outbt, errbt bytes.Buffer
session.Stdout = &outbt session.Stdout = &outbt
session.Stderr = &errbt session.Stderr = &errbt
err = session.Shell() err = session.Shell()
for _, c := range cmdlist { for _, c := range cmdlist {
c = c + "\n" c = c + "\n"
stdinBuf.Write([]byte(c)) stdinBuf.Write([]byte(c))
} }
session.Wait() session.Wait()
ch <- (outbt.String() + errbt.String()) ch <- (outbt.String() + errbt.String())
//<-chLimit
return return
} }