1 /*
   2  * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 #include <windows.h>
  27 #include <shellapi.h>
  28 
  29 #include "WinSysInfo.h"
  30 #include "FileUtils.h"
  31 #include "WinErrorHandling.h"
  32 
  33 #pragma comment(lib, "Shell32")
  34 
  35 namespace SysInfo {
  36 
  37 tstring getTempDir() {
  38     std::vector<TCHAR> buffer(MAX_PATH);
  39     DWORD res = GetTempPath(static_cast<DWORD>(buffer.size()), buffer.data());
  40     if (res > buffer.size()) {
  41         buffer.resize(res);
  42         GetTempPath(static_cast<DWORD>(buffer.size()), buffer.data());
  43     }
  44     return FileUtils::removeTrailingSlash(buffer.data());
  45 }
  46 
  47 namespace {
  48 
  49 template <class Func>
  50 tstring getSystemDirImpl(Func func, const std::string& label) {
  51     std::vector<TCHAR> buffer(MAX_PATH);
  52     for (int i=0; i<2; i++) {
  53         DWORD res = func(buffer.data(), static_cast<DWORD>(buffer.size()));
  54         if (!res) {
  55             JP_THROW(SysError(label + " failed", func));
  56         }
  57         if (res < buffer.size()) {
  58             return FileUtils::removeTrailingSlash(buffer.data());
  59         }
  60         buffer.resize(res + 1);
  61     }
  62     JP_THROW("Unexpected reply from" + label);
  63 }
  64 
  65 } // namespace
  66 
  67 tstring getSystem32Dir() {
  68     return getSystemDirImpl(GetSystemDirectory, "GetSystemDirectory");
  69 }
  70 
  71 tstring getWIPath() {
  72     return FileUtils::mkpath() << getSystem32Dir() << _T("msiexec.exe");
  73 }
  74 
  75 namespace {
  76 
  77 tstring getModulePath(HMODULE h)
  78 {
  79     std::vector<TCHAR> buf(MAX_PATH);
  80     DWORD len = 0;
  81     while (true) {
  82         len = GetModuleFileName(h, buf.data(), (DWORD)buf.size());
  83         if (len < buf.size()) {
  84             break;
  85         }
  86         // buffer is too small, increase it
  87         buf.resize(buf.size() * 2);
  88     }
  89 
  90     if (len == 0) {
  91         // error occured
  92         JP_THROW(SysError("GetModuleFileName failed", GetModuleFileName));
  93     }
  94     return tstring(buf.begin(), buf.begin() + len);
  95 }
  96 
  97 } // namespace
  98 
  99 tstring getProcessModulePath() {
 100     return getModulePath(NULL);
 101 }
 102 
 103 HMODULE getCurrentModuleHandle()
 104 {
 105     // get module handle for the address of this function
 106     LPCWSTR address = reinterpret_cast<LPCWSTR>(getCurrentModuleHandle);
 107     HMODULE hmodule = NULL;
 108     if (!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS
 109             | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, address, &hmodule))
 110     {
 111         JP_THROW(SysError(tstrings::any() << "GetModuleHandleExW failed",
 112                 GetModuleHandleExW));
 113     }
 114     return hmodule;
 115 }
 116 
 117 tstring getCurrentModulePath()
 118 {
 119     return getModulePath(getCurrentModuleHandle());
 120 }
 121 
 122 tstring_array getCommandArgs(CommandArgProgramNameMode progNameMode)
 123 {
 124     int argc = 0;
 125     tstring_array result;
 126 
 127     LPWSTR *parsedArgs = CommandLineToArgvW(GetCommandLineW(), &argc);
 128     if (parsedArgs == NULL) {
 129         JP_THROW(SysError("CommandLineToArgvW failed", CommandLineToArgvW));
 130     }
 131     // the 1st element contains program name
 132     for (int i = progNameMode == ExcludeProgramName ? 1 : 0; i < argc; i++) {
 133         result.push_back(parsedArgs[i]);
 134     }
 135     LocalFree(parsedArgs);
 136 
 137     return result;
 138 }
 139 
 140 namespace {
 141 
 142 tstring getEnvVariableImpl(const tstring& name, bool* errorOccured=0) {
 143     std::vector<TCHAR> buf(10);
 144     SetLastError(ERROR_SUCCESS);
 145     const DWORD size = GetEnvironmentVariable(name.c_str(), buf.data(),
 146                                                             DWORD(buf.size()));
 147     if (GetLastError() == ERROR_ENVVAR_NOT_FOUND) {
 148         if (errorOccured) {
 149             *errorOccured = true;
 150             return tstring();
 151         }
 152         JP_THROW(SysError(tstrings::any() << "GetEnvironmentVariable("
 153             << name << ") failed. Variable not set", GetEnvironmentVariable));
 154     }
 155 
 156     if (size > buf.size()) {
 157         buf.resize(size);
 158         GetEnvironmentVariable(name.c_str(), buf.data(), DWORD(buf.size()));
 159         if (GetLastError() != ERROR_SUCCESS) {
 160             if (errorOccured) {
 161                 *errorOccured = true;
 162                 return tstring();
 163             }
 164             JP_THROW(SysError(tstrings::any() << "GetEnvironmentVariable("
 165                             << name << ") failed", GetEnvironmentVariable));
 166         }
 167     }
 168 
 169     if (errorOccured) {
 170         *errorOccured = false;
 171     }
 172     return tstring(buf.data());
 173 }
 174 
 175 } // namespace
 176 
 177 tstring getEnvVariable(const tstring& name) {
 178     return getEnvVariableImpl(name);
 179 }
 180 
 181 tstring getEnvVariable(const std::nothrow_t&, const tstring& name,
 182                                                     const tstring& defValue) {
 183     bool errorOccured = false;
 184     const tstring reply = getEnvVariableImpl(name, &errorOccured);
 185     if (errorOccured) {
 186         return defValue;
 187     }
 188     return reply;
 189 }
 190 
 191 bool isEnvVariableSet(const tstring& name) {
 192     TCHAR unused[1];
 193     SetLastError(ERROR_SUCCESS);
 194     GetEnvironmentVariable(name.c_str(), unused, _countof(unused));
 195     return GetLastError() != ERROR_ENVVAR_NOT_FOUND;
 196 }
 197 
 198 } // end of namespace SysInfo