commit 9e410eaa21beb1959cfacef95ffe72bebb93cae5 Author: Feng_Qi Date: Wed Jun 7 15:48:34 2017 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2cc6534 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*.txt +*.json +*.exe \ No newline at end of file diff --git a/README.MD b/README.MD new file mode 100644 index 0000000..12f8bfc --- /dev/null +++ b/README.MD @@ -0,0 +1,104 @@ + ## multissh + +一个简单的并行 SSH 工具,可以批量的对主机通过 SSH 执行命令组合。 + +#### 编译 +``` +go get ./... +go build +``` + + +#### 命令体系 +``` +./multissh -h + -cmd string + cmds // 需要执行的命令组合,多条命令以 ; 分割 + -cmdfile string + cmdfile path //需要执行的命令组合文件,文件内命令按行分割 + -hostfile string + hostfile path // 需要执行的主机列表文件,主机列表在文件内按行分割 + -hosts string + host address list //需要执行的主机列表,多个主机以 ; 分割 + -ipfile string + hostfile path //需要执行的主机(IP)列表文件,IP可以以地址段的方式逐行写在文本内 + -p string + password // 主机的 SSH 密码 + -port int + ssh port (default 22) //主机的 SSH 端口,默认 22 + -u string + username //主机的 SSH 用户名 +``` +**cmdfile 示例** +``` +show clock +exit +``` +**hostfile 示例** +``` +10.10.15.101 +10.10.15.102 +``` +**ipfile 示例** +``` +10.10.15.101-10.10.15.102 +``` + +## 用法 +#### cmd string & host string +``` +./multissh -cmd "show clock;exit" -hosts "10.10.15.101;10.10.15.102" -u admin -p admin + +10.10.15.101 ssh start +sw-1#show clock +05:26:40.649 UTC Tue Jun 6 2017 +sw-1#exit + +10.10.15.101 ssh end + +10.10.15.102 ssh start +sw-2#show clock +05:24:38.708 UTC Tue Jun 6 2017 +sw-2#exit + +10.10.15.102 ssh end +``` + +#### cmdfile & hostfile +``` +./multissh -cmdfile cmd.txt -hostfile host.txt -u admin -p admin + +10.10.15.101 ssh start +sw-1#show clock +05:29:43.269 UTC Tue Jun 6 2017 +sw-1#exit + +10.10.15.101 ssh end + +10.10.15.102 ssh start +sw-2#show clock +05:27:41.332 UTC Tue Jun 6 2017 +sw-2#exit + +10.10.15.102 ssh end +``` + +#### ipfile +``` +./multissh -cmdfile cmd.txt -ipfile ip.txt -u admin -p admin + +10.10.15.101 ssh start +sw-1#show clock +05:29:43.269 UTC Tue Jun 6 2017 +sw-1#exit + +10.10.15.101 ssh end + +10.10.15.102 ssh start +sw-2#show clock +05:27:41.332 UTC Tue Jun 6 2017 +sw-2#exit + +10.10.15.102 ssh end +``` + diff --git a/cfg.go b/cfg.go new file mode 100644 index 0000000..eb9827a --- /dev/null +++ b/cfg.go @@ -0,0 +1,185 @@ +package main + +import ( + "encoding/binary" + "io/ioutil" + "log" + "net" + "strconv" + "strings" +) + +func GetfileAll(filePath string) ([]byte, error) { + result, err := ioutil.ReadFile(filePath) + if err != nil { + log.Println("read file ", filePath, err) + return result, err + } + return result, nil +} + +func Getfile(filePath string) ([]string, error) { + result := []string{} + b, err := ioutil.ReadFile(filePath) + if err != nil { + log.Println("read file ", filePath, err) + return result, err + } + s := string(b) + for _, lineStr := range strings.Split(s, "\n") { + lineStr = strings.TrimSpace(lineStr) + if lineStr == "" { + continue + } + result = append(result, lineStr) + } + return result, nil +} + +func GetIpList(filePath string) ([]string, error) { + res, err := Getfile(filePath) + if err != nil { + return nil, nil + } + var allIp []string + if len(res) > 0 { + for _, sip := range res { + aip := ParseIp(sip) + for _, ip := range aip { + allIp = append(allIp, ip) + } + } + } + return allIp, nil +} + +func ParseIp(ip string) []string { + var availableIPs []string + // if ip is "1.1.1.1/",trim / + ip = strings.TrimRight(ip, "/") + if strings.Contains(ip, "/") == true { + if strings.Contains(ip, "/32") == true { + aip := strings.Replace(ip, "/32", "", -1) + availableIPs = append(availableIPs, aip) + } else { + availableIPs = GetAvailableIP(ip) + } + } else if strings.Contains(ip, "-") == true { + ipRange := strings.SplitN(ip, "-", 2) + availableIPs = GetAvailableIPRange(ipRange[0], ipRange[1]) + } else { + availableIPs = append(availableIPs, ip) + } + return availableIPs +} + +func GetAvailableIPRange(ipStart, ipEnd string) []string { + var availableIPs []string + + firstIP := net.ParseIP(ipStart) + endIP := net.ParseIP(ipEnd) + if firstIP.To4() == nil || endIP.To4() == nil { + return availableIPs + } + firstIPNum := ipToInt(firstIP.To4()) + EndIPNum := ipToInt(endIP.To4()) + pos := int32(1) + + newNum := firstIPNum + + for newNum <= EndIPNum { + availableIPs = append(availableIPs, intToIP(newNum).String()) + newNum = newNum + pos + } + return availableIPs +} + +func GetAvailableIP(ipAndMask string) []string { + var availableIPs []string + + ipAndMask = strings.TrimSpace(ipAndMask) + ipAndMask = IPAddressToCIDR(ipAndMask) + _, ipnet, _ := net.ParseCIDR(ipAndMask) + + firstIP, _ := networkRange(ipnet) + ipNum := ipToInt(firstIP) + size := networkSize(ipnet.Mask) + pos := int32(1) + max := size - 2 // -1 for the broadcast address, -1 for the gateway address + + var newNum int32 + for attempt := int32(0); attempt < max; attempt++ { + newNum = ipNum + pos + pos = pos%max + 1 + availableIPs = append(availableIPs, intToIP(newNum).String()) + } + return availableIPs +} + +func IPAddressToCIDR(ipAdress string) string { + if strings.Contains(ipAdress, "/") == true { + ipAndMask := strings.Split(ipAdress, "/") + ip := ipAndMask[0] + mask := ipAndMask[1] + if strings.Contains(mask, ".") == true { + mask = IPMaskStringToCIDR(mask) + } + return ip + "/" + mask + } else { + return ipAdress + } +} + +func IPMaskStringToCIDR(netmask string) string { + netmaskList := strings.Split(netmask, ".") + var mint []int + for _, v := range netmaskList { + strv, _ := strconv.Atoi(v) + mint = append(mint, strv) + } + myIPMask := net.IPv4Mask(byte(mint[0]), byte(mint[1]), byte(mint[2]), byte(mint[3])) + ones, _ := myIPMask.Size() + return strconv.Itoa(ones) +} + +func IPMaskCIDRToString(one string) string { + oneInt, _ := strconv.Atoi(one) + mIPmask := net.CIDRMask(oneInt, 32) + var maskstring []string + for _, v := range mIPmask { + maskstring = append(maskstring, strconv.Itoa(int(v))) + } + return strings.Join(maskstring, ".") +} + +// Calculates the first and last IP addresses in an IPNet +func networkRange(network *net.IPNet) (net.IP, net.IP) { + netIP := network.IP.To4() + firstIP := netIP.Mask(network.Mask) + lastIP := net.IPv4(0, 0, 0, 0).To4() + for i := 0; i < len(lastIP); i++ { + lastIP[i] = netIP[i] | ^network.Mask[i] + } + return firstIP, lastIP +} + +// Given a netmask, calculates the number of available hosts +func networkSize(mask net.IPMask) int32 { + m := net.IPv4Mask(0, 0, 0, 0) + for i := 0; i < net.IPv4len; i++ { + m[i] = ^mask[i] + } + return int32(binary.BigEndian.Uint32(m)) + 1 +} + +// Converts a 4 bytes IP into a 32 bit integer +func ipToInt(ip net.IP) int32 { + return int32(binary.BigEndian.Uint32(ip.To4())) +} + +// Converts 32 bit integer into a 4 bytes IP address +func intToIP(n int32) net.IP { + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, uint32(n)) + return net.IP(b) +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..1dea59a --- /dev/null +++ b/main.go @@ -0,0 +1,118 @@ +package main + +import ( + "flag" + "fmt" + "log" + "strings" + "time" + // "github.com/bitly/go-simplejson" +) + +type sshhost struct { + host string + port int + username string + password string + cmd []string +} + +func main() { + hosts := flag.String("hosts", "", "host address list") + cmd := flag.String("cmd", "", "cmds") + username := flag.String("u", "", "username") + password := flag.String("p", "", "password") + port := flag.Int("port", 22, "ssh port") + cmdfile := flag.String("cmdfile", "", "cmdfile path") + hostfile := flag.String("hostfile", "", "hostfile path") + ipfile := flag.String("ipfile", "", "hostfile path") + cfg := flag.String("cfg", "", "cfg path") + + flag.Parse() + + var cmdlist []string + var hostlist []string + var err error + + sshhosts := []sshhost{} + var host_struct sshhost + + if *ipfile != "" { + hostlist, err = GetIpList(*ipfile) + if err != nil { + log.Println("load hostlist error: ", err) + return + } + } + + if *hostfile != "" { + hostlist, err = Getfile(*hostfile) + if err != nil { + log.Println("load hostfile error: ", err) + return + } + } + if *hosts != "" { + hostlist = strings.Split(*hosts, ";") + + } + + if *cmdfile != "" { + cmdlist, err = Getfile(*cmdfile) + if err != nil { + log.Println("load cmdfile error: ", err) + return + } + } + if *cmd != "" { + cmdlist = strings.Split(*cmd, ";") + } + + if *cfg == "" { + for _, host := range hostlist { + host_struct.host = host + host_struct.username = *username + host_struct.password = *password + host_struct.port = *port + host_struct.cmd = cmdlist + sshhosts = append(sshhosts, host_struct) + } + } + /* + else { + cfgjson, err := GetfileAll(*cfg) + if err != nil { + log.Println("load cfg error: ", err) + return + } + + js, js_err := simplejson.NewJson(cfgjson) + if js_err != nil { + log.Println("json format error: ", js_err) + return + } + + + } + */ + //fmt.Println(sshhosts) + + chs := make([]chan string, len(sshhosts)) + for i, host := range sshhosts { + chs[i] = make(chan string, 1) + go dossh(host.username, host.password, host.host, host.cmd, host.port, chs[i]) + } + for i, ch := range chs { + fmt.Println(sshhosts[i].host, " ssh start") + select { + case res := <-ch: + if res != "" { + fmt.Println(res) + } + case <-time.After(30 * 1000 * 1000 * 1000): + log.Println("SSH run timeout") + } + fmt.Println(sshhosts[i].host, " ssh end\n") + } + +} diff --git a/ssh_test.go b/ssh_test.go new file mode 100644 index 0000000..63f43d2 --- /dev/null +++ b/ssh_test.go @@ -0,0 +1,52 @@ +package main + +import ( + // "bytes" + "os" + "testing" +) + +const ( + username = "" + password = "" + ip = "" + port = 22 + cmd = "date\n" +) + +func Test_SSH(t *testing.T) { + session, err := connect(username, password, ip, port) + if err != nil { + t.Error(err) + return + } + defer session.Close() + + //cmdlist := strings.Split(cmd, ";") + + stdinBuf, err := session.StdinPipe() + if err != nil { + t.Error(err) + return + } + // var bt bytes.Buffer + // session.Stdout = &bt + t.Log(session.Stdout) + t.Log(session.Stderr) + session.Stdout = os.Stdout + session.Stderr = os.Stderr + session.Stdin = os.Stdin + err = session.Shell() + if err != nil { + t.Error(err) + return + } + // for _, c := range cmdlist { + // c = c + "\n" + stdinBuf.Write([]byte(cmd)) + // } + session.Wait() + t.Error(err) + // t.Log(bt.String()) + return +} diff --git a/sshconnect.go b/sshconnect.go new file mode 100644 index 0000000..ecc9c93 --- /dev/null +++ b/sshconnect.go @@ -0,0 +1,83 @@ +package main + +import ( + "bytes" + "fmt" + "time" + + "net" + + "golang.org/x/crypto/ssh" +) + +func connect(user, password, host string, port int) (*ssh.Session, error) { + var ( + auth []ssh.AuthMethod + addr string + clientConfig *ssh.ClientConfig + client *ssh.Client + config ssh.Config + session *ssh.Session + err error + ) + // get auth method + auth = make([]ssh.AuthMethod, 0) + auth = append(auth, ssh.Password(password)) + + config = ssh.Config{ + Ciphers: []string{"aes128-ctr", "aes192-ctr", "aes256-ctr", "arcfour256", "arcfour128", "aes128-cbc", "3des-cbc", "blowfish-cbc", "cast128-cbc", "aes192-cbc", "aes256-cbc", "arcfour"}, + } + + clientConfig = &ssh.ClientConfig{ + User: user, + Auth: auth, + Timeout: 30 * time.Second, + Config: config, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + + // connet to ssh + addr = fmt.Sprintf("%s:%d", host, port) + + if client, err = ssh.Dial("tcp", addr, clientConfig); err != nil { + return nil, err + } + + // create session + if session, err = client.NewSession(); err != nil { + return nil, err + } + + return session, nil +} + +func dossh(username, password, ip string, cmdlist []string, port int, ch chan string) { + session, err := connect(username, password, ip, port) + if err != nil { + ch <- fmt.Sprintf("<%s>", err.Error()) + return + } + defer session.Close() + + // cmd := "ls;date;exit" + + stdinBuf, _ := session.StdinPipe() + + var outbt, errbt bytes.Buffer + session.Stdout = &outbt + + session.Stderr = &errbt + + err = session.Shell() + for _, c := range cmdlist { + c = c + "\n" + stdinBuf.Write([]byte(c)) + } + session.Wait() + ch <- (outbt.String() + errbt.String()) + + return + +}