first commit

This commit is contained in:
2025-08-22 17:42:23 -04:00
commit a6c09a5890
120 changed files with 11443 additions and 0 deletions

1260
internal/codegen/codegen.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,41 @@
package codegen
import (
"fmt"
)
// Config is the configuration for the code generation.
type Config struct {
Debug bool
Class string
ValidateOnly bool
methodFilters []string
}
// NewConfig returns a new Config with default values.
func NewConfig() *Config {
return &Config{}
}
// AddMethodFilter adds a method to the list of methodFilters to generate.
func (cfg *Config) AddMethodFilter(methodFilter string) {
cfg.methodFilters = append(cfg.methodFilters, methodFilter)
}
// MethodFilter creates and returns a new method filter for the current config.
func (cfg *Config) MethodFilter() *MethodFilter {
return NewMethodFilter(cfg.methodFilters)
}
// Validate validates the Config and returns an error if there's any problem.
func (cfg *Config) Validate() error {
if cfg == nil {
return fmt.Errorf("config is nil")
}
if cfg.Class == "" {
return fmt.Errorf("generated classes may not be empty")
}
return nil
}

View File

@ -0,0 +1,28 @@
package codegen
// MethodFilter is a filter for methods to be generated.
type MethodFilter struct {
filters []string
}
// NewMethodFilter creates a new MethodFilter.
func NewMethodFilter(filters []string) *MethodFilter {
return &MethodFilter{filters}
}
// Filter returns true if the method matches one of the filters.
// In case no filter matches the method, the method is allowed.
func (md *MethodFilter) Filter(method string) bool {
for _, filter := range md.filters {
result := true
if filter[0] == '!' {
filter = filter[1:]
result = false
}
if filter == "*" || filter == method {
return result
}
}
return true // everything matches by default
}

View File

@ -0,0 +1,311 @@
package codegen
import (
"embed"
"strings"
"text/template"
"github.com/saltosystems/winrt-go/internal/winmd"
)
type genDataFile struct {
Filename string
Data genData
}
type genData struct {
Package string
Imports []string
Classes []*genClass
Enums []*genEnum
Interfaces []*genInterface
Structs []*genStruct
Delegates []*genDelegate
}
func (g *genData) ComputeImports(typeDef *winmd.TypeDef) {
// gather all imports
imports := make([]*genImport, 0)
if g.Classes != nil {
for _, c := range g.Classes {
imports = append(imports, c.GetRequiredImports()...)
}
}
if g.Interfaces != nil {
for _, i := range g.Interfaces {
imports = append(imports, i.GetRequiredImports()...)
}
}
for _, i := range imports {
if typeDef.TypeNamespace != i.Namespace {
g.Imports = append(g.Imports, i.ToGoImport())
}
}
}
type genInterface struct {
Name string
GUID string
Signature string
Funcs []*genFunc
}
func (g *genInterface) GetRequiredImports() []*genImport {
imports := make([]*genImport, 0)
for _, f := range g.Funcs {
imports = append(imports, f.RequiresImports...)
}
return imports
}
type genClass struct {
Name string
Signature string
RequiresImports []*genImport
FullyQualifiedName string
ImplInterfaces []*genInterface
ExclusiveInterfaces []*genInterface
HasEmptyConstructor bool
IsAbstract bool
}
func (g *genClass) GetRequiredImports() []*genImport {
imports := make([]*genImport, 0)
if g.RequiresImports != nil {
imports = append(imports, g.RequiresImports...)
}
if g.ExclusiveInterfaces != nil {
for _, i := range g.ExclusiveInterfaces {
imports = append(imports, i.GetRequiredImports()...)
}
}
return imports
}
type genDelegate struct {
Name string
GUID string
Signature string
InParams []*genParam
ReturnParam *genParam // this may be nil
}
type genEnum struct {
Name string
Type string
Signature string
Values []*genEnumValue
}
type genEnumValue struct {
Name string
Value string
}
type genFunc struct {
Name string
RequiresImports []*genImport
Implement bool
FuncOwner string
InParams []*genParam
ReturnParams []*genParam // this may be empty
// ExclusiveTo is the name of the class that this function is exclusive to.
// The funcion will be called statically using the RoGetActivationFactory function.
ExclusiveTo string
RequiresActivation bool
InheritedFrom winmd.QualifiedID
}
type genImport struct {
Namespace, Name string
}
func (i genImport) ToGoImport() string {
if !strings.Contains(i.Namespace, ".") && i.Namespace != "Windows" {
// This is probably a built-in package
return i.Namespace
}
folder := typeToFolder(i.Namespace, i.Name)
return "github.com/saltosystems/winrt-go/" + folder
}
// some of the variables are not public to avoid using them
// by mistake in the code.
type genDefaultValue struct {
value string
isPrimitive bool
}
// some of the variables are not public to avoid using them
// by mistake in the code.
type genParamType struct {
namespace string
name string
IsPointer bool
IsGeneric bool
IsArray bool
IsPrimitive bool
IsEnum bool
UnderlyingEnumType string
defaultValue genDefaultValue
}
// some of the variables are not public to avoid using them
// by mistake in the code.
type genParam struct {
callerPackage string
varName string
Type *genParamType
IsOut bool
}
func (g *genParam) GoVarName() string {
return typeNameToGoName(g.varName, true) // assume all are public
}
func (g *genParam) GoTypeName() string {
if g.Type.IsPrimitive {
return g.Type.name
}
name := typeNameToGoName(g.Type.name, true) // assume all are public
pkg := typePackage(g.Type.namespace, g.Type.name)
if g.callerPackage != pkg {
name = pkg + "." + name
}
return name
}
func (g *genParam) GoDefaultValue() string {
if g.Type.defaultValue.isPrimitive {
return g.Type.defaultValue.value
}
pkg := typePackage(g.Type.namespace, g.Type.name)
if g.callerPackage != pkg {
return pkg + "." + g.Type.defaultValue.value
}
return g.Type.defaultValue.value
}
type genStruct struct {
Name string
Signature string
Fields []*genParam
}
//go:embed templates/*
var templatesFS embed.FS
func getTemplates() (*template.Template, error) {
return template.New("").
Funcs(funcs()).
ParseFS(templatesFS, "templates/*")
}
func funcs() template.FuncMap {
return template.FuncMap{
"funcName": funcName,
"concat": func(a, b []*genParam) []*genParam {
return append(a, b...)
},
"toLower": func(s string) string {
return strings.ToLower(s[:1]) + s[1:]
},
}
}
// funcName is used to generate the name of a function.
func funcName(m genFunc) string {
// There are some special prefixes applied to methods that we need to replace
replacer := strings.NewReplacer(
"get_", "Get",
"put_", "Set",
"add_", "Add",
"remove_", "Remove",
)
name := replacer.Replace(m.Name)
// Add a prefix to static methods to include the owner class of the method.
// This is necessary to avoid conflicts with method names within the same package.
// Static methods are those that are exclusive to a class and require activation.
prefix := ""
if m.ExclusiveTo != "" && m.RequiresActivation {
nsAndName := strings.Split(m.ExclusiveTo, ".")
prefix = typeNameToGoName(nsAndName[len(nsAndName)-1], true)
}
return prefix + name
}
func typeToFolder(ns, name string) string {
fullName := ns
return strings.ToLower(strings.Replace(fullName, ".", "/", -1))
}
func typePackage(ns, name string) string {
sns := strings.Split(ns, ".")
return strings.ToLower(sns[len(sns)-1])
}
func enumName(typeName string, enumName string) string {
return typeName + enumName
}
func typeDefGoName(typeName string, public bool) string {
name := typeName
if isParameterizedName(typeName) {
name = strings.Split(name, "`")[0]
}
if !public {
name = strings.ToLower(name[0:1]) + name[1:]
}
return name
}
func isParameterizedName(typeName string) bool {
// parameterized types contain a '`' followed by the amount of generic parameters in their name.
return strings.Contains(typeName, "`")
}
func typeFilename(typeName string) string {
// public boolean is not relevant, we are going to lower everything
goname := typeDefGoName(typeName, true)
return strings.ToLower(goname)
}
// removes Go reserved words from param names
func cleanReservedWords(name string) string {
switch name {
case "type":
return "mType"
}
return name
}
func typeNameToGoName(typeName string, public bool) string {
name := typeName
if isParameterizedName(typeName) {
name = strings.Split(name, "`")[0]
}
if !public {
name = strings.ToLower(name[0:1]) + name[1:]
}
return name
}

View File

@ -0,0 +1,60 @@
{{if not .IsAbstract}}
const Signature{{.Name}} string = "{{.Signature}}"
type {{.Name}} struct {
ole.IUnknown
}
{{if .HasEmptyConstructor}}
func New{{.Name}}() (*{{.Name}}, error) {
inspectable, err := ole.RoActivateInstance("{{.FullyQualifiedName}}")
if err != nil {
return nil, err
}
return (*{{.Name}})(unsafe.Pointer(inspectable)), nil
}
{{end}}
{{end}}
{{$owner := .Name}}
{{range .ImplInterfaces}}
{{range .Funcs}}
{{if not .Implement}}{{continue}}{{end}}
func (impl *{{$owner}}) {{funcName .}} (
{{- range .InParams -}}
{{/*do not include out parameters, they are used as return values*/ -}}
{{ if .IsOut }}{{continue}}{{ end -}}
{{.GoVarName}} {{template "variabletype.tmpl" . }},
{{- end -}}
)
{{- /* return params */ -}}
( {{range .InParams -}}
{{ if not .IsOut }}{{continue}}{{ end -}}
{{template "variabletype.tmpl" . }},{{end -}}
{{range .ReturnParams}}{{template "variabletype.tmpl" . }},{{end}} error )
{{- /* method body */ -}}
{
itf := impl.MustQueryInterface(ole.NewGUID({{if .InheritedFrom.Namespace}}{{.InheritedFrom.Namespace}}.{{end}}GUID{{.InheritedFrom.Name}}))
defer itf.Release()
v := (*{{if .InheritedFrom.Namespace}}{{.InheritedFrom.Namespace}}.{{end}}{{.InheritedFrom.Name}})(unsafe.Pointer(itf))
return v.{{funcName . -}}
(
{{- range .InParams -}}
{{if .IsOut -}}
{{continue -}}
{{end -}}
{{.GoVarName -}}
,
{{- end -}}
)
}
{{end}}
{{end}}
{{range .ExclusiveInterfaces}}
{{ template "interface.tmpl" .}}
{{end}}

View File

@ -0,0 +1,204 @@
const GUID{{.Name}} string = "{{.GUID}}"
const Signature{{.Name}} string = "{{.Signature}}"
type {{.Name}} struct {
ole.IUnknown
sync.Mutex
refs uintptr
IID ole.GUID
}
type {{.Name}}Vtbl struct {
ole.IUnknownVtbl
Invoke uintptr
}
type {{.Name}}Callback func(instance *{{.Name}},{{- range .InParams -}}
{{.GoVarName}} {{template "variabletype.tmpl" . }},
{{- end -}})
var callbacks{{.Name}} = &{{.Name | toLower}}Callbacks {
mu: &sync.Mutex{},
callbacks: make(map[unsafe.Pointer]{{.Name}}Callback),
}
var releaseChannels{{.Name}} = &{{.Name | toLower}}ReleaseChannels {
mu: &sync.Mutex{},
chans: make(map[unsafe.Pointer]chan struct{}),
}
func New{{.Name}}(iid *ole.GUID, callback {{.Name}}Callback) *{{.Name}} {
// create type instance
size := unsafe.Sizeof(*(*{{.Name}})(nil))
instPtr := kernel32.Malloc(size)
inst := (*{{.Name}})(instPtr)
// get the callbacks for the VTable
callbacks := delegate.RegisterCallbacks(instPtr, inst)
// the VTable should also be allocated in the heap
sizeVTable := unsafe.Sizeof(*(*{{.Name}}Vtbl)(nil))
vTablePtr := kernel32.Malloc(sizeVTable)
inst.RawVTable = (*interface{})(vTablePtr)
vTable := (*{{.Name}}Vtbl)(vTablePtr)
vTable.IUnknownVtbl = ole.IUnknownVtbl{
QueryInterface: callbacks.QueryInterface,
AddRef: callbacks.AddRef,
Release: callbacks.Release,
}
vTable.Invoke = callbacks.Invoke
// Initialize all properties: the malloc may contain garbage
inst.IID = *iid // copy contents
inst.Mutex = sync.Mutex{}
inst.refs = 0
callbacks{{.Name}}.add(unsafe.Pointer(inst), callback)
// See the docs in the releaseChannels{{.Name}} struct
releaseChannels{{.Name}}.acquire(unsafe.Pointer(inst))
inst.addRef()
return inst
}
func (r *{{.Name}}) GetIID() *ole.GUID {
return &r.IID
}
// addRef increments the reference counter by one
func (r *{{.Name}}) addRef() uintptr {
r.Lock()
defer r.Unlock()
r.refs++
return r.refs
}
// removeRef decrements the reference counter by one. If it was already zero, it will just return zero.
func (r *{{.Name}}) removeRef() uintptr {
r.Lock()
defer r.Unlock()
if r.refs > 0 {
r.refs--
}
return r.refs
}
func (instance *{{.Name}}) Invoke(instancePtr, rawArgs0, rawArgs1, rawArgs2, rawArgs3, rawArgs4, rawArgs5, rawArgs6, rawArgs7, rawArgs8 unsafe.Pointer) uintptr {
{{range $i, $arg := .InParams -}}
{{- if $arg.Type.IsEnum -}}
{{$arg.GoVarName}}Raw := ({{$arg.Type.UnderlyingEnumType}})(uintptr(rawArgs{{$i}}))
{{- else -}}
{{$arg.GoVarName}}Ptr := rawArgs{{$i}}
{{- end}}
{{end}}
// See the quote above.
{{range .InParams -}}
{{if .Type.IsEnum -}}
{{.GoVarName}} := ({{template "variabletype.tmpl" . }})({{.GoVarName}}Raw)
{{else -}}
{{.GoVarName}} := ({{template "variabletype.tmpl" . }})({{.GoVarName}}Ptr)
{{end -}}
{{end -}}
if callback, ok := callbacks{{.Name}}.get(instancePtr); ok {
callback(instance, {{range .InParams}}{{.GoVarName}},{{end}})
}
return ole.S_OK
}
func (instance *{{.Name}}) AddRef() uintptr {
return instance.addRef()
}
func (instance *{{.Name}}) Release() uintptr {
rem := instance.removeRef()
if rem == 0 {
// We're done.
instancePtr := unsafe.Pointer(instance)
callbacks{{.Name}}.delete(instancePtr)
// stop release channels used to avoid
// https://github.com/golang/go/issues/55015
releaseChannels{{.Name}}.release(instancePtr)
kernel32.Free(unsafe.Pointer(instance.RawVTable))
kernel32.Free(instancePtr)
}
return rem
}
type {{.Name | toLower}}Callbacks struct {
mu *sync.Mutex
callbacks map[unsafe.Pointer]{{.Name}}Callback
}
func (m *{{.Name | toLower}}Callbacks) add(p unsafe.Pointer, v {{.Name}}Callback) {
m.mu.Lock()
defer m.mu.Unlock()
m.callbacks[p] = v
}
func (m *{{.Name | toLower}}Callbacks) get(p unsafe.Pointer) ({{.Name}}Callback, bool) {
m.mu.Lock()
defer m.mu.Unlock()
v, ok := m.callbacks[p]
return v, ok
}
func (m *{{.Name | toLower}}Callbacks) delete(p unsafe.Pointer) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.callbacks, p)
}
// typedEventHandlerReleaseChannels keeps a map with channels
// used to keep a goroutine alive during the lifecycle of this object.
// This is required to avoid causing a deadlock error.
// See this: https://github.com/golang/go/issues/55015
type {{.Name | toLower}}ReleaseChannels struct {
mu *sync.Mutex
chans map[unsafe.Pointer]chan struct{}
}
func (m *{{.Name | toLower}}ReleaseChannels) acquire(p unsafe.Pointer) {
m.mu.Lock()
defer m.mu.Unlock()
c := make(chan struct{})
m.chans[p] = c
go func() {
// we need a timer to trick the go runtime into
// thinking there's still something going on here
// but we are only really interested in <-c
t := time.NewTimer(time.Minute)
for {
select {
case <-t.C:
t.Reset(time.Minute)
case <-c:
t.Stop()
return
}
}
}()
}
func (m *{{.Name | toLower}}ReleaseChannels) release(p unsafe.Pointer) {
m.mu.Lock()
defer m.mu.Unlock()
if c, ok := m.chans[p]; ok {
close(c)
delete(m.chans, p)
}
}

View File

@ -0,0 +1,6 @@
type {{.Name}} {{.Type}}
const Signature{{.Name}} string = "{{.Signature}}"
const ({{range .Values}}
{{.Name}} {{$.Name}} = {{.Value}}{{end}}
)

View File

@ -0,0 +1,36 @@
// Code generated by winrt-go-gen. DO NOT EDIT.
//go:build windows
//nolint:all
package {{.Package}}
import (
"syscall"
"unsafe"
"github.com/go-ole/go-ole"
"github.com/saltosystems/winrt-go"
"github.com/saltosystems/winrt-go/internal/kernel32"
{{range .Imports}}"{{.}}"
{{end}}
)
{{range .Interfaces}}
{{template "interface.tmpl" .}}
{{end}}
{{range .Classes}}
{{template "class.tmpl" .}}
{{end}}
{{range .Enums}}
{{template "enum.tmpl" .}}
{{end}}
{{range .Structs}}
{{template "struct.tmpl" .}}
{{end}}
{{range .Delegates}}
{{template "delegate.tmpl" .}}
{{end}}

View File

@ -0,0 +1,30 @@
{{if .Implement}}
func {{if and .FuncOwner (not .RequiresActivation)}}
(v *{{.FuncOwner}})
{{- end -}}
{{funcName .}}
{{- /* in params */ -}}
(
{{- range .InParams -}}
{{/*do not include out parameters, they are used as return values*/ -}}
{{ if .IsOut }}{{continue}}{{ end -}}
{{.GoVarName}} {{template "variabletype.tmpl" . }},
{{- end -}}
)
{{- /* return params */ -}}
( {{range .InParams -}}
{{ if not .IsOut }}{{continue}}{{ end -}}
{{template "variabletype.tmpl" . }},{{end -}}
{{range .ReturnParams}}{{template "variabletype.tmpl" . }},{{end}} error )
{{- /* method body */ -}}
{
{{template "funcimpl.tmpl" .}}
}
{{end}}

View File

@ -0,0 +1,86 @@
{{if .RequiresActivation}}{{/*Activate class*/ -}}
inspectable, err := ole.RoGetActivationFactory("{{.ExclusiveTo}}", ole.NewGUID(GUID{{.FuncOwner}}))
if err != nil {
return {{range .ReturnParams -}}
{{.GoDefaultValue}}, {{end}}err
}
v := (*{{.FuncOwner}})(unsafe.Pointer(inspectable))
{{end -}}
{{- /* Declare out variables*/ -}}
{{range (concat .InParams .ReturnParams) -}}
{{ if not .IsOut}}{{continue}}{{end -}}
{{if eq .GoTypeName "string" -}}
var {{.GoVarName}}HStr ole.HString
{{ else -}}
var {{.GoVarName}} {{template "variabletype.tmpl" . -}}
{{if .Type.IsArray}} = make({{template "variabletype.tmpl" . -}}, {{.GoVarName}}Size){{end}}
{{ end -}}
{{ end -}}
{{- /* Convert in variables to winrt types */ -}}
{{range .InParams -}}
{{ if .IsOut}}{{continue}}{{end -}}
{{if eq .GoTypeName "string" -}}
{{.GoVarName}}HStr, err := ole.NewHString({{.GoVarName}})
if err != nil{
return {{range $.InParams}}{{if .IsOut}}{{.GoDefaultValue}}, {{end}}{{end -}}
{{range $.ReturnParams }}{{.GoDefaultValue}}, {{end}}err
}
{{ end -}}
{{ end -}}
hr, _, _ := syscall.SyscallN(
v.VTable().{{funcName .}},
uintptr(unsafe.Pointer(v)), // this
{{range (concat .InParams .ReturnParams) -}}
{{if .Type.IsArray -}}
{{/* Arrays need to pass a pointer to their first element */ -}}
uintptr(unsafe.Pointer(&{{.GoVarName}}[0])), // {{if .IsOut}}out{{else}}in{{end}} {{.GoTypeName}}
{{else if .IsOut -}}
{{if (or .Type.IsPrimitive .Type.IsEnum) -}}
{{if eq .GoTypeName "string" -}}
uintptr(unsafe.Pointer(&{{.GoVarName}}HStr)), // out {{.GoTypeName}}
{{else -}}
uintptr(unsafe.Pointer(&{{.GoVarName}})), // out {{.GoTypeName}}
{{end -}}
{{else -}}
uintptr(unsafe.Pointer(&{{.GoVarName}})), // out {{.GoTypeName}}
{{end -}}
{{else if .Type.IsPointer -}}
uintptr(unsafe.Pointer({{.GoVarName}})), // in {{.GoTypeName}}
{{else if (or .Type.IsPrimitive .Type.IsEnum) -}}
{{ if eq .GoTypeName "bool" -}}
uintptr(*(*byte)(unsafe.Pointer(&{{.GoVarName}}))), // in {{.GoTypeName}}
{{ else if eq .GoTypeName "string" -}}
uintptr({{.GoVarName}}HStr), // in {{.GoTypeName}}
{{else -}}
uintptr({{.GoVarName}}), // in {{.GoTypeName}}
{{end -}}
{{else if .Type.IsGeneric -}}
uintptr({{.GoVarName}}), // in {{.GoTypeName}}
{{else -}}
uintptr(unsafe.Pointer(&{{.GoVarName}})), // in {{.GoTypeName}}
{{end -}}
{{end -}}
)
if hr != 0 {
return {{range .InParams}}{{if .IsOut}}{{.GoDefaultValue}}, {{end}}{{end -}}
{{range .ReturnParams }}{{.GoDefaultValue}}, {{end}}ole.NewError(hr)
}
{{range (concat .InParams .ReturnParams) -}}
{{ if not .IsOut}}{{continue}}{{end -}}
{{if eq .GoTypeName "string" -}}
{{.GoVarName}} := {{.GoVarName}}HStr.String()
ole.DeleteHString({{.GoVarName}}HStr)
{{ end -}}
{{ end -}}
return {{range .InParams}}{{if .IsOut}}{{.GoVarName}}, {{end}}{{end -}}
{{range .ReturnParams }}{{.GoVarName}},{{end}} nil
{{- /* remove trailing white space*/ -}}

View File

@ -0,0 +1,22 @@
const GUID{{.Name}} string = "{{.GUID}}"
const Signature{{.Name}} string = "{{.Signature}}"
type {{.Name}} struct {
ole.IInspectable
}
type {{.Name}}Vtbl struct {
ole.IInspectableVtbl
{{range .Funcs}}
{{funcName .}} uintptr
{{- end}}
}
func (v *{{.Name}}) VTable() *{{.Name}}Vtbl {
return (*{{.Name}}Vtbl)(unsafe.Pointer(v.RawVTable))
}
{{range .Funcs}}
{{template "func.tmpl" .}}
{{end}}

View File

@ -0,0 +1,7 @@
const Signature{{.Name}} string = "{{.Signature}}"
type {{.Name}} struct {
{{range .Fields}}
{{.GoVarName}} {{.GoTypeName}}
{{end}}
}

View File

@ -0,0 +1,5 @@
{{if .Type.IsArray}}[]{{end -}}
{{if .Type.IsPointer}}*{{end -}}
{{.GoTypeName -}}
{{- /*remove trailing whitespace*/ -}}