podman-build/vendor/github.com/containers/libhvee/pkg/wmiext/conversion.go
2025-10-11 12:30:35 +09:00

475 lines
13 KiB
Go

//go:build windows
// +build windows
package wmiext
import (
"errors"
"fmt"
"math"
"reflect"
"strconv"
"strings"
"time"
"unsafe"
"github.com/go-ole/go-ole"
)
var (
unixEpoch = time.Unix(0, 0)
zeroTime = time.Time{}
)
// Automation variants do not follow the OLE rules, instead they use the following mapping:
// sint8 VT_I2 Signed 8-bit integer.
// sint16 VT_I2 Signed 16-bit integer.
// sint32 VT_I4 Signed 32-bit integer.
// sint64 VT_BSTR Signed 64-bit integer in string form. This type follows hexadecimal or decimal format
//
// according to the American National Standards Institute (ANSI) C rules.
//
// real32 VT_R4 4-byte floating-point value that follows the Institute of Electrical and Electronics
//
// Engineers, Inc. (IEEE) standard.
//
// real64 VT_R8 8-byte floating-point value that follows the IEEE standard.
// uint8 VT_UI1 Unsigned 8-bit integer.
// uint16 VT_I4 Unsigned 16-bit integer.
// uint32 VT_I4 Unsigned 32-bit integer.
// uint64 VT_BSTR Unsigned 64-bit integer in string form. This type follows hexadecimal or decimal format
//
// according to ANSI C rules.
// NewAutomationVariant returns a new VARIANT com
//
//gocyclo:ignore
func NewAutomationVariant(value interface{}) (ole.VARIANT, error) {
switch cast := value.(type) {
case bool:
if cast {
return ole.NewVariant(ole.VT_BOOL, 0xffff), nil
} else {
return ole.NewVariant(ole.VT_BOOL, 0), nil
}
case int8:
return ole.NewVariant(ole.VT_I2, int64(cast)), nil
case []int8:
return CreateNumericArrayVariant(cast, ole.VT_I2)
case int16:
return ole.NewVariant(ole.VT_I2, int64(cast)), nil
case []int16:
return CreateNumericArrayVariant(cast, ole.VT_I2)
case int32:
return ole.NewVariant(ole.VT_I4, int64(cast)), nil
case []int32:
return CreateNumericArrayVariant(cast, ole.VT_I4)
case int64:
s := fmt.Sprintf("%d", cast)
return ole.NewVariant(ole.VT_BSTR, int64(uintptr(unsafe.Pointer(ole.SysAllocStringLen(s))))), nil
case []int64:
strs := make([]string, len(cast))
for i, num := range cast {
strs[i] = fmt.Sprintf("%d", num)
}
return CreateStringArrayVariant(strs)
case float32:
return ole.NewVariant(ole.VT_R4, int64(math.Float32bits(cast))), nil
case float64:
return ole.NewVariant(ole.VT_R8, int64(math.Float64bits(cast))), nil
case uint8:
return ole.NewVariant(ole.VT_UI1, int64(cast)), nil
case []uint8:
return CreateNumericArrayVariant(cast, ole.VT_UI1)
case uint16:
return ole.NewVariant(ole.VT_I4, int64(cast)), nil
case []uint16:
return CreateNumericArrayVariant(cast, ole.VT_I4)
case uint32:
return ole.NewVariant(ole.VT_I4, int64(cast)), nil
case []uint32:
return CreateNumericArrayVariant(cast, ole.VT_I4)
case uint64:
s := fmt.Sprintf("%d", cast)
return ole.NewVariant(ole.VT_BSTR, int64(uintptr(unsafe.Pointer(ole.SysAllocStringLen(s))))), nil
case []uint64:
strs := make([]string, len(cast))
for i, num := range cast {
strs[i] = fmt.Sprintf("%d", num)
}
return CreateStringArrayVariant(strs)
// Assume 32 bit for generic (u)ints
case int:
return ole.NewVariant(ole.VT_I4, int64(cast)), nil
case uint:
return ole.NewVariant(ole.VT_I4, int64(cast)), nil
case []int:
return CreateNumericArrayVariant(cast, ole.VT_I4)
case []uint:
return CreateNumericArrayVariant(cast, ole.VT_I4)
case string:
return ole.NewVariant(ole.VT_BSTR, int64(uintptr(unsafe.Pointer(ole.SysAllocStringLen(value.(string)))))), nil
case []string:
if len(cast) == 0 {
return ole.NewVariant(ole.VT_NULL, 0), nil
}
return CreateStringArrayVariant(cast)
case time.Time:
return convertTimeToDataTime(&cast), nil
case *time.Time:
return convertTimeToDataTime(cast), nil
case time.Duration:
return convertDurationToDateTime(cast), nil
case nil:
return ole.NewVariant(ole.VT_NULL, 0), nil
case *ole.IUnknown:
if cast == nil {
return ole.NewVariant(ole.VT_NULL, 0), nil
}
return ole.NewVariant(ole.VT_UNKNOWN, int64(uintptr(unsafe.Pointer(cast)))), nil
case *Instance:
if cast == nil {
return ole.NewVariant(ole.VT_NULL, 0), nil
}
return ole.NewVariant(ole.VT_UNKNOWN, int64(uintptr(unsafe.Pointer(cast.object)))), nil
default:
return ole.VARIANT{}, fmt.Errorf("unsupported type for automation variants %T", value)
}
}
func convertToGoType(variant *ole.VARIANT, outputValue reflect.Value, outputType reflect.Type) (value interface{}, err error) {
if variant.VT&ole.VT_ARRAY == ole.VT_ARRAY {
return convertVariantToArray(variant, outputType)
}
if variant.VT == ole.VT_UNKNOWN {
return convertVariantToStruct(variant, outputType)
}
switch cast := outputValue.Interface().(type) {
case bool:
return variant.Val != 0, nil
case time.Time:
return convertDataTimeToTime(variant)
case *time.Time:
x, err := convertDataTimeToTime(variant)
return &x, err
case time.Duration:
return convertIntervalToDuration(variant)
case uint, uint8, uint16, uint32, uint64, int, int8, int16, int32, int64:
return convertVariantToInt(variant, outputType)
case float32, float64:
return convertVariantToFloat(variant, outputType)
case string:
return variant.ToString(), nil
default:
if variant.VT == ole.VT_NULL {
return nil, nil
}
return nil, fmt.Errorf("could not convert %d to %v", variant.VT, cast)
}
}
func convertInt64ToInt(value int64, outputType reflect.Type) (interface{}, error) {
switch outputType.Kind() {
case reflect.Int:
return int(value), nil
case reflect.Int8:
return int8(value), nil
case reflect.Int16:
return int16(value), nil
case reflect.Int32:
return int32(value), nil
case reflect.Int64:
return int64(value), nil
case reflect.Uint:
return uint(value), nil
case reflect.Uint8:
return uint8(value), nil
case reflect.Uint16:
return uint16(value), nil
case reflect.Uint32:
return uint32(value), nil
case reflect.Uint64:
return uint64(value), nil
default:
return 0, fmt.Errorf("could not convert int64 to %v", outputType)
}
}
func convertStringToInt64(str string, unsigned bool) (int64, error) {
if unsigned {
val, err := strconv.ParseUint(str, 0, 64)
return int64(val), err
}
return strconv.ParseInt(str, 0, 64)
}
func convertVariantToInt(variant *ole.VARIANT, outputType reflect.Type) (interface{}, error) {
var value int64
switch variant.VT {
case ole.VT_NULL:
fallthrough
case ole.VT_BOOL:
fallthrough
case ole.VT_I1, ole.VT_I2, ole.VT_I4, ole.VT_I8, ole.VT_INT:
fallthrough
case ole.VT_UI1, ole.VT_UI2, ole.VT_UI4, ole.VT_UI8, ole.VT_UINT:
value = variant.Val
case ole.VT_R4:
// not necessarily a useful conversion but handle it anyway
value = int64(*(*float32)(unsafe.Pointer(&variant.Val)))
case ole.VT_R8:
value = int64(*(*float64)(unsafe.Pointer(&variant.Val)))
case ole.VT_BSTR:
var err error
value, err = convertStringToInt64(variant.ToString(), outputType.Kind() == reflect.Uint64)
if err != nil {
return value, err
}
default:
return nil, fmt.Errorf("could not convert variant type %d to %v", variant.VT, outputType)
}
return convertInt64ToInt(value, outputType)
}
func convertVariantToFloat(variant *ole.VARIANT, outputType reflect.Type) (interface{}, error) {
var value float64
switch variant.VT {
case ole.VT_NULL:
fallthrough
case ole.VT_BOOL:
fallthrough
case ole.VT_I1, ole.VT_I2, ole.VT_I4, ole.VT_I8, ole.VT_INT:
fallthrough
case ole.VT_UI1, ole.VT_UI2, ole.VT_UI4, ole.VT_UI8, ole.VT_UINT:
value = float64(variant.Val)
case ole.VT_R4:
value = float64(*(*float32)(unsafe.Pointer(&variant.Val)))
case ole.VT_R8:
value = *(*float64)(unsafe.Pointer(&variant.Val))
case ole.VT_BSTR:
var err error
value, err = strconv.ParseFloat(variant.ToString(), 64)
if err != nil {
return value, err
}
default:
return nil, fmt.Errorf("could not convert variant type %d to %v", variant.VT, outputType)
}
if outputType.Kind() == reflect.Float32 {
return float32(value), nil
}
return value, nil
}
func convertVariantToStruct(variant *ole.VARIANT, outputType reflect.Type) (interface{}, error) {
if variant.VT != ole.VT_UNKNOWN {
return nil, fmt.Errorf("could not convert non-IUnknown variant type %d to %v", variant.VT, outputType)
}
ptr := variant.ToIUnknown()
var rawInstance struct {
*ole.IUnknown
*IWbemClassObjectVtbl
}
rawInstance.IUnknown = ptr
rawInstance.IWbemClassObjectVtbl = (*IWbemClassObjectVtbl)(unsafe.Pointer(ptr.RawVTable))
instance := (*Instance)(unsafe.Pointer(&rawInstance))
val := reflect.New(outputType)
err := instance.GetAll(val.Interface())
return val.Elem().Interface(), err
}
func convertVariantToArray(variant *ole.VARIANT, outputType reflect.Type) (interface{}, error) {
if variant.VT&ole.VT_ARRAY != ole.VT_ARRAY {
return nil, fmt.Errorf("could not convert non-array variant type %d to %v", variant.VT, outputType)
}
safeArrayConversion := ole.SafeArrayConversion{Array: *(**ole.SafeArray)(unsafe.Pointer(&variant.Val))}
arrayLen, err := safeArrayConversion.TotalElements(0)
if err != nil {
return nil, err
}
elemVT := (^ole.VT_ARRAY) & variant.VT
slice := reflect.MakeSlice(reflect.SliceOf(outputType.Elem()), int(arrayLen), int(arrayLen))
for i := 0; i < int(arrayLen); i++ {
elemVariant := ole.VARIANT{VT: elemVT}
elemSrc, err := safeArrayGetAsVariantVal(safeArrayConversion.Array, int64(i), elemVariant)
if err != nil {
return nil, err
}
elemVariant.Val = int64(elemSrc)
elemDest, err := convertToGoType(&elemVariant, slice.Index(i), outputType.Elem())
if err != nil {
return nil, err
}
slice.Index(i).Set(reflect.ValueOf(elemDest))
}
return slice.Interface(), nil
}
func convertToGenericValue(variant *ole.VARIANT) interface{} {
var result interface{}
if variant.VT&ole.VT_ARRAY == ole.VT_ARRAY {
safeArrayConversion := ole.SafeArrayConversion{Array: *(**ole.SafeArray)(unsafe.Pointer(&variant.Val))}
result = safeArrayConversion.ToValueArray()
} else {
result = variant.Value()
}
return result
}
func convertTimeToDataTime(time *time.Time) ole.VARIANT {
if time == nil || !time.After(WindowsEpoch) {
return ole.NewVariant(ole.VT_NULL, 0)
}
_, offset := time.Zone()
// convert to minutes
offset /= 60
//yyyymmddHHMMSS.mmmmmmsUUU
s := fmt.Sprintf("%s%+04d", time.Format("20060102150405.000000"), offset)
return ole.NewVariant(ole.VT_BSTR, int64(uintptr(unsafe.Pointer(ole.SysAllocStringLen(s)))))
}
func convertDurationToDateTime(duration time.Duration) ole.VARIANT {
const daySeconds = time.Second * 86400
if duration == 0 {
return ole.NewVariant(ole.VT_NULL, 0)
}
days := duration / daySeconds
duration = duration % daySeconds
hours := duration / time.Hour
duration = duration % time.Hour
mins := duration / time.Minute
duration = duration % time.Minute
seconds := duration / time.Second
duration = duration % time.Second
micros := duration / time.Microsecond
s := fmt.Sprintf("%08d%02d%02d%02d.%06d:000", days, hours, mins, seconds, micros)
return ole.NewVariant(ole.VT_BSTR, int64(uintptr(unsafe.Pointer(ole.SysAllocStringLen(s)))))
}
func extractDateTimeString(variant *ole.VARIANT) (string, error) {
switch variant.VT {
case ole.VT_BSTR:
return variant.ToString(), nil
case ole.VT_NULL:
return "", nil
default:
return "", errors.New("variant not compatible with dateTime field")
}
}
func convertDataTimeToTime(variant *ole.VARIANT) (time.Time, error) {
var err error
dateTime, err := extractDateTimeString(variant)
if err != nil || len(dateTime) == 0 {
return zeroTime, err
}
dLen := len(dateTime)
if dLen < 5 {
return zeroTime, errors.New("invalid datetime string")
}
if strings.HasPrefix(dateTime, "00000000000000.000000") {
// Zero time
return zeroTime, nil
}
zoneStart := dLen - 4
timePortion := dateTime[0:zoneStart]
var zoneMinutes int64
if dateTime[zoneStart] == ':' {
// interval ends in :000
return parseIntervalTime(dateTime)
}
zoneSuffix := dateTime[zoneStart:dLen]
zoneMinutes, err = strconv.ParseInt(zoneSuffix, 10, 0)
if err != nil {
return zeroTime, errors.New("invalid datetime string, zone did not parse")
}
timePortion = fmt.Sprintf("%s%+03d%02d", timePortion, zoneMinutes/60, abs(int(zoneMinutes%60)))
return time.Parse("20060102150405.000000-0700", timePortion)
}
// parseIntervalTime encodes an interval time as an offset to Unix time
// allowing a duration to be computed without precision loss
func parseIntervalTime(interval string) (time.Time, error) {
if len(interval) < 25 || interval[21:22] != ":" {
return time.Time{}, fmt.Errorf("invalid interval time: %s", interval)
}
days, err := parseUintChain(interval[0:8], nil)
hours, err := parseUintChain(interval[8:10], err)
mins, err := parseUintChain(interval[10:12], err)
secs, err := parseUintChain(interval[12:14], err)
micros, err := parseUintChain(interval[15:21], err)
if err != nil {
return time.Time{}, err
}
var stamp uint64 = secs
stamp += days * 86400
stamp += hours * 3600
stamp += mins * 60
return time.Unix(int64(stamp), int64(micros*1000)), nil
}
func convertIntervalToDuration(variant *ole.VARIANT) (time.Duration, error) {
var err error
interval, err := extractDateTimeString(variant)
if err != nil || len(interval) == 0 {
return 0, err
}
t, err := parseIntervalTime(interval)
if err != nil {
return 0, nil
}
return t.Sub(unixEpoch), nil
}
func parseUintChain(str string, err error) (uint64, error) {
if err != nil {
return 0, err
}
return strconv.ParseUint(str, 10, 0)
}
func abs(num int) int {
if num < 0 {
return -num
}
return num
}