475 lines
13 KiB
Go
475 lines
13 KiB
Go
//go:build windows
|
|
// +build windows
|
|
|
|
package wmiext
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/go-ole/go-ole"
|
|
)
|
|
|
|
var (
|
|
unixEpoch = time.Unix(0, 0)
|
|
zeroTime = time.Time{}
|
|
)
|
|
|
|
// Automation variants do not follow the OLE rules, instead they use the following mapping:
|
|
// sint8 VT_I2 Signed 8-bit integer.
|
|
// sint16 VT_I2 Signed 16-bit integer.
|
|
// sint32 VT_I4 Signed 32-bit integer.
|
|
// sint64 VT_BSTR Signed 64-bit integer in string form. This type follows hexadecimal or decimal format
|
|
//
|
|
// according to the American National Standards Institute (ANSI) C rules.
|
|
//
|
|
// real32 VT_R4 4-byte floating-point value that follows the Institute of Electrical and Electronics
|
|
//
|
|
// Engineers, Inc. (IEEE) standard.
|
|
//
|
|
// real64 VT_R8 8-byte floating-point value that follows the IEEE standard.
|
|
// uint8 VT_UI1 Unsigned 8-bit integer.
|
|
// uint16 VT_I4 Unsigned 16-bit integer.
|
|
// uint32 VT_I4 Unsigned 32-bit integer.
|
|
// uint64 VT_BSTR Unsigned 64-bit integer in string form. This type follows hexadecimal or decimal format
|
|
//
|
|
// according to ANSI C rules.
|
|
|
|
// NewAutomationVariant returns a new VARIANT com
|
|
//
|
|
//gocyclo:ignore
|
|
func NewAutomationVariant(value interface{}) (ole.VARIANT, error) {
|
|
switch cast := value.(type) {
|
|
case bool:
|
|
if cast {
|
|
return ole.NewVariant(ole.VT_BOOL, 0xffff), nil
|
|
} else {
|
|
return ole.NewVariant(ole.VT_BOOL, 0), nil
|
|
}
|
|
case int8:
|
|
return ole.NewVariant(ole.VT_I2, int64(cast)), nil
|
|
case []int8:
|
|
return CreateNumericArrayVariant(cast, ole.VT_I2)
|
|
case int16:
|
|
return ole.NewVariant(ole.VT_I2, int64(cast)), nil
|
|
case []int16:
|
|
return CreateNumericArrayVariant(cast, ole.VT_I2)
|
|
case int32:
|
|
return ole.NewVariant(ole.VT_I4, int64(cast)), nil
|
|
case []int32:
|
|
return CreateNumericArrayVariant(cast, ole.VT_I4)
|
|
case int64:
|
|
s := fmt.Sprintf("%d", cast)
|
|
return ole.NewVariant(ole.VT_BSTR, int64(uintptr(unsafe.Pointer(ole.SysAllocStringLen(s))))), nil
|
|
case []int64:
|
|
strs := make([]string, len(cast))
|
|
for i, num := range cast {
|
|
strs[i] = fmt.Sprintf("%d", num)
|
|
}
|
|
return CreateStringArrayVariant(strs)
|
|
case float32:
|
|
return ole.NewVariant(ole.VT_R4, int64(math.Float32bits(cast))), nil
|
|
case float64:
|
|
return ole.NewVariant(ole.VT_R8, int64(math.Float64bits(cast))), nil
|
|
case uint8:
|
|
return ole.NewVariant(ole.VT_UI1, int64(cast)), nil
|
|
case []uint8:
|
|
return CreateNumericArrayVariant(cast, ole.VT_UI1)
|
|
case uint16:
|
|
return ole.NewVariant(ole.VT_I4, int64(cast)), nil
|
|
case []uint16:
|
|
return CreateNumericArrayVariant(cast, ole.VT_I4)
|
|
case uint32:
|
|
return ole.NewVariant(ole.VT_I4, int64(cast)), nil
|
|
case []uint32:
|
|
return CreateNumericArrayVariant(cast, ole.VT_I4)
|
|
case uint64:
|
|
s := fmt.Sprintf("%d", cast)
|
|
return ole.NewVariant(ole.VT_BSTR, int64(uintptr(unsafe.Pointer(ole.SysAllocStringLen(s))))), nil
|
|
case []uint64:
|
|
strs := make([]string, len(cast))
|
|
for i, num := range cast {
|
|
strs[i] = fmt.Sprintf("%d", num)
|
|
}
|
|
return CreateStringArrayVariant(strs)
|
|
|
|
// Assume 32 bit for generic (u)ints
|
|
case int:
|
|
return ole.NewVariant(ole.VT_I4, int64(cast)), nil
|
|
case uint:
|
|
return ole.NewVariant(ole.VT_I4, int64(cast)), nil
|
|
case []int:
|
|
return CreateNumericArrayVariant(cast, ole.VT_I4)
|
|
case []uint:
|
|
return CreateNumericArrayVariant(cast, ole.VT_I4)
|
|
|
|
case string:
|
|
return ole.NewVariant(ole.VT_BSTR, int64(uintptr(unsafe.Pointer(ole.SysAllocStringLen(value.(string)))))), nil
|
|
case []string:
|
|
if len(cast) == 0 {
|
|
return ole.NewVariant(ole.VT_NULL, 0), nil
|
|
}
|
|
return CreateStringArrayVariant(cast)
|
|
|
|
case time.Time:
|
|
return convertTimeToDataTime(&cast), nil
|
|
case *time.Time:
|
|
return convertTimeToDataTime(cast), nil
|
|
case time.Duration:
|
|
return convertDurationToDateTime(cast), nil
|
|
case nil:
|
|
return ole.NewVariant(ole.VT_NULL, 0), nil
|
|
case *ole.IUnknown:
|
|
if cast == nil {
|
|
return ole.NewVariant(ole.VT_NULL, 0), nil
|
|
}
|
|
return ole.NewVariant(ole.VT_UNKNOWN, int64(uintptr(unsafe.Pointer(cast)))), nil
|
|
case *Instance:
|
|
if cast == nil {
|
|
return ole.NewVariant(ole.VT_NULL, 0), nil
|
|
}
|
|
return ole.NewVariant(ole.VT_UNKNOWN, int64(uintptr(unsafe.Pointer(cast.object)))), nil
|
|
default:
|
|
return ole.VARIANT{}, fmt.Errorf("unsupported type for automation variants %T", value)
|
|
}
|
|
}
|
|
|
|
func convertToGoType(variant *ole.VARIANT, outputValue reflect.Value, outputType reflect.Type) (value interface{}, err error) {
|
|
if variant.VT&ole.VT_ARRAY == ole.VT_ARRAY {
|
|
return convertVariantToArray(variant, outputType)
|
|
}
|
|
|
|
if variant.VT == ole.VT_UNKNOWN {
|
|
return convertVariantToStruct(variant, outputType)
|
|
}
|
|
|
|
switch cast := outputValue.Interface().(type) {
|
|
case bool:
|
|
return variant.Val != 0, nil
|
|
case time.Time:
|
|
return convertDataTimeToTime(variant)
|
|
case *time.Time:
|
|
x, err := convertDataTimeToTime(variant)
|
|
return &x, err
|
|
case time.Duration:
|
|
return convertIntervalToDuration(variant)
|
|
case uint, uint8, uint16, uint32, uint64, int, int8, int16, int32, int64:
|
|
return convertVariantToInt(variant, outputType)
|
|
case float32, float64:
|
|
return convertVariantToFloat(variant, outputType)
|
|
case string:
|
|
return variant.ToString(), nil
|
|
default:
|
|
if variant.VT == ole.VT_NULL {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("could not convert %d to %v", variant.VT, cast)
|
|
}
|
|
}
|
|
|
|
func convertInt64ToInt(value int64, outputType reflect.Type) (interface{}, error) {
|
|
switch outputType.Kind() {
|
|
case reflect.Int:
|
|
return int(value), nil
|
|
case reflect.Int8:
|
|
return int8(value), nil
|
|
case reflect.Int16:
|
|
return int16(value), nil
|
|
case reflect.Int32:
|
|
return int32(value), nil
|
|
case reflect.Int64:
|
|
return int64(value), nil
|
|
case reflect.Uint:
|
|
return uint(value), nil
|
|
case reflect.Uint8:
|
|
return uint8(value), nil
|
|
case reflect.Uint16:
|
|
return uint16(value), nil
|
|
case reflect.Uint32:
|
|
return uint32(value), nil
|
|
case reflect.Uint64:
|
|
return uint64(value), nil
|
|
default:
|
|
return 0, fmt.Errorf("could not convert int64 to %v", outputType)
|
|
}
|
|
}
|
|
|
|
func convertStringToInt64(str string, unsigned bool) (int64, error) {
|
|
if unsigned {
|
|
val, err := strconv.ParseUint(str, 0, 64)
|
|
return int64(val), err
|
|
}
|
|
|
|
return strconv.ParseInt(str, 0, 64)
|
|
}
|
|
|
|
func convertVariantToInt(variant *ole.VARIANT, outputType reflect.Type) (interface{}, error) {
|
|
var value int64
|
|
switch variant.VT {
|
|
case ole.VT_NULL:
|
|
fallthrough
|
|
case ole.VT_BOOL:
|
|
fallthrough
|
|
case ole.VT_I1, ole.VT_I2, ole.VT_I4, ole.VT_I8, ole.VT_INT:
|
|
fallthrough
|
|
case ole.VT_UI1, ole.VT_UI2, ole.VT_UI4, ole.VT_UI8, ole.VT_UINT:
|
|
value = variant.Val
|
|
case ole.VT_R4:
|
|
// not necessarily a useful conversion but handle it anyway
|
|
value = int64(*(*float32)(unsafe.Pointer(&variant.Val)))
|
|
case ole.VT_R8:
|
|
value = int64(*(*float64)(unsafe.Pointer(&variant.Val)))
|
|
case ole.VT_BSTR:
|
|
var err error
|
|
value, err = convertStringToInt64(variant.ToString(), outputType.Kind() == reflect.Uint64)
|
|
if err != nil {
|
|
return value, err
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("could not convert variant type %d to %v", variant.VT, outputType)
|
|
}
|
|
|
|
return convertInt64ToInt(value, outputType)
|
|
}
|
|
|
|
func convertVariantToFloat(variant *ole.VARIANT, outputType reflect.Type) (interface{}, error) {
|
|
var value float64
|
|
switch variant.VT {
|
|
case ole.VT_NULL:
|
|
fallthrough
|
|
case ole.VT_BOOL:
|
|
fallthrough
|
|
case ole.VT_I1, ole.VT_I2, ole.VT_I4, ole.VT_I8, ole.VT_INT:
|
|
fallthrough
|
|
case ole.VT_UI1, ole.VT_UI2, ole.VT_UI4, ole.VT_UI8, ole.VT_UINT:
|
|
value = float64(variant.Val)
|
|
case ole.VT_R4:
|
|
value = float64(*(*float32)(unsafe.Pointer(&variant.Val)))
|
|
case ole.VT_R8:
|
|
value = *(*float64)(unsafe.Pointer(&variant.Val))
|
|
case ole.VT_BSTR:
|
|
var err error
|
|
value, err = strconv.ParseFloat(variant.ToString(), 64)
|
|
if err != nil {
|
|
return value, err
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("could not convert variant type %d to %v", variant.VT, outputType)
|
|
}
|
|
|
|
if outputType.Kind() == reflect.Float32 {
|
|
return float32(value), nil
|
|
}
|
|
|
|
return value, nil
|
|
}
|
|
|
|
func convertVariantToStruct(variant *ole.VARIANT, outputType reflect.Type) (interface{}, error) {
|
|
if variant.VT != ole.VT_UNKNOWN {
|
|
return nil, fmt.Errorf("could not convert non-IUnknown variant type %d to %v", variant.VT, outputType)
|
|
}
|
|
|
|
ptr := variant.ToIUnknown()
|
|
|
|
var rawInstance struct {
|
|
*ole.IUnknown
|
|
*IWbemClassObjectVtbl
|
|
}
|
|
|
|
rawInstance.IUnknown = ptr
|
|
rawInstance.IWbemClassObjectVtbl = (*IWbemClassObjectVtbl)(unsafe.Pointer(ptr.RawVTable))
|
|
|
|
instance := (*Instance)(unsafe.Pointer(&rawInstance))
|
|
val := reflect.New(outputType)
|
|
err := instance.GetAll(val.Interface())
|
|
return val.Elem().Interface(), err
|
|
}
|
|
|
|
func convertVariantToArray(variant *ole.VARIANT, outputType reflect.Type) (interface{}, error) {
|
|
if variant.VT&ole.VT_ARRAY != ole.VT_ARRAY {
|
|
return nil, fmt.Errorf("could not convert non-array variant type %d to %v", variant.VT, outputType)
|
|
}
|
|
|
|
safeArrayConversion := ole.SafeArrayConversion{Array: *(**ole.SafeArray)(unsafe.Pointer(&variant.Val))}
|
|
|
|
arrayLen, err := safeArrayConversion.TotalElements(0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
elemVT := (^ole.VT_ARRAY) & variant.VT
|
|
slice := reflect.MakeSlice(reflect.SliceOf(outputType.Elem()), int(arrayLen), int(arrayLen))
|
|
|
|
for i := 0; i < int(arrayLen); i++ {
|
|
elemVariant := ole.VARIANT{VT: elemVT}
|
|
elemSrc, err := safeArrayGetAsVariantVal(safeArrayConversion.Array, int64(i), elemVariant)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
elemVariant.Val = int64(elemSrc)
|
|
elemDest, err := convertToGoType(&elemVariant, slice.Index(i), outputType.Elem())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
slice.Index(i).Set(reflect.ValueOf(elemDest))
|
|
}
|
|
|
|
return slice.Interface(), nil
|
|
}
|
|
|
|
func convertToGenericValue(variant *ole.VARIANT) interface{} {
|
|
var result interface{}
|
|
if variant.VT&ole.VT_ARRAY == ole.VT_ARRAY {
|
|
safeArrayConversion := ole.SafeArrayConversion{Array: *(**ole.SafeArray)(unsafe.Pointer(&variant.Val))}
|
|
result = safeArrayConversion.ToValueArray()
|
|
} else {
|
|
result = variant.Value()
|
|
}
|
|
return result
|
|
}
|
|
|
|
func convertTimeToDataTime(time *time.Time) ole.VARIANT {
|
|
if time == nil || !time.After(WindowsEpoch) {
|
|
return ole.NewVariant(ole.VT_NULL, 0)
|
|
}
|
|
_, offset := time.Zone()
|
|
// convert to minutes
|
|
offset /= 60
|
|
//yyyymmddHHMMSS.mmmmmmsUUU
|
|
s := fmt.Sprintf("%s%+04d", time.Format("20060102150405.000000"), offset)
|
|
return ole.NewVariant(ole.VT_BSTR, int64(uintptr(unsafe.Pointer(ole.SysAllocStringLen(s)))))
|
|
}
|
|
|
|
func convertDurationToDateTime(duration time.Duration) ole.VARIANT {
|
|
const daySeconds = time.Second * 86400
|
|
|
|
if duration == 0 {
|
|
return ole.NewVariant(ole.VT_NULL, 0)
|
|
}
|
|
|
|
days := duration / daySeconds
|
|
duration = duration % daySeconds
|
|
|
|
hours := duration / time.Hour
|
|
duration = duration % time.Hour
|
|
|
|
mins := duration / time.Minute
|
|
duration = duration % time.Minute
|
|
|
|
seconds := duration / time.Second
|
|
duration = duration % time.Second
|
|
|
|
micros := duration / time.Microsecond
|
|
|
|
s := fmt.Sprintf("%08d%02d%02d%02d.%06d:000", days, hours, mins, seconds, micros)
|
|
return ole.NewVariant(ole.VT_BSTR, int64(uintptr(unsafe.Pointer(ole.SysAllocStringLen(s)))))
|
|
}
|
|
|
|
func extractDateTimeString(variant *ole.VARIANT) (string, error) {
|
|
switch variant.VT {
|
|
case ole.VT_BSTR:
|
|
return variant.ToString(), nil
|
|
case ole.VT_NULL:
|
|
return "", nil
|
|
default:
|
|
return "", errors.New("variant not compatible with dateTime field")
|
|
}
|
|
}
|
|
|
|
func convertDataTimeToTime(variant *ole.VARIANT) (time.Time, error) {
|
|
var err error
|
|
dateTime, err := extractDateTimeString(variant)
|
|
if err != nil || len(dateTime) == 0 {
|
|
return zeroTime, err
|
|
}
|
|
|
|
dLen := len(dateTime)
|
|
if dLen < 5 {
|
|
return zeroTime, errors.New("invalid datetime string")
|
|
}
|
|
|
|
if strings.HasPrefix(dateTime, "00000000000000.000000") {
|
|
// Zero time
|
|
return zeroTime, nil
|
|
}
|
|
|
|
zoneStart := dLen - 4
|
|
timePortion := dateTime[0:zoneStart]
|
|
|
|
var zoneMinutes int64
|
|
if dateTime[zoneStart] == ':' {
|
|
// interval ends in :000
|
|
return parseIntervalTime(dateTime)
|
|
}
|
|
|
|
zoneSuffix := dateTime[zoneStart:dLen]
|
|
zoneMinutes, err = strconv.ParseInt(zoneSuffix, 10, 0)
|
|
if err != nil {
|
|
return zeroTime, errors.New("invalid datetime string, zone did not parse")
|
|
}
|
|
|
|
timePortion = fmt.Sprintf("%s%+03d%02d", timePortion, zoneMinutes/60, abs(int(zoneMinutes%60)))
|
|
return time.Parse("20060102150405.000000-0700", timePortion)
|
|
}
|
|
|
|
// parseIntervalTime encodes an interval time as an offset to Unix time
|
|
// allowing a duration to be computed without precision loss
|
|
func parseIntervalTime(interval string) (time.Time, error) {
|
|
if len(interval) < 25 || interval[21:22] != ":" {
|
|
return time.Time{}, fmt.Errorf("invalid interval time: %s", interval)
|
|
}
|
|
|
|
days, err := parseUintChain(interval[0:8], nil)
|
|
hours, err := parseUintChain(interval[8:10], err)
|
|
mins, err := parseUintChain(interval[10:12], err)
|
|
secs, err := parseUintChain(interval[12:14], err)
|
|
micros, err := parseUintChain(interval[15:21], err)
|
|
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
|
|
var stamp uint64 = secs
|
|
stamp += days * 86400
|
|
stamp += hours * 3600
|
|
stamp += mins * 60
|
|
|
|
return time.Unix(int64(stamp), int64(micros*1000)), nil
|
|
}
|
|
|
|
func convertIntervalToDuration(variant *ole.VARIANT) (time.Duration, error) {
|
|
var err error
|
|
interval, err := extractDateTimeString(variant)
|
|
if err != nil || len(interval) == 0 {
|
|
return 0, err
|
|
}
|
|
|
|
t, err := parseIntervalTime(interval)
|
|
if err != nil {
|
|
return 0, nil
|
|
}
|
|
|
|
return t.Sub(unixEpoch), nil
|
|
}
|
|
|
|
func parseUintChain(str string, err error) (uint64, error) {
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return strconv.ParseUint(str, 10, 0)
|
|
}
|
|
|
|
func abs(num int) int {
|
|
if num < 0 {
|
|
return -num
|
|
}
|
|
|
|
return num
|
|
}
|