diff --git a/README.MD b/README.MD index 12f8bfc..bc46240 100644 --- a/README.MD +++ b/README.MD @@ -28,6 +28,15 @@ go build ssh port (default 22) //主机的 SSH 端口,默认 22 -u string username //主机的 SSH 用户名 + -j string + jsonFile //保存大量主机,包括主机地址,SSH用户名,SSH密码,SSH端口,所需执行的cmd指令文件地址 + -outTxt bool + outTxt (default false) //是否允许把结果保存到文件中,true为允许 false为默认值 + -t duration + timeLimit (default 30) //最大并发访问时间 默认为30s + -n int + numLimit (default 20) //最大并发访问量 默认为20 + ``` **cmdfile 示例** ``` @@ -102,3 +111,20 @@ sw-2#exit 10.10.15.102 ssh end ``` +#### ipfile +``` +./multissh -j jsonSample.json -t 30 -n 20 -outTxt true +10.10.15.101 ssh start +sw-1#show clock +05:29:43.269 UTC Tue Jun 6 2017 +sw-1#exit + +10.10.15.101 ssh end + +10.10.15.102 ssh start +sw-2#show clock +05:27:41.332 UTC Tue Jun 6 2017 +sw-2#exit + +10.10.15.102 ssh end +``` \ No newline at end of file diff --git a/cfg.go b/cfg.go index eb9827a..7aa4d88 100644 --- a/cfg.go +++ b/cfg.go @@ -1,10 +1,13 @@ package main import ( + "bufio" "encoding/binary" + "encoding/json" "io/ioutil" "log" "net" + "os" "strconv" "strings" ) @@ -36,6 +39,35 @@ func Getfile(filePath string) ([]string, error) { return result, nil } +//gu +func GetJsonFile(filePath string) ([]SSHHost, error) { + result := []SSHHost{} + b, err := ioutil.ReadFile(filePath) + if err != nil { + log.Println("read file ", filePath, err) + return result, err + } + var m HostJson + json.Unmarshal(b, &m) + result = m.SshHosts + return result, nil +} +func WriteIntoTxt(sshHost SSHHost) error { + outputFile, outputError := os.OpenFile(sshHost.Host+".txt", os.O_WRONLY|os.O_CREATE, 0666) + if outputError != nil { + return outputError + } + defer outputFile.Close() + + outputWriter := bufio.NewWriter(outputFile) + //var outputString string + + outputString := sshHost.Result + outputWriter.WriteString(outputString) + outputWriter.Flush() + return nil +} + func GetIpList(filePath string) ([]string, error) { res, err := Getfile(filePath) if err != nil { diff --git a/cmdSample.txt b/cmdSample.txt new file mode 100644 index 0000000..332fea8 --- /dev/null +++ b/cmdSample.txt @@ -0,0 +1,4 @@ +date +sleep 10 +date +exit \ No newline at end of file diff --git a/jsonSample.json b/jsonSample.json new file mode 100644 index 0000000..0d17385 --- /dev/null +++ b/jsonSample.json @@ -0,0 +1 @@ +{"SshHosts":[{"Host":"1.1.1.1","Port":11,"Username":"xxx","Password":"xxx","CmdFile":"cmd1.txt"},{"Host":"1.1.1.1","Port":11,"Username":"yyy","Password":"yyy","CmdFile":"cmd2.txt"}]} diff --git a/main.go b/main.go index 1dea59a..73bfe28 100644 --- a/main.go +++ b/main.go @@ -4,17 +4,24 @@ import ( "flag" "fmt" "log" + "strconv" "strings" "time" + // "github.com/bitly/go-simplejson" ) -type sshhost struct { - host string - port int - username string - password string - cmd []string +type SSHHost struct { + Host string + Port int + Username string + Password string + CmdFile string + Cmd []string + Result string +} +type HostJson struct { + SshHosts []SSHHost } func main() { @@ -23,61 +30,82 @@ func main() { username := flag.String("u", "", "username") password := flag.String("p", "", "password") port := flag.Int("port", 22, "ssh port") - cmdfile := flag.String("cmdfile", "", "cmdfile path") - hostfile := flag.String("hostfile", "", "hostfile path") - ipfile := flag.String("ipfile", "", "hostfile path") + cmdFile := flag.String("cmdfile", "", "cmdfile path") + hostFile := flag.String("hostfile", "", "hostfile path") + ipFile := flag.String("ipfile", "", "hostfile path") cfg := flag.String("cfg", "", "cfg path") + //gu + jsonFile := flag.String("j", "", "Json File Path") + outTxt := flag.Bool("outTxt", false, "write result into txt") + timeLimit := flag.Duration("t", 30, "max timeout") + numLimit := flag.Int("n", 20, "max execute number") flag.Parse() - - var cmdlist []string - var hostlist []string + var cmdList []string + var hostList []string var err error - sshhosts := []sshhost{} - var host_struct sshhost + sshHosts := []SSHHost{} + var host_Struct SSHHost - if *ipfile != "" { - hostlist, err = GetIpList(*ipfile) + if *ipFile != "" { + hostList, err = GetIpList(*ipFile) if err != nil { log.Println("load hostlist error: ", err) return } } - if *hostfile != "" { - hostlist, err = Getfile(*hostfile) + if *hostFile != "" { + hostList, err = Getfile(*hostFile) if err != nil { log.Println("load hostfile error: ", err) return } } if *hosts != "" { - hostlist = strings.Split(*hosts, ";") - + hostList = strings.Split(*hosts, ";") } - if *cmdfile != "" { - cmdlist, err = Getfile(*cmdfile) + if *cmdFile != "" { + cmdList, err = Getfile(*cmdFile) if err != nil { log.Println("load cmdfile error: ", err) return } } if *cmd != "" { - cmdlist = strings.Split(*cmd, ";") + cmdList = strings.Split(*cmd, ";") } - if *cfg == "" { - for _, host := range hostlist { - host_struct.host = host - host_struct.username = *username - host_struct.password = *password - host_struct.port = *port - host_struct.cmd = cmdlist - sshhosts = append(sshhosts, host_struct) + for _, host := range hostList { + host_Struct.Host = host + host_Struct.Username = *username + host_Struct.Password = *password + host_Struct.Port = *port + host_Struct.Cmd = cmdList + sshHosts = append(sshHosts, host_Struct) } } + //gu + if *jsonFile != "" { + sshHosts, err = GetJsonFile(*jsonFile) + if err != nil { + log.Println("load jsonFile error: ", err) + return + } + for i := 0; i < len(sshHosts); i++ { + cmdList, err = Getfile(sshHosts[i].CmdFile) + if err != nil { + log.Println("load cmdFile error: ", err) + return + } + //fmt.Println(cmdList) + sshHosts[i].Cmd = cmdList + } + //为什么不能用for range + } + /* else { cfgjson, err := GetfileAll(*cfg) @@ -96,23 +124,42 @@ func main() { } */ //fmt.Println(sshhosts) - - chs := make([]chan string, len(sshhosts)) - for i, host := range sshhosts { + chLimit := make(chan bool, *numLimit) //控制并发访问量 + chs := make([]chan string, len(sshHosts)) + limitFunc := func(chLimit chan bool, ch chan string, host SSHHost) { + dossh(host.Username, host.Password, host.Host, host.Cmd, host.Port, ch) + <-chLimit + } + for i, host := range sshHosts { chs[i] = make(chan string, 1) - go dossh(host.username, host.password, host.host, host.cmd, host.port, chs[i]) + chLimit <- true + go limitFunc(chLimit, chs[i], host) } for i, ch := range chs { - fmt.Println(sshhosts[i].host, " ssh start") + fmt.Println(sshHosts[i].Host, " ssh start") select { case res := <-ch: if res != "" { fmt.Println(res) + sshHosts[i].Result += res } - case <-time.After(30 * 1000 * 1000 * 1000): + case <-time.After(*timeLimit * 1000 * 1000 * 1000): log.Println("SSH run timeout") + sshHosts[i].Result += ("SSH run timeout:" + strconv.Itoa(int(*timeLimit)) + "second.") + } + + fmt.Println(sshHosts[i].Host, " ssh end") + } + + //gu + if !*outTxt { + for i := 0; i < len(sshHosts); i++ { + err = WriteIntoTxt(sshHosts[i]) + if err != nil { + log.Println("write into txt error: ", err) + return + } } - fmt.Println(sshhosts[i].host, " ssh end\n") } } diff --git a/sshconnect.go b/sshconnect.go index ecc9c93..c11ad2b 100644 --- a/sshconnect.go +++ b/sshconnect.go @@ -3,9 +3,10 @@ package main import ( "bytes" "fmt" - "time" - "net" + //"os" + + "time" "golang.org/x/crypto/ssh" ) @@ -55,29 +56,32 @@ func connect(user, password, host string, port int) (*ssh.Session, error) { func dossh(username, password, ip string, cmdlist []string, port int, ch chan string) { session, err := connect(username, password, ip, port) + if err != nil { ch <- fmt.Sprintf("<%s>", err.Error()) + //<-chLimit return + } defer session.Close() // cmd := "ls;date;exit" - stdinBuf, _ := session.StdinPipe() - + //fmt.Fprintf(os.Stdout, "%s", stdinBuf) var outbt, errbt bytes.Buffer session.Stdout = &outbt session.Stderr = &errbt - err = session.Shell() for _, c := range cmdlist { c = c + "\n" stdinBuf.Write([]byte(c)) + } session.Wait() - ch <- (outbt.String() + errbt.String()) + ch <- (outbt.String() + errbt.String()) + //<-chLimit return }