nvidia-container-toolkit/internal/nvsandboxutils/gen/nvsandboxutils/generateapi.go
Sananya Majumder aa946f3f59 nvsandboxutils: Add script to generate bindings
This change adds a script and related files to generate the internal
bindings for Sandboxutils library with the help of c-for-go.
This can be used to update the bindings when the header file is modified
with reference to how they are generated with the Makefile in go-nvml.

Run: ./update-bindings.sh

Signed-off-by: Evan Lezar <elezar@nvidia.com>
Signed-off-by: Huy Nguyen <huyn@nvidia.com>
Signed-off-by: Sananya Majumder <sananyam@nvidia.com>
2024-09-24 10:04:48 -07:00

390 lines
9.2 KiB
Go

/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package main
import (
"flag"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io"
"io/fs"
"os"
"path/filepath"
"slices"
"sort"
"strings"
"unicode"
)
type GeneratableInterfacePoperties struct {
Type string
Interface string
Exclude []string
PackageMethodsAliasedFrom string
}
var GeneratableInterfaces = []GeneratableInterfacePoperties{
{
Type: "library",
Interface: "Interface",
PackageMethodsAliasedFrom: "libnvsandboxutils",
},
}
func main() {
sourceDir := flag.String("sourceDir", "", "Path to the source directory for all go files")
output := flag.String("output", "", "Path to the output file (default: stdout)")
flag.Parse()
// Check if required flags are provided
if *sourceDir == "" {
flag.Usage()
return
}
writer, closer, err := getWriter(*output)
if err != nil {
fmt.Printf("Error: %v", err)
return
}
defer func() {
_ = closer()
}()
header, err := generateHeader()
if err != nil {
fmt.Printf("Error: %v", err)
return
}
fmt.Fprint(writer, header)
for i, p := range GeneratableInterfaces {
if p.PackageMethodsAliasedFrom != "" {
comment, err := generatePackageMethodsComment(p)
if err != nil {
fmt.Printf("Error: %v", err)
return
}
fmt.Fprint(writer, comment)
output, err := generatePackageMethods(*sourceDir, p)
if err != nil {
fmt.Printf("Error: %v", err)
return
}
fmt.Fprintf(writer, "%s\n", output)
}
comment, err := generateInterfaceComment(p)
if err != nil {
fmt.Printf("Error: %v", err)
return
}
fmt.Fprint(writer, comment)
output, err := generateInterface(*sourceDir, p)
if err != nil {
fmt.Printf("Error: %v", err)
return
}
fmt.Fprint(writer, output)
if i < (len(GeneratableInterfaces) - 1) {
fmt.Fprint(writer, "\n")
}
}
}
func getWriter(outputFile string) (io.Writer, func() error, error) {
if outputFile == "" {
return os.Stdout, func() error { return nil }, nil
}
file, err := os.Create(outputFile)
if err != nil {
return nil, nil, err
}
return file, file.Close, nil
}
func generateHeader() (string, error) {
lines := []string{
"/**",
"# Copyright 2024 NVIDIA CORPORATION",
"#",
"# Licensed under the Apache License, Version 2.0 (the \"License\");",
"# you may not use this file except in compliance with the License.",
"# You may obtain a copy of the License at",
"#",
"# http://www.apache.org/licenses/LICENSE-2.0",
"#",
"# Unless required by applicable law or agreed to in writing, software",
"# distributed under the License is distributed on an \"AS IS\" BASIS,",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
"# See the License for the specific language governing permissions and",
"# limitations under the License.",
"**/",
"",
"// Generated Code; DO NOT EDIT.",
"",
"package nvsandboxutils",
"",
"",
}
return strings.Join(lines, "\n"), nil
}
func generatePackageMethodsComment(input GeneratableInterfacePoperties) (string, error) {
commentFmt := []string{
"// The variables below represent package level methods from the %s type.",
}
var signature strings.Builder
comment := strings.Join(commentFmt, "\n")
comment = fmt.Sprintf(comment, input.Type)
signature.WriteString(fmt.Sprintf("%s\n", comment))
return signature.String(), nil
}
func generateInterfaceComment(input GeneratableInterfacePoperties) (string, error) {
commentFmt := []string{
"// %s represents the interface for the %s type.",
"//",
"//go:generate moq -out mock/%s.go -pkg mock . %s:%s",
}
var signature strings.Builder
comment := strings.Join(commentFmt, "\n")
comment = fmt.Sprintf(comment, input.Interface, input.Type, strings.ToLower(input.Interface), input.Interface, input.Interface)
signature.WriteString(fmt.Sprintf("%s\n", comment))
return signature.String(), nil
}
func generatePackageMethods(sourceDir string, input GeneratableInterfacePoperties) (string, error) {
var signature strings.Builder
signature.WriteString("var (\n")
methods, err := extractMethodsFromPackage(sourceDir, input)
if err != nil {
return "", err
}
for _, method := range methods {
name := method.Name.Name
formatted := fmt.Sprintf("\t%s = %s.%s\n", name, input.PackageMethodsAliasedFrom, name)
signature.WriteString(formatted)
}
signature.WriteString(")\n")
return signature.String(), nil
}
func generateInterface(sourceDir string, input GeneratableInterfacePoperties) (string, error) {
var signature strings.Builder
signature.WriteString(fmt.Sprintf("type %s interface {\n", input.Interface))
methods, err := extractMethodsFromPackage(sourceDir, input)
if err != nil {
return "", err
}
for _, method := range methods {
formatted := fmt.Sprintf("\t%s\n", formatMethodSignature(method))
signature.WriteString(formatted)
}
signature.WriteString("}\n")
return signature.String(), nil
}
func getGoFiles(sourceDir string) (map[string][]byte, error) {
gofiles := make(map[string][]byte)
err := filepath.WalkDir(sourceDir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() || filepath.Ext(path) != ".go" {
return nil
}
content, err := os.ReadFile(path)
if err != nil {
return err
}
gofiles[path] = content
return nil
})
if err != nil {
return nil, fmt.Errorf("walking %s: %w", sourceDir, err)
}
return gofiles, nil
}
func extractMethodsFromPackage(sourceDir string, input GeneratableInterfacePoperties) ([]*ast.FuncDecl, error) {
gofiles, err := getGoFiles(sourceDir)
if err != nil {
return nil, err
}
var methods []*ast.FuncDecl
for file, content := range gofiles {
m, err := extractMethods(file, content, input)
if err != nil {
return nil, err
}
methods = append(methods, m...)
}
sort.Slice(methods, func(i, j int) bool {
return methods[i].Name.Name < methods[j].Name.Name
})
return methods, nil
}
func extractMethods(sourceFile string, sourceContent []byte, input GeneratableInterfacePoperties) ([]*ast.FuncDecl, error) {
// Parse source file
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, sourceFile, sourceContent, parser.ParseComments)
if err != nil {
return nil, err
}
// Traverse AST to find type declarations and associated methods
var methods []*ast.FuncDecl
for _, decl := range node.Decls {
funcDecl, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
// Check if the function is a method associated with the specified type
if receiverType := funcDecl.Recv; receiverType != nil {
var ident *ast.Ident
for _, field := range receiverType.List {
switch fieldType := field.Type.(type) {
case *ast.Ident:
ident = fieldType
case *ast.StarExpr:
// Update ident if it's a *ast.StarExpr
if newIdent, ok := fieldType.X.(*ast.Ident); ok {
// If the inner type is an *ast.Ident, update ident
ident = newIdent
}
}
// No identifier found
if ident == nil {
continue
}
// Identifier is not the one we are looking for
if ident.Name != input.Type {
continue
}
// Ignore non-public methods
if !isPublic(funcDecl.Name.Name) {
continue
}
// Ignore method in the exclude list
if slices.Contains(input.Exclude, funcDecl.Name.Name) {
continue
}
methods = append(methods, funcDecl)
}
}
}
return methods, nil
}
func formatMethodSignature(decl *ast.FuncDecl) string {
var signature strings.Builder
// Write method name
signature.WriteString(decl.Name.Name)
signature.WriteString("(")
// Write parameters
if decl.Type.Params != nil {
for i, param := range decl.Type.Params.List {
if i > 0 {
signature.WriteString(", ")
}
signature.WriteString(formatFieldList(param))
}
}
signature.WriteString(")")
// Write return types
if decl.Type.Results != nil {
signature.WriteString(" ")
if len(decl.Type.Results.List) > 1 {
signature.WriteString("(")
}
for i, result := range decl.Type.Results.List {
if i > 0 {
signature.WriteString(", ")
}
signature.WriteString(formatFieldList(result))
}
if len(decl.Type.Results.List) > 1 {
signature.WriteString(")")
}
}
return signature.String()
}
func formatFieldList(field *ast.Field) string {
var builder strings.Builder
switch fieldType := field.Type.(type) {
case *ast.Ident:
builder.WriteString(fieldType.Name)
case *ast.ArrayType:
builder.WriteString("[]")
builder.WriteString(formatFieldList(&ast.Field{Type: fieldType.Elt}))
case *ast.StarExpr:
builder.WriteString("*")
builder.WriteString(formatFieldList(&ast.Field{Type: fieldType.X}))
}
return builder.String()
}
func isPublic(name string) bool {
if len(name) == 0 {
return false
}
return unicode.IsUpper([]rune(name)[0])
}