mirror of
https://github.com/shanghai-edu/multissh.git
synced 2025-12-16 13:27:44 +00:00
155 lines
3.3 KiB
Go
155 lines
3.3 KiB
Go
package main
|
||
|
||
import (
|
||
"flag"
|
||
"fmt"
|
||
"log"
|
||
"os"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
// "github.com/bitly/go-simplejson"
|
||
)
|
||
|
||
const (
|
||
VERSION = "0.1.1"
|
||
)
|
||
|
||
type SSHHost struct {
|
||
Host string
|
||
Port int
|
||
Username string
|
||
Password string
|
||
CmdFile string
|
||
Cmd []string
|
||
Result string
|
||
}
|
||
type HostJson struct {
|
||
SshHosts []SSHHost
|
||
}
|
||
|
||
func main() {
|
||
version := flag.Bool("v", false, "show version")
|
||
hosts := flag.String("hosts", "", "host address list")
|
||
cmd := flag.String("cmd", "", "cmds")
|
||
username := flag.String("u", "", "username")
|
||
password := flag.String("p", "", "password")
|
||
port := flag.Int("port", 22, "ssh port")
|
||
cmdFile := flag.String("cmdfile", "", "cmdfile path")
|
||
hostFile := flag.String("hostfile", "", "hostfile path")
|
||
ipFile := flag.String("ipfile", "", "hostfile path")
|
||
//gu
|
||
jsonFile := flag.String("j", "", "Json File Path")
|
||
outTxt := flag.Bool("outTxt", false, "write result into txt")
|
||
timeLimit := flag.Int("t", 30, "max timeout")
|
||
numLimit := flag.Int("n", 20, "max execute number")
|
||
|
||
flag.Parse()
|
||
var cmdList []string
|
||
var hostList []string
|
||
var err error
|
||
|
||
sshHosts := []SSHHost{}
|
||
var host_Struct SSHHost
|
||
timeout := time.Duration(*timeLimit) * time.Second
|
||
|
||
if *version {
|
||
fmt.Println(VERSION)
|
||
os.Exit(0)
|
||
}
|
||
|
||
if *ipFile != "" {
|
||
hostList, err = GetIpList(*ipFile)
|
||
if err != nil {
|
||
log.Println("load hostlist error: ", err)
|
||
return
|
||
}
|
||
}
|
||
|
||
if *hostFile != "" {
|
||
hostList, err = Getfile(*hostFile)
|
||
if err != nil {
|
||
log.Println("load hostfile error: ", err)
|
||
return
|
||
}
|
||
}
|
||
if *hosts != "" {
|
||
hostList = strings.Split(*hosts, ";")
|
||
}
|
||
|
||
if *cmdFile != "" {
|
||
cmdList, err = Getfile(*cmdFile)
|
||
if err != nil {
|
||
log.Println("load cmdfile error: ", err)
|
||
return
|
||
}
|
||
}
|
||
if *cmd != "" {
|
||
cmdList = strings.Split(*cmd, ";")
|
||
}
|
||
if *jsonFile == "" {
|
||
for _, host := range hostList {
|
||
host_Struct.Host = host
|
||
host_Struct.Username = *username
|
||
host_Struct.Password = *password
|
||
host_Struct.Port = *port
|
||
host_Struct.Cmd = cmdList
|
||
sshHosts = append(sshHosts, host_Struct)
|
||
}
|
||
}
|
||
//gu
|
||
if *jsonFile != "" {
|
||
sshHosts, err = GetJsonFile(*jsonFile)
|
||
if err != nil {
|
||
log.Println("load jsonFile error: ", err)
|
||
return
|
||
}
|
||
for i := 0; i < len(sshHosts); i++ {
|
||
cmdList, err = Getfile(sshHosts[i].CmdFile)
|
||
if err != nil {
|
||
log.Println("load cmdFile error: ", err)
|
||
return
|
||
}
|
||
sshHosts[i].Cmd = cmdList
|
||
}
|
||
}
|
||
|
||
chLimit := make(chan bool, *numLimit) //控制并发访问量
|
||
chs := make([]chan string, len(sshHosts))
|
||
limitFunc := func(chLimit chan bool, ch chan string, host SSHHost) {
|
||
dossh(host.Username, host.Password, host.Host, host.Cmd, host.Port, ch)
|
||
<-chLimit
|
||
}
|
||
for i, host := range sshHosts {
|
||
chs[i] = make(chan string, 1)
|
||
chLimit <- true
|
||
go limitFunc(chLimit, chs[i], host)
|
||
}
|
||
for i, ch := range chs {
|
||
fmt.Println(sshHosts[i].Host, " ssh start")
|
||
select {
|
||
case res := <-ch:
|
||
if res != "" {
|
||
fmt.Println(res)
|
||
sshHosts[i].Result += res
|
||
}
|
||
case <-time.After(timeout):
|
||
log.Println("SSH run timeout")
|
||
sshHosts[i].Result += ("SSH run timeout:" + strconv.Itoa(int(*timeLimit)) + "second.")
|
||
}
|
||
|
||
fmt.Println(sshHosts[i].Host, " ssh end")
|
||
}
|
||
|
||
//gu
|
||
if *outTxt {
|
||
for i := 0; i < len(sshHosts); i++ {
|
||
err = WriteIntoTxt(sshHosts[i])
|
||
if err != nil {
|
||
log.Println("write into txt error: ", err)
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
}
|