Compare commits

...

4 Commits

2 changed files with 83 additions and 86 deletions

View File

@ -3,6 +3,7 @@ package updater
import ( import (
"archive/tar" "archive/tar"
"archive/zip" "archive/zip"
"bytes"
"compress/gzip" "compress/gzip"
"fmt" "fmt"
"io" "io"
@ -32,6 +33,17 @@ const (
typeTarGzip typeTarGzip
) )
func (t compressionType) String() string {
switch t {
case typeZip:
return "zip"
case typeTarGzip:
return "tar.gz"
default:
return "unknown"
}
}
var DefaultUiUpdater = &UIUpdater{} var DefaultUiUpdater = &UIUpdater{}
func NewUiUpdater(externalUI, externalUIURL, externalUIName string) *UIUpdater { func NewUiUpdater(externalUI, externalUIURL, externalUIName string) *UIUpdater {
@ -99,48 +111,35 @@ func detectFileType(data []byte) compressionType {
} }
func (u *UIUpdater) downloadUI() error { func (u *UIUpdater) downloadUI() error {
err := u.prepareUIPath()
if err != nil {
return fmt.Errorf("prepare UI path failed: %w", err)
}
data, err := downloadForBytes(u.externalUIURL) data, err := downloadForBytes(u.externalUIURL)
if err != nil { if err != nil {
return fmt.Errorf("can't download file: %w", err) return fmt.Errorf("can't download file: %w", err)
} }
fileType := detectFileType(data) tmpDir := C.Path.Resolve("downloadUI.tmp")
if fileType == typeUnknown { defer os.RemoveAll(tmpDir)
return fmt.Errorf("unknown or unsupported file type") extractedFolder, err := extract(data, tmpDir)
if err != nil {
return fmt.Errorf("can't extract compressed file: %w", err)
} }
ext := ".zip" log.Debugln("cleanupFolder: %s", u.externalUIPath)
if fileType == typeTarGzip { err = cleanup(u.externalUIPath) // cleanup files in dir don't remove dir itself
ext = ".tgz"
}
saved := path.Join(C.Path.HomeDir(), "download"+ext)
log.Debugln("compression Type: %s", ext)
if err = saveFile(data, saved); err != nil {
return fmt.Errorf("can't save compressed file: %w", err)
}
defer os.Remove(saved)
err = cleanup(u.externalUIPath)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
return fmt.Errorf("cleanup exist file error: %w", err) return fmt.Errorf("cleanup exist file error: %w", err)
} }
} }
extractedFolder, err := extract(saved, C.Path.HomeDir()) err = u.prepareUIPath()
if err != nil { if err != nil {
return fmt.Errorf("can't extract compressed file: %w", err) return fmt.Errorf("prepare UI path failed: %w", err)
} }
err = os.Rename(extractedFolder, u.externalUIPath) log.Debugln("moveFolder from %s to %s", extractedFolder, u.externalUIPath)
err = moveDir(extractedFolder, u.externalUIPath) // move files from tmp to target
if err != nil { if err != nil {
return fmt.Errorf("rename UI folder failed: %w", err) return fmt.Errorf("move UI folder failed: %w", err)
} }
return nil return nil
} }
@ -155,12 +154,11 @@ func (u *UIUpdater) prepareUIPath() error {
return nil return nil
} }
func unzip(src, dest string) (string, error) { func unzip(data []byte, dest string) (string, error) {
r, err := zip.OpenReader(src) r, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
if err != nil { if err != nil {
return "", err return "", err
} }
defer r.Close()
// check whether or not only exists singleRoot dir // check whether or not only exists singleRoot dir
rootDir := "" rootDir := ""
@ -199,17 +197,7 @@ func unzip(src, dest string) (string, error) {
log.Debugln("extractedFolder: %s", extractedFolder) log.Debugln("extractedFolder: %s", extractedFolder)
} else { } else {
log.Debugln("Match the multiRoot") log.Debugln("Match the multiRoot")
// or put the files/dirs into new dir extractedFolder = dest
baseName := filepath.Base(src)
baseName = strings.TrimSuffix(baseName, filepath.Ext(baseName))
extractedFolder = filepath.Join(dest, baseName)
for i := 1; ; i++ {
if _, err := os.Stat(extractedFolder); os.IsNotExist(err) {
break
}
extractedFolder = filepath.Join(dest, fmt.Sprintf("%s_%d", baseName, i))
}
log.Debugln("extractedFolder: %s", extractedFolder) log.Debugln("extractedFolder: %s", extractedFolder)
} }
@ -221,13 +209,17 @@ func unzip(src, dest string) (string, error) {
fpath = filepath.Join(extractedFolder, f.Name) fpath = filepath.Join(extractedFolder, f.Name)
} }
if !strings.HasPrefix(fpath, filepath.Clean(dest)+string(os.PathSeparator)) { if !inDest(fpath, dest) {
return "", fmt.Errorf("invalid file path: %s", fpath) return "", fmt.Errorf("invalid file path: %s", fpath)
} }
if f.FileInfo().IsDir() { info := f.FileInfo()
if info.IsDir() {
os.MkdirAll(fpath, os.ModePerm) os.MkdirAll(fpath, os.ModePerm)
continue continue
} }
if info.Mode()&os.ModeSymlink != 0 {
continue // disallow symlink
}
if err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil { if err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm); err != nil {
return "", err return "", err
} }
@ -249,14 +241,8 @@ func unzip(src, dest string) (string, error) {
return extractedFolder, nil return extractedFolder, nil
} }
func untgz(src, dest string) (string, error) { func untgz(data []byte, dest string) (string, error) {
file, err := os.Open(src) gzr, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return "", err
}
defer file.Close()
gzr, err := gzip.NewReader(file)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -299,8 +285,7 @@ func untgz(src, dest string) (string, error) {
isSingleRoot = false isSingleRoot = false
} }
file.Seek(0, 0) _ = gzr.Reset(bytes.NewReader(data))
gzr, _ = gzip.NewReader(file)
tr = tar.NewReader(gzr) tr = tar.NewReader(gzr)
var extractedFolder string var extractedFolder string
@ -310,17 +295,7 @@ func untgz(src, dest string) (string, error) {
log.Debugln("extractedFolder: %s", extractedFolder) log.Debugln("extractedFolder: %s", extractedFolder)
} else { } else {
log.Debugln("Match the multiRoot") log.Debugln("Match the multiRoot")
baseName := filepath.Base(src) extractedFolder = dest
baseName = strings.TrimSuffix(baseName, filepath.Ext(baseName))
baseName = strings.TrimSuffix(baseName, ".tar")
extractedFolder = filepath.Join(dest, baseName)
for i := 1; ; i++ {
if _, err := os.Stat(extractedFolder); os.IsNotExist(err) {
break
}
extractedFolder = filepath.Join(dest, fmt.Sprintf("%s_%d", baseName, i))
}
log.Debugln("extractedFolder: %s", extractedFolder) log.Debugln("extractedFolder: %s", extractedFolder)
} }
@ -340,7 +315,7 @@ func untgz(src, dest string) (string, error) {
fpath = filepath.Join(extractedFolder, cleanTarPath(header.Name)) fpath = filepath.Join(extractedFolder, cleanTarPath(header.Name))
} }
if !strings.HasPrefix(fpath, filepath.Clean(dest)+string(os.PathSeparator)) { if !inDest(fpath, dest) {
return "", fmt.Errorf("invalid file path: %s", fpath) return "", fmt.Errorf("invalid file path: %s", fpath)
} }
@ -367,16 +342,16 @@ func untgz(src, dest string) (string, error) {
return extractedFolder, nil return extractedFolder, nil
} }
func extract(src, dest string) (string, error) { func extract(data []byte, dest string) (string, error) {
srcLower := strings.ToLower(src) fileType := detectFileType(data)
switch { log.Debugln("compression Type: %s", fileType)
case strings.HasSuffix(srcLower, ".tar.gz") || switch fileType {
strings.HasSuffix(srcLower, ".tgz"): case typeZip:
return untgz(src, dest) return unzip(data, dest)
case strings.HasSuffix(srcLower, ".zip"): case typeTarGzip:
return unzip(src, dest) return untgz(data, dest)
default: default:
return "", fmt.Errorf("unsupported file format: %s", src) return "", fmt.Errorf("unknown or unsupported file type")
} }
} }
@ -398,22 +373,40 @@ func cleanTarPath(path string) string {
} }
func cleanup(root string) error { func cleanup(root string) error {
if _, err := os.Stat(root); os.IsNotExist(err) { dirEntryList, err := os.ReadDir(root)
return nil if err != nil {
return err
} }
return filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
for _, dirEntry := range dirEntryList {
err = os.RemoveAll(filepath.Join(root, dirEntry.Name()))
if err != nil { if err != nil {
return err return err
} }
if info.IsDir() { }
if err := os.RemoveAll(path); err != nil { return nil
return err }
}
} else { func moveDir(src string, dst string) error {
if err := os.Remove(path); err != nil { dirEntryList, err := os.ReadDir(src)
return err if err != nil {
} return err
} }
return nil
}) for _, dirEntry := range dirEntryList {
err = os.Rename(filepath.Join(src, dirEntry.Name()), filepath.Join(dst, dirEntry.Name()))
if err != nil {
return err
}
}
return nil
}
func inDest(fpath, dest string) bool {
if rel, err := filepath.Rel(dest, fpath); err == nil {
if filepath.IsLocal(rel) {
return true
}
}
return false
} }

View File

@ -7,6 +7,7 @@ import (
"net" "net"
"net/netip" "net/netip"
"net/url" "net/url"
"path/filepath"
"strings" "strings"
"time" "time"
_ "unsafe" _ "unsafe"
@ -759,6 +760,9 @@ func parseController(cfg *RawConfig) (*Controller, error) {
if path := cfg.ExternalUI; path != "" && !C.Path.IsSafePath(path) { if path := cfg.ExternalUI; path != "" && !C.Path.IsSafePath(path) {
return nil, C.Path.ErrNotSafePath(path) return nil, C.Path.ErrNotSafePath(path)
} }
if uiName := cfg.ExternalUIName; uiName != "" && !filepath.IsLocal(uiName) {
return nil, fmt.Errorf("external UI name is not local: %s", uiName)
}
return &Controller{ return &Controller{
ExternalController: cfg.ExternalController, ExternalController: cfg.ExternalController,
ExternalUI: cfg.ExternalUI, ExternalUI: cfg.ExternalUI,