diff --git a/.gitignore b/.gitignore index b443ae4..4a59bb3 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,6 @@ bin .history .vscode *.conf +*.yaml logs/ +cfg/ diff --git a/cmd/vidai/main.go b/cmd/vidai/main.go index 1834012..6dc4e01 100644 --- a/cmd/vidai/main.go +++ b/cmd/vidai/main.go @@ -2,18 +2,11 @@ package main import ( "context" - "flag" - "fmt" "log" "os" "os/signal" - "runtime/debug" - "strings" - "time" - "github.com/igolaizola/vidai" - "github.com/peterbourgon/ff/v3" - "github.com/peterbourgon/ff/v3/ffcli" + "github.com/igolaizola/vidai/pkg/cli" ) // Build flags @@ -27,183 +20,8 @@ func main() { defer cancel() // Launch command - cmd := newCommand() + cmd := cli.NewCommand(version, commit, date) if err := cmd.ParseAndRun(ctx, os.Args[1:]); err != nil { log.Fatal(err) } } - -func newCommand() *ffcli.Command { - fs := flag.NewFlagSet("vidai", flag.ExitOnError) - - return &ffcli.Command{ - ShortUsage: "vidai [flags] ", - FlagSet: fs, - Exec: func(context.Context, []string) error { - return flag.ErrHelp - }, - Subcommands: []*ffcli.Command{ - newVersionCommand(), - newGenerateCommand(), - newExtendCommand(), - newLoopCommand(), - }, - } -} - -func newVersionCommand() *ffcli.Command { - return &ffcli.Command{ - Name: "version", - ShortUsage: "vidai version", - ShortHelp: "print version", - Exec: func(ctx context.Context, args []string) error { - v := version - if v == "" { - if buildInfo, ok := debug.ReadBuildInfo(); ok { - v = buildInfo.Main.Version - } - } - if v == "" { - v = "dev" - } - versionFields := []string{v} - if commit != "" { - versionFields = append(versionFields, commit) - } - if date != "" { - versionFields = append(versionFields, date) - } - fmt.Println(strings.Join(versionFields, " ")) - return nil - }, - } -} - -func newGenerateCommand() *ffcli.Command { - cmd := "generate" - fs := flag.NewFlagSet(cmd, flag.ExitOnError) - _ = fs.String("config", "", "config file (optional)") - - var cfg vidai.Config - fs.BoolVar(&cfg.Debug, "debug", false, "debug mode") - fs.DurationVar(&cfg.Wait, "wait", 2*time.Second, "wait time between requests") - fs.StringVar(&cfg.Token, "token", "", "runway token") - image := fs.String("image", "", "source image") - text := fs.String("text", "", "source text") - output := fs.String("output", "", "output file (optional, if omitted it won't be saved)") - extend := fs.Int("extend", 0, "extend the video by this many times (optional)") - interpolate := fs.Bool("interpolate", true, "interpolate frames (optional)") - upscale := fs.Bool("upscale", false, "upscale frames (optional)") - watermark := fs.Bool("watermark", false, "add watermark (optional)") - - return &ffcli.Command{ - Name: cmd, - ShortUsage: fmt.Sprintf("vidai %s [flags] ", cmd), - Options: []ff.Option{ - ff.WithConfigFileFlag("config"), - ff.WithConfigFileParser(ff.PlainParser), - ff.WithEnvVarPrefix("VIDAI"), - }, - ShortHelp: fmt.Sprintf("vidai %s command", cmd), - FlagSet: fs, - Exec: func(ctx context.Context, args []string) error { - if cfg.Token == "" { - return fmt.Errorf("token is required") - } - if *image == "" && *text == "" { - return fmt.Errorf("image or text is required") - } - c := vidai.New(&cfg) - id, u, err := c.Generate(ctx, *image, *text, *output, *extend, - *interpolate, *upscale, *watermark) - if err != nil { - return err - } - fmt.Printf("ID: %s URL: %s\n", id, u) - return nil - }, - } -} - -func newExtendCommand() *ffcli.Command { - cmd := "extend" - fs := flag.NewFlagSet(cmd, flag.ExitOnError) - _ = fs.String("config", "", "config file (optional)") - - var cfg vidai.Config - fs.BoolVar(&cfg.Debug, "debug", false, "debug mode") - fs.DurationVar(&cfg.Wait, "wait", 2*time.Second, "wait time between requests") - fs.StringVar(&cfg.Token, "token", "", "runway token") - input := fs.String("input", "", "input video") - output := fs.String("output", "", "output file (optional, if omitted it won't be saved)") - n := fs.Int("n", 1, "extend the video by this many times") - interpolate := fs.Bool("interpolate", true, "interpolate frames (optional)") - upscale := fs.Bool("upscale", false, "upscale frames (optional)") - watermark := fs.Bool("watermark", false, "add watermark (optional)") - - return &ffcli.Command{ - Name: cmd, - ShortUsage: fmt.Sprintf("vidai %s [flags] ", cmd), - Options: []ff.Option{ - ff.WithConfigFileFlag("config"), - ff.WithConfigFileParser(ff.PlainParser), - ff.WithEnvVarPrefix("VIDAI"), - }, - ShortHelp: fmt.Sprintf("vidai %s command", cmd), - FlagSet: fs, - Exec: func(ctx context.Context, args []string) error { - if cfg.Token == "" { - return fmt.Errorf("token is required") - } - if *input == "" { - return fmt.Errorf("input is required") - } - if *n < 1 { - return fmt.Errorf("n must be greater than 0") - } - - c := vidai.New(&cfg) - urls, err := c.Extend(ctx, *input, *output, *n, *interpolate, *upscale, *watermark) - if err != nil { - return err - } - for i, u := range urls { - fmt.Printf("Video URL %d: %s\n", i+1, u) - } - return nil - }, - } -} - -func newLoopCommand() *ffcli.Command { - cmd := "loop" - fs := flag.NewFlagSet(cmd, flag.ExitOnError) - _ = fs.String("config", "", "config file (optional)") - - input := fs.String("input", "", "input video") - output := fs.String("output", "", "output file") - - return &ffcli.Command{ - Name: cmd, - ShortUsage: fmt.Sprintf("vidai %s [flags] ", cmd), - Options: []ff.Option{ - ff.WithConfigFileFlag("config"), - ff.WithConfigFileParser(ff.PlainParser), - ff.WithEnvVarPrefix("VIDAI"), - }, - ShortHelp: fmt.Sprintf("vidai %s command", cmd), - FlagSet: fs, - Exec: func(ctx context.Context, args []string) error { - if *input == "" { - return fmt.Errorf("input is required") - } - if *output == "" { - return fmt.Errorf("output is required") - } - if err := vidai.Loop(ctx, *input, *output); err != nil { - return err - } - return nil - }, - } -} diff --git a/go.mod b/go.mod index 0f30c4c..ebce2d5 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,24 @@ module github.com/igolaizola/vidai go 1.20 -require github.com/peterbourgon/ff/v3 v3.3.0 +require ( + github.com/bogdanfinn/fhttp v0.5.28 + github.com/bogdanfinn/tls-client v1.7.5 + github.com/peterbourgon/ff/v3 v3.3.0 +) + +require ( + github.com/andybalholm/brotli v1.0.5 // indirect + github.com/bogdanfinn/utls v1.6.1 // indirect + github.com/cloudflare/circl v1.3.6 // indirect + github.com/klauspost/compress v1.16.7 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/quic-go/quic-go v0.37.4 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect + github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 // indirect + golang.org/x/crypto v0.14.0 // indirect + golang.org/x/net v0.17.0 // indirect + golang.org/x/sys v0.13.0 // indirect + golang.org/x/text v0.13.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect +) diff --git a/go.sum b/go.sum index bdc8b56..f647f49 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,44 @@ +github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= +github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/bogdanfinn/fhttp v0.5.28 h1:G6thT8s8v6z1IuvXMUsX9QKy3ZHseTQTzxuIhSiaaAw= +github.com/bogdanfinn/fhttp v0.5.28/go.mod h1:oJiYPG3jQTKzk/VFmogH8jxjH5yiv2rrOH48Xso2lrE= +github.com/bogdanfinn/tls-client v1.7.5 h1:R1aTwe5oja5niLnQggzbWnzJEssw9n+3O4kR0H/Tjl4= +github.com/bogdanfinn/tls-client v1.7.5/go.mod h1:pQwF0eqfL0gf0mu8hikvu6deZ3ijSPruJDzEKEnnXjU= +github.com/bogdanfinn/utls v1.6.1 h1:dKDYAcXEyFFJ3GaWaN89DEyjyRraD1qb4osdEK89ass= +github.com/bogdanfinn/utls v1.6.1/go.mod h1:VXIbRZaiY/wHZc6Hu+DZ4O2CgTzjhjCg/Ou3V4r/39Y= +github.com/cloudflare/circl v1.3.6 h1:/xbKIqSHbZXHwkhbrhrt2YOHIwYJlXH94E3tI/gDlUg= +github.com/cloudflare/circl v1.3.6/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= +github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= +github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= +github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/peterbourgon/ff/v3 v3.3.0 h1:PaKe7GW8orVFh8Unb5jNHS+JZBwWUMa2se0HM6/BI24= github.com/peterbourgon/ff/v3 v3.3.0/go.mod h1:zjJVUhx+twciwfDl0zBcFzl4dW8axCRyXE/eKY9RztQ= +github.com/quic-go/quic-go v0.37.4 h1:ke8B73yMCWGq9MfrCCAw0Uzdm7GaViC3i39dsIdDlH4= +github.com/quic-go/quic-go v0.37.4/go.mod h1:YsbH1r4mSHPJcLF4k4zruUkLBqctEMBDR6VPvcYjIsU= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc= +github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5/go.mod h1:2JjD2zLQYH5HO74y5+aE3remJQvl6q4Sn6aWA2wD1Ng= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go new file mode 100644 index 0000000..a1bd1d9 --- /dev/null +++ b/pkg/cli/cli.go @@ -0,0 +1,166 @@ +package cli + +import ( + "context" + "flag" + "fmt" + "runtime/debug" + "strings" + "time" + + "github.com/igolaizola/vidai/pkg/cmd/extend" + "github.com/igolaizola/vidai/pkg/cmd/generate" + "github.com/igolaizola/vidai/pkg/cmd/loop" + "github.com/peterbourgon/ff/v3" + "github.com/peterbourgon/ff/v3/ffcli" + "github.com/peterbourgon/ff/v3/ffyaml" +) + +func NewCommand(version, commit, date string) *ffcli.Command { + fs := flag.NewFlagSet("vidai", flag.ExitOnError) + + return &ffcli.Command{ + ShortUsage: "vidai [flags] ", + FlagSet: fs, + Exec: func(context.Context, []string) error { + return flag.ErrHelp + }, + Subcommands: []*ffcli.Command{ + newVersionCommand(version, commit, date), + newGenerateCommand(), + newExtendCommand(), + newLoopCommand(), + }, + } +} + +func newVersionCommand(version, commit, date string) *ffcli.Command { + return &ffcli.Command{ + Name: "version", + ShortUsage: "vidai version", + ShortHelp: "print version", + Exec: func(ctx context.Context, args []string) error { + v := version + if v == "" { + if buildInfo, ok := debug.ReadBuildInfo(); ok { + v = buildInfo.Main.Version + } + } + if v == "" { + v = "dev" + } + versionFields := []string{v} + if commit != "" { + versionFields = append(versionFields, commit) + } + if date != "" { + versionFields = append(versionFields, date) + } + fmt.Println(strings.Join(versionFields, " ")) + return nil + }, + } +} + +func newGenerateCommand() *ffcli.Command { + cmd := "generate" + fs := flag.NewFlagSet(cmd, flag.ExitOnError) + _ = fs.String("config", "", "config file (optional)") + + var cfg generate.Config + fs.BoolVar(&cfg.Debug, "debug", false, "debug mode") + fs.DurationVar(&cfg.Wait, "wait", 2*time.Second, "wait time between requests") + fs.StringVar(&cfg.Token, "token", "", "runway token") + + fs.StringVar(&cfg.Model, "model", "gen2", "model to use (gen2 or gen3)") + fs.StringVar(&cfg.Image, "image", "", "source image") + fs.StringVar(&cfg.Text, "text", "", "source text") + fs.StringVar(&cfg.Output, "output", "", "output file (optional, if omitted it won't be saved)") + fs.IntVar(&cfg.Extend, "extend", 0, "extend the video by this many times (optional)") + fs.BoolVar(&cfg.Interpolate, "interpolate", true, "interpolate frames (optional)") + fs.BoolVar(&cfg.Upscale, "upscale", false, "upscale frames (optional)") + fs.BoolVar(&cfg.Watermark, "watermark", false, "add watermark (optional)") + fs.IntVar(&cfg.Width, "width", 0, "output video width (optional)") + fs.IntVar(&cfg.Height, "height", 0, "output video height (optional)") + + return &ffcli.Command{ + Name: cmd, + ShortUsage: fmt.Sprintf("vidai %s [flags] ", cmd), + Options: []ff.Option{ + ff.WithConfigFileFlag("config"), + ff.WithConfigFileParser(ffyaml.Parser), + ff.WithEnvVarPrefix("VIDAI"), + }, + ShortHelp: fmt.Sprintf("vidai %s command", cmd), + FlagSet: fs, + Exec: func(ctx context.Context, args []string) error { + return generate.Run(ctx, &cfg) + }, + } +} + +func newExtendCommand() *ffcli.Command { + cmd := "extend" + fs := flag.NewFlagSet(cmd, flag.ExitOnError) + _ = fs.String("config", "", "config file (optional)") + + var cfg extend.Config + fs.BoolVar(&cfg.Debug, "debug", false, "debug mode") + fs.DurationVar(&cfg.Wait, "wait", 2*time.Second, "wait time between requests") + fs.StringVar(&cfg.Token, "token", "", "runway token") + fs.StringVar(&cfg.Input, "input", "", "input video") + fs.StringVar(&cfg.Output, "output", "", "output file (optional, if omitted it won't be saved)") + fs.IntVar(&cfg.N, "n", 1, "extend the video by this many times") + fs.StringVar(&cfg.Model, "model", "gen2", "model to use (gen2 or gen3)") + fs.BoolVar(&cfg.Interpolate, "interpolate", true, "interpolate frames (optional)") + fs.BoolVar(&cfg.Upscale, "upscale", false, "upscale frames (optional)") + fs.BoolVar(&cfg.Watermark, "watermark", false, "add watermark (optional)") + + return &ffcli.Command{ + Name: cmd, + ShortUsage: fmt.Sprintf("vidai %s [flags] ", cmd), + Options: []ff.Option{ + ff.WithConfigFileFlag("config"), + ff.WithConfigFileParser(ffyaml.Parser), + ff.WithEnvVarPrefix("VIDAI"), + }, + ShortHelp: fmt.Sprintf("vidai %s command", cmd), + FlagSet: fs, + Exec: func(ctx context.Context, args []string) error { + return extend.Run(ctx, &cfg) + }, + } +} + +func newLoopCommand() *ffcli.Command { + cmd := "loop" + fs := flag.NewFlagSet(cmd, flag.ExitOnError) + _ = fs.String("config", "", "config file (optional)") + + input := fs.String("input", "", "input video") + output := fs.String("output", "", "output file") + + return &ffcli.Command{ + Name: cmd, + ShortUsage: fmt.Sprintf("vidai %s [flags] ", cmd), + Options: []ff.Option{ + ff.WithConfigFileFlag("config"), + ff.WithConfigFileParser(ffyaml.Parser), + ff.WithEnvVarPrefix("VIDAI"), + }, + ShortHelp: fmt.Sprintf("vidai %s command", cmd), + FlagSet: fs, + Exec: func(ctx context.Context, args []string) error { + if *input == "" { + return fmt.Errorf("input is required") + } + if *output == "" { + return fmt.Errorf("output is required") + } + if err := loop.Run(ctx, *input, *output); err != nil { + return err + } + return nil + }, + } +} diff --git a/pkg/cmd/extend/extend.go b/pkg/cmd/extend/extend.go new file mode 100644 index 0000000..f4bb8aa --- /dev/null +++ b/pkg/cmd/extend/extend.go @@ -0,0 +1,172 @@ +package extend + +import ( + "context" + "fmt" + "io" + "log" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/igolaizola/vidai/pkg/runway" +) + +type Config struct { + Token string + Wait time.Duration + Debug bool + Proxy string + + Input string + Output string + N int + Model string + Interpolate bool + Upscale bool + Watermark bool +} + +// Run generates a video from an image and a text prompt. +func Run(ctx context.Context, cfg *Config) error { + if cfg.Input == "" { + return fmt.Errorf("input is required") + } + if cfg.N < 1 { + return fmt.Errorf("n must be greater than 0") + } + if cfg.Token == "" { + return fmt.Errorf("token is required") + } + client := runway.New(&runway.Config{ + Token: cfg.Token, + Wait: cfg.Wait, + Debug: cfg.Debug, + Proxy: cfg.Proxy, + }) + + base := strings.TrimSuffix(filepath.Base(cfg.Input), filepath.Ext(cfg.Input)) + + // Copy input video to temp file + vid := filepath.Join(os.TempDir(), fmt.Sprintf("%s-0.mp4", base)) + if err := copyFile(cfg.Input, vid); err != nil { + return fmt.Errorf("vidai: couldn't copy input video: %w", err) + } + + videos := []string{vid} + var urls []string + for i := 0; i < cfg.N; i++ { + img := filepath.Join(os.TempDir(), fmt.Sprintf("%s-%d.jpg", base, i)) + + // Extract last frame from video using the following command: + // ffmpeg -sseof -1 -i input.mp4 -update 1 -q:v 1 output.jpg + // This will seek to the last second of the input and output all frames. + // But since -update 1 is set, each frame will be overwritten to the + // same file, leaving only the last frame remaining. + cmd := exec.CommandContext(ctx, "ffmpeg", "-sseof", "-1", "-i", vid, "-update", "1", "-q:v", "1", img) + cmdOut, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("vidai: couldn't extract last frame (%s): %w", string(cmdOut), err) + } + + // Read image + b, err := os.ReadFile(img) + if err != nil { + return fmt.Errorf("vidai: couldn't read image: %w", err) + } + name := filepath.Base(img) + + // Generate video + imageURL, err := client.Upload(ctx, name, b) + if err != nil { + return fmt.Errorf("vidai: couldn't upload image: %w", err) + } + gen, err := client.Generate(ctx, &runway.GenerateRequest{ + Model: cfg.Model, + AssetURL: imageURL, + Prompt: "", + Interpolate: cfg.Interpolate, + Upscale: cfg.Upscale, + Watermark: cfg.Watermark, + Extend: false, + }) + if err != nil { + return fmt.Errorf("vidai: couldn't generate video: %w", err) + } + urls = append(urls, gen.URL) + + // Remove temporary image + if err := os.Remove(img); err != nil { + log.Println(fmt.Errorf("vidai: couldn't remove image: %w", err)) + } + + // Download video to temp file + vid = filepath.Join(os.TempDir(), fmt.Sprintf("%s-%d.mp4", base, i+1)) + if err := client.Download(ctx, gen.URL, vid); err != nil { + return fmt.Errorf("vidai: couldn't download video: %w", err) + } + videos = append(videos, vid) + } + + if cfg.Output != "" { + // Create list of videos + var listData string + for _, v := range videos { + listData += fmt.Sprintf("file '%s'\n", filepath.Base(v)) + } + list := filepath.Join(os.TempDir(), fmt.Sprintf("%s-list.txt", base)) + if err := os.WriteFile(list, []byte(listData), 0644); err != nil { + return fmt.Errorf("vidai: couldn't create list file: %w", err) + } + + // Combine videos using the following command: + // ffmpeg -f concat -safe 0 -i list.txt -c copy output.mp4 + cmd := exec.CommandContext(ctx, "ffmpeg", "-f", "concat", "-safe", "0", "-i", list, "-c", "copy", "-y", cfg.Output) + cmdOut, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("vidai: couldn't combine videos (%s): %w", string(cmdOut), err) + } + + // Remove temporary list file + if err := os.Remove(list); err != nil { + log.Println(fmt.Errorf("vidai: couldn't remove list file: %w", err)) + } + } + + // Remove temporary videos + for _, v := range videos { + if err := os.Remove(v); err != nil { + log.Println(fmt.Errorf("vidai: couldn't remove video: %w", err)) + } + } + + fmt.Println("URLs:") + for _, u := range urls { + fmt.Println(u) + } + return nil +} + +func copyFile(src, dst string) error { + // Open source file + srcFile, err := os.Open(src) + if err != nil { + return fmt.Errorf("vidai: couldn't open source file: %w", err) + } + defer srcFile.Close() + + // Create destination file + dstFile, err := os.Create(dst) + if err != nil { + return fmt.Errorf("vidai: couldn't create destination file: %w", err) + } + defer dstFile.Close() + + // Copy source to destination + if _, err := io.Copy(dstFile, srcFile); err != nil { + return fmt.Errorf("vidai: couldn't copy source to destination: %w", err) + } + return nil +} diff --git a/pkg/cmd/generate/generate.go b/pkg/cmd/generate/generate.go new file mode 100644 index 0000000..2e5086b --- /dev/null +++ b/pkg/cmd/generate/generate.go @@ -0,0 +1,112 @@ +package generate + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/igolaizola/vidai/pkg/runway" +) + +type Config struct { + Token string + Wait time.Duration + Debug bool + Proxy string + + Output string + Model string + Image string + Text string + Extend int + Interpolate bool + Upscale bool + Watermark bool + Width int + Height int +} + +// Run generates a video from an image and a text prompt. +func Run(ctx context.Context, cfg *Config) error { + if cfg.Image == "" && cfg.Text == "" { + return fmt.Errorf("vidai: image or text is required") + } + if cfg.Token == "" { + return fmt.Errorf("token is required") + } + client := runway.New(&runway.Config{ + Token: cfg.Token, + Wait: cfg.Wait, + Debug: cfg.Debug, + Proxy: cfg.Proxy, + }) + + var imageURL string + if cfg.Image != "" { + b, err := os.ReadFile(cfg.Image) + if err != nil { + return fmt.Errorf("vidai: couldn't read image: %w", err) + } + name := filepath.Base(cfg.Image) + + imageURL, err = client.Upload(ctx, name, b) + if err != nil { + return fmt.Errorf("vidai: couldn't upload image: %w", err) + } + } + gen, err := client.Generate(ctx, &runway.GenerateRequest{ + Model: cfg.Model, + AssetURL: imageURL, + Prompt: cfg.Text, + Interpolate: cfg.Interpolate, + Upscale: cfg.Upscale, + Watermark: cfg.Watermark, + Extend: false, + Width: cfg.Width, + Height: cfg.Height, + }) + if err != nil { + return fmt.Errorf("vidai: couldn't generate video: %w", err) + } + + // Extend video + for i := 0; i < cfg.Extend; i++ { + gen, err = client.Generate(ctx, &runway.GenerateRequest{ + Model: cfg.Model, + AssetURL: gen.URL, + Prompt: "", + Interpolate: cfg.Interpolate, + Upscale: cfg.Upscale, + Watermark: cfg.Watermark, + Extend: true, + }) + if err != nil { + return fmt.Errorf("vidai: couldn't extend video: %w", err) + } + } + + // Use temp file if no output is set and we need to extend the video + videoPath := cfg.Output + if videoPath == "" && cfg.Extend > 0 { + base := strings.TrimSuffix(filepath.Base(cfg.Image), filepath.Ext(cfg.Image)) + videoPath = filepath.Join(os.TempDir(), fmt.Sprintf("%s.mp4", base)) + } + + // Download video + if videoPath != "" { + if err := client.Download(ctx, gen.URL, videoPath); err != nil { + return fmt.Errorf("vidai: couldn't download video: %w", err) + } + } + + js, err := json.MarshalIndent(gen, "", " ") + if err != nil { + return fmt.Errorf("vidai: couldn't marshal json: %w", err) + } + fmt.Println(string(js)) + return nil +} diff --git a/pkg/cmd/loop/loop.go b/pkg/cmd/loop/loop.go new file mode 100644 index 0000000..c54800f --- /dev/null +++ b/pkg/cmd/loop/loop.go @@ -0,0 +1,42 @@ +package loop + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" +) + +func Run(ctx context.Context, input, output string) error { + // Reverse video using the following command: + // ffmpeg -i input.mp4 -vf reverse temp.mp4 + tmp := filepath.Join(os.TempDir(), fmt.Sprintf("%s-reversed.mp4", filepath.Base(input))) + cmd := exec.CommandContext(ctx, "ffmpeg", "-i", input, "-vf", "reverse", tmp) + cmdOut, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("vidai: couldn't reverse video (%s): %w", string(cmdOut), err) + } + + // Obtain absolute path to input video + absInput, err := filepath.Abs(input) + if err != nil { + return fmt.Errorf("vidai: couldn't get absolute path to input video: %w", err) + } + + // Generate list of videos + listData := fmt.Sprintf("file '%s'\nfile '%s'\n", absInput, filepath.Base(tmp)) + list := filepath.Join(os.TempDir(), fmt.Sprintf("%s-list.txt", filepath.Base(input))) + if err := os.WriteFile(list, []byte(listData), 0644); err != nil { + return fmt.Errorf("vidai: couldn't create list file: %w", err) + } + + // Combine videos using the following command: + // ffmpeg -f concat -safe 0 -i list.txt -c copy output.mp4 + cmd = exec.CommandContext(ctx, "ffmpeg", "-f", "concat", "-safe", "0", "-i", list, "-c", "copy", "-y", output) + cmdOut, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("vidai: couldn't combine videos (%s): %w", string(cmdOut), err) + } + return nil +} diff --git a/pkg/fhttp/http.go b/pkg/fhttp/http.go new file mode 100644 index 0000000..c218bec --- /dev/null +++ b/pkg/fhttp/http.go @@ -0,0 +1,40 @@ +package fhttp + +import ( + "time" + + tlsclient "github.com/bogdanfinn/tls-client" + "github.com/bogdanfinn/tls-client/profiles" +) + +type Client interface { + tlsclient.HttpClient +} + +type client struct { + tlsclient.HttpClient +} + +func NewClient(timeout time.Duration, useJar bool, proxy string) Client { + jar := tlsclient.NewCookieJar() + secs := int(timeout.Seconds()) + if secs <= 0 { + secs = 30 + } + options := []tlsclient.HttpClientOption{ + tlsclient.WithTimeoutSeconds(secs), + tlsclient.WithClientProfile(profiles.Chrome_120), + tlsclient.WithNotFollowRedirects(), + } + if useJar { + options = append(options, tlsclient.WithCookieJar(jar)) + } + if proxy != "" { + options = append(options, tlsclient.WithProxyUrl(proxy)) + } + c, err := tlsclient.NewHttpClient(tlsclient.NewNoopLogger(), options...) + if err != nil { + panic(err) + } + return &client{HttpClient: c} +} diff --git a/internal/ratelimit/ratelimit.go b/pkg/ratelimit/ratelimit.go similarity index 100% rename from internal/ratelimit/ratelimit.go rename to pkg/ratelimit/ratelimit.go diff --git a/pkg/runway/client.go b/pkg/runway/client.go new file mode 100644 index 0000000..7428757 --- /dev/null +++ b/pkg/runway/client.go @@ -0,0 +1,239 @@ +package runway + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "strings" + "time" + + http "github.com/bogdanfinn/fhttp" + "github.com/igolaizola/vidai/pkg/fhttp" + "github.com/igolaizola/vidai/pkg/ratelimit" +) + +type Client struct { + client fhttp.Client + debug bool + ratelimit ratelimit.Lock + token string + teamID int +} + +type Config struct { + Token string + Wait time.Duration + Debug bool + Proxy string +} + +func New(cfg *Config) *Client { + wait := cfg.Wait + if wait == 0 { + wait = 1 * time.Second + } + client := fhttp.NewClient(2*time.Minute, true, cfg.Proxy) + return &Client{ + client: client, + ratelimit: ratelimit.New(wait), + debug: cfg.Debug, + token: cfg.Token, + } +} + +func (c *Client) log(format string, args ...interface{}) { + if c.debug { + format += "\n" + log.Printf(format, args...) + } +} + +var backoff = []time.Duration{ + 30 * time.Second, + 1 * time.Minute, + 2 * time.Minute, +} + +func (c *Client) do(ctx context.Context, method, path string, in, out any) ([]byte, error) { + maxAttempts := 3 + attempts := 0 + var err error + for { + if err != nil { + log.Println("retrying...", err) + } + var b []byte + b, err = c.doAttempt(ctx, method, path, in, out) + if err == nil { + return b, nil + } + // Increase attempts and check if we should stop + attempts++ + if attempts >= maxAttempts { + return nil, err + } + // If the error is temporary retry + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + continue + } + // Check status code + var errStatus errStatusCode + if errors.As(err, &errStatus) { + switch int(errStatus) { + // These errors are retriable but we should wait before retry + case http.StatusBadGateway, http.StatusGatewayTimeout, http.StatusTooManyRequests, http.StatusInternalServerError, 520, 522: + default: + return nil, err + } + + idx := attempts - 1 + if idx >= len(backoff) { + idx = len(backoff) - 1 + } + wait := backoff[idx] + c.log("server seems to be down, waiting %s before retrying\n", wait) + t := time.NewTimer(wait) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-t.C: + } + continue + } + return nil, err + } +} + +type errStatusCode int + +func (e errStatusCode) Error() string { + return fmt.Sprintf("%d", e) +} + +func (c *Client) doAttempt(ctx context.Context, method, path string, in, out any) ([]byte, error) { + var body []byte + var reqBody io.Reader + contentType := "application/json" + if f, ok := in.(*uploadFile); ok { + body = f.data + ext := f.extension + if ext == "jpg" { + ext = "jpeg" + } + contentType = fmt.Sprintf("image/%s", ext) + reqBody = bytes.NewReader(body) + } else if in != nil { + var err error + body, err = json.Marshal(in) + if err != nil { + return nil, fmt.Errorf("runway: couldn't marshal request body: %w", err) + } + reqBody = bytes.NewReader(body) + } + logBody := string(body) + /*if len(logBody) > 100 { + logBody = logBody[:100] + "..." + }*/ + c.log("runway: do %s %s %s", method, path, logBody) + + // Check if path is absolute + u := fmt.Sprintf("https://api.runwayml.com/v1/%s", path) + var uploadLen int + if strings.HasPrefix(path, "http") { + u = path + uploadLen = len(body) + } + req, err := http.NewRequestWithContext(ctx, method, u, reqBody) + if err != nil { + return nil, fmt.Errorf("runway: couldn't create request: %w", err) + } + c.addHeaders(req, path, contentType, uploadLen) + + unlock := c.ratelimit.Lock(ctx) + defer unlock() + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("runway: couldn't %s %s: %w", method, u, err) + } + defer resp.Body.Close() + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("runway: couldn't read response body: %w", err) + } + c.log("runway: response %s %s %d %s", method, path, resp.StatusCode, string(respBody)) + if resp.StatusCode != http.StatusOK { + errMessage := string(respBody) + if len(errMessage) > 100 { + errMessage = errMessage[:100] + "..." + } + _ = os.WriteFile(fmt.Sprintf("logs/debug_%s.json", time.Now().Format("20060102_150405")), respBody, 0644) + return nil, fmt.Errorf("runway: %s %s returned (%s): %w", method, u, errMessage, errStatusCode(resp.StatusCode)) + } + if out != nil { + if err := json.Unmarshal(respBody, out); err != nil { + // Write response body to file for debugging. + _ = os.WriteFile(fmt.Sprintf("logs/debug_%s.json", time.Now().Format("20060102_150405")), respBody, 0644) + return nil, fmt.Errorf("runway: couldn't unmarshal response body (%T): %w", out, err) + } + } + return respBody, nil +} + +func (c *Client) addHeaders(req *http.Request, path, contentType string, uploadLen int) { + switch { + case uploadLen > 0: + req.Header.Set("accept", "*/*") + req.Header.Set("accept-language", "en-US,en;q=0.9") + req.Header.Set("content-length", fmt.Sprintf("%d", uploadLen)) + req.Header.Set("content-type", contentType) + req.Header.Set("connection", "keep-alive") + req.Header.Set("origin", "https://app.runwayml.com") + req.Header.Set("priority", "u=1, i") + req.Header.Set("referer", "https://app.runwayml.com/") + req.Header.Set("sec-ch-ua", `"Not/A)Brand";v="8", "Chromium";v="126", "Google Chrome";v="126"`) + req.Header.Set("sec-ch-ua-mobile", "?0") + req.Header.Set("sec-ch-ua-platform", "\"Windows\"") + req.Header.Set("sec-fetch-dest", "empty") + req.Header.Set("sec-fetch-mode", "cors") + req.Header.Set("sec-fetch-site", "cross-site") + req.Header.Set("user-agent", `Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36`) + case !strings.HasPrefix(path, "http"): + req.Header.Set("accept", "application/json") + req.Header.Set("accept-language", "en-US,en;q=0.9") + req.Header.Set("authorization", fmt.Sprintf("Bearer %s", c.token)) + req.Header.Set("content-type", contentType) + req.Header.Set("origin", "https://app.runwayml.com") + req.Header.Set("priority", "u=1, i") + req.Header.Set("referer", "https://app.runwayml.com/") + req.Header.Set("sec-ch-ua", `"Not/A)Brand";v="8", "Chromium";v="126", "Google Chrome";v="126"`) + req.Header.Set("sec-ch-ua-mobile", "?0") + req.Header.Set("sec-ch-ua-platform", "\"Windows\"") + req.Header.Set("sec-fetch-dest", "empty") + req.Header.Set("sec-fetch-mode", "cors") + req.Header.Set("sec-fetch-site", "same-site") + req.Header.Set("user-agent", `Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36`) + // TODO: Add sentry trace if needed. + // req.Header.Set("Sentry-Trace", "TODO") + default: + req.Header.Set("accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8") + req.Header.Set("accept-language", "en-US,en;q=0.9") + req.Header.Set("origin", fmt.Sprintf("%s://%s", req.URL.Scheme, req.URL.Host)) + req.Header.Set("priority", "u=1, i") + req.Header.Set("referer", "https://app.runwayml.com/") + req.Header.Set("sec-ch-ua", `"Not/A)Brand";v="8", "Chromium";v="126", "Google Chrome";v="126"`) + req.Header.Set("sec-ch-ua-mobile", "?0") + req.Header.Set("sec-ch-ua-platform", "\"Windows\"") + req.Header.Set("sec-fetch-dest", "empty") + req.Header.Set("sec-fetch-mode", "cors") + req.Header.Set("sec-fetch-site", "same-site") + req.Header.Set("user-agent", `Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36`) + } +} diff --git a/pkg/runway/runway.go b/pkg/runway/runway.go index 3fbcebd..8ddceea 100644 --- a/pkg/runway/runway.go +++ b/pkg/runway/runway.go @@ -1,59 +1,16 @@ package runway import ( - "bytes" "context" "crypto/md5" - "encoding/json" - "errors" "fmt" - "io" - "log" "math/rand" - "net" - "net/http" "os" "path/filepath" "strings" "time" - - "github.com/igolaizola/vidai/internal/ratelimit" ) -type Client struct { - client *http.Client - debug bool - ratelimit ratelimit.Lock - token string - teamID int -} - -type Config struct { - Token string - Wait time.Duration - Debug bool - Client *http.Client -} - -func New(cfg *Config) *Client { - wait := cfg.Wait - if wait == 0 { - wait = 1 * time.Second - } - client := cfg.Client - if client == nil { - client = &http.Client{ - Timeout: 2 * time.Minute, - } - } - return &Client{ - client: client, - ratelimit: ratelimit.New(wait), - debug: cfg.Debug, - token: cfg.Token, - } -} - type profileResponse struct { User struct { ID int `json:"id"` @@ -83,7 +40,7 @@ func (c *Client) loadTeamID(ctx context.Context) error { return nil } var resp profileResponse - if err := c.do(ctx, "GET", "profile", nil, &resp); err != nil { + if _, err := c.do(ctx, "GET", "profile", nil, &resp); err != nil { return fmt.Errorf("runway: couldn't get profile: %w", err) } if len(resp.User.Organizations) > 0 { @@ -145,7 +102,7 @@ func (c *Client) Upload(ctx context.Context, name string, data []byte) (string, Type: t, } var uploadResp uploadResponse - if err := c.do(ctx, "POST", "uploads", uploadReq, &uploadResp); err != nil { + if _, err := c.do(ctx, "POST", "uploads", uploadReq, &uploadResp); err != nil { return "", fmt.Errorf("runway: couldn't obtain upload url: %w", err) } if len(uploadResp.UploadURLs) == 0 { @@ -154,7 +111,7 @@ func (c *Client) Upload(ctx context.Context, name string, data []byte) (string, // Upload file uploadURL := uploadResp.UploadURLs[0] - if err := c.do(ctx, "PUT", uploadURL, file, nil); err != nil { + if _, err := c.do(ctx, "PUT", uploadURL, file, nil); err != nil { return "", fmt.Errorf("runway: couldn't upload file: %w", err) } @@ -172,7 +129,7 @@ func (c *Client) Upload(ctx context.Context, name string, data []byte) (string, }, } var completeResp uploadCompleteResponse - if err := c.do(ctx, "POST", completeURL, completeReq, &completeResp); err != nil { + if _, err := c.do(ctx, "POST", completeURL, completeReq, &completeResp); err != nil { return "", fmt.Errorf("runway: couldn't complete upload: %w", err) } c.log("runway: upload complete %s", completeResp.URL) @@ -184,7 +141,7 @@ func (c *Client) Upload(ctx context.Context, name string, data []byte) (string, return imageURL, nil } -type createTaskRequest struct { +type createGen2TaskRequest struct { TaskType string `json:"taskType"` Internal bool `json:"internal"` Options struct { @@ -210,23 +167,38 @@ type gen2Options struct { MotionScore int `json:"motion_score"` UseMotionScore bool `json:"use_motion_score"` UseMotionVectors bool `json:"use_motion_vectors"` + Width int `json:"width"` + Height int `json:"height"` +} + +type createGen3TaskRequest struct { + TaskType string `json:"taskType"` + Internal bool `json:"internal"` + Options gen3Options `json:"options"` + AsTeamID int `json:"asTeamId"` +} + +type gen3Options struct { + Name string `json:"name"` + Seconds int `json:"seconds"` + TextPrompt string `json:"text_prompt"` + Seed int `json:"seed"` + ExploreMode bool `json:"exploreMode"` + Watermark bool `json:"watermark"` + EnhancePrompt bool `json:"enhance_prompt"` + Width int `json:"width"` + Height int `json:"height"` + AssetGroupName string `json:"assetGroupName"` } type taskResponse struct { Task struct { - ID string `json:"id"` - Name string `json:"name"` - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - TaskType string `json:"taskType"` - Options struct { - Seconds int `json:"seconds"` - Gen2Options gen2Options `json:"gen2Options"` - Name string `json:"name"` - AssetGroupName string `json:"assetGroupName"` - ExploreMode bool `json:"exploreMode"` - Recording bool `json:"recordingEnabled"` - } `json:"options"` + ID string `json:"id"` + Name string `json:"name"` + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + TaskType string `json:"taskType"` + Options any `json:"options"` Status string `json:"status"` ProgressText string `json:"progressText"` ProgressRatio string `json:"progressRatio"` @@ -241,8 +213,6 @@ type artifact struct { ID string `json:"id"` CreatedAt string `json:"createdAt"` UpdatedAt string `json:"updatedAt"` - Name string `json:"name"` - MediaType string `json:"mediaType"` UserID int `json:"userId"` CreatedBy int `json:"createdBy"` TaskID string `json:"taskId"` @@ -261,60 +231,123 @@ type artifact struct { FrameRate int `json:"frameRate"` Duration float32 `json:"duration"` Dimensions []int `json:"dimensions"` + Size struct { + Width int `json:"width"` + Height int `json:"height"` + } `json:"size"` } `json:"metadata"` } -func (c *Client) Generate(ctx context.Context, assetURL, textPrompt string, interpolate, upscale, watermark, extend bool) (string, string, error) { +type Generation struct { + ID string `json:"id"` + URL string `json:"url"` + PreviewURLs []string `json:"previewUrls"` +} + +type GenerateRequest struct { + Model string + AssetURL string + Prompt string + Interpolate bool + Upscale bool + Watermark bool + Extend bool + Width int + Height int +} + +func (c *Client) Generate(ctx context.Context, cfg *GenerateRequest) (*Generation, error) { // Load team ID if err := c.loadTeamID(ctx); err != nil { - return "", "", fmt.Errorf("runway: couldn't load team id: %w", err) + return nil, fmt.Errorf("runway: couldn't load team id: %w", err) } - // Generate seed - seed := rand.Intn(1000000000) + // Generate seed between 2000000000 and 2999999999 + seed := rand.Intn(1000000000) + 2000000000 var imageURL string var videoURL string - if extend { - videoURL = assetURL + if cfg.Extend { + videoURL = cfg.AssetURL } else { - imageURL = assetURL + imageURL = cfg.AssetURL + } + + width := cfg.Width + height := cfg.Height + if width == 0 || height == 0 { + width = 1280 + height = 768 } // Create task - createReq := &createTaskRequest{ - TaskType: "gen2", - Internal: false, - Options: struct { - Seconds int `json:"seconds"` - Gen2Options gen2Options `json:"gen2Options"` - Name string `json:"name"` - AssetGroupName string `json:"assetGroupName"` - ExploreMode bool `json:"exploreMode"` - }{ - Seconds: 4, - Gen2Options: gen2Options{ - Interpolate: interpolate, + var createReq any + switch cfg.Model { + case "gen2": + name := fmt.Sprintf("Gen-2 %d, %s", seed, cfg.Prompt) + if len(name) > 44 { + name = name[:44] + } + createReq = &createGen2TaskRequest{ + TaskType: "gen2", + Internal: false, + Options: struct { + Seconds int `json:"seconds"` + Gen2Options gen2Options `json:"gen2Options"` + Name string `json:"name"` + AssetGroupName string `json:"assetGroupName"` + ExploreMode bool `json:"exploreMode"` + }{ + Seconds: 4, + Gen2Options: gen2Options{ + Interpolate: cfg.Interpolate, + Seed: seed, + Upscale: cfg.Upscale, + TextPrompt: cfg.Prompt, + Watermark: cfg.Watermark, + ImagePrompt: imageURL, + InitImage: imageURL, + InitVideo: videoURL, + Mode: "gen2", + UseMotionScore: true, + MotionScore: 22, + Width: width, + Height: height, + }, + Name: name, + AssetGroupName: "Generative Video", + ExploreMode: false, + }, + AsTeamID: c.teamID, + } + case "gen3": + name := fmt.Sprintf("Gen-3 Alpha %d, %s", seed, cfg.Prompt) + if len(name) > 44 { + name = name[:44] + } + createReq = &createGen3TaskRequest{ + TaskType: "europa", + Internal: false, + Options: gen3Options{ + Name: name, + Seconds: 10, + TextPrompt: cfg.Prompt, Seed: seed, - Upscale: upscale, - TextPrompt: textPrompt, - Watermark: watermark, - ImagePrompt: imageURL, - InitImage: imageURL, - InitVideo: videoURL, - Mode: "gen2", - UseMotionScore: true, - MotionScore: 22, + ExploreMode: false, + Watermark: cfg.Watermark, + EnhancePrompt: true, + Width: width, + Height: height, + AssetGroupName: "Generative Video", }, - Name: fmt.Sprintf("Gen-2, %d", seed), - AssetGroupName: "Gen-2", - ExploreMode: false, - }, - AsTeamID: c.teamID, + AsTeamID: c.teamID, + } + default: + return nil, fmt.Errorf("runway: unknown model %s", cfg.Model) } var taskResp taskResponse - if err := c.do(ctx, "POST", "tasks", createReq, &taskResp); err != nil { - return "", "", fmt.Errorf("runway: couldn't create task: %w", err) + if _, err := c.do(ctx, "POST", "tasks", createReq, &taskResp); err != nil { + return nil, fmt.Errorf("runway: couldn't create task: %w", err) } // Wait for task to finish @@ -322,28 +355,32 @@ func (c *Client) Generate(ctx context.Context, assetURL, textPrompt string, inte switch taskResp.Task.Status { case "SUCCEEDED": if len(taskResp.Task.Artifacts) == 0 { - return "", "", fmt.Errorf("runway: no artifacts returned") + return nil, fmt.Errorf("runway: no artifacts returned") } artifact := taskResp.Task.Artifacts[0] if artifact.URL == "" { - return "", "", fmt.Errorf("runway: empty artifact url") + return nil, fmt.Errorf("runway: empty artifact url") } - return artifact.ID, artifact.URL, nil + return &Generation{ + ID: artifact.ID, + URL: artifact.URL, + PreviewURLs: artifact.PreviewURLs, + }, nil case "PENDING", "RUNNING": c.log("runway: task %s: %s", taskResp.Task.ID, taskResp.Task.ProgressRatio) default: - return "", "", fmt.Errorf("runway: task failed: %s", taskResp.Task.Status) + return nil, fmt.Errorf("runway: task failed: %s", taskResp.Task.Status) } select { case <-ctx.Done(): - return "", "", fmt.Errorf("runway: %w", ctx.Err()) + return nil, fmt.Errorf("runway: %w", ctx.Err()) case <-time.After(5 * time.Second): } path := fmt.Sprintf("tasks/%s?asTeamId=%d", taskResp.Task.ID, c.teamID) - if err := c.do(ctx, "GET", path, nil, &taskResp); err != nil { - return "", "", fmt.Errorf("runway: couldn't get task: %w", err) + if _, err := c.do(ctx, "GET", path, nil, &taskResp); err != nil { + return nil, fmt.Errorf("runway: couldn't get task: %w", err) } } } @@ -362,7 +399,7 @@ type assetResponse struct { func (c *Client) DeleteAsset(ctx context.Context, id string) error { path := fmt.Sprintf("assets/%s", id) var resp assetDeleteResponse - if err := c.do(ctx, "DELETE", path, &assetDeleteRequest{}, &resp); err != nil { + if _, err := c.do(ctx, "DELETE", path, &assetDeleteRequest{}, &resp); err != nil { return fmt.Errorf("runway: couldn't delete asset %s: %w", id, err) } if !resp.Success { @@ -374,7 +411,7 @@ func (c *Client) DeleteAsset(ctx context.Context, id string) error { func (c *Client) GetAsset(ctx context.Context, id string) (string, error) { path := fmt.Sprintf("assets/%s", id) var resp assetResponse - if err := c.do(ctx, "GET", path, nil, &resp); err != nil { + if _, err := c.do(ctx, "GET", path, nil, &resp); err != nil { return "", fmt.Errorf("runway: couldn't get asset %s: %w", id, err) } if resp.Asset.URL == "" { @@ -383,166 +420,17 @@ func (c *Client) GetAsset(ctx context.Context, id string) (string, error) { return resp.Asset.URL, nil } -func (c *Client) log(format string, args ...interface{}) { - if c.debug { - format += "\n" - log.Printf(format, args...) - } -} - -var backoff = []time.Duration{ - 30 * time.Second, - 1 * time.Minute, - 2 * time.Minute, -} - -func (c *Client) do(ctx context.Context, method, path string, in, out any) error { - maxAttempts := 3 - attempts := 0 - var err error - for { - if err != nil { - log.Println("retrying...", err) - } - err = c.doAttempt(ctx, method, path, in, out) - if err == nil { - return nil - } - // Increase attempts and check if we should stop - attempts++ - if attempts >= maxAttempts { - return err - } - // If the error is temporary retry - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - continue - } - // Check status code - var errStatus errStatusCode - if errors.As(err, &errStatus) { - switch int(errStatus) { - // These errors are retriable but we should wait before retry - case http.StatusBadGateway, http.StatusGatewayTimeout, http.StatusTooManyRequests: - default: - return err - } - - idx := attempts - 1 - if idx >= len(backoff) { - idx = len(backoff) - 1 - } - wait := backoff[idx] - c.log("server seems to be down, waiting %s before retrying\n", wait) - t := time.NewTimer(wait) - select { - case <-ctx.Done(): - return ctx.Err() - case <-t.C: - } - continue - } - return err - } -} - -type errStatusCode int - -func (e errStatusCode) Error() string { - return fmt.Sprintf("%d", e) -} - -func (c *Client) doAttempt(ctx context.Context, method, path string, in, out any) error { - var body []byte - var reqBody io.Reader - contentType := "application/json" - if f, ok := in.(*uploadFile); ok { - body = f.data - ext := f.extension - if ext == "jpg" { - ext = "jpeg" - } - contentType = fmt.Sprintf("image/%s", ext) - reqBody = bytes.NewReader(body) - } else if in != nil { - var err error - body, err = json.Marshal(in) - if err != nil { - return fmt.Errorf("runway: couldn't marshal request body: %w", err) - } - reqBody = bytes.NewReader(body) - } - logBody := string(body) - if len(logBody) > 100 { - logBody = logBody[:100] + "..." - } - c.log("runway: do %s %s %s", method, path, logBody) - - // Check if path is absolute - u := fmt.Sprintf("https://api.runwayml.com/v1/%s", path) - var uploadLen int - if strings.HasPrefix(path, "http") { - u = path - uploadLen = len(body) - } - req, err := http.NewRequestWithContext(ctx, method, u, reqBody) - if err != nil { - return fmt.Errorf("runway: couldn't create request: %w", err) - } - c.addHeaders(req, contentType, uploadLen) - - unlock := c.ratelimit.Lock(ctx) - defer unlock() - - resp, err := c.client.Do(req) +func (c *Client) Download(ctx context.Context, u, output string) error { + b, err := c.do(ctx, "GET", u, nil, nil) if err != nil { - return fmt.Errorf("runway: couldn't %s %s: %w", method, u, err) + return fmt.Errorf("runway: couldn't download video: %w", err) } - defer resp.Body.Close() - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("runway: couldn't read response body: %w", err) - } - c.log("runway: response %s %s %d %s", method, path, resp.StatusCode, string(respBody)) - if resp.StatusCode != http.StatusOK { - errMessage := string(respBody) - if len(errMessage) > 100 { - errMessage = errMessage[:100] + "..." - } - _ = os.WriteFile(fmt.Sprintf("logs/debug_%s.json", time.Now().Format("20060102_150405")), respBody, 0644) - return fmt.Errorf("runway: %s %s returned (%s): %w", method, u, errMessage, errStatusCode(resp.StatusCode)) + // Write video to output + if err := os.MkdirAll(filepath.Dir(output), 0755); err != nil { + return fmt.Errorf("runway: couldn't create output directory: %w", err) } - if out != nil { - if err := json.Unmarshal(respBody, out); err != nil { - // Write response body to file for debugging. - _ = os.WriteFile(fmt.Sprintf("logs/debug_%s.json", time.Now().Format("20060102_150405")), respBody, 0644) - return fmt.Errorf("runway: couldn't unmarshal response body (%T): %w", out, err) - } + if err := os.WriteFile(output, b, 0644); err != nil { + return fmt.Errorf("runway: couldn't write video to file: %w", err) } return nil } - -func (c *Client) addHeaders(req *http.Request, contentType string, uploadLen int) { - if uploadLen > 0 { - req.Header.Set("Accept", "*/*") - req.Header.Set("Content-Length", fmt.Sprintf("%d", uploadLen)) - req.Header.Set("Sec-Fetch-Site", "cross-site") - } else { - req.Header.Set("Accept", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token)) - req.Header.Set("Sec-Fetch-Site", "same-site") - } - req.Header.Set("Accept-Language", "en-US,en;q=0.9") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Content-Type", contentType) - req.Header.Set("Origin", "https://app.runwayml.com") - req.Header.Set("Referer", "https://app.runwayml.com/") - req.Header.Set("Sec-Ch-Ua", `"Not.A/Brand";v="8", "Chromium";v="114", "Microsoft Edge";v="114"`) - req.Header.Set("Sec-Ch-Ua-Mobile", "?0") - req.Header.Set("Sec-Ch-Ua-Platform", "\"Windows\"") - req.Header.Set("Sec-Fetch-Dest", "empty") - req.Header.Set("Sec-Fetch-Mode", "cors") - // TODO: Add sentry trace if needed. - // req.Header.Set("Sentry-Trace", "TODO") - req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36 Edg/114.0.1823.82") -} diff --git a/vidai.go b/vidai.go index 6046584..de79d06 100644 --- a/vidai.go +++ b/vidai.go @@ -1,269 +1 @@ package vidai - -import ( - "context" - "fmt" - "io" - "log" - "net/http" - "os" - "os/exec" - "path/filepath" - "strings" - "time" - - "github.com/igolaizola/vidai/pkg/runway" -) - -type Client struct { - client *runway.Client - httpClient *http.Client -} - -type Config struct { - Token string - Wait time.Duration - Debug bool - Client *http.Client -} - -func New(cfg *Config) *Client { - httpClient := cfg.Client - if httpClient == nil { - httpClient = &http.Client{ - Timeout: 2 * time.Minute, - } - } - client := runway.New(&runway.Config{ - Token: cfg.Token, - Wait: cfg.Wait, - Debug: cfg.Debug, - Client: httpClient, - }) - return &Client{ - client: client, - httpClient: httpClient, - } -} - -// Generate generates a video from an image and a text prompt. -func (c *Client) Generate(ctx context.Context, image, text, output string, - extend int, interpolate, upscale, watermark bool) (string, string, error) { - b, err := os.ReadFile(image) - if err != nil { - return "", "", fmt.Errorf("vidai: couldn't read image: %w", err) - } - name := filepath.Base(image) - - var imageURL string - if image != "" { - imageURL, err = c.client.Upload(ctx, name, b) - if err != nil { - return "", "", fmt.Errorf("vidai: couldn't upload image: %w", err) - } - } - id, videoURL, err := c.client.Generate(ctx, imageURL, text, interpolate, upscale, watermark, false) - if err != nil { - return "", "", fmt.Errorf("vidai: couldn't generate video: %w", err) - } - - // Extend video - for i := 0; i < extend; i++ { - id, videoURL, err = c.client.Generate(ctx, videoURL, "", interpolate, upscale, watermark, true) - if err != nil { - return "", "", fmt.Errorf("vidai: couldn't extend video: %w", err) - } - } - - // Use temp file if no output is set and we need to extend the video - videoPath := output - if videoPath == "" && extend > 0 { - base := strings.TrimSuffix(filepath.Base(image), filepath.Ext(image)) - videoPath = filepath.Join(os.TempDir(), fmt.Sprintf("%s.mp4", base)) - } - - // Download video - if videoPath != "" { - if err := c.download(ctx, videoURL, videoPath); err != nil { - return "", "", fmt.Errorf("vidai: couldn't download video: %w", err) - } - } - - return id, videoURL, nil -} - -// Extend extends a video using the previous video. -func (c *Client) Extend(ctx context.Context, input, output string, n int, - interpolate, upscale, watermark bool) ([]string, error) { - base := strings.TrimSuffix(filepath.Base(input), filepath.Ext(input)) - - // Copy input video to temp file - vid := filepath.Join(os.TempDir(), fmt.Sprintf("%s-0.mp4", base)) - if err := copyFile(input, vid); err != nil { - return nil, fmt.Errorf("vidai: couldn't copy input video: %w", err) - } - - videos := []string{vid} - var urls []string - for i := 0; i < n; i++ { - img := filepath.Join(os.TempDir(), fmt.Sprintf("%s-%d.jpg", base, i)) - - // Extract last frame from video using the following command: - // ffmpeg -sseof -1 -i input.mp4 -update 1 -q:v 1 output.jpg - // This will seek to the last second of the input and output all frames. - // But since -update 1 is set, each frame will be overwritten to the - // same file, leaving only the last frame remaining. - cmd := exec.CommandContext(ctx, "ffmpeg", "-sseof", "-1", "-i", vid, "-update", "1", "-q:v", "1", img) - cmdOut, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("vidai: couldn't extract last frame (%s): %w", string(cmdOut), err) - } - - // Read image - b, err := os.ReadFile(img) - if err != nil { - return nil, fmt.Errorf("vidai: couldn't read image: %w", err) - } - name := filepath.Base(img) - - // Generate video - imageURL, err := c.client.Upload(ctx, name, b) - if err != nil { - return nil, fmt.Errorf("vidai: couldn't upload image: %w", err) - } - _, videoURL, err := c.client.Generate(ctx, imageURL, "", interpolate, upscale, watermark, false) - if err != nil { - return nil, fmt.Errorf("vidai: couldn't generate video: %w", err) - } - - // Remove temporary image - if err := os.Remove(img); err != nil { - log.Println(fmt.Errorf("vidai: couldn't remove image: %w", err)) - } - - // Download video to temp file - vid = filepath.Join(os.TempDir(), fmt.Sprintf("%s-%d.mp4", base, i+1)) - if err := c.download(ctx, videoURL, vid); err != nil { - return nil, fmt.Errorf("vidai: couldn't download video: %w", err) - } - videos = append(videos, vid) - } - - if output != "" { - // Create list of videos - var listData string - for _, v := range videos { - listData += fmt.Sprintf("file '%s'\n", filepath.Base(v)) - } - list := filepath.Join(os.TempDir(), fmt.Sprintf("%s-list.txt", base)) - if err := os.WriteFile(list, []byte(listData), 0644); err != nil { - return nil, fmt.Errorf("vidai: couldn't create list file: %w", err) - } - - // Combine videos using the following command: - // ffmpeg -f concat -safe 0 -i list.txt -c copy output.mp4 - cmd := exec.CommandContext(ctx, "ffmpeg", "-f", "concat", "-safe", "0", "-i", list, "-c", "copy", "-y", output) - cmdOut, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("vidai: couldn't combine videos (%s): %w", string(cmdOut), err) - } - - // Remove temporary list file - if err := os.Remove(list); err != nil { - log.Println(fmt.Errorf("vidai: couldn't remove list file: %w", err)) - } - } - - // Remove temporary videos - for _, v := range videos { - if err := os.Remove(v); err != nil { - log.Println(fmt.Errorf("vidai: couldn't remove video: %w", err)) - } - } - - return urls, nil -} - -// URL returns the URL of a video. -func (c *Client) URL(ctx context.Context, id string) (string, error) { - return c.client.GetAsset(ctx, id) -} - -func (c *Client) download(ctx context.Context, url, output string) error { - // Create request - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return fmt.Errorf("vidai: couldn't create request: %w", err) - } - resp, err := c.httpClient.Do(req) - if err != nil { - return fmt.Errorf("vidai: couldn't download video: %w", err) - } - defer resp.Body.Close() - - // Write video to output - f, err := os.Create(output) - if err != nil { - return fmt.Errorf("vidai: couldn't create temp file: %w", err) - } - defer f.Close() - if _, err := io.Copy(f, resp.Body); err != nil { - return fmt.Errorf("vidai: couldn't write to temp file: %w", err) - } - return nil -} - -func copyFile(src, dst string) error { - // Open source file - srcFile, err := os.Open(src) - if err != nil { - return fmt.Errorf("vidai: couldn't open source file: %w", err) - } - defer srcFile.Close() - - // Create destination file - dstFile, err := os.Create(dst) - if err != nil { - return fmt.Errorf("vidai: couldn't create destination file: %w", err) - } - defer dstFile.Close() - - // Copy source to destination - if _, err := io.Copy(dstFile, srcFile); err != nil { - return fmt.Errorf("vidai: couldn't copy source to destination: %w", err) - } - return nil -} - -func Loop(ctx context.Context, input, output string) error { - // Reverse video using the following command: - // ffmpeg -i input.mp4 -vf reverse temp.mp4 - tmp := filepath.Join(os.TempDir(), fmt.Sprintf("%s-reversed.mp4", filepath.Base(input))) - cmd := exec.CommandContext(ctx, "ffmpeg", "-i", input, "-vf", "reverse", tmp) - cmdOut, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("vidai: couldn't reverse video (%s): %w", string(cmdOut), err) - } - - // Obtain absolute path to input video - absInput, err := filepath.Abs(input) - if err != nil { - return fmt.Errorf("vidai: couldn't get absolute path to input video: %w", err) - } - - // Generate list of videos - listData := fmt.Sprintf("file '%s'\nfile '%s'\n", absInput, filepath.Base(tmp)) - list := filepath.Join(os.TempDir(), fmt.Sprintf("%s-list.txt", filepath.Base(input))) - if err := os.WriteFile(list, []byte(listData), 0644); err != nil { - return fmt.Errorf("vidai: couldn't create list file: %w", err) - } - - // Combine videos using the following command: - // ffmpeg -f concat -safe 0 -i list.txt -c copy output.mp4 - cmd = exec.CommandContext(ctx, "ffmpeg", "-f", "concat", "-safe", "0", "-i", list, "-c", "copy", "-y", output) - cmdOut, err = cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("vidai: couldn't combine videos (%s): %w", string(cmdOut), err) - } - return nil -}