From 71636fff17b5fcc0f14f31ee7af0b9ea889ec55d Mon Sep 17 00:00:00 2001 From: Alexander Larsson Date: Mon, 3 Oct 2022 12:12:39 +0200 Subject: [PATCH] Initial quadlet version integrated in golang Based on the initial port in https://github.com/containers/quadlet/pull/41 Signed-off-by: Alexander Larsson --- Makefile | 10 + cmd/quadlet/main.go | 254 +++++++++ pkg/quadlet/podmancmdline.go | 60 ++ pkg/quadlet/quadlet.go | 687 +++++++++++++++++++++++ pkg/quadlet/quadlet_test.go | 45 ++ pkg/quadlet/ranges.go | 247 ++++++++ pkg/quadlet/ranges_test.go | 242 ++++++++ pkg/quadlet/subuids.go | 67 +++ pkg/systemdparser/split.go | 505 +++++++++++++++++ pkg/systemdparser/unitfile.go | 866 +++++++++++++++++++++++++++++ pkg/systemdparser/unitfile_test.go | 245 ++++++++ 11 files changed, 3228 insertions(+) create mode 100644 cmd/quadlet/main.go create mode 100644 pkg/quadlet/podmancmdline.go create mode 100644 pkg/quadlet/quadlet.go create mode 100644 pkg/quadlet/quadlet_test.go create mode 100644 pkg/quadlet/ranges.go create mode 100644 pkg/quadlet/ranges_test.go create mode 100644 pkg/quadlet/subuids.go create mode 100644 pkg/systemdparser/split.go create mode 100644 pkg/systemdparser/unitfile.go create mode 100644 pkg/systemdparser/unitfile_test.go diff --git a/Makefile b/Makefile index f622739652da..b63608ee7cd9 100644 --- a/Makefile +++ b/Makefile @@ -337,6 +337,16 @@ podman: bin/podman .PHONY: podman-remote podman-remote: $(SRCBINDIR)/podman$(BINSFX) +$(SRCBINDIR)/quadlet: $(SOURCES) go.mod go.sum + $(GOCMD) build \ + $(BUILDFLAGS) \ + $(GO_LDFLAGS) '$(LDFLAGS_PODMAN)' \ + -tags "${BUILDTAGS}" \ + -o $@ ./cmd/quadlet + +.PHONY: quadlet +quadlet: bin/quadlet + PHONY: podman-remote-static podman-remote-static: $(SRCBINDIR)/podman-remote-static diff --git a/cmd/quadlet/main.go b/cmd/quadlet/main.go new file mode 100644 index 000000000000..fe8aae7d4d01 --- /dev/null +++ b/cmd/quadlet/main.go @@ -0,0 +1,254 @@ +package main + +import ( + "errors" + "flag" + "fmt" + "os" + "path" + "path/filepath" + "strings" + + "github.com/containers/podman/v4/pkg/quadlet" + "github.com/containers/podman/v4/pkg/systemdparser" +) + +var ( + verboseFlag bool // True if -v passed + isUser bool // True if run as quadlet-user-generator executable +) + +var ( + // data saved between logToKmsg calls + noKmsg = false + kmsgFile *os.File +) + +func logToKmsg(s string) bool { + if noKmsg { + return false + } + + if kmsgFile == nil { + f, err := os.OpenFile("/dev/kmsg", os.O_WRONLY, 0644) + if err != nil { + noKmsg = true + return false + } + kmsgFile = f + } + + if _, err := kmsgFile.Write([]byte(s)); err != nil { + kmsgFile.Close() + kmsgFile = nil + return false + } + + return true +} + +func Logf(format string, a ...interface{}) { + s := fmt.Sprintf(format, a...) + line := fmt.Sprintf("quadlet-generator[%d]: %s", os.Getpid(), s) + + if !logToKmsg(line) { + // If we can't log, print to stderr + fmt.Fprintf(os.Stderr, "%s\n", line) + os.Stderr.Sync() + } +} + +var debugEnabled = false + +func enableDebug() { + debugEnabled = true +} + +func Debugf(format string, a ...interface{}) { + if debugEnabled { + Logf(format, a...) + } +} + +func getUnitDirs(user bool) []string { + unitDirsEnv := os.Getenv("QUADLET_UNIT_DIRS") + if len(unitDirsEnv) > 0 { + return strings.Split(unitDirsEnv, ":") + } + + dirs := make([]string, 0) + if user { + if configDir, err := os.UserConfigDir(); err == nil { + dirs = append(dirs, path.Join(configDir, "containers/systemd")) + } + } else { + dirs = append(dirs, quadlet.UnitDirAdmin) + dirs = append(dirs, quadlet.UnitDirDistro) + } + return dirs +} + +func loadUnitsFromDir(sourcePath string, units map[string]*systemdparser.UnitFile) { + files, err := os.ReadDir(sourcePath) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + Logf("Can't read \"%s\": %s", sourcePath, err) + } + return + } + + for _, file := range files { + name := file.Name() + if units[name] == nil && + (strings.HasSuffix(name, ".container") || + strings.HasSuffix(name, ".volume")) { + path := path.Join(sourcePath, name) + + Debugf("Loading source unit file %s", path) + + if f, err := systemdparser.ParseUnitFile(path); err != nil { + Logf("Error loading '%s', ignoring: %s", path, err) + } else { + units[name] = f + } + } + } +} + +func generateServiceFile(service *systemdparser.UnitFile) error { + Debugf("writing '%s'", service.Path) + + service.PrependComment("", + "Automatically generated by quadlet-generator", + "") + + f, err := os.Create(service.Path) + if err != nil { + return err + } + + defer f.Close() + + err = service.Write(f) + if err != nil { + return err + } + + err = f.Sync() + if err != nil { + return err + } + + return nil +} + +func enableServiceFile(outputPath string, service *systemdparser.UnitFile) { + symlinks := make([]string, 0) + + aliases := service.LookupAllStrv(quadlet.InstallGroup, "Alias") + for _, alias := range aliases { + symlinks = append(symlinks, filepath.Clean(alias)) + } + + wantedBy := service.LookupAllStrv(quadlet.InstallGroup, "WantedBy") + for _, wantedByUnit := range wantedBy { + // Only allow filenames, not paths + if !strings.Contains(wantedByUnit, "/") { + symlinks = append(symlinks, fmt.Sprintf("%s.wants/%s", wantedByUnit, service.Filename)) + } + } + + requiredBy := service.LookupAllStrv(quadlet.InstallGroup, "RequiredBy") + for _, requiredByUnit := range requiredBy { + // Only allow filenames, not paths + if !strings.Contains(requiredByUnit, "/") { + symlinks = append(symlinks, fmt.Sprintf("%s.requires/%s", requiredByUnit, service.Filename)) + } + } + + for _, symlinkRel := range symlinks { + target, err := filepath.Rel(path.Dir(symlinkRel), service.Filename) + if err != nil { + Logf("Can't create symlink %s: %s", symlinkRel, err) + continue + } + symlinkPath := path.Join(outputPath, symlinkRel) + + symlinkDir := path.Dir(symlinkPath) + err = os.MkdirAll(symlinkDir, os.ModePerm) + if err != nil { + Logf("Can't create dir %s: %s", symlinkDir, err) + continue + } + + Debugf("Creating symlink %s -> %s", symlinkPath, target) + _ = os.Remove(symlinkPath) // overwrite existing symlinks + err = os.Symlink(target, symlinkPath) + if err != nil { + Logf("Failed creating symlink %s: %s", symlinkPath, err) + } + } +} + +func main() { + prgname := path.Base(os.Args[0]) + isUser = strings.Contains(prgname, "user") + + flag.Parse() + + if verboseFlag { + enableDebug() + } + + if flag.NArg() < 1 { + Logf("Missing output directory argument") + os.Exit(1) + } + + outputPath := flag.Arg(0) + + Debugf("Starting quadlet-generator, output to: %s", outputPath) + + sourcePaths := getUnitDirs(isUser) + + units := make(map[string]*systemdparser.UnitFile) + for _, d := range sourcePaths { + loadUnitsFromDir(d, units) + } + + err := os.MkdirAll(outputPath, os.ModePerm) + if err != nil { + Logf("Can't create dir %s: %s", outputPath, err) + os.Exit(1) + } + + for name, unit := range units { + var service *systemdparser.UnitFile + var err error + + switch { + case strings.HasSuffix(name, ".container"): + service, err = quadlet.ConvertContainer(unit, isUser) + case strings.HasSuffix(name, ".volume"): + service, err = quadlet.ConvertVolume(unit, name) + default: + Logf("Unsupported file type '%s'", name) + continue + } + + if err != nil { + Logf("Error converting '%s', ignoring: %s", name, err) + } else { + service.Path = path.Join(outputPath, service.Filename) + + if err := generateServiceFile(service); err != nil { + Logf("Error writing '%s'o: %s", service.Path, err) + } + enableServiceFile(outputPath, service) + } + } +} + +func init() { + flag.BoolVar(&verboseFlag, "v", false, "Print debug information") +} diff --git a/pkg/quadlet/podmancmdline.go b/pkg/quadlet/podmancmdline.go new file mode 100644 index 000000000000..adec5b2c239c --- /dev/null +++ b/pkg/quadlet/podmancmdline.go @@ -0,0 +1,60 @@ +package quadlet + +import ( + "fmt" + "sort" +) + +/* This is a helper for constructing podman commandlines */ +type PodmanCmdline struct { + Args []string +} + +func (c *PodmanCmdline) add(args ...string) { + c.Args = append(c.Args, args...) +} + +func (c *PodmanCmdline) addf(format string, a ...interface{}) { + c.add(fmt.Sprintf(format, a...)) +} + +func (c *PodmanCmdline) addKeys(arg string, keys map[string]string) { + ks := make([]string, 0, len(keys)) + for k := range keys { + ks = append(ks, k) + } + sort.Strings(ks) + + for _, k := range ks { + c.add(arg, fmt.Sprintf("%s=%s", k, keys[k])) + } +} + +func (c *PodmanCmdline) addEnv(env map[string]string) { + c.addKeys("--env", env) +} + +func (c *PodmanCmdline) addLabels(labels map[string]string) { + c.addKeys("--label", labels) +} + +func (c *PodmanCmdline) addAnnotations(annotations map[string]string) { + c.addKeys("--annotation", annotations) +} + +func (c *PodmanCmdline) addIDMap(argPrefix string, containerIDStart, hostIDStart, numIDs uint32) { + if numIDs != 0 { + c.add(argPrefix) + c.addf("%d:%d:%d", containerIDStart, hostIDStart, numIDs) + } +} + +func NewPodmanCmdline(args ...string) *PodmanCmdline { + c := &PodmanCmdline{ + Args: make([]string, 0), + } + + c.add("/usr/bin/podman") + c.add(args...) + return c +} diff --git a/pkg/quadlet/quadlet.go b/pkg/quadlet/quadlet.go new file mode 100644 index 000000000000..576029190da0 --- /dev/null +++ b/pkg/quadlet/quadlet.go @@ -0,0 +1,687 @@ +package quadlet + +import ( + "fmt" + "math" + "os" + "regexp" + "strings" + "unicode" + + "github.com/containers/podman/v4/pkg/systemdparser" +) + +const ( + UnitDirAdmin = "/etc/containers/systemd" + UnitDirDistro = "/usr/share/containers/systemd" + + UnitGroup = "Unit" + InstallGroup = "Install" + ServiceGroup = "Service" + ContainerGroup = "Container" + XContainerGroup = "X-Container" + VolumeGroup = "Volume" + XVolumeGroup = "X-Volume" + + // TODO: These should be configurable + QuadletUserName = "quadlet" + FallbackUIDStart = 1879048192 + FallbackUIDLength = 165536 + FallbackGIDStart = 1879048192 + FallbackGIDLength = 165536 +) + +var validPortRange = regexp.MustCompile(`\d+(-\d+)?(/udp|/tcp)?$`) + +const ( + KeyContainerName = "ContainerName" + KeyImage = "Image" + KeyEnvironment = "Environment" + KeyExec = "Exec" + KeyNoNewPrivileges = "NoNewPrivileges" + KeyDropCapability = "DropCapability" + KeyAddCapability = "AddCapability" + KeyReadOnly = "ReadOnly" + KeyRemapUsers = "RemapUsers" + KeyRemapUIDStart = "RemapUidStart" + KeyRemapGIDStart = "RemapGidStart" + KeyRemapUIDRanges = "RemapUidRanges" + KeyRemapGIDRanges = "RemapGidRanges" + KeyNotify = "Notify" + KeySocketActivated = "SocketActivated" + KeyExposeHostPort = "ExposeHostPort" + KeyPublishPort = "PublishPort" + KeyKeepID = "KeepId" + KeyUser = "User" + KeyGroup = "Group" + KeyHostUser = "HostUser" + KeyHostGroup = "HostGroup" + KeyVolume = "Volume" + KeyPodmanArgs = "PodmanArgs" + KeyLabel = "Label" + KeyAnnotation = "Annotation" + KeyRunInit = "RunInit" + KeyVolatileTmp = "VolatileTmp" + KeyTimezone = "Timezone" +) + +var supportedContainerKeys = map[string]bool{ + KeyContainerName: true, + KeyImage: true, + KeyEnvironment: true, + KeyExec: true, + KeyNoNewPrivileges: true, + KeyDropCapability: true, + KeyAddCapability: true, + KeyReadOnly: true, + KeyRemapUsers: true, + KeyRemapUIDStart: true, + KeyRemapGIDStart: true, + KeyRemapUIDRanges: true, + KeyRemapGIDRanges: true, + KeyNotify: true, + KeySocketActivated: true, + KeyExposeHostPort: true, + KeyPublishPort: true, + KeyKeepID: true, + KeyUser: true, + KeyGroup: true, + KeyHostUser: true, + KeyHostGroup: true, + KeyVolume: true, + KeyPodmanArgs: true, + KeyLabel: true, + KeyAnnotation: true, + KeyRunInit: true, + KeyVolatileTmp: true, + KeyTimezone: true, +} + +var supportedVolumeKeys = map[string]bool{ + KeyUser: true, + KeyGroup: true, + KeyLabel: true, +} + +func replaceExtension(name string, extension string, extraPrefix string, extraSuffix string) string { + baseName := name + + dot := strings.LastIndexByte(name, '.') + if dot > 0 { + baseName = name[:dot] + } + + return extraPrefix + baseName + extraSuffix + extension +} + +var defaultRemapUIDs, defaultRemapGIDs *Ranges + +func getDefaultRemapUids() *Ranges { + if defaultRemapUIDs == nil { + defaultRemapUIDs = lookupHostSubuid(QuadletUserName) + if defaultRemapUIDs == nil { + defaultRemapUIDs = + NewRanges(FallbackUIDStart, FallbackUIDLength) + } + } + return defaultRemapUIDs +} + +func getDefaultRemapGids() *Ranges { + if defaultRemapGIDs == nil { + defaultRemapGIDs = lookupHostSubgid(QuadletUserName) + if defaultRemapGIDs == nil { + defaultRemapGIDs = + NewRanges(FallbackGIDStart, FallbackGIDLength) + } + } + return defaultRemapGIDs +} + +func isPortRange(port string) bool { + return validPortRange.MatchString(port) +} + +func checkForUnknownKeys(unit *systemdparser.UnitFile, groupName string, supportedKeys map[string]bool) error { + keys := unit.ListKeys(groupName) + for _, key := range keys { + if !supportedKeys[key] { + return fmt.Errorf("unsupported key '%s' in group '%s' in %s", key, groupName, unit.Path) + } + } + return nil +} + +func lookupRanges(unit *systemdparser.UnitFile, groupName string, key string, nameLookup func(string) *Ranges, defaultValue *Ranges) *Ranges { + v, ok := unit.Lookup(groupName, key) + if !ok { + if defaultValue != nil { + return defaultValue.Copy() + } + + return NewRangesEmpty() + } + + if len(v) == 0 { + return NewRangesEmpty() + } + + if !unicode.IsDigit(rune(v[0])) { + if nameLookup != nil { + r := nameLookup(v) + if r != nil { + return r + } + } + return NewRangesEmpty() + } + + return ParseRanges(v) +} + +func splitPorts(ports string) []string { + parts := make([]string, 0) + + // IP address could have colons in it. For example: "[::]:8080:80/tcp, so we split carefully + start := 0 + end := 0 + for end < len(ports) { + switch ports[end] { + case '[': + end++ + for end < len(ports) && ports[end] != ']' { + end++ + } + if end < len(ports) { + end++ // Skip ] + } + case ':': + parts = append(parts, ports[start:end]) + end++ + start = end + default: + end++ + } + } + + parts = append(parts, ports[start:end]) + return parts +} + +func addIDMaps(podman *PodmanCmdline, argPrefix string, containerID, hostID, remapStartID uint32, availableHostIDs *Ranges) { + if availableHostIDs == nil { + // Map everything by default + availableHostIDs = NewRangesEmpty() + } + + // Map the first ids up to remapStartID to the host equivalent + unmappedIds := NewRanges(0, remapStartID) + + // The rest we want to map to availableHostIDs. Note that this + // overlaps unmappedIds, because below we may remove ranges from + // unmapped ids and we want to backfill those. + mappedIds := NewRanges(0, math.MaxUint32) + + // Always map specified uid to specified host_uid + podman.addIDMap(argPrefix, containerID, hostID, 1) + + // We no longer want to map this container id as its already mapped + mappedIds.Remove(containerID, 1) + unmappedIds.Remove(containerID, 1) + + // But also, we don't want to use the *host* id again, as we can only map it once + unmappedIds.Remove(hostID, 1) + availableHostIDs.Remove(hostID, 1) + + // Map unmapped ids to equivalent host range, and remove from mappedIds to avoid double-mapping + for _, r := range unmappedIds.Ranges { + start := r.Start + length := r.Length + + podman.addIDMap(argPrefix, start, start, length) + mappedIds.Remove(start, length) + availableHostIDs.Remove(start, length) + } + + for cIdx := 0; cIdx < len(mappedIds.Ranges) && len(availableHostIDs.Ranges) > 0; cIdx++ { + cRange := &mappedIds.Ranges[cIdx] + cStart := cRange.Start + cLength := cRange.Length + + for cLength > 0 && len(availableHostIDs.Ranges) > 0 { + hRange := &availableHostIDs.Ranges[0] + hStart := hRange.Start + hLength := hRange.Length + + nextLength := minUint32(hLength, cLength) + + podman.addIDMap(argPrefix, cStart, hStart, nextLength) + availableHostIDs.Remove(hStart, nextLength) + cStart += nextLength + cLength -= nextLength + } + } +} + +func ConvertContainer(container *systemdparser.UnitFile, isUser bool) (*systemdparser.UnitFile, error) { + service := container.Dup() + service.Filename = replaceExtension(container.Filename, ".service", "", "") + + if container.Path != "" { + service.Add(UnitGroup, "SourcePath", container.Path) + } + + if err := checkForUnknownKeys(container, ContainerGroup, supportedContainerKeys); err != nil { + return nil, err + } + + // Rename old Container group to x-Container so that systemd ignores it + service.RenameGroup(ContainerGroup, XContainerGroup) + + image, ok := container.Lookup(ContainerGroup, KeyImage) + if !ok || len(image) == 0 { + return nil, fmt.Errorf("no Image key specified") + } + + containerName, ok := container.Lookup(ContainerGroup, KeyContainerName) + if !ok || len(containerName) == 0 { + // By default, We want to name the container by the service name + containerName = "systemd-%N" + } + + // Set PODMAN_SYSTEMD_UNIT so that podman auto-update can restart the service. + service.Add(ServiceGroup, "Environment", "PODMAN_SYSTEMD_UNIT=%n") + + // Only allow mixed or control-group, as nothing else works well + killMode, ok := service.Lookup(ServiceGroup, "KillMode") + if !ok || !(killMode == "mixed" || killMode == "control-group") { + if ok { + return nil, fmt.Errorf("invalid KillMode '%s'", killMode) + } + + // We default to mixed instead of control-group, because it lets conmon do its thing + service.Set(ServiceGroup, "KillMode", "mixed") + } + + // Read env early so we can override it below + podmanEnv := container.LookupAllKeyVal(ContainerGroup, KeyEnvironment) + + // Need the containers filesystem mounted to start podman + service.Add(UnitGroup, "RequiresMountsFor", "%t/containers") + + // Remove any leftover cid file before starting, just to be sure. + // We remove any actual pre-existing container by name with --replace=true. + // But --cidfile will fail if the target exists. + service.Add(ServiceGroup, "ExecStartPre", "-rm -f %t/%N.cid") + + // If the conman exited uncleanly it may not have removed the container, so force it, + // -i makes it ignore non-existing files. + service.Add(ServiceGroup, "ExecStopPost", "-/usr/bin/podman rm -f -i --cidfile=%t/%N.cid") + + // Remove the cid file, to avoid confusion as the container is no longer running. + service.Add(ServiceGroup, "ExecStopPost", "-rm -f %t/%N.cid") + + podman := NewPodmanCmdline("run") + + podman.addf("--name=%s", containerName) + + podman.add( + // We store the container id so we can clean it up in case of failure + "--cidfile=%t/%N.cid", + + // And replace any previous container with the same name, not fail + "--replace", + + // On clean shutdown, remove container + "--rm", + + // Detach from container, we don't need the podman process to hang around + "-d", + + // But we still want output to the journal, so use the log driver. + // TODO: Once available we want to use the passthrough log-driver instead. + "--log-driver", "journald", + + // Never try to pull the image during service start + "--pull=never") + + // We use crun as the runtime and delegated groups to it + service.Add(ServiceGroup, "Delegate", "yes") + podman.add( + "--runtime", "/usr/bin/crun", + "--cgroups=split") + + timezone, ok := container.Lookup(ContainerGroup, KeyTimezone) + if ok && len(timezone) > 0 { + podman.addf("--tz=%s", timezone) + } + + // Run with a pid1 init to reap zombies by default (as most apps don't do that) + runInit := container.LookupBoolean(ContainerGroup, KeyRunInit, true) + if runInit { + podman.add("--init") + } + + // By default we handle startup notification with conmon, but allow passing it to the container with Notify=yes + notify := container.LookupBoolean(ContainerGroup, KeyNotify, false) + if notify { + podman.add("--sdnotify=container") + } else { + podman.add("--sdnotify=conmon") + } + service.Setv(ServiceGroup, + "Type", "notify", + "NotifyAccess", "all") + + if !container.HasKey(ServiceGroup, "SyslogIdentifier") { + service.Set(ServiceGroup, "SyslogIdentifier", "%N") + } + + // Default to no higher level privileges or caps + noNewPrivileges := container.LookupBoolean(ContainerGroup, KeyNoNewPrivileges, true) + if noNewPrivileges { + podman.add("--security-opt=no-new-privileges") + } + + dropCaps := []string{"all"} // Default + if container.HasKey(ContainerGroup, KeyDropCapability) { + dropCaps = container.LookupAll(ContainerGroup, KeyDropCapability) + } + + for _, caps := range dropCaps { + podman.addf("--cap-drop=%s", strings.ToLower(caps)) + } + + // But allow overrides with AddCapability + addCaps := container.LookupAll(ContainerGroup, KeyAddCapability) + for _, caps := range addCaps { + podman.addf("--cap-add=%s", strings.ToLower(caps)) + } + + readOnly := container.LookupBoolean(ContainerGroup, KeyReadOnly, false) + if readOnly { + podman.add("--read-only") + } + + // We want /tmp to be a tmpfs, like on rhel host + volatileTmp := container.LookupBoolean(ContainerGroup, KeyVolatileTmp, true) + if volatileTmp { + /* Read only mode already has a tmpfs by default */ + if !readOnly { + podman.add("--tmpfs", "/tmp:rw,size=512M,mode=1777") + } + } else if readOnly { + /* !volatileTmp, disable the default tmpfs from --read-only */ + podman.add("--read-only-tmpfs=false") + } + + socketActivated := container.LookupBoolean(ContainerGroup, KeySocketActivated, false) + if socketActivated { + // TODO: This will not be needed with later podman versions that support activation directly: + // https://github.com/containers/podman/pull/11316 + podman.add("--preserve-fds=1") + podmanEnv["LISTEN_FDS"] = "1" + + // TODO: This will not be 2 when catatonit forwards fds: + // https://github.com/openSUSE/catatonit/pull/15 + podmanEnv["LISTEN_PID"] = "2" + } + + defaultContainerUID := uint32(0) + defaultContainerGID := uint32(0) + + keepID := container.LookupBoolean(ContainerGroup, KeyKeepID, false) + if keepID { + if isUser { + defaultContainerUID = uint32(os.Getuid()) + defaultContainerGID = uint32(os.Getgid()) + podman.add("--userns", "keep-id") + } else { + return nil, fmt.Errorf("key 'KeepId' in '%s' unsupported for system units", container.Path) + } + } + + uid := container.LookupUint32(ContainerGroup, KeyUser, defaultContainerUID) + gid := container.LookupUint32(ContainerGroup, KeyGroup, defaultContainerGID) + + hostUID, err := container.LookupUID(ContainerGroup, KeyHostUser, uid) + if err != nil { + return nil, fmt.Errorf("key 'HostUser' invalid: %s", err) + } + + hostGID, err := container.LookupGID(ContainerGroup, KeyHostGroup, gid) + if err != nil { + return nil, fmt.Errorf("key 'HostGroup' invalid: %s", err) + } + + if uid != defaultContainerUID || gid != defaultContainerGID { + podman.add("--user") + if gid == defaultContainerGID { + podman.addf("%d", uid) + } else { + podman.addf("%d:%d", uid, gid) + } + } + + var remapUsers bool + if isUser { + remapUsers = false + } else { + remapUsers = container.LookupBoolean(ContainerGroup, KeyRemapUsers, false) + } + + if !remapUsers { + // No remapping of users, although we still need maps if the + // main user/group is remapped, even if most ids map one-to-one. + if uid != hostUID { + addIDMaps(podman, "--uidmap", uid, hostUID, math.MaxUint32, nil) + } + if gid != hostGID { + addIDMaps(podman, "--gidmap", gid, hostGID, math.MaxUint32, nil) + } + } else { + uidRemapIDs := lookupRanges(container, ContainerGroup, KeyRemapUIDRanges, lookupHostSubuid, getDefaultRemapUids()) + gidRemapIDs := lookupRanges(container, ContainerGroup, KeyRemapGIDRanges, lookupHostSubgid, getDefaultRemapGids()) + remapUIDStart := container.LookupUint32(ContainerGroup, KeyRemapUIDStart, 1) + remapGIDStart := container.LookupUint32(ContainerGroup, KeyRemapGIDStart, 1) + + addIDMaps(podman, "--uidmap", uid, hostUID, remapUIDStart, uidRemapIDs) + addIDMaps(podman, "--gidmap", gid, hostGID, remapGIDStart, gidRemapIDs) + } + + volumes := container.LookupAll(ContainerGroup, KeyVolume) + for _, volume := range volumes { + parts := strings.SplitN(volume, ":", 3) + + source := "" + var dest string + options := "" + if len(parts) >= 2 { + source = parts[0] + dest = parts[1] + } else { + dest = parts[0] + } + if len(parts) >= 3 { + options = ":" + parts[2] + } + + if source != "" { + if source[0] == '/' { + // Absolute path + service.Add(UnitGroup, "RequiresMountsFor", source) + } else if strings.HasSuffix(source, ".volume") { + // the podman volume name is systemd-$name + volumeName := replaceExtension(source, "", "systemd-", "") + + // the systemd unit name is $name-volume.service + volumeServiceName := replaceExtension(source, ".service", "", "-volume") + + source = volumeName + + service.Add(UnitGroup, "Requires", volumeServiceName) + service.Add(UnitGroup, "After", volumeServiceName) + } + } + + podman.add("-v") + if source == "" { + podman.add(dest) + } else { + podman.addf("%s:%s%s", source, dest, options) + } + } + + exposedPorts := container.LookupAll(ContainerGroup, KeyExposeHostPort) + for _, exposedPort := range exposedPorts { + exposedPort = strings.TrimSpace(exposedPort) // Allow whitespace after + + if !isPortRange(exposedPort) { + return nil, fmt.Errorf("invalid port format '%s'", exposedPort) + } + + podman.addf("--expose=%s", exposedPort) + } + + publishPorts := container.LookupAll(ContainerGroup, KeyPublishPort) + for _, publishPort := range publishPorts { + publishPort = strings.TrimSpace(publishPort) // Allow whitespace after + + // IP address could have colons in it. For example: "[::]:8080:80/tcp, so use custom splitter + parts := splitPorts(publishPort) + + var containerPort string + ip := "" + hostPort := "" + + // format (from podman run): + // ip:hostPort:containerPort | ip::containerPort | hostPort:containerPort | containerPort + // + // ip could be IPv6 with minimum of these chars "[::]" + // containerPort can have a suffix of "/tcp" or "/udp" + // + + switch len(parts) { + case 1: + containerPort = parts[0] + + case 2: + hostPort = parts[0] + containerPort = parts[1] + + case 3: + ip = parts[0] + hostPort = parts[1] + containerPort = parts[2] + + default: + return nil, fmt.Errorf("invalid published port '%s'", publishPort) + } + + if ip == "0.0.0.0" { + ip = "" + } + + if len(hostPort) > 0 && !isPortRange(hostPort) { + return nil, fmt.Errorf("invalid port format '%s'", hostPort) + } + + if len(containerPort) > 0 && !isPortRange(containerPort) { + return nil, fmt.Errorf("invalid port format '%s'", containerPort) + } + + switch { + case len(ip) > 0 && len(hostPort) > 0: + podman.addf("-p=%s:%s:%s", ip, hostPort, containerPort) + case len(ip) > 0: + podman.addf("-p=%s::%s", ip, containerPort) + case len(hostPort) > 0: + podman.addf("-p=%s:%s", hostPort, containerPort) + default: + podman.addf("-p=%s", containerPort) + } + } + + podman.addEnv(podmanEnv) + + labels := container.LookupAllKeyVal(ContainerGroup, KeyLabel) + podman.addLabels(labels) + + annotations := container.LookupAllKeyVal(ContainerGroup, KeyAnnotation) + podman.addAnnotations(annotations) + + podmanArgs := container.LookupAllArgs(ContainerGroup, KeyPodmanArgs) + podman.add(podmanArgs...) + + podman.add(image) + + execArgs, ok := container.LookupLastArgs(ContainerGroup, KeyExec) + if ok { + podman.add(execArgs...) + } + + service.AddCmdline(ServiceGroup, "ExecStart", podman.Args) + + return service, nil +} + +func ConvertVolume(volume *systemdparser.UnitFile, name string) (*systemdparser.UnitFile, error) { + service := volume.Dup() + service.Filename = replaceExtension(volume.Filename, ".service", "", "-volume") + + if err := checkForUnknownKeys(volume, VolumeGroup, supportedVolumeKeys); err != nil { + return nil, err + } + + /* Rename old Volume group to x-Volume so that systemd ignores it */ + service.RenameGroup(VolumeGroup, XVolumeGroup) + + volumeName := replaceExtension(name, "", "systemd-", "") + + // Need the containers filesystem mounted to start podman + service.Add(UnitGroup, "RequiresMountsFor", "%t/containers") + + execCond := fmt.Sprintf("/usr/bin/bash -c \"! /usr/bin/podman volume exists %s\"", volumeName) + + labels := volume.LookupAllKeyVal(VolumeGroup, "Label") + + podman := NewPodmanCmdline("volume", "create") + + var opts strings.Builder + opts.WriteString("o=") + + if volume.HasKey(VolumeGroup, "User") { + uid := volume.LookupUint32(VolumeGroup, "User", 0) + if opts.Len() > 2 { + opts.WriteString(",") + } + opts.WriteString(fmt.Sprintf("uid=%d", uid)) + } + + if volume.HasKey(VolumeGroup, "Group") { + gid := volume.LookupUint32(VolumeGroup, "Group", 0) + if opts.Len() > 2 { + opts.WriteString(",") + } + opts.WriteString(fmt.Sprintf("gid=%d", gid)) + } + + if opts.Len() > 2 { + podman.add("--opt", opts.String()) + } + + podman.addLabels(labels) + podman.add(volumeName) + + service.AddCmdline(ServiceGroup, "ExecStart", podman.Args) + + service.Setv(ServiceGroup, + "Type", "oneshot", + "RemainAfterExit", "yes", + "ExecCondition", execCond, + + // The default syslog identifier is the exec basename (podman) which isn't very useful here + "SyslogIdentifier", "%N") + + return service, nil +} diff --git a/pkg/quadlet/quadlet_test.go b/pkg/quadlet/quadlet_test.go new file mode 100644 index 000000000000..c94678ef6f3f --- /dev/null +++ b/pkg/quadlet/quadlet_test.go @@ -0,0 +1,45 @@ +package quadlet + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestQuadlet_SplitPorts(t *testing.T) { + parts := splitPorts("") + assert.Equal(t, len(parts), 1) + assert.Equal(t, parts[0], "") + + parts = splitPorts("foo") + assert.Equal(t, len(parts), 1) + assert.Equal(t, parts[0], "foo") + + parts = splitPorts("foo:bar") + assert.Equal(t, len(parts), 2) + assert.Equal(t, parts[0], "foo") + assert.Equal(t, parts[1], "bar") + + parts = splitPorts("foo:bar:") + assert.Equal(t, len(parts), 3) + assert.Equal(t, parts[0], "foo") + assert.Equal(t, parts[1], "bar") + assert.Equal(t, parts[2], "") + + parts = splitPorts("abc[foo::bar]xyz:foo:bar") + assert.Equal(t, len(parts), 3) + assert.Equal(t, parts[0], "abc[foo::bar]xyz") + assert.Equal(t, parts[1], "foo") + assert.Equal(t, parts[2], "bar") + + parts = splitPorts("foo:abc[foo::bar]xyz:bar") + assert.Equal(t, len(parts), 3) + assert.Equal(t, parts[0], "foo") + assert.Equal(t, parts[1], "abc[foo::bar]xyz") + assert.Equal(t, parts[2], "bar") + + parts = splitPorts("foo:abc[foo::barxyz:bar") + assert.Equal(t, len(parts), 2) + assert.Equal(t, parts[0], "foo") + assert.Equal(t, parts[1], "abc[foo::barxyz:bar") +} diff --git a/pkg/quadlet/ranges.go b/pkg/quadlet/ranges.go new file mode 100644 index 000000000000..ec5da44b01b7 --- /dev/null +++ b/pkg/quadlet/ranges.go @@ -0,0 +1,247 @@ +package quadlet + +import ( + "math" + "strconv" + "strings" +) + +/* Keep track of a list of ranges of uint32, used to manage Uid/Gid ranges for mapping */ + +func minUint32(x, y uint32) uint32 { + if x < y { + return x + } + return y +} + +func maxUint32(x, y uint32) uint32 { + if x > y { + return x + } + return y +} + +type Range struct { + Start uint32 + Length uint32 +} + +type Ranges struct { + Ranges []Range +} + +func (r *Ranges) Add(start, length uint32) { + // The maximum value we can store is UINT32_MAX-1, because if start + // is 0 and length is UINT32_MAX, then the first non-range item is + // 0+UINT32_MAX. So, we limit the start and length here so all + // elements in the ranges are in this area. + if start == math.MaxUint32 { + return + } + length = minUint32(length, math.MaxUint32-start) + + if length == 0 { + return + } + + for i := 0; i < len(r.Ranges); i++ { + current := &r.Ranges[i] + // Check if new range starts before current + if start < current.Start { + // Check if new range is completely before current + if start+length < current.Start { + // insert new range at i + newr := make([]Range, len(r.Ranges)+1) + copy(newr[0:i], r.Ranges[0:i]) + newr[i] = Range{Start: start, Length: length} + copy(newr[i+1:], r.Ranges[i:]) + r.Ranges = newr + + return // All done + } + + // ranges overlap, extend current backward to new start + toExtendLen := current.Start - start + current.Start -= toExtendLen + current.Length += toExtendLen + + // And drop the extended part from new range + start += toExtendLen + length -= toExtendLen + + if length == 0 { + return // That was all + } + + // Move on to next case + } + + if start >= current.Start && start < current.Start+current.Length { + // New range overlaps current + if start+length <= current.Start+current.Length { + return // All overlapped, we're done + } + + // New range extends past end of current + overlapLen := (current.Start + current.Length) - start + + // And drop the overlapped part from current range + start += overlapLen + length -= overlapLen + + // Move on to next case + } + + if start == current.Start+current.Length { + // We're extending current + current.Length += length + + // Might have to merge some old remaining ranges + for i+1 < len(r.Ranges) && + r.Ranges[i+1].Start <= current.Start+current.Length { + next := &r.Ranges[i+1] + + newEnd := maxUint32(current.Start+current.Length, next.Start+next.Length) + + current.Length = newEnd - current.Start + + copy(r.Ranges[i+1:], r.Ranges[i+2:]) + r.Ranges = r.Ranges[:len(r.Ranges)-1] + current = &r.Ranges[i] + } + + return // All done + } + } + + // New range remaining after last old range, append + if length > 0 { + r.Ranges = append(r.Ranges, Range{Start: start, Length: length}) + } +} + +func (r *Ranges) Remove(start, length uint32) { + // Limit ranges, see comment in Add + if start == math.MaxUint32 { + return + } + length = minUint32(length, math.MaxUint32-start) + + if length == 0 { + return + } + + for i := 0; i < len(r.Ranges); i++ { + current := &r.Ranges[i] + + end := start + length + currentStart := current.Start + currentEnd := current.Start + current.Length + + if end > currentStart && start < currentEnd { + remainingAtStart := uint32(0) + remainingAtEnd := uint32(0) + + if start > currentStart { + remainingAtStart = start - currentStart + } + + if end < currentEnd { + remainingAtEnd = currentEnd - end + } + + switch { + case remainingAtStart == 0 && remainingAtEnd == 0: + // Remove whole range + copy(r.Ranges[i:], r.Ranges[i+1:]) + r.Ranges = r.Ranges[:len(r.Ranges)-1] + i-- // undo loop iter + case remainingAtStart != 0 && remainingAtEnd != 0: + // Range is split + + newr := make([]Range, len(r.Ranges)+1) + copy(newr[0:i], r.Ranges[0:i]) + copy(newr[i+1:], r.Ranges[i:]) + newr[i].Start = currentStart + newr[i].Length = remainingAtStart + newr[i+1].Start = currentEnd - remainingAtEnd + newr[i+1].Length = remainingAtEnd + r.Ranges = newr + i++ /* double loop iter */ + case remainingAtStart != 0: + r.Ranges[i].Start = currentStart + r.Ranges[i].Length = remainingAtStart + default: /* remainingAtEnd != 0 */ + r.Ranges[i].Start = currentEnd - remainingAtEnd + r.Ranges[i].Length = remainingAtEnd + } + } + } +} + +func (r *Ranges) Merge(other *Ranges) { + for _, o := range other.Ranges { + r.Add(o.Start, o.Length) + } +} + +func (r *Ranges) Copy() *Ranges { + rs := make([]Range, len(r.Ranges)) + copy(rs, r.Ranges) + return &Ranges{Ranges: rs} +} + +func (r *Ranges) Length() uint32 { + length := uint32(0) + for _, rr := range r.Ranges { + length += rr.Length + } + return length +} + +func NewRangesEmpty() *Ranges { + return &Ranges{Ranges: nil} +} + +func NewRanges(start, length uint32) *Ranges { + r := NewRangesEmpty() + r.Add(start, length) + + return r +} + +func parseEndpoint(str string, defaultVal uint32) uint32 { + str = strings.TrimSpace(str) + intVal, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return defaultVal + } + + if intVal < 0 { + return uint32(0) + } + if intVal > math.MaxUint32 { + return uint32(math.MaxUint32) + } + return uint32(intVal) +} + +// Ranges are specified inclusive. I.e. 1-3 is 1,2,3 +func ParseRanges(str string) *Ranges { + r := NewRangesEmpty() + + for _, part := range strings.Split(str, ",") { + start, end, isPair := strings.Cut(part, "-") + startV := parseEndpoint(start, 0) + endV := startV + if isPair { + endV = parseEndpoint(end, math.MaxUint32) + } + if endV >= startV { + r.Add(startV, endV-startV+1) + } + } + + return r +} diff --git a/pkg/quadlet/ranges_test.go b/pkg/quadlet/ranges_test.go new file mode 100644 index 000000000000..b738c7dc42f8 --- /dev/null +++ b/pkg/quadlet/ranges_test.go @@ -0,0 +1,242 @@ +package quadlet + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRanges_Creation(t *testing.T) { + empty := NewRangesEmpty() + + assert.Equal(t, empty.Length(), uint32(0)) + + one := NewRanges(17, 42) + assert.Equal(t, one.Ranges[0].Start, uint32(17)) + assert.Equal(t, one.Ranges[0].Length, uint32(42)) +} + +func TestRanges_Single(t *testing.T) { + /* Before */ + r := NewRanges(10, 10) + + r.Add(0, 9) + + assert.Equal(t, len(r.Ranges), 2) + assert.Equal(t, r.Ranges[0].Start, uint32(0)) + assert.Equal(t, r.Ranges[0].Length, uint32(9)) + assert.Equal(t, r.Ranges[1].Start, uint32(10)) + assert.Equal(t, r.Ranges[1].Length, uint32(10)) + + /* just before */ + r = NewRanges(10, 10) + + r.Add(0, 10) + + assert.Equal(t, len(r.Ranges), 1) + assert.Equal(t, r.Ranges[0].Start, uint32(0)) + assert.Equal(t, r.Ranges[0].Length, uint32(20)) + + /* before + inside */ + r = NewRanges(10, 10) + + r.Add(0, 19) + + assert.Equal(t, len(r.Ranges), 1) + assert.Equal(t, r.Ranges[0].Start, uint32(0)) + assert.Equal(t, r.Ranges[0].Length, uint32(20)) + + /* before + inside, whole */ + r = NewRanges(10, 10) + + r.Add(0, 20) + + assert.Equal(t, len(r.Ranges), 1) + assert.Equal(t, r.Ranges[0].Start, uint32(0)) + assert.Equal(t, r.Ranges[0].Length, uint32(20)) + + /* before + inside + after */ + r = NewRanges(10, 10) + + r.Add(0, 30) + + assert.Equal(t, len(r.Ranges), 1) + assert.Equal(t, r.Ranges[0].Start, uint32(0)) + assert.Equal(t, r.Ranges[0].Length, uint32(30)) + + /* just inside */ + r = NewRanges(10, 10) + + r.Add(10, 5) + + assert.Equal(t, len(r.Ranges), 1) + assert.Equal(t, r.Ranges[0].Start, uint32(10)) + assert.Equal(t, r.Ranges[0].Length, uint32(10)) + + /* inside */ + r = NewRanges(10, 10) + + r.Add(12, 5) + + assert.Equal(t, len(r.Ranges), 1) + assert.Equal(t, r.Ranges[0].Start, uint32(10)) + assert.Equal(t, r.Ranges[0].Length, uint32(10)) + + /* inside at end */ + r = NewRanges(10, 10) + + r.Add(15, 5) + + assert.Equal(t, len(r.Ranges), 1) + assert.Equal(t, r.Ranges[0].Start, uint32(10)) + assert.Equal(t, r.Ranges[0].Length, uint32(10)) + + /* inside + after */ + r = NewRanges(10, 10) + + r.Add(15, 10) + + assert.Equal(t, len(r.Ranges), 1) + assert.Equal(t, r.Ranges[0].Start, uint32(10)) + assert.Equal(t, r.Ranges[0].Length, uint32(15)) + + /* just after */ + r = NewRanges(10, 10) + + r.Add(20, 10) + + assert.Equal(t, len(r.Ranges), 1) + assert.Equal(t, r.Ranges[0].Start, uint32(10)) + assert.Equal(t, r.Ranges[0].Length, uint32(20)) + + /* after */ + r = NewRanges(10, 10) + + r.Add(21, 10) + + assert.Equal(t, len(r.Ranges), 2) + assert.Equal(t, r.Ranges[0].Start, uint32(10)) + assert.Equal(t, r.Ranges[0].Length, uint32(10)) + assert.Equal(t, r.Ranges[1].Start, uint32(21)) + assert.Equal(t, r.Ranges[1].Length, uint32(10)) +} + +func TestRanges_Multi(t *testing.T) { + base := NewRanges(10, 10) + base.Add(50, 10) + base.Add(30, 10) + + /* Test copy */ + r := base.Copy() + + assert.Equal(t, len(r.Ranges), 3) + assert.Equal(t, r.Ranges[0].Start, uint32(10)) + assert.Equal(t, r.Ranges[0].Length, uint32(10)) + assert.Equal(t, r.Ranges[1].Start, uint32(30)) + assert.Equal(t, r.Ranges[1].Length, uint32(10)) + assert.Equal(t, r.Ranges[2].Start, uint32(50)) + assert.Equal(t, r.Ranges[2].Length, uint32(10)) + + /* overlap everything */ + r = base.Copy() + + r.Add(0, 100) + + assert.Equal(t, len(r.Ranges), 1) + assert.Equal(t, r.Ranges[0].Start, uint32(0)) + assert.Equal(t, r.Ranges[0].Length, uint32(100)) + + /* overlap middle */ + r = base.Copy() + + r.Add(25, 10) + + assert.Equal(t, len(r.Ranges), 3) + assert.Equal(t, r.Ranges[0].Start, uint32(10)) + assert.Equal(t, r.Ranges[0].Length, uint32(10)) + assert.Equal(t, r.Ranges[1].Start, uint32(25)) + assert.Equal(t, r.Ranges[1].Length, uint32(15)) + assert.Equal(t, r.Ranges[2].Start, uint32(50)) + assert.Equal(t, r.Ranges[2].Length, uint32(10)) + + /* overlap last */ + r = base.Copy() + + r.Add(45, 10) + + assert.Equal(t, len(r.Ranges), 3) + assert.Equal(t, r.Ranges[0].Start, uint32(10)) + assert.Equal(t, r.Ranges[0].Length, uint32(10)) + assert.Equal(t, r.Ranges[1].Start, uint32(30)) + assert.Equal(t, r.Ranges[1].Length, uint32(10)) + assert.Equal(t, r.Ranges[2].Start, uint32(45)) + assert.Equal(t, r.Ranges[2].Length, uint32(15)) +} + +func TestRanges_Remove(t *testing.T) { + base := NewRanges(10, 10) + base.Add(50, 10) + base.Add(30, 10) + + /* overlap all */ + r := base.Copy() + + r.Remove(0, 100) + + assert.Equal(t, len(r.Ranges), 0) + + /* overlap middle 1 */ + + r = base.Copy() + + r.Remove(25, 20) + + assert.Equal(t, len(r.Ranges), 2) + assert.Equal(t, r.Ranges[0].Start, uint32(10)) + assert.Equal(t, r.Ranges[0].Length, uint32(10)) + assert.Equal(t, r.Ranges[1].Start, uint32(50)) + assert.Equal(t, r.Ranges[1].Length, uint32(10)) + + /* overlap middle 2 */ + + r = base.Copy() + + r.Remove(25, 10) + + assert.Equal(t, len(r.Ranges), 3) + assert.Equal(t, r.Ranges[0].Start, uint32(10)) + assert.Equal(t, r.Ranges[0].Length, uint32(10)) + assert.Equal(t, r.Ranges[1].Start, uint32(35)) + assert.Equal(t, r.Ranges[1].Length, uint32(5)) + assert.Equal(t, r.Ranges[2].Start, uint32(50)) + assert.Equal(t, r.Ranges[2].Length, uint32(10)) + + /* overlap middle 3 */ + r = base.Copy() + + r.Remove(35, 10) + + assert.Equal(t, len(r.Ranges), 3) + assert.Equal(t, r.Ranges[0].Start, uint32(10)) + assert.Equal(t, r.Ranges[0].Length, uint32(10)) + assert.Equal(t, r.Ranges[1].Start, uint32(30)) + assert.Equal(t, r.Ranges[1].Length, uint32(5)) + assert.Equal(t, r.Ranges[2].Start, uint32(50)) + assert.Equal(t, r.Ranges[2].Length, uint32(10)) + + /* overlap middle 4 */ + + r = base.Copy() + + r.Remove(34, 2) + + assert.Equal(t, len(r.Ranges), 4) + assert.Equal(t, r.Ranges[0].Start, uint32(10)) + assert.Equal(t, r.Ranges[0].Length, uint32(10)) + assert.Equal(t, r.Ranges[1].Start, uint32(30)) + assert.Equal(t, r.Ranges[1].Length, uint32(4)) + assert.Equal(t, r.Ranges[2].Start, uint32(36)) + assert.Equal(t, r.Ranges[2].Length, uint32(4)) + assert.Equal(t, r.Ranges[3].Start, uint32(50)) + assert.Equal(t, r.Ranges[3].Length, uint32(10)) +} diff --git a/pkg/quadlet/subuids.go b/pkg/quadlet/subuids.go new file mode 100644 index 000000000000..de0d72ea53eb --- /dev/null +++ b/pkg/quadlet/subuids.go @@ -0,0 +1,67 @@ +package quadlet + +import ( + "os" + "strconv" + "strings" +) + +func lookupHostSubid(name string, file string, cache *[]string) *Ranges { + ranges := NewRangesEmpty() + + if len(*cache) == 0 { + data, e := os.ReadFile(file) + if e != nil { + *cache = make([]string, 0) + } else { + *cache = strings.Split(string(data), "\n") + } + for i := range *cache { + (*cache)[i] = strings.TrimSpace((*cache)[i]) + } + + // If file had no lines, add an empty line so the above cache created check works + if len(*cache) == 0 { + *cache = append(*cache, "") + } + } + + for _, line := range *cache { + if strings.HasPrefix(line, name) && + len(line) > len(name)+1 && line[len(name)] == ':' { + parts := strings.SplitN(line, ":", 3) + + if len(parts) != 3 { + continue + } + + start, err := strconv.ParseUint(parts[1], 10, 32) + if err != nil { + continue + } + + len, err := strconv.ParseUint(parts[1], 10, 32) + if err != nil { + continue + } + + if len > 0 { + ranges.Add(uint32(start), uint32(len)) + } + + break + } + } + + return ranges +} + +var subuidCache, subgidCache []string + +func lookupHostSubuid(userName string) *Ranges { + return lookupHostSubid(userName, "/etc/subuid", &subuidCache) +} + +func lookupHostSubgid(userName string) *Ranges { + return lookupHostSubid(userName, "/etc/subgid", &subgidCache) +} diff --git a/pkg/systemdparser/split.go b/pkg/systemdparser/split.go new file mode 100644 index 000000000000..3ff3d83a55b0 --- /dev/null +++ b/pkg/systemdparser/split.go @@ -0,0 +1,505 @@ +package systemdparser + +import ( + "fmt" + "strings" + "unicode" +) + +/* Functions to split/join, unescape/escape strings similar to Exec=... lines in unit files */ + +type SplitFlags = uint64 + +const ( + SplitRelax SplitFlags = 1 << iota // Allow unbalanced quote and eat up trailing backslash. + SplitCUnescape // Unescape known escape sequences. + SplitUnescapeRelax // Allow and keep unknown escape sequences, allow and keep trailing backslash. + SplitUnescapeSeparators // Unescape separators (those specified, or whitespace by default). + SplitKeepQuote // Ignore separators in quoting with "" and ''. + SplitUnquote // Ignore separators in quoting with "" and '', and remove the quotes. + SplitDontCoalesceSeparators // Don't treat multiple adjacent separators as one + SplitRetainEscape // Treat escape character '\' as any other character without special meaning + SplitRetainSeparators // Do not advance the original string pointer past the separator(s) */ +) + +const WhitespaceSeparators = " \t\n\r" + +func unoctchar(v byte) int { + if v >= '0' && v <= '7' { + return int(v - '0') + } + + return -1 +} + +func unhexchar(v byte) int { + if v >= '0' && v <= '9' { + return int(v - '0') + } + + if v >= 'a' && v <= 'f' { + return int(v - 'a' + 10) + } + + if v >= 'A' && v <= 'F' { + return int(v - 'A' + 10) + } + + return -1 +} + +func isValidUnicode(c uint32) bool { + return c <= unicode.MaxRune +} + +/* This is based on code from systemd (src/basic/escape.c), marked LGPL-2.1-or-later and is copyrighted by the systemd developers */ + +func cUnescapeOne(p string, acceptNul bool) (int, rune, bool) { + var count = 1 + var eightBit = false + var ret rune + + // Unescapes C style. Returns the unescaped character in ret. + // Returns eightBit as true if the escaped sequence either fits in + // one byte in UTF-8 or is a non-unicode literal byte and should + // instead be copied directly. + + if len(p) < 1 { + return -1, 0, false + } + + switch p[0] { + case 'a': + ret = '\a' + case 'b': + ret = '\b' + case 'f': + ret = '\f' + case 'n': + ret = '\n' + case 'r': + ret = '\r' + case 't': + ret = '\t' + case 'v': + ret = '\v' + case '\\': + ret = '\\' + case '"': + ret = '"' + case '\'': + ret = '\'' + case 's': + /* This is an extension of the XDG syntax files */ + ret = ' ' + case 'x': + /* hexadecimal encoding */ + if len(p) < 3 { + return -1, 0, false + } + + a := unhexchar(p[1]) + if a < 0 { + return -1, 0, false + } + + b := unhexchar(p[2]) + if b < 0 { + return -1, 0, false + } + + /* Don't allow NUL bytes */ + if a == 0 && b == 0 && !acceptNul { + return -1, 0, false + } + + ret = rune((a << 4) | b) + eightBit = true + count = 3 + case 'u': + /* C++11 style 16bit unicode */ + + if len(p) < 5 { + return -1, 0, false + } + + var a [4]int + for i := 0; i < 4; i++ { + a[i] = unhexchar(p[1+i]) + if a[i] < 0 { + return -1, 0, false + } + } + + c := (uint32(a[0]) << 12) | (uint32(a[1]) << 8) | (uint32(a[2]) << 4) | uint32(a[3]) + + /* Don't allow 0 chars */ + if c == 0 && !acceptNul { + return -1, 0, false + } + + ret = rune(c) + count = 5 + case 'U': + /* C++11 style 32bit unicode */ + + if len(p) < 9 { + return -1, 0, false + } + + var a [8]int + for i := 0; i < 8; i++ { + a[i] = unhexchar(p[1+i]) + if a[i] < 0 { + return -10, 0, false + } + } + + c := (uint32(a[0]) << 28) | (uint32(a[1]) << 24) | (uint32(a[2]) << 20) | (uint32(a[3]) << 16) | + (uint32(a[4]) << 12) | (uint32(a[5]) << 8) | (uint32(a[6]) << 4) | uint32(a[7]) + + /* Don't allow 0 chars */ + if c == 0 && !acceptNul { + return -1, 0, false + } + + /* Don't allow invalid code points */ + if !isValidUnicode(c) { + return -1, 0, false + } + + ret = rune(c) + count = 9 + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + /* octal encoding */ + + if len(p) < 3 { + return -1, 0, false + } + + a := unoctchar(p[0]) + if a < 0 { + return -1, 0, false + } + + b := unoctchar(p[0]) + if b < 0 { + return -1, 0, false + } + + c := unoctchar(p[0]) + if c < 0 { + return -1, 0, false + } + + /* don't allow NUL bytes */ + if a == 0 && b == 0 && c == 0 && !acceptNul { + return -1, 0, false + } + + /* Don't allow bytes above 255 */ + m := (uint32(a) << 6) | (uint32(b) << 3) | uint32(c) + if m > 255 { + return -1, 0, false + } + + ret = rune(m) + eightBit = true + count = 3 + default: + return -1, 0, false + } + + return count, ret, eightBit +} + +/* This is based on code from systemd (src/basic/extract-workd.c), marked LGPL-2.1-or-later and is copyrighted by the systemd developers */ + +// Returns: word, remaining, more-words, error +func extractFirstWord(in string, separators string, flags SplitFlags) (string, string, bool, error) { + var s strings.Builder + var quote byte // 0 or ' or " + backslash := false // whether we've just seen a backslash + + // The string handling in this function is a bit weird, using + // 0 bytes to mark end-of-string. This is because its a direct + // conversion of the C in systemd, and w want to ensure + // exactly the same behaviour of some complex code + + p := 0 + end := len(in) + var c byte + + nextChar := func() byte { + p++ + if p >= end { + return 0 + } + return in[p] + } + + /* Bail early if called after last value or with no input */ + if len(in) == 0 { + goto finish + } + + // Parses the first word of a string, and returns it and the + // remainder. Removes all quotes in the process. When parsing + // fails (because of an uneven number of quotes or similar), + // the rest is at the first invalid character. */ + +loop1: + for c = in[0]; ; c = nextChar() { + switch { + case c == 0: + goto finishForceTerminate + case strings.ContainsRune(separators, rune(c)): + if flags&SplitDontCoalesceSeparators != 0 { + if !(flags&SplitRetainSeparators != 0) { + p++ + } + goto finishForceNext + } + default: + // We found a non-blank character, so we will always + // want to return a string (even if it is empty), + // allocate it here. + break loop1 + } + } + + for ; ; c = nextChar() { + switch { + case backslash: + if c == 0 { + if flags&SplitUnescapeRelax != 0 && + (quote == 0 || flags&SplitRelax != 0) { + // If we find an unquoted trailing backslash and we're in + // SplitUnescapeRelax mode, keep it verbatim in the + // output. + // + // Unbalanced quotes will only be allowed in SplitRelax + // mode, SplitUnescapeRelax mode does not allow them. + s.WriteString("\\") + goto finishForceTerminate + } + if flags&SplitRelax != 0 { + goto finishForceTerminate + } + return "", "", false, fmt.Errorf("unbalanced escape") + } + + if flags&(SplitCUnescape|SplitUnescapeSeparators) != 0 { + var r = -1 + var u rune + + if flags&SplitCUnescape != 0 { + r, u, _ = cUnescapeOne(in[p:], false) + } + + switch { + case r > 0: + p += r - 1 + s.WriteRune(u) + case (flags&SplitUnescapeSeparators != 0) && + (strings.ContainsRune(separators, rune(c)) || c == '\\'): + /* An escaped separator char or the escape char itself */ + s.WriteByte(c) + case flags&SplitUnescapeRelax != 0: + s.WriteByte('\\') + s.WriteByte(c) + default: + return "", "", false, fmt.Errorf("unsupported escape char") + } + } else { + s.WriteByte(c) + } + + backslash = false + case quote != 0: + /* inside either single or double quotes */ + quoteloop: + for ; ; c = nextChar() { + switch { + case c == 0: + if flags&SplitRelax != 0 { + goto finishForceTerminate + } + return "", "", false, fmt.Errorf("unbalanced quotes") + case c == quote: + /* found the end quote */ + quote = 0 + if flags&SplitUnquote != 0 { + break quoteloop + } + case c == '\\' && !(flags&SplitRetainEscape != 0): + backslash = true + break quoteloop + } + + s.WriteByte(c) + + if quote == 0 { + break quoteloop + } + } + default: + nonquoteloop: + for ; ; c = nextChar() { + switch { + case c == 0: + goto finishForceTerminate + case (c == '\'' || c == '"') && (flags&(SplitKeepQuote|SplitUnquote) != 0): + quote = c + if flags&SplitUnquote != 0 { + break nonquoteloop + } + case c == '\\' && !(flags&SplitRetainEscape != 0): + backslash = true + break nonquoteloop + case strings.ContainsRune(separators, rune(c)): + if flags&SplitDontCoalesceSeparators != 0 { + if !(flags&SplitRetainSeparators != 0) { + p++ + } + goto finishForceNext + } + + if !(flags&SplitRetainSeparators != 0) { + /* Skip additional coalesced separators. */ + for ; ; c = nextChar() { + if c == 0 { + goto finishForceTerminate + } + if !strings.ContainsRune(separators, rune(c)) { + break + } + } + } + goto finish + } + + s.WriteByte(c) + + if quote != 0 { + break nonquoteloop + } + } + } + } + +finishForceTerminate: + p = end + +finish: + if s.Len() == 0 { + return "", "", false, nil + } + +finishForceNext: + return s.String(), in[p:], true, nil +} + +func splitStringAppend(appendTo []string, s string, separators string, flags SplitFlags) ([]string, error) { + orig := appendTo + for { + word, remaining, moreWords, err := extractFirstWord(s, separators, flags) + if err != nil { + return orig, err + } + + if !moreWords { + break + } + appendTo = append(appendTo, word) + s = remaining + } + return appendTo, nil +} + +func splitString(s string, separators string, flags SplitFlags) ([]string, error) { + return splitStringAppend(make([]string, 0), s, separators, flags) +} + +func charNeedEscape(c rune) bool { + if c > 128 { + return false /* unicode is ok */ + } + + return unicode.IsSpace(c) || + unicode.IsControl(c) || + c == '"' || + c == '\'' || + c == '\\' +} + +func wordNeedEscape(word string) bool { + for _, c := range word { + if charNeedEscape(c) { + return true + } + } + + return false +} + +func appendEscapeWord(escaped *strings.Builder, word string) { + escaped.WriteRune('"') + for _, c := range word { + if charNeedEscape(c) { + switch c { + case '\a': + escaped.WriteString("\\a") + case '\b': + escaped.WriteString("\\b") + case '\n': + escaped.WriteString("\\n") + case '\r': + escaped.WriteString("\\r") + case '\t': + escaped.WriteString("\\t") + case '\v': + escaped.WriteString("\\v") + case '\f': + escaped.WriteString("\\f") + case '\\': + escaped.WriteString("\\\\") + case ' ': + escaped.WriteString(" ") + case '"': + escaped.WriteString("\\\"") + case '\'': + escaped.WriteString("'") + default: + escaped.WriteString(fmt.Sprintf("\\x%.2x", c)) + } + } else { + escaped.WriteRune(c) + } + } + escaped.WriteRune('"') +} + +func escapeWords(words []string) string { + var escaped strings.Builder + + for i, word := range words { + if i != 0 { + escaped.WriteString(" ") + } + if wordNeedEscape(word) { + appendEscapeWord(&escaped, word) + } else { + escaped.WriteString(word) + } + } + + return escaped.String() +} diff --git a/pkg/systemdparser/unitfile.go b/pkg/systemdparser/unitfile.go new file mode 100644 index 000000000000..912d391f6064 --- /dev/null +++ b/pkg/systemdparser/unitfile.go @@ -0,0 +1,866 @@ +package systemdparser + +import ( + "fmt" + "io" + "math" + "os" + "os/user" + "path" + "strconv" + "strings" + "unicode" +) + +/* Code to parse, modify and re-emit system unit files */ + +type unitLine struct { + key string + value string + isComment bool +} + +type unitGroup struct { + name string + comments []*unitLine // Comments before the groupname + lines []*unitLine +} + +type UnitFile struct { + groups []*unitGroup + groupByName map[string]*unitGroup + + Filename string + Path string +} + +type UnitFileParser struct { + file *UnitFile + + currentGroup *unitGroup + pendingComments []*unitLine + lineNr int +} + +func newUnitLine(key string, value string, isComment bool) *unitLine { + l := &unitLine{ + key: key, + value: value, + isComment: isComment, + } + return l +} + +func (l *unitLine) set(value string) { + l.value = value +} + +func (l *unitLine) dup() *unitLine { + return newUnitLine(l.key, l.value, l.isComment) +} + +func (l *unitLine) isKey(key string) bool { + return !l.isComment && + l.key == key +} + +func (l *unitLine) isEmpty() bool { + return len(l.value) == 0 +} + +func newUnitGroup(name string) *unitGroup { + g := &unitGroup{ + name: name, + comments: make([]*unitLine, 0), + lines: make([]*unitLine, 0), + } + return g +} + +func (g *unitGroup) addLine(line *unitLine) { + g.lines = append(g.lines, line) +} + +func (g *unitGroup) addComment(line *unitLine) { + g.comments = append(g.comments, line) +} + +func (g *unitGroup) prependComment(line *unitLine) { + n := []*unitLine{line} + g.comments = append(n, g.comments...) +} + +func (g *unitGroup) add(key string, value string) { + g.addLine(newUnitLine(key, value, false)) +} + +func (g *unitGroup) findLast(key string) *unitLine { + for i := len(g.lines) - 1; i >= 0; i-- { + l := g.lines[i] + if l.isKey(key) { + return l + } + } + + return nil +} + +func (g *unitGroup) set(key string, value string) { + line := g.findLast(key) + if line != nil { + line.set(value) + } else { + g.add(key, value) + } +} + +func (g *unitGroup) unset(key string) { + newlines := make([]*unitLine, 0, len(g.lines)) + + for _, line := range g.lines { + if !line.isKey(key) { + newlines = append(newlines, line) + } + } + g.lines = newlines +} + +func (g *unitGroup) merge(source *unitGroup) { + for _, l := range source.comments { + g.comments = append(g.comments, l.dup()) + } + for _, l := range source.lines { + g.lines = append(g.lines, l.dup()) + } +} + +func NewUnitFile() *UnitFile { + f := &UnitFile{ + groups: make([]*unitGroup, 0), + groupByName: make(map[string]*unitGroup), + } + + return f +} + +func ParseUnitFile(pathName string) (*UnitFile, error) { + data, e := os.ReadFile(pathName) + if e != nil { + return nil, e + } + + f := NewUnitFile() + f.Path = pathName + f.Filename = path.Base(pathName) + + if e := f.Parse(string(data)); e != nil { + return nil, e + } + + return f, nil +} + +func (f *UnitFile) ensureGroup(groupName string) *unitGroup { + if g, ok := f.groupByName[groupName]; ok { + return g + } + + g := newUnitGroup(groupName) + f.groups = append(f.groups, g) + f.groupByName[groupName] = g + + return g +} + +func (f *UnitFile) merge(source *UnitFile) { + for _, srcGroup := range source.groups { + group := f.ensureGroup(srcGroup.name) + group.merge(srcGroup) + } +} + +func (f *UnitFile) Dup() *UnitFile { + copy := NewUnitFile() + + copy.merge(f) + copy.Filename = f.Filename + return copy +} + +func lineIsComment(line string) bool { + return len(line) == 0 || line[0] == '#' || line[0] == ':' +} + +func lineIsGroup(line string) bool { + if len(line) == 0 { + return false + } + + if line[0] != '[' { + return false + } + + end := strings.Index(line, "]") + if end == -1 { + return false + } + + // silently accept whitespace after the ] + for i := end + 1; i < len(line); i++ { + if line[i] != ' ' && line[i] != '\t' { + return false + } + } + + return true +} + +func lineIsKeyValuePair(line string) bool { + if len(line) == 0 { + return false + } + + p := strings.IndexByte(line, '=') + if p == -1 { + return false + } + + // Key must be non-empty + if p == 0 { + return false + } + + return true +} + +func groupNameIsValid(name string) bool { + if len(name) == 0 { + return false + } + + for _, c := range name { + if c == ']' || c == '[' || unicode.IsControl(c) { + return false + } + } + + return true +} + +func keyNameIsValid(name string) bool { + if len(name) == 0 { + return false + } + + for _, c := range name { + if c == '=' { + return false + } + } + + // No leading/trailing space + if name[0] == ' ' || name[len(name)-1] == ' ' { + return false + } + + return true +} + +func (p *UnitFileParser) parseComment(line string) error { + l := newUnitLine("", line, true) + p.pendingComments = append(p.pendingComments, l) + return nil +} + +func (p *UnitFileParser) parseGroup(line string) error { + end := strings.Index(line, "]") + + groupName := line[1:end] + + if !groupNameIsValid(groupName) { + return fmt.Errorf("invalid group name: %s", groupName) + } + + p.currentGroup = p.file.ensureGroup(groupName) + + if p.pendingComments != nil { + firstComment := p.pendingComments[0] + + // Remove one newline between groups, which is re-added on + // printing, see unitGroup.Write() + if firstComment.isEmpty() { + p.pendingComments = p.pendingComments[1:] + } + + p.flushPendingComments(true) + } + + return nil +} + +func (p *UnitFileParser) parseKeyValuePair(line string) error { + if p.currentGroup == nil { + return fmt.Errorf("key file does not start with a group") + } + + keyEnd := strings.Index(line, "=") + valueStart := keyEnd + 1 + + // Pull the key name from the line (chomping trailing whitespace) + for keyEnd > 0 && unicode.IsSpace(rune(line[keyEnd-1])) { + keyEnd-- + } + key := line[:keyEnd] + if !keyNameIsValid(key) { + return fmt.Errorf("invalid key name: %s", key) + } + + // Pull the value from the line (chugging leading whitespace) + + for valueStart < len(line) && unicode.IsSpace(rune(line[valueStart])) { + valueStart++ + } + + value := line[valueStart:] + + p.flushPendingComments(false) + + p.currentGroup.add(key, value) + + return nil +} + +func (p *UnitFileParser) parseLine(line string) error { + switch { + case lineIsComment(line): + return p.parseComment(line) + case lineIsGroup(line): + return p.parseGroup(line) + case lineIsKeyValuePair(line): + return p.parseKeyValuePair(line) + default: + return fmt.Errorf("file contains line %d: ā€œ%sā€ which is not a key-value pair, group, or comment", p.lineNr, line) + } +} + +func (p *UnitFileParser) flushPendingComments(toComment bool) { + pending := p.pendingComments + if pending == nil { + return + } + p.pendingComments = nil + + for _, pendingLine := range pending { + if toComment { + p.currentGroup.addComment(pendingLine) + } else { + p.currentGroup.addLine(pendingLine) + } + } +} + +func nextLine(data string, afterPos int) (string, string) { + rest := data[afterPos:] + if i := strings.Index(rest, "\n"); i >= 0 { + return data[:i+afterPos], data[i+afterPos+1:] + } + return data, "" +} + +func (f *UnitFile) Parse(data string) error { + p := &UnitFileParser{ + file: f, + lineNr: 1, + } + for len(data) > 0 { + origdata := data + nLines := 1 + var line string + line, data = nextLine(data, 0) + + // Handle multi-line continuations + // Note: This doesn't support coments in the middle of the continuation, which systemd does + if lineIsKeyValuePair(line) { + for len(data) > 0 && line[len(line)-1] == '\\' { + line, data = nextLine(origdata, len(line)+1) + nLines++ + } + } + + if err := p.parseLine(line); err != nil { + return err + } + + p.lineNr += nLines + } + + if p.currentGroup == nil { + // For files without groups, add an empty group name used only for initial comments + p.currentGroup = p.file.ensureGroup("") + } + p.flushPendingComments(false) + + return nil +} + +func (l *unitLine) write(w io.Writer) error { + if l.isComment { + if _, err := fmt.Fprintf(w, "%s\n", l.value); err != nil { + return err + } + } else { + if _, err := fmt.Fprintf(w, "%s=%s\n", l.key, l.value); err != nil { + return err + } + } + + return nil +} + +func (g *unitGroup) write(w io.Writer) error { + for _, c := range g.comments { + if err := c.write(w); err != nil { + return err + } + } + + if g.name == "" { + // Empty name groups are not valid, but used interally to handle comments in empty files + return nil + } + + if _, err := fmt.Fprintf(w, "[%s]\n", g.name); err != nil { + return err + } + + for _, l := range g.lines { + if err := l.write(w); err != nil { + return err + } + } + + return nil +} + +func (f *UnitFile) Write(w io.Writer) error { + for i, g := range f.groups { + // We always add a newline between groups, and strip one if it exists during + // parsing. This looks nicer, and avoids issues of duplicate newlines when + // merging groups or missing ones when creating new groups + if i != 0 { + if _, err := io.WriteString(w, "\n"); err != nil { + return err + } + } + + if err := g.write(w); err != nil { + return err + } + } + + return nil +} + +func (f *UnitFile) ToString() (string, error) { + var str strings.Builder + if err := f.Write(&str); err != nil { + return "", err + } + return str.String(), nil +} + +func applyLineContinuation(raw string) string { + if !strings.Contains(raw, "\\\n") { + return raw + } + + var str strings.Builder + + for len(raw) > 0 { + if first, rest, found := strings.Cut(raw, "\\\n"); found { + str.WriteString(first) + raw = rest + } else { + str.WriteString(raw) + raw = "" + } + } + + return str.String() +} + +func (f *UnitFile) HasGroup(groupName string) bool { + _, ok := f.groupByName[groupName] + return ok +} + +func (f *UnitFile) RemoveGroup(groupName string) { + g, ok := f.groupByName[groupName] + if ok { + delete(f.groupByName, groupName) + + newgroups := make([]*unitGroup, 0, len(f.groups)) + for _, oldgroup := range f.groups { + if oldgroup != g { + newgroups = append(newgroups, oldgroup) + } + } + f.groups = newgroups + } +} + +func (f *UnitFile) RenameGroup(groupName string, newName string) { + group, okOld := f.groupByName[groupName] + if !okOld { + return + } + + newGroup, okNew := f.groupByName[newName] + if !okNew { + // New group doesn't exist, just rename in-place + delete(f.groupByName, groupName) + group.name = newName + f.groupByName[newName] = group + } else if group != newGroup { + /* merge to existing group and delete old */ + newGroup.merge(group) + f.RemoveGroup(groupName) + } +} + +func (f *UnitFile) ListGroups() []string { + groups := make([]string, len(f.groups)) + for i, group := range f.groups { + groups[i] = group.name + } + return groups +} + +func (f *UnitFile) ListKeys(groupName string) []string { + g, ok := f.groupByName[groupName] + if !ok { + return make([]string, 0) + } + + hash := make(map[string]struct{}) + keys := make([]string, 0, len(g.lines)) + for _, line := range g.lines { + if !line.isComment { + if _, ok := hash[line.key]; !ok { + keys = append(keys, line.key) + hash[line.key] = struct{}{} + } + } + } + + return keys +} + +// Last instance of the key wins +// Can have trailing space +// Raw == contains continuations +func (f *UnitFile) LookupLastRaw(groupName string, key string) (string, bool) { + g, ok := f.groupByName[groupName] + if !ok { + return "", false + } + + line := g.findLast(key) + if line == nil { + return "", false + } + + return line.value, true +} + +func (f *UnitFile) HasKey(groupName string, key string) bool { + _, ok := f.LookupLastRaw(groupName, key) + return ok +} + +// Last instance of the key wins +// Can have trailing space +func (f *UnitFile) LookupLast(groupName string, key string) (string, bool) { + raw, ok := f.LookupLastRaw(groupName, key) + if !ok { + return "", false + } + + return applyLineContinuation(raw), true +} + +func (f *UnitFile) Lookup(groupName string, key string) (string, bool) { + v, ok := f.LookupLast(groupName, key) + if !ok { + return "", false + } + + return strings.TrimRightFunc(v, unicode.IsSpace), true +} + +func (f *UnitFile) LookupBoolean(groupName string, key string, defaultValue bool) bool { + v, ok := f.Lookup(groupName, key) + if !ok { + return defaultValue + } + + return strings.EqualFold(v, "1") || + strings.EqualFold(v, "yes") || + strings.EqualFold(v, "true") || + strings.EqualFold(v, "on") +} + +/* Mimics strol, which is what systemd uses */ +func convertNumber(v string) (int64, error) { + var err error + var intVal int64 + + mult := int64(1) + + if strings.HasPrefix(v, "+") { + v = v[1:] + } else if strings.HasPrefix(v, "-") { + v = v[1:] + mult = int64(-11) + } + + switch { + case strings.HasPrefix(v, "0x") || strings.HasPrefix(v, "0X"): + intVal, err = strconv.ParseInt(v[2:], 16, 64) + case strings.HasPrefix(v, "0"): + intVal, err = strconv.ParseInt(v, 8, 64) + default: + intVal, err = strconv.ParseInt(v, 10, 64) + } + + return intVal * mult, err +} + +func (f *UnitFile) LookupInt(groupName string, key string, defaultValue int64) int64 { + v, ok := f.Lookup(groupName, key) + if !ok { + return defaultValue + } + + intVal, err := convertNumber(v) + + if err != nil { + return defaultValue + } + + return intVal +} + +func (f *UnitFile) LookupUint32(groupName string, key string, defaultValue uint32) uint32 { + v := f.LookupInt(groupName, key, int64(defaultValue)) + if v < 0 || v > math.MaxUint32 { + return defaultValue + } + return uint32(v) +} + +func (f *UnitFile) LookupUID(groupName string, key string, defaultValue uint32) (uint32, error) { + v, ok := f.Lookup(groupName, key) + if !ok { + if defaultValue == math.MaxUint32 { + return 0, fmt.Errorf("no key %s", key) + } + return defaultValue, nil + } + + intVal, err := convertNumber(v) + if err == nil { + /* On linux, uids are uint32 values, that can't be (uint32)-1 (== MAXUINT32)*/ + if intVal < 0 || intVal >= math.MaxUint32 { + return 0, fmt.Errorf("invalid numerical uid '%s'", v) + } + + return uint32(intVal), nil + } + + user, err := user.Lookup(v) + if err != nil { + return 0, err + } + + intVal, err = strconv.ParseInt(user.Uid, 10, 64) + if err != nil { + return 0, err + } + + return uint32(intVal), nil +} + +func (f *UnitFile) LookupGID(groupName string, key string, defaultValue uint32) (uint32, error) { + v, ok := f.Lookup(groupName, key) + if !ok { + if defaultValue == math.MaxUint32 { + return 0, fmt.Errorf("no key %s", key) + } + return defaultValue, nil + } + + intVal, err := convertNumber(v) + if err == nil { + /* On linux, uids are uint32 values, that can't be (uint32)-1 (== MAXUINT32)*/ + if intVal < 0 || intVal >= math.MaxUint32 { + return 0, fmt.Errorf("invalid numerical uid '%s'", v) + } + + return uint32(intVal), nil + } + + group, err := user.LookupGroup(v) + if err != nil { + return 0, err + } + + intVal, err = strconv.ParseInt(group.Gid, 10, 64) + if err != nil { + return 0, err + } + + return uint32(intVal), nil +} + +// Can have trailing space +// Raw == contains line continuations +func (f *UnitFile) LookupAllRaw(groupName string, key string) []string { + g, ok := f.groupByName[groupName] + if !ok { + return make([]string, 0) + } + + values := make([]string, 0) + + for _, line := range g.lines { + if line.isKey(key) { + if len(line.value) == 0 { + // Empty value clears all before + values = make([]string, 0) + } else { + values = append(values, line.value) + } + } + } + + return values +} + +func (f *UnitFile) LookupAll(groupName string, key string) []string { + values := f.LookupAllRaw(groupName, key) + for i, raw := range values { + values[i] = applyLineContinuation(raw) + } + return values +} + +// this splits space separated values similar to the systemd config_parse_strv, merging multiple values into a single vector +func (f *UnitFile) LookupAllStrv(groupName string, key string) []string { + res := make([]string, 0) + values := f.LookupAll(groupName, key) + for _, value := range values { + res, _ = splitStringAppend(res, value, WhitespaceSeparators, SplitRetainEscape|SplitUnquote) + } + return res +} + +// Unescapes exec-like arguments for all instances of the key +func (f *UnitFile) LookupAllArgs(groupName string, key string) []string { + res := make([]string, 0) + argsv := f.LookupAll(groupName, key) + for _, argsS := range argsv { + args, err := splitString(argsS, WhitespaceSeparators, SplitRelax|SplitUnquote|SplitCUnescape) + if err == nil { + res = append(res, args...) + } + } + return res +} + +// Unescapes exec-like arguments for the last instance of the key +func (f *UnitFile) LookupLastArgs(groupName string, key string) ([]string, bool) { + execKey, ok := f.LookupLast(groupName, "Exec") + if ok { + execArgs, err := splitString(execKey, WhitespaceSeparators, SplitRelax|SplitUnquote|SplitCUnescape) + if err == nil { + return execArgs, true + } + } + return nil, false +} + +// Look up 'Environment' style key-value keys +func (f *UnitFile) LookupAllKeyVal(groupName string, key string) map[string]string { + res := make(map[string]string) + allKeyvals := f.LookupAll(groupName, key) + for _, keyvals := range allKeyvals { + assigns, err := splitString(keyvals, WhitespaceSeparators, SplitRelax|SplitUnquote|SplitCUnescape) + if err == nil { + for _, assign := range assigns { + key, value, found := strings.Cut(assign, "=") + if found { + res[key] = value + } + } + } + } + return res +} + +func (f *UnitFile) Set(groupName string, key string, value string) { + group := f.ensureGroup(groupName) + group.set(key, value) +} + +func (f *UnitFile) Setv(groupName string, keyvals ...string) { + group := f.ensureGroup(groupName) + for i := 0; i+1 < len(keyvals); i += 2 { + group.set(keyvals[i], keyvals[i+1]) + } +} + +func (f *UnitFile) Add(groupName string, key string, value string) { + group := f.ensureGroup(groupName) + group.add(key, value) +} + +func (f *UnitFile) AddCmdline(groupName string, key string, args []string) { + f.Add(groupName, key, escapeWords(args)) +} + +func (f *UnitFile) Unset(groupName string, key string) { + group, ok := f.groupByName[groupName] + if ok { + group.unset(key) + } +} + +// Empty group name == first group +func (f *UnitFile) AddComment(groupName string, comments ...string) { + var group *unitGroup + if groupName == "" && len(f.groups) > 0 { + group = f.groups[0] + } else { + // Uses magic "" for first comment-only group if no other groups + group = f.ensureGroup(groupName) + } + + for _, comment := range comments { + group.addComment(newUnitLine("", "# "+comment, true)) + } +} + +func (f *UnitFile) PrependComment(groupName string, comments ...string) { + var group *unitGroup + if groupName == "" && len(f.groups) > 0 { + group = f.groups[0] + } else { + // Uses magic "" for first comment-only group if no other groups + group = f.ensureGroup(groupName) + } + // Prepend in reverse order to keep argument order + for i := len(comments) - 1; i >= 0; i-- { + group.prependComment(newUnitLine("", "# "+comments[i], true)) + } +} diff --git a/pkg/systemdparser/unitfile_test.go b/pkg/systemdparser/unitfile_test.go new file mode 100644 index 000000000000..308901b9668c --- /dev/null +++ b/pkg/systemdparser/unitfile_test.go @@ -0,0 +1,245 @@ +package systemdparser + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +const memcachedService = `# It's not recommended to modify this file in-place, because it will be +# overwritten during upgrades. If you want to customize, the best +# way is to use the "systemctl edit" command to create an override unit. +# +# For example, to pass additional options, create an override unit +# (as is done by systemctl edit) and enter the following: +# +# [Service] +# Environment=OPTIONS="-l 127.0.0.1,::1" + + +[Unit] +Description=memcached daemon +Before=httpd.service +After=network.target + +[Service] +EnvironmentFile=/etc/sysconfig/memcached +ExecStart=/usr/bin/memcached -p ${PORT} -u ${USER} -m ${CACHESIZE} -c ${MAXCONN} $OPTIONS + +# Set up a new file system namespace and mounts private /tmp and /var/tmp +# directories so this service cannot access the global directories and +# other processes cannot access this service's directories. +PrivateTmp=true + +# Mounts the /usr, /boot, and /etc directories read-only for processes +# invoked by this unit. +ProtectSystem=full + +# Ensures that the service process and all its children can never gain new +# privileges +NoNewPrivileges=true + +# Sets up a new /dev namespace for the executed processes and only adds API +# pseudo devices such as /dev/null, /dev/zero or /dev/random (as well as +# the pseudo TTY subsystem) to it, but no physical devices such as /dev/sda. +PrivateDevices=true + +# Required for dropping privileges and running as a different user +CapabilityBoundingSet=CAP_SETGID CAP_SETUID CAP_SYS_RESOURCE + +# Restricts the set of socket address families accessible to the processes +# of this unit. Protects against vulnerabilities such as CVE-2016-8655 +RestrictAddressFamilies=AF_INET AF_INET6 AF_UNIX + + +# Some security features are not in the older versions of systemd used by +# e.g. RHEL7/CentOS 7. The below settings are automatically edited at package +# build time to uncomment them if the target platform supports them. + +# Attempts to create memory mappings that are writable and executable at +# the same time, or to change existing memory mappings to become executable +# are prohibited. +##safer##MemoryDenyWriteExecute=true + +# Explicit module loading will be denied. This allows to turn off module +# load and unload operations on modular kernels. It is recommended to turn +# this on for most services that do not need special file systems or extra +# kernel modules to work. +##safer##ProtectKernelModules=true + +# Kernel variables accessible through /proc/sys, /sys, /proc/sysrq-trigger, +# /proc/latency_stats, /proc/acpi, /proc/timer_stats, /proc/fs and /proc/irq +# will be made read-only to all processes of the unit. Usually, tunable +# kernel variables should only be written at boot-time, with the sysctl.d(5) +# mechanism. Almost no services need to write to these at runtime; it is hence +# recommended to turn this on for most services. +##safer##ProtectKernelTunables=true + +# The Linux Control Groups (cgroups(7)) hierarchies accessible through +# /sys/fs/cgroup will be made read-only to all processes of the unit. +# Except for container managers no services should require write access +# to the control groups hierarchies; it is hence recommended to turn this +# on for most services +##safer##ProtectControlGroups=true + +# Any attempts to enable realtime scheduling in a process of the unit are +# refused. +##safer##RestrictRealtime=true + +# Takes away the ability to create or manage any kind of namespace +##safer##RestrictNamespaces=true + +[Install] +WantedBy=multi-user.target +` + +const systemdloginService = `# SPDX-License-Identifier: LGPL-2.1-or-later +# +# This file is part of systemd. +# +# systemd is free software; you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation; either version 2.1 of the License, or +# (at your option) any later version. + +[Unit] +Description=User Login Management +Documentation=man:sd-login(3) +Documentation=man:systemd-logind.service(8) +Documentation=man:logind.conf(5) +Documentation=man:org.freedesktop.login1(5) + +Wants=user.slice modprobe@drm.service +After=nss-user-lookup.target user.slice modprobe@drm.service + +# Ask for the dbus socket. +Wants=dbus.socket +After=dbus.socket + +[Service] +BusName=org.freedesktop.login1 +CapabilityBoundingSet=CAP_SYS_ADMIN CAP_MAC_ADMIN CAP_AUDIT_CONTROL CAP_CHOWN CAP_DAC_READ_SEARCH CAP_DAC_OVERRIDE CAP_FOWNER CAP_SYS_TTY_CONFIG CAP_LINUX_IMMUTABLE +DeviceAllow=block-* r +DeviceAllow=char-/dev/console rw +DeviceAllow=char-drm rw +DeviceAllow=char-input rw +DeviceAllow=char-tty rw +DeviceAllow=char-vcs rw +ExecStart=/usr/lib/systemd/systemd-logind +FileDescriptorStoreMax=512 +IPAddressDeny=any +LockPersonality=yes +MemoryDenyWriteExecute=yes +NoNewPrivileges=yes +PrivateTmp=yes +ProtectProc=invisible +ProtectClock=yes +ProtectControlGroups=yes +ProtectHome=yes +ProtectHostname=yes +ProtectKernelLogs=yes +ProtectKernelModules=yes +ProtectSystem=strict +ReadWritePaths=/etc /run +Restart=always +RestartSec=0 +RestrictAddressFamilies=AF_UNIX AF_NETLINK +RestrictNamespaces=yes +RestrictRealtime=yes +RestrictSUIDSGID=yes +RuntimeDirectory=systemd/sessions systemd/seats systemd/users systemd/inhibit systemd/shutdown +RuntimeDirectoryPreserve=yes +StateDirectory=systemd/linger +SystemCallArchitectures=native +SystemCallErrorNumber=EPERM +SystemCallFilter=@system-service + + +# Increase the default a bit in order to allow many simultaneous logins since +# we keep one fd open per session. +LimitNOFILE=524288 +` +const systemdnetworkdService = `# SPDX-License-Identifier: LGPL-2.1-or-later +# +# This file is part of systemd. +# +# systemd is free software; you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation; either version 2.1 of the License, or +# (at your option) any later version. + +[Unit] +Description=Network Configuration +Documentation=man:systemd-networkd.service(8) +ConditionCapability=CAP_NET_ADMIN +DefaultDependencies=no +# systemd-udevd.service can be dropped once tuntap is moved to netlink +After=systemd-networkd.socket systemd-udevd.service network-pre.target systemd-sysusers.service systemd-sysctl.service +Before=network.target multi-user.target shutdown.target +Conflicts=shutdown.target +Wants=systemd-networkd.socket network.target + +[Service] +AmbientCapabilities=CAP_NET_ADMIN CAP_NET_BIND_SERVICE CAP_NET_BROADCAST CAP_NET_RAW +BusName=org.freedesktop.network1 +CapabilityBoundingSet=CAP_NET_ADMIN CAP_NET_BIND_SERVICE CAP_NET_BROADCAST CAP_NET_RAW +DeviceAllow=char-* rw +ExecStart=!!/usr/lib/systemd/systemd-networkd +ExecReload=networkctl reload +LockPersonality=yes +MemoryDenyWriteExecute=yes +NoNewPrivileges=yes +ProtectProc=invisible +ProtectClock=yes +ProtectControlGroups=yes +ProtectHome=yes +ProtectKernelLogs=yes +ProtectKernelModules=yes +ProtectSystem=strict +Restart=on-failure +RestartKillSignal=SIGUSR2 +RestartSec=0 +RestrictAddressFamilies=AF_UNIX AF_NETLINK AF_INET AF_INET6 AF_PACKET AF_ALG +RestrictNamespaces=yes +RestrictRealtime=yes +RestrictSUIDSGID=yes +RuntimeDirectory=systemd/netif +RuntimeDirectoryPreserve=yes +SystemCallArchitectures=native +SystemCallErrorNumber=EPERM +SystemCallFilter=@system-service +Type=notify +User=systemd-network + + +[Install] +WantedBy=multi-user.target +Also=systemd-networkd.socket +Alias=dbus-org.freedesktop.network1.service + +# We want to enable systemd-networkd-wait-online.service whenever this service +# is enabled. systemd-networkd-wait-online.service has +# WantedBy=network-online.target, so enabling it only has an effect if +# network-online.target itself is enabled or pulled in by some other unit. +Also=systemd-networkd-wait-online.service +` + +var samples = []string{memcachedService, systemdloginService, systemdnetworkdService} + +func TestRanges_Roundtrip(t *testing.T) { + for i := range samples { + sample := samples[i] + + f := NewUnitFile() + if e := f.Parse(sample); e != nil { + panic(e) + } + + asStr, e := f.ToString() + if e != nil { + panic(e) + } + + assert.Equal(t, sample, asStr) + } +}