1
0
mirror of https://github.com/1Panel-dev/1Panel.git synced 2025-01-22 01:39:18 +08:00
2024-06-28 14:04:08 +08:00

303 lines
7.1 KiB
Go

package helper
import (
"bufio"
"database/sql"
"fmt"
"io"
"os"
"strings"
"time"
"github.com/1Panel-dev/1Panel/backend/global"
_ "github.com/go-sql-driver/mysql"
)
func init() {}
type dumpOption struct {
isData bool
tables []string
isAllTable bool
isDropTable bool
writer io.Writer
}
type DumpOption func(*dumpOption)
func WithDropTable() DumpOption {
return func(option *dumpOption) {
option.isDropTable = true
}
}
func WithData() DumpOption {
return func(option *dumpOption) {
option.isData = true
}
}
func WithWriter(writer io.Writer) DumpOption {
return func(option *dumpOption) {
option.writer = writer
}
}
func Dump(dns string, opts ...DumpOption) error {
start := time.Now()
global.LOG.Infof("dump start at %s\n", start.Format(constant.DateTimeLayout))
defer func() {
end := time.Now()
global.LOG.Infof("dump end at %s, cost %s\n", end.Format(constant.DateTimeLayout), end.Sub(start))
}()
var err error
var o dumpOption
for _, opt := range opts {
opt(&o)
}
if len(o.tables) == 0 {
o.isAllTable = true
}
if o.writer == nil {
o.writer = os.Stdout
}
buf := bufio.NewWriter(o.writer)
defer buf.Flush()
itemFile, lineNumber := "", 0
itemFile += "-- ----------------------------\n"
itemFile += "-- MySQL Database Dump\n"
itemFile += "-- Start Time: " + start.Format(constant.DateTimeLayout) + "\n"
itemFile += "-- ----------------------------\n\n\n"
db, err := sql.Open("mysql", dns)
if err != nil {
global.LOG.Errorf("open mysql db failed, err: %v", err)
return err
}
defer db.Close()
dbName, err := getDBNameFromDNS(dns)
if err != nil {
global.LOG.Errorf("get db name from dns failed, err: %v", err)
return err
}
_, err = db.Exec(fmt.Sprintf("USE `%s`", dbName))
if err != nil {
global.LOG.Errorf("exec `use %s` failed, err: %v", dbName, err)
return err
}
var tables []string
if o.isAllTable {
tmp, err := getAllTables(db)
if err != nil {
global.LOG.Errorf("get all tables failed, err: %v", err)
return err
}
tables = tmp
} else {
tables = o.tables
}
for _, table := range tables {
if o.isDropTable {
itemFile += fmt.Sprintf("DROP TABLE IF EXISTS `%s`;\n", table)
}
itemFile += "-- ----------------------------\n"
itemFile += fmt.Sprintf("-- Table structure for %s\n", table)
itemFile += "-- ----------------------------\n"
createTableSQL, err := getCreateTableSQL(db, table)
if err != nil {
global.LOG.Errorf("get create table sql failed, err: %v", err)
return err
}
itemFile += createTableSQL
itemFile += ";\n\n\n\n"
if o.isData {
itemFile += "-- ----------------------------\n"
itemFile += fmt.Sprintf("-- Records of %s\n", table)
itemFile += "-- ----------------------------\n"
lineRows, err := db.Query(fmt.Sprintf("SELECT * FROM `%s`", table))
if err != nil {
global.LOG.Errorf("exec `select * from %s` failed, err: %v", table, err)
return err
}
defer lineRows.Close()
var columns []string
columns, err = lineRows.Columns()
if err != nil {
global.LOG.Errorf("get columes failed, err: %v", err)
return err
}
columnTypes, err := lineRows.ColumnTypes()
if err != nil {
global.LOG.Errorf("get colume types failed, err: %v", err)
return err
}
for lineRows.Next() {
row := make([]interface{}, len(columns))
rowPointers := make([]interface{}, len(columns))
for i := range columns {
rowPointers[i] = &row[i]
}
if err = lineRows.Scan(rowPointers...); err != nil {
global.LOG.Errorf("scan row data failed, err: %v", err)
return err
}
ssql := loadDataSql(row, columnTypes, table)
if len(ssql) != 0 {
itemFile += ssql
lineNumber++
}
if lineNumber > 500 {
_, _ = buf.WriteString(itemFile)
itemFile = ""
lineNumber = 0
}
}
itemFile += "\n\n"
}
}
itemFile += "-- ----------------------------\n"
itemFile += "-- Dumped by mysqldump\n"
itemFile += "-- Cost Time: " + time.Since(start).String() + "\n"
itemFile += "-- ----------------------------\n"
_, _ = buf.WriteString(itemFile)
_ = buf.Flush()
return nil
}
func getCreateTableSQL(db *sql.DB, table string) (string, error) {
var createTableSQL string
err := db.QueryRow(fmt.Sprintf("SHOW CREATE TABLE `%s`", table)).Scan(&table, &createTableSQL)
if err != nil {
return "", err
}
createTableSQL = strings.Replace(createTableSQL, "CREATE TABLE", "CREATE TABLE IF NOT EXISTS", 1)
return createTableSQL, nil
}
func getAllTables(db *sql.DB) ([]string, error) {
var tables []string
rows, err := db.Query("SHOW TABLES")
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var table string
err = rows.Scan(&table)
if err != nil {
return nil, err
}
tables = append(tables, table)
}
return tables, nil
}
func loadDataSql(row []interface{}, columnTypes []*sql.ColumnType, table string) string {
ssql := "INSERT INTO `" + table + "` VALUES ("
for i, col := range row {
if col == nil {
ssql += "NULL"
} else {
Type := columnTypes[i].DatabaseTypeName()
Type = strings.Replace(Type, "UNSIGNED", "", -1)
Type = strings.Replace(Type, " ", "", -1)
switch Type {
case "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "INTEGER", "BIGINT":
if bs, ok := col.([]byte); ok {
ssql += string(bs)
} else {
ssql += fmt.Sprintf("%d", col)
}
case "FLOAT", "DOUBLE":
if bs, ok := col.([]byte); ok {
ssql += string(bs)
} else {
ssql += fmt.Sprintf("%f", col)
}
case "DECIMAL", "DEC":
ssql += fmt.Sprintf("%s", col)
case "DATE":
t, ok := col.(time.Time)
if !ok {
global.LOG.Errorf("the DATE type conversion failed, err value: %v", col)
return ""
}
ssql += fmt.Sprintf("'%s'", t.Format("2006-01-02"))
case "DATETIME":
t, ok := col.(time.Time)
if !ok {
global.LOG.Errorf("the DATETIME type conversion failed, err value: %v", col)
return ""
}
ssql += fmt.Sprintf("'%s'", t.Format(constant.DateTimeLayout))
case "TIMESTAMP":
t, ok := col.(time.Time)
if !ok {
global.LOG.Errorf("the TIMESTAMP type conversion failed, err value: %v", col)
return ""
}
ssql += fmt.Sprintf("'%s'", t.Format(constant.DateTimeLayout))
case "TIME":
t, ok := col.([]byte)
if !ok {
global.LOG.Errorf("the TIME type conversion failed, err value: %v", col)
return ""
}
ssql += fmt.Sprintf("'%s'", string(t))
case "YEAR":
t, ok := col.([]byte)
if !ok {
global.LOG.Errorf("the YEAR type conversion failed, err value: %v", col)
return ""
}
ssql += string(t)
case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT":
ssql += fmt.Sprintf("'%s'", strings.Replace(fmt.Sprintf("%s", col), "'", "''", -1))
case "BIT", "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB":
ssql += fmt.Sprintf("0x%X", col)
case "ENUM", "SET":
ssql += fmt.Sprintf("'%s'", col)
case "BOOL", "BOOLEAN":
if col.(bool) {
ssql += "true"
} else {
ssql += "false"
}
case "JSON":
ssql += fmt.Sprintf("'%s'", col)
default:
global.LOG.Errorf("unsupported colume type: %s", Type)
return ""
}
}
if i < len(row)-1 {
ssql += ","
}
}
ssql += ");\n"
return ssql
}