Merged SVN plugins
[idea/community.git] / native / focusKiller / HookImportFunction.cpp
1 /*
2 Module : HookImportFunction.cpp
3 Purpose: Defines the implementation for code to hook a call to any imported Win32 SDK
4 Created: PJN / 23-10-1999
5 History: PJN / 01-01-2001 1. Now includes copyright message in the source code and documentation.
6                           2. Fixed an access violation in where I was getting the name of the import
7                           function but not checking for failure.
8                           3. Fixed a compiler error where I was incorrectly casting to a PDWORD instead
9                           of a DWORD
10          PJN / 20-04-2002 1. Fixed a potential infinite loop in HookImportFunctionByName. Thanks to
11                           David Defoort for spotting this problem.
12
13 Copyright (c) 1996 - 2002 by PJ Naughter.  (Web: www.naughter.com, Email: pjna@naughter.com)
14
15 All rights reserved.
16
17 Copyright / Usage Details:
18
19 You are allowed to include the source code in any product (commercial, shareware, freeware or otherwise) 
20 when your product is released in binary form. You are allowed to modify the source code in any way you want 
21 except you cannot modify the copyright details at the top of each module. If you want to distribute source 
22 code with your application, then you are only allowed to distribute versions released by the author. This is 
23 to maintain a single distribution point for the source code. 
24
25 */
26
27
28 ////////////////// Includes ////////////////////////////////////
29
30 #include <windows.h>
31 #include "HookImportFunction.h"
32
33 #define ASSERT(e)
34 #define VERIFY(e) e
35 #define TRACE0(s) OutputDebugString(s)
36 #define _T(s) s
37
38 ////////////////// Defines / Locals ////////////////////////////
39
40 #ifdef _DEBUG
41   #define new DEBUG_NEW
42   #undef THIS_FILE
43   static char THIS_FILE[] = __FILE__;
44 #endif
45
46 #define MakePtr(cast, ptr, AddValue) (cast)((DWORD)(ptr)+(DWORD)(AddValue))
47
48 BOOL IsNT();
49
50
51
52 ////////////////// Implementation //////////////////////////////
53
54 BOOL HookImportFunctionsByName(HMODULE hModule, LPCSTR szImportMod, UINT uiCount, 
55                                LPHOOKFUNCDESC paHookArray, PROC* paOrigFuncs, UINT* puiHooked)
56 {
57   // Double check the parameters.
58   ASSERT(szImportMod);
59   ASSERT(uiCount);
60   ASSERT(!IsBadReadPtr(paHookArray, sizeof(HOOKFUNCDESC)*uiCount));
61
62 #ifdef _DEBUG
63   if (paOrigFuncs)
64     ASSERT(!IsBadWritePtr(paOrigFuncs, sizeof(PROC)*uiCount));
65   if (puiHooked)
66     ASSERT(!IsBadWritePtr(puiHooked, sizeof(UINT)));
67
68   //Check each function name in the hook array.
69   for (UINT i = 0; i<uiCount; i++)
70   {
71     ASSERT(paHookArray[i].szFunc);
72     ASSERT(*paHookArray[i].szFunc != _T('\0'));
73
74     //If the proc is not NULL, then it is checked.
75     if (paHookArray[i].pProc)
76       ASSERT(!IsBadCodePtr(paHookArray[i].pProc));
77   }
78 #endif
79
80   //Do the parameter validation for real.
81   if (uiCount == 0 || szImportMod == NULL || IsBadReadPtr(paHookArray, sizeof(HOOKFUNCDESC)* uiCount))
82   {
83     ASSERT(FALSE);
84     SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR);
85     return FALSE;
86   }
87
88   if (paOrigFuncs && IsBadWritePtr(paOrigFuncs, sizeof(PROC)*uiCount))
89   {
90     ASSERT(FALSE);
91     SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR);
92     return FALSE;
93   }
94
95   if (puiHooked && IsBadWritePtr(puiHooked, sizeof(UINT)))
96   {
97     ASSERT(FALSE);
98     SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR );
99     return FALSE;
100   }
101
102   //Is this a system DLL, which Windows95 will not let you patch
103   //since it is above the 2GB line?
104   if (!IsNT() && ((DWORD)hModule >= 0x80000000))
105   {
106     #ifdef _DEBUG
107     CString sMsg;
108     sMsg.Format(_T("Could not hook module %x because we are on Win9x and it is in shared memory\n"), hModule);
109     OutputDebugString(sMsg);
110     #endif
111     SetLastErrorEx(ERROR_INVALID_HANDLE, SLE_ERROR);
112     return FALSE;
113   }
114
115   //TODO TODO
116   // Should each item in the hook array be checked in release builds?
117
118   if (puiHooked)
119     *puiHooked = 0; //Set the number of functions hooked to zero.
120
121   //Get the specific import descriptor.
122   PIMAGE_IMPORT_DESCRIPTOR pImportDesc = GetNamedImportDescriptor(hModule, szImportMod);
123   if (NULL == pImportDesc)
124     return FALSE; // The requested module was not imported.
125
126   HINSTANCE hImportMod = GetModuleHandle(szImportMod);
127   if (NULL == hImportMod)
128   {
129     ASSERT(FALSE);
130     SetLastErrorEx(ERROR_HOOK_NEEDS_HMOD, SLE_ERROR);
131     return FALSE; // The requested module was not available.
132   }
133
134   //Set all the values in paOrigFuncs to NULL.
135   if (NULL != paOrigFuncs)
136     memset(paOrigFuncs, NULL, sizeof(PROC)*uiCount);
137
138   //Get the original thunk information for this DLL.  I cannot use
139   // the thunk information stored in the pImportDesc->FirstThunk
140   // because the that is the array that the loader
141   // has already bashed to fix up all the imports. 
142   // This pointer gives us acess to the function names.
143   PIMAGE_THUNK_DATA pOrigThunk = MakePtr(PIMAGE_THUNK_DATA, hModule, pImportDesc->OriginalFirstThunk);
144
145   //Get the array pointed to by the pImportDesc->FirstThunk. 
146   // This is where I will do the actual bash.
147   PIMAGE_THUNK_DATA pRealThunk = MakePtr(PIMAGE_THUNK_DATA, hModule, pImportDesc->FirstThunk);
148
149   //Loop through and look for the one that matches the name.
150   for (; NULL != pOrigThunk->u1.Function;
151       // Increment both tables.
152       pOrigThunk++, pRealThunk++)
153   {
154     //Only look at those that are imported by name, not ordinal.
155     if (IMAGE_ORDINAL_FLAG == (IMAGE_ORDINAL_FLAG & pOrigThunk->u1.Ordinal))
156       continue;
157
158     //Look get the name of this imported function.
159     PIMAGE_IMPORT_BY_NAME pByName = MakePtr(PIMAGE_IMPORT_BY_NAME, hModule, pOrigThunk->u1.AddressOfData);
160
161     if (IsBadReadPtr(pByName, MAX_PATH+4))
162     {
163       SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR);
164       continue;
165     }
166
167     //If the name starts with NULL, then just skip to next.
168     if (_T('\0') == pByName->Name[0])
169       continue;
170
171     //Determines if we do the hook.
172     BOOL bDoHook = FALSE;
173
174     //TODO {
175     // Might want to consider bsearch here.
176     //TODO }
177     //See if the particular function name is in the import
178     // list.  It might be good to consider requiring the
179     // paHookArray to be in sorted order so bsearch could be
180     // used so the lookup will be faster.  However, the size of
181     // uiCount coming into this function should be rather small
182     // but it is called for each function imported by szImportMod.
183     for (UINT i = 0; i<uiCount; i++)
184     {
185       if ((paHookArray[i].szFunc[0] == pByName->Name[0]) &&
186         (strcmpi(paHookArray[i].szFunc, (char*)pByName->Name) == 0))
187       {
188         //If the proc is NULL, kick out, otherwise
189         // go ahead and hook it.
190         if (paHookArray[i].pProc)
191           bDoHook = TRUE;
192         break;
193       }
194     }
195
196     if (FALSE == bDoHook)
197       continue;
198
199     // I found it.  Now I need to change the protection to
200     //  writable before I do the blast.  Note that I am now
201     //  blasting into the real thunk area!
202     MEMORY_BASIC_INFORMATION mbi_thunk;
203     VirtualQuery(pRealThunk, &mbi_thunk, sizeof(MEMORY_BASIC_INFORMATION));
204     VERIFY(VirtualProtect(mbi_thunk.BaseAddress, mbi_thunk.RegionSize, PAGE_READWRITE, &mbi_thunk.Protect));
205
206     // Get fast/simple pointer
207     PROC* pFunction = (PROC*) &(pRealThunk->u1.Function);
208     if (*pFunction == paHookArray[i].pProc)
209     {
210       SetLastErrorEx(ERROR_ALREADY_INITIALIZED, SLE_ERROR);
211       return FALSE;
212     }
213     if (IsBadCodePtr(*pFunction))
214     {
215       ASSERT(FALSE);
216       SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR);
217       return FALSE;
218     }
219     //Save the original address if requested.
220     if (NULL != paOrigFuncs)
221     {
222       if ((DWORD)(*pFunction) < (DWORD)hImportMod && ((DWORD)(0x80000000) > (DWORD)hImportMod))
223       {
224         ASSERT(FALSE);
225         SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR);
226         return FALSE;
227       }
228       if (*pFunction != paOrigFuncs[i])
229       {
230         if (NULL != paOrigFuncs[i])
231         {
232           if (paHookArray[i].pProc != paOrigFuncs[i])
233           {
234             ASSERT(FALSE);
235             SetLastErrorEx(ERROR_INVALID_ADDRESS, SLE_ERROR);
236             return FALSE;
237           }
238         }
239         paOrigFuncs[i] = * pFunction;
240       }
241     }
242     //Do the actual hook.
243     *pFunction = paHookArray[i].pProc;
244
245     //Increment the total number hooked.
246     if (puiHooked)
247       *puiHooked += 1; 
248
249     //Change the protection back to what it was before I blasted.
250     DWORD dwOldProtect;
251     VERIFY(VirtualProtect(mbi_thunk.BaseAddress, mbi_thunk.RegionSize, mbi_thunk.Protect, &dwOldProtect));
252   }
253   //All OK, JumpMaster!
254   SetLastError(ERROR_SUCCESS);
255   return TRUE;
256 }
257
258 PIMAGE_IMPORT_DESCRIPTOR GetNamedImportDescriptor(HMODULE hModule, LPCSTR szImportMod)
259 {
260   //Always check parameters.
261   ASSERT(szImportMod);
262   ASSERT(hModule);
263   if ((szImportMod == NULL) || (hModule == NULL))
264   {
265     ASSERT(FALSE);
266     SetLastErrorEx(ERROR_INVALID_PARAMETER, SLE_ERROR);
267     return NULL;
268   }
269
270   //Get the Dos header.
271   PIMAGE_DOS_HEADER pDOSHeader = (PIMAGE_DOS_HEADER) hModule;
272
273   // Is this the MZ header?
274   if (IsBadReadPtr(pDOSHeader, sizeof(IMAGE_DOS_HEADER)) || (pDOSHeader->e_magic != IMAGE_DOS_SIGNATURE))
275   {
276     #ifdef _DEBUG
277     CString sMsg;
278     sMsg.Format(_T("Could not find the MZ Header for %x\n"), hModule);
279     OutputDebugString(sMsg);
280     #endif
281     SetLastErrorEx( ERROR_BAD_EXE_FORMAT, SLE_ERROR);
282     return NULL;
283   }
284
285   // Get the PE header.
286   PIMAGE_NT_HEADERS pNTHeader = MakePtr(PIMAGE_NT_HEADERS, pDOSHeader, pDOSHeader->e_lfanew);
287
288   //Is this a real PE image?
289   if (IsBadReadPtr(pNTHeader, sizeof(IMAGE_NT_HEADERS)) || (pNTHeader->Signature != IMAGE_NT_SIGNATURE))
290   {
291     ASSERT(FALSE);
292     SetLastErrorEx( ERROR_INVALID_EXE_SIGNATURE, SLE_ERROR);
293     return NULL;
294   }
295
296   //If there is no imports section, leave now.
297   if (pNTHeader->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress == 0)
298     return NULL;
299
300   // Get the pointer to the imports section.
301   PIMAGE_IMPORT_DESCRIPTOR pImportDesc = MakePtr(PIMAGE_IMPORT_DESCRIPTOR, pDOSHeader, pNTHeader->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress);
302
303   //Loop through the import module descriptors looking for the module whose name matches szImportMod.
304   while (pImportDesc->Name)
305   {
306     PSTR szCurrMod = MakePtr(PSTR, pDOSHeader, pImportDesc->Name);
307     if (stricmp(szCurrMod, szImportMod) == 0)
308       break; // Found it.
309
310     //Look at the next one.
311     pImportDesc++;
312   }
313
314   //If the name is NULL, then the module is not imported.
315   if (pImportDesc->Name == NULL)
316     return NULL;
317
318   //All OK, Jumpmaster!
319   return pImportDesc;
320 }
321
322 BOOL IsNT()
323 {
324   OSVERSIONINFO stOSVI;
325   memset(&stOSVI, NULL, sizeof(OSVERSIONINFO));
326   stOSVI.dwOSVersionInfoSize = sizeof(OSVERSIONINFO);
327
328   BOOL bRet = GetVersionEx(&stOSVI);
329   ASSERT(TRUE == bRet);
330   if (FALSE == bRet)
331   {
332     TRACE0("GetVersionEx failed!\n");
333     return FALSE;
334   }
335
336   //Check the version and call the appropriate thing.
337   return (VER_PLATFORM_WIN32_NT == stOSVI.dwPlatformId);
338 }