diff --git a/component/updater/update_ui.go b/component/updater/update_ui.go index 3cb603819..e1368aea1 100644 --- a/component/updater/update_ui.go +++ b/component/updater/update_ui.go @@ -118,7 +118,10 @@ func (u *UIUpdater) downloadUI() error { tmpDir := C.Path.Resolve("downloadUI.tmp") defer os.RemoveAll(tmpDir) - extractedFolder, err := extract(data, tmpDir) + + os.RemoveAll(tmpDir) // cleanup tmp dir before extract + log.Debugln("extractedFolder: %s", tmpDir) + err = extract(data, tmpDir) if err != nil { return fmt.Errorf("can't extract compressed file: %w", err) } @@ -136,8 +139,8 @@ func (u *UIUpdater) downloadUI() error { return fmt.Errorf("prepare UI path failed: %w", err) } - log.Debugln("moveFolder from %s to %s", extractedFolder, u.externalUIPath) - err = moveDir(extractedFolder, u.externalUIPath) // move files from tmp to target + log.Debugln("moveFolder from %s to %s", tmpDir, u.externalUIPath) + err = moveDir(tmpDir, u.externalUIPath) // move files from tmp to target if err != nil { return fmt.Errorf("move UI folder failed: %w", err) } @@ -154,63 +157,19 @@ func (u *UIUpdater) prepareUIPath() error { return nil } -func unzip(data []byte, dest string) (string, error) { +func unzip(data []byte, dest string) error { r, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) if err != nil { - return "", err + return err } // check whether or not only exists singleRoot dir - rootDir := "" - isSingleRoot := true - rootItemCount := 0 - for _, f := range r.File { - parts := strings.Split(strings.Trim(f.Name, "/"), "/") - if len(parts) == 0 { - continue - } - - if len(parts) == 1 { - isDir := strings.HasSuffix(f.Name, "/") - if !isDir { - isSingleRoot = false - break - } - - if rootDir == "" { - rootDir = parts[0] - } - rootItemCount++ - } - } - - if rootItemCount != 1 { - isSingleRoot = false - } - - // build the dir of extraction - var extractedFolder string - if isSingleRoot && rootDir != "" { - // if the singleRoot, use it directly - log.Debugln("Match the singleRoot") - extractedFolder = filepath.Join(dest, rootDir) - log.Debugln("extractedFolder: %s", extractedFolder) - } else { - log.Debugln("Match the multiRoot") - extractedFolder = dest - log.Debugln("extractedFolder: %s", extractedFolder) - } for _, f := range r.File { - var fpath string - if isSingleRoot && rootDir != "" { - fpath = filepath.Join(dest, f.Name) - } else { - fpath = filepath.Join(extractedFolder, f.Name) - } + fpath := filepath.Join(dest, f.Name) if !inDest(fpath, dest) { - return "", fmt.Errorf("invalid file path: %s", fpath) + return fmt.Errorf("invalid file path: %s", fpath) } info := f.FileInfo() if info.IsDir() { @@ -221,128 +180,77 @@ func unzip(data []byte, dest string) (string, error) { continue // disallow symlink } if err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil { - return "", err + return err } outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) if err != nil { - return "", err + return err } rc, err := f.Open() if err != nil { - return "", err + return err } _, err = io.Copy(outFile, rc) outFile.Close() rc.Close() if err != nil { - return "", err + return err } } - return extractedFolder, nil + return nil } -func untgz(data []byte, dest string) (string, error) { +func untgz(data []byte, dest string) error { gzr, err := gzip.NewReader(bytes.NewReader(data)) if err != nil { - return "", err + return err } defer gzr.Close() tr := tar.NewReader(gzr) - rootDir := "" - isSingleRoot := true - rootItemCount := 0 - for { - header, err := tr.Next() - if err == io.EOF { - break - } - if err != nil { - return "", err - } - - parts := strings.Split(cleanTarPath(header.Name), string(os.PathSeparator)) - if len(parts) == 0 { - continue - } - - if len(parts) == 1 { - isDir := header.Typeflag == tar.TypeDir - if !isDir { - isSingleRoot = false - break - } - - if rootDir == "" { - rootDir = parts[0] - } - rootItemCount++ - } - } - - if rootItemCount != 1 { - isSingleRoot = false - } - _ = gzr.Reset(bytes.NewReader(data)) tr = tar.NewReader(gzr) - var extractedFolder string - if isSingleRoot && rootDir != "" { - log.Debugln("Match the singleRoot") - extractedFolder = filepath.Join(dest, rootDir) - log.Debugln("extractedFolder: %s", extractedFolder) - } else { - log.Debugln("Match the multiRoot") - extractedFolder = dest - log.Debugln("extractedFolder: %s", extractedFolder) - } - for { header, err := tr.Next() if err == io.EOF { break } if err != nil { - return "", err + return err } - var fpath string - if isSingleRoot && rootDir != "" { - fpath = filepath.Join(dest, cleanTarPath(header.Name)) - } else { - fpath = filepath.Join(extractedFolder, cleanTarPath(header.Name)) - } + fpath := filepath.Join(dest, header.Name) if !inDest(fpath, dest) { - return "", fmt.Errorf("invalid file path: %s", fpath) + return fmt.Errorf("invalid file path: %s", fpath) } switch header.Typeflag { case tar.TypeDir: if err = os.MkdirAll(fpath, os.FileMode(header.Mode)); err != nil { - return "", err + return err } case tar.TypeReg: if err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil { - return "", err + return err } outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)) if err != nil { - return "", err + return err } if _, err := io.Copy(outFile, tr); err != nil { outFile.Close() - return "", err + return err } outFile.Close() } } - return extractedFolder, nil + return nil } -func extract(data []byte, dest string) (string, error) { +func extract(data []byte, dest string) error { fileType := detectFileType(data) log.Debugln("compression Type: %s", fileType) switch fileType { @@ -351,7 +259,7 @@ func extract(data []byte, dest string) (string, error) { case typeTarGzip: return untgz(data, dest) default: - return "", fmt.Errorf("unknown or unsupported file type") + return fmt.Errorf("unknown or unsupported file type") } } @@ -393,6 +301,15 @@ func moveDir(src string, dst string) error { return err } + if len(dirEntryList) == 1 && dirEntryList[0].IsDir() { + src = filepath.Join(src, dirEntryList[0].Name()) + log.Debugln("match the singleRoot: %s", src) + dirEntryList, err = os.ReadDir(src) + if err != nil { + return err + } + } + for _, dirEntry := range dirEntryList { err = os.Rename(filepath.Join(src, dirEntry.Name()), filepath.Join(dst, dirEntry.Name())) if err != nil {